├── LICENSE ├── onekey_core ├── transforms │ ├── __init__.py │ └── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── functional.cpython-37.pyc │ │ └── transforms.cpython-37.pyc ├── __pycache__ │ ├── utils.cpython-37.pyc │ └── __init__.cpython-37.pyc ├── models │ ├── __pycache__ │ │ ├── vgg.cpython-37.pyc │ │ ├── utils.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── _utils.cpython-37.pyc │ │ ├── alexnet.cpython-37.pyc │ │ ├── densenet.cpython-37.pyc │ │ ├── mnasnet.cpython-37.pyc │ │ ├── resnet.cpython-37.pyc │ │ ├── resnet3d.cpython-37.pyc │ │ ├── googlenet.cpython-37.pyc │ │ ├── inception.cpython-37.pyc │ │ ├── mobilenet.cpython-37.pyc │ │ ├── squeezenet.cpython-37.pyc │ │ ├── mobilenetv2.cpython-37.pyc │ │ ├── mobilenetv3.cpython-37.pyc │ │ ├── res2net_v1b.cpython-37.pyc │ │ └── shufflenetv2.cpython-37.pyc │ ├── utils.py │ ├── mobilenet.py │ ├── __init__.py │ ├── alexnet.py │ ├── _utils.py │ ├── squeezenet.py │ ├── resnet3d.py │ ├── vgg.py │ ├── shufflenetv2.py │ ├── mobilenetv2.py │ ├── mnasnet.py │ └── res2net_v1b.py ├── core │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── image_loader.cpython-37.pyc │ │ ├── losses_factory.cpython-37.pyc │ │ ├── model_factory.cpython-37.pyc │ │ ├── optimizer_lr_factory.cpython-37.pyc │ │ └── transformer_factory.cpython-37.pyc │ ├── __init__.py │ ├── image_loader.py │ ├── model_factory.py │ ├── test_factory.py │ ├── transformer_factory.py │ ├── optimizer_lr_factory.py │ └── losses_factory.py ├── __init__.py ├── extension.py └── utils.py ├── onekey_algo ├── changelog.md ├── custom │ ├── components │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── comp1.cpython-37.pyc │ │ │ ├── comp2.cpython-37.pyc │ │ │ ├── delong.cpython-37.pyc │ │ │ ├── metrics.cpython-37.pyc │ │ │ ├── stats.cpython-37.pyc │ │ │ ├── Radiology.cpython-37.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── nomogram.cpython-37.pyc │ │ ├── nomogram.py │ │ ├── delong.py │ │ └── comp2.py │ ├── __pycache__ │ │ └── __init__.cpython-37.pyc │ ├── utils │ │ ├── __pycache__ │ │ │ └── __init__.cpython-37.pyc │ │ └── __init__.py │ └── __init__.py ├── __pycache__ │ └── __init__.cpython-37.pyc ├── datasets │ ├── test_case │ │ ├── mask │ │ │ └── 2007_000584.png │ │ └── image │ │ │ └── 2007_000584.jpg │ ├── __pycache__ │ │ ├── vision.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── image_loader.cpython-37.pyc │ │ └── ClassificationDataset.cpython-37.pyc │ ├── __init__.py │ ├── image_loader.py │ └── vision.py ├── mietb │ ├── __pycache__ │ │ ├── utils.cpython-37.pyc │ │ └── __init__.cpython-37.pyc │ ├── super_resolution │ │ ├── __pycache__ │ │ │ ├── rcan.cpython-37.pyc │ │ │ ├── common.cpython-37.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── eval_super_res_reconstruction.cpython-37.pyc │ │ ├── __init__.py │ │ ├── RRDBNet_arch.py │ │ ├── common.py │ │ ├── eval_super_res_reconstruction.py │ │ └── rcan.py │ ├── __init__.py │ ├── utils.py │ └── README.md ├── utils │ ├── __pycache__ │ │ ├── common.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── about_log.cpython-37.pyc │ │ ├── MultiProcess.cpython-37.pyc │ │ └── dataset_utils.cpython-37.pyc │ ├── __init__.py │ ├── MultiProcess.py │ └── about_log.py ├── classification │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── eval_classification.cpython-37.pyc │ ├── __init__.py │ ├── README.md │ ├── inference.py │ └── eval_classification.py └── __init__.py ├── onekey_comp ├── comp0. MIET(医学图像处理) │ ├── config.yaml │ ├── Module1. 超分辨率采样.ipynb │ └── README.md ├── comp9-Solutions │ ├── sol1. 传统组学-单中心-临床 │ │ ├── config.yaml │ │ ├── 点我运行.bat │ │ ├── stats.csv │ │ └── Step2. 临床基线统计分析.ipynb │ └── 点我运行.bat ├── comp1-传统组学 │ └── 点我运行.bat ├── comp2-结构化数据 │ └── 点我运行.bat ├── comp4-What(分类识别) │ ├── 点我运行.bat │ ├── 可视化.bat │ ├── 生成CV数据-List模式.ipynb │ └── What-特征提取.ipynb ├── comp7-Survival │ ├── 点我运行.bat │ ├── 列线图-Nomogram.ipynb │ └── 生存分析-KaplanMeier.ipynb ├── 点我运行.bat ├── README.md └── .gitignore └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /onekey_core/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | -------------------------------------------------------------------------------- /onekey_algo/changelog.md: -------------------------------------------------------------------------------- 1 | # onekey lite 2 | 3 | 轻量级Onekey平台。 4 | 5 | 2023年5月5日,重新初始化 -------------------------------------------------------------------------------- /onekey_core/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/custom/components/__init__.py: -------------------------------------------------------------------------------- 1 | # TODO: 把components迁移到onekey_core中 2 | from . import comp1 3 | from . import comp2 4 | from . import Radiology -------------------------------------------------------------------------------- /onekey_algo/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/vgg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/vgg.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/datasets/test_case/mask/2007_000584.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/datasets/test_case/mask/2007_000584.png -------------------------------------------------------------------------------- /onekey_algo/mietb/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/mietb/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/utils/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/utils/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/custom/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/custom/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/datasets/__pycache__/vision.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/datasets/__pycache__/vision.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/datasets/test_case/image/2007_000584.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/datasets/test_case/image/2007_000584.jpg -------------------------------------------------------------------------------- /onekey_algo/mietb/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/mietb/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/utils/__pycache__/about_log.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/utils/__pycache__/about_log.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/core/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/core/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/_utils.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/alexnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/alexnet.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/densenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/densenet.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/mnasnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/mnasnet.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/resnet3d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/resnet3d.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/core/__pycache__/image_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/core/__pycache__/image_loader.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/googlenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/googlenet.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/inception.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/inception.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/mobilenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/mobilenet.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/squeezenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/squeezenet.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/utils/__pycache__/MultiProcess.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/utils/__pycache__/MultiProcess.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/utils/__pycache__/dataset_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/utils/__pycache__/dataset_utils.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/core/__pycache__/losses_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/core/__pycache__/losses_factory.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/core/__pycache__/model_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/core/__pycache__/model_factory.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/mobilenetv2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/mobilenetv2.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/mobilenetv3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/mobilenetv3.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/res2net_v1b.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/res2net_v1b.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/__pycache__/shufflenetv2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/models/__pycache__/shufflenetv2.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/transforms/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/transforms/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/custom/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/custom/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/datasets/__pycache__/image_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/datasets/__pycache__/image_loader.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/transforms/__pycache__/functional.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/transforms/__pycache__/functional.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/transforms/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/transforms/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/classification/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/classification/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/custom/components/__pycache__/comp1.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/custom/components/__pycache__/comp1.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/custom/components/__pycache__/comp2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/custom/components/__pycache__/comp2.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/custom/components/__pycache__/delong.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/custom/components/__pycache__/delong.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/custom/components/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/custom/components/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/custom/components/__pycache__/stats.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/custom/components/__pycache__/stats.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/core/__pycache__/optimizer_lr_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/core/__pycache__/optimizer_lr_factory.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/core/__pycache__/transformer_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_core/core/__pycache__/transformer_factory.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/custom/components/__pycache__/Radiology.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/custom/components/__pycache__/Radiology.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/custom/components/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/custom/components/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/custom/components/__pycache__/nomogram.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/custom/components/__pycache__/nomogram.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/mietb/super_resolution/__pycache__/rcan.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/mietb/super_resolution/__pycache__/rcan.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_comp/comp0. MIET(医学图像处理)/config.yaml: -------------------------------------------------------------------------------- 1 | # 你的数据存放位置,循环查找所有的数据目录,找到所有的nii.gz数据 2 | rad_dir: ONEKEYDS_HOME/CT/images 3 | # 超分重建的放大倍数,目前支持2倍和4倍的超分重建比例。 4 | scale: 4 5 | # 数据保存位置。 6 | save_dir: . -------------------------------------------------------------------------------- /onekey_algo/datasets/__pycache__/ClassificationDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/datasets/__pycache__/ClassificationDataset.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/mietb/super_resolution/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/mietb/super_resolution/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch.hub import load_state_dict_from_url 3 | except ImportError: 4 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 5 | -------------------------------------------------------------------------------- /onekey_algo/mietb/super_resolution/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/mietb/super_resolution/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_algo/classification/__pycache__/eval_classification.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/classification/__pycache__/eval_classification.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_comp/comp9-Solutions/sol1. 传统组学-单中心-临床/config.yaml: -------------------------------------------------------------------------------- 1 | radio_dir: C:\OnekeyPlatform\OnekeyDS\CT 2 | task_colum: label 3 | clinic_file: C:\OnekeyPlatform\OnekeyDS\survival.csv 4 | font.size: 12 5 | sel_model: LR -------------------------------------------------------------------------------- /onekey_algo/mietb/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2023/04/15 4 | # Forum: www.medai.icu 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2023 All Rights Reserved. 7 | -------------------------------------------------------------------------------- /onekey_algo/custom/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2021/12/22 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2021 All Rights Reserved. 7 | -------------------------------------------------------------------------------- /onekey_algo/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2019/7/22 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2019 All Rights Reserved. 7 | -------------------------------------------------------------------------------- /onekey_algo/mietb/super_resolution/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2023/04/15 4 | # Forum: www.medai.icu 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2023 All Rights Reserved. 7 | -------------------------------------------------------------------------------- /onekey_algo/mietb/super_resolution/__pycache__/eval_super_res_reconstruction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OnekeyAI-Platform/onekey/HEAD/onekey_algo/mietb/super_resolution/__pycache__/eval_super_res_reconstruction.cpython-37.pyc -------------------------------------------------------------------------------- /onekey_core/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | from .mobilenetv2 import MobileNetV2, mobilenet_v2, __all__ as mv2_all 2 | from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all 3 | 4 | __all__ = mv2_all + mv3_all 5 | -------------------------------------------------------------------------------- /onekey_algo/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2019/7/22 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2019 All Rights Reserved. 7 | from .ClassificationDataset import * 8 | -------------------------------------------------------------------------------- /onekey_algo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2021/3/21 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2021 All Rights Reserved. 7 | import os 8 | from .common import * 9 | 10 | -------------------------------------------------------------------------------- /onekey_algo/mietb/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def normalize(_min, _max, _data): 5 | return (_data - _min) / (_max - _min) 6 | 7 | 8 | def denormalize(_min, _max, _data): 9 | return (_max - _min) * _data + _min 10 | 11 | 12 | def clip(_min, _max, _data): 13 | return np.clip(_data, _min, _max) -------------------------------------------------------------------------------- /onekey_core/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .alexnet import * 2 | from .densenet import * 3 | from .googlenet import * 4 | from .inception import * 5 | from .mnasnet import * 6 | from .mobilenet import * 7 | from .resnet import * 8 | from .shufflenetv2 import * 9 | from .squeezenet import * 10 | from .vgg import * 11 | from .res2net_v1b import * 12 | from .resnet3d import * 13 | -------------------------------------------------------------------------------- /onekey_core/core/__init__.py: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # -*- coding: UTF-8 -*- 3 | # Authorized by Vlon Jang 4 | # Created on 2019/7/24 5 | # Blog: www.wangqingbaidu.cn 6 | # Email: wangqingbaidu@gmail.com 7 | # Copyright 2015-2019 All Rights Reserved. 8 | 9 | from .losses_factory import * 10 | from .model_factory import * 11 | from .optimizer_lr_factory import * 12 | from .transformer_factory import * 13 | 14 | -------------------------------------------------------------------------------- /onekey_comp/comp1-传统组学/点我运行.bat: -------------------------------------------------------------------------------- 1 | if exist C:\Users\onekey\.conda\envs\onekey ( 2 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate onekey; jupyter-notebook" 3 | ) else ( 4 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate %ONEKEY_HOME%onekey_envs; jupyter-notebook" 5 | ) -------------------------------------------------------------------------------- /onekey_comp/comp2-结构化数据/点我运行.bat: -------------------------------------------------------------------------------- 1 | if exist C:\Users\onekey\.conda\envs\onekey ( 2 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate onekey; jupyter-notebook" 3 | ) else ( 4 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate %ONEKEY_HOME%onekey_envs; jupyter-notebook" 5 | ) -------------------------------------------------------------------------------- /onekey_comp/comp4-What(分类识别)/点我运行.bat: -------------------------------------------------------------------------------- 1 | if exist C:\Users\onekey\.conda\envs\onekey ( 2 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate onekey; jupyter-notebook" 3 | ) else ( 4 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate %ONEKEY_HOME%onekey_envs; jupyter-notebook" 5 | ) -------------------------------------------------------------------------------- /onekey_comp/comp7-Survival/点我运行.bat: -------------------------------------------------------------------------------- 1 | if exist C:\Users\onekey\.conda\envs\onekey ( 2 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate onekey; jupyter-notebook" 3 | ) else ( 4 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate %ONEKEY_HOME%onekey_envs; jupyter-notebook" 5 | ) -------------------------------------------------------------------------------- /onekey_comp/comp9-Solutions/点我运行.bat: -------------------------------------------------------------------------------- 1 | if exist C:\Users\onekey\.conda\envs\onekey ( 2 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate onekey; jupyter-notebook" 3 | ) else ( 4 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate %ONEKEY_HOME%onekey_envs; jupyter-notebook" 5 | ) -------------------------------------------------------------------------------- /onekey_comp/comp9-Solutions/sol1. 传统组学-单中心-临床/点我运行.bat: -------------------------------------------------------------------------------- 1 | if exist C:\Users\onekey\.conda\envs\onekey ( 2 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate onekey; jupyter-notebook" 3 | ) else ( 4 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate %ONEKEY_HOME%onekey_envs; jupyter-notebook" 5 | ) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cPython 2 | lightgbm 3 | pexpect 4 | matplotlib 5 | numpy 6 | scipy 7 | scikit-learn 8 | pandas 9 | termcolor 10 | contextlib2 11 | prettytable 12 | graphviz 13 | configparser 14 | pillow 15 | pytest 16 | pyyaml 17 | requests 18 | setuptools 19 | pycocotools 20 | opencv-python==4.3.0.36 21 | imgviz 22 | scikit-image 23 | pydicom 24 | imageio 25 | labelme 26 | thop 27 | flask 28 | pyradiomics 29 | xgboost 30 | seaborn 31 | nibabel 32 | pynrrd 33 | moviepy 34 | lxml 35 | SimpleITK 36 | statsmodels 37 | imblearn 38 | joblib -------------------------------------------------------------------------------- /onekey_comp/点我运行.bat: -------------------------------------------------------------------------------- 1 | if exist C:\Users\onekey\.conda\envs\onekey ( 2 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate onekey; set PYTHONPATH=D:\Projects\onekey_lite; jupyter-notebook" 3 | ) else ( 4 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate %ONEKEY_HOME%onekey_envs; set=PYTHONPATH D:\Projects\onekey_lite; jupyter-notebook" 5 | ) -------------------------------------------------------------------------------- /onekey_comp/comp4-What(分类识别)/可视化.bat: -------------------------------------------------------------------------------- 1 | if exist C:\Users\onekey\.conda\envs\onekey ( 2 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate onekey; python C:\Users\onekey\.conda\envs\onekey\Lib\site-packages\onekey_algo\scripts\serving.py" 3 | ) else ( 4 | %windir%\System32\WindowsPowerShell\v1.0\powershell.exe -ExecutionPolicy ByPass -NoExit -Command "& 'C:\ProgramData\Anaconda3\shell\condabin\conda-hook.ps1' ; conda activate %ONEKEY_HOME%onekey_envs; python %ONEKEY_HOME%onekey_envs\Lib\site-packages\onekey_algo\scripts\serving.py" 5 | ) -------------------------------------------------------------------------------- /onekey_comp/README.md: -------------------------------------------------------------------------------- 1 | # Onekey 2 | 3 | ## 初始化 4 | 5 | 首次使用Onekey时必须要运行的代码,具体的步骤,另外注意解压到的目录不能包括中文。 6 | 7 | 1. 安装必备软件中的`Anaconda3.exe`。 8 | 9 | > 注意在此页面一定选择`All Users(requires admin privileges)`,然后一路保持默认,包括安装位置。 10 | 11 | ![image-20220426225135173](http://www.medai.icu/storage/attachments/2022/04/26/n1PceZmyokf1LvGAMuYD9p1nUM5OP93xQmEIKTrY.png) 12 | 13 | 2. 双击运行`初始化OnekeyAI.bat`, 等待安装完成。如果是非初次初始化,中间可能会遇到问题,可尝试多次直至没有问题。 14 | 15 | 3. 运行OnekeyTools工具箱中的OKT-update.exe程序,更新Onekey组件。 16 | ### 验证配置 17 | 18 | 运行`验证OnekeyAI.bat`;出现下面的,即为初始化成功。 19 | 20 | ```shell 21 | ####################################################### 22 | ## 欢迎使用Onekey,当前版本:x.x.x ## 23 | ## OnekeyAI助力科研,我们将竭诚为您服务! ## 24 | ####################################################### 25 | ``` 26 | 27 | ## 使用说明 28 | 29 | 进入到onekey_comp,使用【点我运行.bat】启动可视化页面,使用对应组件即可。 30 | -------------------------------------------------------------------------------- /onekey_algo/mietb/README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | 这是一个用Python编写的医学图像处理程序 **M**edical **I**mage **E**nhancement **T**ool **B**ox(MIETB) ,可以用于各种医学图像处理相关的任务。 4 | 5 | ## 效果说明 6 | 7 | 该程序可以实现以下功能: 8 | 9 | - 图像超分辨率重建 10 | 11 | ## 依赖环境 12 | 13 | 该程序依赖于以下Python库: 14 | 15 | - numpy 16 | 17 | - matplotlib 18 | 19 | - opencv-python 20 | 21 | - scikit-image 22 | 23 | - scikit-learn 24 | 25 | 可以使用以下命令安装: 26 | 27 | ```pip install numpy matplotlib opencv-python scikit-image scikit-learn``` 28 | 29 | 30 | 31 | ## 运行方式 32 | 33 | 根据任务不同,运行方式有所区别。 34 | 35 | 36 | 37 | ### 超分辨率重建 38 | 39 | 在命令行中运行以下命令: 40 | 41 | ```python sr_demo.py --src_dir path/to/your/input/data --dst_dir path/to/your/output/data --scale 4 ``` 42 | 43 | 参数释义: 44 | 45 | - src_dir path/to/your/input/data 46 | - dst_dir path/to/your/output/data 47 | - scale 输入一个放大倍数,支持[2,3,4,8]倍 48 | 49 | ## TODO 50 | 51 | - 支持更多图像处理方法 52 | 53 | - 支持Image2Image Translation 方法,如T1转T2 54 | 55 | -------------------------------------------------------------------------------- /onekey_core/__init__.py: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | from . import models 3 | from . import transforms 4 | from . import utils 5 | 6 | __VERSION__ = '2.3.5' 7 | _image_backend = 'PIL' 8 | 9 | 10 | def set_image_backend(backend): 11 | """ 12 | Specifies the package used to load images. 13 | 14 | Args: 15 | backend (string): Name of the image backend. one of {'PIL', 'accimage'}. 16 | The :mod:`accimage` package uses the Intel IPP library. It is 17 | generally faster than PIL, but does not support as many operations. 18 | """ 19 | global _image_backend 20 | if backend not in ['PIL', 'accimage']: 21 | raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'" 22 | .format(backend)) 23 | _image_backend = backend 24 | 25 | 26 | def get_image_backend(): 27 | """ 28 | Gets the name of the package used to load images 29 | """ 30 | return _image_backend 31 | -------------------------------------------------------------------------------- /onekey_algo/datasets/image_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2019/7/25 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2019 All Rights Reserved. 7 | 8 | import nibabel as nib 9 | import numpy as np 10 | from PIL import Image 11 | 12 | 13 | def pil_loader(path): 14 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 15 | with open(path, 'rb') as f: 16 | img = Image.open(f) 17 | return img.convert('RGB') 18 | 19 | 20 | def nii_loader(path): 21 | return np.array(nib.load(path).dataobj) 22 | 23 | 24 | def accimage_loader(path): 25 | import accimage 26 | try: 27 | return accimage.Image(path) 28 | except IOError: 29 | # Potentially a decoding problem, fall back to PIL.Image 30 | return pil_loader(path) 31 | 32 | 33 | def default_loader(path): 34 | from torchvision import get_image_backend 35 | if get_image_backend() == 'accimage': 36 | return accimage_loader(path) 37 | else: 38 | if path.endswith('.nii.gz') or path.endswith('.nii'): 39 | return nii_loader(path) 40 | else: 41 | return pil_loader(path) 42 | -------------------------------------------------------------------------------- /onekey_core/extension.py: -------------------------------------------------------------------------------- 1 | _C = None 2 | 3 | 4 | def _lazy_import(): 5 | """ 6 | Make sure that CUDA versions match between the pytorch install and torchvision install 7 | """ 8 | global _C 9 | if _C is not None: 10 | return _C 11 | import torch 12 | from torchvision import _C as C 13 | _C = C 14 | if hasattr(_C, "CUDA_VERSION") and torch.version.cuda is not None: 15 | tv_version = str(_C.CUDA_VERSION) 16 | if int(tv_version) < 10000: 17 | tv_major = int(tv_version[0]) 18 | tv_minor = int(tv_version[2]) 19 | else: 20 | tv_major = int(tv_version[0:2]) 21 | tv_minor = int(tv_version[3]) 22 | t_version = torch.version.cuda 23 | t_version = t_version.split('.') 24 | t_major = int(t_version[0]) 25 | t_minor = int(t_version[1]) 26 | if t_major != tv_major or t_minor != tv_minor: 27 | raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. " 28 | "PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. " 29 | "Please reinstall the torchvision that matches your PyTorch install." 30 | .format(t_major, t_minor, tv_major, tv_minor)) 31 | return _C 32 | -------------------------------------------------------------------------------- /onekey_core/core/image_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2022/02/23 4 | # Forum: www.medai.icu 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2022 All Rights Reserved. 7 | import os 8 | from typing import Optional 9 | 10 | import nibabel 11 | import nrrd 12 | import numpy as np 13 | 14 | 15 | def image_loader_3d(impath: str, root='', index_order='F') -> Optional[np.ndarray]: 16 | """ 17 | Args: 18 | impath: image path 19 | root: Where impath is relative path, use root to concat impath 20 | index_order: {'C', 'F'}, optional 21 | Specifies the index order of the resulting data array. Either 'C' (C-order) where the dimensions are ordered from 22 | slowest-varying to fastest-varying (e.g. (z, y, x)), or 'F' (Fortran-order) where the dimensions are ordered 23 | from fastest-varying to slowest-varying (e.g. (x, y, z)). 24 | 25 | Returns: 26 | 27 | """ 28 | assert index_order in ['F', 'C'] 29 | impath = os.path.join(root, impath) 30 | if impath and os.path.exists(impath): 31 | if impath.endswith('.nrrd'): 32 | nrrd_data, _ = nrrd.read(impath, index_order=index_order) 33 | return nrrd_data 34 | elif impath.endswith('.nii.gz') or impath.endswith('.nii'): 35 | image = nibabel.load(impath).get_data() 36 | if index_order == 'C': 37 | image = np.transpose(image, [2, 1, 0]) 38 | return image 39 | else: 40 | return None 41 | -------------------------------------------------------------------------------- /onekey_comp/comp9-Solutions/sol1. 传统组学-单中心-临床/stats.csv: -------------------------------------------------------------------------------- 1 | feature_name,train-label=ALL,train-label=0,train-label=1,pvalue,test-label=ALL,test-label=0,test-label=1,pvalue 2 | duration,35.73±21.55,34.89±21.26,37.54±22.23,0.4326775003387858,33.39±20.65,30.86±19.32,38.97±23.02,0.21036378365346842 3 | age,45.13±17.08,43.52±17.65,48.60±15.34,0.056738552598900364,43.88±16.99,43.79±16.50,44.07±18.64,0.9586586387141496 4 | BMI,23.25±2.58,23.30±2.50,23.13±2.77,0.6790647000194576,22.53±2.05,22.57±1.96,22.45±2.30,0.8515367366646331 5 | chemotherapy,,,,1.0,,,,0.7202057424438955 6 | 0,77(40.74),53(41.09),24(40.00),,19(39.58),12(36.36),7(46.67), 7 | 1,112(59.26),76(58.91),36(60.00),,29(60.42),21(63.64),8(53.33), 8 | gender,,,,0.6242573449829052,,,,1.0 9 | 0,17(8.99),13(10.08),4(6.67),,3(6.25),2(6.06),1(6.67), 10 | 1,172(91.01),116(89.92),56(93.33),,45(93.75),31(93.94),14(93.33), 11 | result,,,,0.9503540997953493,,,,1.0 12 | 0,103(54.50),71(55.04),32(53.33),,25(52.08),17(51.52),8(53.33), 13 | 1,86(45.50),58(44.96),28(46.67),,23(47.92),16(48.48),7(46.67), 14 | degree,,,,0.8449320916713244,,,,0.7736986468971447 15 | 0,43(22.75),29(22.48),14(23.33),,10(20.83),6(18.18),4(26.67), 16 | 1,144(76.19),99(76.74),45(75.00),,38(79.17),27(81.82),11(73.33), 17 | 2,2(1.06),1(0.78),1(1.67),,,,, 18 | Tstage,,,,0.044142432830072936,,,,0.32629482949200733 19 | 0,38(20.11),26(20.16),12(20.00),,14(29.17),10(30.30),4(26.67), 20 | 1,43(22.75),36(27.91),7(11.67),,16(33.33),9(27.27),7(46.67), 21 | 2,55(29.10),37(28.68),18(30.00),,10(20.83),9(27.27),1(6.67), 22 | 3,53(28.04),30(23.26),23(38.33),,8(16.67),5(15.15),3(20.00), 23 | smoke,,,,0.19387120575076355,,,,0.4033953048926282 24 | 0,51(26.98),39(30.23),12(20.00),,8(16.67),4(12.12),4(26.67), 25 | 1,138(73.02),90(69.77),48(80.00),,40(83.33),29(87.88),11(73.33), 26 | drink,,,,0.4036682203617079,,,,0.2082086705075723 27 | 0,66(34.92),42(32.56),24(40.00),,12(25.00),6(18.18),6(40.00), 28 | 1,123(65.08),87(67.44),36(60.00),,36(75.00),27(81.82),9(60.00), 29 | -------------------------------------------------------------------------------- /onekey_core/core/model_factory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2019/7/24 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2019 All Rights Reserved. 7 | 8 | from onekey_core import models 9 | __all__ = ['create_model'] 10 | 11 | 12 | def create_model(model_name, **kwargs): 13 | """Create core that torch vision supported. Supported `model_name` is as followings. 14 | alexnet, AlexNet, 15 | ResNet, resnet18, resnet34, resnet50, resnet101, resnet152, wide_resnet50_2, wide_resnet101_2, 16 | resnext50_32x4d, resnext101_32x8d, 17 | VGG, vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19_bn, vgg19, 18 | SqueezeNet, squeezenet1_0, squeezenet1_1, 19 | Inception3, inception_v3, googlenet, GoogLeNet, 20 | DenseNet, densenet121, densenet169, densenet201, densenet161, 21 | MobileNetV2, mobilenet_v2, 22 | MNASNet, mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3, 23 | ShuffleNetV2, shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5, shufflenet_v2_x2_0, 24 | detection.*, 25 | segmentation.* 26 | 27 | :param model_name: The above core name. 28 | :param kwargs: other core settings. 29 | :return: the matched core. 30 | :raise: ValueError, whose `model_name` is not supported. 31 | """ 32 | supported_models = [k for k in models.__dict__ 33 | if not k.startswith('_') and type(models.__dict__[k]).__name__ != 'module'] 34 | supported_models.extend(['detection', 'segmentation', 'segmentation3d', 'classification3d', 'fusion']) 35 | _modules = model_name.split('.') 36 | if len(_modules) == 1: 37 | if _modules[0] in supported_models: 38 | return models.__dict__[_modules[0]](**kwargs) 39 | elif len(_modules) == 2: 40 | if _modules[0] in supported_models and _modules[1] in models.__dict__[_modules[0]].__dict__: 41 | return models.__dict__[_modules[0]].__dict__[_modules[1]](**kwargs) 42 | 43 | raise ValueError(f'{model_name} not supported!') 44 | -------------------------------------------------------------------------------- /onekey_algo/custom/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2022/3/22 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2022 All Rights Reserved. 7 | import copy 8 | from typing import List, Union 9 | 10 | import numpy as np 11 | import pandas as pd 12 | 13 | from onekey_algo.utils.about_log import logger 14 | 15 | 16 | def map2numerical(data: pd.DataFrame, mapping_columns: Union[str, List[str]], inplace=True, map_nan: bool = False): 17 | """ 18 | 把数据集的非数值数据映射成分类数值 19 | Args: 20 | data: 数据 21 | mapping_columns: 需要映射的列 22 | inplace: bool 23 | map_nan: 是否映射空值,默认不进行映射,可以在后续任务进行填充 24 | 25 | Returns: 26 | 27 | """ 28 | mapping = {} 29 | if not inplace: 30 | new_data = copy.deepcopy(data) 31 | else: 32 | new_data = data 33 | if not isinstance(mapping_columns, list): 34 | mapping_columns = [mapping_columns] 35 | assert all(c in data.columns for c in mapping_columns) 36 | if map_nan: 37 | for c in mapping_columns: 38 | unique_labels = {v: idx for idx, v in enumerate(sorted(np.unique(np.array(data[c]).astype(str))))} 39 | mapping[c] = unique_labels 40 | new_data[[c]] = new_data[[c]].applymap(lambda x: unique_labels[str(x)]) 41 | else: 42 | for c in mapping_columns: 43 | ul = sorted([ul_ for ul_ in np.unique(np.array(data[c]).astype(str)) if ul_ != 'nan']) 44 | unique_labels = {v: idx for idx, v in enumerate(ul)} 45 | mapping[c] = unique_labels 46 | new_data[[c]] = new_data[[c]].applymap(lambda x: unique_labels[str(x)] if str(x) != 'nan' else None) 47 | return new_data, mapping 48 | 49 | 50 | def print_join_info(left: pd.DataFrame, right: pd.DataFrame, on='ID'): 51 | left_set = set(left[on]) 52 | right_set = set(right[on]) 53 | if left_set == right_set: 54 | logger.info(f'{on}特征完全匹配!') 55 | else: 56 | logger.warning(f"存在{on}特征不完全匹配的问题!在左边不在右边的{on}:{left_set - right_set};" 57 | f"在右边不在左边的{on}:{right_set - left_set}") 58 | -------------------------------------------------------------------------------- /onekey_core/core/test_factory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2019/7/24 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2019 All Rights Reserved. 7 | import torch.nn as nn 8 | 9 | from onekey_core.core.losses_factory import create_losses 10 | from onekey_core.core import create_model 11 | from onekey_core.core import create_optimizer, create_lr_scheduler 12 | 13 | 14 | def test_create_model(): 15 | create_model('inception_v3', num_classes=100, pretrained=False) 16 | create_model('Inception3', num_classes=100) 17 | create_model('densenet161', pretrained=False) 18 | create_model('detection.keypointrcnn_resnet50_fpn', pretrained_backbone=False) 19 | try: 20 | create_model('mobilenet', pretrained=False) 21 | except ValueError: 22 | pass 23 | 24 | 25 | def test_create_optimizer(): 26 | model = create_model('alexnet') 27 | create_optimizer('RMSprop', model.parameters()) 28 | create_optimizer('RMSprop', [{'params': model.features.parameters(), 'lr': 0.01}, 29 | {'params': model.classifier.parameters(), 'lr': 0.001}], 30 | alpha=0.99) 31 | 32 | 33 | def test_create_lr_scheduler(): 34 | model = create_model('alexnet') 35 | rms = create_optimizer('RMSprop', [{'params': model.features.parameters(), 'lr': 0.01}, 36 | {'params': model.classifier.parameters(), 'lr': 0.001}], 37 | alpha=0.99) 38 | create_lr_scheduler('cosine', rms, T_max=10) 39 | 40 | 41 | def test_create_losses(): 42 | assert isinstance(create_losses('l1', reduction='mean'), nn.L1Loss) 43 | losses = create_losses([{'loss': 'softmax_ce'}, {'loss': 'l1'}]) 44 | assert isinstance(losses, list) and isinstance(losses[0], nn.CrossEntropyLoss) and isinstance(losses[1], nn.L1Loss) 45 | try: 46 | create_losses('l1', Error_parm='mean') 47 | except TypeError: 48 | pass 49 | 50 | try: 51 | create_losses('lxxx') 52 | except ValueError: 53 | pass 54 | 55 | try: 56 | create_losses([{'loss': 'softmax_ce', 'reduction': 'mean'}, {'loss': 'l1'}]) 57 | except AssertionError: 58 | pass 59 | -------------------------------------------------------------------------------- /onekey_core/core/transformer_factory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2019/7/23 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2019 All Rights Reserved. 7 | 8 | from monai.transforms import ( 9 | AddChannel, 10 | Compose, 11 | RandRotate90, 12 | Resize, 13 | ScaleIntensity, 14 | EnsureType, 15 | ) 16 | from torchvision.transforms import transforms 17 | 18 | __all__ = ['create_standard_image_transformer'] 19 | 20 | 21 | def create_standard_image_transformer(input_size, phase='train', normalize_method='imagenet', is_nii: bool = False, 22 | **kwargs): 23 | """Standard image transformer. 24 | 25 | :param input_size: The core's input image size. 26 | :param phase: phase of transformer, train or valid or test supported. 27 | :param normalize_method: Normalize method, imagenet or -1+1 supported. 28 | :param is_nii: 是不是多通过nii,当成2d来训练 29 | :return: 30 | """ 31 | assert phase in ['train', 'valid', 'test'], "`phase` not found, only 'train', 'valid', 'test' supported!" 32 | normalize = {'imagenet': [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]], 33 | '-1+1': [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]} 34 | assert normalize_method in normalize, "`normalize_method` not found, only 'imagenet', '-1+1' supported!" 35 | if not is_nii: 36 | if phase == 'train': 37 | return transforms.Compose([ 38 | transforms.RandomResizedCrop(input_size), 39 | # transforms.Resize(input_size), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ToTensor(), 42 | transforms.Normalize(*normalize[normalize_method])]) 43 | else: 44 | return transforms.Compose([ 45 | transforms.Resize(input_size), 46 | # transforms.CenterCrop(input_size), 47 | transforms.ToTensor(), 48 | transforms.Normalize(*normalize[normalize_method])]) 49 | else: 50 | roi_size = kwargs.get('roi_size', [3, 96, 96]) 51 | if phase == 'train': 52 | return Compose([ScaleIntensity(), AddChannel(), Resize(roi_size), EnsureType()]) 53 | else: 54 | return Compose([ScaleIntensity(), AddChannel(), Resize(roi_size), EnsureType()]) 55 | -------------------------------------------------------------------------------- /onekey_core/models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .utils import load_state_dict_from_url 5 | 6 | __all__ = ['AlexNet', 'alexnet'] 7 | 8 | model_urls = { 9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 10 | } 11 | 12 | 13 | class AlexNet(nn.Module): 14 | 15 | def __init__(self, in_channels=3, num_classes=1000, **kwargs): 16 | super(AlexNet, self).__init__() 17 | self.features = nn.Sequential( 18 | nn.Conv2d(in_channels, 64, kernel_size=11, stride=4, padding=2), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=3, stride=2), 21 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | ) 32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 33 | self.classifier = nn.Sequential( 34 | nn.Dropout(), 35 | nn.Linear(256 * 6 * 6, 4096), 36 | nn.ReLU(inplace=True), 37 | nn.Dropout(), 38 | nn.Linear(4096, 4096), 39 | nn.ReLU(inplace=True), 40 | nn.Linear(4096, num_classes), 41 | ) 42 | 43 | def forward(self, x): 44 | x = self.features(x) 45 | x = self.avgpool(x) 46 | x = torch.flatten(x, 1) 47 | x = self.classifier(x) 48 | return x 49 | 50 | 51 | def alexnet(pretrained=False, progress=True, transfer_learning=True, **kwargs): 52 | r"""AlexNet core architecture from the 53 | `"One weird trick..." `_ paper. 54 | 55 | Args: 56 | pretrained (bool): If True, returns a core pre-trained on ImageNet 57 | progress (bool): If True, displays a progress bar of the download to stderr 58 | transfer_learning: 59 | """ 60 | model = AlexNet(**kwargs) 61 | if pretrained: 62 | state_dict = load_state_dict_from_url(model_urls['alexnet'], progress=progress) 63 | parameters_list = list(state_dict.keys()) 64 | for k in parameters_list: 65 | if k.startswith('classifier.') and transfer_learning: 66 | del state_dict[k] 67 | model.load_state_dict(state_dict, strict=False) 68 | return model 69 | -------------------------------------------------------------------------------- /onekey_core/models/_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class IntermediateLayerGetter(nn.ModuleDict): 8 | """ 9 | Module wrapper that returns intermediate layers from a core 10 | 11 | It has a strong assumption that the modules have been registered 12 | into the core in the same order as they are used. 13 | This means that one should **not** reuse the same nn.Module 14 | twice in the forward if you want this to work. 15 | 16 | Additionally, it is only able to query submodules that are directly 17 | assigned to the core. So if `core` is passed, `core.feature1` can 18 | be returned, but not `core.feature1.layer2`. 19 | 20 | Arguments: 21 | model (nn.Module): core on which we will extract the features 22 | return_layers (Dict[name, new_name]): a dict containing the names 23 | of the modules for which the activations will be returned as 24 | the key of the dict, and the value of the dict is the name 25 | of the returned activation (which the user can specify). 26 | 27 | Examples:: 28 | 29 | >>> m = torchvision.models.resnet18(pretrained=True) 30 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2` 31 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, 32 | >>> {'layer1': 'feat1', 'layer3': 'feat2'}) 33 | >>> out = new_m(torch.rand(1, 3, 224, 224)) 34 | >>> print([(k, v.shape) for k, v in out.items()]) 35 | >>> [('feat1', torch.Size([1, 64, 56, 56])), 36 | >>> ('feat2', torch.Size([1, 256, 14, 14]))] 37 | """ 38 | 39 | def __init__(self, model, return_layers): 40 | if not set(return_layers).issubset([name for name, _ in model.named_children()]): 41 | raise ValueError("return_layers are not present in core") 42 | 43 | orig_return_layers = return_layers 44 | return_layers = {k: v for k, v in return_layers.items()} 45 | layers = OrderedDict() 46 | for name, module in model.named_children(): 47 | layers[name] = module 48 | if name in return_layers: 49 | del return_layers[name] 50 | if not return_layers: 51 | break 52 | 53 | super(IntermediateLayerGetter, self).__init__(layers) 54 | self.return_layers = orig_return_layers 55 | 56 | def forward(self, x): 57 | out = OrderedDict() 58 | for name, module in self.named_children(): 59 | x = module(x) 60 | if name in self.return_layers: 61 | out_name = self.return_layers[name] 62 | out[out_name] = x 63 | return out 64 | -------------------------------------------------------------------------------- /onekey_algo/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2022/1/18 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2022 All Rights Reserved. 7 | __VERSION__ = '2.4.0' 8 | 9 | import json 10 | import os 11 | 12 | import yaml 13 | 14 | 15 | def hello_onekey(): 16 | from onekey_algo.custom.components import comp1 17 | print(f""" 18 | ####################################################### 19 | ## 欢迎使用Onekey,当前版本:{__VERSION__} ## 20 | ## OnekeyAI助力科研,我们将竭诚为您服务! ## 21 | ####################################################### 22 | """) 23 | 24 | 25 | if os.environ.get('ONEKEY_HOME'): 26 | ONEKEYDS_ROOT = os.path.join(os.environ.get('ONEKEY_HOME'), 'OnekeyDS') 27 | else: 28 | ONEKEYDS_ROOT = os.environ.get('ONEKEY_HOME') or os.path.expanduser(r'~/Project/OnekeyDS') 29 | 30 | 31 | class OnekeyDS: 32 | ct = os.path.join(ONEKEYDS_ROOT, 'CT') 33 | ct_features = os.path.join(ONEKEYDS_ROOT, 'CT', 'rad_features.csv') 34 | tumour_stroma = os.path.join(ONEKEYDS_ROOT, 'tumour_stroma') 35 | complaint = os.path.join(ONEKEYDS_ROOT, "complaint.csv") 36 | grade = os.path.join(ONEKEYDS_ROOT, 'grade.csv') 37 | Metabonomics = os.path.join(ONEKEYDS_ROOT, 'Metabonomics.csv') 38 | phy_bio = os.path.join(ONEKEYDS_ROOT, 'phy_bio.csv') 39 | survival = os.path.join(ONEKEYDS_ROOT, 'survival.csv') 40 | 41 | 42 | def get_config(directory=os.getcwd(), config_file='config.txt') -> dict: 43 | if os.path.exists(os.path.join(directory, config_file)): 44 | with open(os.path.join(directory, config_file), encoding='utf8') as c: 45 | content = c.read() 46 | if '\\\\' not in content: 47 | content = content.replace('\\', '\\\\') 48 | if config_file.endswith('.txt'): 49 | config = json.loads(content) 50 | elif config_file.endswith('.yaml'): 51 | config = yaml.load(content, Loader=yaml.FullLoader) 52 | return config 53 | else: 54 | return {} 55 | 56 | 57 | def get_param_in_cwd(param: str, default=None, **kwargs): 58 | directory = kwargs.get('directory', os.getcwd()) 59 | config_file = 'config.yaml' if os.path.exists(os.path.join(directory, 'config.yaml')) else 'config.txt' 60 | config = get_config(directory, config_file) 61 | ret = config.get(param, None) or default 62 | if isinstance(ret, str) and 'ONEKEY_HOME' in ret: 63 | ret = ret.replace('ONEKEY_HOME', os.environ.get('ONEKEY_HOME')) 64 | if isinstance(ret, str) and 'ONEKEYDS_HOME' in ret: 65 | ret = ret.replace('ONEKEYDS_HOME', ONEKEYDS_ROOT) 66 | return ret 67 | 68 | 69 | if __name__ == '__main__': 70 | okds = OnekeyDS() 71 | print(okds.ct) 72 | -------------------------------------------------------------------------------- /onekey_comp/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /onekey_core/core/optimizer_lr_factory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2019/7/24 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2019 All Rights Reserved. 7 | import torch.optim as optim 8 | 9 | __all__ = ['create_lr_scheduler', 'create_optimizer'] 10 | 11 | 12 | def create_optimizer(opt_name, parameters, **kwargs): 13 | """ 14 | Create optimizer with specific optimizer name. Supported optimizers are as followings. 15 | 'ASGD', 'Adadelta', 'Adagrad', 'Adam', 'Adamax', 'LBFGS', 'Optimizer', 'RMSprop', 'Rprop', 'SGD', 'SparseAdam' 16 | 17 | :param opt_name: optimizer name. 18 | :param parameters: Parameters to be optimized. 19 | :param kwargs: other optimizer settings. 20 | :return: optimizer 21 | :raises: 22 | ValueError, Optimizer not found. 23 | AssertError, `params` not found in pre-parameter settings. 24 | """ 25 | supported_optimizer = {name.lower(): name for name in optim.__dict__ 26 | if not name.startswith("__") 27 | and callable(optim.__dict__[name])} 28 | if opt_name.lower() not in supported_optimizer: 29 | raise ValueError(f'Optimizer name {opt_name} not supported!') 30 | 31 | # If Pre-parameter settings. 32 | if isinstance(parameters, list): 33 | for param_ in parameters: 34 | if isinstance(param_, dict): 35 | assert 'params' in param_, '`params` must contains in pre-parameter settings.' 36 | 37 | return optim.__dict__[supported_optimizer[opt_name.lower()]](params=parameters, **kwargs) 38 | 39 | 40 | def create_lr_scheduler(scheduler_name: str, optimizer, **kwargs): 41 | """Learning rate scheduler to change lr dynamically. 42 | 43 | :param scheduler_name: learning rate scheduler name 44 | :param optimizer: A instance of optimizer. 45 | :param kwargs: other key args for lr scheduler. 46 | :return: learning rate scheduler. 47 | :raise: ValueError, learning rate scheduler not found. 48 | """ 49 | supported_optimizer = {'lambda': optim.lr_scheduler.LambdaLR, 50 | 'step': optim.lr_scheduler.StepLR, 51 | 'mstep': optim.lr_scheduler.MultiStepLR, 52 | 'exponential': optim.lr_scheduler.ExponentialLR, 53 | 'cosine': optim.lr_scheduler.CosineAnnealingLR, 54 | 'reduce': optim.lr_scheduler.ReduceLROnPlateau, 55 | 'circle': optim.lr_scheduler.CyclicLR} 56 | 57 | if scheduler_name.lower() not in supported_optimizer: 58 | raise ValueError(f'Scheduler name {scheduler_name} not supported!') 59 | return supported_optimizer[scheduler_name](optimizer, **kwargs) 60 | -------------------------------------------------------------------------------- /onekey_comp/comp4-What(分类识别)/生成CV数据-List模式.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "45cf0f85", 6 | "metadata": {}, 7 | "source": [ 8 | "# 拆分数据集\n", 9 | "\n", 10 | "针对What的List模式,进行数据集拆分,形成具有交叉验证或者随机划分的功能。\n", 11 | "\n", 12 | "```python\n", 13 | "def split_dataset(X_data: pd.DataFrame, y_data: pd.DataFrame = None, test_size=0.2, n_trails=10,\n", 14 | " cv: bool = False, shuffle: bool = False, random_state=None, save_dir=None):\n", 15 | " \"\"\"\n", 16 | " 数据划分。\n", 17 | " Args:\n", 18 | " X_data: 训练数据\n", 19 | " y_data: 监督数据\n", 20 | " test_size: 测试集比例\n", 21 | " n_trails: 尝试多少次寻找最佳数据集划分。\n", 22 | " cv: 是否是交叉验证,默认是False,当为True时,n_trails为交叉验证的n_fold\n", 23 | " shuffle: 是否进行随机打乱\n", 24 | " random_state: 随机种子\n", 25 | " save_dir: 信息保存的路径。\n", 26 | "\n", 27 | " Returns: 拆分之后的数据列表\n", 28 | "\n", 29 | " \"\"\"\n", 30 | " ```" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "31f0e7a6", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import os\n", 41 | "import pandas as pd\n", 42 | "\n", 43 | "from onekey_algo import OnekeyDS as okds\n", 44 | "from onekey_algo.custom.components.comp2 import split_dataset\n", 45 | "\n", 46 | "data = pd.read_csv(os.path.join(okds.ct, 'label.csv'))\n", 47 | "\n", 48 | "rt = split_dataset(data, data['label'], cv=False, save_dir='.')\n", 49 | "x1, x2 = rt[0]" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "006b7c34", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "x1" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "a5544e6d", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "x2" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "c1eabae6", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [] 79 | } 80 | ], 81 | "metadata": { 82 | "kernelspec": { 83 | "display_name": "Python 3 (ipykernel)", 84 | "language": "python", 85 | "name": "python3" 86 | }, 87 | "language_info": { 88 | "codemirror_mode": { 89 | "name": "ipython", 90 | "version": 3 91 | }, 92 | "file_extension": ".py", 93 | "mimetype": "text/x-python", 94 | "name": "python", 95 | "nbconvert_exporter": "python", 96 | "pygments_lexer": "ipython3", 97 | "version": "3.7.12" 98 | } 99 | }, 100 | "nbformat": 4, 101 | "nbformat_minor": 5 102 | } 103 | -------------------------------------------------------------------------------- /onekey_comp/comp0. MIET(医学图像处理)/Module1. 超分辨率采样.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "98d4f33d", 6 | "metadata": {}, 7 | "source": [ 8 | "# README\n", 9 | "\n", 10 | "这是一个用Python编写的医学图像处理程序 **M**edical **I**mage **E**nhancement **T**ool **B**ox(MIETB) ,可以用于各种医学图像处理相关的任务。\n", 11 | "\n", 12 | "## 效果说明\n", 13 | "\n", 14 | "该程序可以实现以下功能:\n", 15 | "\n", 16 | "- 图像超分辨率重建\n", 17 | " - scale 输入一个放大倍数,支持[2,4]倍\n", 18 | " - 其他高倍数的算法定制,可以联系张老师,微信:OnekeyAI4U\n", 19 | "\n", 20 | "## 获取待提取特征的文件\n", 21 | "\n", 22 | "提供两种批量处理的模式:\n", 23 | "1. 目录模式,提取指定目录下的所有.nii.gz文件的特征。默认寻找目录下所有的nii.gz数据\n", 24 | "2. 文件模式,待提取的数据存储在文件中,每行一个样本。" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "df6f2a55", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import os\n", 35 | "os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'\n", 36 | "from onekey_algo import OnekeyDS as okds\n", 37 | "from onekey_algo import get_param_in_cwd\n", 38 | "\n", 39 | "# 目录模式\n", 40 | "scale = get_param_in_cwd('scale', 4)\n", 41 | "mydir = get_param_in_cwd('rad_dir', os.path.join(okds.ct, 'images'))\n", 42 | "samples = []\n", 43 | "for r, ds, fs in os.walk(mydir):\n", 44 | " samples.extend([os.path.join(r, p) for p in fs if p.endswith('.nii.gz')])\n", 45 | "\n", 46 | "# 文件模式\n", 47 | "# test_file = ''\n", 48 | "# with open(test_file) as f:\n", 49 | "# test_samples = [l.strip() for l in f.readlines()]\n", 50 | "\n", 51 | "# 自定义模式\n", 52 | "# test_sampleses = ['path2nii.gz']\n", 53 | "samples" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "577fbf57", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "from onekey_algo.mietb.super_resolution.eval_super_res_reconstruction import init as init_super\n", 64 | "from onekey_algo.mietb.super_resolution.eval_super_res_reconstruction import inference as inference_super\n", 65 | "\n", 66 | "save_dir = get_param_in_cwd('save_dir', None)\n", 67 | "print(save_dir)\n", 68 | "model, device = init_super(scale)\n", 69 | "inference_super(samples, model, device, scale, save_dir=save_dir)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "2157e73d", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [] 79 | } 80 | ], 81 | "metadata": { 82 | "kernelspec": { 83 | "display_name": "Python 3 (ipykernel)", 84 | "language": "python", 85 | "name": "python3" 86 | }, 87 | "language_info": { 88 | "codemirror_mode": { 89 | "name": "ipython", 90 | "version": 3 91 | }, 92 | "file_extension": ".py", 93 | "mimetype": "text/x-python", 94 | "name": "python", 95 | "nbconvert_exporter": "python", 96 | "pygments_lexer": "ipython3", 97 | "version": "3.7.12" 98 | } 99 | }, 100 | "nbformat": 4, 101 | "nbformat_minor": 5 102 | } 103 | -------------------------------------------------------------------------------- /onekey_algo/mietb/super_resolution/RRDBNet_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def make_layer(block, n_layers): 8 | layers = [] 9 | for _ in range(n_layers): 10 | layers.append(block()) 11 | return nn.Sequential(*layers) 12 | 13 | 14 | class ResidualDenseBlock_5C(nn.Module): 15 | def __init__(self, nf=64, gc=32, bias=True): 16 | super(ResidualDenseBlock_5C, self).__init__() 17 | # gc: growth channel, i.e. intermediate channels 18 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 19 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 20 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 21 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 22 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 23 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 24 | 25 | # initialization 26 | # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 27 | 28 | def forward(self, x): 29 | x1 = self.lrelu(self.conv1(x)) 30 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 31 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 32 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 33 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 34 | return x5 * 0.2 + x 35 | 36 | 37 | class RRDB(nn.Module): 38 | '''Residual in Residual Dense Block''' 39 | 40 | def __init__(self, nf, gc=32): 41 | super(RRDB, self).__init__() 42 | self.RDB1 = ResidualDenseBlock_5C(nf, gc) 43 | self.RDB2 = ResidualDenseBlock_5C(nf, gc) 44 | self.RDB3 = ResidualDenseBlock_5C(nf, gc) 45 | 46 | def forward(self, x): 47 | out = self.RDB1(x) 48 | out = self.RDB2(out) 49 | out = self.RDB3(out) 50 | return out * 0.2 + x 51 | 52 | 53 | class RRDBNet(nn.Module): 54 | def __init__(self, in_nc, out_nc, nf, nb, gc=32): 55 | super(RRDBNet, self).__init__() 56 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) 57 | 58 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 59 | self.RRDB_trunk = make_layer(RRDB_block_f, nb) 60 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 61 | #### upsampling 62 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 63 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 64 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 65 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 66 | 67 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 68 | 69 | def forward(self, x): 70 | fea = self.conv_first(x) 71 | trunk = self.trunk_conv(self.RRDB_trunk(fea)) 72 | fea = fea + trunk 73 | 74 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) 75 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) 76 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 77 | 78 | return out -------------------------------------------------------------------------------- /onekey_algo/classification/README.md: -------------------------------------------------------------------------------- 1 | ## run_classification 2 | ```shell 3 | $ python run_classification.py -h 4 | usage: run_classification.py [-h] [--train [TRAIN [TRAIN ...]]] 5 | [--valid [VALID [VALID ...]]] 6 | [--labels_file LABELS_FILE] [-j J] 7 | [--max2use MAX2USE] 8 | [--data_pattern [DATA_PATTERN [DATA_PATTERN ...]]] 9 | [--normalize_method {-1+1,imagenet}] 10 | [--model_name MODEL_NAME] 11 | [--gpus [GPUS [GPUS ...]]] 12 | [--batch_size BATCH_SIZE] [--epochs EPOCHS] 13 | [--init_lr INIT_LR] [--optimizer OPTIMIZER] 14 | [--retrain RETRAIN] [--model_root MODEL_ROOT] 15 | [--iters_verbose ITERS_VERBOSE] 16 | [--iters_start ITERS_START] [--pretrained] 17 | 18 | PyTorch Classification Training 19 | 20 | optional arguments: 21 | -h, --help show this help message and exit 22 | --train [TRAIN [TRAIN ...]] 23 | Training dataset 24 | --valid [VALID [VALID ...]] 25 | Validation dataset 26 | --labels_file LABELS_FILE 27 | Labels file 28 | -j J, --worker J Number of workers.(default=1) 29 | --max2use MAX2USE Maximum number of sample per class to be used! 30 | --data_pattern [DATA_PATTERN [DATA_PATTERN ...]] 31 | Where to save origin image data. 32 | --normalize_method {-1+1,imagenet} 33 | Normalize method. 34 | --model_name MODEL_NAME 35 | Model name 36 | --gpus [GPUS [GPUS ...]] 37 | GPU index to be used! 38 | --batch_size BATCH_SIZE 39 | --epochs EPOCHS number of total epochs to run 40 | --init_lr INIT_LR initial learning rate 41 | --optimizer OPTIMIZER 42 | Optimizer 43 | --retrain RETRAIN Retrain from path 44 | --model_root MODEL_ROOT 45 | path where to save 46 | --iters_verbose ITERS_VERBOSE 47 | print frequency 48 | --iters_start ITERS_START 49 | Iters start 50 | --pretrained Use pretrained core or not 51 | ``` 52 | 53 | ### 常用参数 54 | 55 | * `train`,`valid`:分别指定训练集和测试集,均不能为空。数据集格式支持两种。 56 | 1. FolderDataset,传入对应数据的上级目录,每一个类别的数据放在一个文件夹中。 57 | 2. ListDataset,传入的为文本文件,行一个样本,支持2列和6列数据,其中两列的输入最为常用,第一列为文件名称,第二列标签。6列数据中后四列为(`x`, `y`, `width`, `height`)即对bbox内部的物体进行分类。 58 | 59 | * `labels_file`:当数据集为FolderDataset时,只会使用`labels_file`中指定的标签进行训练,文件格式为每行一个标签。 60 | * `data_pattern`:当数据集为ListDataset时,文件路径最终为`data_pattern`\\`file_name`。 61 | * `-j`:读取数据的并发度。最大建议为CPU的线程数。 62 | * `max2use`:数据集中的每个标签最大使用的样本数。默认全部使用。 63 | * `normalize_method`:对数据正则化的方法,具体操作为减均值除方差。 64 | * `model_name`:选用的模型名称,例如`resnet18`、`inception_v3`等。 65 | * `gpus`:使用到的GPU序号。 66 | * `retrain`:接着上次训练保存的参数继续训练。 67 | * `model_root`:模型参数存储的位置。 68 | * `iters_verbose`:多少次迭代打印一次log。 69 | * `iters_start`:指定当前训练的步数。 70 | * `pretrained`:开关参数,是否使用预训练的参数。 71 | 72 | ### COVID-19 73 | ``shell 74 | python run_classification.py --train $DS_ROOT/data.txt --valid $DS_ROOT/data.txt --labels_file $DS_ROOT/labels.txt 75 | --data_pattern $DS_ROOT/images/ --model_root $DS_ROOT/model -j 2 --batch_size 16 --init_lr 0.1 --epoch 300 76 | `` -------------------------------------------------------------------------------- /onekey_core/core/losses_factory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2019/7/25 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2019 All Rights Reserved. 7 | import torch.nn as nn 8 | 9 | __all__ = ['create_losses'] 10 | 11 | 12 | def create_losses(losses, **kwargs): 13 | r""" 14 | Create losses with specified loss name. Supported loss are as followings. 15 | 'AdaptiveLogSoftmaxWithLoss', 'BCELoss', 'BCEWithLogitsLoss', 'CTCLoss', 'CosineEmbeddingLoss', 16 | 'CrossEntropyLoss', 'HingeEmbeddingLoss', 'KLDivLoss', 'L1Loss', 'MSELoss', 'MarginRankingLoss', 17 | 'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'NLLLoss', 'NLLLoss2d', 18 | 'PoissonNLLLoss', 'SmoothL1Loss', 'SoftMarginLoss', 'TripletMarginLoss' 19 | 20 | `losses` can be str for loss name or a list of dict type for multi losses combination. 21 | ```python 22 | loss1 = create_losses('ce', input=input, target=target) 23 | # Or a list of dict type. 24 | loss2 = create_losses([{'loss':'softmax_ce', kwargs:{'reduction':'mean'}}, 25 | {"loss":'sigmoid', kwargs:{'reduction':'mean', 'pos_weight':None}}]) 26 | ``` 27 | The following is dict's params 28 | :param: loss, specify the loss. REQUIRED! 29 | :param: kwargs, other loss settings. 30 | 31 | :param losses: loss name or a list of dict type for multi losses combination. 32 | :param kwargs: other loss settings. 33 | :return: combined loss. 34 | :raises: 35 | ValueError, loss not found. 36 | AssertError 37 | type of each item in `losses` is not dict if use losses combination. 38 | `loss` not found in multi losses combination settings. 39 | 40 | """ 41 | supported_losses = {'softmax_ce': nn.CrossEntropyLoss, # Softmax cross entropy for single label. 42 | 'sigmoid_ce': nn.BCEWithLogitsLoss, # Sigmoid cross entropy for single label. 43 | 'bce': nn.BCELoss, # Binary classification targets without sigmoid activation. 44 | 'cosine_embedding': nn.CosineEmbeddingLoss, # Cosine embedding loss. 45 | 'ctc': nn.CTCLoss, # CTC loss. 46 | 'hinge': nn.HingeEmbeddingLoss, 47 | 'kl': nn.KLDivLoss, # KL divergence Loss for continuous targets. 48 | 'l1': nn.L1Loss, 49 | 'smooth_l1': nn.L1Loss, 50 | 'triplet': nn.TripletMarginLoss, # Triplet loss. 51 | 'mse': nn.MSELoss, 52 | 'ranking': nn.MarginRankingLoss, 53 | 'multi_sigmoid': nn.MultiLabelSoftMarginLoss # Multi label sigmoid loss. 54 | } 55 | 56 | def _form_loss(loss_name, **spec_loss_kwargs): 57 | if loss_name not in supported_losses: 58 | raise ValueError(f'Loss name {loss_name} not supported!') 59 | return supported_losses[loss_name](**spec_loss_kwargs) 60 | 61 | if isinstance(losses, list): 62 | assert all(isinstance(l, dict) and 'loss' in l for l in losses) 63 | for l in losses: 64 | if 'kwargs' not in l: 65 | l['kwargs'] = {} 66 | return [_form_loss(l['loss'], **l['kwargs']) for l in losses] 67 | else: 68 | return _form_loss(losses, **kwargs) 69 | -------------------------------------------------------------------------------- /onekey_algo/datasets/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | from typing import Any, Callable, List, Optional, Tuple 5 | 6 | 7 | class VisionDataset(data.Dataset): 8 | _repr_indent = 4 9 | 10 | def __init__( 11 | self, 12 | root: str, 13 | transforms: Optional[Callable] = None, 14 | transform: Optional[Callable] = None, 15 | target_transform: Optional[Callable] = None, 16 | ) -> None: 17 | if isinstance(root, torch._six.string_classes): 18 | root = os.path.expanduser(root) 19 | self.root = root 20 | 21 | has_transforms = transforms is not None 22 | has_separate_transform = transform is not None or target_transform is not None 23 | if has_transforms and has_separate_transform: 24 | raise ValueError("Only transforms or transform/target_transform can " 25 | "be passed as argument") 26 | 27 | # for backwards-compatibility 28 | self.transform = transform 29 | self.target_transform = target_transform 30 | 31 | if has_separate_transform: 32 | transforms = StandardTransform(transform, target_transform) 33 | self.transforms = transforms 34 | 35 | def __getitem__(self, index: int) -> Any: 36 | raise NotImplementedError 37 | 38 | def __len__(self) -> int: 39 | raise NotImplementedError 40 | 41 | def __repr__(self) -> str: 42 | head = "Dataset " + self.__class__.__name__ 43 | body = ["Number of datapoints: {}".format(self.__len__())] 44 | if self.root is not None: 45 | body.append("Root location: {}".format(self.root)) 46 | body += self.extra_repr().splitlines() 47 | if hasattr(self, "transforms") and self.transforms is not None: 48 | body += [repr(self.transforms)] 49 | lines = [head] + [" " * self._repr_indent + line for line in body] 50 | return '\n'.join(lines) 51 | 52 | def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: 53 | lines = transform.__repr__().splitlines() 54 | return (["{}{}".format(head, lines[0])] + 55 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 56 | 57 | def extra_repr(self) -> str: 58 | return "" 59 | 60 | 61 | class StandardTransform(object): 62 | def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: 63 | self.transform = transform 64 | self.target_transform = target_transform 65 | 66 | def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]: 67 | if self.transform is not None: 68 | input = self.transform(input) 69 | if self.target_transform is not None: 70 | target = self.target_transform(target) 71 | return input, target 72 | 73 | def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: 74 | lines = transform.__repr__().splitlines() 75 | return (["{}{}".format(head, lines[0])] + 76 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 77 | 78 | def __repr__(self) -> str: 79 | body = [self.__class__.__name__] 80 | if self.transform is not None: 81 | body += self._format_transform_repr(self.transform, 82 | "Transform: ") 83 | if self.target_transform is not None: 84 | body += self._format_transform_repr(self.target_transform, 85 | "Target transform: ") 86 | 87 | return '\n'.join(body) 88 | -------------------------------------------------------------------------------- /onekey_algo/classification/inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2019/7/22 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2019 All Rights Reserved. 7 | import argparse 8 | import os 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from onekey_core.core import create_model 14 | from onekey_core.core import create_standard_image_transformer 15 | from onekey_algo.datasets.image_loader import default_loader 16 | from onekey_algo.utils.MultiProcess import MultiProcess 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch Classification Inference') 19 | 20 | parser.add_argument('-c', '--config', dest='c', required=True, help='Model and transformer configuration') 21 | parser.add_argument('-m', '--core', dest='core', required=True, help='Model parameters!') 22 | parser.add_argument('-d', '--directory', dest='d', default=None, help='Inference data directory.') 23 | parser.add_argument('-l', '--list_file', dest='l', default=None, help='Inference data list file') 24 | parser.add_argument('--labels_file', default=None, help='Labels file') 25 | parser.add_argument('--gpus', type=int, nargs='*', default=None, help='GPU index to be used!') 26 | parser.add_argument('--num_process', type=int, default=1, help='Number of process!') 27 | 28 | args = parser.parse_args() 29 | 30 | 31 | def test_model(samples, thread_id, params): 32 | # config = json.loads(open(params.config).read()) 33 | config = {'transform': {'input_size': 299, 'normalize_method': '-1+1'}, 34 | 'core': {'model_name': 'inception_v3', 'num_classes': 3}} 35 | 36 | # Configuration of transformer. 37 | transform_config = {'phase': 'valid'} 38 | if 'transform' in config: 39 | transform_config.update(config['transform']) 40 | assert 'input_size' in transform_config, '`input_size` must in `transform`' 41 | transformer = create_standard_image_transformer(**transform_config) 42 | 43 | # Configuration of core 44 | model_config = {'pretrained': False} 45 | if 'core' in config: 46 | model_config.update(config['core']) 47 | assert 'model_name' in model_config and 'num_classes' in model_config, '`model_name` and `num_classes` must in ' \ 48 | '`core`' 49 | model = create_model(**model_config) 50 | # Configuration of device 51 | device_info = 'cpu' 52 | if params.gpus: 53 | gpu_idx = params.gpus[thread_id % len(params.gpus)] 54 | device_info = f"cuda:{gpu_idx}" if torch.cuda.is_available() and gpu_idx else "cpu" 55 | device = torch.device(device_info) 56 | model = model.to(device) 57 | state_dict = torch.load(params.model, map_location=device)['model_state_dict'] 58 | mapped_state_dict = {k.lstrip('module.'): state_dict[k] for k in state_dict} 59 | 60 | model.load_state_dict(mapped_state_dict) 61 | model.eval() 62 | 63 | # Inference 64 | with torch.set_grad_enabled(False): 65 | for sample in samples: 66 | sample_ = transformer(default_loader(sample)) 67 | sample_ = sample_.to(device) 68 | # print(sample_.size()) 69 | outputs = model(sample_.view(1, *sample_.size())) 70 | print(sample, F.softmax(outputs, dim=1)) 71 | 72 | 73 | if __name__ == "__main__": 74 | if args.d is not None: 75 | test_samples = [os.path.join(args.d, p) for p in os.listdir(args.d) if p.endswith('.jpg')] 76 | elif args.l is not None: 77 | with open(args.l) as f: 78 | test_samples = [l.strip() for l in f.readlines()] 79 | else: 80 | raise ValueError('You must provide a directory or list file for inference.') 81 | MultiProcess(test_samples, test_model, num_process=args.num_process, params=args).run() 82 | -------------------------------------------------------------------------------- /onekey_comp/comp0. MIET(医学图像处理)/README.md: -------------------------------------------------------------------------------- 1 | Medical imaging plays a crucial role in the diagnosis and treatment of various diseases. However, the spatial resolution of medical images is often limited due to various factors such as hardware limitations, acquisition time, and radiation exposure. This limitation can lead to difficulties in accurate diagnosis and treatment planning. Therefore, there is a need for techniques that can improve the spatial resolution of medical images. In recent years, deep learning-based super-resolution reconstruction techniques have shown promising results in improving the spatial resolution of medical images. Now, we will provide a 3D super-resolution reconstruction technique for medical images that utilizes a generative adversarial network (GAN) as its basic architecture. 2 | 3 | Super-resolution reconstruction is a technique that aims to improve the spatial resolution of an image beyond the physical limitations of the imaging system. In medical imaging, super-resolution reconstruction can be used to improve the quality of images obtained from various modalities such as computed tomography (CT), magnetic resonance imaging (MRI), and ultrasound. Super-resolution reconstruction techniques can be broadly classified into two categories: interpolation-based and learning-based. Interpolation-based techniques use mathematical models to estimate the high-resolution image from the low-resolution image. However, these techniques often result in blurry images with limited improvement in spatial resolution. Learning-based techniques, on the other hand, use deep learning models to learn the mapping between low-resolution and high-resolution images. These techniques have shown promising results in improving the spatial resolution of medical images. 4 | 5 | The 3D super-resolution reconstruction technique provided in our method utilizes a GAN as its basic architecture. GANs are a type of deep learning model that consists of two networks: a generator network and a discriminator network. The generator network generates high-resolution images from low-resolution images, while the discriminator network distinguishes between real and generated images. The two networks are trained in an adversarial manner, where the generator network tries to generate images that can fool the discriminator network, and the discriminator network tries to distinguish between real and generated images. This adversarial training process helps the generator network to learn the mapping between low-resolution and high-resolution images. 6 | 7 | The dataset used to train the 3D super-resolution reconstruction technique consists of millions of medical images. The images are preprocessed to remove noise and artifacts and to normalize the intensity values. The images are then divided into low-resolution and high-resolution pairs, where the low-resolution images are obtained by downsampling the high-resolution images. The pairs are used to train the GAN model. 8 | 9 | The loss function used in the GAN model consists of three components: gradient loss, L1 loss, and perceptual loss. The gradient loss encourages the generated images to have similar gradient values as the high-resolution images. The L1 loss measures the pixel-wise difference between the generated and high-resolution images. The perceptual loss measures the difference between the feature representations of the generated and high-resolution images obtained from a pre-trained deep learning model. The combination of these loss functions helps to ensure that the generated images are visually similar to the high-resolution images. 10 | 11 | The 3D super-resolution reconstruction technique we provided has shown promising results in improving the spatial resolution of medical images. For example, it can increase the spatial resolution by 4 times while maintaining the original image size. This means that a pixel volume of 1x1x1mm can be transformed into 1x1x0.25mm. The technique has been evaluated on various medical imaging modalities such as CT, MRI, and ultrasound, and has shown significant improvement in image quality and spatial resolution. The technique has also been compared with other state-of-the-art super-resolution reconstruction techniques and has shown superior performance. -------------------------------------------------------------------------------- /onekey_algo/mietb/super_resolution/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 8 | return nn.Conv2d( 9 | in_channels, out_channels, kernel_size, 10 | padding=(kernel_size // 2), bias=bias) 11 | 12 | 13 | class MeanShift(nn.Conv2d): 14 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 15 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 16 | std = torch.Tensor(rgb_std) 17 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 18 | self.weight.data.div_(std.view(3, 1, 1, 1)) 19 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 20 | self.bias.data.div_(std) 21 | self.requires_grad = False 22 | 23 | 24 | class BasicBlock(nn.Sequential): 25 | def __init__( 26 | self, in_channels, out_channels, kernel_size, stride=1, bias=False, 27 | bn=True, act=nn.ReLU(True)): 28 | 29 | m = [nn.Conv2d( 30 | in_channels, out_channels, kernel_size, 31 | padding=(kernel_size // 2), stride=stride, bias=bias) 32 | ] 33 | if bn: m.append(nn.BatchNorm2d(out_channels)) 34 | if act is not None: m.append(act) 35 | super(BasicBlock, self).__init__(*m) 36 | 37 | 38 | class ResBlock(nn.Module): 39 | def __init__( 40 | self, conv, n_feat, kernel_size, 41 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 42 | 43 | super(ResBlock, self).__init__() 44 | m = [] 45 | for i in range(2): 46 | m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 47 | if bn: m.append(nn.BatchNorm2d(n_feat)) 48 | if i == 0: m.append(act) 49 | 50 | self.body = nn.Sequential(*m) 51 | self.res_scale = res_scale 52 | 53 | def forward(self, x): 54 | res = self.body(x).mul(self.res_scale) 55 | res += x 56 | 57 | return res 58 | 59 | 60 | class Upsampler(nn.Sequential): 61 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): 62 | 63 | m = [] 64 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 65 | for _ in range(int(math.log(scale, 2))): 66 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 67 | m.append(nn.PixelShuffle(2)) 68 | if bn: m.append(nn.BatchNorm2d(n_feat)) 69 | if act: m.append(act()) 70 | elif scale == 3: 71 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 72 | m.append(nn.PixelShuffle(3)) 73 | if bn: m.append(nn.BatchNorm2d(n_feat)) 74 | if act: m.append(act()) 75 | else: 76 | raise NotImplementedError 77 | 78 | super(Upsampler, self).__init__(*m) 79 | 80 | 81 | ## add SELayer 82 | class SELayer(nn.Module): 83 | def __init__(self, channel, reduction=16): 84 | super(SELayer, self).__init__() 85 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 86 | self.conv_du = nn.Sequential( 87 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 88 | nn.ReLU(inplace=True), 89 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 90 | nn.Sigmoid() 91 | ) 92 | 93 | def forward(self, x): 94 | y = self.avg_pool(x) 95 | y = self.conv_du(y) 96 | return x * y 97 | 98 | 99 | ## add SEResBlock 100 | class SEResBlock(nn.Module): 101 | def __init__( 102 | self, conv, n_feat, kernel_size, reduction, 103 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 104 | 105 | super(SEResBlock, self).__init__() 106 | modules_body = [] 107 | for i in range(2): 108 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 109 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 110 | if i == 0: modules_body.append(act) 111 | modules_body.append(SELayer(n_feat, reduction)) 112 | self.body = nn.Sequential(*modules_body) 113 | self.res_scale = res_scale 114 | 115 | def forward(self, x): 116 | res = self.body(x) 117 | # res = self.body(x).mul(self.res_scale) 118 | res += x 119 | 120 | return res 121 | -------------------------------------------------------------------------------- /onekey_comp/comp7-Survival/列线图-Nomogram.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "84c4794a", 6 | "metadata": {}, 7 | "source": [ 8 | "## Nomogram\n", 9 | "\n", 10 | "读取`mydir`数据的,尽量保证每个数据都是英文编码。" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "fb9e7879", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import pandas as pd\n", 21 | "from onekey_algo.custom.components import nomogram\n", 22 | "from onekey_algo import OnekeyDS as okds\n", 23 | "\n", 24 | "mydir = r''\n", 25 | "mydir = okds.survival\n", 26 | "df = pd.read_csv(mydir, header=0)\n", 27 | "df.head()" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "def05a4c", 33 | "metadata": {}, 34 | "source": [ 35 | "### 绘制nomogram\n", 36 | "\n", 37 | "绘制的接口和参数,COX回归进行生存时间预测。\n", 38 | "\n", 39 | "```python\n", 40 | "def nomogram(df: Union[str, DataFrame], duration: str, result: str, columns: Union[str, List[str]],\n", 41 | " survs: Union[int, List[int]], surv_names: Union[str, List[str]] = None,\n", 42 | " width: int = 960, height: int = 480) -> Image:\n", 43 | " \"\"\"\n", 44 | " 绘制nomogram图,Nomogram的图存储在当前文件夹下的nomogram.png\n", 45 | " Args:\n", 46 | " df: 数据路径,或者是读取之后的Dataframe格式。\n", 47 | " duration: OS\n", 48 | " result: OST\n", 49 | " columns: 使用那些列计算nomogram\n", 50 | " survs: 生存时间转化成x 年生存率\n", 51 | " surv_names: survs对应的列名。\n", 52 | " width: nomogram分辨率--宽度,默认960\n", 53 | " height: nomogram分辨率--宽度,默认480\n", 54 | "\n", 55 | " Returns: PIL.Image\n", 56 | " \"\"\"\n", 57 | "```" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "id": "99013dd5", 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# 绘制Nomagram\n", 68 | "nomogram.nomogram(df, duration='duration', result='result', \n", 69 | " columns=['age', 'gender', 'degree', 'Tstage', 'BMI', 'chemotherapy'],\n", 70 | " survs=[36, 60], \n", 71 | " surv_names=['3-year Survival', '5-year Survival'], \n", 72 | " height=5800)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "id": "d2d47ca6", 78 | "metadata": {}, 79 | "source": [ 80 | "## 绘制 Risk-nomogram\n", 81 | "\n", 82 | "绘制的接口和参数,使用的是logisitc回归进行风险模型建模\n", 83 | "\n", 84 | "```python\n", 85 | "def risk_nomogram(df: Union[str, DataFrame], result: str, columns: Union[str, List[str]],\n", 86 | " width: int = 960, height: int = 480) -> Image:\n", 87 | " \"\"\"\n", 88 | " 绘制nomogram图,Nomogram的图存储在当前文件夹下的nomogram.png\n", 89 | " Args:\n", 90 | " df: 数据路径,或者是读取之后的Dataframe格式。\n", 91 | " result: OST\n", 92 | " columns: 使用那些列计算nomogram\n", 93 | " width: nomogram分辨率--宽度,默认960\n", 94 | " height: nomogram分辨率--宽度,默认480\n", 95 | "\n", 96 | " Returns: PIL.Image\n", 97 | " \"\"\"\n", 98 | " ```" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "id": "6edc49b8", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "nomogram.risk_nomogram(df, result='result', columns=['age', 'gender', 'degree', 'Tstage', 'BMI', 'chemotherapy'], \n", 109 | " height=5800)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "31a55a1f", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [] 119 | } 120 | ], 121 | "metadata": { 122 | "kernelspec": { 123 | "display_name": "Python 3 (ipykernel)", 124 | "language": "python", 125 | "name": "python3" 126 | }, 127 | "language_info": { 128 | "codemirror_mode": { 129 | "name": "ipython", 130 | "version": 3 131 | }, 132 | "file_extension": ".py", 133 | "mimetype": "text/x-python", 134 | "name": "python", 135 | "nbconvert_exporter": "python", 136 | "pygments_lexer": "ipython3", 137 | "version": "3.7.12" 138 | } 139 | }, 140 | "nbformat": 4, 141 | "nbformat_minor": 5 142 | } 143 | -------------------------------------------------------------------------------- /onekey_algo/utils/MultiProcess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2021/3/21 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2021 All Rights Reserved. 7 | 8 | # 2017-06-13 First start. 9 | # 2017-08-26 Add parameters and feed_backs. 10 | # 2018-12-28 Move to JeanWe. 11 | import logging 12 | import math 13 | import multiprocessing 14 | import time 15 | 16 | 17 | def _get_value(v, t): 18 | if t == 'i' or t == 'integer': 19 | return int(v.value) 20 | else: 21 | return float(v.value) 22 | 23 | 24 | class MultiProcess(object): 25 | """MultiProcess, Give total samples to deal with, a function and kwargs and split parts. 26 | 27 | Parameters 28 | --------------- 29 | @samples: Total samples to process, in list or tuple type. 30 | @function: A single thread function to process samples. 31 | @num_process: Number of process to run. 32 | @feed_back: `func` return values. It can be list or tuple. Only integer and double supported. 33 | eg. [`i`,'d'] which mean an integer and a double while be returned from `function`. 34 | @kwargs: Other parameters passing to the `function`. 35 | 36 | Methods: 37 | --------------- 38 | @start: Star processing. 39 | @wait: wait for all the process to exit. 40 | @run: Useful, combine `start` and `wait`. 41 | 42 | Properties: 43 | --------------- 44 | @feed_back: `function` feed backs, in original order. 45 | shape: [num_process, len(`feed_back`)] 46 | 47 | How2use: 48 | --------------- 49 | samples = balabala 50 | function = bala 51 | # If `function` doesn't return values. 52 | MultiProcess(samples, function, numprocess=8).run() 53 | # Else 54 | mp = MultiProcess(samples, function, feed_back=['i', 'd'], num_process=8) 55 | mp.run() 56 | feed_back = mp.feed_back() 57 | """ 58 | logger = logging.getLogger('MultiProcess') 59 | logger.setLevel(logging.INFO) 60 | 61 | def __init__(self, samples, func, num_process=1, feed_back=None, **kwargs): 62 | # Set self logger 63 | 64 | if feed_back is None: 65 | feed_back = [] 66 | self.samples = list(samples) 67 | self.func = func 68 | self.kwargs = kwargs 69 | self.num_process = min(num_process, len(self.samples)) 70 | self.thread_list = [] 71 | 72 | self.share_mem = [] 73 | for t in feed_back: 74 | if not (t == 'i' or t == 'd' or t == 'integer' or t == 'double'): 75 | raise ValueError('Type of feed backs must be i(integer) or d(double).') 76 | for _ in range(num_process): 77 | fbv = [] 78 | for _ in range(len(feed_back)): 79 | fbv.append(multiprocessing.Value('d', 0)) 80 | self.share_mem.append(fbv) 81 | self.feed_back_type = feed_back 82 | 83 | def start(self, sleep_time=0): 84 | """ 85 | Exception 86 | --------------- 87 | VauleError: `samples`, `func`, `num_process` parameters setttings. 88 | """ 89 | if self.samples and self.func and self.num_process: 90 | samples_per_thread = int(math.ceil(len(self.samples) / self.num_process)) 91 | 92 | for i in range(self.num_process): 93 | kwargs = {'thread_id': i, 94 | 'samples': self.samples[samples_per_thread * i: samples_per_thread * (i + 1)]} 95 | if self.feed_back_type: 96 | kwargs.update({'feed_back': self.share_mem[i]}) 97 | kwargs.update(self.kwargs) 98 | sthread = multiprocessing.Process(target=self.func, kwargs=kwargs) 99 | self.logger.info('Starting process %d from %d to %d ...' 100 | % (i, samples_per_thread * i, min(len(self.samples), samples_per_thread * (i + 1)))) 101 | sthread.start() 102 | self.thread_list.append(sthread) 103 | if sleep_time and sleep_time > 0: 104 | time.sleep(sleep_time) 105 | else: 106 | raise ValueError('Error while starting multi process', 107 | 'Please check `samples`, `func`, `num_process` parameters') 108 | 109 | def wait(self): 110 | for t in self.thread_list: 111 | t.join() 112 | 113 | def run(self, sleep_time=0): 114 | self.start(sleep_time) 115 | self.wait() 116 | return self 117 | 118 | @property 119 | def feed_back(self): 120 | fb_list = [] 121 | for fb in self.share_mem: 122 | fb_list += [_get_value(v, t) for v, t in zip(fb, self.feed_back_type)] 123 | return fb_list 124 | -------------------------------------------------------------------------------- /onekey_algo/mietb/super_resolution/eval_super_res_reconstruction.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2023/04/15 4 | # Forum: www.medai.icu 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2023 All Rights Reserved. 7 | import argparse 8 | import os 9 | 10 | import nibabel as nib 11 | import numpy as np 12 | import torch 13 | import tqdm 14 | 15 | from onekey_algo.mietb.super_resolution.rcan import RCAN 16 | from onekey_algo.mietb.utils import normalize, denormalize, clip 17 | from onekey_algo.utils.about_log import logger 18 | 19 | 20 | def init(scale, model_path=None, img_range=255): 21 | model_path = model_path or os.path.join(os.environ.get('ONEKEY_HOME'), 'pretrain', f'RCAN_BIX{scale}.pt') 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | torch.set_grad_enabled(False) 24 | if torch.cuda.is_available(): 25 | torch.backends.cudnn.enabled = True 26 | torch.backends.cudnn.benchmark = True 27 | model = RCAN(scale=scale, n_colors=3, img_range=img_range) 28 | model.load_state_dict(torch.load(model_path, map_location=device), strict=False) 29 | model.eval() 30 | model.to(device) 31 | return model, device 32 | 33 | 34 | def inference(input_data, model, device, scale, img_range=255, save_dir=None): 35 | assert scale in [2, 4], f'{scale}不支持的重建倍数,目前只支持2x和4x的超清采样。' 36 | if save_dir is not None: 37 | os.makedirs(save_dir, exist_ok=True) 38 | if isinstance(input_data, str): 39 | if os.path.exists(input_data): 40 | input_data = [input_data] 41 | else: 42 | raise ValueError(f'input_data数据错误,{input_data}不存在。') 43 | for idx, input_data_path in enumerate(input_data): 44 | filename = os.path.basename(input_data_path) 45 | logger.info("count: {}, data path: {}".format(idx, input_data_path)) 46 | input_nii = nib.load(input_data_path) 47 | input_nii_data = input_nii.get_fdata() 48 | input_nii_affine = input_nii.get_affine() 49 | 50 | voxel_size = input_nii.header.get_zooms() 51 | input_nii.header.set_zooms(tuple([voxel_size[0] / scale, voxel_size[1] / scale, voxel_size[2]])) 52 | 53 | rlt = [] 54 | for jdx in tqdm.tqdm(range(input_nii_data.shape[2])): 55 | img = input_nii_data[:, :, jdx:jdx + 1].repeat(3, axis=2) 56 | _min, _max = img.min(), img.max() 57 | img = normalize(_min, _max, img).astype(np.float32) * 255. 58 | img_torch = torch.tensor(img.transpose(2, 0, 1)).unsqueeze(0).to(device) 59 | output = model(img_torch) 60 | output = output[0].cpu().numpy().transpose(1, 2, 0)[:, :, 0] 61 | # if args.save_rlt and jdx % 10 == 0: 62 | # filepath = os.path.join(args.dst_dir, filename[:-7]) 63 | # os.makedirs(filepath, exist_ok=True) 64 | # img = cv2.resize(img, (img.shape[0] * args.scale, img.shape[1] * args.scale), 65 | # interpolation=cv2.INTER_NEAREST) 66 | # cv2.imwrite(os.path.join(filepath, "{:03d}_input.png".format(jdx)), img) 67 | # cv2.imwrite(os.path.join(filepath, "{:03d}_output_{}X.png".format(jdx, args.scale)), output) 68 | output = output / img_range 69 | output = clip(0, 1, output) 70 | output = denormalize(_min, _max, output) 71 | rlt.append(output) 72 | if save_dir is not None: 73 | output_nii_path = os.path.join(save_dir, filename) 74 | else: 75 | output_nii_path = input_data_path.replace('.nii.gz', f'_X{scale}.nii.gz') 76 | flow_nii_data = np.stack(rlt, axis=2) 77 | nib.Nifti1Image(flow_nii_data, input_nii_affine, input_nii.header).to_filename(output_nii_path) 78 | 79 | 80 | if __name__ == '__main__': 81 | parser = argparse.ArgumentParser(description='Process some integers.') 82 | # changed configs 83 | parser.add_argument('--input_data', type=str, nargs='*', default="./demo/input", 84 | help="path to input data, only support nii.gz data") 85 | parser.add_argument('--dst_dir', type=str, default=None, 86 | help="path to result data") 87 | parser.add_argument('--scale', type=int, default=4, 88 | help="choose your upsample scale factor from [2,4]") 89 | parser.add_argument('--img_range', type=int, default=255, 90 | help="choose img range for test") 91 | args = parser.parse_args() 92 | model_path = os.path.join(os.environ.get('ONEKEY_HOME'), 'pretrain', f'RCAN_BIX{args.scale}.pt') 93 | model, device = init(model_path, scale=args.scale, ) 94 | args.input_data = r'C:\Users\onekey\Project\OnekeyDS\CT\images/0.nii.gz' 95 | inference(args.input_data, model, device, scale=args.scale, save_dir=args.dst_dir) 96 | -------------------------------------------------------------------------------- /onekey_core/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | irange = range 6 | 7 | 8 | def make_grid(tensor, nrow=8, padding=2, 9 | normalize=False, range=None, scale_each=False, pad_value=0): 10 | """Make a grid of images. 11 | 12 | Args: 13 | tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) 14 | or a list of images all of the same size. 15 | nrow (int, optional): Number of images displayed in each row of the grid. 16 | The final grid size is ``(B / nrow, nrow)``. Default: ``8``. 17 | padding (int, optional): amount of padding. Default: ``2``. 18 | normalize (bool, optional): If True, shift the image to the range (0, 1), 19 | by the min and max values specified by :attr:`range`. Default: ``False``. 20 | range (tuple, optional): tuple (min, max) where min and max are numbers, 21 | then these numbers are used to normalize the image. By default, min and max 22 | are computed from the tensor. 23 | scale_each (bool, optional): If ``True``, scale each image in the batch of 24 | images separately rather than the (min, max) over all images. Default: ``False``. 25 | pad_value (float, optional): Value for the padded pixels. Default: ``0``. 26 | 27 | Example: 28 | See this notebook `here `_ 29 | 30 | """ 31 | if not (torch.is_tensor(tensor) or 32 | (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 33 | raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor))) 34 | 35 | # if list of tensors, convert to a 4D mini-batch Tensor 36 | if isinstance(tensor, list): 37 | tensor = torch.stack(tensor, dim=0) 38 | 39 | if tensor.dim() == 2: # single image H x W 40 | tensor = tensor.unsqueeze(0) 41 | if tensor.dim() == 3: # single image 42 | if tensor.size(0) == 1: # if single-channel, convert to 3-channel 43 | tensor = torch.cat((tensor, tensor, tensor), 0) 44 | tensor = tensor.unsqueeze(0) 45 | 46 | if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images 47 | tensor = torch.cat((tensor, tensor, tensor), 1) 48 | 49 | if normalize is True: 50 | tensor = tensor.clone() # avoid modifying tensor in-place 51 | if range is not None: 52 | assert isinstance(range, tuple), \ 53 | "range has to be a tuple (min, max) if specified. min and max are numbers" 54 | 55 | def norm_ip(img, min, max): 56 | img.clamp_(min=min, max=max) 57 | img.add_(-min).div_(max - min + 1e-5) 58 | 59 | def norm_range(t, range): 60 | if range is not None: 61 | norm_ip(t, range[0], range[1]) 62 | else: 63 | norm_ip(t, float(t.min()), float(t.max())) 64 | 65 | if scale_each is True: 66 | for t in tensor: # loop over mini-batch dimension 67 | norm_range(t, range) 68 | else: 69 | norm_range(tensor, range) 70 | 71 | if tensor.size(0) == 1: 72 | return tensor.squeeze(0) 73 | 74 | # make the mini-batch of images into a grid 75 | nmaps = tensor.size(0) 76 | xmaps = min(nrow, nmaps) 77 | ymaps = int(math.ceil(float(nmaps) / xmaps)) 78 | height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) 79 | grid = tensor.new_full((3, height * ymaps + padding, width * xmaps + padding), pad_value) 80 | k = 0 81 | for y in irange(ymaps): 82 | for x in irange(xmaps): 83 | if k >= nmaps: 84 | break 85 | grid.narrow(1, y * height + padding, height - padding) \ 86 | .narrow(2, x * width + padding, width - padding) \ 87 | .copy_(tensor[k]) 88 | k = k + 1 89 | return grid 90 | 91 | 92 | def save_image(tensor, filename, nrow=8, padding=2, 93 | normalize=False, range=None, scale_each=False, pad_value=0): 94 | """Save a given Tensor into an image file. 95 | 96 | Args: 97 | tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, 98 | saves the tensor as a grid of images by calling ``make_grid``. 99 | **kwargs: Other arguments are documented in ``make_grid``. 100 | """ 101 | from PIL import Image 102 | grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value, 103 | normalize=normalize, range=range, scale_each=scale_each) 104 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer 105 | ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 106 | im = Image.fromarray(ndarr) 107 | im.save(filename) 108 | -------------------------------------------------------------------------------- /onekey_algo/mietb/super_resolution/rcan.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from onekey_algo.mietb.super_resolution import common 4 | 5 | 6 | ## Channel Attention (CA) Layer 7 | class CALayer(nn.Module): 8 | def __init__(self, channel, reduction=16): 9 | super(CALayer, self).__init__() 10 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 11 | self.conv_du = nn.Sequential( 12 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 15 | nn.Sigmoid() 16 | ) 17 | 18 | def forward(self, x): 19 | y = self.avg_pool(x) 20 | y = self.conv_du(y) 21 | return x * y 22 | 23 | 24 | ## Residual Channel Attention Block (RCAB) 25 | class RCAB(nn.Module): 26 | def __init__( 27 | self, conv, n_feat, kernel_size, reduction, 28 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 29 | 30 | super(RCAB, self).__init__() 31 | modules_body = [] 32 | for i in range(2): 33 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 34 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 35 | if i == 0: modules_body.append(act) 36 | modules_body.append(CALayer(n_feat, reduction)) 37 | self.body = nn.Sequential(*modules_body) 38 | self.res_scale = res_scale 39 | 40 | def forward(self, x): 41 | res = self.body(x) 42 | # res = self.body(x).mul(self.res_scale) 43 | res += x 44 | return res 45 | 46 | 47 | ## Residual Group (RG) 48 | class ResidualGroup(nn.Module): 49 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 50 | super(ResidualGroup, self).__init__() 51 | modules_body = [] 52 | modules_body = [ 53 | RCAB( 54 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 55 | for _ in range(n_resblocks)] 56 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 57 | self.body = nn.Sequential(*modules_body) 58 | 59 | def forward(self, x): 60 | res = self.body(x) 61 | res += x 62 | return res 63 | 64 | 65 | class RCAN(nn.Module): 66 | def __init__(self, scale, n_colors, img_range, conv=common.default_conv): 67 | super(RCAN, self).__init__() 68 | 69 | n_resgroups = 10 70 | n_resblocks = 20 71 | n_feats = 64 72 | kernel_size = 3 73 | reduction = 16 74 | act = nn.ReLU(True) 75 | 76 | rgb_mean = (0.4488, 0.4371, 0.4040) 77 | rgb_std = (1.0, 1.0, 1.0) 78 | self.sub_mean = common.MeanShift(1, rgb_mean, rgb_std) 79 | 80 | # define head module 81 | modules_head = [conv(n_colors, n_feats, kernel_size)] 82 | 83 | # define body module 84 | modules_body = [ 85 | ResidualGroup( 86 | conv, n_feats, kernel_size, reduction, act=act, res_scale=1, n_resblocks=n_resblocks) \ 87 | for _ in range(n_resgroups)] 88 | 89 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 90 | 91 | # define tail module 92 | modules_tail = [ 93 | common.Upsampler(conv, scale, n_feats, act=False), 94 | conv(n_feats, n_colors, kernel_size)] 95 | 96 | self.add_mean = common.MeanShift(img_range, rgb_mean, rgb_std, 1) 97 | 98 | self.head = nn.Sequential(*modules_head) 99 | self.body = nn.Sequential(*modules_body) 100 | self.tail = nn.Sequential(*modules_tail) 101 | 102 | def forward(self, x): 103 | x = self.sub_mean(x) 104 | x = self.head(x) 105 | 106 | res = self.body(x) 107 | res += x 108 | 109 | x = self.tail(res) 110 | x = self.add_mean(x) 111 | 112 | return x 113 | 114 | def load_state_dict(self, state_dict, strict=False): 115 | own_state = self.state_dict() 116 | for name, param in state_dict.items(): 117 | if name in own_state: 118 | if isinstance(param, nn.Parameter): 119 | param = param.data 120 | try: 121 | own_state[name].copy_(param) 122 | except Exception: 123 | if name.find('tail') >= 0: 124 | print('Replace pre-trained upsampler to new one...') 125 | else: 126 | raise RuntimeError('While copying the parameter named {}, ' 127 | 'whose dimensions in the model are {} and ' 128 | 'whose dimensions in the checkpoint are {}.' 129 | .format(name, own_state[name].size(), param.size())) 130 | elif strict: 131 | if name.find('tail') == -1: 132 | raise KeyError('unexpected key "{}" in state_dict' 133 | .format(name)) 134 | 135 | if strict: 136 | missing = set(own_state.keys()) - set(state_dict.keys()) 137 | if len(missing) > 0: 138 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 139 | -------------------------------------------------------------------------------- /onekey_algo/custom/components/nomogram.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2022/4/20 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2022 All Rights Reserved. 7 | import os 8 | from typing import List, Union, Iterable 9 | 10 | import pandas as pd 11 | import rpy2.robjects as robjects 12 | from PIL import Image 13 | from pandas import DataFrame 14 | from rpy2.robjects import globalenv 15 | from rpy2.robjects import pandas2ri 16 | from rpy2.robjects.packages import importr 17 | 18 | rms = importr("rms") 19 | survival = importr("survival") 20 | RTEMP = """ 21 | dd=datadist(rdf) 22 | options(datadist="dd") 23 | f2 <- cph(Surv({duration},{result}) ~ {columns}, data=rdf, x=TRUE, y=TRUE, surv=TRUE) 24 | 25 | # med <- Quantile(f2) 26 | surv <- Survival(f2) 27 | 28 | png( 29 | filename = "{save_name}", 30 | width = {width}, 31 | height = {height}, 32 | units = "px", 33 | bg = "white", 34 | res = 600) 35 | 36 | nom <- nomogram(f2, fun=list( 37 | {func} 38 | # function(x) med(lp=x) 39 | ), lp=F, 40 | funlabel=c({funlabel}), fun.at=c({x_range})) 41 | plot(nom, xfrac=.2) 42 | dev.off() 43 | """ 44 | FUNCTEMP = " function(x) surv({time}, x)" 45 | 46 | RRISKTEMP = """ 47 | dd=datadist(rdf) 48 | options(datadist="dd") 49 | f2 <- lrm({result} ~ {columns}, data = rdf) 50 | 51 | png( 52 | filename = "{save_name}", 53 | width = {width}, 54 | height = {height}, 55 | units = "px", 56 | bg = "white", 57 | res = 600) 58 | 59 | nom <- nomogram(f2, fun= function(x)1/(1+exp(-x)), 60 | lp=F, funlabel="Risk", fun.at=c({x_range})) 61 | plot(nom, xfrac=.2) 62 | dev.off() 63 | """ 64 | 65 | 66 | def nomogram(df: Union[str, DataFrame], duration: str, result: str, columns: Union[str, List[str]], 67 | survs: Union[int, List[int]], surv_names: Union[str, List[str]] = None, x_range='0.01,0.5,0.99', 68 | width: int = 8000, height: int = 3200, save_name='nomogram.png', with_r: bool = False) -> Image: 69 | """ 70 | 绘制nomogram图,Nomogram的图存储在当前文件夹下的nomogram.png 71 | Args: 72 | df: 数据路径,或者是读取之后的Dataframe格式。 73 | duration: OS 74 | result: OST 75 | columns: 使用那些列计算nomogram 76 | survs: 生存时间转化成x 年生存率 77 | surv_names: survs对应的列名。 78 | x_range: 79 | width: nomogram分辨率--宽度,默认960 80 | height: nomogram分辨率--宽度,默认480 81 | save_name: 保存的文件名。 82 | with_r: 是否输出R语言代码 83 | 84 | Returns: PIL.Image 85 | """ 86 | if not isinstance(survs, Iterable): 87 | survs = [survs] 88 | if isinstance(surv_names, Iterable): 89 | assert len(survs) == len(surv_names), f"预测的标尺名称和标尺个数必须相等。" 90 | else: 91 | surv_names = [surv_names] * len(survs) 92 | if not isinstance(columns, Iterable): 93 | columns = [columns] 94 | assert all(c_ in df.columns for c_ in [duration, result] + columns), '所有列名必须在df参数中' 95 | if isinstance(x_range, (list, tuple)): 96 | x_range = ','.join(map(lambda x: str(x), x_range)) 97 | if isinstance(df, str) and os.path.exists(df): 98 | df = pd.read_csv(df, header=0) 99 | pandas2ri.activate() 100 | rdf = pandas2ri.py2rpy(df) 101 | globalenv['rdf'] = rdf 102 | columns = '+'.join(map(lambda x: str(x), columns)) 103 | func = ','.join(FUNCTEMP.format(time=surv) for surv in survs) 104 | funlabel = ','.join([f'"{surv_name}"' if surv_names is not None else f'"{surv_name} Survival"' 105 | for surv_name in surv_names]) 106 | rscript = RTEMP.format(duration=duration, result=result, columns=columns, func=func, funlabel=funlabel, 107 | width=width, height=height, save_name=save_name, x_range=x_range) 108 | if with_r: 109 | print(rscript) 110 | robjects.r(rscript) 111 | return Image.open(save_name) 112 | 113 | 114 | def risk_nomogram(df: Union[str, DataFrame], result: str, columns: Union[str, List[str]], x_range='0.01,0.5,0.99', 115 | width: int = 8000, height: int = 3200, save_name='nomogram.png', with_r: bool = False) -> Image: 116 | """ 117 | 绘制nomogram图,Nomogram的图存储在当前文件夹下的nomogram.png 118 | Args: 119 | df: 数据路径,或者是读取之后的Dataframe格式。 120 | result: OST 121 | columns: 使用那些列计算nomogram 122 | x_range: 横坐标的取值区间 123 | width: nomogram分辨率--宽度,默认960 124 | height: nomogram分辨率--宽度,默认480 125 | save_name: 保存的文件名 126 | with_r: 是否输出r语言代码 127 | 128 | Returns: PIL.Image 129 | """ 130 | if not isinstance(columns, Iterable): 131 | columns = [columns] 132 | assert all(c_ in df.columns for c_ in [result] + columns), '所有列名必须在df参数中' 133 | if isinstance(x_range, (list, tuple)): 134 | x_range = ','.join(map(lambda x: str(x), x_range)) 135 | if isinstance(df, str) and os.path.exists(df): 136 | df = pd.read_csv(df, header=0) 137 | pandas2ri.activate() 138 | rdf = pandas2ri.py2rpy(df) 139 | globalenv['rdf'] = rdf 140 | columns = '+'.join(map(lambda x: str(x), columns)) 141 | rscript = RRISKTEMP.format(result=result, columns=columns, width=width, height=height, save_name=save_name, 142 | x_range=x_range) 143 | if with_r: 144 | print(rscript) 145 | robjects.r(rscript) 146 | return Image.open(save_name) 147 | -------------------------------------------------------------------------------- /onekey_algo/utils/about_log.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2021/3/21 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2021 All Rights Reserved. 7 | 8 | import logging 9 | import os 10 | import time 11 | from configparser import ConfigParser 12 | from datetime import datetime 13 | 14 | from termcolor import colored 15 | 16 | logger = logging.root 17 | logger.setLevel(logging.INFO) 18 | handler = logging.StreamHandler() 19 | handler.setFormatter(logging.Formatter("[%(asctime)s - %(filename)s:%(lineno)4s]\t%(levelname)s\t%(message)s", 20 | '%Y-%m-%d %H:%M:%S')) 21 | logger.handlers = [handler] 22 | PIPE_HINT = colored('PIPE INPUT SUPPORTED!', 'green', attrs=['blink']) 23 | REQUIRED = colored('REQUIRED!', 'red', attrs=['blink']) 24 | 25 | 26 | class ProcessPrinter(object): 27 | """ 28 | Print processing in two way. 29 | """ 30 | 31 | def __init__(self, total_num=None): 32 | self.last_time = time.time() 33 | self.start_time = time.time() 34 | self.last_num = 0 35 | self.total_num = total_num 36 | 37 | def print_log(self, done_num, total_num=None, ratio=0.1, thread_id=None): 38 | """ 39 | Print log every total_num * ratio number of examples. 40 | 41 | :param done_num: Already parsed number of samples. 42 | :param total_num: Total number of samples to parse. 43 | :param ratio: Print log after dealing `ratio` samples. 44 | :param thread_id: Default None. If not None, thread_id flag is used. 45 | """ 46 | total_num = total_num if total_num else self.total_num 47 | assert total_num, "`total_num` must be supplied!" 48 | prefix = 'Thread id %s\t' % str(thread_id) if thread_id is not None else '' 49 | 50 | if done_num % (int(total_num * ratio) if int(total_num * ratio) else 1) == 0: 51 | speed = (done_num - self.last_num) / (time.time() - self.last_time + 1e-10) 52 | self.last_num = done_num 53 | self.last_time = time.time() 54 | logger.info('%sParse %.3d%% ====>> %.7d/%.7d\tSpeed: %.6f/s' % (prefix, 55 | done_num * 100 // total_num, 56 | done_num, total_num, speed)) 57 | if done_num == total_num: 58 | speed = total_num / (time.time() - self.start_time + 1e-10) 59 | self.last_num = 0 60 | self.last_time = time.time() 61 | self.start_time = time.time() 62 | logger.info('%sCompleted!\tSpeed: %f/s' % (prefix, speed)) 63 | 64 | 65 | class ColorPrinter(object): 66 | def __init__(self): 67 | self._print_fn = None 68 | 69 | def _get_color_print(self): 70 | try: 71 | from termcolor import cprint 72 | 73 | def color_print(msg, color='green'): 74 | cprint(msg, color) 75 | 76 | self._print_fn = color_print 77 | except Exception as e: 78 | print("Can't use color print because of e" % e) 79 | 80 | def print_ignore_color(msg, *used, **unused): 81 | print(msg) 82 | 83 | self._print_fn = print_ignore_color 84 | 85 | def print_with_color(self, info, flag=True): 86 | color = 'green' if flag else 'red' 87 | self.cprint(info, color) 88 | 89 | @property 90 | def cprint(self): 91 | """ Print text with given color. 92 | 93 | :param msg: message to be colored 94 | :param color: which color to use, [grey, red, green, yellow, blue, magenta, cyan, white] 95 | """ 96 | if not self._print_fn: 97 | self._get_color_print() 98 | return self._print_fn 99 | 100 | @staticmethod 101 | def color_text(msg, color='green', attrs=None): 102 | """ Color text if sys has termcolor. 103 | 104 | Parameters 105 | --------------- 106 | :param msg: message to be colored 107 | :param color: which color to use, [grey, red, green, yellow, blue, magenta, cyan, white] 108 | :param attrs: [bold, dark, underline, blink, reverse, concealed] 109 | :return Colored message. 110 | """ 111 | try: 112 | if type(attrs) == str: 113 | attrs = [attrs] 114 | return colored(msg, color, attrs=attrs) 115 | except Exception as e: 116 | print(e) 117 | return msg 118 | 119 | 120 | def save_report(prefix, args, save_to): 121 | """Save report into `save_to`. 122 | 123 | :param prefix: Prefix. 124 | :param args: Must in dict type. 125 | :param save_to: Save to specific path with date added. 126 | """ 127 | cp = ConfigParser.ConfigParser() 128 | for s in args: 129 | cp.add_section(s) 130 | for o in args[s]: 131 | cp.set(s, o, args[s][o]) 132 | 133 | save_to = os.path.join(save_to, datetime.now().strftime("%Y%m%d")) 134 | os.makedirs(save_to, exist_ok=True) 135 | with open(os.path.join(save_to, "%s-%s" % (prefix, datetime.now().strftime("%Y%m%d-%H%M%S"))), "w") as f: 136 | cp.write(f) 137 | 138 | 139 | def log_long_str(string, logger_type=logger.info, prefix=''): 140 | """ 141 | Long string to log. 142 | 143 | :param string: string 144 | :param logger_type: info, warning or error. 145 | :param prefix: Prefix of each line. 146 | :return: 147 | """ 148 | for l in str(string).split('\n'): 149 | logger_type(f"{prefix}{l}") 150 | -------------------------------------------------------------------------------- /onekey_algo/custom/components/delong.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2022/3/7 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2022 All Rights Reserved. 7 | import numpy as np 8 | import scipy.stats 9 | 10 | 11 | def compute_midrank(x): 12 | """Computes midranks. 13 | Args: 14 | x - a 1D numpy array 15 | Returns: 16 | array of midranks 17 | """ 18 | J = np.argsort(x) 19 | Z = x[J] 20 | N = len(x) 21 | T = np.zeros(N, dtype=np.float64) 22 | i = 0 23 | while i < N: 24 | j = i 25 | while j < N and Z[j] == Z[i]: 26 | j += 1 27 | T[i:j] = 0.5 * (i + j - 1) 28 | i = j 29 | T2 = np.empty(N, dtype=np.float64) 30 | # Note(kazeevn) +1 is due to Python using 0-based indexing 31 | # instead of 1-based in the AUC formula in the paper 32 | T2[J] = T + 1 33 | return T2 34 | 35 | 36 | def fastDeLong(predictions_sorted_transposed, label_1_count): 37 | """ 38 | The fast version of DeLong's method for computing the covariance of 39 | unadjusted AUC. 40 | Args: 41 | predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples] 42 | sorted such as the examples with label "1" are first 43 | Returns: 44 | (AUC value, DeLong covariance) 45 | Reference: 46 | @article{sun2014fast, 47 | title={Fast Implementation of DeLong's Algorithm for 48 | Comparing the Areas Under Correlated Receiver Operating Characteristic Curves}, 49 | author={Xu Sun and Weichao Xu}, 50 | journal={IEEE Signal Processing Letters}, 51 | volume={21}, 52 | number={11}, 53 | pages={1389--1393}, 54 | year={2014}, 55 | publisher={IEEE} 56 | } 57 | """ 58 | # Short variables are named as they are in the paper 59 | m = label_1_count 60 | n = predictions_sorted_transposed.shape[1] - m 61 | positive_examples = predictions_sorted_transposed[:, :m] 62 | negative_examples = predictions_sorted_transposed[:, m:] 63 | k = predictions_sorted_transposed.shape[0] 64 | 65 | tx = np.empty([k, m], dtype=np.float64) 66 | ty = np.empty([k, n], dtype=np.float64) 67 | tz = np.empty([k, m + n], dtype=np.float64) 68 | for r in range(k): 69 | tx[r, :] = compute_midrank(positive_examples[r, :]) 70 | ty[r, :] = compute_midrank(negative_examples[r, :]) 71 | tz[r, :] = compute_midrank(predictions_sorted_transposed[r, :]) 72 | aucs = tz[:, :m].sum(axis=1) / m / n - float(m + 1.0) / 2.0 / n 73 | v01 = (tz[:, :m] - tx[:, :]) / n 74 | v10 = 1.0 - (tz[:, m:] - ty[:, :]) / m 75 | sx = np.cov(v01) 76 | sy = np.cov(v10) 77 | delongcov = sx / m + sy / n 78 | return aucs, delongcov 79 | 80 | 81 | def calc_pvalue(aucs, sigma, logv=10): 82 | """Computes log(10) of p-values. 83 | Args: 84 | aucs: 1D array of AUCs 85 | sigma: AUC DeLong covariances 86 | logv: 87 | Returns: 88 | log10(pvalue) 89 | """ 90 | l = np.array([[1, -1]]) 91 | z = np.abs(np.diff(aucs)) / np.sqrt(np.dot(np.dot(l, sigma), l.T)) 92 | # print(np.sqrt(np.dot(np.dot(l, sigma), l.T))) 93 | return 10 ** (np.log10(2) + scipy.stats.norm.logsf(z, loc=0, scale=1) / np.log(logv)) 94 | 95 | 96 | def compute_ground_truth_statistics(ground_truth): 97 | assert np.array_equal(np.unique(ground_truth), [0, 1]) 98 | order = (-ground_truth).argsort() 99 | label_1_count = int(ground_truth.sum()) 100 | return order, label_1_count 101 | 102 | 103 | def delong_roc_variance(ground_truth, predictions): 104 | """ 105 | Computes ROC AUC variance for a single set of predictions 106 | Args: 107 | ground_truth: np.array of 0 and 1 108 | predictions: np.array of floats of the probability of being class 1 109 | """ 110 | order, label_1_count = compute_ground_truth_statistics(ground_truth) 111 | predictions_sorted_transposed = predictions[np.newaxis, order] 112 | aucs, delongcov = fastDeLong(predictions_sorted_transposed, label_1_count) 113 | assert len(aucs) == 1, "There is a bug in the code, please forward this to the developers" 114 | return aucs[0], delongcov 115 | 116 | 117 | def delong_roc_test(ground_truth, predictions_one, predictions_two, logv=10): 118 | """ 119 | Computes log(p-value) for hypothesis that two ROC AUCs are different 120 | Args: 121 | ground_truth: np.array of 0 and 1 122 | predictions_one: predictions of the first model, 123 | np.array of floats of the probability of being class 1 124 | predictions_two: predictions of the second model, 125 | np.array of floats of the probability of being class 1 126 | logv: 127 | """ 128 | order, label_1_count = compute_ground_truth_statistics(ground_truth) 129 | predictions_sorted_transposed = np.vstack((predictions_one, predictions_two))[:, order] 130 | aucs, delongcov = fastDeLong(predictions_sorted_transposed, label_1_count) 131 | # print(aucs, delongcov) 132 | return calc_pvalue(aucs, delongcov, logv=logv) 133 | 134 | 135 | def calc_95_CI(ground_truth, predictions, alpha=0.95, with_auc: bool = True): 136 | auc, auc_cov = delong_roc_variance(ground_truth, predictions) 137 | auc_std = np.sqrt(auc_cov) 138 | lower_upper_q = np.abs(np.array([0, 1]) - (1 - alpha) / 2) 139 | ci = scipy.stats.norm.ppf(lower_upper_q, loc=auc, scale=auc_std) 140 | ci[ci > 1] = 1 141 | ci[ci < 0] = 0 142 | if with_auc: 143 | return auc, ci 144 | else: 145 | return ci 146 | 147 | 148 | if __name__ == '__main__': 149 | preds_A = np.array([.5, .5, .5, .5, .5, .5, .5, .5, .5, .8]) 150 | preds_B = np.array([.2, .5, .1, .4, .9, .8, .7, .5, .9, .8]) 151 | actual = np.array([0, 0, 0, 0, 1, 0, 1, 1, 1, 1]) 152 | print(delong_roc_test(actual, preds_A, preds_B)) 153 | -------------------------------------------------------------------------------- /onekey_comp/comp7-Survival/生存分析-KaplanMeier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "637adcb1", 6 | "metadata": {}, 7 | "source": [ 8 | "## 生存分析\n", 9 | "\n", 10 | "一般使用KM算法记性单一变量拟合,同事可以分变量预测效果。" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "f50c8eb1", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import pandas as pd\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "from onekey_algo import OnekeyDS as okds\n", 23 | "\n", 24 | "# 设置绘图参数。\n", 25 | "plt.rcParams['figure.figsize'] = (10.0, 8.0)\n", 26 | "plt.rcParams['font.sans-serif'] = 'SimHei'\n", 27 | "plt.rcParams['axes.unicode_minus'] = False\n", 28 | "\n", 29 | "mydir = okds.survival\n", 30 | "data = pd.read_csv(mydir)\n", 31 | "data.head()" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "49fcae05", 37 | "metadata": {}, 38 | "source": [ 39 | "### KM Estimator\n", 40 | "To estimate the survival function, we first will use the Kaplan-Meier Estimate, defined:\n", 41 | "\n", 42 | "$\\hat{S}(t) = \\prod_{t_i \\lt t} \\frac{n_i - d_i}{n_i} $\n", 43 | "\n", 44 | "where $d_i$ are the number of death events at time $t$ and $n_i$ is the number of subjects at risk of death just prior to time t." 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "7b69afca", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "from lifelines import KaplanMeierFitter\n", 55 | "kmf = KaplanMeierFitter()" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "a3bfcca4", 61 | "metadata": {}, 62 | "source": [ 63 | "### 定义生存时间和最终状态\n", 64 | "Other ways to estimate the survival function in lifelines are discussed below.\n", 65 | "\n", 66 | "For this estimation, we need the duration each leader was/has been in office, and whether or not they were observed to have left office (leaders who died in office or were in office in 2008, the latest date this data was record at, do not have observed death events)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "7d5b1bb4", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "T = data[\"duration\"]\n", 77 | "E = data[[\"result\"]]\n", 78 | "\n", 79 | "kmf.fit(T, event_observed=E)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "708c07d8", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "from matplotlib import pyplot as plt\n", 90 | "\n", 91 | "kmf.plot_survival_function()\n", 92 | "plt.title('Survival function of political regimes')\n", 93 | "plt.show()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "1b81d911", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "ax = plt.subplot(111)\n", 104 | "\n", 105 | "dem = (data[\"smoke\"] == 1)\n", 106 | "\n", 107 | "kmf.fit(T[dem], event_observed=E[dem], label=\"smoke\")\n", 108 | "kmf.plot_survival_function(ax=ax)\n", 109 | "\n", 110 | "kmf.fit(T[~dem], event_observed=E[~dem], label=\"Non-smoke\")\n", 111 | "kmf.plot_survival_function(ax=ax)\n", 112 | "\n", 113 | "plt.title(\"Lifespans of different global regimes\");" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "04b44640", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "import numpy as np\n", 124 | "\n", 125 | "ax = plt.subplot(111)\n", 126 | "\n", 127 | "t = np.linspace(20, 80, 41)\n", 128 | "kmf.fit(T[dem], event_observed=E[dem], timeline=t, label=\"Democratic Regimes\")\n", 129 | "ax = kmf.plot_survival_function(ax=ax)\n", 130 | "\n", 131 | "kmf.fit(T[~dem], event_observed=E[~dem], timeline=t, label=\"Non-democratic Regimes\")\n", 132 | "ax = kmf.plot_survival_function(ax=ax)\n", 133 | "\n", 134 | "plt.title(\"Lifespans of different global regimes\");" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "a93d4c7d", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "regime_types = data['Tstage'].unique()\n", 145 | "\n", 146 | "for i, regime_type in enumerate(regime_types):\n", 147 | " ax = plt.subplot(2, 2, i + 1)\n", 148 | "\n", 149 | " ix = data['Tstage'] == regime_type\n", 150 | " kmf.fit(T[ix], E[ix], label=regime_type)\n", 151 | " kmf.plot_survival_function(ax=ax, legend=False)\n", 152 | "\n", 153 | " plt.title(regime_type)\n", 154 | " plt.xlim(20, 83)\n", 155 | "\n", 156 | " if i==0:\n", 157 | " plt.ylabel('Frac. in power after $n$ years')\n", 158 | "\n", 159 | "plt.tight_layout()" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "id": "20c9fd42", 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "kmf = KaplanMeierFitter().fit(T, E, label=\"all_regimes\")\n", 170 | "kmf.plot_survival_function(at_risk_counts=True)\n", 171 | "plt.tight_layout()" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "id": "cf716b5b", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "Python 3 (ipykernel)", 186 | "language": "python", 187 | "name": "python3" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.7.12" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 5 204 | } 205 | -------------------------------------------------------------------------------- /onekey_algo/custom/components/comp2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2019/7/22 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2019 All Rights Reserved. 7 | import argparse 8 | import json 9 | import os 10 | from functools import partial 11 | from typing import Iterable 12 | 13 | import numpy as np 14 | import torch 15 | 16 | from onekey_algo.datasets.image_loader import default_loader 17 | from onekey_algo.utils.about_log import logger 18 | from onekey_core.core import create_model 19 | from onekey_core.core import create_standard_image_transformer 20 | 21 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 22 | 23 | 24 | def extract(samples, model, transformer, device=None, fp=None): 25 | results = [] 26 | # Inference 27 | if not isinstance(samples, (list, tuple)): 28 | samples = [samples] 29 | with torch.set_grad_enabled(False): 30 | for sample in samples: 31 | fp.write(f"{os.path.basename(sample)},") 32 | sample_ = transformer(default_loader(sample)) 33 | sample_ = sample_.to(device) 34 | # print(sample_.size()) 35 | outputs = model(sample_.view(1, *sample_.size())) 36 | results.append(outputs) 37 | return results 38 | 39 | 40 | def print_feature_hook(module, inp, outp, fp): 41 | print(','.join(map(lambda x: f"{x:.6f}", np.reshape(outp.cpu().numpy(), -1))), file=fp) 42 | 43 | 44 | def reg_hook_on_module(name, model, hook): 45 | find_ = 0 46 | for n, m in model.named_modules(): 47 | if name == n: 48 | m.register_forward_hook(hook) 49 | find_ += 1 50 | if find_ == 0: 51 | logger.warning(f'{name} not found in {model}') 52 | elif find_ > 1: 53 | logger.info(f'Found {find_} features named {name} in {model}') 54 | return find_ 55 | 56 | 57 | def init_from_model(model_name, model_path=None, num_classes=1000, model_state='model_state_dict', 58 | img_size=(224, 224), **kwargs): 59 | # Configuration of core 60 | kwargs.update({'pretrained': True if model_path is None else False, 61 | 'model_name': model_name, 'num_classes': num_classes}) 62 | model = create_model(**kwargs).eval() 63 | # Config device automatically 64 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 65 | model = model.to(device) 66 | if model_path and os.path.exists(model_path): 67 | state_dict = torch.load(model_path, map_location=device)[model_state] 68 | model.load_state_dict(state_dict) 69 | if 'inception' in model_name.lower(): 70 | if isinstance(img_size, int): 71 | if img_size != 299: 72 | logger.warning(f'{model_name} is inception structure, `img_size` is set to be 299 * 299.') 73 | img_size = 299 74 | elif isinstance(img_size, Iterable): 75 | if 299 not in img_size: 76 | logger.warning(f'{model_name} is inception structure, `img_size` is set to be 299 * 299.') 77 | img_size = (299, 299) 78 | transformer = create_standard_image_transformer(img_size, phase='valid') 79 | return model, transformer, device 80 | 81 | 82 | def init_from_onekey(config_path): 83 | config = json.loads(open(os.path.join(config_path, 'task.json')).read()) 84 | model_path = os.path.join(config_path, 'BEST-training-params.pth') 85 | assert 'model_name' in config and 'num_classes' in config and 'transform' in config 86 | # Configuration of transformer. 87 | transform_config = {'phase': 'valid'} 88 | transform_config.update(config['transform']) 89 | assert 'input_size' in transform_config, '`input_size` must in `transform`' 90 | transformer = create_standard_image_transformer(**transform_config) 91 | 92 | # Configuration of core 93 | model_config = {'pretrained': False, 'model_name': config['model_name'], 'num_classes': config['num_classes']} 94 | model = create_model(**model_config) 95 | # Configuration of device 96 | device_info = 'cpu' 97 | device = torch.device(device_info) 98 | model = model.to(device) 99 | state_dict = torch.load(model_path, map_location=device)['model_state_dict'] 100 | for key in list(state_dict.keys()): 101 | if key.startswith('module.'): 102 | new_key = key[7:] 103 | state_dict[new_key] = state_dict[key] 104 | del state_dict[key] 105 | model.load_state_dict(state_dict) 106 | model.eval() 107 | return model, transformer, device 108 | 109 | 110 | if __name__ == "__main__": 111 | parser = argparse.ArgumentParser(description='PyTorch Classification Inference') 112 | 113 | parser.add_argument('-c', '--config_path', dest='c', default='20211014/resnet18/viz', 114 | help='Model and transformer configuration') 115 | parser.add_argument('-d', '--directory', dest='d', 116 | default=r'C:\Users\onekey\Project\data\labelme', help='Inference data directory.') 117 | parser.add_argument('-l', '--list_file', dest='l', default=None, help='Inference data list file') 118 | 119 | args = parser.parse_args() 120 | if args.d is not None: 121 | test_samples = [os.path.join(args.d, p) for p in os.listdir(args.d) if p.endswith('.jpg')] 122 | elif args.l is not None: 123 | with open(args.l) as f: 124 | test_samples = [l.strip() for l in f.readlines()] 125 | else: 126 | raise ValueError('You must provide a directory or list file for inference.') 127 | model_name = 'resnet18' 128 | model, transformer, device = init_from_model(model_name=model_name) 129 | # print(model) 130 | # for n, m in model.named_modules(): 131 | # print(n, m) 132 | feature_name = 'avgpool' 133 | outfile = open('feature.txt', 'w') 134 | hook = partial(print_feature_hook, fp=outfile) 135 | find_num = reg_hook_on_module(feature_name, model, hook) 136 | results_ = extract(test_samples[:5], model, transformer, device, fp=outfile) 137 | print(json.dumps(results_, ensure_ascii=False, indent=True)) 138 | -------------------------------------------------------------------------------- /onekey_core/models/squeezenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | from .utils import load_state_dict_from_url 6 | 7 | __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] 8 | 9 | model_urls = { 10 | 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', 11 | 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', 12 | } 13 | 14 | 15 | class Fire(nn.Module): 16 | 17 | def __init__(self, inplanes, squeeze_planes, 18 | expand1x1_planes, expand3x3_planes): 19 | super(Fire, self).__init__() 20 | self.inplanes = inplanes 21 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 22 | self.squeeze_activation = nn.ReLU(inplace=True) 23 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 24 | kernel_size=1) 25 | self.expand1x1_activation = nn.ReLU(inplace=True) 26 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 27 | kernel_size=3, padding=1) 28 | self.expand3x3_activation = nn.ReLU(inplace=True) 29 | 30 | def forward(self, x): 31 | x = self.squeeze_activation(self.squeeze(x)) 32 | return torch.cat([ 33 | self.expand1x1_activation(self.expand1x1(x)), 34 | self.expand3x3_activation(self.expand3x3(x)) 35 | ], 1) 36 | 37 | 38 | class SqueezeNet(nn.Module): 39 | 40 | def __init__(self, version='1_0', in_channels: int = 3, num_classes=1000, **kwargs): 41 | super(SqueezeNet, self).__init__() 42 | self.num_classes = num_classes 43 | if version == '1_0': 44 | self.features = nn.Sequential( 45 | nn.Conv2d(in_channels, 96, kernel_size=7, stride=2), 46 | nn.ReLU(inplace=True), 47 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 48 | Fire(96, 16, 64, 64), 49 | Fire(128, 16, 64, 64), 50 | Fire(128, 32, 128, 128), 51 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 52 | Fire(256, 32, 128, 128), 53 | Fire(256, 48, 192, 192), 54 | Fire(384, 48, 192, 192), 55 | Fire(384, 64, 256, 256), 56 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 57 | Fire(512, 64, 256, 256), 58 | ) 59 | elif version == '1_1': 60 | self.features = nn.Sequential( 61 | nn.Conv2d(3, 64, kernel_size=3, stride=2), 62 | nn.ReLU(inplace=True), 63 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 64 | Fire(64, 16, 64, 64), 65 | Fire(128, 16, 64, 64), 66 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 67 | Fire(128, 32, 128, 128), 68 | Fire(256, 32, 128, 128), 69 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 70 | Fire(256, 48, 192, 192), 71 | Fire(384, 48, 192, 192), 72 | Fire(384, 64, 256, 256), 73 | Fire(512, 64, 256, 256), 74 | ) 75 | else: 76 | # FIXME: Is this needed? SqueezeNet should only be called from the 77 | # FIXME: squeezenet1_x() functions 78 | # FIXME: This checking is not done for the other models 79 | raise ValueError("Unsupported SqueezeNet version {version}:" 80 | "1_0 or 1_1 expected".format(version=version)) 81 | 82 | # Final convolution is initialized differently from the rest 83 | final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) 84 | self.classifier = nn.Sequential( 85 | nn.Dropout(p=0.5), 86 | final_conv, 87 | nn.ReLU(inplace=True), 88 | nn.AdaptiveAvgPool2d((1, 1)) 89 | ) 90 | 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | if m is final_conv: 94 | init.normal_(m.weight, mean=0.0, std=0.01) 95 | else: 96 | init.kaiming_uniform_(m.weight) 97 | if m.bias is not None: 98 | init.constant_(m.bias, 0) 99 | 100 | def forward(self, x): 101 | x = self.features(x) 102 | x = self.classifier(x) 103 | return torch.flatten(x, 1) 104 | 105 | 106 | def _squeezenet(version, pretrained, progress, **kwargs): 107 | model = SqueezeNet(version, **kwargs) 108 | if pretrained: 109 | arch = 'squeezenet' + version 110 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 111 | parameters_list = list(state_dict.keys()) 112 | for k in parameters_list: 113 | if 'classifier' in k: 114 | del state_dict[k] 115 | model.load_state_dict(state_dict, strict=False) 116 | return model 117 | 118 | 119 | def squeezenet1_0(pretrained=False, progress=True, **kwargs): 120 | r"""SqueezeNet core architecture from the `"SqueezeNet: AlexNet-level 121 | accuracy with 50x fewer parameters and <0.5MB core size" 122 | `_ paper. 123 | 124 | Args: 125 | pretrained (bool): If True, returns a core pre-trained on ImageNet 126 | progress (bool): If True, displays a progress bar of the download to stderr 127 | """ 128 | return _squeezenet('1_0', pretrained, progress, **kwargs) 129 | 130 | 131 | def squeezenet1_1(pretrained=False, progress=True, **kwargs): 132 | r"""SqueezeNet 1.1 core from the `official SqueezeNet repo 133 | `_. 134 | SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters 135 | than SqueezeNet 1.0, without sacrificing accuracy. 136 | 137 | Args: 138 | pretrained (bool): If True, returns a core pre-trained on ImageNet 139 | progress (bool): If True, displays a progress bar of the download to stderr 140 | """ 141 | return _squeezenet('1_1', pretrained, progress, **kwargs) 142 | -------------------------------------------------------------------------------- /onekey_algo/classification/eval_classification.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Authorized by Vlon Jang 3 | # Created on 2019/7/22 4 | # Blog: www.wangqingbaidu.cn 5 | # Email: wangqingbaidu@gmail.com 6 | # Copyright 2015-2019 All Rights Reserved. 7 | import argparse 8 | import json 9 | import math 10 | import os 11 | import time 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.utils as utils 16 | import pandas as pd 17 | 18 | from onekey_algo.datasets.ClassificationDataset import ListDataset4Test 19 | from onekey_algo.datasets.image_loader import default_loader 20 | from onekey_algo.utils.about_log import logger 21 | from onekey_core.core import create_model 22 | from onekey_core.core import create_standard_image_transformer 23 | 24 | 25 | def inference(samples, model, transformer, labels=None, device=None): 26 | # Inference 27 | if labels is None: 28 | labels = {} 29 | results = [] 30 | if not isinstance(samples, (list, tuple)): 31 | samples = [samples] 32 | with torch.set_grad_enabled(False): 33 | for idx, sample in enumerate(samples): 34 | if len(samples) > 1e4 and idx % 1000 == 0: 35 | logger.info(f'正在预测中,已完成:{idx}, 完成率:{idx * 100 / len(samples):.4f}%') 36 | sample_ = transformer(default_loader(sample)) 37 | sample_ = sample_.to(device) 38 | # print(sample_.size()) 39 | outputs = model(sample_.view(1, *sample_.size())) 40 | prob = F.softmax(outputs, dim=1)[0].cpu() 41 | prediction = torch.argmax(prob) 42 | results.append((os.path.basename(sample), dict(zip(labels, prob.numpy().tolist())), 43 | labels[prediction.item()])) 44 | return results 45 | 46 | 47 | def inference_dataloader(samples, model, transformer, labels=None, device=None, batch_size=1, num_workers=1, 48 | cached_dir: str = None): 49 | # Inference 50 | if labels is None: 51 | labels = {} 52 | results = [] 53 | if not isinstance(samples, (list, tuple)): 54 | samples = [samples] 55 | dataloader = utils.data.DataLoader(ListDataset4Test(samples, transformer), batch_size=batch_size, drop_last=False, 56 | shuffle=False, num_workers=num_workers) 57 | calc_num = 0 58 | with torch.set_grad_enabled(False): 59 | start_time = time.time() 60 | for idx, (sample_, fnames) in enumerate(dataloader): 61 | sample_ = sample_.to(device) 62 | if len(samples) > 1e4 and idx % 100 == 0 and idx != 0: 63 | speed = (time.time() - start_time) * 1000 / (idx * batch_size) 64 | logger.info( 65 | f'正在预测中,已完成:{idx * batch_size}, 完成率:{idx * batch_size * 100 / len(samples):.4f}%,' 66 | f'移动平均速度是:{speed:.4f} msec/img') 67 | outputs = model(sample_) 68 | probs = F.softmax(outputs, dim=1).cpu() 69 | predictions = torch.argmax(probs, dim=1).cpu() 70 | # print(probs.shape, predictions.shape) 71 | for fname, prob, prediction in zip(fnames, probs, predictions): 72 | results.append((fname, json.dumps(dict(zip(labels, prob.numpy().tolist())), ensure_ascii=False), 73 | labels[prediction.item()])) 74 | calc_num += 1 75 | if cached_dir is not None and len(samples) > 100 and calc_num % math.ceil(len(samples) / 100) == 0: 76 | logger.info(f'Saving cached {calc_num * 100 // len(samples):03d} results...') 77 | os.makedirs(cached_dir, exist_ok=True) 78 | r = pd.DataFrame(results, columns=['fname', 'prob', 'label']) 79 | r.to_csv(os.path.join(cached_dir, f'cached_{calc_num // math.ceil(len(samples) / 100):03d}.csv'), 80 | index=False) 81 | results = [] 82 | if cached_dir is not None and len(results): 83 | os.makedirs(cached_dir, exist_ok=True) 84 | r = pd.DataFrame(results, columns=['fname', 'prob', 'label']) 85 | r.to_csv(os.path.join(cached_dir, f'cached_100.csv'), index=False) 86 | results = [] 87 | return results 88 | 89 | 90 | def init(config_path): 91 | config = json.loads(open(os.path.join(config_path, 'task.json')).read()) 92 | labels = [l.strip() for l in open(os.path.join(config_path, 'labels.txt'), encoding='utf8').readlines()] 93 | model_path = os.path.join(config_path, 'BEST-training-params.pth') 94 | assert 'model_name' in config and 'num_classes' in config and 'transform' in config 95 | # Configuration of transformer. 96 | transform_config = {'phase': 'valid'} 97 | transform_config.update(config['transform']) 98 | assert 'input_size' in transform_config, '`input_size` must in `transform`' 99 | transformer = create_standard_image_transformer(**transform_config) 100 | 101 | # Configuration of core 102 | model_config = {'pretrained': False, 'model_name': config['model_name'], 'num_classes': config['num_classes']} 103 | model = create_model(**model_config) 104 | # Configuration of device 105 | device_info = config.get('device', None) or 'cpu' 106 | device = torch.device(device_info) 107 | model = model.to(device) 108 | state_dict = torch.load(model_path, map_location=device)['model_state_dict'] 109 | 110 | for key in list(state_dict.keys()): 111 | if key.startswith('module.'): 112 | new_key = key[7:] 113 | state_dict[new_key] = state_dict[key] 114 | del state_dict[key] 115 | model.load_state_dict(state_dict) 116 | model.eval() 117 | return model, transformer, labels, device 118 | 119 | 120 | if __name__ == "__main__": 121 | parser = argparse.ArgumentParser(description='PyTorch Classification Inference') 122 | 123 | parser.add_argument('-c', '--config_path', dest='c', default='20211014/resnet18/viz', 124 | help='Model and transformer configuration') 125 | parser.add_argument('-d', '--directory', dest='d', 126 | default=r'G:\skin_classification\images', help='Inference data directory.') 127 | parser.add_argument('-l', '--list_file', dest='l', default=None, help='Inference data list file') 128 | 129 | args = parser.parse_args() 130 | 131 | if args.d is not None: 132 | test_samples = [os.path.join(args.d, p) for p in os.listdir(args.d) if p.endswith('.jpg')] 133 | elif args.l is not None: 134 | with open(args.l) as f: 135 | test_samples = [l.strip() for l in f.readlines()] 136 | else: 137 | raise ValueError('You must provide a directory or list file for inference.') 138 | model, transformer, labels_, device = init(config_path=args.c) 139 | results_ = inference(test_samples, model, transformer, labels_, device) 140 | print(json.dumps(results_, ensure_ascii=False, indent=True)) 141 | -------------------------------------------------------------------------------- /onekey_comp/comp4-What(分类识别)/What-特征提取.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "13d5c640", 6 | "metadata": {}, 7 | "source": [ 8 | "## 深度学习特征\n", 9 | "\n", 10 | "提取CT、MRI、内镜、Xray等影像数据的深度学习特征。\n", 11 | "\n", 12 | "### Onekey步骤\n", 13 | "\n", 14 | "1. 将待提取的数据转化成jpg,可以参考使用OKT-convert2jpg或者OKT-crop_max_roi两个Onekey工具。\n", 15 | "2. 获取到指定目录的所有图像数据。\n", 16 | "3. 选择要提取什么样的模型的深度学习特征,目前Onekey支持主流的深度学习模型。(可以考虑使用Onekey进行迁移学习)\n", 17 | "4. 提取特征,保存特征文件。" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "id": "b21b469b", 23 | "metadata": {}, 24 | "source": [ 25 | "## 获取待提取特征的文件\n", 26 | "\n", 27 | "提供两种批量处理的模式:\n", 28 | "1. 目录模式,提取指定目录下的所有jpg文件的特征。\n", 29 | "2. 文件模式,待提取的数据存储在文件中,每行一个样本。\n", 30 | "\n", 31 | "当然也可以在最后自己指定手动提取指定若干文件。" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "f8dee942", 38 | "metadata": { 39 | "scrolled": true 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "import os\n", 44 | "# 目录模式\n", 45 | "mydir = r'C:\\Users\\onekey\\Project\\OnekeyDS\\CT\\crop'\n", 46 | "# mydir = r'C:\\Users\\onekey\\Project\\OnekeyDS\\CT\\full'\n", 47 | "directory = os.path.expanduser(mydir)\n", 48 | "test_samples = [os.path.join(directory, p) for p in os.listdir(directory) if p.endswith('.png') or p.endswith('.jpg')]\n", 49 | "\n", 50 | "# 文件模式\n", 51 | "# test_file = ''\n", 52 | "# with open(test_file) as f:\n", 53 | "# test_samples = [l.strip() for l in f.readlines()]\n", 54 | "\n", 55 | "# 自定义模式\n", 56 | "# test_sampleses = ['path2jpg']\n", 57 | "test_samples" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "26847144", 63 | "metadata": {}, 64 | "source": [ 65 | "## 确定提取特征\n", 66 | "\n", 67 | "通过关键词获取要提取那一层的特征。\n", 68 | "\n", 69 | "### 支持的模型名称\n", 70 | "\n", 71 | "模型名称替换代码中的 `model_name`变量的值。\n", 72 | "\n", 73 | "| **模型系列** | **模型名称** |\n", 74 | "| ------------ | ------------------------------------------------------------ |\n", 75 | "| AlexNet | alexnet |\n", 76 | "| VGG | vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19_bn, vgg19 |\n", 77 | "| ResNet | resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2 |\n", 78 | "| DenseNet | densenet121, densenet169, densenet201, densenet161 |\n", 79 | "| Inception | googlenet, inception_v3 |\n", 80 | "| SqueezeNet | squeezenet1_0, squeezenet1_1 |\n", 81 | "| ShuffleNetV2 | shufflenet_v2_x2_0, shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5 |\n", 82 | "| MobileNet | mobilenet_v2, mobilenet_v3_large, mobilenet_v3_small |\n", 83 | "| MNASNet | mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3 |" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "9e8d607a", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "from onekey_algo.custom.components.comp2 import extract, print_feature_hook, reg_hook_on_module, \\\n", 94 | " init_from_model, init_from_onekey\n", 95 | "\n", 96 | "model_name = 'resnet50'\n", 97 | "model, transformer, device = init_from_model(model_name=model_name)\n", 98 | "# model, transformer, device = init_from_onekey(r'')\n", 99 | "for n, m in model.named_modules():\n", 100 | " print('Feature name:', n, \"|| Module:\", m)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "id": "87f29370", 106 | "metadata": {}, 107 | "source": [ 108 | "## 提取特征\n", 109 | "\n", 110 | "`Feature name:` 之后的名称为要提取的特征名,例如`layer3.0.conv2`, 一般深度学习特征提取最后一层,例如`avgpool`" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "id": "541bfc3f", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "from functools import partial\n", 121 | "feature_name = 'avgpool'\n", 122 | "with open('feature.csv', 'w') as outfile:\n", 123 | " hook = partial(print_feature_hook, fp=outfile)\n", 124 | " find_num = reg_hook_on_module(feature_name, model, hook)\n", 125 | " results = extract(test_samples, model, transformer, device, fp=outfile)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "id": "9d3cbd6f", 131 | "metadata": {}, 132 | "source": [ 133 | "## 读取数据" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "id": "a7ac1309", 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "import pandas as pd\n", 144 | "features = pd.read_csv('feature.csv', header=None)\n", 145 | "features.columns=['ID'] + list(features.columns[1:])\n", 146 | "features.head()" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "id": "9634e5f2", 152 | "metadata": {}, 153 | "source": [ 154 | "### 深度特征压缩\n", 155 | "\n", 156 | "深度学习特征压缩,注意压缩到的维度需要小于样本数\n", 157 | "\n", 158 | "```python\n", 159 | "def compress_df_feature(features: pd.DataFrame, dim: int, not_compress: Union[str, List[str]] = None,\n", 160 | " prefix='') -> pd.DataFrame:\n", 161 | " \"\"\"\n", 162 | " 压缩深度学习特征\n", 163 | " Args:\n", 164 | " features: 特征DataFrame\n", 165 | " dim: 需要压缩到的维度,此值需要小于样本数\n", 166 | " not_compress: 不进行压缩的列。\n", 167 | " prefix: 所有特征的前缀。\n", 168 | "\n", 169 | " Returns:\n", 170 | "\n", 171 | " \"\"\"\n", 172 | "```" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "8c649a10", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "from onekey_algo.custom.components.comp1 import compress_df_feature\n", 183 | "\n", 184 | "cm_features = compress_df_feature(features=features, dim=32, prefix='DL_', not_compress='ID')\n", 185 | "cm_features.to_csv('compress_features.csv', header=True, index=False)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "id": "b1285696", 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [] 195 | } 196 | ], 197 | "metadata": { 198 | "kernelspec": { 199 | "display_name": "Python 3 (ipykernel)", 200 | "language": "python", 201 | "name": "python3" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.7.12" 214 | } 215 | }, 216 | "nbformat": 4, 217 | "nbformat_minor": 5 218 | } 219 | -------------------------------------------------------------------------------- /onekey_core/models/resnet3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = [ 9 | 'ResNet3D', 'resnet10_3d', 'resnet18_3d', 'resnet34_3d', 'resnet50_3d', 'resnet101_3d', 10 | 'resnet152_3d', 'resnet200_3d' 11 | ] 12 | 13 | 14 | def conv3x3x3(in_planes, out_planes, stride=1, dilation=1): 15 | # 3x3x3 convolution with padding 16 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, dilation=dilation, stride=stride, 17 | padding=dilation, bias=False) 18 | 19 | 20 | def downsample_basic_block(x, planes, stride, no_cuda=False): 21 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 22 | zero_pads = torch.Tensor(out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4)).zero_() 23 | if not no_cuda: 24 | if isinstance(out.data, torch.cuda.FloatTensor): 25 | zero_pads = zero_pads.cuda() 26 | 27 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 28 | 29 | return out 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | expansion = 1 34 | 35 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 36 | super(BasicBlock, self).__init__() 37 | self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation) 38 | self.bn1 = nn.BatchNorm3d(planes) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = conv3x3x3(planes, planes, dilation=dilation) 41 | self.bn2 = nn.BatchNorm3d(planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | self.dilation = dilation 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | residual = self.downsample(x) 57 | 58 | out += residual 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 4 66 | 67 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 68 | super(Bottleneck, self).__init__() 69 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 70 | self.bn1 = nn.BatchNorm3d(planes) 71 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, 72 | bias=False) 73 | self.bn2 = nn.BatchNorm3d(planes) 74 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 75 | self.bn3 = nn.BatchNorm3d(planes * 4) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | self.dilation = dilation 80 | 81 | def forward(self, x): 82 | residual = x 83 | 84 | out = self.conv1(x) 85 | out = self.bn1(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv2(out) 89 | out = self.bn2(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv3(out) 93 | out = self.bn3(out) 94 | 95 | if self.downsample is not None: 96 | residual = self.downsample(x) 97 | 98 | out += residual 99 | out = self.relu(out) 100 | 101 | return out 102 | 103 | 104 | class ResNet3D(nn.Module): 105 | 106 | def __init__(self, block, layers, sample_input_D, sample_input_H, sample_input_W, num_seg_classes, 107 | shortcut_type='B', no_cuda=False): 108 | self.inplanes = 64 109 | self.no_cuda = no_cuda 110 | super(ResNet3D, self).__init__() 111 | self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=(2, 2, 2), padding=(3, 3, 3), bias=False) 112 | 113 | self.bn1 = nn.BatchNorm3d(64) 114 | self.relu = nn.ReLU(inplace=True) 115 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 116 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) 117 | self.layer2 = self._make_layer( 118 | block, 128, layers[1], shortcut_type, stride=2) 119 | self.layer3 = self._make_layer( 120 | block, 256, layers[2], shortcut_type, stride=1, dilation=2) 121 | self.layer4 = self._make_layer( 122 | block, 512, layers[3], shortcut_type, stride=1, dilation=4) 123 | 124 | self.conv_seg = nn.Sequential( 125 | nn.ConvTranspose3d(512 * block.expansion, 32, 2, stride=2), 126 | nn.BatchNorm3d(32), 127 | nn.ReLU(inplace=True), 128 | nn.Conv3d(32, 32, kernel_size=3, stride=(1, 1, 1), padding=(1, 1, 1), bias=False), 129 | nn.BatchNorm3d(32), 130 | nn.ReLU(inplace=True), 131 | nn.Conv3d(32, num_seg_classes, kernel_size=1, stride=(1, 1, 1), bias=False) 132 | ) 133 | 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv3d): 136 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 137 | elif isinstance(m, nn.BatchNorm3d): 138 | m.weight.data.fill_(1) 139 | m.bias.data.zero_() 140 | 141 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1): 142 | downsample = None 143 | if stride != 1 or self.inplanes != planes * block.expansion: 144 | if shortcut_type == 'A': 145 | downsample = partial( 146 | downsample_basic_block, 147 | planes=planes * block.expansion, 148 | stride=stride, 149 | no_cuda=self.no_cuda) 150 | else: 151 | downsample = nn.Sequential( 152 | nn.Conv3d( 153 | self.inplanes, 154 | planes * block.expansion, 155 | kernel_size=1, 156 | stride=stride, 157 | bias=False), nn.BatchNorm3d(planes * block.expansion)) 158 | 159 | layers = [] 160 | layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample)) 161 | self.inplanes = planes * block.expansion 162 | for i in range(1, blocks): 163 | layers.append(block(self.inplanes, planes, dilation=dilation)) 164 | 165 | return nn.Sequential(*layers) 166 | 167 | def forward(self, x): 168 | x = self.conv1(x) 169 | x = self.bn1(x) 170 | x = self.relu(x) 171 | x = self.maxpool(x) 172 | x = self.layer1(x) 173 | x = self.layer2(x) 174 | x = self.layer3(x) 175 | x = self.layer4(x) 176 | x = self.conv_seg(x) 177 | 178 | return x 179 | 180 | 181 | def resnet10_3d(**kwargs): 182 | """Constructs a ResNet-18 model. 183 | """ 184 | model = ResNet3D(BasicBlock, [1, 1, 1, 1], **kwargs) 185 | return model 186 | 187 | 188 | def resnet18_3d(**kwargs): 189 | """Constructs a ResNet-18 model. 190 | """ 191 | model = ResNet3D(BasicBlock, [2, 2, 2, 2], **kwargs) 192 | return model 193 | 194 | 195 | def resnet34_3d(**kwargs): 196 | """Constructs a ResNet-34 model. 197 | """ 198 | model = ResNet3D(BasicBlock, [3, 4, 6, 3], **kwargs) 199 | return model 200 | 201 | 202 | def resnet50_3d(**kwargs): 203 | """Constructs a ResNet-50 model. 204 | """ 205 | model = ResNet3D(Bottleneck, [3, 4, 6, 3], **kwargs) 206 | return model 207 | 208 | 209 | def resnet101_3d(**kwargs): 210 | """Constructs a ResNet-101 model. 211 | """ 212 | model = ResNet3D(Bottleneck, [3, 4, 23, 3], **kwargs) 213 | return model 214 | 215 | 216 | def resnet152_3d(**kwargs): 217 | """Constructs a ResNet-101 model. 218 | """ 219 | model = ResNet3D(Bottleneck, [3, 8, 36, 3], **kwargs) 220 | return model 221 | 222 | 223 | def resnet200_3d(**kwargs): 224 | """Constructs a ResNet-101 model. 225 | """ 226 | model = ResNet3D(Bottleneck, [3, 24, 36, 3], **kwargs) 227 | return model 228 | -------------------------------------------------------------------------------- /onekey_core/models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .utils import load_state_dict_from_url 5 | 6 | __all__ = [ 7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 8 | 'vgg19_bn', 'vgg19', 9 | ] 10 | 11 | model_urls = { 12 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 13 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 14 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 15 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 16 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 17 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 18 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 19 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 20 | } 21 | 22 | 23 | class VGG(nn.Module): 24 | 25 | def __init__(self, features, num_classes=1000, init_weights=True, **kwargs): 26 | super(VGG, self).__init__() 27 | self.features = features 28 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 29 | self.classifier = nn.Sequential( 30 | nn.Linear(512 * 7 * 7, 4096), 31 | nn.ReLU(True), 32 | nn.Dropout(), 33 | nn.Linear(4096, 4096), 34 | nn.ReLU(True), 35 | nn.Dropout(), 36 | nn.Linear(4096, num_classes), 37 | ) 38 | if init_weights: 39 | self._initialize_weights() 40 | 41 | def forward(self, x): 42 | x = self.features(x) 43 | x = self.avgpool(x) 44 | x = torch.flatten(x, 1) 45 | x = self.classifier(x) 46 | return x 47 | 48 | def _initialize_weights(self): 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 52 | if m.bias is not None: 53 | nn.init.constant_(m.bias, 0) 54 | elif isinstance(m, nn.BatchNorm2d): 55 | nn.init.constant_(m.weight, 1) 56 | nn.init.constant_(m.bias, 0) 57 | elif isinstance(m, nn.Linear): 58 | nn.init.normal_(m.weight, 0, 0.01) 59 | nn.init.constant_(m.bias, 0) 60 | 61 | 62 | def make_layers(cfg, in_channels: int = 3, batch_norm=False): 63 | layers = [] 64 | for v in cfg: 65 | if v == 'M': 66 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 67 | else: 68 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 69 | if batch_norm: 70 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 71 | else: 72 | layers += [conv2d, nn.ReLU(inplace=True)] 73 | in_channels = v 74 | return nn.Sequential(*layers) 75 | 76 | 77 | cfgs = { 78 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 79 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 80 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 81 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 82 | } 83 | 84 | 85 | def _vgg(arch, cfg, batch_norm, pretrained, progress, transfer_learning=True, **kwargs): 86 | if pretrained: 87 | kwargs['init_weights'] = False 88 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 89 | if pretrained: 90 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 91 | parameters_list = list(state_dict.keys()) 92 | for k in parameters_list: 93 | if k.startswith('classifier.') and transfer_learning: 94 | del state_dict[k] 95 | model.load_state_dict(state_dict, strict=False) 96 | return model 97 | 98 | 99 | def vgg11(pretrained=False, progress=True, **kwargs): 100 | r"""VGG 11-layer core (configuration "A") from 101 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" '_ 102 | 103 | Args: 104 | pretrained (bool): If True, returns a core pre-trained on ImageNet 105 | progress (bool): If True, displays a progress bar of the download to stderr 106 | """ 107 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 108 | 109 | 110 | def vgg11_bn(pretrained=False, progress=True, **kwargs): 111 | r"""VGG 11-layer core (configuration "A") with batch normalization 112 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" '_ 113 | 114 | Args: 115 | pretrained (bool): If True, returns a core pre-trained on ImageNet 116 | progress (bool): If True, displays a progress bar of the download to stderr 117 | """ 118 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 119 | 120 | 121 | def vgg13(pretrained=False, progress=True, **kwargs): 122 | r"""VGG 13-layer core (configuration "B") 123 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" '_ 124 | 125 | Args: 126 | pretrained (bool): If True, returns a core pre-trained on ImageNet 127 | progress (bool): If True, displays a progress bar of the download to stderr 128 | """ 129 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 130 | 131 | 132 | def vgg13_bn(pretrained=False, progress=True, **kwargs): 133 | r"""VGG 13-layer core (configuration "B") with batch normalization 134 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" '_ 135 | 136 | Args: 137 | pretrained (bool): If True, returns a core pre-trained on ImageNet 138 | progress (bool): If True, displays a progress bar of the download to stderr 139 | """ 140 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 141 | 142 | 143 | def vgg16(pretrained=False, progress=True, **kwargs): 144 | r"""VGG 16-layer core (configuration "D") 145 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" '_ 146 | 147 | Args: 148 | pretrained (bool): If True, returns a core pre-trained on ImageNet 149 | progress (bool): If True, displays a progress bar of the download to stderr 150 | """ 151 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 152 | 153 | 154 | def vgg16_bn(pretrained=False, progress=True, **kwargs): 155 | r"""VGG 16-layer core (configuration "D") with batch normalization 156 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" '_ 157 | 158 | Args: 159 | pretrained (bool): If True, returns a core pre-trained on ImageNet 160 | progress (bool): If True, displays a progress bar of the download to stderr 161 | """ 162 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 163 | 164 | 165 | def vgg19(pretrained=False, progress=True, **kwargs): 166 | r"""VGG 19-layer core (configuration "E") 167 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" '_ 168 | 169 | Args: 170 | pretrained (bool): If True, returns a core pre-trained on ImageNet 171 | progress (bool): If True, displays a progress bar of the download to stderr 172 | """ 173 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 174 | 175 | 176 | def vgg19_bn(pretrained=False, progress=True, **kwargs): 177 | r"""VGG 19-layer core (configuration 'E') with batch normalization 178 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" '_ 179 | 180 | Args: 181 | pretrained (bool): If True, returns a core pre-trained on ImageNet 182 | progress (bool): If True, displays a progress bar of the download to stderr 183 | """ 184 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) 185 | -------------------------------------------------------------------------------- /onekey_comp/comp9-Solutions/sol1. 传统组学-单中心-临床/Step2. 临床基线统计分析.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 统计分析\n", 8 | "\n", 9 | "通过指定统计分析字段,得到每个特征的p_value,所有的p_value计算都是基于Ttest计算。支持指定不同的分组`group`,例如train、val、test等分组统计。\n", 10 | "\n", 11 | "对于两大类不同的特征\n", 12 | "\n", 13 | "1. 离散特征,统计数量以及占比。\n", 14 | "2. 连续特征,统计均值、方差。" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "scrolled": false 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "import pandas as pd\n", 26 | "import numpy as np\n", 27 | "from onekey_algo import OnekeyDS as okds\n", 28 | "from onekey_algo import get_param_in_cwd\n", 29 | "from onekey_algo.custom.utils import print_join_info\n", 30 | "\n", 31 | "task = get_param_in_cwd('task_column') or 'label'\n", 32 | "p_value = get_param_in_cwd('p_value') or 0.05\n", 33 | "# 修改成自己临床数据的文件。\n", 34 | "test_data = pd.read_csv(get_param_in_cwd('clinic_file') or okds.survival)\n", 35 | "stats_columns_settings = get_param_in_cwd('stats_columns')\n", 36 | "continuous_columns_settings = get_param_in_cwd('continuous_columns')\n", 37 | "mapping_columns_settings = get_param_in_cwd('mapping_columns')\n", 38 | "test_data = test_data[[c for c in test_data.columns if c != task]]\n", 39 | "test_data['ID'] = test_data['ID'].map(lambda x: f\"{x}.nii.gz\" if not (f\"{x}\".endswith('.nii.gz') or f\"{x}\".endswith('.nii')) else x)\n", 40 | "group_info = pd.read_csv('group.csv')\n", 41 | "print_join_info(test_data, group_info)\n", 42 | "test_data = pd.merge(test_data, group_info, on='ID', how='inner')\n", 43 | "test_data" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "# 特征名称处理\n", 51 | "\n", 52 | "去掉所有特征名称中的特殊字符。" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "import re\n", 62 | "\n", 63 | "def map_cnames(x):\n", 64 | " x = re.split('[(|(]', x)[0]\n", 65 | " x = x.replace('-', '_').replace(' ', '_').replace('>', '').replace('/', '_')\n", 66 | " return x.strip()\n", 67 | "\n", 68 | "test_data.columns = list(map(map_cnames, test_data.columns))\n", 69 | "test_data.columns" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "# 分析数据\n", 77 | "\n", 78 | "获取待分析的特征列名,如未制定,自动侦测。" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "stats_columns = stats_columns_settings or list(test_data.columns[1:-2])\n", 88 | "test_data = test_data.copy()[['ID'] + stats_columns + ['group', 'label']]\n", 89 | "test_data" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "# 特征队列映射\n", 97 | "\n", 98 | "所有需要进行特征映射的队列,range未制定,可以进行自动判断。" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "mapping_columns = mapping_columns_settings or [c for c in test_data.columns[1:-2] if test_data[c].dtype == object]\n", 108 | "mapping_columns" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "# 数据映射\n", 116 | "\n", 117 | "针对所有非数值形式的数据,可以进行类别映射。" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "from onekey_algo.custom.utils import map2numerical\n", 127 | "\n", 128 | "data, mapping = map2numerical(test_data, mapping_columns=mapping_columns)\n", 129 | "mapping" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "# 连续特征列\n", 137 | "\n", 138 | "自动识别所有可能的连续特征列。如果列不是整数,或者列的元素超过5个,则呗认定为连续特征。" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "continuous_columns = []\n", 148 | "for c in stats_columns:\n", 149 | " if len(np.unique(test_data[c])) > 5 or not np.int8 <= test_data[c].dtype <= np.int64:\n", 150 | " continuous_columns.append(c)\n", 151 | " \n", 152 | "continuous_columns = continuous_columns_settings or continuous_columns\n", 153 | "continuous_columns" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "# 缺失值填充" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "import os\n", 170 | "from onekey_algo.custom.components.comp1 import fillna\n", 171 | "os.makedirs('data', exist_ok=True)\n", 172 | "data.to_csv('data/clinical.csv', index=False)\n", 173 | "data = fillna(data)\n", 174 | "data" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | "### 统计分析\n", 182 | "\n", 183 | "支持两种格式数据,分别对应`pretty`参数的`True`和`False`, 当为`True`时,输出的是表格模式,反之则为dict数据。\n", 184 | "\n", 185 | "```python\n", 186 | "def clinic_stats(data: DataFrame, stats_columns: Union[str, List[str]], label_column='label',\n", 187 | " group_column: str = None, continuous_columns: Union[str, List[str]] = None,\n", 188 | " pretty: bool = True) -> Union[dict, DataFrame]:\n", 189 | " \"\"\"\n", 190 | "\n", 191 | " Args:\n", 192 | " data: 数据\n", 193 | " stats_columns: 需要统计的列名\n", 194 | " label_column: 二分类的标签列,默认`label`\n", 195 | " group_column: 分组统计依据,例如区分训练组、测试组、验证组。\n", 196 | " continuous_columns: 那些列是连续变量,连续变量统计均值方差。\n", 197 | " pretty: bool, 是否对结果进行格式美化。\n", 198 | "\n", 199 | " Returns:\n", 200 | " stats DataFrame or json\n", 201 | "\n", 202 | " \"\"\"\n", 203 | "```" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": { 210 | "scrolled": false 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "from onekey_algo.custom.components.stats import clinic_stats\n", 215 | "\n", 216 | "pd.set_option('display.max_rows', None)\n", 217 | "stats = clinic_stats(data, \n", 218 | " stats_columns= stats_columns,\n", 219 | " label_column=task, \n", 220 | " group_column='group', \n", 221 | " continuous_columns= continuous_columns, \n", 222 | " pretty=True, verbose=False)\n", 223 | "stats.to_csv('stats.csv', index=False, encoding='utf_8_sig')\n", 224 | "stats" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "sel_idx = [True if (isinstance(pv[0], str) and pv[0] != '') or (isinstance(pv[0], float) and pv[0] < p_value) else False \n", 234 | " for pv in np.array(stats['pvalue'])]\n", 235 | "sel_data = data[['ID'] + list(stats[sel_idx]['feature_name']) + ['group', 'label']]\n", 236 | "sel_data.to_csv('clinic_sel.csv', index=False)\n", 237 | "sel_data" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [] 246 | } 247 | ], 248 | "metadata": { 249 | "kernelspec": { 250 | "display_name": "Python 3 (ipykernel)", 251 | "language": "python", 252 | "name": "python3" 253 | }, 254 | "language_info": { 255 | "codemirror_mode": { 256 | "name": "ipython", 257 | "version": 3 258 | }, 259 | "file_extension": ".py", 260 | "mimetype": "text/x-python", 261 | "name": "python", 262 | "nbconvert_exporter": "python", 263 | "pygments_lexer": "ipython3", 264 | "version": "3.7.12" 265 | } 266 | }, 267 | "nbformat": 4, 268 | "nbformat_minor": 4 269 | } 270 | -------------------------------------------------------------------------------- /onekey_core/models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .utils import load_state_dict_from_url 5 | 6 | __all__ = [ 7 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 8 | 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' 9 | ] 10 | 11 | model_urls = { 12 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', 13 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', 14 | 'shufflenetv2_x1.5': None, 15 | 'shufflenetv2_x2.0': None, 16 | } 17 | 18 | 19 | def channel_shuffle(x, groups): 20 | batchsize, num_channels, height, width = x.data.size() 21 | channels_per_group = num_channels // groups 22 | 23 | # reshape 24 | x = x.view(batchsize, groups, 25 | channels_per_group, height, width) 26 | 27 | x = torch.transpose(x, 1, 2).contiguous() 28 | 29 | # flatten 30 | x = x.view(batchsize, -1, height, width) 31 | 32 | return x 33 | 34 | 35 | class InvertedResidual(nn.Module): 36 | def __init__(self, inp, oup, stride): 37 | super(InvertedResidual, self).__init__() 38 | 39 | if not (1 <= stride <= 3): 40 | raise ValueError('illegal stride value') 41 | self.stride = stride 42 | 43 | branch_features = oup // 2 44 | assert (self.stride != 1) or (inp == branch_features << 1) 45 | 46 | if self.stride > 1: 47 | self.branch1 = nn.Sequential( 48 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 49 | nn.BatchNorm2d(inp), 50 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 51 | nn.BatchNorm2d(branch_features), 52 | nn.ReLU(inplace=True), 53 | ) 54 | 55 | self.branch2 = nn.Sequential( 56 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 57 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 58 | nn.BatchNorm2d(branch_features), 59 | nn.ReLU(inplace=True), 60 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 61 | nn.BatchNorm2d(branch_features), 62 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 63 | nn.BatchNorm2d(branch_features), 64 | nn.ReLU(inplace=True), 65 | ) 66 | 67 | @staticmethod 68 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): 69 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 70 | 71 | def forward(self, x): 72 | if self.stride == 1: 73 | x1, x2 = x.chunk(2, dim=1) 74 | out = torch.cat((x1, self.branch2(x2)), dim=1) 75 | else: 76 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 77 | 78 | out = channel_shuffle(out, 2) 79 | 80 | return out 81 | 82 | 83 | class ShuffleNetV2(nn.Module): 84 | def __init__(self, stages_repeats, stages_out_channels, input_channels: int = 3, num_classes=1000, **kwargs): 85 | super(ShuffleNetV2, self).__init__() 86 | 87 | if len(stages_repeats) != 3: 88 | raise ValueError('expected stages_repeats as list of 3 positive ints') 89 | if len(stages_out_channels) != 5: 90 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 91 | self._stage_out_channels = stages_out_channels 92 | 93 | output_channels = self._stage_out_channels[0] 94 | self.conv1 = nn.Sequential( 95 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), 96 | nn.BatchNorm2d(output_channels), 97 | nn.ReLU(inplace=True), 98 | ) 99 | input_channels = output_channels 100 | 101 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 102 | 103 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 104 | for name, repeats, output_channels in zip( 105 | stage_names, stages_repeats, self._stage_out_channels[1:]): 106 | seq = [InvertedResidual(input_channels, output_channels, 2)] 107 | for i in range(repeats - 1): 108 | seq.append(InvertedResidual(output_channels, output_channels, 1)) 109 | setattr(self, name, nn.Sequential(*seq)) 110 | input_channels = output_channels 111 | 112 | output_channels = self._stage_out_channels[-1] 113 | self.conv5 = nn.Sequential( 114 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 115 | nn.BatchNorm2d(output_channels), 116 | nn.ReLU(inplace=True), 117 | ) 118 | 119 | self.fc = nn.Linear(output_channels, num_classes) 120 | 121 | def forward(self, x): 122 | x = self.conv1(x) 123 | x = self.maxpool(x) 124 | x = self.stage2(x) 125 | x = self.stage3(x) 126 | x = self.stage4(x) 127 | x = self.conv5(x) 128 | x = x.mean([2, 3]) # globalpool 129 | x = self.fc(x) 130 | return x 131 | 132 | 133 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): 134 | model = ShuffleNetV2(*args, **kwargs) 135 | 136 | if pretrained: 137 | model_url = model_urls[arch] 138 | if model_url is None: 139 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 140 | else: 141 | state_dict = load_state_dict_from_url(model_url, progress=progress) 142 | parameters_list = list(state_dict.keys()) 143 | for k in parameters_list: 144 | if k.startswith('fc.'): 145 | del state_dict[k] 146 | model.load_state_dict(state_dict, strict=False) 147 | 148 | return model 149 | 150 | 151 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): 152 | """ 153 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in 154 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 155 | `_. 156 | 157 | Args: 158 | pretrained (bool): If True, returns a core pre-trained on ImageNet 159 | progress (bool): If True, displays a progress bar of the download to stderr 160 | """ 161 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, 162 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) 163 | 164 | 165 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): 166 | """ 167 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in 168 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 169 | `_. 170 | 171 | Args: 172 | pretrained (bool): If True, returns a core pre-trained on ImageNet 173 | progress (bool): If True, displays a progress bar of the download to stderr 174 | """ 175 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, 176 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) 177 | 178 | 179 | def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs): 180 | """ 181 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in 182 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 183 | `_. 184 | 185 | Args: 186 | pretrained (bool): If True, returns a core pre-trained on ImageNet 187 | progress (bool): If True, displays a progress bar of the download to stderr 188 | """ 189 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, 190 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) 191 | 192 | 193 | def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs): 194 | """ 195 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in 196 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 197 | `_. 198 | 199 | Args: 200 | pretrained (bool): If True, returns a core pre-trained on ImageNet 201 | progress (bool): If True, displays a progress bar of the download to stderr 202 | """ 203 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, 204 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) 205 | -------------------------------------------------------------------------------- /onekey_core/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, List 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch import nn 6 | 7 | from .utils import load_state_dict_from_url 8 | 9 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 10 | 11 | model_urls = { 12 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 13 | } 14 | 15 | 16 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: 17 | """ 18 | This function is taken from the original tf repo. 19 | It ensures that all layers have a channel number that is divisible by 8 20 | It can be seen here: 21 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 22 | """ 23 | if min_value is None: 24 | min_value = divisor 25 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 26 | # Make sure that round down does not go down by more than 10%. 27 | if new_v < 0.9 * v: 28 | new_v += divisor 29 | return new_v 30 | 31 | 32 | class ConvBNActivation(nn.Sequential): 33 | def __init__( 34 | self, 35 | in_planes: int, 36 | out_planes: int, 37 | kernel_size: int = 3, 38 | stride: int = 1, 39 | groups: int = 1, 40 | norm_layer: Optional[Callable[..., nn.Module]] = None, 41 | activation_layer: Optional[Callable[..., nn.Module]] = None, 42 | dilation: int = 1, 43 | ) -> None: 44 | padding = (kernel_size - 1) // 2 * dilation 45 | if norm_layer is None: 46 | norm_layer = nn.BatchNorm2d 47 | if activation_layer is None: 48 | activation_layer = nn.ReLU6 49 | super().__init__( 50 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups, 51 | bias=False), 52 | norm_layer(out_planes), 53 | activation_layer(inplace=True) 54 | ) 55 | self.out_channels = out_planes 56 | 57 | 58 | # necessary for backwards compatibility 59 | ConvBNReLU = ConvBNActivation 60 | 61 | 62 | class InvertedResidual(nn.Module): 63 | def __init__( 64 | self, 65 | inp: int, 66 | oup: int, 67 | stride: int, 68 | expand_ratio: int, 69 | norm_layer: Optional[Callable[..., nn.Module]] = None 70 | ) -> None: 71 | super(InvertedResidual, self).__init__() 72 | self.stride = stride 73 | assert stride in [1, 2] 74 | 75 | if norm_layer is None: 76 | norm_layer = nn.BatchNorm2d 77 | 78 | hidden_dim = int(round(inp * expand_ratio)) 79 | self.use_res_connect = self.stride == 1 and inp == oup 80 | 81 | layers: List[nn.Module] = [] 82 | if expand_ratio != 1: 83 | # pw 84 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 85 | layers.extend([ 86 | # dw 87 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), 88 | # pw-linear 89 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 90 | norm_layer(oup), 91 | ]) 92 | self.conv = nn.Sequential(*layers) 93 | self.out_channels = oup 94 | self._is_cn = stride > 1 95 | 96 | def forward(self, x: Tensor) -> Tensor: 97 | if self.use_res_connect: 98 | return x + self.conv(x) 99 | else: 100 | return self.conv(x) 101 | 102 | 103 | class MobileNetV2(nn.Module): 104 | def __init__( 105 | self, 106 | in_channels: int = 3, 107 | num_classes: int = 1000, 108 | width_mult: float = 1.0, 109 | inverted_residual_setting: Optional[List[List[int]]] = None, 110 | round_nearest: int = 8, 111 | block: Optional[Callable[..., nn.Module]] = None, 112 | norm_layer: Optional[Callable[..., nn.Module]] = None, 113 | **kwargs 114 | ) -> None: 115 | """ 116 | MobileNet V2 main class 117 | 118 | Args: 119 | num_classes (int): Number of classes 120 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 121 | inverted_residual_setting: Network structure 122 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 123 | Set to 1 to turn off rounding 124 | block: Module specifying inverted residual building block for mobilenet 125 | norm_layer: Module specifying the normalization layer to use 126 | 127 | """ 128 | super(MobileNetV2, self).__init__() 129 | 130 | if block is None: 131 | block = InvertedResidual 132 | 133 | if norm_layer is None: 134 | norm_layer = nn.BatchNorm2d 135 | 136 | input_channel = 32 137 | last_channel = 1280 138 | 139 | if inverted_residual_setting is None: 140 | inverted_residual_setting = [ 141 | # t, c, n, s 142 | [1, 16, 1, 1], 143 | [6, 24, 2, 2], 144 | [6, 32, 3, 2], 145 | [6, 64, 4, 2], 146 | [6, 96, 3, 1], 147 | [6, 160, 3, 2], 148 | [6, 320, 1, 1], 149 | ] 150 | 151 | # only check the first element, assuming user knows t,c,n,s are required 152 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 153 | raise ValueError("inverted_residual_setting should be non-empty " 154 | "or a 4-element list, got {}".format(inverted_residual_setting)) 155 | 156 | # building first layer 157 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 158 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 159 | features: List[nn.Module] = [ConvBNReLU(in_channels, input_channel, stride=2, norm_layer=norm_layer)] 160 | # building inverted residual blocks 161 | for t, c, n, s in inverted_residual_setting: 162 | output_channel = _make_divisible(c * width_mult, round_nearest) 163 | for i in range(n): 164 | stride = s if i == 0 else 1 165 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 166 | input_channel = output_channel 167 | # building last several layers 168 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 169 | # make it nn.Sequential 170 | self.features = nn.Sequential(*features) 171 | 172 | # building classifier 173 | self.classifier = nn.Sequential( 174 | nn.Dropout(0.2), 175 | nn.Linear(self.last_channel, num_classes), 176 | ) 177 | 178 | # weight initialization 179 | for m in self.modules(): 180 | if isinstance(m, nn.Conv2d): 181 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 182 | if m.bias is not None: 183 | nn.init.zeros_(m.bias) 184 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 185 | nn.init.ones_(m.weight) 186 | nn.init.zeros_(m.bias) 187 | elif isinstance(m, nn.Linear): 188 | nn.init.normal_(m.weight, 0, 0.01) 189 | nn.init.zeros_(m.bias) 190 | 191 | def _forward_impl(self, x: Tensor) -> Tensor: 192 | # This exists since TorchScript doesn't support inheritance, so the superclass method 193 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 194 | x = self.features(x) 195 | # Cannot use "squeeze" as batch-size can be 1 196 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) 197 | x = torch.flatten(x, 1) 198 | x = self.classifier(x) 199 | return x 200 | 201 | def forward(self, x: Tensor) -> Tensor: 202 | return self._forward_impl(x) 203 | 204 | 205 | def mobilenet_v2(pretrained=False, progress=True, transfer_learning=True, **kwargs): 206 | """ 207 | Constructs a MobileNetV2 architecture from 208 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 209 | 210 | Args: 211 | pretrained (bool): If True, returns a core pre-trained on ImageNet 212 | progress (bool): If True, displays a progress bar of the download to stderr 213 | transfer_learning: 214 | """ 215 | model = MobileNetV2(**kwargs) 216 | if pretrained: 217 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], progress=progress) 218 | parameters_list = list(state_dict.keys()) 219 | for k in parameters_list: 220 | if k.startswith('classifier.') and transfer_learning: 221 | del state_dict[k] 222 | model.load_state_dict(state_dict, strict=False) 223 | return model 224 | -------------------------------------------------------------------------------- /onekey_core/models/mnasnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .utils import load_state_dict_from_url 5 | 6 | __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] 7 | 8 | _MODEL_URLS = { 9 | "mnasnet0_5": 10 | "https://download.pytorch.org/models/mnasnet0.5_top1_67.592-7c6cb539b9.pth", 11 | "mnasnet0_75": None, 12 | "mnasnet1_0": 13 | "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", 14 | "mnasnet1_3": None 15 | } 16 | 17 | # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is 18 | # 1.0 - tensorflow. 19 | _BN_MOMENTUM = 1 - 0.9997 20 | 21 | 22 | class _InvertedResidual(nn.Module): 23 | 24 | def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, 25 | bn_momentum=0.1): 26 | super(_InvertedResidual, self).__init__() 27 | assert stride in [1, 2] 28 | assert kernel_size in [3, 5] 29 | mid_ch = in_ch * expansion_factor 30 | self.apply_residual = (in_ch == out_ch and stride == 1) 31 | self.layers = nn.Sequential( 32 | # Pointwise 33 | nn.Conv2d(in_ch, mid_ch, 1, bias=False), 34 | nn.BatchNorm2d(mid_ch, momentum=bn_momentum), 35 | nn.ReLU(inplace=True), 36 | # Depthwise 37 | nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, 38 | stride=stride, groups=mid_ch, bias=False), 39 | nn.BatchNorm2d(mid_ch, momentum=bn_momentum), 40 | nn.ReLU(inplace=True), 41 | # Linear pointwise. Note that there's no activation. 42 | nn.Conv2d(mid_ch, out_ch, 1, bias=False), 43 | nn.BatchNorm2d(out_ch, momentum=bn_momentum)) 44 | 45 | def forward(self, input): 46 | if self.apply_residual: 47 | return self.layers(input) + input 48 | else: 49 | return self.layers(input) 50 | 51 | 52 | def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, 53 | bn_momentum): 54 | """ Creates a stack of inverted residuals. """ 55 | assert repeats >= 1 56 | # First one has no skip, because feature map size changes. 57 | first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, 58 | bn_momentum=bn_momentum) 59 | remaining = [] 60 | for _ in range(1, repeats): 61 | remaining.append( 62 | _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, 63 | bn_momentum=bn_momentum)) 64 | return nn.Sequential(first, *remaining) 65 | 66 | 67 | def _round_to_multiple_of(val, divisor, round_up_bias=0.9): 68 | """ Asymmetric rounding to make `val` divisible by `divisor`. With default 69 | bias, will round up, unless the number is no more than 10% greater than the 70 | smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ 71 | assert 0.0 < round_up_bias < 1.0 72 | new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) 73 | return new_val if new_val >= round_up_bias * val else new_val + divisor 74 | 75 | 76 | def _scale_depths(depths, alpha): 77 | """ Scales tensor depths as in reference MobileNet code, prefers rouding up 78 | rather than down. """ 79 | return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] 80 | 81 | 82 | class MNASNet(torch.nn.Module): 83 | """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. 84 | >>> core = MNASNet(1000, 1.0) 85 | >>> x = torch.rand(1, 3, 224, 224) 86 | >>> y = core(x) 87 | >>> y.dim() 88 | 1 89 | >>> y.nelement() 90 | 1000 91 | """ 92 | 93 | def __init__(self, alpha, in_channels: int = 3, num_classes=1000, dropout=0.2, **kwargs): 94 | super(MNASNet, self).__init__() 95 | depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha) 96 | layers = [ 97 | # First layer: regular conv. 98 | nn.Conv2d(in_channels, 32, 3, padding=1, stride=2, bias=False), 99 | nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), 100 | nn.ReLU(inplace=True), 101 | # Depthwise separable, no skip. 102 | nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False), 103 | nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), 106 | nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), 107 | # MNASNet blocks: stacks of inverted residuals. 108 | _stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM), 109 | _stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM), 110 | _stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM), 111 | _stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM), 112 | _stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM), 113 | _stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM), 114 | # Final mapping to classifier input. 115 | nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False), 116 | nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), 117 | nn.ReLU(inplace=True), 118 | ] 119 | self.layers = nn.Sequential(*layers) 120 | self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), 121 | nn.Linear(1280, num_classes)) 122 | self._initialize_weights() 123 | 124 | def forward(self, x): 125 | x = self.layers(x) 126 | # Equivalent to global avgpool and removing H and W dimensions. 127 | x = x.mean([2, 3]) 128 | return self.classifier(x) 129 | 130 | def _initialize_weights(self): 131 | for m in self.modules(): 132 | if isinstance(m, nn.Conv2d): 133 | nn.init.kaiming_normal_(m.weight, mode="fan_out", 134 | nonlinearity="relu") 135 | if m.bias is not None: 136 | nn.init.zeros_(m.bias) 137 | elif isinstance(m, nn.BatchNorm2d): 138 | nn.init.ones_(m.weight) 139 | nn.init.zeros_(m.bias) 140 | elif isinstance(m, nn.Linear): 141 | nn.init.normal_(m.weight, 0.01) 142 | nn.init.zeros_(m.bias) 143 | 144 | 145 | def _load_pretrained(model_name, model, progress): 146 | if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: 147 | raise ValueError( 148 | "No checkpoint is available for core type {}".format(model_name)) 149 | checkpoint_url = _MODEL_URLS[model_name] 150 | state_dict = load_state_dict_from_url(checkpoint_url, progress=progress) 151 | parameters_list = list(state_dict.keys()) 152 | for k in parameters_list: 153 | if k.startswith('classifier.'): 154 | del state_dict[k] 155 | model.load_state_dict(state_dict, strict=False) 156 | 157 | 158 | def mnasnet0_5(pretrained=False, progress=True, **kwargs): 159 | """MNASNet with depth multiplier of 0.5 from 160 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 161 | `_. 162 | Args: 163 | pretrained (bool): If True, returns a core pre-trained on ImageNet 164 | progress (bool): If True, displays a progress bar of the download to stderr 165 | """ 166 | model = MNASNet(0.5, **kwargs) 167 | if pretrained: 168 | _load_pretrained("mnasnet0_5", model, progress) 169 | return model 170 | 171 | 172 | def mnasnet0_75(pretrained=False, progress=True, **kwargs): 173 | """MNASNet with depth multiplier of 0.75 from 174 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 175 | `_. 176 | Args: 177 | pretrained (bool): If True, returns a core pre-trained on ImageNet 178 | progress (bool): If True, displays a progress bar of the download to stderr 179 | """ 180 | model = MNASNet(0.75, **kwargs) 181 | if pretrained: 182 | _load_pretrained("mnasnet0_75", model, progress) 183 | return model 184 | 185 | 186 | def mnasnet1_0(pretrained=False, progress=True, **kwargs): 187 | """MNASNet with depth multiplier of 1.0 from 188 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 189 | `_. 190 | Args: 191 | pretrained (bool): If True, returns a core pre-trained on ImageNet 192 | progress (bool): If True, displays a progress bar of the download to stderr 193 | """ 194 | model = MNASNet(1.0, **kwargs) 195 | if pretrained: 196 | _load_pretrained("mnasnet1_0", model, progress) 197 | return model 198 | 199 | 200 | def mnasnet1_3(pretrained=False, progress=True, **kwargs): 201 | """MNASNet with depth multiplier of 1.3 from 202 | `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" 203 | `_. 204 | Args: 205 | pretrained (bool): If True, returns a core pre-trained on ImageNet 206 | progress (bool): If True, displays a progress bar of the download to stderr 207 | """ 208 | model = MNASNet(1.3, **kwargs) 209 | if pretrained: 210 | _load_pretrained("mnasnet1_3", model, progress) 211 | return model 212 | -------------------------------------------------------------------------------- /onekey_core/models/res2net_v1b.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | from torch.hub import load_state_dict_from_url 7 | 8 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b', 'res2net50_v1b_26w_4s'] 9 | 10 | model_urls = { 11 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 12 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth', 13 | } 14 | 15 | 16 | class Bottle2neck(nn.Module): 17 | expansion = 4 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'): 20 | """ Constructor 21 | Args: 22 | inplanes: input channel dimensionality 23 | planes: output channel dimensionality 24 | stride: conv stride. Replaces pooling layer. 25 | downsample: None when stride = 1 26 | baseWidth: basic width of conv3x3 27 | scale: number of scale. 28 | type: 'normal': normal set. 'stage': first block of a new stage. 29 | """ 30 | super(Bottle2neck, self).__init__() 31 | 32 | width = int(math.floor(planes * (baseWidth / 64.0))) 33 | self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(width * scale) 35 | 36 | if scale == 1: 37 | self.nums = 1 38 | else: 39 | self.nums = scale - 1 40 | if stype == 'stage': 41 | self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) 42 | convs = [] 43 | bns = [] 44 | for i in range(self.nums): 45 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False)) 46 | bns.append(nn.BatchNorm2d(width)) 47 | self.convs = nn.ModuleList(convs) 48 | self.bns = nn.ModuleList(bns) 49 | 50 | self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 52 | 53 | self.relu = nn.ReLU(inplace=True) 54 | self.downsample = downsample 55 | self.stype = stype 56 | self.scale = scale 57 | self.width = width 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | spx = torch.split(out, self.width, 1) 67 | for i in range(self.nums): 68 | if i == 0 or self.stype == 'stage': 69 | sp = spx[i] 70 | else: 71 | sp = sp + spx[i] 72 | sp = self.convs[i](sp) 73 | sp = self.relu(self.bns[i](sp)) 74 | if i == 0: 75 | out = sp 76 | else: 77 | out = torch.cat((out, sp), 1) 78 | if self.scale != 1 and self.stype == 'normal': 79 | out = torch.cat((out, spx[self.nums]), 1) 80 | elif self.scale != 1 and self.stype == 'stage': 81 | out = torch.cat((out, self.pool(spx[self.nums])), 1) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class Res2Net(nn.Module): 96 | 97 | def __init__(self, block, layers, in_channels: int = 3, baseWidth=26, scale=4, num_classes=1000, **kwargs): 98 | self.inplanes = 64 99 | super(Res2Net, self).__init__() 100 | self.baseWidth = baseWidth 101 | self.scale = scale 102 | self.conv1 = nn.Sequential( 103 | nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False), 104 | nn.BatchNorm2d(32), 105 | nn.ReLU(inplace=True), 106 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 107 | nn.BatchNorm2d(32), 108 | nn.ReLU(inplace=True), 109 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 110 | ) 111 | self.bn1 = nn.BatchNorm2d(64) 112 | self.relu = nn.ReLU() 113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 114 | self.layer1 = self._make_layer(block, 64, layers[0]) 115 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 116 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 117 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 118 | self.avgpool = nn.AdaptiveAvgPool2d(1) 119 | self.fc = nn.Linear(512 * block.expansion, num_classes) 120 | 121 | for m in self.modules(): 122 | if isinstance(m, nn.Conv2d): 123 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 124 | elif isinstance(m, nn.BatchNorm2d): 125 | nn.init.constant_(m.weight, 1) 126 | nn.init.constant_(m.bias, 0) 127 | 128 | def _make_layer(self, block, planes, blocks, stride=1): 129 | downsample = None 130 | if stride != 1 or self.inplanes != planes * block.expansion: 131 | downsample = nn.Sequential( 132 | nn.AvgPool2d(kernel_size=stride, stride=stride, 133 | ceil_mode=True, count_include_pad=False), 134 | nn.Conv2d(self.inplanes, planes * block.expansion, 135 | kernel_size=1, stride=1, bias=False), 136 | nn.BatchNorm2d(planes * block.expansion), 137 | ) 138 | 139 | layers = [] 140 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 141 | stype='stage', baseWidth=self.baseWidth, scale=self.scale)) 142 | self.inplanes = planes * block.expansion 143 | for i in range(1, blocks): 144 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def forward(self, x): 149 | x = self.conv1(x) 150 | x = self.bn1(x) 151 | x = self.relu(x) 152 | x = self.maxpool(x) 153 | 154 | x = self.layer1(x) 155 | x = self.layer2(x) 156 | x = self.layer3(x) 157 | x = self.layer4(x) 158 | 159 | x = self.avgpool(x) 160 | x = x.view(x.size(0), -1) 161 | x = self.fc(x) 162 | 163 | return x 164 | 165 | 166 | def res2net50_v1b(pretrained=False, **kwargs): 167 | """Constructs a Res2Net-50_v1b lib. 168 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 169 | Args: 170 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 171 | """ 172 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 173 | if pretrained: 174 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 175 | return model 176 | 177 | 178 | def res2net101_v1b(pretrained=False, **kwargs): 179 | """Constructs a Res2Net-50_v1b_26w_4s lib. 180 | Args: 181 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 182 | """ 183 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 184 | if pretrained: 185 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'], 186 | map_location=torch.device('cpu'))) 187 | return model 188 | 189 | 190 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs): 191 | """Constructs a Res2Net-50_v1b_26w_4s lib. 192 | Args: 193 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 194 | """ 195 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 196 | if pretrained: 197 | model_state = load_state_dict_from_url(model_urls['res2net50_v1b_26w_4s'], progress=True, 198 | map_location=torch.device('cpu')) 199 | model.load_state_dict(model_state) 200 | # lib.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 201 | return model 202 | 203 | 204 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs): 205 | """Constructs a Res2Net-50_v1b_26w_4s lib. 206 | Args: 207 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 208 | """ 209 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 210 | if pretrained: 211 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'], 212 | map_location=torch.device('cpu'))) 213 | return model 214 | 215 | 216 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs): 217 | """Constructs a Res2Net-50_v1b_26w_4s lib. 218 | Args: 219 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 220 | """ 221 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth=26, scale=4, **kwargs) 222 | if pretrained: 223 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s'], 224 | map_location=torch.device('cpu'))) 225 | return model 226 | 227 | 228 | if __name__ == '__main__': 229 | images = torch.rand(1, 3, 224, 224).cuda(0) 230 | model = res2net50_v1b_26w_4s(pretrained=True) 231 | model = model.cuda(0) 232 | print(model(images).size()) 233 | --------------------------------------------------------------------------------