├── demo
├── Demo_MultiTask_MMoE.py
├── Demo_MultiTask_PLE.py
├── Demo_SeBlockwithResnet_MultiInstanseClassification.py
├── Demo_dualbranchnet_forMultiInput_classification.py
├── Demo0_VGG_SingleLabelClassification.py
├── multi_label
│ └── generate_multilabel_dataset.py
├── Demo1_ResNet_SingleLabelClassification.py
├── Demo2_ResNet_MultiLabelClassification.py
├── Demo3_ResNetUnet_SingleLabelSegmentation.py
├── Demo4_ResNetUnet_MultiLabelSegmentation.py
├── Demo6_UnetwithFPN_segmentation.py
├── _explain_how2useVitAsNeck.py
├── Demo5_MultiTask_SegAndCls.py
├── Demo7_2D_TransUnet_Segmentation.py
├── Demo8_3D_TransUnet_Segmentation.py
└── Demo_BilinearPooling.py
├── wama_modules
├── __init__.py
├── thirdparty_lib
│ ├── __init__.py
│ ├── C3D_jfzhang95
│ │ ├── __init__.py
│ │ └── c3d.py
│ ├── C3D_yyuanad
│ │ ├── __init__.py
│ │ └── c3d.py
│ ├── MedicalNet_Tencent
│ │ ├── __init__.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ └── __pycache__
│ │ │ │ ├── resnet.cpython-38.pyc
│ │ │ │ └── __init__.cpython-38.pyc
│ │ └── model.py
│ ├── VC3D_kenshohara
│ │ ├── __init__.py
│ │ ├── wide_resnet.py
│ │ └── resnext.py
│ ├── Efficient3D_okankop
│ │ ├── __init__.py
│ │ └── models
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ ├── c3d.cpython-38.pyc
│ │ │ ├── resnet.cpython-38.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── resnext.cpython-38.pyc
│ │ │ ├── mobilenet.cpython-38.pyc
│ │ │ ├── mobilenetv2.cpython-38.pyc
│ │ │ ├── shufflenet.cpython-38.pyc
│ │ │ ├── squeezenet.cpython-38.pyc
│ │ │ └── shufflenetv2.cpython-38.pyc
│ │ │ ├── c3d.py
│ │ │ ├── mobilenet.py
│ │ │ ├── squeezenet.py
│ │ │ ├── shufflenet.py
│ │ │ └── mobilenetv2.py
│ ├── ResNets3D_kenshohara
│ │ └── __init__.py
│ ├── get_model.py
│ └── SMP_qubvel
│ │ ├── __init__.py
│ │ ├── __version__.py
│ │ ├── encoders
│ │ ├── __pycache__
│ │ │ ├── dpn.cpython-38.pyc
│ │ │ ├── vgg.cpython-38.pyc
│ │ │ ├── _base.cpython-38.pyc
│ │ │ ├── senet.cpython-38.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── _utils.cpython-38.pyc
│ │ │ ├── densenet.cpython-38.pyc
│ │ │ ├── resnet.cpython-38.pyc
│ │ │ ├── xception.cpython-38.pyc
│ │ │ ├── mobilenet.cpython-38.pyc
│ │ │ ├── timm_sknet.cpython-38.pyc
│ │ │ ├── efficientnet.cpython-38.pyc
│ │ │ ├── inceptionv4.cpython-38.pyc
│ │ │ ├── timm_gernet.cpython-38.pyc
│ │ │ ├── timm_regnet.cpython-38.pyc
│ │ │ ├── timm_res2net.cpython-38.pyc
│ │ │ ├── timm_resnest.cpython-38.pyc
│ │ │ ├── _preprocessing.cpython-38.pyc
│ │ │ ├── mix_transformer.cpython-38.pyc
│ │ │ ├── timm_universal.cpython-38.pyc
│ │ │ ├── inceptionresnetv2.cpython-38.pyc
│ │ │ ├── timm_efficientnet.cpython-38.pyc
│ │ │ └── timm_mobilenetv3.cpython-38.pyc
│ │ ├── _preprocessing.py
│ │ ├── timm_universal.py
│ │ ├── _utils.py
│ │ ├── _base.py
│ │ ├── xception.py
│ │ ├── mobilenet.py
│ │ ├── inceptionresnetv2.py
│ │ ├── inceptionv4.py
│ │ ├── timm_sknet.py
│ │ ├── __init__.py
│ │ ├── timm_gernet.py
│ │ ├── densenet.py
│ │ ├── vgg.py
│ │ ├── timm_res2net.py
│ │ ├── senet.py
│ │ ├── dpn.py
│ │ ├── timm_mobilenetv3.py
│ │ └── efficientnet.py
│ │ ├── base
│ │ ├── __init__.py
│ │ ├── initialization.py
│ │ ├── heads.py
│ │ ├── model.py
│ │ └── modules.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── losses.py
│ │ ├── meter.py
│ │ ├── base.py
│ │ ├── metrics.py
│ │ ├── train.py
│ │ └── functional.py
├── Head.py
├── utils.py
├── Decoder.py
└── Neck.py
├── make_rqs.py
├── images
└── transUnet.png
├── requirements.txt
├── setup.py
├── LICENSE
└── Document_allmodules.md
/demo/Demo_MultiTask_MMoE.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/demo/Demo_MultiTask_PLE.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wama_modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/C3D_jfzhang95/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/C3D_yyuanad/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/demo/Demo_SeBlockwithResnet_MultiInstanseClassification.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/demo/Demo_dualbranchnet_forMultiInput_classification.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/MedicalNet_Tencent/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/VC3D_kenshohara/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/ResNets3D_kenshohara/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/MedicalNet_Tencent/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/make_rqs.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.system('pipreqs ./ --encoding=utf8 --force')
--------------------------------------------------------------------------------
/images/transUnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/images/transUnet.png
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/get_model.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/__init__.py:
--------------------------------------------------------------------------------
1 | from . import encoders
2 | from .__version__ import __version__
3 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/__version__.py:
--------------------------------------------------------------------------------
1 | VERSION = (0, 3, 0)
2 |
3 | __version__ = ".".join(map(str, VERSION))
4 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/dpn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/dpn.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/vgg.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/vgg.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/_base.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/_base.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/senet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/senet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/densenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/densenet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/xception.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/xception.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/c3d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/c3d.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/mobilenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/mobilenet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_sknet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_sknet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/MedicalNet_Tencent/models/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/MedicalNet_Tencent/models/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/efficientnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/efficientnet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/inceptionv4.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/inceptionv4.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_gernet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_gernet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_regnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_regnet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_res2net.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_res2net.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_resnest.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_resnest.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/resnext.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/resnext.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/MedicalNet_Tencent/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/MedicalNet_Tencent/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/_preprocessing.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/_preprocessing.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/mix_transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/mix_transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_universal.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_universal.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/mobilenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/mobilenet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/mobilenetv2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/mobilenetv2.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/shufflenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/shufflenet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/squeezenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/squeezenet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/inceptionresnetv2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/inceptionresnetv2.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_efficientnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_efficientnet.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_mobilenetv3.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_mobilenetv3.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/shufflenetv2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/shufflenetv2.cpython-38.pyc
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import SegmentationModel
2 |
3 | from .modules import (
4 | Conv2dReLU,
5 | Attention,
6 | )
7 |
8 | from .heads import (
9 | SegmentationHead,
10 | ClassificationHead,
11 | )
12 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | efficientnet_pytorch==0.7.1
2 | einops==0.6.0
3 | inplace_abn==1.1.0
4 | matplotlib==3.5.1
5 | numpy==1.23.5
6 | pretrainedmodels==0.7.4
7 | prettytable==3.6.0
8 | torch
9 | torchtext==0.12.0
10 | torchvision==0.12.0
11 |
12 | setuptools
13 | timm
14 | tqdm
15 | transformers
16 |
17 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from . import train
4 | from . import losses
5 | from . import metrics
6 |
7 | warnings.warn(
8 | "`smp.utils` module is deprecated and will be removed in future releases.",
9 | DeprecationWarning,
10 | )
11 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/_preprocessing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def preprocess_input(x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs):
5 |
6 | if input_space == "BGR":
7 | x = x[..., ::-1].copy()
8 |
9 | if input_range is not None:
10 | if x.max() > 1 and input_range[1] == 1:
11 | x = x / 255.0
12 |
13 | if mean is not None:
14 | mean = np.array(mean)
15 | x = x - mean
16 |
17 | if std is not None:
18 | std = np.array(std)
19 | x = x / std
20 |
21 | return x
22 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/MedicalNet_Tencent/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .models import resnet
4 |
5 |
6 | def generate_model(model_depth):
7 | if model_depth == 10:
8 | model = resnet.resnet10()
9 | elif model_depth == 18:
10 | model = resnet.resnet18()
11 | elif model_depth == 34:
12 | model = resnet.resnet34()
13 | elif model_depth == 50:
14 | model = resnet.resnet50()
15 | elif model_depth == 101:
16 | model = resnet.resnet101()
17 | elif model_depth == 152:
18 | model = resnet.resnet152()
19 | elif model_depth == 200:
20 | model = resnet.resnet200()
21 | return model
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 | import io
3 | import os
4 | import sys
5 |
6 | here = os.path.abspath(os.path.dirname(__file__))
7 |
8 | # What packages are required for this module to be executed?
9 | try:
10 | with open(os.path.join(here, "requirements.txt"), encoding="utf-8") as f:
11 | REQUIRED = f.read().split("\n")
12 | except:
13 | REQUIRED = []
14 |
15 |
16 | setup(
17 | name='aini_modules',
18 | version='0.0.1',
19 | description='Enjoy~',
20 | author='wamawama',
21 | author_email='wmy19970215@gmail.com',
22 | python_requires=">=3.6.0",
23 | url='https://github.com/WAMAWAMA/wama_modules',
24 | packages=find_packages(exclude=("demo", "docs", "images")),
25 | install_requires=REQUIRED,
26 | license="MIT",
27 | )
28 |
29 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/base/initialization.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def initialize_decoder(module):
5 | for m in module.modules():
6 |
7 | if isinstance(m, nn.Conv2d):
8 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
9 | if m.bias is not None:
10 | nn.init.constant_(m.bias, 0)
11 |
12 | elif isinstance(m, nn.BatchNorm2d):
13 | nn.init.constant_(m.weight, 1)
14 | nn.init.constant_(m.bias, 0)
15 |
16 | elif isinstance(m, nn.Linear):
17 | nn.init.xavier_uniform_(m.weight)
18 | if m.bias is not None:
19 | nn.init.constant_(m.bias, 0)
20 |
21 |
22 | def initialize_head(module):
23 | for m in module.modules():
24 | if isinstance(m, (nn.Linear, nn.Conv2d)):
25 | nn.init.xavier_uniform_(m.weight)
26 | if m.bias is not None:
27 | nn.init.constant_(m.bias, 0)
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2022, 2023 the wama_modules Project
2 | All rights reserved.
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy
5 | of this software and associated documentation files (the "Software"), to deal
6 | in the Software without restriction, including without limitation the rights
7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 | copies of the Software, and to permit persons to whom the Software is
9 | furnished to do so, subject to the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be included in
12 | all copies or substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 | THE SOFTWARE.
21 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/base/heads.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .modules import Activation
3 |
4 |
5 | class SegmentationHead(nn.Sequential):
6 | def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1):
7 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
8 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
9 | activation = Activation(activation)
10 | super().__init__(conv2d, upsampling, activation)
11 |
12 |
13 | class ClassificationHead(nn.Sequential):
14 | def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None):
15 | if pooling not in ("max", "avg"):
16 | raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling))
17 | pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1)
18 | flatten = nn.Flatten()
19 | dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
20 | linear = nn.Linear(in_channels, classes, bias=True)
21 | activation = Activation(activation)
22 | super().__init__(pool, flatten, dropout, linear, activation)
23 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/timm_universal.py:
--------------------------------------------------------------------------------
1 | import timm
2 | import torch.nn as nn
3 |
4 |
5 | class TimmUniversalEncoder(nn.Module):
6 | def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32):
7 | super().__init__()
8 | kwargs = dict(
9 | in_chans=in_channels,
10 | features_only=True,
11 | output_stride=output_stride,
12 | pretrained=pretrained,
13 | out_indices=tuple(range(depth)),
14 | )
15 |
16 | # not all models support output stride argument, drop it by default
17 | if output_stride == 32:
18 | kwargs.pop("output_stride")
19 |
20 | self.model = timm.create_model(name, **kwargs)
21 |
22 | self._in_channels = in_channels
23 | self._out_channels = [
24 | in_channels,
25 | ] + self.model.feature_info.channels()
26 | self._depth = depth
27 | self._output_stride = output_stride
28 |
29 | def forward(self, x):
30 | features = self.model(x)
31 | features = [
32 | x,
33 | ] + features
34 | return features
35 |
36 | @property
37 | def out_channels(self):
38 | return self._out_channels
39 |
40 | @property
41 | def output_stride(self):
42 | return min(self._output_stride, 2**self._depth)
43 |
--------------------------------------------------------------------------------
/demo/Demo0_VGG_SingleLabelClassification.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from wama_modules.Encoder import VGGEncoder
4 | from wama_modules.Head import ClassificationHead
5 | from wama_modules.BaseModule import GlobalMaxPool
6 |
7 |
8 | class Model(nn.Module):
9 | def __init__(self, in_channel, label_category_dict, dim=2):
10 | super().__init__()
11 | # encoder
12 | f_channel_list = [64, 128, 256, 512]
13 | self.encoder = VGGEncoder(
14 | in_channel,
15 | stage_output_channels=f_channel_list,
16 | blocks=[1, 2, 3, 4],
17 | downsample_ration=[0.5, 0.5, 0.5, 0.5],
18 | dim=dim)
19 | # cls head
20 | self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1])
21 | self.pooling = GlobalMaxPool()
22 |
23 | def forward(self, x):
24 | f = self.encoder(x)
25 | logits = self.cls_head(self.pooling(f[-1]))
26 | return logits
27 |
28 |
29 | if __name__ == '__main__':
30 | x = torch.ones([2, 1, 64, 64, 64])
31 | category_num = 1
32 | label_category_dict = dict(is_malignant=category_num)
33 | model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
34 | logits = model(x)
35 | print('single-label predicted logits')
36 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
37 |
38 | # output 👇
39 | # single-label predicted logits
40 | # logits of is_malignant : torch.Size([2, 4])
41 |
--------------------------------------------------------------------------------
/demo/multi_label/generate_multilabel_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | generate example multi_label dataset
3 | """
4 | import numpy as np
5 |
6 |
7 | label_category_dict = dict(
8 | water=2, # binary class label
9 | milk=2, # binary class label
10 | cow=2, # binary class label
11 | big_white_fish=2, # binary class label
12 | grass=2, # binary class label
13 | weather=5, # 5-class label
14 | )
15 | label_name = list(label_category_dict.keys())
16 | print('label num :',len(label_name))
17 | _ = [print('-'*4, key, ': class num = ', label_category_dict[key]) for key in label_category_dict.keys()]
18 |
19 | img_channel = 2
20 | dataset = [
21 | dict( # case1
22 | img_1D=np.ones([128, img_channel]), # which is also a 1D signal
23 | img_2D=np.ones([128, 128, img_channel]),
24 | img_3D=np.ones([64, 64, 64, img_channel]),
25 | label_value=[1, 1, 1, 0, 0, 1], # 0=negative, 1=positive
26 | label_known=[1, 1, 1, 1, 1, 1], # 0=unknown/missing, 1=known
27 | ),
28 | dict( # case2
29 | img_1D=np.ones([128, img_channel]),
30 | img_2D=np.ones([128, 128, img_channel]),
31 | img_3D=np.ones([64, 64, 64, img_channel]),
32 | label_value=[1, 0, 0, 0, 1, 1],
33 | label_known=[1, 1, 1, 0, 0, 0],
34 | ),
35 | dict( # case3
36 | img_1D=np.ones([128, img_channel]),
37 | img_2D=np.ones([128, 128, img_channel]),
38 | img_3D=np.ones([64, 64, 64, img_channel]),
39 | label_value=[1, 1, 0, 1, 1, 0],
40 | label_known=[1, 1, 0, 1, 0, 0],
41 | ),
42 | ]
43 |
--------------------------------------------------------------------------------
/demo/Demo1_ResNet_SingleLabelClassification.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from wama_modules.Encoder import ResNetEncoder
4 | from wama_modules.Head import ClassificationHead
5 | from wama_modules.BaseModule import GlobalMaxPool
6 |
7 |
8 | class Model(nn.Module):
9 | def __init__(self, in_channel, label_category_dict, dim=2):
10 | super().__init__()
11 | # encoder
12 | f_channel_list = [64, 128, 256, 512]
13 | self.encoder = ResNetEncoder(
14 | in_channel,
15 | stage_output_channels=f_channel_list,
16 | stage_middle_channels=f_channel_list,
17 | blocks=[1, 2, 3, 4],
18 | type='131',
19 | downsample_ration=[0.5, 0.5, 0.5, 0.5],
20 | dim=dim)
21 | # cls head
22 | self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1])
23 | self.pooling = GlobalMaxPool()
24 |
25 | def forward(self, x):
26 | f = self.encoder(x)
27 | logits = self.cls_head(self.pooling(f[-1]))
28 | return logits
29 |
30 |
31 | if __name__ == '__main__':
32 | x = torch.ones([2, 1, 64, 64, 64])
33 | label_category_dict = dict(is_malignant=4)
34 | model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
35 | logits = model(x)
36 | print('single-label predicted logits')
37 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
38 |
39 | # output 👇
40 | # single-label predicted logits
41 | # logits of is_malignant : torch.Size([2, 4])
42 |
--------------------------------------------------------------------------------
/demo/Demo2_ResNet_MultiLabelClassification.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from wama_modules.Encoder import ResNetEncoder
4 | from wama_modules.Head import ClassificationHead
5 | from wama_modules.BaseModule import GlobalMaxPool
6 |
7 |
8 | class Model(nn.Module):
9 | def __init__(self, in_channel, label_category_dict, dim=2):
10 | super().__init__()
11 | # encoder
12 | f_channel_list = [64, 128, 256, 512]
13 | self.encoder = ResNetEncoder(
14 | in_channel,
15 | stage_output_channels=f_channel_list,
16 | stage_middle_channels=f_channel_list,
17 | blocks=[1, 2, 3, 4],
18 | type='131',
19 | downsample_ration=[0.5, 0.5, 0.5, 0.5],
20 | dim=dim)
21 | # cls head
22 | self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1])
23 |
24 | self.pooling = GlobalMaxPool()
25 |
26 | def forward(self, x):
27 | f = self.encoder(x)
28 | logits = self.cls_head(self.pooling(f[-1]))
29 | return logits
30 |
31 |
32 | if __name__ == '__main__':
33 | x = torch.ones([2, 1, 64, 64, 64])
34 | label_category_dict = dict(shape=4, color=3, other=13)
35 | model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
36 | logits = model(x)
37 | print('multi_label predicted logits')
38 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
39 |
40 | # out
41 | # multi_label predicted logits
42 | # logits of shape : torch.Size([2, 4])
43 | # logits of color : torch.Size([2, 3])
44 | # logits of other : torch.Size([2, 13])
45 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/utils/losses.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from . import base
4 | from . import functional as F
5 | from ..base.modules import Activation
6 |
7 |
8 | class JaccardLoss(base.Loss):
9 | def __init__(self, eps=1.0, activation=None, ignore_channels=None, **kwargs):
10 | super().__init__(**kwargs)
11 | self.eps = eps
12 | self.activation = Activation(activation)
13 | self.ignore_channels = ignore_channels
14 |
15 | def forward(self, y_pr, y_gt):
16 | y_pr = self.activation(y_pr)
17 | return 1 - F.jaccard(
18 | y_pr,
19 | y_gt,
20 | eps=self.eps,
21 | threshold=None,
22 | ignore_channels=self.ignore_channels,
23 | )
24 |
25 |
26 | class DiceLoss(base.Loss):
27 | def __init__(self, eps=1.0, beta=1.0, activation=None, ignore_channels=None, **kwargs):
28 | super().__init__(**kwargs)
29 | self.eps = eps
30 | self.beta = beta
31 | self.activation = Activation(activation)
32 | self.ignore_channels = ignore_channels
33 |
34 | def forward(self, y_pr, y_gt):
35 | y_pr = self.activation(y_pr)
36 | return 1 - F.f_score(
37 | y_pr,
38 | y_gt,
39 | beta=self.beta,
40 | eps=self.eps,
41 | threshold=None,
42 | ignore_channels=self.ignore_channels,
43 | )
44 |
45 |
46 | class L1Loss(nn.L1Loss, base.Loss):
47 | pass
48 |
49 |
50 | class MSELoss(nn.MSELoss, base.Loss):
51 | pass
52 |
53 |
54 | class CrossEntropyLoss(nn.CrossEntropyLoss, base.Loss):
55 | pass
56 |
57 |
58 | class NLLLoss(nn.NLLLoss, base.Loss):
59 | pass
60 |
61 |
62 | class BCELoss(nn.BCELoss, base.Loss):
63 | pass
64 |
65 |
66 | class BCEWithLogitsLoss(nn.BCEWithLogitsLoss, base.Loss):
67 | pass
68 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/utils/meter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Meter(object):
5 | """Meters provide a way to keep track of important statistics in an online manner.
6 | This class is abstract, but provides a standard interface for all meters to follow.
7 | """
8 |
9 | def reset(self):
10 | """Reset the meter to default settings."""
11 | pass
12 |
13 | def add(self, value):
14 | """Log a new value to the meter
15 | Args:
16 | value: Next result to include.
17 | """
18 | pass
19 |
20 | def value(self):
21 | """Get the value of the meter in the current state."""
22 | pass
23 |
24 |
25 | class AverageValueMeter(Meter):
26 | def __init__(self):
27 | super(AverageValueMeter, self).__init__()
28 | self.reset()
29 | self.val = 0
30 |
31 | def add(self, value, n=1):
32 | self.val = value
33 | self.sum += value
34 | self.var += value * value
35 | self.n += n
36 |
37 | if self.n == 0:
38 | self.mean, self.std = np.nan, np.nan
39 | elif self.n == 1:
40 | self.mean = 0.0 + self.sum # This is to force a copy in torch/numpy
41 | self.std = np.inf
42 | self.mean_old = self.mean
43 | self.m_s = 0.0
44 | else:
45 | self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n)
46 | self.m_s += (value - self.mean_old) * (value - self.mean)
47 | self.mean_old = self.mean
48 | self.std = np.sqrt(self.m_s / (self.n - 1.0))
49 |
50 | def value(self):
51 | return self.mean, self.std
52 |
53 | def reset(self):
54 | self.n = 0
55 | self.sum = 0.0
56 | self.var = 0.0
57 | self.val = 0.0
58 | self.mean = np.nan
59 | self.mean_old = 0.0
60 | self.m_s = 0.0
61 | self.std = np.nan
62 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True):
6 | """Change first convolution layer input channels.
7 | In case:
8 | in_channels == 1 or in_channels == 2 -> reuse original weights
9 | in_channels > 3 -> make random kaiming normal initialization
10 | """
11 |
12 | # get first conv
13 | for module in model.modules():
14 | if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels:
15 | break
16 |
17 | weight = module.weight.detach()
18 | module.in_channels = new_in_channels
19 |
20 | if not pretrained:
21 | module.weight = nn.parameter.Parameter(
22 | torch.Tensor(module.out_channels, new_in_channels // module.groups, *module.kernel_size)
23 | )
24 | module.reset_parameters()
25 |
26 | elif new_in_channels == 1:
27 | new_weight = weight.sum(1, keepdim=True)
28 | module.weight = nn.parameter.Parameter(new_weight)
29 |
30 | else:
31 | new_weight = torch.Tensor(module.out_channels, new_in_channels // module.groups, *module.kernel_size)
32 |
33 | for i in range(new_in_channels):
34 | new_weight[:, i] = weight[:, i % default_in_channels]
35 |
36 | new_weight = new_weight * (default_in_channels / new_in_channels)
37 | module.weight = nn.parameter.Parameter(new_weight)
38 |
39 |
40 | def replace_strides_with_dilation(module, dilation_rate):
41 | """Patch Conv2d modules replacing strides with dilation"""
42 | for mod in module.modules():
43 | if isinstance(mod, nn.Conv2d):
44 | mod.stride = (1, 1)
45 | mod.dilation = (dilation_rate, dilation_rate)
46 | kh, kw = mod.kernel_size
47 | mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate)
48 |
49 | # Kostyl for EfficientNet
50 | if hasattr(mod, "static_padding"):
51 | mod.static_padding = nn.Identity()
52 |
--------------------------------------------------------------------------------
/demo/Demo3_ResNetUnet_SingleLabelSegmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from wama_modules.Encoder import ResNetEncoder
4 | from wama_modules.Decoder import UNet_decoder
5 | from wama_modules.Head import SegmentationHead
6 | from wama_modules.utils import resizeTensor
7 |
8 |
9 | class Model(nn.Module):
10 | def __init__(self, in_channel, label_category_dict, dim=2):
11 | super().__init__()
12 | # encoder
13 | Encoder_f_channel_list = [64, 128, 256, 512]
14 | self.encoder = ResNetEncoder(
15 | in_channel,
16 | stage_output_channels=Encoder_f_channel_list,
17 | stage_middle_channels=Encoder_f_channel_list,
18 | blocks=[1, 2, 3, 4],
19 | type='131',
20 | downsample_ration=[0.5, 0.5, 0.5, 0.5],
21 | dim=dim)
22 | # decoder
23 | Decoder_f_channel_list = [32, 64, 128]
24 | self.decoder = UNet_decoder(
25 | in_channels_list=Encoder_f_channel_list,
26 | skip_connection=[False, True, True],
27 | out_channels_list=Decoder_f_channel_list,
28 | dim=dim)
29 | # seg head
30 | self.seg_head = SegmentationHead(
31 | label_category_dict,
32 | Decoder_f_channel_list[0],
33 | dim=dim)
34 |
35 | def forward(self, x):
36 | multi_scale_f1 = self.encoder(x)
37 | multi_scale_f2 = self.decoder(multi_scale_f1)
38 | f_for_seg = resizeTensor(multi_scale_f2[0], size=x.shape[2:])
39 | logits = self.seg_head(f_for_seg)
40 | return logits
41 |
42 |
43 | if __name__ == '__main__':
44 | x = torch.ones([2, 1, 128, 128, 128])
45 | label_category_dict = dict(organ=3)
46 | model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
47 | logits = model(x)
48 | print('multi_label predicted logits')
49 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
50 |
51 | # out
52 | # multi_label predicted logits
53 | # logits of organ : torch.Size([2, 3, 128, 128, 128])
54 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/base/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from . import initialization as init
3 |
4 |
5 | class SegmentationModel(torch.nn.Module):
6 | def initialize(self):
7 | init.initialize_decoder(self.decoder)
8 | init.initialize_head(self.segmentation_head)
9 | if self.classification_head is not None:
10 | init.initialize_head(self.classification_head)
11 |
12 | def check_input_shape(self, x):
13 |
14 | h, w = x.shape[-2:]
15 | output_stride = self.encoder.output_stride
16 | if h % output_stride != 0 or w % output_stride != 0:
17 | new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h
18 | new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w
19 | raise RuntimeError(
20 | f"Wrong input shape height={h}, width={w}. Expected image height and width "
21 | f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})."
22 | )
23 |
24 | def forward(self, x):
25 | """Sequentially pass `x` trough model`s encoder, decoder and heads"""
26 |
27 | self.check_input_shape(x)
28 |
29 | features = self.encoder(x)
30 | decoder_output = self.decoder(*features)
31 |
32 | masks = self.segmentation_head(decoder_output)
33 |
34 | if self.classification_head is not None:
35 | labels = self.classification_head(features[-1])
36 | return masks, labels
37 |
38 | return masks
39 |
40 | @torch.no_grad()
41 | def predict(self, x):
42 | """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`
43 |
44 | Args:
45 | x: 4D torch tensor with shape (batch_size, channels, height, width)
46 |
47 | Return:
48 | prediction: 4D torch tensor with shape (batch_size, classes, height, width)
49 |
50 | """
51 | if self.training:
52 | self.eval()
53 |
54 | x = self.forward(x)
55 |
56 | return x
57 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/utils/base.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch.nn as nn
3 |
4 |
5 | class BaseObject(nn.Module):
6 | def __init__(self, name=None):
7 | super().__init__()
8 | self._name = name
9 |
10 | @property
11 | def __name__(self):
12 | if self._name is None:
13 | name = self.__class__.__name__
14 | s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
15 | return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
16 | else:
17 | return self._name
18 |
19 |
20 | class Metric(BaseObject):
21 | pass
22 |
23 |
24 | class Loss(BaseObject):
25 | def __add__(self, other):
26 | if isinstance(other, Loss):
27 | return SumOfLosses(self, other)
28 | else:
29 | raise ValueError("Loss should be inherited from `Loss` class")
30 |
31 | def __radd__(self, other):
32 | return self.__add__(other)
33 |
34 | def __mul__(self, value):
35 | if isinstance(value, (int, float)):
36 | return MultipliedLoss(self, value)
37 | else:
38 | raise ValueError("Loss should be inherited from `BaseLoss` class")
39 |
40 | def __rmul__(self, other):
41 | return self.__mul__(other)
42 |
43 |
44 | class SumOfLosses(Loss):
45 | def __init__(self, l1, l2):
46 | name = "{} + {}".format(l1.__name__, l2.__name__)
47 | super().__init__(name=name)
48 | self.l1 = l1
49 | self.l2 = l2
50 |
51 | def __call__(self, *inputs):
52 | return self.l1.forward(*inputs) + self.l2.forward(*inputs)
53 |
54 |
55 | class MultipliedLoss(Loss):
56 | def __init__(self, loss, multiplier):
57 |
58 | # resolve name
59 | if len(loss.__name__.split("+")) > 1:
60 | name = "{} * ({})".format(multiplier, loss.__name__)
61 | else:
62 | name = "{} * {}".format(multiplier, loss.__name__)
63 | super().__init__(name=name)
64 | self.loss = loss
65 | self.multiplier = multiplier
66 |
67 | def __call__(self, *inputs):
68 | return self.multiplier * self.loss.forward(*inputs)
69 |
--------------------------------------------------------------------------------
/demo/Demo4_ResNetUnet_MultiLabelSegmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from wama_modules.Encoder import ResNetEncoder
4 | from wama_modules.Decoder import UNet_decoder
5 | from wama_modules.Head import SegmentationHead
6 | from wama_modules.utils import resizeTensor
7 |
8 |
9 | class Model(nn.Module):
10 | def __init__(self, in_channel, label_category_dict, dim=2):
11 | super().__init__()
12 | # encoder
13 | Encoder_f_channel_list = [64, 128, 256, 512]
14 | self.encoder = ResNetEncoder(
15 | in_channel,
16 | stage_output_channels=Encoder_f_channel_list,
17 | stage_middle_channels=Encoder_f_channel_list,
18 | blocks=[1, 2, 3, 4],
19 | type='131',
20 | downsample_ration=[0.5, 0.5, 0.5, 0.5],
21 | dim=dim)
22 | # decoder
23 | Decoder_f_channel_list = [32, 64, 128]
24 | self.decoder = UNet_decoder(
25 | in_channels_list=Encoder_f_channel_list,
26 | skip_connection=[False, True, True],
27 | out_channels_list=Decoder_f_channel_list,
28 | dim=dim)
29 | # seg head
30 | self.seg_head = SegmentationHead(
31 | label_category_dict,
32 | Decoder_f_channel_list[0],
33 | dim=dim)
34 |
35 | def forward(self, x):
36 | multi_scale_f1 = self.encoder(x)
37 | multi_scale_f2 = self.decoder(multi_scale_f1)
38 | f_for_seg = resizeTensor(multi_scale_f2[0], size=x.shape[2:])
39 | logits = self.seg_head(f_for_seg)
40 | return logits
41 |
42 |
43 | if __name__ == '__main__':
44 | x = torch.ones([2, 1, 128, 128, 128])
45 | label_category_dict = dict(organ=3, tumor=4)
46 | model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
47 | logits = model(x)
48 | print('multi_label predicted logits')
49 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
50 |
51 | # out
52 | # multi_label predicted logits
53 | # logits of organ : torch.Size([2, 3, 128, 128, 128])
54 | # logits of tumor : torch.Size([2, 4, 128, 128, 128])
55 |
56 |
57 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import List
4 | from collections import OrderedDict
5 |
6 | from . import _utils as utils
7 |
8 |
9 | class EncoderMixin:
10 | """Add encoder functionality such as:
11 | - output channels specification of feature tensors (produced by encoder)
12 | - patching first convolution for arbitrary input channels
13 | """
14 |
15 | _output_stride = 32
16 |
17 | @property
18 | def out_channels(self):
19 | """Return channels dimensions for each tensor of forward output of encoder"""
20 | return self._out_channels[: self._depth + 1]
21 |
22 | @property
23 | def output_stride(self):
24 | return min(self._output_stride, 2**self._depth)
25 |
26 | def set_in_channels(self, in_channels, pretrained=True):
27 | """Change first convolution channels"""
28 | if in_channels == 3:
29 | return
30 |
31 | self._in_channels = in_channels
32 | if self._out_channels[0] == 3:
33 | self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])
34 |
35 | utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained)
36 |
37 | def get_stages(self):
38 | """Override it in your implementation"""
39 | raise NotImplementedError
40 |
41 | def make_dilated(self, output_stride):
42 |
43 | if output_stride == 16:
44 | stage_list = [
45 | 5,
46 | ]
47 | dilation_list = [
48 | 2,
49 | ]
50 |
51 | elif output_stride == 8:
52 | stage_list = [4, 5]
53 | dilation_list = [2, 4]
54 |
55 | else:
56 | raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride))
57 |
58 | self._output_stride = output_stride
59 |
60 | stages = self.get_stages()
61 | for stage_indx, dilation_rate in zip(stage_list, dilation_list):
62 | utils.replace_strides_with_dilation(
63 | module=stages[stage_indx],
64 | dilation_rate=dilation_rate,
65 | )
66 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/C3D_yyuanad/c3d.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | import torch.nn as nn
4 |
5 |
6 | class C3D(nn.Module):
7 | """
8 | nb_classes: nb_classes in classification task, 101 for UCF101 dataset
9 | """
10 |
11 | def __init__(self,):
12 | super(C3D, self).__init__()
13 |
14 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
15 | self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
16 |
17 | self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1))
18 | self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
19 |
20 | self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
21 | self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
22 | self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
23 |
24 | self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
25 | self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
26 | self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
27 |
28 | self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
29 | self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
30 | self.pool5 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1))
31 |
32 | self.relu = nn.ReLU()
33 |
34 | def forward(self, x):
35 | f_list = []
36 | h = self.relu(self.conv1(x))
37 | h = self.pool1(h)
38 | f_list.append(h)
39 |
40 | h = self.relu(self.conv2(h))
41 | h = self.pool2(h)
42 | f_list.append(h)
43 |
44 | h = self.relu(self.conv3a(h))
45 | h = self.relu(self.conv3b(h))
46 | h = self.pool3(h)
47 | f_list.append(h)
48 |
49 | h = self.relu(self.conv4a(h))
50 | h = self.relu(self.conv4b(h))
51 | h = self.pool4(h)
52 | f_list.append(h)
53 |
54 | h = self.relu(self.conv5a(h))
55 | h = self.relu(self.conv5b(h))
56 | h = self.pool5(h)
57 | f_list.append(h)
58 |
59 | return f_list
60 |
61 |
62 |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/xception.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch.nn as nn
3 |
4 | from pretrainedmodels.models.xception import pretrained_settings
5 | from pretrainedmodels.models.xception import Xception
6 |
7 | from ._base import EncoderMixin
8 |
9 |
10 | class XceptionEncoder(Xception, EncoderMixin):
11 | def __init__(self, out_channels, *args, depth=5, **kwargs):
12 | super().__init__(*args, **kwargs)
13 |
14 | self._out_channels = out_channels
15 | self._depth = depth
16 | self._in_channels = 3
17 |
18 | # modify padding to maintain output shape
19 | self.conv1.padding = (1, 1)
20 | self.conv2.padding = (1, 1)
21 |
22 | del self.fc
23 |
24 | def make_dilated(self, *args, **kwargs):
25 | raise ValueError(
26 | "Xception encoder does not support dilated mode " "due to pooling operation for downsampling!"
27 | )
28 |
29 | def get_stages(self):
30 | return [
31 | nn.Identity(),
32 | nn.Sequential(self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.relu),
33 | self.block1,
34 | self.block2,
35 | nn.Sequential(
36 | self.block3,
37 | self.block4,
38 | self.block5,
39 | self.block6,
40 | self.block7,
41 | self.block8,
42 | self.block9,
43 | self.block10,
44 | self.block11,
45 | ),
46 | nn.Sequential(self.block12, self.conv3, self.bn3, self.relu, self.conv4, self.bn4),
47 | ]
48 |
49 | def forward(self, x):
50 | stages = self.get_stages()
51 |
52 | features = []
53 | for i in range(self._depth + 1):
54 | x = stages[i](x)
55 | features.append(x)
56 |
57 | return features
58 |
59 | def load_state_dict(self, state_dict):
60 | # remove linear
61 | state_dict.pop("fc.bias", None)
62 | state_dict.pop("fc.weight", None)
63 |
64 | super().load_state_dict(state_dict)
65 |
66 |
67 | xception_encoders = {
68 | "xception": {
69 | "encoder": XceptionEncoder,
70 | "pretrained_settings": pretrained_settings["xception"],
71 | "params": {
72 | "out_channels": (3, 64, 128, 256, 728, 2048),
73 | },
74 | },
75 | }
76 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/C3D_jfzhang95/c3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class C3D(nn.Module):
6 | """
7 | The C3D network.
8 | """
9 |
10 | def __init__(self):
11 | super(C3D, self).__init__()
12 |
13 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
14 | self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
15 |
16 | self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1))
17 | self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
18 |
19 | self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
20 | self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
21 | self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
22 |
23 | self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
24 | self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
25 | self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
26 |
27 | self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
28 | self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
29 | self.pool5 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1))
30 |
31 | self.relu = torch.relu
32 |
33 | def forward(self, x):
34 | f_list = []
35 | x = self.relu(self.conv1(x))
36 | x = self.pool1(x)
37 | f_list.append(x)
38 | x = self.relu(self.conv2(x))
39 | x = self.pool2(x)
40 | f_list.append(x)
41 | x = self.relu(self.conv3a(x))
42 | x = self.relu(self.conv3b(x))
43 | x = self.pool3(x)
44 | f_list.append(x)
45 | x = self.relu(self.conv4a(x))
46 | x = self.relu(self.conv4b(x))
47 | x = self.pool4(x)
48 | f_list.append(x)
49 | x = self.relu(self.conv5a(x))
50 | x = self.relu(self.conv5b(x))
51 | x = self.pool5(x)
52 | f_list.append(x)
53 |
54 | return f_list
55 |
56 | def __init_weight(self):
57 | for m in self.modules():
58 | if isinstance(m, nn.Conv3d):
59 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
60 | # m.weight.data.normal_(0, math.sqrt(2. / n))
61 | torch.nn.init.kaiming_normal_(m.weight)
62 | elif isinstance(m, nn.BatchNorm3d):
63 | m.weight.data.fill_(1)
64 | m.bias.data.zero_()
65 |
--------------------------------------------------------------------------------
/demo/Demo6_UnetwithFPN_segmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from wama_modules.Encoder import ResNetEncoder
4 | from wama_modules.Decoder import UNet_decoder
5 | from wama_modules.Head import SegmentationHead
6 | from wama_modules.utils import resizeTensor
7 | from wama_modules.Neck import FPN
8 |
9 |
10 | class Model(nn.Module):
11 | def __init__(self, in_channel, label_category_dict, dim=2):
12 | super().__init__()
13 | # encoder
14 | Encoder_f_channel_list = [64, 128, 256, 512]
15 | self.encoder = ResNetEncoder(
16 | in_channel,
17 | stage_output_channels=Encoder_f_channel_list,
18 | stage_middle_channels=Encoder_f_channel_list,
19 | blocks=[1, 2, 3, 4],
20 | type='131',
21 | downsample_ration=[0.5, 0.5, 0.5, 0.5],
22 | dim=dim)
23 |
24 | # neck
25 | FPN_output_channel = 256
26 | FPN_channels = [FPN_output_channel]*len(Encoder_f_channel_list)
27 | self.neck = FPN(in_channels_list=Encoder_f_channel_list,
28 | c1=FPN_output_channel//2,
29 | c2=FPN_output_channel,
30 | mode='AddSmall2Big',
31 | dim=dim,)
32 |
33 | # decoder
34 | Decoder_f_channel_list = [32, 64, 128]
35 | self.decoder = UNet_decoder(
36 | in_channels_list=FPN_channels,
37 | skip_connection=[True, True, True],
38 | out_channels_list=Decoder_f_channel_list,
39 | dim=dim)
40 | # seg head
41 | self.seg_head = SegmentationHead(
42 | label_category_dict,
43 | Decoder_f_channel_list[0],
44 | dim=dim)
45 |
46 | def forward(self, x):
47 | multi_scale_encoder = self.encoder(x)
48 | multi_scale_neck = self.neck(multi_scale_encoder)
49 | multi_scale_decoder = self.decoder(multi_scale_neck)
50 | f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])
51 | logits = self.seg_head(f_for_seg)
52 | return logits
53 |
54 |
55 | if __name__ == '__main__':
56 | x = torch.ones([2, 1, 128, 128, 128])
57 | label_category_dict = dict(organ=3, tumor=4)
58 | model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
59 | logits = model(x)
60 | print('multi_label predicted logits')
61 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
62 |
63 | # out
64 | # multi_label predicted logits
65 | # logits of organ : torch.Size([2, 3, 128, 128, 128])
66 | # logits of tumor : torch.Size([2, 4, 128, 128, 128])
67 |
68 |
69 |
--------------------------------------------------------------------------------
/demo/_explain_how2useVitAsNeck.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import ViTConfig, ViTModel
3 | from wama_modules.utils import load_weights, tensor2array
4 |
5 |
6 | m = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k')
7 | f = m(torch.ones([1, 3, 224, 224]), output_hidden_states=True)
8 | f_last = f.last_hidden_state
9 | print(f_last.shape)
10 | f_cls_token = (torch.squeeze(f_last[:,0])).data.cpu().numpy()
11 | f_cls_token = list(f_cls_token)
12 |
13 |
14 | configuration = m.config
15 | configuration.image_size = [16, 8]
16 | configuration.patch_size = [1, 1]
17 | configuration.num_channels = 1
18 | configuration.encoder_stride = 1 # just for MAE decoder, otherwise this paramater is not used
19 | m1 = ViTModel(configuration, add_pooling_layer=False)
20 |
21 | f = m1(torch.ones([2, 1, 16, 8]), output_hidden_states=True)
22 |
23 |
24 | f_list = f.hidden_states # For transformer, should use reshaped_hidden_states
25 | _ = [print(i.shape) for i in f_list]
26 |
27 | f_last = f.last_hidden_state
28 | f_last = f_last[:, 1:]
29 | f_last = f_last.permute(0, 2, 1)
30 | f_last = f_last.reshape(f_last.shape[0], f_last.shape[1], configuration.image_size[0], configuration.image_size[1])
31 | print('spatial f_last:', f_last.shape)
32 |
33 |
34 | # reload weights
35 |
36 | m = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k')
37 | weights = m.state_dict()
38 | weights['embeddings.position_embeddings'] = m1.state_dict()['embeddings.position_embeddings']
39 | weights['embeddings.patch_embeddings.projection.weight'] = m1.state_dict()['embeddings.patch_embeddings.projection.weight']
40 | weights['embeddings.patch_embeddings.projection.bias'] = m1.state_dict()['embeddings.patch_embeddings.projection.bias']
41 |
42 |
43 | m1 = load_weights(m1, weights)
44 |
45 |
46 | # test: spatial visualization
47 | m1 = ViTModel(configuration, add_pooling_layer=False)
48 |
49 | input = torch.ones([2, 1, 16, 8])*100
50 | input[:,:,8:] = input[:,:,8:]*0.
51 | input[:,:,:3] = input[:,:,:3]*0.
52 | input[:,:,:,:3] = input[:,:,:,:3]*0.
53 | f = m1(input, output_hidden_states=True)
54 | f_last = f.last_hidden_state
55 | f_last = f_last[:, 1:]
56 | f_last = f_last.permute(0, 2, 1)
57 | f_last = f_last.reshape(f_last.shape[0], f_last.shape[1], configuration.image_size[0], configuration.image_size[1])
58 | print('spatial f_last:', f_last.shape)
59 | print(f_last.max())
60 | print(f_last.min())
61 |
62 | def tensor2numpy(tensor):
63 | return tensor.data.cpu().numpy()
64 | import numpy as np
65 | def mat2gray(image):
66 | """
67 | 归一化函数(线性归一化)
68 | :param image: ndarray
69 | :return:
70 | """
71 | # as dtype = np.float32
72 | image = image.astype(np.float32)
73 | image = (image - np.min(image)) / (np.max(image)-np.min(image)+ 1e-14)
74 | return image
75 |
76 | import matplotlib.pyplot as plt
77 | def show2D(img):
78 | plt.imshow(img)
79 | plt.show()
80 |
81 |
82 | # the two image should be aligned in space
83 | show2D(tensor2numpy(f_last[0,0]))
84 | show2D(tensor2numpy(input[0,0]))
85 |
86 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/mobilenet.py:
--------------------------------------------------------------------------------
1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
2 |
3 | Attributes:
4 |
5 | _out_channels (list of int): specify number of channels for each encoder feature tensor
6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations)
7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
8 |
9 | Methods:
10 |
11 | forward(self, x: torch.Tensor)
12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting
14 | with resolution same as input `x` tensor).
15 |
16 | Input: `x` with shape (1, 3, 64, 64)
17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
20 |
21 | also should support number of features according to specified depth, e.g. if depth = 5,
22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24 | """
25 |
26 | import torchvision
27 | import torch.nn as nn
28 |
29 | from ._base import EncoderMixin
30 |
31 |
32 | class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin):
33 | def __init__(self, out_channels, depth=5, **kwargs):
34 | super().__init__(**kwargs)
35 | self._depth = depth
36 | self._out_channels = out_channels
37 | self._in_channels = 3
38 | del self.classifier
39 |
40 | def get_stages(self):
41 | return [
42 | nn.Identity(),
43 | self.features[:2],
44 | self.features[2:4],
45 | self.features[4:7],
46 | self.features[7:14],
47 | self.features[14:],
48 | ]
49 |
50 | def forward(self, x):
51 | stages = self.get_stages()
52 |
53 | features = []
54 | for i in range(self._depth + 1):
55 | x = stages[i](x)
56 | features.append(x)
57 |
58 | return features
59 |
60 | def load_state_dict(self, state_dict, **kwargs):
61 | state_dict.pop("classifier.1.bias", None)
62 | state_dict.pop("classifier.1.weight", None)
63 | super().load_state_dict(state_dict, **kwargs)
64 |
65 |
66 | mobilenet_encoders = {
67 | "mobilenet_v2": {
68 | "encoder": MobileNetV2Encoder,
69 | "pretrained_settings": {
70 | "imagenet": {
71 | "mean": [0.485, 0.456, 0.406],
72 | "std": [0.229, 0.224, 0.225],
73 | "url": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
74 | "input_space": "RGB",
75 | "input_range": [0, 1],
76 | },
77 | },
78 | "params": {
79 | "out_channels": (3, 16, 24, 32, 96, 1280),
80 | },
81 | },
82 | }
83 |
--------------------------------------------------------------------------------
/demo/Demo5_MultiTask_SegAndCls.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from wama_modules.Encoder import ResNetEncoder
4 | from wama_modules.Decoder import UNet_decoder
5 | from wama_modules.Head import SegmentationHead, ClassificationHead
6 | from wama_modules.utils import resizeTensor
7 | from wama_modules.BaseModule import GlobalMaxPool
8 |
9 |
10 | class Model(nn.Module):
11 | def __init__(self,
12 | in_channel,
13 | seg_label_category_dict,
14 | cls_label_category_dict,
15 | dim=2):
16 | super().__init__()
17 | # encoder
18 | Encoder_f_channel_list = [64, 128, 256, 512]
19 | self.encoder = ResNetEncoder(
20 | in_channel,
21 | stage_output_channels=Encoder_f_channel_list,
22 | stage_middle_channels=Encoder_f_channel_list,
23 | blocks=[1, 2, 3, 4],
24 | type='131',
25 | downsample_ration=[0.5, 0.5, 0.5, 0.5],
26 | dim=dim)
27 | # decoder
28 | Decoder_f_channel_list = [32, 64, 128]
29 | self.decoder = UNet_decoder(
30 | in_channels_list=Encoder_f_channel_list,
31 | skip_connection=[False, True, True],
32 | out_channels_list=Decoder_f_channel_list,
33 | dim=dim)
34 | # seg head
35 | self.seg_head = SegmentationHead(
36 | seg_label_category_dict,
37 | Decoder_f_channel_list[0],
38 | dim=dim)
39 | # cls head
40 | self.cls_head = ClassificationHead(cls_label_category_dict, Encoder_f_channel_list[-1])
41 |
42 | # pooling
43 | self.pooling = GlobalMaxPool()
44 |
45 | def forward(self, x):
46 | # get encoder features
47 | multi_scale_encoder = self.encoder(x)
48 | # get decoder features
49 | multi_scale_decoder = self.decoder(multi_scale_encoder)
50 | # perform segmentation
51 | f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])
52 | seg_logits = self.seg_head(f_for_seg)
53 | # perform classification
54 | cls_logits = self.cls_head(self.pooling(multi_scale_encoder[-1]))
55 | return seg_logits, cls_logits
56 |
57 | if __name__ == '__main__':
58 | x = torch.ones([2, 1, 128, 128, 128])
59 | seg_label_category_dict = dict(organ=3, tumor=2)
60 | cls_label_category_dict = dict(shape=4, color=3, other=13)
61 | model = Model(
62 | in_channel=1,
63 | cls_label_category_dict=cls_label_category_dict,
64 | seg_label_category_dict=seg_label_category_dict,
65 | dim=3)
66 | seg_logits, cls_logits = model(x)
67 | print('multi_label predicted logits')
68 | _ = [print('seg logits of ', key, ':', seg_logits[key].shape) for key in seg_logits.keys()]
69 | print('-'*30)
70 | _ = [print('cls logits of ', key, ':', cls_logits[key].shape) for key in cls_logits.keys()]
71 |
72 | # out
73 | # multi_label predicted logits
74 | # seg logits of organ : torch.Size([2, 3, 128, 128, 128])
75 | # seg logits of tumor : torch.Size([2, 2, 128, 128, 128])
76 | # ------------------------------
77 | # cls logits of shape : torch.Size([2, 4])
78 | # cls logits of color : torch.Size([2, 3])
79 | # cls logits of other : torch.Size([2, 13])
80 |
81 |
82 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/utils/metrics.py:
--------------------------------------------------------------------------------
1 | from . import base
2 | from . import functional as F
3 | from ..base.modules import Activation
4 |
5 |
6 | class IoU(base.Metric):
7 | __name__ = "iou_score"
8 |
9 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs):
10 | super().__init__(**kwargs)
11 | self.eps = eps
12 | self.threshold = threshold
13 | self.activation = Activation(activation)
14 | self.ignore_channels = ignore_channels
15 |
16 | def forward(self, y_pr, y_gt):
17 | y_pr = self.activation(y_pr)
18 | return F.iou(
19 | y_pr,
20 | y_gt,
21 | eps=self.eps,
22 | threshold=self.threshold,
23 | ignore_channels=self.ignore_channels,
24 | )
25 |
26 |
27 | class Fscore(base.Metric):
28 | def __init__(self, beta=1, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs):
29 | super().__init__(**kwargs)
30 | self.eps = eps
31 | self.beta = beta
32 | self.threshold = threshold
33 | self.activation = Activation(activation)
34 | self.ignore_channels = ignore_channels
35 |
36 | def forward(self, y_pr, y_gt):
37 | y_pr = self.activation(y_pr)
38 | return F.f_score(
39 | y_pr,
40 | y_gt,
41 | eps=self.eps,
42 | beta=self.beta,
43 | threshold=self.threshold,
44 | ignore_channels=self.ignore_channels,
45 | )
46 |
47 |
48 | class Accuracy(base.Metric):
49 | def __init__(self, threshold=0.5, activation=None, ignore_channels=None, **kwargs):
50 | super().__init__(**kwargs)
51 | self.threshold = threshold
52 | self.activation = Activation(activation)
53 | self.ignore_channels = ignore_channels
54 |
55 | def forward(self, y_pr, y_gt):
56 | y_pr = self.activation(y_pr)
57 | return F.accuracy(
58 | y_pr,
59 | y_gt,
60 | threshold=self.threshold,
61 | ignore_channels=self.ignore_channels,
62 | )
63 |
64 |
65 | class Recall(base.Metric):
66 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs):
67 | super().__init__(**kwargs)
68 | self.eps = eps
69 | self.threshold = threshold
70 | self.activation = Activation(activation)
71 | self.ignore_channels = ignore_channels
72 |
73 | def forward(self, y_pr, y_gt):
74 | y_pr = self.activation(y_pr)
75 | return F.recall(
76 | y_pr,
77 | y_gt,
78 | eps=self.eps,
79 | threshold=self.threshold,
80 | ignore_channels=self.ignore_channels,
81 | )
82 |
83 |
84 | class Precision(base.Metric):
85 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs):
86 | super().__init__(**kwargs)
87 | self.eps = eps
88 | self.threshold = threshold
89 | self.activation = Activation(activation)
90 | self.ignore_channels = ignore_channels
91 |
92 | def forward(self, y_pr, y_gt):
93 | y_pr = self.activation(y_pr)
94 | return F.precision(
95 | y_pr,
96 | y_gt,
97 | eps=self.eps,
98 | threshold=self.threshold,
99 | ignore_channels=self.ignore_channels,
100 | )
101 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/inceptionresnetv2.py:
--------------------------------------------------------------------------------
1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
2 |
3 | Attributes:
4 |
5 | _out_channels (list of int): specify number of channels for each encoder feature tensor
6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations)
7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
8 |
9 | Methods:
10 |
11 | forward(self, x: torch.Tensor)
12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting
14 | with resolution same as input `x` tensor).
15 |
16 | Input: `x` with shape (1, 3, 64, 64)
17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
20 |
21 | also should support number of features according to specified depth, e.g. if depth = 5,
22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24 | """
25 |
26 | import torch.nn as nn
27 | from pretrainedmodels.models.inceptionresnetv2 import InceptionResNetV2
28 | from pretrainedmodels.models.inceptionresnetv2 import pretrained_settings
29 |
30 | from ._base import EncoderMixin
31 |
32 |
33 | class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin):
34 | def __init__(self, out_channels, depth=5, **kwargs):
35 | super().__init__(**kwargs)
36 |
37 | self._out_channels = out_channels
38 | self._depth = depth
39 | self._in_channels = 3
40 |
41 | # correct paddings
42 | for m in self.modules():
43 | if isinstance(m, nn.Conv2d):
44 | if m.kernel_size == (3, 3):
45 | m.padding = (1, 1)
46 | if isinstance(m, nn.MaxPool2d):
47 | m.padding = (1, 1)
48 |
49 | # remove linear layers
50 | del self.avgpool_1a
51 | del self.last_linear
52 |
53 | def make_dilated(self, *args, **kwargs):
54 | raise ValueError(
55 | "InceptionResnetV2 encoder does not support dilated mode " "due to pooling operation for downsampling!"
56 | )
57 |
58 | def get_stages(self):
59 | return [
60 | nn.Identity(),
61 | nn.Sequential(self.conv2d_1a, self.conv2d_2a, self.conv2d_2b),
62 | nn.Sequential(self.maxpool_3a, self.conv2d_3b, self.conv2d_4a),
63 | nn.Sequential(self.maxpool_5a, self.mixed_5b, self.repeat),
64 | nn.Sequential(self.mixed_6a, self.repeat_1),
65 | nn.Sequential(self.mixed_7a, self.repeat_2, self.block8, self.conv2d_7b),
66 | ]
67 |
68 | def forward(self, x):
69 |
70 | stages = self.get_stages()
71 |
72 | features = []
73 | for i in range(self._depth + 1):
74 | x = stages[i](x)
75 | features.append(x)
76 |
77 | return features
78 |
79 | def load_state_dict(self, state_dict, **kwargs):
80 | state_dict.pop("last_linear.bias", None)
81 | state_dict.pop("last_linear.weight", None)
82 | super().load_state_dict(state_dict, **kwargs)
83 |
84 |
85 | inceptionresnetv2_encoders = {
86 | "inceptionresnetv2": {
87 | "encoder": InceptionResNetV2Encoder,
88 | "pretrained_settings": pretrained_settings["inceptionresnetv2"],
89 | "params": {"out_channels": (3, 64, 192, 320, 1088, 1536), "num_classes": 1000},
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/inceptionv4.py:
--------------------------------------------------------------------------------
1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
2 |
3 | Attributes:
4 |
5 | _out_channels (list of int): specify number of channels for each encoder feature tensor
6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations)
7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
8 |
9 | Methods:
10 |
11 | forward(self, x: torch.Tensor)
12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting
14 | with resolution same as input `x` tensor).
15 |
16 | Input: `x` with shape (1, 3, 64, 64)
17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
20 |
21 | also should support number of features according to specified depth, e.g. if depth = 5,
22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24 | """
25 |
26 | import torch.nn as nn
27 | from pretrainedmodels.models.inceptionv4 import InceptionV4, BasicConv2d
28 | from pretrainedmodels.models.inceptionv4 import pretrained_settings
29 |
30 | from ._base import EncoderMixin
31 |
32 |
33 | class InceptionV4Encoder(InceptionV4, EncoderMixin):
34 | def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
35 | super().__init__(**kwargs)
36 | self._stage_idxs = stage_idxs
37 | self._out_channels = out_channels
38 | self._depth = depth
39 | self._in_channels = 3
40 |
41 | # correct paddings
42 | for m in self.modules():
43 | if isinstance(m, nn.Conv2d):
44 | if m.kernel_size == (3, 3):
45 | m.padding = (1, 1)
46 | if isinstance(m, nn.MaxPool2d):
47 | m.padding = (1, 1)
48 |
49 | # remove linear layers
50 | del self.last_linear
51 |
52 | def make_dilated(self, stage_list, dilation_list):
53 | raise ValueError(
54 | "InceptionV4 encoder does not support dilated mode " "due to pooling operation for downsampling!"
55 | )
56 |
57 | def get_stages(self):
58 | return [
59 | nn.Identity(),
60 | self.features[: self._stage_idxs[0]],
61 | self.features[self._stage_idxs[0] : self._stage_idxs[1]],
62 | self.features[self._stage_idxs[1] : self._stage_idxs[2]],
63 | self.features[self._stage_idxs[2] : self._stage_idxs[3]],
64 | self.features[self._stage_idxs[3] :],
65 | ]
66 |
67 | def forward(self, x):
68 |
69 | stages = self.get_stages()
70 |
71 | features = []
72 | for i in range(self._depth + 1):
73 | x = stages[i](x)
74 | features.append(x)
75 |
76 | return features
77 |
78 | def load_state_dict(self, state_dict, **kwargs):
79 | state_dict.pop("last_linear.bias", None)
80 | state_dict.pop("last_linear.weight", None)
81 | super().load_state_dict(state_dict, **kwargs)
82 |
83 |
84 | inceptionv4_encoders = {
85 | "inceptionv4": {
86 | "encoder": InceptionV4Encoder,
87 | "pretrained_settings": pretrained_settings["inceptionv4"],
88 | "params": {
89 | "stage_idxs": (3, 5, 9, 15),
90 | "out_channels": (3, 64, 192, 384, 1024, 1536),
91 | "num_classes": 1001,
92 | },
93 | }
94 | }
95 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/utils/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from tqdm import tqdm as tqdm
4 | from .meter import AverageValueMeter
5 |
6 |
7 | class Epoch:
8 | def __init__(self, model, loss, metrics, stage_name, device="cpu", verbose=True):
9 | self.model = model
10 | self.loss = loss
11 | self.metrics = metrics
12 | self.stage_name = stage_name
13 | self.verbose = verbose
14 | self.device = device
15 |
16 | self._to_device()
17 |
18 | def _to_device(self):
19 | self.model.to(self.device)
20 | self.loss.to(self.device)
21 | for metric in self.metrics:
22 | metric.to(self.device)
23 |
24 | def _format_logs(self, logs):
25 | str_logs = ["{} - {:.4}".format(k, v) for k, v in logs.items()]
26 | s = ", ".join(str_logs)
27 | return s
28 |
29 | def batch_update(self, x, y):
30 | raise NotImplementedError
31 |
32 | def on_epoch_start(self):
33 | pass
34 |
35 | def run(self, dataloader):
36 |
37 | self.on_epoch_start()
38 |
39 | logs = {}
40 | loss_meter = AverageValueMeter()
41 | metrics_meters = {metric.__name__: AverageValueMeter() for metric in self.metrics}
42 |
43 | with tqdm(
44 | dataloader,
45 | desc=self.stage_name,
46 | file=sys.stdout,
47 | disable=not (self.verbose),
48 | ) as iterator:
49 | for x, y in iterator:
50 | x, y = x.to(self.device), y.to(self.device)
51 | loss, y_pred = self.batch_update(x, y)
52 |
53 | # update loss logs
54 | loss_value = loss.cpu().detach().numpy()
55 | loss_meter.add(loss_value)
56 | loss_logs = {self.loss.__name__: loss_meter.mean}
57 | logs.update(loss_logs)
58 |
59 | # update metrics logs
60 | for metric_fn in self.metrics:
61 | metric_value = metric_fn(y_pred, y).cpu().detach().numpy()
62 | metrics_meters[metric_fn.__name__].add(metric_value)
63 | metrics_logs = {k: v.mean for k, v in metrics_meters.items()}
64 | logs.update(metrics_logs)
65 |
66 | if self.verbose:
67 | s = self._format_logs(logs)
68 | iterator.set_postfix_str(s)
69 |
70 | return logs
71 |
72 |
73 | class TrainEpoch(Epoch):
74 | def __init__(self, model, loss, metrics, optimizer, device="cpu", verbose=True):
75 | super().__init__(
76 | model=model,
77 | loss=loss,
78 | metrics=metrics,
79 | stage_name="train",
80 | device=device,
81 | verbose=verbose,
82 | )
83 | self.optimizer = optimizer
84 |
85 | def on_epoch_start(self):
86 | self.model.train()
87 |
88 | def batch_update(self, x, y):
89 | self.optimizer.zero_grad()
90 | prediction = self.model.forward(x)
91 | loss = self.loss(prediction, y)
92 | loss.backward()
93 | self.optimizer.step()
94 | return loss, prediction
95 |
96 |
97 | class ValidEpoch(Epoch):
98 | def __init__(self, model, loss, metrics, device="cpu", verbose=True):
99 | super().__init__(
100 | model=model,
101 | loss=loss,
102 | metrics=metrics,
103 | stage_name="valid",
104 | device=device,
105 | verbose=verbose,
106 | )
107 |
108 | def on_epoch_start(self):
109 | self.model.eval()
110 |
111 | def batch_update(self, x, y):
112 | with torch.no_grad():
113 | prediction = self.model.forward(x)
114 | loss = self.loss(prediction, y)
115 | return loss, prediction
116 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/timm_sknet.py:
--------------------------------------------------------------------------------
1 | from ._base import EncoderMixin
2 | from timm.models.resnet import ResNet
3 | from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic
4 | import torch.nn as nn
5 |
6 |
7 | class SkNetEncoder(ResNet, EncoderMixin):
8 | def __init__(self, out_channels, depth=5, **kwargs):
9 | super().__init__(**kwargs)
10 | self._depth = depth
11 | self._out_channels = out_channels
12 | self._in_channels = 3
13 |
14 | del self.fc
15 | del self.global_pool
16 |
17 | def get_stages(self):
18 | return [
19 | nn.Identity(),
20 | nn.Sequential(self.conv1, self.bn1, self.act1),
21 | nn.Sequential(self.maxpool, self.layer1),
22 | self.layer2,
23 | self.layer3,
24 | self.layer4,
25 | ]
26 |
27 | def forward(self, x):
28 | stages = self.get_stages()
29 |
30 | features = []
31 | for i in range(self._depth + 1):
32 | x = stages[i](x)
33 | features.append(x)
34 |
35 | return features
36 |
37 | def load_state_dict(self, state_dict, **kwargs):
38 | state_dict.pop("fc.bias", None)
39 | state_dict.pop("fc.weight", None)
40 | super().load_state_dict(state_dict, **kwargs)
41 |
42 |
43 | sknet_weights = {
44 | "timm-skresnet18": {
45 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth", # noqa
46 | },
47 | "timm-skresnet34": {
48 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth", # noqa
49 | },
50 | "timm-skresnext50_32x4d": {
51 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth", # noqa
52 | },
53 | }
54 |
55 | pretrained_settings = {}
56 | for model_name, sources in sknet_weights.items():
57 | pretrained_settings[model_name] = {}
58 | for source_name, source_url in sources.items():
59 | pretrained_settings[model_name][source_name] = {
60 | "url": source_url,
61 | "input_size": [3, 224, 224],
62 | "input_range": [0, 1],
63 | "mean": [0.485, 0.456, 0.406],
64 | "std": [0.229, 0.224, 0.225],
65 | "num_classes": 1000,
66 | }
67 |
68 | timm_sknet_encoders = {
69 | "timm-skresnet18": {
70 | "encoder": SkNetEncoder,
71 | "pretrained_settings": pretrained_settings["timm-skresnet18"],
72 | "params": {
73 | "out_channels": (3, 64, 64, 128, 256, 512),
74 | "block": SelectiveKernelBasic,
75 | "layers": [2, 2, 2, 2],
76 | "zero_init_last_bn": False,
77 | "block_args": {"sk_kwargs": {"rd_ratio": 1 / 8, "split_input": True}},
78 | },
79 | },
80 | "timm-skresnet34": {
81 | "encoder": SkNetEncoder,
82 | "pretrained_settings": pretrained_settings["timm-skresnet34"],
83 | "params": {
84 | "out_channels": (3, 64, 64, 128, 256, 512),
85 | "block": SelectiveKernelBasic,
86 | "layers": [3, 4, 6, 3],
87 | "zero_init_last_bn": False,
88 | "block_args": {"sk_kwargs": {"rd_ratio": 1 / 8, "split_input": True}},
89 | },
90 | },
91 | "timm-skresnext50_32x4d": {
92 | "encoder": SkNetEncoder,
93 | "pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"],
94 | "params": {
95 | "out_channels": (3, 64, 256, 512, 1024, 2048),
96 | "block": SelectiveKernelBottleneck,
97 | "layers": [3, 4, 6, 3],
98 | "zero_init_last_bn": False,
99 | "cardinality": 32,
100 | "base_width": 4,
101 | },
102 | },
103 | }
104 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/c3d.py:
--------------------------------------------------------------------------------
1 | """
2 | This is the c3d implementation with batch norm.
3 |
4 | References
5 | ----------
6 | [1] Tran, Du, et al. "Learning spatiotemporal features with 3d convolutional networks."
7 | Proceedings of the IEEE international conference on computer vision. 2015.
8 | """
9 |
10 | import math
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 | import torch.nn.functional as F
15 | from torch.autograd import Variable
16 | from functools import partial
17 |
18 |
19 | class C3D(nn.Module):
20 | def __init__(self,):
21 |
22 | super(C3D, self).__init__()
23 | self.group1 = nn.Sequential(
24 | nn.Conv3d(3, 64, kernel_size=3, padding=1),
25 | nn.BatchNorm3d(64),
26 | nn.ReLU(),
27 | nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2)))
28 | self.group2 = nn.Sequential(
29 | nn.Conv3d(64, 128, kernel_size=3, padding=1),
30 | nn.BatchNorm3d(128),
31 | nn.ReLU(),
32 | nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)))
33 | self.group3 = nn.Sequential(
34 | nn.Conv3d(128, 256, kernel_size=3, padding=1),
35 | nn.BatchNorm3d(256),
36 | nn.ReLU(),
37 | nn.Conv3d(256, 256, kernel_size=3, padding=1),
38 | nn.BatchNorm3d(256),
39 | nn.ReLU(),
40 | nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)))
41 | self.group4 = nn.Sequential(
42 | nn.Conv3d(256, 512, kernel_size=3, padding=1),
43 | nn.BatchNorm3d(512),
44 | nn.ReLU(),
45 | nn.Conv3d(512, 512, kernel_size=3, padding=1),
46 | nn.BatchNorm3d(512),
47 | nn.ReLU(),
48 | nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)))
49 | self.group5 = nn.Sequential(
50 | nn.Conv3d(512, 512, kernel_size=3, padding=1),
51 | nn.BatchNorm3d(512),
52 | nn.ReLU(),
53 | nn.Conv3d(512, 512, kernel_size=3, padding=1),
54 | nn.BatchNorm3d(512),
55 | nn.ReLU(),
56 | nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)))
57 |
58 | def forward(self, x):
59 | f_list = []
60 | out = self.group1(x)
61 | f_list.append(out)
62 | out = self.group2(out)
63 | f_list.append(out)
64 | out = self.group3(out)
65 | f_list.append(out)
66 | out = self.group4(out)
67 | f_list.append(out)
68 | out = self.group5(out)
69 | f_list.append(out)
70 | return f_list
71 |
72 |
73 | def get_fine_tuning_parameters(model, ft_portion):
74 | if ft_portion == "complete":
75 | return model.parameters()
76 |
77 | elif ft_portion == "last_layer":
78 | ft_module_names = []
79 | ft_module_names.append('fc')
80 |
81 | parameters = []
82 | for k, v in model.named_parameters():
83 | for ft_module in ft_module_names:
84 | if ft_module in k:
85 | parameters.append({'params': v})
86 | break
87 | else:
88 | parameters.append({'params': v, 'lr': 0.0})
89 | return parameters
90 |
91 | else:
92 | raise ValueError("Unsupported ft_portion: 'complete' or 'last_layer' expected")
93 |
94 |
95 | def get_model(**kwargs):
96 | """
97 | Returns the model.
98 | """
99 | model = C3D(**kwargs)
100 | return model
101 |
102 |
103 | if __name__ == '__main__':
104 | model = get_model(sample_size = 112, sample_duration = 16, num_classes=600)
105 | model = model.cuda()
106 | model = nn.DataParallel(model, device_ids=None)
107 | print(model)
108 |
109 | input_var = Variable(torch.randn(8, 3, 16, 112, 112))
110 | output = model(input_var)
111 | print(output.shape)
112 |
--------------------------------------------------------------------------------
/wama_modules/Head.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from wama_modules.BaseModule import *
3 |
4 |
5 | class ClassificationHead(nn.Module):
6 | """
7 | Head for single or multiple label classification task
8 | """
9 |
10 | def __init__(self, label_category_dict, in_channel, bias=True):
11 | super().__init__()
12 | self.classification_head = torch.nn.ModuleDict({})
13 | for key in label_category_dict.keys():
14 | self.classification_head[key] = torch.nn.Linear(in_channel, label_category_dict[key], bias=bias)
15 |
16 | def forward(self, f):
17 | """
18 | # demo: an example of fruit classification task
19 |
20 | f = torch.ones([3, 512]) # from encoder
21 | label_category_dict = dict(
22 | shape=4,
23 | color=3,
24 | rotten=2,
25 | sweet=2,
26 | sour=2,
27 | )
28 | cls_head = ClassificationHead(label_category_dict, 512)
29 | logits = cls_head(f)
30 | _ = [print('logits of ', key,':' ,logits[key].shape) for key in logits.keys()]
31 |
32 |
33 | # support element-wise performing, by transfer dict or list to f
34 | label_category_dict = dict(
35 | shape=4,
36 | color=3,
37 | rotten=2,
38 | sweet=2,
39 | sour=2,
40 | )
41 |
42 | # dict input for element-wised FC
43 | f = {}
44 | for key in label_category_dict.keys():
45 | f[key] = torch.ones([3, 512]) # from encoder
46 | cls_head = ClassificationHead(label_category_dict, 512)
47 | logits = cls_head(f)
48 | _ = [print('logits of ', key,':' ,logits[key].shape) for key in logits.keys()]
49 |
50 | # list input for element-wised FC
51 | f = []
52 | for key in label_category_dict.keys():
53 | f.append(torch.ones([3, 512])) # from encoder
54 | cls_head = ClassificationHead(label_category_dict, 512)
55 | logits = cls_head(f)
56 | _ = [print('logits of ', key,':' ,logits[key].shape) for key in logits.keys()]
57 |
58 | """
59 | logits = {}
60 | if isinstance(f,dict):
61 | print('dict element-wised forward')
62 | for key in self.classification_head.keys():
63 | logits[key] = self.classification_head[key](f[key])
64 | elif isinstance(f,list):
65 | print('list element-wised forward')
66 | for key_index, key in enumerate(self.classification_head.keys()):
67 | logits[key] = self.classification_head[key](f[key_index])
68 | else:
69 | for key in self.classification_head.keys():
70 | logits[key] = self.classification_head[key](f)
71 | return logits
72 |
73 |
74 | class SegmentationHead(nn.Module):
75 | """Head for single or multiple label segmentation task"""
76 |
77 | def __init__(self, label_category_dict, in_channel, bias=True, dim=2):
78 | super().__init__()
79 | self.segmentatin_head = torch.nn.ModuleDict({})
80 | for key in label_category_dict.keys():
81 | self.segmentatin_head[key] = MakeConv(in_channel, label_category_dict[key], 3, padding=1, stride=1, dim=dim,
82 | bias=bias)
83 |
84 | def forward(self, f):
85 | """
86 | # demo 2D
87 |
88 | f = torch.ones([3, 512, 128, 128]) # from decoder or encoder
89 | label_category_dict = dict(
90 | organ=14, # 14 kinds of organ
91 | tumor=3, # 3 kinds of tumor
92 | )
93 | seg_head = SegmentationHead(label_category_dict, 512, dim=2)
94 | seg_logits = seg_head(f)
95 | _ = [print('segmentation_logits of ', key,':' ,seg_logits[key].shape) for key in seg_logits.keys()]
96 |
97 | """
98 | logits = {}
99 | for key in self.segmentatin_head.keys():
100 | logits[key] = self.segmentatin_head[key](f)
101 | return logits
102 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | '''MobileNet in PyTorch.
2 |
3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
4 | for more details.
5 | '''
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | def conv_bn(inp, oup, stride):
12 | return nn.Sequential(
13 | nn.Conv3d(inp, oup, kernel_size=3, stride=stride, padding=(1,1,1), bias=False),
14 | nn.BatchNorm3d(oup),
15 | nn.ReLU(inplace=True)
16 | )
17 |
18 |
19 | class Block(nn.Module):
20 | '''Depthwise conv + Pointwise conv'''
21 | def __init__(self, in_planes, out_planes, stride=1):
22 | super(Block, self).__init__()
23 | self.conv1 = nn.Conv3d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False)
24 | self.bn1 = nn.BatchNorm3d(in_planes)
25 | self.conv2 = nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
26 | self.bn2 = nn.BatchNorm3d(out_planes)
27 |
28 | def forward(self, x):
29 | out = F.relu(self.bn1(self.conv1(x)))
30 | out = F.relu(self.bn2(self.conv2(out)))
31 | return out
32 |
33 |
34 | class MobileNet(nn.Module):
35 | def __init__(self, width_mult=1.):
36 | super(MobileNet, self).__init__()
37 |
38 | input_channel = 32
39 | last_channel = 1024
40 | input_channel = int(input_channel * width_mult)
41 | cfg = [
42 | # c, n, s
43 | [64, 1, (2,2,2)],
44 | [128, 2, (2,2,2)],
45 | [256, 2, (2,2,2)],
46 | [512, 6, (2,2,2)],
47 | [1024, 2, (1,1,1)],
48 | ]
49 |
50 | self.features = [conv_bn(3, input_channel, (1,2,2))]
51 | # building inverted residual blocks
52 | for c, n, s in cfg:
53 | output_channel = int(c * width_mult)
54 | for i in range(n):
55 | stride = s if i == 0 else 1
56 | self.features.append(Block(input_channel, output_channel, stride))
57 | input_channel = output_channel
58 | # make it nn.Sequential
59 | self.features = nn.Sequential(*self.features)
60 |
61 |
62 |
63 | def forward(self, x):
64 | f_list = []
65 | for i in range(len(self.features)):
66 | x = self.features[i](x)
67 | f_list.append(x)
68 |
69 | # keep last f
70 | f_list_ = []
71 | for i, f in enumerate(f_list):
72 | if i == 0 or i == len(f_list)-1:
73 | f_list_.append(f)
74 | elif f.shape[1] != f_list[i+1].shape[1]:
75 | f_list_.append(f)
76 |
77 | return f_list_
78 |
79 |
80 | def get_fine_tuning_parameters(model, ft_portion):
81 | if ft_portion == "complete":
82 | return model.parameters()
83 |
84 | elif ft_portion == "last_layer":
85 | ft_module_names = []
86 | ft_module_names.append('classifier')
87 |
88 | parameters = []
89 | for k, v in model.named_parameters():
90 | for ft_module in ft_module_names:
91 | if ft_module in k:
92 | parameters.append({'params': v})
93 | break
94 | else:
95 | parameters.append({'params': v, 'lr': 0.0})
96 | return parameters
97 |
98 | else:
99 | raise ValueError("Unsupported ft_portion: 'complete' or 'last_layer' expected")
100 |
101 |
102 | def get_model(**kwargs):
103 | """
104 | Returns the model.
105 | """
106 | model = MobileNet(**kwargs)
107 | return model
108 |
109 |
110 |
111 | if __name__ == '__main__':
112 | model = get_model(num_classes=600, sample_size = 112, width_mult=1.)
113 | model = model.cuda()
114 | model = nn.DataParallel(model, device_ids=None)
115 | print(model)
116 |
117 | input_var = Variable(torch.randn(8, 3, 16, 112, 112))
118 | output = model(input_var)
119 | print(output.shape)
120 |
--------------------------------------------------------------------------------
/demo/Demo7_2D_TransUnet_Segmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from wama_modules.Encoder import ResNetEncoder
4 | from wama_modules.Decoder import UNet_decoder
5 | from wama_modules.Head import SegmentationHead
6 | from wama_modules.utils import resizeTensor
7 | from transformers import ViTModel
8 | from wama_modules.utils import load_weights, tmp_class
9 |
10 |
11 | class TransUNet(nn.Module):
12 | def __init__(self, in_channel, label_category_dict, dim=2):
13 | super().__init__()
14 |
15 | # encoder
16 | Encoder_f_channel_list = [64, 128, 256, 512]
17 | self.encoder = ResNetEncoder(
18 | in_channel,
19 | stage_output_channels=Encoder_f_channel_list,
20 | stage_middle_channels=Encoder_f_channel_list,
21 | blocks=[1, 2, 3, 4],
22 | type='131',
23 | downsample_ration=[0.5, 0.5, 0.5, 0.5],
24 | dim=dim)
25 |
26 | # neck
27 | neck_out_channel = 768
28 | transformer = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k')
29 | configuration = transformer.config
30 | self.trans_downsample_size = configuration.image_size = [8, 8]
31 | configuration.patch_size = [1, 1]
32 | configuration.num_channels = Encoder_f_channel_list[-1]
33 | configuration.encoder_stride = 1 # just for MAE decoder, otherwise this paramater is not used
34 | self.neck = ViTModel(configuration, add_pooling_layer=False)
35 |
36 | pretrained_weights = transformer.state_dict()
37 | pretrained_weights['embeddings.position_embeddings'] = self.neck.state_dict()[
38 | 'embeddings.position_embeddings']
39 | pretrained_weights['embeddings.patch_embeddings.projection.weight'] = self.neck.state_dict()[
40 | 'embeddings.patch_embeddings.projection.weight']
41 | pretrained_weights['embeddings.patch_embeddings.projection.bias'] = self.neck.state_dict()[
42 | 'embeddings.patch_embeddings.projection.bias']
43 | self.neck = load_weights(self.neck, pretrained_weights) # reload pretrained weights
44 |
45 | # decoder
46 | Decoder_f_channel_list = [32, 64, 128]
47 | self.decoder = UNet_decoder(
48 | in_channels_list=Encoder_f_channel_list[:-1]+[neck_out_channel],
49 | skip_connection=[True, True, True],
50 | out_channels_list=Decoder_f_channel_list,
51 | dim=dim)
52 |
53 | # seg head
54 | self.seg_head = SegmentationHead(
55 | label_category_dict,
56 | Decoder_f_channel_list[0],
57 | dim=dim)
58 |
59 | def forward(self, x):
60 | # encoder forward
61 | multi_scale_encoder = self.encoder(x)
62 |
63 | # neck forward
64 | f_neck = self.neck(resizeTensor(multi_scale_encoder[-1], size=self.trans_downsample_size))
65 | f_neck = f_neck.last_hidden_state
66 | f_neck = f_neck[:, 1:] # remove class token
67 | f_neck = f_neck.permute(0, 2, 1)
68 | f_neck = f_neck.reshape(
69 | f_neck.shape[0],
70 | f_neck.shape[1],
71 | self.trans_downsample_size[0],
72 | self.trans_downsample_size[1]
73 | ) # reshape
74 | f_neck = resizeTensor(f_neck, size=multi_scale_encoder[-1].shape[2:])
75 | multi_scale_encoder[-1] = f_neck
76 |
77 | # decoder forward
78 | multi_scale_decoder = self.decoder(multi_scale_encoder)
79 | f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])
80 |
81 | # seg_head forward
82 | logits = self.seg_head(f_for_seg)
83 | return logits
84 |
85 |
86 | if __name__ == '__main__':
87 | x = torch.ones([2, 1, 256, 256])
88 | label_category_dict = dict(organ=3, tumor=4)
89 | model = TransUNet(in_channel=1, label_category_dict=label_category_dict, dim=2)
90 | with torch.no_grad():
91 | logits = model(x)
92 | print('multi_label predicted logits')
93 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
94 |
95 | # out
96 | # multi_label predicted logits
97 | # logits of organ : torch.Size([2, 3, 256, 256])
98 | # logits of tumor : torch.Size([2, 4, 256, 256])
99 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/base/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | try:
5 | from inplace_abn import InPlaceABN
6 | except ImportError:
7 | InPlaceABN = None
8 |
9 |
10 | class Conv2dReLU(nn.Sequential):
11 | def __init__(
12 | self,
13 | in_channels,
14 | out_channels,
15 | kernel_size,
16 | padding=0,
17 | stride=1,
18 | use_batchnorm=True,
19 | ):
20 |
21 | if use_batchnorm == "inplace" and InPlaceABN is None:
22 | raise RuntimeError(
23 | "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
24 | + "To install see: https://github.com/mapillary/inplace_abn"
25 | )
26 |
27 | conv = nn.Conv2d(
28 | in_channels,
29 | out_channels,
30 | kernel_size,
31 | stride=stride,
32 | padding=padding,
33 | bias=not (use_batchnorm),
34 | )
35 | relu = nn.ReLU(inplace=True)
36 |
37 | if use_batchnorm == "inplace":
38 | bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
39 | relu = nn.Identity()
40 |
41 | elif use_batchnorm and use_batchnorm != "inplace":
42 | bn = nn.BatchNorm2d(out_channels)
43 |
44 | else:
45 | bn = nn.Identity()
46 |
47 | super(Conv2dReLU, self).__init__(conv, bn, relu)
48 |
49 |
50 | class SCSEModule(nn.Module):
51 | def __init__(self, in_channels, reduction=16):
52 | super().__init__()
53 | self.cSE = nn.Sequential(
54 | nn.AdaptiveAvgPool2d(1),
55 | nn.Conv2d(in_channels, in_channels // reduction, 1),
56 | nn.ReLU(inplace=True),
57 | nn.Conv2d(in_channels // reduction, in_channels, 1),
58 | nn.Sigmoid(),
59 | )
60 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
61 |
62 | def forward(self, x):
63 | return x * self.cSE(x) + x * self.sSE(x)
64 |
65 |
66 | class ArgMax(nn.Module):
67 | def __init__(self, dim=None):
68 | super().__init__()
69 | self.dim = dim
70 |
71 | def forward(self, x):
72 | return torch.argmax(x, dim=self.dim)
73 |
74 |
75 | class Clamp(nn.Module):
76 | def __init__(self, min=0, max=1):
77 | super().__init__()
78 | self.min, self.max = min, max
79 |
80 | def forward(self, x):
81 | return torch.clamp(x, self.min, self.max)
82 |
83 |
84 | class Activation(nn.Module):
85 | def __init__(self, name, **params):
86 |
87 | super().__init__()
88 |
89 | if name is None or name == "identity":
90 | self.activation = nn.Identity(**params)
91 | elif name == "sigmoid":
92 | self.activation = nn.Sigmoid()
93 | elif name == "softmax2d":
94 | self.activation = nn.Softmax(dim=1, **params)
95 | elif name == "softmax":
96 | self.activation = nn.Softmax(**params)
97 | elif name == "logsoftmax":
98 | self.activation = nn.LogSoftmax(**params)
99 | elif name == "tanh":
100 | self.activation = nn.Tanh()
101 | elif name == "argmax":
102 | self.activation = ArgMax(**params)
103 | elif name == "argmax2d":
104 | self.activation = ArgMax(dim=1, **params)
105 | elif name == "clamp":
106 | self.activation = Clamp(**params)
107 | elif callable(name):
108 | self.activation = name(**params)
109 | else:
110 | raise ValueError(
111 | f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/"
112 | f"argmax/argmax2d/clamp/None; got {name}"
113 | )
114 |
115 | def forward(self, x):
116 | return self.activation(x)
117 |
118 |
119 | class Attention(nn.Module):
120 | def __init__(self, name, **params):
121 | super().__init__()
122 |
123 | if name is None:
124 | self.attention = nn.Identity(**params)
125 | elif name == "scse":
126 | self.attention = SCSEModule(**params)
127 | else:
128 | raise ValueError("Attention {} is not implemented".format(name))
129 |
130 | def forward(self, x):
131 | return self.attention(x)
132 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/utils/functional.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def _take_channels(*xs, ignore_channels=None):
5 | if ignore_channels is None:
6 | return xs
7 | else:
8 | channels = [channel for channel in range(xs[0].shape[1]) if channel not in ignore_channels]
9 | xs = [torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device)) for x in xs]
10 | return xs
11 |
12 |
13 | def _threshold(x, threshold=None):
14 | if threshold is not None:
15 | return (x > threshold).type(x.dtype)
16 | else:
17 | return x
18 |
19 |
20 | def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):
21 | """Calculate Intersection over Union between ground truth and prediction
22 | Args:
23 | pr (torch.Tensor): predicted tensor
24 | gt (torch.Tensor): ground truth tensor
25 | eps (float): epsilon to avoid zero division
26 | threshold: threshold for outputs binarization
27 | Returns:
28 | float: IoU (Jaccard) score
29 | """
30 |
31 | pr = _threshold(pr, threshold=threshold)
32 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)
33 |
34 | intersection = torch.sum(gt * pr)
35 | union = torch.sum(gt) + torch.sum(pr) - intersection + eps
36 | return (intersection + eps) / union
37 |
38 |
39 | jaccard = iou
40 |
41 |
42 | def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None):
43 | """Calculate F-score between ground truth and prediction
44 | Args:
45 | pr (torch.Tensor): predicted tensor
46 | gt (torch.Tensor): ground truth tensor
47 | beta (float): positive constant
48 | eps (float): epsilon to avoid zero division
49 | threshold: threshold for outputs binarization
50 | Returns:
51 | float: F score
52 | """
53 |
54 | pr = _threshold(pr, threshold=threshold)
55 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)
56 |
57 | tp = torch.sum(gt * pr)
58 | fp = torch.sum(pr) - tp
59 | fn = torch.sum(gt) - tp
60 |
61 | score = ((1 + beta**2) * tp + eps) / ((1 + beta**2) * tp + beta**2 * fn + fp + eps)
62 |
63 | return score
64 |
65 |
66 | def accuracy(pr, gt, threshold=0.5, ignore_channels=None):
67 | """Calculate accuracy score between ground truth and prediction
68 | Args:
69 | pr (torch.Tensor): predicted tensor
70 | gt (torch.Tensor): ground truth tensor
71 | eps (float): epsilon to avoid zero division
72 | threshold: threshold for outputs binarization
73 | Returns:
74 | float: precision score
75 | """
76 | pr = _threshold(pr, threshold=threshold)
77 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)
78 |
79 | tp = torch.sum(gt == pr, dtype=pr.dtype)
80 | score = tp / gt.view(-1).shape[0]
81 | return score
82 |
83 |
84 | def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):
85 | """Calculate precision score between ground truth and prediction
86 | Args:
87 | pr (torch.Tensor): predicted tensor
88 | gt (torch.Tensor): ground truth tensor
89 | eps (float): epsilon to avoid zero division
90 | threshold: threshold for outputs binarization
91 | Returns:
92 | float: precision score
93 | """
94 |
95 | pr = _threshold(pr, threshold=threshold)
96 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)
97 |
98 | tp = torch.sum(gt * pr)
99 | fp = torch.sum(pr) - tp
100 |
101 | score = (tp + eps) / (tp + fp + eps)
102 |
103 | return score
104 |
105 |
106 | def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):
107 | """Calculate Recall between ground truth and prediction
108 | Args:
109 | pr (torch.Tensor): A list of predicted elements
110 | gt (torch.Tensor): A list of elements that are to be predicted
111 | eps (float): epsilon to avoid zero division
112 | threshold: threshold for outputs binarization
113 | Returns:
114 | float: recall score
115 | """
116 |
117 | pr = _threshold(pr, threshold=threshold)
118 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)
119 |
120 | tp = torch.sum(gt * pr)
121 | fn = torch.sum(gt) - tp
122 |
123 | score = (tp + eps) / (tp + fn + eps)
124 |
125 | return score
126 |
--------------------------------------------------------------------------------
/demo/Demo8_3D_TransUnet_Segmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from wama_modules.Encoder import ResNetEncoder
4 | from wama_modules.Decoder import UNet_decoder
5 | from wama_modules.Head import SegmentationHead
6 | from wama_modules.utils import resizeTensor
7 | from transformers import ViTModel
8 | from wama_modules.utils import load_weights, tmp_class
9 |
10 |
11 | class TransUnet(nn.Module):
12 | def __init__(self, in_channel, label_category_dict, dim=2):
13 | super().__init__()
14 |
15 | # encoder
16 | Encoder_f_channel_list = [64, 128, 256, 512]
17 | self.encoder = ResNetEncoder(
18 | in_channel,
19 | stage_output_channels=Encoder_f_channel_list,
20 | stage_middle_channels=Encoder_f_channel_list,
21 | blocks=[1, 2, 3, 4],
22 | type='131',
23 | downsample_ration=[0.5, 0.5, 0.5, 0.5],
24 | dim=dim)
25 |
26 | # neck
27 | neck_out_channel = 768
28 | transformer = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k')
29 | configuration = transformer.config
30 | self.trans_size_3D = [8, 8, 4]
31 | self.trans_size = configuration.image_size = [
32 | self.trans_size_3D[0], self.trans_size_3D[1]*self.trans_size_3D[2]
33 | ]
34 | configuration.patch_size = [1, 1]
35 | configuration.num_channels = Encoder_f_channel_list[-1]
36 | configuration.encoder_stride = 1 # just for MAE decoder, otherwise this paramater is not used
37 | self.neck = ViTModel(configuration, add_pooling_layer=False)
38 |
39 | pretrained_weights = transformer.state_dict()
40 | pretrained_weights['embeddings.position_embeddings'] = self.neck.state_dict()[
41 | 'embeddings.position_embeddings']
42 | pretrained_weights['embeddings.patch_embeddings.projection.weight'] = self.neck.state_dict()[
43 | 'embeddings.patch_embeddings.projection.weight']
44 | pretrained_weights['embeddings.patch_embeddings.projection.bias'] = self.neck.state_dict()[
45 | 'embeddings.patch_embeddings.projection.bias']
46 | self.neck = load_weights(self.neck, pretrained_weights) # reload pretrained weights
47 |
48 | # decoder
49 | Decoder_f_channel_list = [32, 64, 128]
50 | self.decoder = UNet_decoder(
51 | in_channels_list=Encoder_f_channel_list[:-1]+[neck_out_channel],
52 | skip_connection=[True, True, True],
53 | out_channels_list=Decoder_f_channel_list,
54 | dim=dim)
55 |
56 | # seg head
57 | self.seg_head = SegmentationHead(
58 | label_category_dict,
59 | Decoder_f_channel_list[0],
60 | dim=dim)
61 |
62 | def forward(self, x):
63 | # encoder forward
64 | multi_scale_encoder = self.encoder(x)
65 |
66 | # neck forward
67 | neck_input = resizeTensor(multi_scale_encoder[-1], size=self.trans_size_3D)
68 | neck_input = neck_input.reshape(neck_input.shape[0], neck_input.shape[1], *self.trans_size) # 3D to 2D
69 | f_neck = self.neck(neck_input)
70 | f_neck = f_neck.last_hidden_state
71 | f_neck = f_neck[:, 1:] # remove class token
72 | f_neck = f_neck.permute(0, 2, 1)
73 | f_neck = f_neck.reshape(
74 | f_neck.shape[0],
75 | f_neck.shape[1],
76 | self.trans_size[0],
77 | self.trans_size[1]
78 | ) # reshape
79 | f_neck = f_neck.reshape(f_neck.shape[0], f_neck.shape[1], *self.trans_size_3D) # 2D to 3D
80 | f_neck = resizeTensor(f_neck, size=multi_scale_encoder[-1].shape[2:])
81 | multi_scale_encoder[-1] = f_neck
82 |
83 | # decoder forward
84 | multi_scale_decoder = self.decoder(multi_scale_encoder)
85 | f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])
86 |
87 | # seg_head forward
88 | logits = self.seg_head(f_for_seg)
89 | return logits
90 |
91 |
92 | if __name__ == '__main__':
93 | x = torch.ones([2, 1, 128, 128, 96])
94 | label_category_dict = dict(organ=3, tumor=4)
95 | model = TransUnet(in_channel=1, label_category_dict=label_category_dict, dim=3)
96 | with torch.no_grad():
97 | logits = model(x)
98 | print('multi_label predicted logits')
99 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
100 |
101 | # out
102 | # multi_label predicted logits
103 | # logits of organ : torch.Size([2, 3, 128, 128, 96])
104 | # logits of tumor : torch.Size([2, 4, 128, 128, 96])
105 |
106 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | import timm
2 | import functools
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | from .resnet import resnet_encoders
6 | from .dpn import dpn_encoders
7 | from .vgg import vgg_encoders
8 | from .senet import senet_encoders
9 | from .densenet import densenet_encoders
10 | from .inceptionresnetv2 import inceptionresnetv2_encoders
11 | from .inceptionv4 import inceptionv4_encoders
12 | from .efficientnet import efficient_net_encoders
13 | from .mobilenet import mobilenet_encoders
14 | from .xception import xception_encoders
15 | from .timm_efficientnet import timm_efficientnet_encoders
16 | from .timm_resnest import timm_resnest_encoders
17 | from .timm_res2net import timm_res2net_encoders
18 | from .timm_regnet import timm_regnet_encoders
19 | from .timm_sknet import timm_sknet_encoders
20 | from .timm_mobilenetv3 import timm_mobilenetv3_encoders
21 | from .timm_gernet import timm_gernet_encoders
22 | from .mix_transformer import mix_transformer_encoders
23 |
24 | from .timm_universal import TimmUniversalEncoder
25 |
26 | from ._preprocessing import preprocess_input
27 |
28 | encoders = {}
29 | encoders.update(resnet_encoders)
30 | encoders.update(dpn_encoders)
31 | encoders.update(vgg_encoders)
32 | encoders.update(senet_encoders)
33 | encoders.update(densenet_encoders)
34 | encoders.update(inceptionresnetv2_encoders)
35 | encoders.update(inceptionv4_encoders)
36 | encoders.update(efficient_net_encoders)
37 | encoders.update(mobilenet_encoders)
38 | encoders.update(xception_encoders)
39 | encoders.update(timm_efficientnet_encoders)
40 | encoders.update(timm_resnest_encoders)
41 | encoders.update(timm_res2net_encoders)
42 | encoders.update(timm_regnet_encoders)
43 | encoders.update(timm_sknet_encoders)
44 | encoders.update(timm_mobilenetv3_encoders)
45 | encoders.update(timm_gernet_encoders)
46 | encoders.update(mix_transformer_encoders)
47 |
48 |
49 | def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
50 |
51 | if name.startswith("tu-"):
52 | name = name[3:]
53 | encoder = TimmUniversalEncoder(
54 | name=name,
55 | in_channels=in_channels,
56 | depth=depth,
57 | output_stride=output_stride,
58 | pretrained=weights is not None,
59 | **kwargs,
60 | )
61 | return encoder
62 |
63 | try:
64 | Encoder = encoders[name]["encoder"]
65 | except KeyError:
66 | raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))
67 |
68 | params = encoders[name]["params"]
69 | params.update(depth=depth)
70 | encoder = Encoder(**params)
71 |
72 | if weights is not None:
73 | try:
74 | settings = encoders[name]["pretrained_settings"][weights]
75 | except KeyError:
76 | raise KeyError(
77 | "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
78 | weights,
79 | name,
80 | list(encoders[name]["pretrained_settings"].keys()),
81 | )
82 | )
83 | encoder.load_state_dict(model_zoo.load_url(settings["url"]))
84 |
85 | encoder.set_in_channels(in_channels, pretrained=weights is not None)
86 | if output_stride != 32:
87 | encoder.make_dilated(output_stride)
88 |
89 | return encoder
90 |
91 |
92 | def get_encoder_names():
93 | return list(encoders.keys())
94 |
95 |
96 | def get_preprocessing_params(encoder_name, pretrained="imagenet"):
97 |
98 | if encoder_name.startswith("tu-"):
99 | encoder_name = encoder_name[3:]
100 | if encoder_name not in timm.models.registry._model_has_pretrained:
101 | raise ValueError(f"{encoder_name} does not have pretrained weights and preprocessing parameters")
102 | settings = timm.models.registry._model_default_cfgs[encoder_name]
103 | else:
104 | all_settings = encoders[encoder_name]["pretrained_settings"]
105 | if pretrained not in all_settings.keys():
106 | raise ValueError("Available pretrained options {}".format(all_settings.keys()))
107 | settings = all_settings[pretrained]
108 |
109 | formatted_settings = {}
110 | formatted_settings["input_space"] = settings.get("input_space", "RGB")
111 | formatted_settings["input_range"] = list(settings.get("input_range", [0, 1]))
112 | formatted_settings["mean"] = list(settings.get("mean"))
113 | formatted_settings["std"] = list(settings.get("std"))
114 |
115 | return formatted_settings
116 |
117 |
118 | def get_preprocessing_fn(encoder_name, pretrained="imagenet"):
119 | params = get_preprocessing_params(encoder_name, pretrained=pretrained)
120 | return functools.partial(preprocess_input, **params)
121 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/timm_gernet.py:
--------------------------------------------------------------------------------
1 | from timm.models import ByoModelCfg, ByoBlockCfg, ByobNet
2 |
3 | from ._base import EncoderMixin
4 | import torch.nn as nn
5 |
6 |
7 | class GERNetEncoder(ByobNet, EncoderMixin):
8 | def __init__(self, out_channels, depth=5, **kwargs):
9 | super().__init__(**kwargs)
10 | self._depth = depth
11 | self._out_channels = out_channels
12 | self._in_channels = 3
13 |
14 | del self.head
15 |
16 | def get_stages(self):
17 | return [
18 | nn.Identity(),
19 | self.stem,
20 | self.stages[0],
21 | self.stages[1],
22 | self.stages[2],
23 | nn.Sequential(self.stages[3], self.stages[4], self.final_conv),
24 | ]
25 |
26 | def forward(self, x):
27 | stages = self.get_stages()
28 |
29 | features = []
30 | for i in range(self._depth + 1):
31 | x = stages[i](x)
32 | features.append(x)
33 |
34 | return features
35 |
36 | def load_state_dict(self, state_dict, **kwargs):
37 | state_dict.pop("head.fc.weight", None)
38 | state_dict.pop("head.fc.bias", None)
39 | super().load_state_dict(state_dict, **kwargs)
40 |
41 |
42 | regnet_weights = {
43 | "timm-gernet_s": {
44 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth", # noqa
45 | },
46 | "timm-gernet_m": {
47 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth", # noqa
48 | },
49 | "timm-gernet_l": {
50 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth", # noqa
51 | },
52 | }
53 |
54 | pretrained_settings = {}
55 | for model_name, sources in regnet_weights.items():
56 | pretrained_settings[model_name] = {}
57 | for source_name, source_url in sources.items():
58 | pretrained_settings[model_name][source_name] = {
59 | "url": source_url,
60 | "input_range": [0, 1],
61 | "mean": [0.485, 0.456, 0.406],
62 | "std": [0.229, 0.224, 0.225],
63 | "num_classes": 1000,
64 | }
65 |
66 | timm_gernet_encoders = {
67 | "timm-gernet_s": {
68 | "encoder": GERNetEncoder,
69 | "pretrained_settings": pretrained_settings["timm-gernet_s"],
70 | "params": {
71 | "out_channels": (3, 13, 48, 48, 384, 1920),
72 | "cfg": ByoModelCfg(
73 | blocks=(
74 | ByoBlockCfg(type="basic", d=1, c=48, s=2, gs=0, br=1.0),
75 | ByoBlockCfg(type="basic", d=3, c=48, s=2, gs=0, br=1.0),
76 | ByoBlockCfg(type="bottle", d=7, c=384, s=2, gs=0, br=1 / 4),
77 | ByoBlockCfg(type="bottle", d=2, c=560, s=2, gs=1, br=3.0),
78 | ByoBlockCfg(type="bottle", d=1, c=256, s=1, gs=1, br=3.0),
79 | ),
80 | stem_chs=13,
81 | stem_pool=None,
82 | num_features=1920,
83 | ),
84 | },
85 | },
86 | "timm-gernet_m": {
87 | "encoder": GERNetEncoder,
88 | "pretrained_settings": pretrained_settings["timm-gernet_m"],
89 | "params": {
90 | "out_channels": (3, 32, 128, 192, 640, 2560),
91 | "cfg": ByoModelCfg(
92 | blocks=(
93 | ByoBlockCfg(type="basic", d=1, c=128, s=2, gs=0, br=1.0),
94 | ByoBlockCfg(type="basic", d=2, c=192, s=2, gs=0, br=1.0),
95 | ByoBlockCfg(type="bottle", d=6, c=640, s=2, gs=0, br=1 / 4),
96 | ByoBlockCfg(type="bottle", d=4, c=640, s=2, gs=1, br=3.0),
97 | ByoBlockCfg(type="bottle", d=1, c=640, s=1, gs=1, br=3.0),
98 | ),
99 | stem_chs=32,
100 | stem_pool=None,
101 | num_features=2560,
102 | ),
103 | },
104 | },
105 | "timm-gernet_l": {
106 | "encoder": GERNetEncoder,
107 | "pretrained_settings": pretrained_settings["timm-gernet_l"],
108 | "params": {
109 | "out_channels": (3, 32, 128, 192, 640, 2560),
110 | "cfg": ByoModelCfg(
111 | blocks=(
112 | ByoBlockCfg(type="basic", d=1, c=128, s=2, gs=0, br=1.0),
113 | ByoBlockCfg(type="basic", d=2, c=192, s=2, gs=0, br=1.0),
114 | ByoBlockCfg(type="bottle", d=6, c=640, s=2, gs=0, br=1 / 4),
115 | ByoBlockCfg(type="bottle", d=5, c=640, s=2, gs=1, br=3.0),
116 | ByoBlockCfg(type="bottle", d=4, c=640, s=1, gs=1, br=3.0),
117 | ),
118 | stem_chs=32,
119 | stem_pool=None,
120 | num_features=2560,
121 | ),
122 | },
123 | },
124 | }
125 |
--------------------------------------------------------------------------------
/Document_allmodules.md:
--------------------------------------------------------------------------------
1 |
2 | # All modules and functions
3 |
4 | ## 1 `wama_modules.BaseModule`
5 |
6 | ### 1.1 Pooling
7 | - `GlobalAvgPool` Global average pooling
8 | - `GlobalMaxPool` Global maximum pooling
9 | - `GlobalMaxAvgPool` GlobalMaxAvgPool = (GlobalAvgPool + GlobalMaxPool) / 2.
10 |
11 |
12 | Click here to see demo code
13 |
14 | ```python
15 | """ demo """
16 | # import libs
17 | import torch
18 | from wama_modules.BaseModule import GlobalAvgPool, GlobalMaxPool, GlobalMaxAvgPool
19 |
20 | # make tensor
21 | inputs1D = torch.ones([3,12,13]) # 1D
22 | inputs2D = torch.ones([3,12,13,13]) # 2D
23 | inputs3D = torch.ones([3,12,13,13,13]) # 3D
24 |
25 | # build layer
26 | GAP = GlobalAvgPool()
27 | GMP = GlobalMaxPool()
28 | GAMP = GlobalMaxAvgPool()
29 |
30 | # test GAP & GMP & GAMP
31 | print(inputs1D.shape, GAP(inputs1D).shape)
32 | print(inputs2D.shape, GAP(inputs2D).shape)
33 | print(inputs3D.shape, GAP(inputs3D).shape)
34 |
35 | print(inputs1D.shape, GMP(inputs1D).shape)
36 | print(inputs2D.shape, GMP(inputs2D).shape)
37 | print(inputs3D.shape, GMP(inputs3D).shape)
38 |
39 | print(inputs1D.shape, GAMP(inputs1D).shape)
40 | print(inputs2D.shape, GAMP(inputs2D).shape)
41 | print(inputs3D.shape, GAMP(inputs3D).shape)
42 | ```
43 |
44 |
45 |
46 | ### 1.2 Norm&Activation
47 | - `customLayerNorm` a custom implementation of layer normalization
48 | - `MakeNorm` make normalization layer, includes BN / GN / IN / LN
49 | - `MakeActive` make activation layer, includes Relu / LeakyRelu
50 | - `MakeConv` make 1D / 2D / 3D convolutional layer
51 |
52 |
53 | Click here to see demo code
54 |
55 | ```python
56 | """ demo """
57 | ```
58 |
59 |
60 |
61 |
62 | ### 1.3 Conv
63 | - `ConvNormActive` 'Convolution→Normalization→Activation', used in VGG or ResNet
64 | - `NormActiveConv` 'Normalization→Activation→Convolution', used in DenseNet
65 | - `VGGBlock` the basic module in VGG
66 | - `VGGStage` a VGGStage = few VGGBlocks
67 | - `ResBlock` the basic module in ResNet
68 | - `ResStage` a ResStage = few ResBlocks
69 | - `DenseLayer` the basic module in DenseNet
70 | - `DenseBlock` a DenseBlock = few DenseLayers
71 |
72 |
73 | Click here to see demo code
74 |
75 | ```python
76 | """ demo """
77 | ```
78 |
79 |
80 | ## 2 `wama_modules.utils`
81 | - `resizeTensor` scale torch tensor, similar to scipy's zoom
82 | - `tensor2array` transform tensor to ndarray
83 | - `load_weights` load torch weights and print loading details(miss keys and match keys)
84 |
85 |
86 | Click here to see demo code
87 |
88 | ```python
89 | """ demo """
90 | ```
91 |
92 |
93 |
94 | ## 3 `wama_modules.Attention`
95 | - `SCSEModule`
96 | - `NonLocal`
97 |
98 |
99 | Click here to see demo code
100 |
101 | ```python
102 | """ demo """
103 | ```
104 |
105 |
106 |
107 | ## 4 `wama_modules.Encoder`
108 | - `VGGEncoder`
109 | - `ResNetEncoder`
110 | - `DenseNetEncoder`
111 | - `???`
112 |
113 |
114 | Click here to see demo code
115 |
116 | ```python
117 | """ demo """
118 | ```
119 |
120 |
121 |
122 | ## 5 `wama_modules.Decoder`
123 | - `UNet_decoder`
124 |
125 |
126 | Click here to see demo code
127 |
128 | ```python
129 | """ demo """
130 | ```
131 |
132 |
133 |
134 | ## 6 `wama_modules.Neck`
135 | - `FPN`
136 |
137 |
138 | Click here to see demo code
139 |
140 | ```python
141 | """ demo """
142 | import torch
143 | from wama_modules.Neck import FPN
144 |
145 | # make multi-scale feature maps
146 | featuremaps = [
147 | torch.ones([3,16,32,32,32]),
148 | torch.ones([3,32,24,24,24]),
149 | torch.ones([3,64,16,16,16]),
150 | torch.ones([3,128,8,8,8]),
151 | ]
152 |
153 | # build FPN
154 | fpn_AddSmall2Big = FPN(in_channels_list=[16, 32, 64, 128],
155 | c1=128,
156 | c2=256,
157 | active='relu',
158 | norm='bn',
159 | gn_c=8,
160 | mode='AddSmall2Big',
161 | dim=3,)
162 | fpn_AddBig2Small = FPN(in_channels_list=[16, 32, 64, 128],
163 | c1=128,
164 | c2=256,
165 | active='relu',
166 | norm='bn',
167 | gn_c=8,
168 | mode='AddBig2Small', # Add big size feature to small size feature, for classification
169 | dim=3,)
170 |
171 | # forward
172 | f_listA = fpn_AddSmall2Big(featuremaps)
173 | f_listB = fpn_AddBig2Small(featuremaps)
174 | _ = [print(i.shape) for i in featuremaps]
175 | _ = [print(i.shape) for i in f_listA]
176 | _ = [print(i.shape) for i in f_listB]
177 | ```
178 |
179 |
180 |
181 | ## 7 `wama_modules.Transformer`
182 | - `FeedForward`
183 | - `MultiHeadAttention`
184 | - `TransformerEncoderLayer`
185 | - `TransformerDecoderLayer`
186 |
187 |
188 | Click here to see demo code
189 |
190 | ```python
191 | """ demo """
192 | ```
193 |
194 |
--------------------------------------------------------------------------------
/wama_modules/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import pickle
6 |
7 |
8 | class tmp_class():
9 | """
10 | for debug
11 | """
12 | def __init__(self,):
13 | super().__init__()
14 |
15 |
16 | def resizeTensor(x, scale_factor=None, size=None):
17 | """
18 | resize for 1D\2D\3D tensor (1D → signal 2D → image, 3D → volume)
19 |
20 | :param x: 1D [bz,c,l] 2D [bz,c,w,h] 3D [bz,c,w,h,l]
21 | :param scale_factor: 1D [2.,] 2D [2.,2.,] 3D [3.,3.,3.,]
22 | :param size: 2D [256.,256,m] or torch.ones([256,256]).shape
23 | :return:
24 |
25 | # 1D demo:
26 | x = torch.ones([3,1,256])
27 | y = torch.ones([3,1,128])
28 | x1 = resizeTensor(x, scale_factor=[2.,])
29 | print(x1.shape)
30 | x1 = resizeTensor(x, size=y.shape[-1:])
31 | print(x1.shape)
32 |
33 | # 2D demo:
34 | x = torch.ones([3,1,256,256])
35 | y = torch.ones([3,1,256,128])
36 | x1 = resizeTensor(x, scale_factor=[2.,2.])
37 | print(x1.shape)
38 | x1 = resizeTensor(x, size=y.shape[-2:])
39 | print(x1.shape)
40 |
41 | # 3D demo:
42 | x = torch.ones([3,1,256,256,256])
43 | y = torch.ones([3,1,256,128,128])
44 | x1 = resizeTensor(x, scale_factor=[2.,2.,2.])
45 | print(x1.shape)
46 | x1 = resizeTensor(x, size=y.shape[-3:])
47 | print(x1.shape)
48 |
49 | """
50 | if len(x.shape) == 3:
51 | return F.interpolate(x, scale_factor=scale_factor, size=size,
52 | mode='linear',
53 | align_corners=True)
54 | if len(x.shape) == 4:
55 | return F.interpolate(x, scale_factor=scale_factor, size=size,
56 | mode='bicubic',
57 | align_corners=True)
58 | elif len(x.shape) == 5:
59 | return F.interpolate(x, scale_factor=scale_factor, size=size,
60 | mode='trilinear',
61 | align_corners=True)
62 |
63 |
64 | def tensor2array(tensor):
65 | return tensor.data.cpu().numpy()
66 |
67 |
68 | def load_weights(model, state_dict, drop_modelDOT=False, silence=False):
69 | if drop_modelDOT:
70 | new_dict = {}
71 | for k, v in state_dict.items():
72 | new_dict[k[7:]] = v
73 | state_dict = new_dict
74 | net_dict = model.state_dict() # model dict
75 | pretrain_dict = {k: v for k, v in state_dict.items()} # pretrain dict
76 | InPretrain_InModel_dict = {k: v for k, v in state_dict.items() if k in net_dict.keys()}
77 | InPretrain_NotInModel_dict = {k: v for k, v in state_dict.items() if k not in net_dict.keys()}
78 | NotInPretrain_InModel_dict = {k: v for k, v in net_dict.items() if k not in state_dict.keys()}
79 | if not silence:
80 | print('-' * 200)
81 | print('keys ( Current model,C ) ', len(net_dict.keys()), net_dict.keys())
82 | print('keys ( Pre-trained ,P ) ', len(pretrain_dict.keys()), pretrain_dict.keys())
83 | print('keys ( In C & In P ) ', len(InPretrain_InModel_dict.keys()), InPretrain_InModel_dict.keys())
84 | print('keys ( NoIn C & In P ) ', len(InPretrain_NotInModel_dict.keys()), InPretrain_NotInModel_dict.keys())
85 | print('keys ( In C & NoIn P ) ', len(NotInPretrain_InModel_dict.keys()), NotInPretrain_InModel_dict.keys())
86 | print('-' * 200)
87 | print('Pretrained keys :', len(InPretrain_InModel_dict.keys()), InPretrain_InModel_dict.keys())
88 | print('Non-Pretrained keys:', len(NotInPretrain_InModel_dict.keys()), NotInPretrain_InModel_dict.keys())
89 | print('-' * 200)
90 | net_dict.update(InPretrain_InModel_dict)
91 | model.load_state_dict(net_dict)
92 | return model
93 |
94 |
95 | def MaxMinNorm(array, FirstDimBATCH = True):
96 | """
97 | :param array:
98 | :param FirstDimBATCH: bool, is the first dim batch? True or False
99 | :return:
100 |
101 | # demo for numpy ndarray
102 |
103 |
104 |
105 |
106 | # demo for torch tensor
107 |
108 |
109 |
110 |
111 |
112 | """
113 | pass
114 |
115 |
116 | def mat2gray(image):
117 | """
118 | 归一化函数(线性归一化)
119 | :param image: ndarray
120 | :return:
121 | """
122 | # as dtype = np.float32
123 | image = image.astype(np.float32)
124 | image = (image - np.min(image)) / (np.max(image) - np.min(image) + 1e-14)
125 | return image
126 |
127 |
128 | def save_as_pkl(save_path, obj):
129 | data_output = open(save_path, 'wb')
130 | pickle.dump(obj, data_output)
131 | data_output.close()
132 |
133 | def load_from_pkl(load_path):
134 | data_input = open(load_path, 'rb')
135 | read_data = pickle.load(data_input)
136 | data_input.close()
137 | return read_data
138 |
139 |
140 | import matplotlib.pyplot as plt
141 | def show2D(img):
142 | plt.imshow(img)
143 | plt.show()
144 |
145 | # try:
146 | # from mayavi import mlab
147 | # def show3D(img3D):
148 | # vol = mlab.pipeline.volume(mlab.pipeline.scalar_field(img3D), name='3-d ultrasound ')
149 | # mlab.colorbar(orientation='vertical')
150 | # mlab.show()
151 | # except:
152 | # pass
153 |
--------------------------------------------------------------------------------
/wama_modules/Decoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from wama_modules.utils import tensor2array, resizeTensor, tmp_class
5 | from wama_modules.BaseModule import ConvNormActive
6 |
7 |
8 | class UNet_decoder(nn.Module):
9 | def __init__(self,
10 | in_channels_list=[64, 128, 256, 512], # from shallow to deep
11 | skip_connection=[True, True, True], # from shallow to deep
12 | out_channels_list=[12, 13, 14], # from shallow to deep
13 | norm='bn',
14 | gn_c=8,
15 | dim=2,
16 | ):
17 | super().__init__()
18 | self._skip_connection = skip_connection[::-1] # from deep to shallow
19 | _skip_channels_list = in_channels_list[-2::-1] # from deep to shallow [256, 128, 64]
20 | _out_channels_list = out_channels_list[::-1] # from deep to shallow [14, 13, 12]
21 | _in_conv_list = [in_channels_list[-1]] + _out_channels_list[:-1] # from deep to shallow
22 | self.docoder_conv_list = nn.ModuleList([])
23 | for stage, _out_channels in enumerate(_out_channels_list):
24 | if self._skip_connection[stage]:
25 | _in_channel = _in_conv_list[stage] + _skip_channels_list[stage]
26 | else:
27 | _in_channel = _in_conv_list[stage]
28 | _out_channel = _out_channels_list[stage]
29 | self.docoder_conv_list.append(
30 | nn.Sequential(
31 | ConvNormActive(_in_channel, _out_channel, kernel_size=3, norm=norm, gn_c=gn_c, dim=dim),
32 | ConvNormActive(_out_channel, _out_channel, kernel_size=3, norm=norm, gn_c=gn_c, dim=dim),
33 | )
34 | )
35 |
36 | def forward(self, f_list):
37 | """
38 | :return: decoder_f_list, feature list from shallow to deep, and decoder_f_list[0] can be used for seg head
39 |
40 | # demo
41 |
42 | # 1D -------------------------------------------------------------
43 | f_list = [
44 | torch.ones([3,64,128]),
45 | torch.ones([3,128,64]),
46 | torch.ones([3,256,32]),
47 | torch.ones([3,512,8]),
48 | ]
49 |
50 | decoder = UNet_decoder(
51 | in_channels_list=[64, 128, 256, 512], # from shallow to deep
52 | skip_connection=[False, True, True], # from shallow to deep
53 | out_channels_list=[12, 13, 14], # from shallow to deep
54 | norm='bn',
55 | gn_c=8,
56 | dim=1
57 | )
58 |
59 | decoder_f_list = decoder(f_list)
60 | _ = [print(i.shape) for i in decoder_f_list]
61 |
62 | # 2D -------------------------------------------------------------
63 | f_list = [
64 | torch.ones([3,64,128,128]),
65 | torch.ones([3,128,64,64]),
66 | torch.ones([3,256,32,32]),
67 | torch.ones([3,512,8,8]),
68 | ]
69 |
70 | decoder = UNet_decoder(
71 | in_channels_list=[64, 128, 256, 512], # from shallow to deep
72 | skip_connection=[False, True, True], # from shallow to deep
73 | out_channels_list=[12, 13, 14], # from shallow to deep
74 | norm='bn',
75 | gn_c=8,
76 | dim=2
77 | )
78 |
79 | decoder_f_list = decoder(f_list)
80 | _ = [print(i.shape) for i in decoder_f_list]
81 |
82 | # 3D -------------------------------------------------------------
83 | f_list = [
84 | torch.ones([3,64,128,128,128]),
85 | torch.ones([3,128,64,64,64]),
86 | torch.ones([3,256,32,32,32]),
87 | torch.ones([3,512,8,8,8]),
88 | ]
89 |
90 | decoder = UNet_decoder(
91 | in_channels_list=[64, 128, 256, 512], # from shallow to deep
92 | skip_connection=[False, True, True], # from shallow to deep
93 | out_channels_list=[12, 13, 14], # from shallow to deep
94 | norm='bn',
95 | gn_c=8,
96 | dim=3
97 | )
98 |
99 | decoder_f_list = decoder(f_list)
100 | _ = [print(i.shape) for i in decoder_f_list]
101 |
102 | """
103 | _f_list = f_list[::-1]
104 | feature = _f_list[0]
105 | _f_list = _f_list[1:]
106 | decoder_f_list = []
107 | for stage, conv in enumerate(self.docoder_conv_list):
108 | if self._skip_connection[stage]:
109 | _in_feature = torch.cat([resizeTensor(feature, size=_f_list[stage].shape[2:]), _f_list[stage]], 1)
110 | else:
111 | _in_feature = resizeTensor(feature, size=_f_list[stage].shape[2:])
112 | feature = conv(_in_feature)
113 | decoder_f_list.append(feature)
114 | decoder_f_list = decoder_f_list[::-1]
115 | return decoder_f_list # from shallow to deep, and decoder_f_list[0] can be used for seg head
116 |
117 | # psp
118 |
119 |
120 | # deeplabv3+
121 | # try this https://blog.csdn.net/m0_51436734/article/details/124073901
122 |
123 |
124 | # NestedUNet(Unet++)
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/squeezenet.py:
--------------------------------------------------------------------------------
1 | '''SqueezeNet in PyTorch.
2 |
3 | See the paper "SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size" for more details.
4 | '''
5 |
6 | import math
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.init as init
10 | import torch.nn.functional as F
11 | from torch.autograd import Variable
12 | from functools import partial
13 |
14 | __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
15 |
16 |
17 | class Fire(nn.Module):
18 |
19 | def __init__(self, inplanes, squeeze_planes,
20 | expand1x1_planes, expand3x3_planes,
21 | use_bypass=False):
22 | super(Fire, self).__init__()
23 | self.use_bypass = use_bypass
24 | self.inplanes = inplanes
25 | self.relu = nn.ReLU(inplace=True)
26 | self.squeeze = nn.Conv3d(inplanes, squeeze_planes, kernel_size=1)
27 | self.squeeze_bn = nn.BatchNorm3d(squeeze_planes)
28 | self.expand1x1 = nn.Conv3d(squeeze_planes, expand1x1_planes,
29 | kernel_size=1)
30 | self.expand1x1_bn = nn.BatchNorm3d(expand1x1_planes)
31 | self.expand3x3 = nn.Conv3d(squeeze_planes, expand3x3_planes,
32 | kernel_size=3, padding=1)
33 | self.expand3x3_bn = nn.BatchNorm3d(expand3x3_planes)
34 |
35 | def forward(self, x):
36 | out = self.squeeze(x)
37 | out = self.squeeze_bn(out)
38 | out = self.relu(out)
39 |
40 | out1 = self.expand1x1(out)
41 | out1 = self.expand1x1_bn(out1)
42 |
43 | out2 = self.expand3x3(out)
44 | out2 = self.expand3x3_bn(out2)
45 |
46 | out = torch.cat([out1, out2], 1)
47 | if self.use_bypass:
48 | out += x
49 | out = self.relu(out)
50 |
51 | return out
52 |
53 |
54 | class SqueezeNet(nn.Module):
55 |
56 | def __init__(self,):
57 | super(SqueezeNet, self).__init__()
58 | # if version not in [1.0, 1.1]:
59 | # raise ValueError("Unsupported SqueezeNet version {version}:"
60 | # "1.0 or 1.1 expected".format(version=version))
61 |
62 | # if version == 1.1:
63 | if True:
64 | self.features = nn.Sequential(
65 | nn.Conv3d(3, 64, kernel_size=3, stride=(1,2,2), padding=(1,1,1)), # 0
66 | nn.BatchNorm3d(64), # 1
67 | nn.ReLU(inplace=True), # 2
68 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1), # todo 3
69 | Fire(64, 16, 64, 64), # 4
70 | Fire(128, 16, 64, 64, use_bypass=True), # 5
71 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1), # todo 6
72 | Fire(128, 32, 128, 128), # 7
73 | Fire(256, 32, 128, 128, use_bypass=True), # 8
74 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1), # todo 9
75 | Fire(256, 48, 192, 192), # 10
76 | Fire(384, 48, 192, 192, use_bypass=True), # 11
77 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1), # todo 12
78 | Fire(384, 64, 256, 256), # 13
79 | Fire(512, 64, 256, 256, use_bypass=True), # todo 14
80 | )
81 | # Final convolution is initialized differently form the rest
82 |
83 | for m in self.modules():
84 | if isinstance(m, nn.Conv3d):
85 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
86 | elif isinstance(m, nn.BatchNorm3d):
87 | m.weight.data.fill_(1)
88 | m.bias.data.zero_()
89 |
90 |
91 | def forward(self, x):
92 | f_list = []
93 | for i in range(len(self.features)):
94 | x = self.features[i](x)
95 | if i in [3,6,9,12,14]:
96 | f_list.append(x)
97 | return f_list
98 |
99 |
100 | def get_fine_tuning_parameters(model, ft_portion):
101 | if ft_portion == "complete":
102 | return model.parameters()
103 |
104 | elif ft_portion == "last_layer":
105 | ft_module_names = []
106 | ft_module_names.append('classifier')
107 |
108 | parameters = []
109 | for k, v in model.named_parameters():
110 | for ft_module in ft_module_names:
111 | if ft_module in k:
112 | parameters.append({'params': v})
113 | break
114 | else:
115 | parameters.append({'params': v, 'lr': 0.0})
116 | return parameters
117 |
118 | else:
119 | raise ValueError("Unsupported ft_portion: 'complete' or 'last_layer' expected")
120 |
121 |
122 | def get_model(**kwargs):
123 | """
124 | Returns the model.
125 | """
126 | model = SqueezeNet(**kwargs)
127 | return model
128 |
129 |
130 | if __name__ == '__main__':
131 | model = SqueezeNet(version=1.1, sample_size = 112, sample_duration = 16, num_classes=600)
132 | model = model.cuda()
133 | model = nn.DataParallel(model, device_ids=None)
134 | print(model)
135 |
136 | input_var = Variable(torch.randn(8, 3, 16, 112, 112))
137 | output = model(input_var)
138 | print(output.shape)
139 |
--------------------------------------------------------------------------------
/wama_modules/Neck.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from wama_modules.BaseModule import ConvNormActive
4 | from wama_modules.utils import resizeTensor
5 |
6 |
7 | class FPN(nn.Module):
8 | def __init__(self,
9 | in_channels_list=[16, 32, 64, 128],
10 | c1=128,
11 | c2=256,
12 | active='relu',
13 | norm='bn',
14 | gn_c=8,
15 | mode='AddSmall2Big', # AddSmall2Big or AddBig2Small(much better or classification tasks)
16 | dim=2,
17 | ):
18 | super().__init__()
19 | self.mode = mode
20 |
21 | self.conv1_list = nn.ModuleList([
22 | ConvNormActive(in_channels, c1, kernel_size=1, norm=norm, active=active, gn_c=gn_c, dim=dim, padding=0)
23 | for in_channels in in_channels_list
24 | ])
25 | self.conv2_list = nn.ModuleList([
26 | ConvNormActive(c1, c2, kernel_size=3, norm=norm, active=active, gn_c=gn_c, dim=dim, padding=1)
27 | for _ in range(len(in_channels_list))
28 | ])
29 |
30 | def forward(self, x_list):
31 | """
32 | :param x_list: multi scale feature maps, from shallow(big size) to deep(small size)
33 | :return:
34 |
35 | # demo
36 | # 1D
37 | x_list = [
38 | torch.ones([3,16,32]),
39 | torch.ones([3,32,24]),
40 | torch.ones([3,64,16]),
41 | torch.ones([3,128,8]),
42 | ]
43 | fpn = FPN(in_channels_list=[16, 32, 64, 128],
44 | c1=128,
45 | c2=256,
46 | active='relu',
47 | norm='bn',
48 | gn_c=8,
49 | mode='AddSmall2Big',
50 | dim=1,)
51 | fpn = FPN(in_channels_list=[16, 32, 64, 128],
52 | c1=128,
53 | c2=256,
54 | active='relu',
55 | norm='bn',
56 | gn_c=8,
57 | mode='AddBig2Small', # revserse, for classification
58 | dim=1,)
59 | f_list = fpn(x_list)
60 | _ = [print(i.shape) for i in x_list]
61 | _ = [print(i.shape) for i in f_list]
62 |
63 | # 2D
64 | x_list = [
65 | torch.ones([3,16,32,32]),
66 | torch.ones([3,32,24,24]),
67 | torch.ones([3,64,16,16]),
68 | torch.ones([3,128,8,8]),
69 | ]
70 | fpn = FPN(in_channels_list=[16, 32, 64, 128],
71 | c1=128,
72 | c2=256,
73 | active='relu',
74 | norm='bn',
75 | gn_c=8,
76 | mode='AddSmall2Big',
77 | dim=2,)
78 | fpn = FPN(in_channels_list=[16, 32, 64, 128],
79 | c1=128,
80 | c2=256,
81 | active='relu',
82 | norm='bn',
83 | gn_c=8,
84 | mode='AddBig2Small', # revserse, for classification
85 | dim=2,)
86 | f_list = fpn(x_list)
87 | _ = [print(i.shape) for i in x_list]
88 | _ = [print(i.shape) for i in f_list]
89 |
90 |
91 | # 3D
92 | x_list = [
93 | torch.ones([3,16,32,32,32]),
94 | torch.ones([3,32,24,24,24]),
95 | torch.ones([3,64,16,16,16]),
96 | torch.ones([3,128,8,8,8]),
97 | ]
98 | fpn = FPN(in_channels_list=[16, 32, 64, 128],
99 | c1=128,
100 | c2=256,
101 | active='relu',
102 | norm='bn',
103 | gn_c=8,
104 | mode='AddSmall2Big',
105 | dim=3,)
106 | fpn = FPN(in_channels_list=[16, 32, 64, 128],
107 | c1=128,
108 | c2=256,
109 | active='relu',
110 | norm='bn',
111 | gn_c=8,
112 | mode='AddBig2Small', # revserse, for classification
113 | dim=3,)
114 | f_list = fpn(x_list)
115 | _ = [print(i.shape) for i in x_list]
116 | _ = [print(i.shape) for i in f_list]
117 |
118 | """
119 | f_list = [self.conv1_list[index](f) for index, f in enumerate(x_list)]
120 |
121 | if self.mode == 'AddSmall2Big':
122 | f_list_2 = []
123 | x = f_list[-1]
124 | f_list_2.append(x)
125 | for index in range(len(f_list)-1):
126 | # print(f_list[-(index+2)].shape[2:])
127 | x = f_list[-(index + 2)] + resizeTensor(x, size=f_list[-(index+2)].shape[2:])
128 | f_list_2.append(x)
129 | f_list_2 = f_list_2[::-1]
130 | elif self.mode == 'AddBig2Small':
131 | f_list_2 = []
132 | x = f_list[0]
133 | f_list_2.append(x)
134 | for index in range(len(f_list) - 1):
135 | # print(f_list[index + 1].shape[2:])
136 | x = f_list[index + 1] + resizeTensor(x, size=f_list[index + 1].shape[2:])
137 | f_list_2.append(x)
138 |
139 | return_list = [self.conv2_list[index](f) for index, f in enumerate(f_list_2)]
140 |
141 | return return_list
142 |
143 |
144 |
145 |
146 | # UCtrans的
147 |
148 |
149 |
150 |
151 |
152 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/densenet.py:
--------------------------------------------------------------------------------
1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
2 |
3 | Attributes:
4 |
5 | _out_channels (list of int): specify number of channels for each encoder feature tensor
6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations)
7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
8 |
9 | Methods:
10 |
11 | forward(self, x: torch.Tensor)
12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting
14 | with resolution same as input `x` tensor).
15 |
16 | Input: `x` with shape (1, 3, 64, 64)
17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
20 |
21 | also should support number of features according to specified depth, e.g. if depth = 5,
22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24 | """
25 |
26 | import re
27 | import torch.nn as nn
28 |
29 | from pretrainedmodels.models.torchvision_models import pretrained_settings
30 | from torchvision.models.densenet import DenseNet
31 |
32 | from ._base import EncoderMixin
33 |
34 |
35 | class TransitionWithSkip(nn.Module):
36 | def __init__(self, module):
37 | super().__init__()
38 | self.module = module
39 |
40 | def forward(self, x):
41 | for module in self.module:
42 | x = module(x)
43 | if isinstance(module, nn.ReLU):
44 | skip = x
45 | return x, skip
46 |
47 |
48 | class DenseNetEncoder(DenseNet, EncoderMixin):
49 | def __init__(self, out_channels, depth=5, **kwargs):
50 | super().__init__(**kwargs)
51 | self._out_channels = out_channels
52 | self._depth = depth
53 | self._in_channels = 3
54 | del self.classifier
55 |
56 | def make_dilated(self, *args, **kwargs):
57 | raise ValueError("DenseNet encoders do not support dilated mode " "due to pooling operation for downsampling!")
58 |
59 | def get_stages(self):
60 | return [
61 | nn.Identity(),
62 | nn.Sequential(self.features.conv0, self.features.norm0, self.features.relu0),
63 | nn.Sequential(
64 | self.features.pool0,
65 | self.features.denseblock1,
66 | TransitionWithSkip(self.features.transition1),
67 | ),
68 | nn.Sequential(self.features.denseblock2, TransitionWithSkip(self.features.transition2)),
69 | nn.Sequential(self.features.denseblock3, TransitionWithSkip(self.features.transition3)),
70 | nn.Sequential(self.features.denseblock4, self.features.norm5),
71 | ]
72 |
73 | def forward(self, x):
74 |
75 | stages = self.get_stages()
76 |
77 | features = []
78 | for i in range(self._depth + 1):
79 | x = stages[i](x)
80 | if isinstance(x, (list, tuple)):
81 | x, skip = x
82 | features.append(skip)
83 | else:
84 | features.append(x)
85 |
86 | return features
87 |
88 | def load_state_dict(self, state_dict):
89 | pattern = re.compile(
90 | r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
91 | )
92 | for key in list(state_dict.keys()):
93 | res = pattern.match(key)
94 | if res:
95 | new_key = res.group(1) + res.group(2)
96 | state_dict[new_key] = state_dict[key]
97 | del state_dict[key]
98 |
99 | # remove linear
100 | state_dict.pop("classifier.bias", None)
101 | state_dict.pop("classifier.weight", None)
102 |
103 | super().load_state_dict(state_dict)
104 |
105 |
106 | densenet_encoders = {
107 | "densenet121": {
108 | "encoder": DenseNetEncoder,
109 | "pretrained_settings": pretrained_settings["densenet121"],
110 | "params": {
111 | "out_channels": (3, 64, 256, 512, 1024, 1024),
112 | "num_init_features": 64,
113 | "growth_rate": 32,
114 | "block_config": (6, 12, 24, 16),
115 | },
116 | },
117 | "densenet169": {
118 | "encoder": DenseNetEncoder,
119 | "pretrained_settings": pretrained_settings["densenet169"],
120 | "params": {
121 | "out_channels": (3, 64, 256, 512, 1280, 1664),
122 | "num_init_features": 64,
123 | "growth_rate": 32,
124 | "block_config": (6, 12, 32, 32),
125 | },
126 | },
127 | "densenet201": {
128 | "encoder": DenseNetEncoder,
129 | "pretrained_settings": pretrained_settings["densenet201"],
130 | "params": {
131 | "out_channels": (3, 64, 256, 512, 1792, 1920),
132 | "num_init_features": 64,
133 | "growth_rate": 32,
134 | "block_config": (6, 12, 48, 32),
135 | },
136 | },
137 | "densenet161": {
138 | "encoder": DenseNetEncoder,
139 | "pretrained_settings": pretrained_settings["densenet161"],
140 | "params": {
141 | "out_channels": (3, 96, 384, 768, 2112, 2208),
142 | "num_init_features": 96,
143 | "growth_rate": 48,
144 | "block_config": (6, 12, 36, 24),
145 | },
146 | },
147 | }
148 |
--------------------------------------------------------------------------------
/demo/Demo_BilinearPooling.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Variable
6 |
7 | import torch.fft as afft
8 |
9 |
10 | class CompactBilinearPooling(nn.Module):
11 | """
12 | from https://github.com/DeepInsight-PCALab/CompactBilinearPooling-Pytorch
13 |
14 | Compute compact bilinear pooling over two bottom inputs.
15 | Args:
16 | output_dim: output dimension for compact bilinear pooling.
17 | sum_pool: (Optional) If True, sum the output along height and width
18 | dimensions and return output shape [batch_size, output_dim].
19 | Otherwise return [batch_size, height, width, output_dim].
20 | Default: True.
21 | rand_h_1: (Optional) an 1D numpy array containing indices in interval
22 | `[0, output_dim)`. Automatically generated from `seed_h_1`
23 | if is None.
24 | rand_s_1: (Optional) an 1D numpy array of 1 and -1, having the same shape
25 | as `rand_h_1`. Automatically generated from `seed_s_1` if is
26 | None.
27 | rand_h_2: (Optional) an 1D numpy array containing indices in interval
28 | `[0, output_dim)`. Automatically generated from `seed_h_2`
29 | if is None.
30 | rand_s_2: (Optional) an 1D numpy array of 1 and -1, having the same shape
31 | as `rand_h_2`. Automatically generated from `seed_s_2` if is
32 | None.
33 | """
34 |
35 | def __init__(self, input_dim1, input_dim2, output_dim,
36 | sum_pool=True, cuda=True,
37 | rand_h_1=None, rand_s_1=None, rand_h_2=None, rand_s_2=None):
38 | super(CompactBilinearPooling, self).__init__()
39 | self.input_dim1 = input_dim1
40 | self.input_dim2 = input_dim2
41 | self.output_dim = output_dim
42 | self.sum_pool = sum_pool
43 |
44 | if rand_h_1 is None:
45 | np.random.seed(1)
46 | rand_h_1 = np.random.randint(output_dim, size=self.input_dim1)
47 | if rand_s_1 is None:
48 | np.random.seed(3)
49 | rand_s_1 = 2 * np.random.randint(2, size=self.input_dim1) - 1
50 |
51 | self.sparse_sketch_matrix1 = Variable(self.generate_sketch_matrix(
52 | rand_h_1, rand_s_1, self.output_dim))
53 |
54 | if rand_h_2 is None:
55 | np.random.seed(5)
56 | rand_h_2 = np.random.randint(output_dim, size=self.input_dim2)
57 | if rand_s_2 is None:
58 | np.random.seed(7)
59 | rand_s_2 = 2 * np.random.randint(2, size=self.input_dim2) - 1
60 |
61 | self.sparse_sketch_matrix2 = Variable(self.generate_sketch_matrix(
62 | rand_h_2, rand_s_2, self.output_dim))
63 |
64 | if cuda:
65 | self.sparse_sketch_matrix1 = self.sparse_sketch_matrix1.cuda()
66 | self.sparse_sketch_matrix2 = self.sparse_sketch_matrix2.cuda()
67 |
68 | def forward(self, bottom1, bottom2):
69 | """
70 | bottom1: 1st input, 4D Tensor of shape [batch_size, input_dim1, height, width].
71 | bottom2: 2nd input, 4D Tensor of shape [batch_size, input_dim2, height, width].
72 | """
73 | assert bottom1.size(1) == self.input_dim1 and \
74 | bottom2.size(1) == self.input_dim2
75 |
76 | batch_size, _, height, width = bottom1.size()
77 |
78 | bottom1_flat = bottom1.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1)
79 | bottom2_flat = bottom2.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim2)
80 |
81 | sketch_1 = bottom1_flat.mm(self.sparse_sketch_matrix1)
82 | sketch_2 = bottom2_flat.mm(self.sparse_sketch_matrix2)
83 |
84 | fft1 = afft.fft(sketch_1)
85 | fft2 = afft.fft(sketch_2)
86 |
87 | fft_product = fft1 * fft2
88 |
89 | cbp_flat = afft.ifft(fft_product).real
90 |
91 | cbp = cbp_flat.view(batch_size, height, width, self.output_dim)
92 |
93 | if self.sum_pool:
94 | cbp = cbp.sum(dim=1).sum(dim=1)
95 |
96 | return cbp
97 |
98 | @staticmethod
99 | def generate_sketch_matrix(rand_h, rand_s, output_dim):
100 | """
101 | Return a sparse matrix used for tensor sketch operation in compact bilinear
102 | pooling
103 | Args:
104 | rand_h: an 1D numpy array containing indices in interval `[0, output_dim)`.
105 | rand_s: an 1D numpy array of 1 and -1, having the same shape as `rand_h`.
106 | output_dim: the output dimensions of compact bilinear pooling.
107 | Returns:
108 | a sparse matrix of shape [input_dim, output_dim] for tensor sketch.
109 | """
110 |
111 | # Generate a sparse matrix for tensor count sketch
112 | rand_h = rand_h.astype(np.int64)
113 | rand_s = rand_s.astype(np.float32)
114 | assert(rand_h.ndim == 1 and rand_s.ndim ==
115 | 1 and len(rand_h) == len(rand_s))
116 | assert(np.all(rand_h >= 0) and np.all(rand_h < output_dim))
117 |
118 | input_dim = len(rand_h)
119 | indices = np.concatenate((np.arange(input_dim)[..., np.newaxis],
120 | rand_h[..., np.newaxis]), axis=1)
121 | indices = torch.from_numpy(indices)
122 | rand_s = torch.from_numpy(rand_s)
123 | sparse_sketch_matrix = torch.sparse.FloatTensor(
124 | indices.t(), rand_s, torch.Size([input_dim, output_dim]))
125 | return sparse_sketch_matrix.to_dense()
126 |
127 |
128 | if __name__ == '__main__':
129 |
130 | bottom1 = Variable(torch.randn(3, 512, 14, 14))
131 | bottom2 = Variable(torch.randn(3, 128, 14, 14))
132 |
133 | layer = CompactBilinearPooling(512, 128, 512, cuda=False)
134 | layer.train()
135 |
136 | out = layer(bottom1, bottom2)
137 | print(out.shape)
138 |
139 |
140 |
141 |
142 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/vgg.py:
--------------------------------------------------------------------------------
1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
2 |
3 | Attributes:
4 |
5 | _out_channels (list of int): specify number of channels for each encoder feature tensor
6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations)
7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
8 |
9 | Methods:
10 |
11 | forward(self, x: torch.Tensor)
12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting
14 | with resolution same as input `x` tensor).
15 |
16 | Input: `x` with shape (1, 3, 64, 64)
17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
20 |
21 | also should support number of features according to specified depth, e.g. if depth = 5,
22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24 | """
25 |
26 | import torch.nn as nn
27 | from torchvision.models.vgg import VGG
28 | from torchvision.models.vgg import make_layers
29 | from pretrainedmodels.models.torchvision_models import pretrained_settings
30 |
31 | from ._base import EncoderMixin
32 |
33 | # fmt: off
34 | cfg = {
35 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
36 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
37 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
38 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
39 | }
40 | # fmt: on
41 |
42 |
43 | class VGGEncoder(VGG, EncoderMixin):
44 | def __init__(self, out_channels, config, batch_norm=False, depth=5, **kwargs):
45 | super().__init__(make_layers(config, batch_norm=batch_norm), **kwargs)
46 | self._out_channels = out_channels
47 | self._depth = depth
48 | self._in_channels = 3
49 | del self.classifier
50 |
51 | def make_dilated(self, *args, **kwargs):
52 | raise ValueError("'VGG' models do not support dilated mode due to Max Pooling" " operations for downsampling!")
53 |
54 | def get_stages(self):
55 | stages = []
56 | stage_modules = []
57 | for module in self.features:
58 | if isinstance(module, nn.MaxPool2d):
59 | stages.append(nn.Sequential(*stage_modules))
60 | stage_modules = []
61 | stage_modules.append(module)
62 | stages.append(nn.Sequential(*stage_modules))
63 | return stages
64 |
65 | def forward(self, x):
66 | stages = self.get_stages()
67 |
68 | features = []
69 | for i in range(self._depth + 1):
70 | x = stages[i](x)
71 | features.append(x)
72 |
73 | return features
74 |
75 | def load_state_dict(self, state_dict, **kwargs):
76 | keys = list(state_dict.keys())
77 | for k in keys:
78 | if k.startswith("classifier"):
79 | state_dict.pop(k, None)
80 | super().load_state_dict(state_dict, **kwargs)
81 |
82 |
83 | vgg_encoders = {
84 | "vgg11": {
85 | "encoder": VGGEncoder,
86 | "pretrained_settings": pretrained_settings["vgg11"],
87 | "params": {
88 | "out_channels": (64, 128, 256, 512, 512, 512),
89 | "config": cfg["A"],
90 | "batch_norm": False,
91 | },
92 | },
93 | "vgg11_bn": {
94 | "encoder": VGGEncoder,
95 | "pretrained_settings": pretrained_settings["vgg11_bn"],
96 | "params": {
97 | "out_channels": (64, 128, 256, 512, 512, 512),
98 | "config": cfg["A"],
99 | "batch_norm": True,
100 | },
101 | },
102 | "vgg13": {
103 | "encoder": VGGEncoder,
104 | "pretrained_settings": pretrained_settings["vgg13"],
105 | "params": {
106 | "out_channels": (64, 128, 256, 512, 512, 512),
107 | "config": cfg["B"],
108 | "batch_norm": False,
109 | },
110 | },
111 | "vgg13_bn": {
112 | "encoder": VGGEncoder,
113 | "pretrained_settings": pretrained_settings["vgg13_bn"],
114 | "params": {
115 | "out_channels": (64, 128, 256, 512, 512, 512),
116 | "config": cfg["B"],
117 | "batch_norm": True,
118 | },
119 | },
120 | "vgg16": {
121 | "encoder": VGGEncoder,
122 | "pretrained_settings": pretrained_settings["vgg16"],
123 | "params": {
124 | "out_channels": (64, 128, 256, 512, 512, 512),
125 | "config": cfg["D"],
126 | "batch_norm": False,
127 | },
128 | },
129 | "vgg16_bn": {
130 | "encoder": VGGEncoder,
131 | "pretrained_settings": pretrained_settings["vgg16_bn"],
132 | "params": {
133 | "out_channels": (64, 128, 256, 512, 512, 512),
134 | "config": cfg["D"],
135 | "batch_norm": True,
136 | },
137 | },
138 | "vgg19": {
139 | "encoder": VGGEncoder,
140 | "pretrained_settings": pretrained_settings["vgg19"],
141 | "params": {
142 | "out_channels": (64, 128, 256, 512, 512, 512),
143 | "config": cfg["E"],
144 | "batch_norm": False,
145 | },
146 | },
147 | "vgg19_bn": {
148 | "encoder": VGGEncoder,
149 | "pretrained_settings": pretrained_settings["vgg19_bn"],
150 | "params": {
151 | "out_channels": (64, 128, 256, 512, 512, 512),
152 | "config": cfg["E"],
153 | "batch_norm": True,
154 | },
155 | },
156 | }
157 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/shufflenet.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNet in PyTorch.
2 |
3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 |
10 |
11 | def conv_bn(inp, oup, stride):
12 | return nn.Sequential(
13 | nn.Conv3d(inp, oup, kernel_size=3, stride=stride, padding=(1,1,1), bias=False),
14 | nn.BatchNorm3d(oup),
15 | nn.ReLU(inplace=True)
16 | )
17 |
18 |
19 | def channel_shuffle(x, groups):
20 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
21 | batchsize, num_channels, depth, height, width = x.data.size()
22 | channels_per_group = num_channels // groups
23 | # reshape
24 | x = x.view(batchsize, groups,
25 | channels_per_group, depth, height, width)
26 | #permute
27 | x = x.permute(0,2,1,3,4,5).contiguous()
28 | # flatten
29 | x = x.view(batchsize, num_channels, depth, height, width)
30 | return x
31 |
32 |
33 |
34 | class Bottleneck(nn.Module):
35 | def __init__(self, in_planes, out_planes, stride, groups):
36 | super(Bottleneck, self).__init__()
37 | self.stride = stride
38 | self.groups = groups
39 | mid_planes = out_planes//4
40 | if self.stride == 2:
41 | out_planes = out_planes - in_planes
42 | g = 1 if in_planes==24 else groups
43 | self.conv1 = nn.Conv3d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False)
44 | self.bn1 = nn.BatchNorm3d(mid_planes)
45 | self.conv2 = nn.Conv3d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False)
46 | self.bn2 = nn.BatchNorm3d(mid_planes)
47 | self.conv3 = nn.Conv3d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False)
48 | self.bn3 = nn.BatchNorm3d(out_planes)
49 | self.relu = nn.ReLU(inplace=True)
50 |
51 | if stride == 2:
52 | self.shortcut = nn.AvgPool3d(kernel_size=(2,3,3), stride=2, padding=(0,1,1))
53 |
54 |
55 | def forward(self, x):
56 | out = self.relu(self.bn1(self.conv1(x)))
57 | out = channel_shuffle(out, self.groups)
58 | out = self.bn2(self.conv2(out))
59 | out = self.bn3(self.conv3(out))
60 |
61 | if self.stride == 2:
62 | out = self.relu(torch.cat([out, self.shortcut(x)], 1))
63 | else:
64 | out = self.relu(out + x)
65 |
66 | return out
67 |
68 |
69 | class ShuffleNet(nn.Module):
70 | def __init__(self,
71 | groups=3,
72 | width_mult=1,
73 | num_classes=400):
74 | super(ShuffleNet, self).__init__()
75 | self.num_classes = num_classes
76 | self.groups = groups
77 | num_blocks = [4,8,4]
78 |
79 | # index 0 is invalid and should never be called.
80 | # only used for indexing convenience.
81 | if groups == 1:
82 | out_planes = [24, 144, 288, 567]
83 | elif groups == 2:
84 | out_planes = [24, 200, 400, 800]
85 | elif groups == 3:
86 | out_planes = [24, 240, 480, 960]
87 | elif groups == 4:
88 | out_planes = [24, 272, 544, 1088]
89 | elif groups == 8:
90 | out_planes = [24, 384, 768, 1536]
91 | else:
92 | raise ValueError(
93 | """{} groups is not supported for
94 | 1x1 Grouped Convolutions""".format(groups))
95 | out_planes = [int(i * width_mult) for i in out_planes]
96 | self.in_planes = out_planes[0]
97 | self.conv1 = conv_bn(3, self.in_planes, stride=(1,2,2))
98 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
99 | self.layer1 = self._make_layer(out_planes[1], num_blocks[0], self.groups)
100 | self.layer2 = self._make_layer(out_planes[2], num_blocks[1], self.groups)
101 | self.layer3 = self._make_layer(out_planes[3], num_blocks[2], self.groups)
102 |
103 | def _make_layer(self, out_planes, num_blocks, groups):
104 | layers = []
105 | for i in range(num_blocks):
106 | stride = 2 if i == 0 else 1
107 | layers.append(Bottleneck(self.in_planes, out_planes, stride=stride, groups=groups))
108 | self.in_planes = out_planes
109 | return nn.Sequential(*layers)
110 |
111 | def forward(self, x):
112 | f_list = []
113 | out = self.conv1(x)
114 | out = self.maxpool(out)
115 | f_list.append(out)
116 | out = self.layer1(out)
117 | f_list.append(out)
118 | out = self.layer2(out)
119 | f_list.append(out)
120 | out = self.layer3(out)
121 | f_list.append(out)
122 | return f_list
123 |
124 | def get_fine_tuning_parameters(model, ft_portion):
125 | if ft_portion == "complete":
126 | return model.parameters()
127 |
128 | elif ft_portion == "last_layer":
129 | ft_module_names = []
130 | ft_module_names.append('classifier')
131 |
132 | parameters = []
133 | for k, v in model.named_parameters():
134 | for ft_module in ft_module_names:
135 | if ft_module in k:
136 | parameters.append({'params': v})
137 | break
138 | else:
139 | parameters.append({'params': v, 'lr': 0.0})
140 | return parameters
141 |
142 | else:
143 | raise ValueError("Unsupported ft_portion: 'complete' or 'last_layer' expected")
144 |
145 |
146 | def get_model(**kwargs):
147 | """
148 | Returns the model.
149 | """
150 | model = ShuffleNet(**kwargs)
151 | return model
152 |
153 |
154 | if __name__ == "__main__":
155 | model = get_model(groups=3, num_classes=600, width_mult=1)
156 | model = model.cuda()
157 | model = nn.DataParallel(model, device_ids=None)
158 | print(model)
159 |
160 | input_var = Variable(torch.randn(8, 3, 16, 112, 112))
161 | output = model(input_var)
162 | print(output.shape)
163 |
164 |
165 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/VC3D_kenshohara/wide_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import math
6 | from functools import partial
7 |
8 | __all__ = ['WideResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101']
9 |
10 |
11 | def conv3x3x3(in_planes, out_planes, stride=1):
12 | # 3x3x3 convolution with padding
13 | return nn.Conv3d(in_planes, out_planes, kernel_size=3,
14 | stride=stride, padding=1, bias=False)
15 |
16 |
17 | def downsample_basic_block(x, planes, stride):
18 | out = F.avg_pool3d(x, kernel_size=1, stride=stride)
19 | zero_pads = torch.Tensor(out.size(0), planes - out.size(1),
20 | out.size(2), out.size(3),
21 | out.size(4)).zero_()
22 | if isinstance(out.data, torch.cuda.FloatTensor):
23 | zero_pads = zero_pads.cuda()
24 |
25 | out = Variable(torch.cat([out.data, zero_pads], dim=1))
26 |
27 | return out
28 |
29 |
30 | class WideBottleneck(nn.Module):
31 | expansion = 2
32 |
33 | def __init__(self, inplanes, planes, stride=1, downsample=None):
34 | super(WideBottleneck, self).__init__()
35 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
36 | self.bn1 = nn.BatchNorm3d(planes)
37 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride,
38 | padding=1, bias=False)
39 | self.bn2 = nn.BatchNorm3d(planes)
40 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False)
41 | self.bn3 = nn.BatchNorm3d(planes * self.expansion)
42 | self.relu = nn.ReLU(inplace=True)
43 | self.downsample = downsample
44 | self.stride = stride
45 |
46 | def forward(self, x):
47 | residual = x
48 |
49 | out = self.conv1(x)
50 | out = self.bn1(out)
51 | out = self.relu(out)
52 |
53 | out = self.conv2(out)
54 | out = self.bn2(out)
55 | out = self.relu(out)
56 |
57 | out = self.conv3(out)
58 | out = self.bn3(out)
59 |
60 | if self.downsample is not None:
61 | residual = self.downsample(x)
62 |
63 | out += residual
64 | out = self.relu(out)
65 |
66 | return out
67 |
68 |
69 | class WideResNet(nn.Module):
70 |
71 | def __init__(self, block, layers, sample_size = None, sample_duration = None, k=2, shortcut_type='B', num_classes=400, last_fc=True):
72 | self.last_fc = last_fc
73 |
74 | self.inplanes = 64
75 | super(WideResNet, self).__init__()
76 | self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2),
77 | padding=(3, 3, 3), bias=False)
78 | self.bn1 = nn.BatchNorm3d(64)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
81 | self.layer1 = self._make_layer(block, 64 * k, layers[0], shortcut_type)
82 | self.layer2 = self._make_layer(block, 128 * k, layers[1], shortcut_type, stride=2)
83 | self.layer3 = self._make_layer(block, 256 * k, layers[2], shortcut_type, stride=2)
84 | self.layer4 = self._make_layer(block, 512 * k, layers[3], shortcut_type, stride=2)
85 | # last_duration = math.ceil(sample_duration / 16)
86 | # last_size = math.ceil(sample_size / 32)
87 | # self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1)
88 | # self.fc = nn.Linear(512 * k * block.expansion, num_classes)
89 |
90 | for m in self.modules():
91 | if isinstance(m, nn.Conv3d):
92 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
93 | m.weight.data.normal_(0, math.sqrt(2. / n))
94 | elif isinstance(m, nn.BatchNorm3d):
95 | m.weight.data.fill_(1)
96 | m.bias.data.zero_()
97 |
98 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
99 | downsample = None
100 | if stride != 1 or self.inplanes != planes * block.expansion:
101 | if shortcut_type == 'A':
102 | downsample = partial(downsample_basic_block,
103 | planes=planes * block.expansion,
104 | stride=stride)
105 | else:
106 | downsample = nn.Sequential(
107 | nn.Conv3d(self.inplanes, planes * block.expansion,
108 | kernel_size=1, stride=stride, bias=False),
109 | nn.BatchNorm3d(planes * block.expansion)
110 | )
111 |
112 | layers = []
113 | layers.append(block(self.inplanes, planes, stride, downsample))
114 | self.inplanes = planes * block.expansion
115 | for i in range(1, blocks):
116 | layers.append(block(self.inplanes, planes))
117 |
118 | return nn.Sequential(*layers)
119 |
120 | def forward(self, x):
121 | f_list = []
122 | x = self.conv1(x)
123 | x = self.bn1(x)
124 | x = self.relu(x)
125 | x = self.maxpool(x)
126 | f_list.append(x)
127 |
128 | x = self.layer1(x)
129 | f_list.append(x)
130 | x = self.layer2(x)
131 | f_list.append(x)
132 | x = self.layer3(x)
133 | f_list.append(x)
134 | x = self.layer4(x)
135 | f_list.append(x)
136 |
137 | return f_list
138 |
139 | def get_fine_tuning_parameters(model, ft_begin_index):
140 | if ft_begin_index == 0:
141 | return model.parameters()
142 |
143 | ft_module_names = []
144 | for i in range(ft_begin_index, 5):
145 | ft_module_names.append('layer{}'.format(ft_begin_index))
146 | ft_module_names.append('fc')
147 |
148 | parameters = []
149 | for k, v in model.named_parameters():
150 | for ft_module in ft_module_names:
151 | if ft_module in k:
152 | parameters.append({'params': v})
153 | break
154 | else:
155 | parameters.append({'params': v, 'lr': 0.0})
156 |
157 | return parameters
158 |
159 | def generate_model():
160 | """Constructs a ResNet-50 model.
161 | """
162 | model = WideResNet(WideBottleneck, [3, 4, 6, 3])
163 | return model
164 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/Efficient3D_okankop/models/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | '''MobilenetV2 in PyTorch.
2 |
3 | See the paper "MobileNetV2: Inverted Residuals and Linear Bottlenecks" for more details.
4 | '''
5 | import torch
6 | import math
7 | import torch.nn as nn
8 | from torch.autograd import Variable
9 |
10 |
11 |
12 |
13 | def conv_bn(inp, oup, stride):
14 | return nn.Sequential(
15 | nn.Conv3d(inp, oup, kernel_size=3, stride=stride, padding=(1,1,1), bias=False),
16 | nn.BatchNorm3d(oup),
17 | nn.ReLU6(inplace=True)
18 | )
19 |
20 |
21 | def conv_1x1x1_bn(inp, oup):
22 | return nn.Sequential(
23 | nn.Conv3d(inp, oup, 1, 1, 0, bias=False),
24 | nn.BatchNorm3d(oup),
25 | nn.ReLU6(inplace=True)
26 | )
27 |
28 |
29 | class InvertedResidual(nn.Module):
30 | def __init__(self, inp, oup, stride, expand_ratio):
31 | super(InvertedResidual, self).__init__()
32 | self.stride = stride
33 |
34 | hidden_dim = round(inp * expand_ratio)
35 | self.use_res_connect = self.stride == (1,1,1) and inp == oup
36 |
37 | if expand_ratio == 1:
38 | self.conv = nn.Sequential(
39 | # dw
40 | nn.Conv3d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
41 | nn.BatchNorm3d(hidden_dim),
42 | nn.ReLU6(inplace=True),
43 | # pw-linear
44 | nn.Conv3d(hidden_dim, oup, 1, 1, 0, bias=False),
45 | nn.BatchNorm3d(oup),
46 | )
47 | else:
48 | self.conv = nn.Sequential(
49 | # pw
50 | nn.Conv3d(inp, hidden_dim, 1, 1, 0, bias=False),
51 | nn.BatchNorm3d(hidden_dim),
52 | nn.ReLU6(inplace=True),
53 | # dw
54 | nn.Conv3d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
55 | nn.BatchNorm3d(hidden_dim),
56 | nn.ReLU6(inplace=True),
57 | # pw-linear
58 | nn.Conv3d(hidden_dim, oup, 1, 1, 0, bias=False),
59 | nn.BatchNorm3d(oup),
60 | )
61 |
62 | def forward(self, x):
63 | if self.use_res_connect:
64 | return x + self.conv(x)
65 | else:
66 | return self.conv(x)
67 |
68 |
69 | class MobileNetV2(nn.Module):
70 | def __init__(self, width_mult=1.):
71 | super(MobileNetV2, self).__init__()
72 | block = InvertedResidual
73 | input_channel = 32
74 | last_channel = 1280
75 | interverted_residual_setting = [
76 | # t, c, n, s
77 | [1, 16, 1, (1,1,1)],
78 | [6, 24, 2, (2,2,2)],
79 | [6, 32, 3, (2,2,2)],
80 | [6, 64, 4, (2,2,2)],
81 | [6, 96, 3, (1,1,1)],
82 | [6, 160, 3, (2,2,2)],
83 | [6, 320, 1, (1,1,1)],
84 | ]
85 |
86 | # building first layer
87 | input_channel = int(input_channel * width_mult)
88 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
89 | self.features = [conv_bn(3, input_channel, (1,2,2))]
90 | # building inverted residual blocks
91 | for t, c, n, s in interverted_residual_setting:
92 | output_channel = int(c * width_mult)
93 | for i in range(n):
94 | stride = s if i == 0 else (1,1,1)
95 | self.features.append(block(input_channel, output_channel, stride, expand_ratio=t))
96 | input_channel = output_channel
97 | # building last several layers
98 | self.features.append(conv_1x1x1_bn(input_channel, self.last_channel))
99 | # make it nn.Sequential
100 | self.features = nn.Sequential(*self.features)
101 |
102 |
103 |
104 | def forward(self, x):
105 | f_list = []
106 | for i in range(len(self.features)):
107 | x = self.features[i](x)
108 | f_list.append(x)
109 |
110 | # keep last f
111 | f_list_ = []
112 | for i, f in enumerate(f_list):
113 | if i == 0 or i == len(f_list)-1:
114 | f_list_.append(f)
115 | elif f.shape[1] != f_list[i+1].shape[1] and f.shape[2:] != f_list[i+1].shape[2:]:
116 | f_list_.append(f)
117 |
118 | return f_list_
119 |
120 |
121 | def _initialize_weights(self):
122 | for m in self.modules():
123 | if isinstance(m, nn.Conv3d):
124 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
125 | m.weight.data.normal_(0, math.sqrt(2. / n))
126 | if m.bias is not None:
127 | m.bias.data.zero_()
128 | elif isinstance(m, nn.BatchNorm3d):
129 | m.weight.data.fill_(1)
130 | m.bias.data.zero_()
131 | elif isinstance(m, nn.Linear):
132 | n = m.weight.size(1)
133 | m.weight.data.normal_(0, 0.01)
134 | m.bias.data.zero_()
135 |
136 |
137 | def get_fine_tuning_parameters(model, ft_portion):
138 | if ft_portion == "complete":
139 | return model.parameters()
140 |
141 | elif ft_portion == "last_layer":
142 | ft_module_names = []
143 | ft_module_names.append('classifier')
144 |
145 | parameters = []
146 | for k, v in model.named_parameters():
147 | for ft_module in ft_module_names:
148 | if ft_module in k:
149 | parameters.append({'params': v})
150 | break
151 | else:
152 | parameters.append({'params': v, 'lr': 0.0})
153 | return parameters
154 |
155 | else:
156 | raise ValueError("Unsupported ft_portion: 'complete' or 'last_layer' expected")
157 |
158 |
159 | def get_model(**kwargs):
160 | """
161 | Returns the model.
162 | """
163 | model = MobileNetV2(**kwargs)
164 | return model
165 |
166 |
167 | if __name__ == "__main__":
168 | model = get_model(num_classes=600, sample_size=112, width_mult=1.)
169 | model = model.cuda()
170 | model = nn.DataParallel(model, device_ids=None)
171 | print(model)
172 |
173 |
174 | input_var = Variable(torch.randn(8, 3, 16, 112, 112))
175 | output = model(input_var)
176 | print(output.shape)
177 |
178 |
179 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/timm_res2net.py:
--------------------------------------------------------------------------------
1 | from ._base import EncoderMixin
2 | from timm.models.resnet import ResNet
3 | from timm.models.res2net import Bottle2neck
4 | import torch.nn as nn
5 |
6 |
7 | class Res2NetEncoder(ResNet, EncoderMixin):
8 | def __init__(self, out_channels, depth=5, **kwargs):
9 | super().__init__(**kwargs)
10 | self._depth = depth
11 | self._out_channels = out_channels
12 | self._in_channels = 3
13 |
14 | del self.fc
15 | del self.global_pool
16 |
17 | def get_stages(self):
18 | return [
19 | nn.Identity(),
20 | nn.Sequential(self.conv1, self.bn1, self.act1),
21 | nn.Sequential(self.maxpool, self.layer1),
22 | self.layer2,
23 | self.layer3,
24 | self.layer4,
25 | ]
26 |
27 | def make_dilated(self, *args, **kwargs):
28 | raise ValueError("Res2Net encoders do not support dilated mode")
29 |
30 | def forward(self, x):
31 | stages = self.get_stages()
32 |
33 | features = []
34 | for i in range(self._depth + 1):
35 | x = stages[i](x)
36 | features.append(x)
37 |
38 | return features
39 |
40 | def load_state_dict(self, state_dict, **kwargs):
41 | state_dict.pop("fc.bias", None)
42 | state_dict.pop("fc.weight", None)
43 | super().load_state_dict(state_dict, **kwargs)
44 |
45 |
46 | res2net_weights = {
47 | "timm-res2net50_26w_4s": {
48 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth", # noqa
49 | },
50 | "timm-res2net50_48w_2s": {
51 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth", # noqa
52 | },
53 | "timm-res2net50_14w_8s": {
54 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth", # noqa
55 | },
56 | "timm-res2net50_26w_6s": {
57 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth", # noqa
58 | },
59 | "timm-res2net50_26w_8s": {
60 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth", # noqa
61 | },
62 | "timm-res2net101_26w_4s": {
63 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth", # noqa
64 | },
65 | "timm-res2next50": {
66 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth", # noqa
67 | },
68 | }
69 |
70 | pretrained_settings = {}
71 | for model_name, sources in res2net_weights.items():
72 | pretrained_settings[model_name] = {}
73 | for source_name, source_url in sources.items():
74 | pretrained_settings[model_name][source_name] = {
75 | "url": source_url,
76 | "input_size": [3, 224, 224],
77 | "input_range": [0, 1],
78 | "mean": [0.485, 0.456, 0.406],
79 | "std": [0.229, 0.224, 0.225],
80 | "num_classes": 1000,
81 | }
82 |
83 |
84 | timm_res2net_encoders = {
85 | "timm-res2net50_26w_4s": {
86 | "encoder": Res2NetEncoder,
87 | "pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"],
88 | "params": {
89 | "out_channels": (3, 64, 256, 512, 1024, 2048),
90 | "block": Bottle2neck,
91 | "layers": [3, 4, 6, 3],
92 | "base_width": 26,
93 | "block_args": {"scale": 4},
94 | },
95 | },
96 | "timm-res2net101_26w_4s": {
97 | "encoder": Res2NetEncoder,
98 | "pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"],
99 | "params": {
100 | "out_channels": (3, 64, 256, 512, 1024, 2048),
101 | "block": Bottle2neck,
102 | "layers": [3, 4, 23, 3],
103 | "base_width": 26,
104 | "block_args": {"scale": 4},
105 | },
106 | },
107 | "timm-res2net50_26w_6s": {
108 | "encoder": Res2NetEncoder,
109 | "pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"],
110 | "params": {
111 | "out_channels": (3, 64, 256, 512, 1024, 2048),
112 | "block": Bottle2neck,
113 | "layers": [3, 4, 6, 3],
114 | "base_width": 26,
115 | "block_args": {"scale": 6},
116 | },
117 | },
118 | "timm-res2net50_26w_8s": {
119 | "encoder": Res2NetEncoder,
120 | "pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"],
121 | "params": {
122 | "out_channels": (3, 64, 256, 512, 1024, 2048),
123 | "block": Bottle2neck,
124 | "layers": [3, 4, 6, 3],
125 | "base_width": 26,
126 | "block_args": {"scale": 8},
127 | },
128 | },
129 | "timm-res2net50_48w_2s": {
130 | "encoder": Res2NetEncoder,
131 | "pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"],
132 | "params": {
133 | "out_channels": (3, 64, 256, 512, 1024, 2048),
134 | "block": Bottle2neck,
135 | "layers": [3, 4, 6, 3],
136 | "base_width": 48,
137 | "block_args": {"scale": 2},
138 | },
139 | },
140 | "timm-res2net50_14w_8s": {
141 | "encoder": Res2NetEncoder,
142 | "pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"],
143 | "params": {
144 | "out_channels": (3, 64, 256, 512, 1024, 2048),
145 | "block": Bottle2neck,
146 | "layers": [3, 4, 6, 3],
147 | "base_width": 14,
148 | "block_args": {"scale": 8},
149 | },
150 | },
151 | "timm-res2next50": {
152 | "encoder": Res2NetEncoder,
153 | "pretrained_settings": pretrained_settings["timm-res2next50"],
154 | "params": {
155 | "out_channels": (3, 64, 256, 512, 1024, 2048),
156 | "block": Bottle2neck,
157 | "layers": [3, 4, 6, 3],
158 | "base_width": 4,
159 | "cardinality": 8,
160 | "block_args": {"scale": 4},
161 | },
162 | },
163 | }
164 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/senet.py:
--------------------------------------------------------------------------------
1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
2 |
3 | Attributes:
4 |
5 | _out_channels (list of int): specify number of channels for each encoder feature tensor
6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations)
7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
8 |
9 | Methods:
10 |
11 | forward(self, x: torch.Tensor)
12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting
14 | with resolution same as input `x` tensor).
15 |
16 | Input: `x` with shape (1, 3, 64, 64)
17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
20 |
21 | also should support number of features according to specified depth, e.g. if depth = 5,
22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24 | """
25 |
26 | import torch.nn as nn
27 |
28 | from pretrainedmodels.models.senet import (
29 | SENet,
30 | SEBottleneck,
31 | SEResNetBottleneck,
32 | SEResNeXtBottleneck,
33 | pretrained_settings,
34 | )
35 | from ._base import EncoderMixin
36 |
37 |
38 | class SENetEncoder(SENet, EncoderMixin):
39 | def __init__(self, out_channels, depth=5, **kwargs):
40 | super().__init__(**kwargs)
41 |
42 | self._out_channels = out_channels
43 | self._depth = depth
44 | self._in_channels = 3
45 |
46 | del self.last_linear
47 | del self.avg_pool
48 |
49 | def get_stages(self):
50 | return [
51 | nn.Identity(),
52 | self.layer0[:-1],
53 | nn.Sequential(self.layer0[-1], self.layer1),
54 | self.layer2,
55 | self.layer3,
56 | self.layer4,
57 | ]
58 |
59 | def forward(self, x):
60 | stages = self.get_stages()
61 |
62 | features = []
63 | for i in range(self._depth + 1):
64 | x = stages[i](x)
65 | features.append(x)
66 |
67 | return features
68 |
69 | def load_state_dict(self, state_dict, **kwargs):
70 | state_dict.pop("last_linear.bias", None)
71 | state_dict.pop("last_linear.weight", None)
72 | super().load_state_dict(state_dict, **kwargs)
73 |
74 |
75 | senet_encoders = {
76 | "senet154": {
77 | "encoder": SENetEncoder,
78 | "pretrained_settings": pretrained_settings["senet154"],
79 | "params": {
80 | "out_channels": (3, 128, 256, 512, 1024, 2048),
81 | "block": SEBottleneck,
82 | "dropout_p": 0.2,
83 | "groups": 64,
84 | "layers": [3, 8, 36, 3],
85 | "num_classes": 1000,
86 | "reduction": 16,
87 | },
88 | },
89 | "se_resnet50": {
90 | "encoder": SENetEncoder,
91 | "pretrained_settings": pretrained_settings["se_resnet50"],
92 | "params": {
93 | "out_channels": (3, 64, 256, 512, 1024, 2048),
94 | "block": SEResNetBottleneck,
95 | "layers": [3, 4, 6, 3],
96 | "downsample_kernel_size": 1,
97 | "downsample_padding": 0,
98 | "dropout_p": None,
99 | "groups": 1,
100 | "inplanes": 64,
101 | "input_3x3": False,
102 | "num_classes": 1000,
103 | "reduction": 16,
104 | },
105 | },
106 | "se_resnet101": {
107 | "encoder": SENetEncoder,
108 | "pretrained_settings": pretrained_settings["se_resnet101"],
109 | "params": {
110 | "out_channels": (3, 64, 256, 512, 1024, 2048),
111 | "block": SEResNetBottleneck,
112 | "layers": [3, 4, 23, 3],
113 | "downsample_kernel_size": 1,
114 | "downsample_padding": 0,
115 | "dropout_p": None,
116 | "groups": 1,
117 | "inplanes": 64,
118 | "input_3x3": False,
119 | "num_classes": 1000,
120 | "reduction": 16,
121 | },
122 | },
123 | "se_resnet152": {
124 | "encoder": SENetEncoder,
125 | "pretrained_settings": pretrained_settings["se_resnet152"],
126 | "params": {
127 | "out_channels": (3, 64, 256, 512, 1024, 2048),
128 | "block": SEResNetBottleneck,
129 | "layers": [3, 8, 36, 3],
130 | "downsample_kernel_size": 1,
131 | "downsample_padding": 0,
132 | "dropout_p": None,
133 | "groups": 1,
134 | "inplanes": 64,
135 | "input_3x3": False,
136 | "num_classes": 1000,
137 | "reduction": 16,
138 | },
139 | },
140 | "se_resnext50_32x4d": {
141 | "encoder": SENetEncoder,
142 | "pretrained_settings": pretrained_settings["se_resnext50_32x4d"],
143 | "params": {
144 | "out_channels": (3, 64, 256, 512, 1024, 2048),
145 | "block": SEResNeXtBottleneck,
146 | "layers": [3, 4, 6, 3],
147 | "downsample_kernel_size": 1,
148 | "downsample_padding": 0,
149 | "dropout_p": None,
150 | "groups": 32,
151 | "inplanes": 64,
152 | "input_3x3": False,
153 | "num_classes": 1000,
154 | "reduction": 16,
155 | },
156 | },
157 | "se_resnext101_32x4d": {
158 | "encoder": SENetEncoder,
159 | "pretrained_settings": pretrained_settings["se_resnext101_32x4d"],
160 | "params": {
161 | "out_channels": (3, 64, 256, 512, 1024, 2048),
162 | "block": SEResNeXtBottleneck,
163 | "layers": [3, 4, 23, 3],
164 | "downsample_kernel_size": 1,
165 | "downsample_padding": 0,
166 | "dropout_p": None,
167 | "groups": 32,
168 | "inplanes": 64,
169 | "input_3x3": False,
170 | "num_classes": 1000,
171 | "reduction": 16,
172 | },
173 | },
174 | }
175 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/dpn.py:
--------------------------------------------------------------------------------
1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
2 |
3 | Attributes:
4 |
5 | _out_channels (list of int): specify number of channels for each encoder feature tensor
6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations)
7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
8 |
9 | Methods:
10 |
11 | forward(self, x: torch.Tensor)
12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting
14 | with resolution same as input `x` tensor).
15 |
16 | Input: `x` with shape (1, 3, 64, 64)
17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
20 |
21 | also should support number of features according to specified depth, e.g. if depth = 5,
22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24 | """
25 |
26 | import torch
27 | import torch.nn as nn
28 | import torch.nn.functional as F
29 |
30 | from pretrainedmodels.models.dpn import DPN
31 | from pretrainedmodels.models.dpn import pretrained_settings
32 |
33 | from ._base import EncoderMixin
34 |
35 |
36 | class DPNEncoder(DPN, EncoderMixin):
37 | def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
38 | super().__init__(**kwargs)
39 | self._stage_idxs = stage_idxs
40 | self._depth = depth
41 | self._out_channels = out_channels
42 | self._in_channels = 3
43 |
44 | del self.last_linear
45 |
46 | def get_stages(self):
47 | return [
48 | nn.Identity(),
49 | nn.Sequential(self.features[0].conv, self.features[0].bn, self.features[0].act),
50 | nn.Sequential(self.features[0].pool, self.features[1 : self._stage_idxs[0]]),
51 | self.features[self._stage_idxs[0] : self._stage_idxs[1]],
52 | self.features[self._stage_idxs[1] : self._stage_idxs[2]],
53 | self.features[self._stage_idxs[2] : self._stage_idxs[3]],
54 | ]
55 |
56 | def forward(self, x):
57 |
58 | stages = self.get_stages()
59 |
60 | features = []
61 | for i in range(self._depth + 1):
62 | x = stages[i](x)
63 | if isinstance(x, (list, tuple)):
64 | features.append(F.relu(torch.cat(x, dim=1), inplace=True))
65 | else:
66 | features.append(x)
67 |
68 | return features
69 |
70 | def load_state_dict(self, state_dict, **kwargs):
71 | state_dict.pop("last_linear.bias", None)
72 | state_dict.pop("last_linear.weight", None)
73 | super().load_state_dict(state_dict, **kwargs)
74 |
75 |
76 | dpn_encoders = {
77 | "dpn68": {
78 | "encoder": DPNEncoder,
79 | "pretrained_settings": pretrained_settings["dpn68"],
80 | "params": {
81 | "stage_idxs": (4, 8, 20, 24),
82 | "out_channels": (3, 10, 144, 320, 704, 832),
83 | "groups": 32,
84 | "inc_sec": (16, 32, 32, 64),
85 | "k_r": 128,
86 | "k_sec": (3, 4, 12, 3),
87 | "num_classes": 1000,
88 | "num_init_features": 10,
89 | "small": True,
90 | "test_time_pool": True,
91 | },
92 | },
93 | "dpn68b": {
94 | "encoder": DPNEncoder,
95 | "pretrained_settings": pretrained_settings["dpn68b"],
96 | "params": {
97 | "stage_idxs": (4, 8, 20, 24),
98 | "out_channels": (3, 10, 144, 320, 704, 832),
99 | "b": True,
100 | "groups": 32,
101 | "inc_sec": (16, 32, 32, 64),
102 | "k_r": 128,
103 | "k_sec": (3, 4, 12, 3),
104 | "num_classes": 1000,
105 | "num_init_features": 10,
106 | "small": True,
107 | "test_time_pool": True,
108 | },
109 | },
110 | "dpn92": {
111 | "encoder": DPNEncoder,
112 | "pretrained_settings": pretrained_settings["dpn92"],
113 | "params": {
114 | "stage_idxs": (4, 8, 28, 32),
115 | "out_channels": (3, 64, 336, 704, 1552, 2688),
116 | "groups": 32,
117 | "inc_sec": (16, 32, 24, 128),
118 | "k_r": 96,
119 | "k_sec": (3, 4, 20, 3),
120 | "num_classes": 1000,
121 | "num_init_features": 64,
122 | "test_time_pool": True,
123 | },
124 | },
125 | "dpn98": {
126 | "encoder": DPNEncoder,
127 | "pretrained_settings": pretrained_settings["dpn98"],
128 | "params": {
129 | "stage_idxs": (4, 10, 30, 34),
130 | "out_channels": (3, 96, 336, 768, 1728, 2688),
131 | "groups": 40,
132 | "inc_sec": (16, 32, 32, 128),
133 | "k_r": 160,
134 | "k_sec": (3, 6, 20, 3),
135 | "num_classes": 1000,
136 | "num_init_features": 96,
137 | "test_time_pool": True,
138 | },
139 | },
140 | "dpn107": {
141 | "encoder": DPNEncoder,
142 | "pretrained_settings": pretrained_settings["dpn107"],
143 | "params": {
144 | "stage_idxs": (5, 13, 33, 37),
145 | "out_channels": (3, 128, 376, 1152, 2432, 2688),
146 | "groups": 50,
147 | "inc_sec": (20, 64, 64, 128),
148 | "k_r": 200,
149 | "k_sec": (4, 8, 20, 3),
150 | "num_classes": 1000,
151 | "num_init_features": 128,
152 | "test_time_pool": True,
153 | },
154 | },
155 | "dpn131": {
156 | "encoder": DPNEncoder,
157 | "pretrained_settings": pretrained_settings["dpn131"],
158 | "params": {
159 | "stage_idxs": (5, 13, 41, 45),
160 | "out_channels": (3, 128, 352, 832, 1984, 2688),
161 | "groups": 40,
162 | "inc_sec": (16, 32, 32, 128),
163 | "k_r": 160,
164 | "k_sec": (4, 8, 28, 3),
165 | "num_classes": 1000,
166 | "num_init_features": 128,
167 | "test_time_pool": True,
168 | },
169 | },
170 | }
171 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/timm_mobilenetv3.py:
--------------------------------------------------------------------------------
1 | import timm
2 | import numpy as np
3 | import torch.nn as nn
4 |
5 | from ._base import EncoderMixin
6 |
7 |
8 | def _make_divisible(x, divisible_by=8):
9 | return int(np.ceil(x * 1.0 / divisible_by) * divisible_by)
10 |
11 |
12 | class MobileNetV3Encoder(nn.Module, EncoderMixin):
13 | def __init__(self, model_name, width_mult, depth=5, **kwargs):
14 | super().__init__()
15 | if "large" not in model_name and "small" not in model_name:
16 | raise ValueError("MobileNetV3 wrong model name {}".format(model_name))
17 |
18 | self._mode = "small" if "small" in model_name else "large"
19 | self._depth = depth
20 | self._out_channels = self._get_channels(self._mode, width_mult)
21 | self._in_channels = 3
22 |
23 | # minimal models replace hardswish with relu
24 | self.model = timm.create_model(
25 | model_name=model_name,
26 | scriptable=True, # torch.jit scriptable
27 | exportable=True, # onnx export
28 | features_only=True,
29 | )
30 |
31 | def _get_channels(self, mode, width_mult):
32 | if mode == "small":
33 | channels = [16, 16, 24, 48, 576]
34 | else:
35 | channels = [16, 24, 40, 112, 960]
36 | channels = [
37 | 3,
38 | ] + [_make_divisible(x * width_mult) for x in channels]
39 | return tuple(channels)
40 |
41 | def get_stages(self):
42 | if self._mode == "small":
43 | return [
44 | nn.Identity(),
45 | nn.Sequential(
46 | self.model.conv_stem,
47 | self.model.bn1,
48 | self.model.act1,
49 | ),
50 | self.model.blocks[0],
51 | self.model.blocks[1],
52 | self.model.blocks[2:4],
53 | self.model.blocks[4:],
54 | ]
55 | elif self._mode == "large":
56 | return [
57 | nn.Identity(),
58 | nn.Sequential(
59 | self.model.conv_stem,
60 | self.model.bn1,
61 | self.model.act1,
62 | self.model.blocks[0],
63 | ),
64 | self.model.blocks[1],
65 | self.model.blocks[2],
66 | self.model.blocks[3:5],
67 | self.model.blocks[5:],
68 | ]
69 | else:
70 | ValueError("MobileNetV3 mode should be small or large, got {}".format(self._mode))
71 |
72 | def forward(self, x):
73 | stages = self.get_stages()
74 |
75 | features = []
76 | for i in range(self._depth + 1):
77 | x = stages[i](x)
78 | features.append(x)
79 |
80 | return features
81 |
82 | def load_state_dict(self, state_dict, **kwargs):
83 | state_dict.pop("conv_head.weight", None)
84 | state_dict.pop("conv_head.bias", None)
85 | state_dict.pop("classifier.weight", None)
86 | state_dict.pop("classifier.bias", None)
87 | self.model.load_state_dict(state_dict, **kwargs)
88 |
89 |
90 | mobilenetv3_weights = {
91 | "tf_mobilenetv3_large_075": {
92 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth" # noqa
93 | },
94 | "tf_mobilenetv3_large_100": {
95 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth" # noqa
96 | },
97 | "tf_mobilenetv3_large_minimal_100": {
98 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth" # noqa
99 | },
100 | "tf_mobilenetv3_small_075": {
101 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth" # noqa
102 | },
103 | "tf_mobilenetv3_small_100": {
104 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth" # noqa
105 | },
106 | "tf_mobilenetv3_small_minimal_100": {
107 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth" # noqa
108 | },
109 | }
110 |
111 | pretrained_settings = {}
112 | for model_name, sources in mobilenetv3_weights.items():
113 | pretrained_settings[model_name] = {}
114 | for source_name, source_url in sources.items():
115 | pretrained_settings[model_name][source_name] = {
116 | "url": source_url,
117 | "input_range": [0, 1],
118 | "mean": [0.485, 0.456, 0.406],
119 | "std": [0.229, 0.224, 0.225],
120 | "input_space": "RGB",
121 | }
122 |
123 |
124 | timm_mobilenetv3_encoders = {
125 | "timm-mobilenetv3_large_075": {
126 | "encoder": MobileNetV3Encoder,
127 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_075"],
128 | "params": {"model_name": "tf_mobilenetv3_large_075", "width_mult": 0.75},
129 | },
130 | "timm-mobilenetv3_large_100": {
131 | "encoder": MobileNetV3Encoder,
132 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_100"],
133 | "params": {"model_name": "tf_mobilenetv3_large_100", "width_mult": 1.0},
134 | },
135 | "timm-mobilenetv3_large_minimal_100": {
136 | "encoder": MobileNetV3Encoder,
137 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_minimal_100"],
138 | "params": {"model_name": "tf_mobilenetv3_large_minimal_100", "width_mult": 1.0},
139 | },
140 | "timm-mobilenetv3_small_075": {
141 | "encoder": MobileNetV3Encoder,
142 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_075"],
143 | "params": {"model_name": "tf_mobilenetv3_small_075", "width_mult": 0.75},
144 | },
145 | "timm-mobilenetv3_small_100": {
146 | "encoder": MobileNetV3Encoder,
147 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_100"],
148 | "params": {"model_name": "tf_mobilenetv3_small_100", "width_mult": 1.0},
149 | },
150 | "timm-mobilenetv3_small_minimal_100": {
151 | "encoder": MobileNetV3Encoder,
152 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_minimal_100"],
153 | "params": {"model_name": "tf_mobilenetv3_small_minimal_100", "width_mult": 1.0},
154 | },
155 | }
156 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/SMP_qubvel/encoders/efficientnet.py:
--------------------------------------------------------------------------------
1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
2 |
3 | Attributes:
4 |
5 | _out_channels (list of int): specify number of channels for each encoder feature tensor
6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations)
7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
8 |
9 | Methods:
10 |
11 | forward(self, x: torch.Tensor)
12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting
14 | with resolution same as input `x` tensor).
15 |
16 | Input: `x` with shape (1, 3, 64, 64)
17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
20 |
21 | also should support number of features according to specified depth, e.g. if depth = 5,
22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24 | """
25 | import torch.nn as nn
26 | from efficientnet_pytorch import EfficientNet
27 | from efficientnet_pytorch.utils import url_map, url_map_advprop, get_model_params
28 |
29 | from ._base import EncoderMixin
30 |
31 |
32 | class EfficientNetEncoder(EfficientNet, EncoderMixin):
33 | def __init__(self, stage_idxs, out_channels, model_name, depth=5):
34 |
35 | blocks_args, global_params = get_model_params(model_name, override_params=None)
36 | super().__init__(blocks_args, global_params)
37 |
38 | self._stage_idxs = stage_idxs
39 | self._out_channels = out_channels
40 | self._depth = depth
41 | self._in_channels = 3
42 |
43 | del self._fc
44 |
45 | def get_stages(self):
46 | return [
47 | nn.Identity(),
48 | nn.Sequential(self._conv_stem, self._bn0, self._swish),
49 | self._blocks[: self._stage_idxs[0]],
50 | self._blocks[self._stage_idxs[0] : self._stage_idxs[1]],
51 | self._blocks[self._stage_idxs[1] : self._stage_idxs[2]],
52 | self._blocks[self._stage_idxs[2] :],
53 | ]
54 |
55 | def forward(self, x):
56 | stages = self.get_stages()
57 |
58 | block_number = 0.0
59 | drop_connect_rate = self._global_params.drop_connect_rate
60 |
61 | features = []
62 | for i in range(self._depth + 1):
63 |
64 | # Identity and Sequential stages
65 | if i < 2:
66 | x = stages[i](x)
67 |
68 | # Block stages need drop_connect rate
69 | else:
70 | for module in stages[i]:
71 | drop_connect = drop_connect_rate * block_number / len(self._blocks)
72 | block_number += 1.0
73 | x = module(x, drop_connect)
74 |
75 | features.append(x)
76 |
77 | return features
78 |
79 | def load_state_dict(self, state_dict, **kwargs):
80 | state_dict.pop("_fc.bias", None)
81 | state_dict.pop("_fc.weight", None)
82 | super().load_state_dict(state_dict, **kwargs)
83 |
84 |
85 | def _get_pretrained_settings(encoder):
86 | pretrained_settings = {
87 | "imagenet": {
88 | "mean": [0.485, 0.456, 0.406],
89 | "std": [0.229, 0.224, 0.225],
90 | "url": url_map[encoder],
91 | "input_space": "RGB",
92 | "input_range": [0, 1],
93 | },
94 | "advprop": {
95 | "mean": [0.5, 0.5, 0.5],
96 | "std": [0.5, 0.5, 0.5],
97 | "url": url_map_advprop[encoder],
98 | "input_space": "RGB",
99 | "input_range": [0, 1],
100 | },
101 | }
102 | return pretrained_settings
103 |
104 |
105 | efficient_net_encoders = {
106 | "efficientnet-b0": {
107 | "encoder": EfficientNetEncoder,
108 | "pretrained_settings": _get_pretrained_settings("efficientnet-b0"),
109 | "params": {
110 | "out_channels": (3, 32, 24, 40, 112, 320),
111 | "stage_idxs": (3, 5, 9, 16),
112 | "model_name": "efficientnet-b0",
113 | },
114 | },
115 | "efficientnet-b1": {
116 | "encoder": EfficientNetEncoder,
117 | "pretrained_settings": _get_pretrained_settings("efficientnet-b1"),
118 | "params": {
119 | "out_channels": (3, 32, 24, 40, 112, 320),
120 | "stage_idxs": (5, 8, 16, 23),
121 | "model_name": "efficientnet-b1",
122 | },
123 | },
124 | "efficientnet-b2": {
125 | "encoder": EfficientNetEncoder,
126 | "pretrained_settings": _get_pretrained_settings("efficientnet-b2"),
127 | "params": {
128 | "out_channels": (3, 32, 24, 48, 120, 352),
129 | "stage_idxs": (5, 8, 16, 23),
130 | "model_name": "efficientnet-b2",
131 | },
132 | },
133 | "efficientnet-b3": {
134 | "encoder": EfficientNetEncoder,
135 | "pretrained_settings": _get_pretrained_settings("efficientnet-b3"),
136 | "params": {
137 | "out_channels": (3, 40, 32, 48, 136, 384),
138 | "stage_idxs": (5, 8, 18, 26),
139 | "model_name": "efficientnet-b3",
140 | },
141 | },
142 | "efficientnet-b4": {
143 | "encoder": EfficientNetEncoder,
144 | "pretrained_settings": _get_pretrained_settings("efficientnet-b4"),
145 | "params": {
146 | "out_channels": (3, 48, 32, 56, 160, 448),
147 | "stage_idxs": (6, 10, 22, 32),
148 | "model_name": "efficientnet-b4",
149 | },
150 | },
151 | "efficientnet-b5": {
152 | "encoder": EfficientNetEncoder,
153 | "pretrained_settings": _get_pretrained_settings("efficientnet-b5"),
154 | "params": {
155 | "out_channels": (3, 48, 40, 64, 176, 512),
156 | "stage_idxs": (8, 13, 27, 39),
157 | "model_name": "efficientnet-b5",
158 | },
159 | },
160 | "efficientnet-b6": {
161 | "encoder": EfficientNetEncoder,
162 | "pretrained_settings": _get_pretrained_settings("efficientnet-b6"),
163 | "params": {
164 | "out_channels": (3, 56, 40, 72, 200, 576),
165 | "stage_idxs": (9, 15, 31, 45),
166 | "model_name": "efficientnet-b6",
167 | },
168 | },
169 | "efficientnet-b7": {
170 | "encoder": EfficientNetEncoder,
171 | "pretrained_settings": _get_pretrained_settings("efficientnet-b7"),
172 | "params": {
173 | "out_channels": (3, 64, 48, 80, 224, 640),
174 | "stage_idxs": (11, 18, 38, 55),
175 | "model_name": "efficientnet-b7",
176 | },
177 | },
178 | }
179 |
--------------------------------------------------------------------------------
/wama_modules/thirdparty_lib/VC3D_kenshohara/resnext.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import math
6 | from functools import partial
7 |
8 | __all__ = ['ResNeXt', 'resnet50', 'resnet101']
9 |
10 |
11 | def conv3x3x3(in_planes, out_planes, stride=1):
12 | # 3x3x3 convolution with padding
13 | return nn.Conv3d(in_planes, out_planes, kernel_size=3,
14 | stride=stride, padding=1, bias=False)
15 |
16 |
17 | def downsample_basic_block(x, planes, stride):
18 | out = F.avg_pool3d(x, kernel_size=1, stride=stride)
19 | zero_pads = torch.Tensor(out.size(0), planes - out.size(1),
20 | out.size(2), out.size(3),
21 | out.size(4)).zero_()
22 | if isinstance(out.data, torch.cuda.FloatTensor):
23 | zero_pads = zero_pads.cuda()
24 |
25 | out = Variable(torch.cat([out.data, zero_pads], dim=1))
26 |
27 | return out
28 |
29 |
30 | class ResNeXtBottleneck(nn.Module):
31 | expansion = 2
32 |
33 | def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
34 | super(ResNeXtBottleneck, self).__init__()
35 | mid_planes = cardinality * int(planes / 32)
36 | self.conv1 = nn.Conv3d(inplanes, mid_planes, kernel_size=1, bias=False)
37 | self.bn1 = nn.BatchNorm3d(mid_planes)
38 | self.conv2 = nn.Conv3d(mid_planes, mid_planes, kernel_size=3, stride=stride,
39 | padding=1, groups=cardinality, bias=False)
40 | self.bn2 = nn.BatchNorm3d(mid_planes)
41 | self.conv3 = nn.Conv3d(mid_planes, planes * self.expansion, kernel_size=1, bias=False)
42 | self.bn3 = nn.BatchNorm3d(planes * self.expansion)
43 | self.relu = nn.ReLU(inplace=True)
44 | self.downsample = downsample
45 | self.stride = stride
46 |
47 | def forward(self, x):
48 | residual = x
49 |
50 | out = self.conv1(x)
51 | out = self.bn1(out)
52 | out = self.relu(out)
53 |
54 | out = self.conv2(out)
55 | out = self.bn2(out)
56 | out = self.relu(out)
57 |
58 | out = self.conv3(out)
59 | out = self.bn3(out)
60 |
61 | if self.downsample is not None:
62 | residual = self.downsample(x)
63 |
64 | out += residual
65 | out = self.relu(out)
66 |
67 | return out
68 |
69 |
70 | class ResNeXt(nn.Module):
71 |
72 | def __init__(self, block, layers, sample_size = None, sample_duration = None, shortcut_type='B', cardinality=32, num_classes=400, last_fc=True):
73 | self.last_fc = last_fc
74 |
75 | self.inplanes = 64
76 | super(ResNeXt, self).__init__()
77 | self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2),
78 | padding=(3, 3, 3), bias=False)
79 | self.bn1 = nn.BatchNorm3d(64)
80 | self.relu = nn.ReLU(inplace=True)
81 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
82 | self.layer1 = self._make_layer(block, 128, layers[0], shortcut_type, cardinality)
83 | self.layer2 = self._make_layer(block, 256, layers[1], shortcut_type, cardinality, stride=2)
84 | self.layer3 = self._make_layer(block, 512, layers[2], shortcut_type, cardinality, stride=2)
85 | self.layer4 = self._make_layer(block, 1024, layers[3], shortcut_type, cardinality, stride=2)
86 | # last_duration = math.ceil(sample_duration / 16)
87 | # last_size = math.ceil(sample_size / 32)
88 | # self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1)
89 | # self.fc = nn.Linear(cardinality * 32 * block.expansion, num_classes)
90 |
91 | for m in self.modules():
92 | if isinstance(m, nn.Conv3d):
93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
94 | m.weight.data.normal_(0, math.sqrt(2. / n))
95 | elif isinstance(m, nn.BatchNorm3d):
96 | m.weight.data.fill_(1)
97 | m.bias.data.zero_()
98 |
99 | def _make_layer(self, block, planes, blocks, shortcut_type, cardinality, stride=1):
100 | downsample = None
101 | if stride != 1 or self.inplanes != planes * block.expansion:
102 | if shortcut_type == 'A':
103 | downsample = partial(downsample_basic_block,
104 | planes=planes * block.expansion,
105 | stride=stride)
106 | else:
107 | downsample = nn.Sequential(
108 | nn.Conv3d(self.inplanes, planes * block.expansion,
109 | kernel_size=1, stride=stride, bias=False),
110 | nn.BatchNorm3d(planes * block.expansion)
111 | )
112 |
113 | layers = []
114 | layers.append(block(self.inplanes, planes, cardinality, stride, downsample))
115 | self.inplanes = planes * block.expansion
116 | for i in range(1, blocks):
117 | layers.append(block(self.inplanes, planes, cardinality))
118 |
119 | return nn.Sequential(*layers)
120 |
121 | def forward(self, x):
122 | f_list = []
123 | x = self.conv1(x)
124 | x = self.bn1(x)
125 | x = self.relu(x)
126 | x = self.maxpool(x)
127 | f_list.append(x)
128 |
129 | x = self.layer1(x)
130 | f_list.append(x)
131 | x = self.layer2(x)
132 | f_list.append(x)
133 | x = self.layer3(x)
134 | f_list.append(x)
135 | x = self.layer4(x)
136 | f_list.append(x)
137 | return f_list
138 |
139 | def get_fine_tuning_parameters(model, ft_begin_index):
140 | if ft_begin_index == 0:
141 | return model.parameters()
142 |
143 | ft_module_names = []
144 | for i in range(ft_begin_index, 5):
145 | ft_module_names.append('layer{}'.format(ft_begin_index))
146 | ft_module_names.append('fc')
147 |
148 | parameters = []
149 | for k, v in model.named_parameters():
150 | for ft_module in ft_module_names:
151 | if ft_module in k:
152 | parameters.append({'params': v})
153 | break
154 | else:
155 | parameters.append({'params': v, 'lr': 0.0})
156 |
157 | return parameters
158 |
159 | def resnet50(**kwargs):
160 | """Constructs a ResNet-50 model.
161 | """
162 | model = ResNeXt(ResNeXtBottleneck, [3, 4, 6, 3], **kwargs)
163 | return model
164 |
165 | def resnet101(**kwargs):
166 | """Constructs a ResNet-101 model.
167 | """
168 | model = ResNeXt(ResNeXtBottleneck, [3, 4, 23, 3], **kwargs)
169 | return model
170 |
171 | def resnet152(**kwargs):
172 | """Constructs a ResNet-101 model.
173 | """
174 | model = ResNeXt(ResNeXtBottleneck, [3, 8, 36, 3], **kwargs)
175 | return model
176 |
177 | def generate_model(model_depth):
178 | if model_depth == 50:
179 | model = resnet50()
180 | elif model_depth == 101:
181 | model = resnet101()
182 | elif model_depth == 152:
183 | model = resnet152()
184 | else:
185 | model = None
186 | return model
--------------------------------------------------------------------------------