├── demo ├── Demo_MultiTask_MMoE.py ├── Demo_MultiTask_PLE.py ├── Demo_SeBlockwithResnet_MultiInstanseClassification.py ├── Demo_dualbranchnet_forMultiInput_classification.py ├── Demo0_VGG_SingleLabelClassification.py ├── multi_label │ └── generate_multilabel_dataset.py ├── Demo1_ResNet_SingleLabelClassification.py ├── Demo2_ResNet_MultiLabelClassification.py ├── Demo3_ResNetUnet_SingleLabelSegmentation.py ├── Demo4_ResNetUnet_MultiLabelSegmentation.py ├── Demo6_UnetwithFPN_segmentation.py ├── _explain_how2useVitAsNeck.py ├── Demo5_MultiTask_SegAndCls.py ├── Demo7_2D_TransUnet_Segmentation.py ├── Demo8_3D_TransUnet_Segmentation.py └── Demo_BilinearPooling.py ├── wama_modules ├── __init__.py ├── thirdparty_lib │ ├── __init__.py │ ├── C3D_jfzhang95 │ │ ├── __init__.py │ │ └── c3d.py │ ├── C3D_yyuanad │ │ ├── __init__.py │ │ └── c3d.py │ ├── MedicalNet_Tencent │ │ ├── __init__.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── __pycache__ │ │ │ │ ├── resnet.cpython-38.pyc │ │ │ │ └── __init__.cpython-38.pyc │ │ └── model.py │ ├── VC3D_kenshohara │ │ ├── __init__.py │ │ ├── wide_resnet.py │ │ └── resnext.py │ ├── Efficient3D_okankop │ │ ├── __init__.py │ │ └── models │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ ├── c3d.cpython-38.pyc │ │ │ ├── resnet.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── resnext.cpython-38.pyc │ │ │ ├── mobilenet.cpython-38.pyc │ │ │ ├── mobilenetv2.cpython-38.pyc │ │ │ ├── shufflenet.cpython-38.pyc │ │ │ ├── squeezenet.cpython-38.pyc │ │ │ └── shufflenetv2.cpython-38.pyc │ │ │ ├── c3d.py │ │ │ ├── mobilenet.py │ │ │ ├── squeezenet.py │ │ │ ├── shufflenet.py │ │ │ └── mobilenetv2.py │ ├── ResNets3D_kenshohara │ │ └── __init__.py │ ├── get_model.py │ └── SMP_qubvel │ │ ├── __init__.py │ │ ├── __version__.py │ │ ├── encoders │ │ ├── __pycache__ │ │ │ ├── dpn.cpython-38.pyc │ │ │ ├── vgg.cpython-38.pyc │ │ │ ├── _base.cpython-38.pyc │ │ │ ├── senet.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── _utils.cpython-38.pyc │ │ │ ├── densenet.cpython-38.pyc │ │ │ ├── resnet.cpython-38.pyc │ │ │ ├── xception.cpython-38.pyc │ │ │ ├── mobilenet.cpython-38.pyc │ │ │ ├── timm_sknet.cpython-38.pyc │ │ │ ├── efficientnet.cpython-38.pyc │ │ │ ├── inceptionv4.cpython-38.pyc │ │ │ ├── timm_gernet.cpython-38.pyc │ │ │ ├── timm_regnet.cpython-38.pyc │ │ │ ├── timm_res2net.cpython-38.pyc │ │ │ ├── timm_resnest.cpython-38.pyc │ │ │ ├── _preprocessing.cpython-38.pyc │ │ │ ├── mix_transformer.cpython-38.pyc │ │ │ ├── timm_universal.cpython-38.pyc │ │ │ ├── inceptionresnetv2.cpython-38.pyc │ │ │ ├── timm_efficientnet.cpython-38.pyc │ │ │ └── timm_mobilenetv3.cpython-38.pyc │ │ ├── _preprocessing.py │ │ ├── timm_universal.py │ │ ├── _utils.py │ │ ├── _base.py │ │ ├── xception.py │ │ ├── mobilenet.py │ │ ├── inceptionresnetv2.py │ │ ├── inceptionv4.py │ │ ├── timm_sknet.py │ │ ├── __init__.py │ │ ├── timm_gernet.py │ │ ├── densenet.py │ │ ├── vgg.py │ │ ├── timm_res2net.py │ │ ├── senet.py │ │ ├── dpn.py │ │ ├── timm_mobilenetv3.py │ │ └── efficientnet.py │ │ ├── base │ │ ├── __init__.py │ │ ├── initialization.py │ │ ├── heads.py │ │ ├── model.py │ │ └── modules.py │ │ └── utils │ │ ├── __init__.py │ │ ├── losses.py │ │ ├── meter.py │ │ ├── base.py │ │ ├── metrics.py │ │ ├── train.py │ │ └── functional.py ├── Head.py ├── utils.py ├── Decoder.py └── Neck.py ├── make_rqs.py ├── images └── transUnet.png ├── requirements.txt ├── setup.py ├── LICENSE └── Document_allmodules.md /demo/Demo_MultiTask_MMoE.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/Demo_MultiTask_PLE.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wama_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/C3D_jfzhang95/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/C3D_yyuanad/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/Demo_SeBlockwithResnet_MultiInstanseClassification.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/Demo_dualbranchnet_forMultiInput_classification.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/MedicalNet_Tencent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/VC3D_kenshohara/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/ResNets3D_kenshohara/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/MedicalNet_Tencent/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /make_rqs.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.system('pipreqs ./ --encoding=utf8 --force') -------------------------------------------------------------------------------- /images/transUnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/images/transUnet.png -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/get_model.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/__init__.py: -------------------------------------------------------------------------------- 1 | from . import encoders 2 | from .__version__ import __version__ 3 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/__version__.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 3, 0) 2 | 3 | __version__ = ".".join(map(str, VERSION)) 4 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/dpn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/dpn.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/vgg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/vgg.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/_base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/_base.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/senet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/senet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/_utils.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/densenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/densenet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/xception.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/xception.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/c3d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/c3d.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/mobilenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/mobilenet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_sknet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_sknet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/MedicalNet_Tencent/models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/MedicalNet_Tencent/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/efficientnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/efficientnet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/inceptionv4.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/inceptionv4.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_gernet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_gernet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_regnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_regnet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_res2net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_res2net.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_resnest.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_resnest.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/resnext.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/resnext.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/MedicalNet_Tencent/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/MedicalNet_Tencent/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/_preprocessing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/_preprocessing.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/mix_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/mix_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_universal.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_universal.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/mobilenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/mobilenet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/mobilenetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/mobilenetv2.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/shufflenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/shufflenet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/squeezenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/squeezenet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/inceptionresnetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/inceptionresnetv2.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_efficientnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_efficientnet.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_mobilenetv3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/SMP_qubvel/encoders/__pycache__/timm_mobilenetv3.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/shufflenetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WAMAWAMA/WAMA_Modules/HEAD/wama_modules/thirdparty_lib/Efficient3D_okankop/models/__pycache__/shufflenetv2.cpython-38.pyc -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SegmentationModel 2 | 3 | from .modules import ( 4 | Conv2dReLU, 5 | Attention, 6 | ) 7 | 8 | from .heads import ( 9 | SegmentationHead, 10 | ClassificationHead, 11 | ) 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | efficientnet_pytorch==0.7.1 2 | einops==0.6.0 3 | inplace_abn==1.1.0 4 | matplotlib==3.5.1 5 | numpy==1.23.5 6 | pretrainedmodels==0.7.4 7 | prettytable==3.6.0 8 | torch 9 | torchtext==0.12.0 10 | torchvision==0.12.0 11 | 12 | setuptools 13 | timm 14 | tqdm 15 | transformers 16 | 17 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from . import train 4 | from . import losses 5 | from . import metrics 6 | 7 | warnings.warn( 8 | "`smp.utils` module is deprecated and will be removed in future releases.", 9 | DeprecationWarning, 10 | ) 11 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/_preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def preprocess_input(x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs): 5 | 6 | if input_space == "BGR": 7 | x = x[..., ::-1].copy() 8 | 9 | if input_range is not None: 10 | if x.max() > 1 and input_range[1] == 1: 11 | x = x / 255.0 12 | 13 | if mean is not None: 14 | mean = np.array(mean) 15 | x = x - mean 16 | 17 | if std is not None: 18 | std = np.array(std) 19 | x = x / std 20 | 21 | return x 22 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/MedicalNet_Tencent/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .models import resnet 4 | 5 | 6 | def generate_model(model_depth): 7 | if model_depth == 10: 8 | model = resnet.resnet10() 9 | elif model_depth == 18: 10 | model = resnet.resnet18() 11 | elif model_depth == 34: 12 | model = resnet.resnet34() 13 | elif model_depth == 50: 14 | model = resnet.resnet50() 15 | elif model_depth == 101: 16 | model = resnet.resnet101() 17 | elif model_depth == 152: 18 | model = resnet.resnet152() 19 | elif model_depth == 200: 20 | model = resnet.resnet200() 21 | return model 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | import io 3 | import os 4 | import sys 5 | 6 | here = os.path.abspath(os.path.dirname(__file__)) 7 | 8 | # What packages are required for this module to be executed? 9 | try: 10 | with open(os.path.join(here, "requirements.txt"), encoding="utf-8") as f: 11 | REQUIRED = f.read().split("\n") 12 | except: 13 | REQUIRED = [] 14 | 15 | 16 | setup( 17 | name='aini_modules', 18 | version='0.0.1', 19 | description='Enjoy~', 20 | author='wamawama', 21 | author_email='wmy19970215@gmail.com', 22 | python_requires=">=3.6.0", 23 | url='https://github.com/WAMAWAMA/wama_modules', 24 | packages=find_packages(exclude=("demo", "docs", "images")), 25 | install_requires=REQUIRED, 26 | license="MIT", 27 | ) 28 | 29 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/base/initialization.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def initialize_decoder(module): 5 | for m in module.modules(): 6 | 7 | if isinstance(m, nn.Conv2d): 8 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") 9 | if m.bias is not None: 10 | nn.init.constant_(m.bias, 0) 11 | 12 | elif isinstance(m, nn.BatchNorm2d): 13 | nn.init.constant_(m.weight, 1) 14 | nn.init.constant_(m.bias, 0) 15 | 16 | elif isinstance(m, nn.Linear): 17 | nn.init.xavier_uniform_(m.weight) 18 | if m.bias is not None: 19 | nn.init.constant_(m.bias, 0) 20 | 21 | 22 | def initialize_head(module): 23 | for m in module.modules(): 24 | if isinstance(m, (nn.Linear, nn.Conv2d)): 25 | nn.init.xavier_uniform_(m.weight) 26 | if m.bias is not None: 27 | nn.init.constant_(m.bias, 0) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022, 2023 the wama_modules Project 2 | All rights reserved. 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/base/heads.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .modules import Activation 3 | 4 | 5 | class SegmentationHead(nn.Sequential): 6 | def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1): 7 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 8 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 9 | activation = Activation(activation) 10 | super().__init__(conv2d, upsampling, activation) 11 | 12 | 13 | class ClassificationHead(nn.Sequential): 14 | def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None): 15 | if pooling not in ("max", "avg"): 16 | raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling)) 17 | pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1) 18 | flatten = nn.Flatten() 19 | dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity() 20 | linear = nn.Linear(in_channels, classes, bias=True) 21 | activation = Activation(activation) 22 | super().__init__(pool, flatten, dropout, linear, activation) 23 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/timm_universal.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch.nn as nn 3 | 4 | 5 | class TimmUniversalEncoder(nn.Module): 6 | def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32): 7 | super().__init__() 8 | kwargs = dict( 9 | in_chans=in_channels, 10 | features_only=True, 11 | output_stride=output_stride, 12 | pretrained=pretrained, 13 | out_indices=tuple(range(depth)), 14 | ) 15 | 16 | # not all models support output stride argument, drop it by default 17 | if output_stride == 32: 18 | kwargs.pop("output_stride") 19 | 20 | self.model = timm.create_model(name, **kwargs) 21 | 22 | self._in_channels = in_channels 23 | self._out_channels = [ 24 | in_channels, 25 | ] + self.model.feature_info.channels() 26 | self._depth = depth 27 | self._output_stride = output_stride 28 | 29 | def forward(self, x): 30 | features = self.model(x) 31 | features = [ 32 | x, 33 | ] + features 34 | return features 35 | 36 | @property 37 | def out_channels(self): 38 | return self._out_channels 39 | 40 | @property 41 | def output_stride(self): 42 | return min(self._output_stride, 2**self._depth) 43 | -------------------------------------------------------------------------------- /demo/Demo0_VGG_SingleLabelClassification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from wama_modules.Encoder import VGGEncoder 4 | from wama_modules.Head import ClassificationHead 5 | from wama_modules.BaseModule import GlobalMaxPool 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self, in_channel, label_category_dict, dim=2): 10 | super().__init__() 11 | # encoder 12 | f_channel_list = [64, 128, 256, 512] 13 | self.encoder = VGGEncoder( 14 | in_channel, 15 | stage_output_channels=f_channel_list, 16 | blocks=[1, 2, 3, 4], 17 | downsample_ration=[0.5, 0.5, 0.5, 0.5], 18 | dim=dim) 19 | # cls head 20 | self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1]) 21 | self.pooling = GlobalMaxPool() 22 | 23 | def forward(self, x): 24 | f = self.encoder(x) 25 | logits = self.cls_head(self.pooling(f[-1])) 26 | return logits 27 | 28 | 29 | if __name__ == '__main__': 30 | x = torch.ones([2, 1, 64, 64, 64]) 31 | category_num = 1 32 | label_category_dict = dict(is_malignant=category_num) 33 | model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3) 34 | logits = model(x) 35 | print('single-label predicted logits') 36 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] 37 | 38 | # output 👇 39 | # single-label predicted logits 40 | # logits of is_malignant : torch.Size([2, 4]) 41 | -------------------------------------------------------------------------------- /demo/multi_label/generate_multilabel_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | generate example multi_label dataset 3 | """ 4 | import numpy as np 5 | 6 | 7 | label_category_dict = dict( 8 | water=2, # binary class label 9 | milk=2, # binary class label 10 | cow=2, # binary class label 11 | big_white_fish=2, # binary class label 12 | grass=2, # binary class label 13 | weather=5, # 5-class label 14 | ) 15 | label_name = list(label_category_dict.keys()) 16 | print('label num :',len(label_name)) 17 | _ = [print('-'*4, key, ': class num = ', label_category_dict[key]) for key in label_category_dict.keys()] 18 | 19 | img_channel = 2 20 | dataset = [ 21 | dict( # case1 22 | img_1D=np.ones([128, img_channel]), # which is also a 1D signal 23 | img_2D=np.ones([128, 128, img_channel]), 24 | img_3D=np.ones([64, 64, 64, img_channel]), 25 | label_value=[1, 1, 1, 0, 0, 1], # 0=negative, 1=positive 26 | label_known=[1, 1, 1, 1, 1, 1], # 0=unknown/missing, 1=known 27 | ), 28 | dict( # case2 29 | img_1D=np.ones([128, img_channel]), 30 | img_2D=np.ones([128, 128, img_channel]), 31 | img_3D=np.ones([64, 64, 64, img_channel]), 32 | label_value=[1, 0, 0, 0, 1, 1], 33 | label_known=[1, 1, 1, 0, 0, 0], 34 | ), 35 | dict( # case3 36 | img_1D=np.ones([128, img_channel]), 37 | img_2D=np.ones([128, 128, img_channel]), 38 | img_3D=np.ones([64, 64, 64, img_channel]), 39 | label_value=[1, 1, 0, 1, 1, 0], 40 | label_known=[1, 1, 0, 1, 0, 0], 41 | ), 42 | ] 43 | -------------------------------------------------------------------------------- /demo/Demo1_ResNet_SingleLabelClassification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from wama_modules.Encoder import ResNetEncoder 4 | from wama_modules.Head import ClassificationHead 5 | from wama_modules.BaseModule import GlobalMaxPool 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self, in_channel, label_category_dict, dim=2): 10 | super().__init__() 11 | # encoder 12 | f_channel_list = [64, 128, 256, 512] 13 | self.encoder = ResNetEncoder( 14 | in_channel, 15 | stage_output_channels=f_channel_list, 16 | stage_middle_channels=f_channel_list, 17 | blocks=[1, 2, 3, 4], 18 | type='131', 19 | downsample_ration=[0.5, 0.5, 0.5, 0.5], 20 | dim=dim) 21 | # cls head 22 | self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1]) 23 | self.pooling = GlobalMaxPool() 24 | 25 | def forward(self, x): 26 | f = self.encoder(x) 27 | logits = self.cls_head(self.pooling(f[-1])) 28 | return logits 29 | 30 | 31 | if __name__ == '__main__': 32 | x = torch.ones([2, 1, 64, 64, 64]) 33 | label_category_dict = dict(is_malignant=4) 34 | model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3) 35 | logits = model(x) 36 | print('single-label predicted logits') 37 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] 38 | 39 | # output 👇 40 | # single-label predicted logits 41 | # logits of is_malignant : torch.Size([2, 4]) 42 | -------------------------------------------------------------------------------- /demo/Demo2_ResNet_MultiLabelClassification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from wama_modules.Encoder import ResNetEncoder 4 | from wama_modules.Head import ClassificationHead 5 | from wama_modules.BaseModule import GlobalMaxPool 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self, in_channel, label_category_dict, dim=2): 10 | super().__init__() 11 | # encoder 12 | f_channel_list = [64, 128, 256, 512] 13 | self.encoder = ResNetEncoder( 14 | in_channel, 15 | stage_output_channels=f_channel_list, 16 | stage_middle_channels=f_channel_list, 17 | blocks=[1, 2, 3, 4], 18 | type='131', 19 | downsample_ration=[0.5, 0.5, 0.5, 0.5], 20 | dim=dim) 21 | # cls head 22 | self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1]) 23 | 24 | self.pooling = GlobalMaxPool() 25 | 26 | def forward(self, x): 27 | f = self.encoder(x) 28 | logits = self.cls_head(self.pooling(f[-1])) 29 | return logits 30 | 31 | 32 | if __name__ == '__main__': 33 | x = torch.ones([2, 1, 64, 64, 64]) 34 | label_category_dict = dict(shape=4, color=3, other=13) 35 | model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3) 36 | logits = model(x) 37 | print('multi_label predicted logits') 38 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] 39 | 40 | # out 41 | # multi_label predicted logits 42 | # logits of shape : torch.Size([2, 4]) 43 | # logits of color : torch.Size([2, 3]) 44 | # logits of other : torch.Size([2, 13]) 45 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from . import base 4 | from . import functional as F 5 | from ..base.modules import Activation 6 | 7 | 8 | class JaccardLoss(base.Loss): 9 | def __init__(self, eps=1.0, activation=None, ignore_channels=None, **kwargs): 10 | super().__init__(**kwargs) 11 | self.eps = eps 12 | self.activation = Activation(activation) 13 | self.ignore_channels = ignore_channels 14 | 15 | def forward(self, y_pr, y_gt): 16 | y_pr = self.activation(y_pr) 17 | return 1 - F.jaccard( 18 | y_pr, 19 | y_gt, 20 | eps=self.eps, 21 | threshold=None, 22 | ignore_channels=self.ignore_channels, 23 | ) 24 | 25 | 26 | class DiceLoss(base.Loss): 27 | def __init__(self, eps=1.0, beta=1.0, activation=None, ignore_channels=None, **kwargs): 28 | super().__init__(**kwargs) 29 | self.eps = eps 30 | self.beta = beta 31 | self.activation = Activation(activation) 32 | self.ignore_channels = ignore_channels 33 | 34 | def forward(self, y_pr, y_gt): 35 | y_pr = self.activation(y_pr) 36 | return 1 - F.f_score( 37 | y_pr, 38 | y_gt, 39 | beta=self.beta, 40 | eps=self.eps, 41 | threshold=None, 42 | ignore_channels=self.ignore_channels, 43 | ) 44 | 45 | 46 | class L1Loss(nn.L1Loss, base.Loss): 47 | pass 48 | 49 | 50 | class MSELoss(nn.MSELoss, base.Loss): 51 | pass 52 | 53 | 54 | class CrossEntropyLoss(nn.CrossEntropyLoss, base.Loss): 55 | pass 56 | 57 | 58 | class NLLLoss(nn.NLLLoss, base.Loss): 59 | pass 60 | 61 | 62 | class BCELoss(nn.BCELoss, base.Loss): 63 | pass 64 | 65 | 66 | class BCEWithLogitsLoss(nn.BCEWithLogitsLoss, base.Loss): 67 | pass 68 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/utils/meter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Meter(object): 5 | """Meters provide a way to keep track of important statistics in an online manner. 6 | This class is abstract, but provides a standard interface for all meters to follow. 7 | """ 8 | 9 | def reset(self): 10 | """Reset the meter to default settings.""" 11 | pass 12 | 13 | def add(self, value): 14 | """Log a new value to the meter 15 | Args: 16 | value: Next result to include. 17 | """ 18 | pass 19 | 20 | def value(self): 21 | """Get the value of the meter in the current state.""" 22 | pass 23 | 24 | 25 | class AverageValueMeter(Meter): 26 | def __init__(self): 27 | super(AverageValueMeter, self).__init__() 28 | self.reset() 29 | self.val = 0 30 | 31 | def add(self, value, n=1): 32 | self.val = value 33 | self.sum += value 34 | self.var += value * value 35 | self.n += n 36 | 37 | if self.n == 0: 38 | self.mean, self.std = np.nan, np.nan 39 | elif self.n == 1: 40 | self.mean = 0.0 + self.sum # This is to force a copy in torch/numpy 41 | self.std = np.inf 42 | self.mean_old = self.mean 43 | self.m_s = 0.0 44 | else: 45 | self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n) 46 | self.m_s += (value - self.mean_old) * (value - self.mean) 47 | self.mean_old = self.mean 48 | self.std = np.sqrt(self.m_s / (self.n - 1.0)) 49 | 50 | def value(self): 51 | return self.mean, self.std 52 | 53 | def reset(self): 54 | self.n = 0 55 | self.sum = 0.0 56 | self.var = 0.0 57 | self.val = 0.0 58 | self.mean = np.nan 59 | self.mean_old = 0.0 60 | self.m_s = 0.0 61 | self.std = np.nan 62 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True): 6 | """Change first convolution layer input channels. 7 | In case: 8 | in_channels == 1 or in_channels == 2 -> reuse original weights 9 | in_channels > 3 -> make random kaiming normal initialization 10 | """ 11 | 12 | # get first conv 13 | for module in model.modules(): 14 | if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels: 15 | break 16 | 17 | weight = module.weight.detach() 18 | module.in_channels = new_in_channels 19 | 20 | if not pretrained: 21 | module.weight = nn.parameter.Parameter( 22 | torch.Tensor(module.out_channels, new_in_channels // module.groups, *module.kernel_size) 23 | ) 24 | module.reset_parameters() 25 | 26 | elif new_in_channels == 1: 27 | new_weight = weight.sum(1, keepdim=True) 28 | module.weight = nn.parameter.Parameter(new_weight) 29 | 30 | else: 31 | new_weight = torch.Tensor(module.out_channels, new_in_channels // module.groups, *module.kernel_size) 32 | 33 | for i in range(new_in_channels): 34 | new_weight[:, i] = weight[:, i % default_in_channels] 35 | 36 | new_weight = new_weight * (default_in_channels / new_in_channels) 37 | module.weight = nn.parameter.Parameter(new_weight) 38 | 39 | 40 | def replace_strides_with_dilation(module, dilation_rate): 41 | """Patch Conv2d modules replacing strides with dilation""" 42 | for mod in module.modules(): 43 | if isinstance(mod, nn.Conv2d): 44 | mod.stride = (1, 1) 45 | mod.dilation = (dilation_rate, dilation_rate) 46 | kh, kw = mod.kernel_size 47 | mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) 48 | 49 | # Kostyl for EfficientNet 50 | if hasattr(mod, "static_padding"): 51 | mod.static_padding = nn.Identity() 52 | -------------------------------------------------------------------------------- /demo/Demo3_ResNetUnet_SingleLabelSegmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from wama_modules.Encoder import ResNetEncoder 4 | from wama_modules.Decoder import UNet_decoder 5 | from wama_modules.Head import SegmentationHead 6 | from wama_modules.utils import resizeTensor 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self, in_channel, label_category_dict, dim=2): 11 | super().__init__() 12 | # encoder 13 | Encoder_f_channel_list = [64, 128, 256, 512] 14 | self.encoder = ResNetEncoder( 15 | in_channel, 16 | stage_output_channels=Encoder_f_channel_list, 17 | stage_middle_channels=Encoder_f_channel_list, 18 | blocks=[1, 2, 3, 4], 19 | type='131', 20 | downsample_ration=[0.5, 0.5, 0.5, 0.5], 21 | dim=dim) 22 | # decoder 23 | Decoder_f_channel_list = [32, 64, 128] 24 | self.decoder = UNet_decoder( 25 | in_channels_list=Encoder_f_channel_list, 26 | skip_connection=[False, True, True], 27 | out_channels_list=Decoder_f_channel_list, 28 | dim=dim) 29 | # seg head 30 | self.seg_head = SegmentationHead( 31 | label_category_dict, 32 | Decoder_f_channel_list[0], 33 | dim=dim) 34 | 35 | def forward(self, x): 36 | multi_scale_f1 = self.encoder(x) 37 | multi_scale_f2 = self.decoder(multi_scale_f1) 38 | f_for_seg = resizeTensor(multi_scale_f2[0], size=x.shape[2:]) 39 | logits = self.seg_head(f_for_seg) 40 | return logits 41 | 42 | 43 | if __name__ == '__main__': 44 | x = torch.ones([2, 1, 128, 128, 128]) 45 | label_category_dict = dict(organ=3) 46 | model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3) 47 | logits = model(x) 48 | print('multi_label predicted logits') 49 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] 50 | 51 | # out 52 | # multi_label predicted logits 53 | # logits of organ : torch.Size([2, 3, 128, 128, 128]) 54 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/base/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import initialization as init 3 | 4 | 5 | class SegmentationModel(torch.nn.Module): 6 | def initialize(self): 7 | init.initialize_decoder(self.decoder) 8 | init.initialize_head(self.segmentation_head) 9 | if self.classification_head is not None: 10 | init.initialize_head(self.classification_head) 11 | 12 | def check_input_shape(self, x): 13 | 14 | h, w = x.shape[-2:] 15 | output_stride = self.encoder.output_stride 16 | if h % output_stride != 0 or w % output_stride != 0: 17 | new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h 18 | new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w 19 | raise RuntimeError( 20 | f"Wrong input shape height={h}, width={w}. Expected image height and width " 21 | f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})." 22 | ) 23 | 24 | def forward(self, x): 25 | """Sequentially pass `x` trough model`s encoder, decoder and heads""" 26 | 27 | self.check_input_shape(x) 28 | 29 | features = self.encoder(x) 30 | decoder_output = self.decoder(*features) 31 | 32 | masks = self.segmentation_head(decoder_output) 33 | 34 | if self.classification_head is not None: 35 | labels = self.classification_head(features[-1]) 36 | return masks, labels 37 | 38 | return masks 39 | 40 | @torch.no_grad() 41 | def predict(self, x): 42 | """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()` 43 | 44 | Args: 45 | x: 4D torch tensor with shape (batch_size, channels, height, width) 46 | 47 | Return: 48 | prediction: 4D torch tensor with shape (batch_size, classes, height, width) 49 | 50 | """ 51 | if self.training: 52 | self.eval() 53 | 54 | x = self.forward(x) 55 | 56 | return x 57 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/utils/base.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch.nn as nn 3 | 4 | 5 | class BaseObject(nn.Module): 6 | def __init__(self, name=None): 7 | super().__init__() 8 | self._name = name 9 | 10 | @property 11 | def __name__(self): 12 | if self._name is None: 13 | name = self.__class__.__name__ 14 | s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) 15 | return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() 16 | else: 17 | return self._name 18 | 19 | 20 | class Metric(BaseObject): 21 | pass 22 | 23 | 24 | class Loss(BaseObject): 25 | def __add__(self, other): 26 | if isinstance(other, Loss): 27 | return SumOfLosses(self, other) 28 | else: 29 | raise ValueError("Loss should be inherited from `Loss` class") 30 | 31 | def __radd__(self, other): 32 | return self.__add__(other) 33 | 34 | def __mul__(self, value): 35 | if isinstance(value, (int, float)): 36 | return MultipliedLoss(self, value) 37 | else: 38 | raise ValueError("Loss should be inherited from `BaseLoss` class") 39 | 40 | def __rmul__(self, other): 41 | return self.__mul__(other) 42 | 43 | 44 | class SumOfLosses(Loss): 45 | def __init__(self, l1, l2): 46 | name = "{} + {}".format(l1.__name__, l2.__name__) 47 | super().__init__(name=name) 48 | self.l1 = l1 49 | self.l2 = l2 50 | 51 | def __call__(self, *inputs): 52 | return self.l1.forward(*inputs) + self.l2.forward(*inputs) 53 | 54 | 55 | class MultipliedLoss(Loss): 56 | def __init__(self, loss, multiplier): 57 | 58 | # resolve name 59 | if len(loss.__name__.split("+")) > 1: 60 | name = "{} * ({})".format(multiplier, loss.__name__) 61 | else: 62 | name = "{} * {}".format(multiplier, loss.__name__) 63 | super().__init__(name=name) 64 | self.loss = loss 65 | self.multiplier = multiplier 66 | 67 | def __call__(self, *inputs): 68 | return self.multiplier * self.loss.forward(*inputs) 69 | -------------------------------------------------------------------------------- /demo/Demo4_ResNetUnet_MultiLabelSegmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from wama_modules.Encoder import ResNetEncoder 4 | from wama_modules.Decoder import UNet_decoder 5 | from wama_modules.Head import SegmentationHead 6 | from wama_modules.utils import resizeTensor 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self, in_channel, label_category_dict, dim=2): 11 | super().__init__() 12 | # encoder 13 | Encoder_f_channel_list = [64, 128, 256, 512] 14 | self.encoder = ResNetEncoder( 15 | in_channel, 16 | stage_output_channels=Encoder_f_channel_list, 17 | stage_middle_channels=Encoder_f_channel_list, 18 | blocks=[1, 2, 3, 4], 19 | type='131', 20 | downsample_ration=[0.5, 0.5, 0.5, 0.5], 21 | dim=dim) 22 | # decoder 23 | Decoder_f_channel_list = [32, 64, 128] 24 | self.decoder = UNet_decoder( 25 | in_channels_list=Encoder_f_channel_list, 26 | skip_connection=[False, True, True], 27 | out_channels_list=Decoder_f_channel_list, 28 | dim=dim) 29 | # seg head 30 | self.seg_head = SegmentationHead( 31 | label_category_dict, 32 | Decoder_f_channel_list[0], 33 | dim=dim) 34 | 35 | def forward(self, x): 36 | multi_scale_f1 = self.encoder(x) 37 | multi_scale_f2 = self.decoder(multi_scale_f1) 38 | f_for_seg = resizeTensor(multi_scale_f2[0], size=x.shape[2:]) 39 | logits = self.seg_head(f_for_seg) 40 | return logits 41 | 42 | 43 | if __name__ == '__main__': 44 | x = torch.ones([2, 1, 128, 128, 128]) 45 | label_category_dict = dict(organ=3, tumor=4) 46 | model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3) 47 | logits = model(x) 48 | print('multi_label predicted logits') 49 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] 50 | 51 | # out 52 | # multi_label predicted logits 53 | # logits of organ : torch.Size([2, 3, 128, 128, 128]) 54 | # logits of tumor : torch.Size([2, 4, 128, 128, 128]) 55 | 56 | 57 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import List 4 | from collections import OrderedDict 5 | 6 | from . import _utils as utils 7 | 8 | 9 | class EncoderMixin: 10 | """Add encoder functionality such as: 11 | - output channels specification of feature tensors (produced by encoder) 12 | - patching first convolution for arbitrary input channels 13 | """ 14 | 15 | _output_stride = 32 16 | 17 | @property 18 | def out_channels(self): 19 | """Return channels dimensions for each tensor of forward output of encoder""" 20 | return self._out_channels[: self._depth + 1] 21 | 22 | @property 23 | def output_stride(self): 24 | return min(self._output_stride, 2**self._depth) 25 | 26 | def set_in_channels(self, in_channels, pretrained=True): 27 | """Change first convolution channels""" 28 | if in_channels == 3: 29 | return 30 | 31 | self._in_channels = in_channels 32 | if self._out_channels[0] == 3: 33 | self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) 34 | 35 | utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained) 36 | 37 | def get_stages(self): 38 | """Override it in your implementation""" 39 | raise NotImplementedError 40 | 41 | def make_dilated(self, output_stride): 42 | 43 | if output_stride == 16: 44 | stage_list = [ 45 | 5, 46 | ] 47 | dilation_list = [ 48 | 2, 49 | ] 50 | 51 | elif output_stride == 8: 52 | stage_list = [4, 5] 53 | dilation_list = [2, 4] 54 | 55 | else: 56 | raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride)) 57 | 58 | self._output_stride = output_stride 59 | 60 | stages = self.get_stages() 61 | for stage_indx, dilation_rate in zip(stage_list, dilation_list): 62 | utils.replace_strides_with_dilation( 63 | module=stages[stage_indx], 64 | dilation_rate=dilation_rate, 65 | ) 66 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/C3D_yyuanad/c3d.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class C3D(nn.Module): 7 | """ 8 | nb_classes: nb_classes in classification task, 101 for UCF101 dataset 9 | """ 10 | 11 | def __init__(self,): 12 | super(C3D, self).__init__() 13 | 14 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 15 | self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) 16 | 17 | self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 18 | self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 19 | 20 | self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 21 | self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 22 | self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 23 | 24 | self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 25 | self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 26 | self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 27 | 28 | self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 29 | self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 30 | self.pool5 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)) 31 | 32 | self.relu = nn.ReLU() 33 | 34 | def forward(self, x): 35 | f_list = [] 36 | h = self.relu(self.conv1(x)) 37 | h = self.pool1(h) 38 | f_list.append(h) 39 | 40 | h = self.relu(self.conv2(h)) 41 | h = self.pool2(h) 42 | f_list.append(h) 43 | 44 | h = self.relu(self.conv3a(h)) 45 | h = self.relu(self.conv3b(h)) 46 | h = self.pool3(h) 47 | f_list.append(h) 48 | 49 | h = self.relu(self.conv4a(h)) 50 | h = self.relu(self.conv4b(h)) 51 | h = self.pool4(h) 52 | f_list.append(h) 53 | 54 | h = self.relu(self.conv5a(h)) 55 | h = self.relu(self.conv5b(h)) 56 | h = self.pool5(h) 57 | f_list.append(h) 58 | 59 | return f_list 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/xception.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch.nn as nn 3 | 4 | from pretrainedmodels.models.xception import pretrained_settings 5 | from pretrainedmodels.models.xception import Xception 6 | 7 | from ._base import EncoderMixin 8 | 9 | 10 | class XceptionEncoder(Xception, EncoderMixin): 11 | def __init__(self, out_channels, *args, depth=5, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | 14 | self._out_channels = out_channels 15 | self._depth = depth 16 | self._in_channels = 3 17 | 18 | # modify padding to maintain output shape 19 | self.conv1.padding = (1, 1) 20 | self.conv2.padding = (1, 1) 21 | 22 | del self.fc 23 | 24 | def make_dilated(self, *args, **kwargs): 25 | raise ValueError( 26 | "Xception encoder does not support dilated mode " "due to pooling operation for downsampling!" 27 | ) 28 | 29 | def get_stages(self): 30 | return [ 31 | nn.Identity(), 32 | nn.Sequential(self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.relu), 33 | self.block1, 34 | self.block2, 35 | nn.Sequential( 36 | self.block3, 37 | self.block4, 38 | self.block5, 39 | self.block6, 40 | self.block7, 41 | self.block8, 42 | self.block9, 43 | self.block10, 44 | self.block11, 45 | ), 46 | nn.Sequential(self.block12, self.conv3, self.bn3, self.relu, self.conv4, self.bn4), 47 | ] 48 | 49 | def forward(self, x): 50 | stages = self.get_stages() 51 | 52 | features = [] 53 | for i in range(self._depth + 1): 54 | x = stages[i](x) 55 | features.append(x) 56 | 57 | return features 58 | 59 | def load_state_dict(self, state_dict): 60 | # remove linear 61 | state_dict.pop("fc.bias", None) 62 | state_dict.pop("fc.weight", None) 63 | 64 | super().load_state_dict(state_dict) 65 | 66 | 67 | xception_encoders = { 68 | "xception": { 69 | "encoder": XceptionEncoder, 70 | "pretrained_settings": pretrained_settings["xception"], 71 | "params": { 72 | "out_channels": (3, 64, 128, 256, 728, 2048), 73 | }, 74 | }, 75 | } 76 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/C3D_jfzhang95/c3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class C3D(nn.Module): 6 | """ 7 | The C3D network. 8 | """ 9 | 10 | def __init__(self): 11 | super(C3D, self).__init__() 12 | 13 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 14 | self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) 15 | 16 | self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 17 | self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 18 | 19 | self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 20 | self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 21 | self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 22 | 23 | self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 24 | self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 25 | self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 26 | 27 | self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 28 | self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 29 | self.pool5 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)) 30 | 31 | self.relu = torch.relu 32 | 33 | def forward(self, x): 34 | f_list = [] 35 | x = self.relu(self.conv1(x)) 36 | x = self.pool1(x) 37 | f_list.append(x) 38 | x = self.relu(self.conv2(x)) 39 | x = self.pool2(x) 40 | f_list.append(x) 41 | x = self.relu(self.conv3a(x)) 42 | x = self.relu(self.conv3b(x)) 43 | x = self.pool3(x) 44 | f_list.append(x) 45 | x = self.relu(self.conv4a(x)) 46 | x = self.relu(self.conv4b(x)) 47 | x = self.pool4(x) 48 | f_list.append(x) 49 | x = self.relu(self.conv5a(x)) 50 | x = self.relu(self.conv5b(x)) 51 | x = self.pool5(x) 52 | f_list.append(x) 53 | 54 | return f_list 55 | 56 | def __init_weight(self): 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv3d): 59 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 60 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 61 | torch.nn.init.kaiming_normal_(m.weight) 62 | elif isinstance(m, nn.BatchNorm3d): 63 | m.weight.data.fill_(1) 64 | m.bias.data.zero_() 65 | -------------------------------------------------------------------------------- /demo/Demo6_UnetwithFPN_segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from wama_modules.Encoder import ResNetEncoder 4 | from wama_modules.Decoder import UNet_decoder 5 | from wama_modules.Head import SegmentationHead 6 | from wama_modules.utils import resizeTensor 7 | from wama_modules.Neck import FPN 8 | 9 | 10 | class Model(nn.Module): 11 | def __init__(self, in_channel, label_category_dict, dim=2): 12 | super().__init__() 13 | # encoder 14 | Encoder_f_channel_list = [64, 128, 256, 512] 15 | self.encoder = ResNetEncoder( 16 | in_channel, 17 | stage_output_channels=Encoder_f_channel_list, 18 | stage_middle_channels=Encoder_f_channel_list, 19 | blocks=[1, 2, 3, 4], 20 | type='131', 21 | downsample_ration=[0.5, 0.5, 0.5, 0.5], 22 | dim=dim) 23 | 24 | # neck 25 | FPN_output_channel = 256 26 | FPN_channels = [FPN_output_channel]*len(Encoder_f_channel_list) 27 | self.neck = FPN(in_channels_list=Encoder_f_channel_list, 28 | c1=FPN_output_channel//2, 29 | c2=FPN_output_channel, 30 | mode='AddSmall2Big', 31 | dim=dim,) 32 | 33 | # decoder 34 | Decoder_f_channel_list = [32, 64, 128] 35 | self.decoder = UNet_decoder( 36 | in_channels_list=FPN_channels, 37 | skip_connection=[True, True, True], 38 | out_channels_list=Decoder_f_channel_list, 39 | dim=dim) 40 | # seg head 41 | self.seg_head = SegmentationHead( 42 | label_category_dict, 43 | Decoder_f_channel_list[0], 44 | dim=dim) 45 | 46 | def forward(self, x): 47 | multi_scale_encoder = self.encoder(x) 48 | multi_scale_neck = self.neck(multi_scale_encoder) 49 | multi_scale_decoder = self.decoder(multi_scale_neck) 50 | f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:]) 51 | logits = self.seg_head(f_for_seg) 52 | return logits 53 | 54 | 55 | if __name__ == '__main__': 56 | x = torch.ones([2, 1, 128, 128, 128]) 57 | label_category_dict = dict(organ=3, tumor=4) 58 | model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3) 59 | logits = model(x) 60 | print('multi_label predicted logits') 61 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] 62 | 63 | # out 64 | # multi_label predicted logits 65 | # logits of organ : torch.Size([2, 3, 128, 128, 128]) 66 | # logits of tumor : torch.Size([2, 4, 128, 128, 128]) 67 | 68 | 69 | -------------------------------------------------------------------------------- /demo/_explain_how2useVitAsNeck.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import ViTConfig, ViTModel 3 | from wama_modules.utils import load_weights, tensor2array 4 | 5 | 6 | m = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k') 7 | f = m(torch.ones([1, 3, 224, 224]), output_hidden_states=True) 8 | f_last = f.last_hidden_state 9 | print(f_last.shape) 10 | f_cls_token = (torch.squeeze(f_last[:,0])).data.cpu().numpy() 11 | f_cls_token = list(f_cls_token) 12 | 13 | 14 | configuration = m.config 15 | configuration.image_size = [16, 8] 16 | configuration.patch_size = [1, 1] 17 | configuration.num_channels = 1 18 | configuration.encoder_stride = 1 # just for MAE decoder, otherwise this paramater is not used 19 | m1 = ViTModel(configuration, add_pooling_layer=False) 20 | 21 | f = m1(torch.ones([2, 1, 16, 8]), output_hidden_states=True) 22 | 23 | 24 | f_list = f.hidden_states # For transformer, should use reshaped_hidden_states 25 | _ = [print(i.shape) for i in f_list] 26 | 27 | f_last = f.last_hidden_state 28 | f_last = f_last[:, 1:] 29 | f_last = f_last.permute(0, 2, 1) 30 | f_last = f_last.reshape(f_last.shape[0], f_last.shape[1], configuration.image_size[0], configuration.image_size[1]) 31 | print('spatial f_last:', f_last.shape) 32 | 33 | 34 | # reload weights 35 | 36 | m = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k') 37 | weights = m.state_dict() 38 | weights['embeddings.position_embeddings'] = m1.state_dict()['embeddings.position_embeddings'] 39 | weights['embeddings.patch_embeddings.projection.weight'] = m1.state_dict()['embeddings.patch_embeddings.projection.weight'] 40 | weights['embeddings.patch_embeddings.projection.bias'] = m1.state_dict()['embeddings.patch_embeddings.projection.bias'] 41 | 42 | 43 | m1 = load_weights(m1, weights) 44 | 45 | 46 | # test: spatial visualization 47 | m1 = ViTModel(configuration, add_pooling_layer=False) 48 | 49 | input = torch.ones([2, 1, 16, 8])*100 50 | input[:,:,8:] = input[:,:,8:]*0. 51 | input[:,:,:3] = input[:,:,:3]*0. 52 | input[:,:,:,:3] = input[:,:,:,:3]*0. 53 | f = m1(input, output_hidden_states=True) 54 | f_last = f.last_hidden_state 55 | f_last = f_last[:, 1:] 56 | f_last = f_last.permute(0, 2, 1) 57 | f_last = f_last.reshape(f_last.shape[0], f_last.shape[1], configuration.image_size[0], configuration.image_size[1]) 58 | print('spatial f_last:', f_last.shape) 59 | print(f_last.max()) 60 | print(f_last.min()) 61 | 62 | def tensor2numpy(tensor): 63 | return tensor.data.cpu().numpy() 64 | import numpy as np 65 | def mat2gray(image): 66 | """ 67 | 归一化函数(线性归一化) 68 | :param image: ndarray 69 | :return: 70 | """ 71 | # as dtype = np.float32 72 | image = image.astype(np.float32) 73 | image = (image - np.min(image)) / (np.max(image)-np.min(image)+ 1e-14) 74 | return image 75 | 76 | import matplotlib.pyplot as plt 77 | def show2D(img): 78 | plt.imshow(img) 79 | plt.show() 80 | 81 | 82 | # the two image should be aligned in space 83 | show2D(tensor2numpy(f_last[0,0])) 84 | show2D(tensor2numpy(input[0,0])) 85 | 86 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/mobilenet.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torchvision 27 | import torch.nn as nn 28 | 29 | from ._base import EncoderMixin 30 | 31 | 32 | class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin): 33 | def __init__(self, out_channels, depth=5, **kwargs): 34 | super().__init__(**kwargs) 35 | self._depth = depth 36 | self._out_channels = out_channels 37 | self._in_channels = 3 38 | del self.classifier 39 | 40 | def get_stages(self): 41 | return [ 42 | nn.Identity(), 43 | self.features[:2], 44 | self.features[2:4], 45 | self.features[4:7], 46 | self.features[7:14], 47 | self.features[14:], 48 | ] 49 | 50 | def forward(self, x): 51 | stages = self.get_stages() 52 | 53 | features = [] 54 | for i in range(self._depth + 1): 55 | x = stages[i](x) 56 | features.append(x) 57 | 58 | return features 59 | 60 | def load_state_dict(self, state_dict, **kwargs): 61 | state_dict.pop("classifier.1.bias", None) 62 | state_dict.pop("classifier.1.weight", None) 63 | super().load_state_dict(state_dict, **kwargs) 64 | 65 | 66 | mobilenet_encoders = { 67 | "mobilenet_v2": { 68 | "encoder": MobileNetV2Encoder, 69 | "pretrained_settings": { 70 | "imagenet": { 71 | "mean": [0.485, 0.456, 0.406], 72 | "std": [0.229, 0.224, 0.225], 73 | "url": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", 74 | "input_space": "RGB", 75 | "input_range": [0, 1], 76 | }, 77 | }, 78 | "params": { 79 | "out_channels": (3, 16, 24, 32, 96, 1280), 80 | }, 81 | }, 82 | } 83 | -------------------------------------------------------------------------------- /demo/Demo5_MultiTask_SegAndCls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from wama_modules.Encoder import ResNetEncoder 4 | from wama_modules.Decoder import UNet_decoder 5 | from wama_modules.Head import SegmentationHead, ClassificationHead 6 | from wama_modules.utils import resizeTensor 7 | from wama_modules.BaseModule import GlobalMaxPool 8 | 9 | 10 | class Model(nn.Module): 11 | def __init__(self, 12 | in_channel, 13 | seg_label_category_dict, 14 | cls_label_category_dict, 15 | dim=2): 16 | super().__init__() 17 | # encoder 18 | Encoder_f_channel_list = [64, 128, 256, 512] 19 | self.encoder = ResNetEncoder( 20 | in_channel, 21 | stage_output_channels=Encoder_f_channel_list, 22 | stage_middle_channels=Encoder_f_channel_list, 23 | blocks=[1, 2, 3, 4], 24 | type='131', 25 | downsample_ration=[0.5, 0.5, 0.5, 0.5], 26 | dim=dim) 27 | # decoder 28 | Decoder_f_channel_list = [32, 64, 128] 29 | self.decoder = UNet_decoder( 30 | in_channels_list=Encoder_f_channel_list, 31 | skip_connection=[False, True, True], 32 | out_channels_list=Decoder_f_channel_list, 33 | dim=dim) 34 | # seg head 35 | self.seg_head = SegmentationHead( 36 | seg_label_category_dict, 37 | Decoder_f_channel_list[0], 38 | dim=dim) 39 | # cls head 40 | self.cls_head = ClassificationHead(cls_label_category_dict, Encoder_f_channel_list[-1]) 41 | 42 | # pooling 43 | self.pooling = GlobalMaxPool() 44 | 45 | def forward(self, x): 46 | # get encoder features 47 | multi_scale_encoder = self.encoder(x) 48 | # get decoder features 49 | multi_scale_decoder = self.decoder(multi_scale_encoder) 50 | # perform segmentation 51 | f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:]) 52 | seg_logits = self.seg_head(f_for_seg) 53 | # perform classification 54 | cls_logits = self.cls_head(self.pooling(multi_scale_encoder[-1])) 55 | return seg_logits, cls_logits 56 | 57 | if __name__ == '__main__': 58 | x = torch.ones([2, 1, 128, 128, 128]) 59 | seg_label_category_dict = dict(organ=3, tumor=2) 60 | cls_label_category_dict = dict(shape=4, color=3, other=13) 61 | model = Model( 62 | in_channel=1, 63 | cls_label_category_dict=cls_label_category_dict, 64 | seg_label_category_dict=seg_label_category_dict, 65 | dim=3) 66 | seg_logits, cls_logits = model(x) 67 | print('multi_label predicted logits') 68 | _ = [print('seg logits of ', key, ':', seg_logits[key].shape) for key in seg_logits.keys()] 69 | print('-'*30) 70 | _ = [print('cls logits of ', key, ':', cls_logits[key].shape) for key in cls_logits.keys()] 71 | 72 | # out 73 | # multi_label predicted logits 74 | # seg logits of organ : torch.Size([2, 3, 128, 128, 128]) 75 | # seg logits of tumor : torch.Size([2, 2, 128, 128, 128]) 76 | # ------------------------------ 77 | # cls logits of shape : torch.Size([2, 4]) 78 | # cls logits of color : torch.Size([2, 3]) 79 | # cls logits of other : torch.Size([2, 13]) 80 | 81 | 82 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | from . import functional as F 3 | from ..base.modules import Activation 4 | 5 | 6 | class IoU(base.Metric): 7 | __name__ = "iou_score" 8 | 9 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 10 | super().__init__(**kwargs) 11 | self.eps = eps 12 | self.threshold = threshold 13 | self.activation = Activation(activation) 14 | self.ignore_channels = ignore_channels 15 | 16 | def forward(self, y_pr, y_gt): 17 | y_pr = self.activation(y_pr) 18 | return F.iou( 19 | y_pr, 20 | y_gt, 21 | eps=self.eps, 22 | threshold=self.threshold, 23 | ignore_channels=self.ignore_channels, 24 | ) 25 | 26 | 27 | class Fscore(base.Metric): 28 | def __init__(self, beta=1, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 29 | super().__init__(**kwargs) 30 | self.eps = eps 31 | self.beta = beta 32 | self.threshold = threshold 33 | self.activation = Activation(activation) 34 | self.ignore_channels = ignore_channels 35 | 36 | def forward(self, y_pr, y_gt): 37 | y_pr = self.activation(y_pr) 38 | return F.f_score( 39 | y_pr, 40 | y_gt, 41 | eps=self.eps, 42 | beta=self.beta, 43 | threshold=self.threshold, 44 | ignore_channels=self.ignore_channels, 45 | ) 46 | 47 | 48 | class Accuracy(base.Metric): 49 | def __init__(self, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 50 | super().__init__(**kwargs) 51 | self.threshold = threshold 52 | self.activation = Activation(activation) 53 | self.ignore_channels = ignore_channels 54 | 55 | def forward(self, y_pr, y_gt): 56 | y_pr = self.activation(y_pr) 57 | return F.accuracy( 58 | y_pr, 59 | y_gt, 60 | threshold=self.threshold, 61 | ignore_channels=self.ignore_channels, 62 | ) 63 | 64 | 65 | class Recall(base.Metric): 66 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 67 | super().__init__(**kwargs) 68 | self.eps = eps 69 | self.threshold = threshold 70 | self.activation = Activation(activation) 71 | self.ignore_channels = ignore_channels 72 | 73 | def forward(self, y_pr, y_gt): 74 | y_pr = self.activation(y_pr) 75 | return F.recall( 76 | y_pr, 77 | y_gt, 78 | eps=self.eps, 79 | threshold=self.threshold, 80 | ignore_channels=self.ignore_channels, 81 | ) 82 | 83 | 84 | class Precision(base.Metric): 85 | def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 86 | super().__init__(**kwargs) 87 | self.eps = eps 88 | self.threshold = threshold 89 | self.activation = Activation(activation) 90 | self.ignore_channels = ignore_channels 91 | 92 | def forward(self, y_pr, y_gt): 93 | y_pr = self.activation(y_pr) 94 | return F.precision( 95 | y_pr, 96 | y_gt, 97 | eps=self.eps, 98 | threshold=self.threshold, 99 | ignore_channels=self.ignore_channels, 100 | ) 101 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/inceptionresnetv2.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch.nn as nn 27 | from pretrainedmodels.models.inceptionresnetv2 import InceptionResNetV2 28 | from pretrainedmodels.models.inceptionresnetv2 import pretrained_settings 29 | 30 | from ._base import EncoderMixin 31 | 32 | 33 | class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin): 34 | def __init__(self, out_channels, depth=5, **kwargs): 35 | super().__init__(**kwargs) 36 | 37 | self._out_channels = out_channels 38 | self._depth = depth 39 | self._in_channels = 3 40 | 41 | # correct paddings 42 | for m in self.modules(): 43 | if isinstance(m, nn.Conv2d): 44 | if m.kernel_size == (3, 3): 45 | m.padding = (1, 1) 46 | if isinstance(m, nn.MaxPool2d): 47 | m.padding = (1, 1) 48 | 49 | # remove linear layers 50 | del self.avgpool_1a 51 | del self.last_linear 52 | 53 | def make_dilated(self, *args, **kwargs): 54 | raise ValueError( 55 | "InceptionResnetV2 encoder does not support dilated mode " "due to pooling operation for downsampling!" 56 | ) 57 | 58 | def get_stages(self): 59 | return [ 60 | nn.Identity(), 61 | nn.Sequential(self.conv2d_1a, self.conv2d_2a, self.conv2d_2b), 62 | nn.Sequential(self.maxpool_3a, self.conv2d_3b, self.conv2d_4a), 63 | nn.Sequential(self.maxpool_5a, self.mixed_5b, self.repeat), 64 | nn.Sequential(self.mixed_6a, self.repeat_1), 65 | nn.Sequential(self.mixed_7a, self.repeat_2, self.block8, self.conv2d_7b), 66 | ] 67 | 68 | def forward(self, x): 69 | 70 | stages = self.get_stages() 71 | 72 | features = [] 73 | for i in range(self._depth + 1): 74 | x = stages[i](x) 75 | features.append(x) 76 | 77 | return features 78 | 79 | def load_state_dict(self, state_dict, **kwargs): 80 | state_dict.pop("last_linear.bias", None) 81 | state_dict.pop("last_linear.weight", None) 82 | super().load_state_dict(state_dict, **kwargs) 83 | 84 | 85 | inceptionresnetv2_encoders = { 86 | "inceptionresnetv2": { 87 | "encoder": InceptionResNetV2Encoder, 88 | "pretrained_settings": pretrained_settings["inceptionresnetv2"], 89 | "params": {"out_channels": (3, 64, 192, 320, 1088, 1536), "num_classes": 1000}, 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/inceptionv4.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch.nn as nn 27 | from pretrainedmodels.models.inceptionv4 import InceptionV4, BasicConv2d 28 | from pretrainedmodels.models.inceptionv4 import pretrained_settings 29 | 30 | from ._base import EncoderMixin 31 | 32 | 33 | class InceptionV4Encoder(InceptionV4, EncoderMixin): 34 | def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): 35 | super().__init__(**kwargs) 36 | self._stage_idxs = stage_idxs 37 | self._out_channels = out_channels 38 | self._depth = depth 39 | self._in_channels = 3 40 | 41 | # correct paddings 42 | for m in self.modules(): 43 | if isinstance(m, nn.Conv2d): 44 | if m.kernel_size == (3, 3): 45 | m.padding = (1, 1) 46 | if isinstance(m, nn.MaxPool2d): 47 | m.padding = (1, 1) 48 | 49 | # remove linear layers 50 | del self.last_linear 51 | 52 | def make_dilated(self, stage_list, dilation_list): 53 | raise ValueError( 54 | "InceptionV4 encoder does not support dilated mode " "due to pooling operation for downsampling!" 55 | ) 56 | 57 | def get_stages(self): 58 | return [ 59 | nn.Identity(), 60 | self.features[: self._stage_idxs[0]], 61 | self.features[self._stage_idxs[0] : self._stage_idxs[1]], 62 | self.features[self._stage_idxs[1] : self._stage_idxs[2]], 63 | self.features[self._stage_idxs[2] : self._stage_idxs[3]], 64 | self.features[self._stage_idxs[3] :], 65 | ] 66 | 67 | def forward(self, x): 68 | 69 | stages = self.get_stages() 70 | 71 | features = [] 72 | for i in range(self._depth + 1): 73 | x = stages[i](x) 74 | features.append(x) 75 | 76 | return features 77 | 78 | def load_state_dict(self, state_dict, **kwargs): 79 | state_dict.pop("last_linear.bias", None) 80 | state_dict.pop("last_linear.weight", None) 81 | super().load_state_dict(state_dict, **kwargs) 82 | 83 | 84 | inceptionv4_encoders = { 85 | "inceptionv4": { 86 | "encoder": InceptionV4Encoder, 87 | "pretrained_settings": pretrained_settings["inceptionv4"], 88 | "params": { 89 | "stage_idxs": (3, 5, 9, 15), 90 | "out_channels": (3, 64, 192, 384, 1024, 1536), 91 | "num_classes": 1001, 92 | }, 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/utils/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from tqdm import tqdm as tqdm 4 | from .meter import AverageValueMeter 5 | 6 | 7 | class Epoch: 8 | def __init__(self, model, loss, metrics, stage_name, device="cpu", verbose=True): 9 | self.model = model 10 | self.loss = loss 11 | self.metrics = metrics 12 | self.stage_name = stage_name 13 | self.verbose = verbose 14 | self.device = device 15 | 16 | self._to_device() 17 | 18 | def _to_device(self): 19 | self.model.to(self.device) 20 | self.loss.to(self.device) 21 | for metric in self.metrics: 22 | metric.to(self.device) 23 | 24 | def _format_logs(self, logs): 25 | str_logs = ["{} - {:.4}".format(k, v) for k, v in logs.items()] 26 | s = ", ".join(str_logs) 27 | return s 28 | 29 | def batch_update(self, x, y): 30 | raise NotImplementedError 31 | 32 | def on_epoch_start(self): 33 | pass 34 | 35 | def run(self, dataloader): 36 | 37 | self.on_epoch_start() 38 | 39 | logs = {} 40 | loss_meter = AverageValueMeter() 41 | metrics_meters = {metric.__name__: AverageValueMeter() for metric in self.metrics} 42 | 43 | with tqdm( 44 | dataloader, 45 | desc=self.stage_name, 46 | file=sys.stdout, 47 | disable=not (self.verbose), 48 | ) as iterator: 49 | for x, y in iterator: 50 | x, y = x.to(self.device), y.to(self.device) 51 | loss, y_pred = self.batch_update(x, y) 52 | 53 | # update loss logs 54 | loss_value = loss.cpu().detach().numpy() 55 | loss_meter.add(loss_value) 56 | loss_logs = {self.loss.__name__: loss_meter.mean} 57 | logs.update(loss_logs) 58 | 59 | # update metrics logs 60 | for metric_fn in self.metrics: 61 | metric_value = metric_fn(y_pred, y).cpu().detach().numpy() 62 | metrics_meters[metric_fn.__name__].add(metric_value) 63 | metrics_logs = {k: v.mean for k, v in metrics_meters.items()} 64 | logs.update(metrics_logs) 65 | 66 | if self.verbose: 67 | s = self._format_logs(logs) 68 | iterator.set_postfix_str(s) 69 | 70 | return logs 71 | 72 | 73 | class TrainEpoch(Epoch): 74 | def __init__(self, model, loss, metrics, optimizer, device="cpu", verbose=True): 75 | super().__init__( 76 | model=model, 77 | loss=loss, 78 | metrics=metrics, 79 | stage_name="train", 80 | device=device, 81 | verbose=verbose, 82 | ) 83 | self.optimizer = optimizer 84 | 85 | def on_epoch_start(self): 86 | self.model.train() 87 | 88 | def batch_update(self, x, y): 89 | self.optimizer.zero_grad() 90 | prediction = self.model.forward(x) 91 | loss = self.loss(prediction, y) 92 | loss.backward() 93 | self.optimizer.step() 94 | return loss, prediction 95 | 96 | 97 | class ValidEpoch(Epoch): 98 | def __init__(self, model, loss, metrics, device="cpu", verbose=True): 99 | super().__init__( 100 | model=model, 101 | loss=loss, 102 | metrics=metrics, 103 | stage_name="valid", 104 | device=device, 105 | verbose=verbose, 106 | ) 107 | 108 | def on_epoch_start(self): 109 | self.model.eval() 110 | 111 | def batch_update(self, x, y): 112 | with torch.no_grad(): 113 | prediction = self.model.forward(x) 114 | loss = self.loss(prediction, y) 115 | return loss, prediction 116 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/timm_sknet.py: -------------------------------------------------------------------------------- 1 | from ._base import EncoderMixin 2 | from timm.models.resnet import ResNet 3 | from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic 4 | import torch.nn as nn 5 | 6 | 7 | class SkNetEncoder(ResNet, EncoderMixin): 8 | def __init__(self, out_channels, depth=5, **kwargs): 9 | super().__init__(**kwargs) 10 | self._depth = depth 11 | self._out_channels = out_channels 12 | self._in_channels = 3 13 | 14 | del self.fc 15 | del self.global_pool 16 | 17 | def get_stages(self): 18 | return [ 19 | nn.Identity(), 20 | nn.Sequential(self.conv1, self.bn1, self.act1), 21 | nn.Sequential(self.maxpool, self.layer1), 22 | self.layer2, 23 | self.layer3, 24 | self.layer4, 25 | ] 26 | 27 | def forward(self, x): 28 | stages = self.get_stages() 29 | 30 | features = [] 31 | for i in range(self._depth + 1): 32 | x = stages[i](x) 33 | features.append(x) 34 | 35 | return features 36 | 37 | def load_state_dict(self, state_dict, **kwargs): 38 | state_dict.pop("fc.bias", None) 39 | state_dict.pop("fc.weight", None) 40 | super().load_state_dict(state_dict, **kwargs) 41 | 42 | 43 | sknet_weights = { 44 | "timm-skresnet18": { 45 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth", # noqa 46 | }, 47 | "timm-skresnet34": { 48 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth", # noqa 49 | }, 50 | "timm-skresnext50_32x4d": { 51 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth", # noqa 52 | }, 53 | } 54 | 55 | pretrained_settings = {} 56 | for model_name, sources in sknet_weights.items(): 57 | pretrained_settings[model_name] = {} 58 | for source_name, source_url in sources.items(): 59 | pretrained_settings[model_name][source_name] = { 60 | "url": source_url, 61 | "input_size": [3, 224, 224], 62 | "input_range": [0, 1], 63 | "mean": [0.485, 0.456, 0.406], 64 | "std": [0.229, 0.224, 0.225], 65 | "num_classes": 1000, 66 | } 67 | 68 | timm_sknet_encoders = { 69 | "timm-skresnet18": { 70 | "encoder": SkNetEncoder, 71 | "pretrained_settings": pretrained_settings["timm-skresnet18"], 72 | "params": { 73 | "out_channels": (3, 64, 64, 128, 256, 512), 74 | "block": SelectiveKernelBasic, 75 | "layers": [2, 2, 2, 2], 76 | "zero_init_last_bn": False, 77 | "block_args": {"sk_kwargs": {"rd_ratio": 1 / 8, "split_input": True}}, 78 | }, 79 | }, 80 | "timm-skresnet34": { 81 | "encoder": SkNetEncoder, 82 | "pretrained_settings": pretrained_settings["timm-skresnet34"], 83 | "params": { 84 | "out_channels": (3, 64, 64, 128, 256, 512), 85 | "block": SelectiveKernelBasic, 86 | "layers": [3, 4, 6, 3], 87 | "zero_init_last_bn": False, 88 | "block_args": {"sk_kwargs": {"rd_ratio": 1 / 8, "split_input": True}}, 89 | }, 90 | }, 91 | "timm-skresnext50_32x4d": { 92 | "encoder": SkNetEncoder, 93 | "pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"], 94 | "params": { 95 | "out_channels": (3, 64, 256, 512, 1024, 2048), 96 | "block": SelectiveKernelBottleneck, 97 | "layers": [3, 4, 6, 3], 98 | "zero_init_last_bn": False, 99 | "cardinality": 32, 100 | "base_width": 4, 101 | }, 102 | }, 103 | } 104 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/c3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the c3d implementation with batch norm. 3 | 4 | References 5 | ---------- 6 | [1] Tran, Du, et al. "Learning spatiotemporal features with 3d convolutional networks." 7 | Proceedings of the IEEE international conference on computer vision. 2015. 8 | """ 9 | 10 | import math 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | import torch.nn.functional as F 15 | from torch.autograd import Variable 16 | from functools import partial 17 | 18 | 19 | class C3D(nn.Module): 20 | def __init__(self,): 21 | 22 | super(C3D, self).__init__() 23 | self.group1 = nn.Sequential( 24 | nn.Conv3d(3, 64, kernel_size=3, padding=1), 25 | nn.BatchNorm3d(64), 26 | nn.ReLU(), 27 | nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2))) 28 | self.group2 = nn.Sequential( 29 | nn.Conv3d(64, 128, kernel_size=3, padding=1), 30 | nn.BatchNorm3d(128), 31 | nn.ReLU(), 32 | nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))) 33 | self.group3 = nn.Sequential( 34 | nn.Conv3d(128, 256, kernel_size=3, padding=1), 35 | nn.BatchNorm3d(256), 36 | nn.ReLU(), 37 | nn.Conv3d(256, 256, kernel_size=3, padding=1), 38 | nn.BatchNorm3d(256), 39 | nn.ReLU(), 40 | nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))) 41 | self.group4 = nn.Sequential( 42 | nn.Conv3d(256, 512, kernel_size=3, padding=1), 43 | nn.BatchNorm3d(512), 44 | nn.ReLU(), 45 | nn.Conv3d(512, 512, kernel_size=3, padding=1), 46 | nn.BatchNorm3d(512), 47 | nn.ReLU(), 48 | nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))) 49 | self.group5 = nn.Sequential( 50 | nn.Conv3d(512, 512, kernel_size=3, padding=1), 51 | nn.BatchNorm3d(512), 52 | nn.ReLU(), 53 | nn.Conv3d(512, 512, kernel_size=3, padding=1), 54 | nn.BatchNorm3d(512), 55 | nn.ReLU(), 56 | nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1))) 57 | 58 | def forward(self, x): 59 | f_list = [] 60 | out = self.group1(x) 61 | f_list.append(out) 62 | out = self.group2(out) 63 | f_list.append(out) 64 | out = self.group3(out) 65 | f_list.append(out) 66 | out = self.group4(out) 67 | f_list.append(out) 68 | out = self.group5(out) 69 | f_list.append(out) 70 | return f_list 71 | 72 | 73 | def get_fine_tuning_parameters(model, ft_portion): 74 | if ft_portion == "complete": 75 | return model.parameters() 76 | 77 | elif ft_portion == "last_layer": 78 | ft_module_names = [] 79 | ft_module_names.append('fc') 80 | 81 | parameters = [] 82 | for k, v in model.named_parameters(): 83 | for ft_module in ft_module_names: 84 | if ft_module in k: 85 | parameters.append({'params': v}) 86 | break 87 | else: 88 | parameters.append({'params': v, 'lr': 0.0}) 89 | return parameters 90 | 91 | else: 92 | raise ValueError("Unsupported ft_portion: 'complete' or 'last_layer' expected") 93 | 94 | 95 | def get_model(**kwargs): 96 | """ 97 | Returns the model. 98 | """ 99 | model = C3D(**kwargs) 100 | return model 101 | 102 | 103 | if __name__ == '__main__': 104 | model = get_model(sample_size = 112, sample_duration = 16, num_classes=600) 105 | model = model.cuda() 106 | model = nn.DataParallel(model, device_ids=None) 107 | print(model) 108 | 109 | input_var = Variable(torch.randn(8, 3, 16, 112, 112)) 110 | output = model(input_var) 111 | print(output.shape) 112 | -------------------------------------------------------------------------------- /wama_modules/Head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from wama_modules.BaseModule import * 3 | 4 | 5 | class ClassificationHead(nn.Module): 6 | """ 7 | Head for single or multiple label classification task 8 | """ 9 | 10 | def __init__(self, label_category_dict, in_channel, bias=True): 11 | super().__init__() 12 | self.classification_head = torch.nn.ModuleDict({}) 13 | for key in label_category_dict.keys(): 14 | self.classification_head[key] = torch.nn.Linear(in_channel, label_category_dict[key], bias=bias) 15 | 16 | def forward(self, f): 17 | """ 18 | # demo: an example of fruit classification task 19 | 20 | f = torch.ones([3, 512]) # from encoder 21 | label_category_dict = dict( 22 | shape=4, 23 | color=3, 24 | rotten=2, 25 | sweet=2, 26 | sour=2, 27 | ) 28 | cls_head = ClassificationHead(label_category_dict, 512) 29 | logits = cls_head(f) 30 | _ = [print('logits of ', key,':' ,logits[key].shape) for key in logits.keys()] 31 | 32 | 33 | # support element-wise performing, by transfer dict or list to f 34 | label_category_dict = dict( 35 | shape=4, 36 | color=3, 37 | rotten=2, 38 | sweet=2, 39 | sour=2, 40 | ) 41 | 42 | # dict input for element-wised FC 43 | f = {} 44 | for key in label_category_dict.keys(): 45 | f[key] = torch.ones([3, 512]) # from encoder 46 | cls_head = ClassificationHead(label_category_dict, 512) 47 | logits = cls_head(f) 48 | _ = [print('logits of ', key,':' ,logits[key].shape) for key in logits.keys()] 49 | 50 | # list input for element-wised FC 51 | f = [] 52 | for key in label_category_dict.keys(): 53 | f.append(torch.ones([3, 512])) # from encoder 54 | cls_head = ClassificationHead(label_category_dict, 512) 55 | logits = cls_head(f) 56 | _ = [print('logits of ', key,':' ,logits[key].shape) for key in logits.keys()] 57 | 58 | """ 59 | logits = {} 60 | if isinstance(f,dict): 61 | print('dict element-wised forward') 62 | for key in self.classification_head.keys(): 63 | logits[key] = self.classification_head[key](f[key]) 64 | elif isinstance(f,list): 65 | print('list element-wised forward') 66 | for key_index, key in enumerate(self.classification_head.keys()): 67 | logits[key] = self.classification_head[key](f[key_index]) 68 | else: 69 | for key in self.classification_head.keys(): 70 | logits[key] = self.classification_head[key](f) 71 | return logits 72 | 73 | 74 | class SegmentationHead(nn.Module): 75 | """Head for single or multiple label segmentation task""" 76 | 77 | def __init__(self, label_category_dict, in_channel, bias=True, dim=2): 78 | super().__init__() 79 | self.segmentatin_head = torch.nn.ModuleDict({}) 80 | for key in label_category_dict.keys(): 81 | self.segmentatin_head[key] = MakeConv(in_channel, label_category_dict[key], 3, padding=1, stride=1, dim=dim, 82 | bias=bias) 83 | 84 | def forward(self, f): 85 | """ 86 | # demo 2D 87 | 88 | f = torch.ones([3, 512, 128, 128]) # from decoder or encoder 89 | label_category_dict = dict( 90 | organ=14, # 14 kinds of organ 91 | tumor=3, # 3 kinds of tumor 92 | ) 93 | seg_head = SegmentationHead(label_category_dict, 512, dim=2) 94 | seg_logits = seg_head(f) 95 | _ = [print('segmentation_logits of ', key,':' ,seg_logits[key].shape) for key in seg_logits.keys()] 96 | 97 | """ 98 | logits = {} 99 | for key in self.segmentatin_head.keys(): 100 | logits[key] = self.segmentatin_head[key](f) 101 | return logits 102 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def conv_bn(inp, oup, stride): 12 | return nn.Sequential( 13 | nn.Conv3d(inp, oup, kernel_size=3, stride=stride, padding=(1,1,1), bias=False), 14 | nn.BatchNorm3d(oup), 15 | nn.ReLU(inplace=True) 16 | ) 17 | 18 | 19 | class Block(nn.Module): 20 | '''Depthwise conv + Pointwise conv''' 21 | def __init__(self, in_planes, out_planes, stride=1): 22 | super(Block, self).__init__() 23 | self.conv1 = nn.Conv3d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 24 | self.bn1 = nn.BatchNorm3d(in_planes) 25 | self.conv2 = nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 26 | self.bn2 = nn.BatchNorm3d(out_planes) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(self.conv1(x))) 30 | out = F.relu(self.bn2(self.conv2(out))) 31 | return out 32 | 33 | 34 | class MobileNet(nn.Module): 35 | def __init__(self, width_mult=1.): 36 | super(MobileNet, self).__init__() 37 | 38 | input_channel = 32 39 | last_channel = 1024 40 | input_channel = int(input_channel * width_mult) 41 | cfg = [ 42 | # c, n, s 43 | [64, 1, (2,2,2)], 44 | [128, 2, (2,2,2)], 45 | [256, 2, (2,2,2)], 46 | [512, 6, (2,2,2)], 47 | [1024, 2, (1,1,1)], 48 | ] 49 | 50 | self.features = [conv_bn(3, input_channel, (1,2,2))] 51 | # building inverted residual blocks 52 | for c, n, s in cfg: 53 | output_channel = int(c * width_mult) 54 | for i in range(n): 55 | stride = s if i == 0 else 1 56 | self.features.append(Block(input_channel, output_channel, stride)) 57 | input_channel = output_channel 58 | # make it nn.Sequential 59 | self.features = nn.Sequential(*self.features) 60 | 61 | 62 | 63 | def forward(self, x): 64 | f_list = [] 65 | for i in range(len(self.features)): 66 | x = self.features[i](x) 67 | f_list.append(x) 68 | 69 | # keep last f 70 | f_list_ = [] 71 | for i, f in enumerate(f_list): 72 | if i == 0 or i == len(f_list)-1: 73 | f_list_.append(f) 74 | elif f.shape[1] != f_list[i+1].shape[1]: 75 | f_list_.append(f) 76 | 77 | return f_list_ 78 | 79 | 80 | def get_fine_tuning_parameters(model, ft_portion): 81 | if ft_portion == "complete": 82 | return model.parameters() 83 | 84 | elif ft_portion == "last_layer": 85 | ft_module_names = [] 86 | ft_module_names.append('classifier') 87 | 88 | parameters = [] 89 | for k, v in model.named_parameters(): 90 | for ft_module in ft_module_names: 91 | if ft_module in k: 92 | parameters.append({'params': v}) 93 | break 94 | else: 95 | parameters.append({'params': v, 'lr': 0.0}) 96 | return parameters 97 | 98 | else: 99 | raise ValueError("Unsupported ft_portion: 'complete' or 'last_layer' expected") 100 | 101 | 102 | def get_model(**kwargs): 103 | """ 104 | Returns the model. 105 | """ 106 | model = MobileNet(**kwargs) 107 | return model 108 | 109 | 110 | 111 | if __name__ == '__main__': 112 | model = get_model(num_classes=600, sample_size = 112, width_mult=1.) 113 | model = model.cuda() 114 | model = nn.DataParallel(model, device_ids=None) 115 | print(model) 116 | 117 | input_var = Variable(torch.randn(8, 3, 16, 112, 112)) 118 | output = model(input_var) 119 | print(output.shape) 120 | -------------------------------------------------------------------------------- /demo/Demo7_2D_TransUnet_Segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from wama_modules.Encoder import ResNetEncoder 4 | from wama_modules.Decoder import UNet_decoder 5 | from wama_modules.Head import SegmentationHead 6 | from wama_modules.utils import resizeTensor 7 | from transformers import ViTModel 8 | from wama_modules.utils import load_weights, tmp_class 9 | 10 | 11 | class TransUNet(nn.Module): 12 | def __init__(self, in_channel, label_category_dict, dim=2): 13 | super().__init__() 14 | 15 | # encoder 16 | Encoder_f_channel_list = [64, 128, 256, 512] 17 | self.encoder = ResNetEncoder( 18 | in_channel, 19 | stage_output_channels=Encoder_f_channel_list, 20 | stage_middle_channels=Encoder_f_channel_list, 21 | blocks=[1, 2, 3, 4], 22 | type='131', 23 | downsample_ration=[0.5, 0.5, 0.5, 0.5], 24 | dim=dim) 25 | 26 | # neck 27 | neck_out_channel = 768 28 | transformer = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k') 29 | configuration = transformer.config 30 | self.trans_downsample_size = configuration.image_size = [8, 8] 31 | configuration.patch_size = [1, 1] 32 | configuration.num_channels = Encoder_f_channel_list[-1] 33 | configuration.encoder_stride = 1 # just for MAE decoder, otherwise this paramater is not used 34 | self.neck = ViTModel(configuration, add_pooling_layer=False) 35 | 36 | pretrained_weights = transformer.state_dict() 37 | pretrained_weights['embeddings.position_embeddings'] = self.neck.state_dict()[ 38 | 'embeddings.position_embeddings'] 39 | pretrained_weights['embeddings.patch_embeddings.projection.weight'] = self.neck.state_dict()[ 40 | 'embeddings.patch_embeddings.projection.weight'] 41 | pretrained_weights['embeddings.patch_embeddings.projection.bias'] = self.neck.state_dict()[ 42 | 'embeddings.patch_embeddings.projection.bias'] 43 | self.neck = load_weights(self.neck, pretrained_weights) # reload pretrained weights 44 | 45 | # decoder 46 | Decoder_f_channel_list = [32, 64, 128] 47 | self.decoder = UNet_decoder( 48 | in_channels_list=Encoder_f_channel_list[:-1]+[neck_out_channel], 49 | skip_connection=[True, True, True], 50 | out_channels_list=Decoder_f_channel_list, 51 | dim=dim) 52 | 53 | # seg head 54 | self.seg_head = SegmentationHead( 55 | label_category_dict, 56 | Decoder_f_channel_list[0], 57 | dim=dim) 58 | 59 | def forward(self, x): 60 | # encoder forward 61 | multi_scale_encoder = self.encoder(x) 62 | 63 | # neck forward 64 | f_neck = self.neck(resizeTensor(multi_scale_encoder[-1], size=self.trans_downsample_size)) 65 | f_neck = f_neck.last_hidden_state 66 | f_neck = f_neck[:, 1:] # remove class token 67 | f_neck = f_neck.permute(0, 2, 1) 68 | f_neck = f_neck.reshape( 69 | f_neck.shape[0], 70 | f_neck.shape[1], 71 | self.trans_downsample_size[0], 72 | self.trans_downsample_size[1] 73 | ) # reshape 74 | f_neck = resizeTensor(f_neck, size=multi_scale_encoder[-1].shape[2:]) 75 | multi_scale_encoder[-1] = f_neck 76 | 77 | # decoder forward 78 | multi_scale_decoder = self.decoder(multi_scale_encoder) 79 | f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:]) 80 | 81 | # seg_head forward 82 | logits = self.seg_head(f_for_seg) 83 | return logits 84 | 85 | 86 | if __name__ == '__main__': 87 | x = torch.ones([2, 1, 256, 256]) 88 | label_category_dict = dict(organ=3, tumor=4) 89 | model = TransUNet(in_channel=1, label_category_dict=label_category_dict, dim=2) 90 | with torch.no_grad(): 91 | logits = model(x) 92 | print('multi_label predicted logits') 93 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] 94 | 95 | # out 96 | # multi_label predicted logits 97 | # logits of organ : torch.Size([2, 3, 256, 256]) 98 | # logits of tumor : torch.Size([2, 4, 256, 256]) 99 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/base/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | try: 5 | from inplace_abn import InPlaceABN 6 | except ImportError: 7 | InPlaceABN = None 8 | 9 | 10 | class Conv2dReLU(nn.Sequential): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | padding=0, 17 | stride=1, 18 | use_batchnorm=True, 19 | ): 20 | 21 | if use_batchnorm == "inplace" and InPlaceABN is None: 22 | raise RuntimeError( 23 | "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " 24 | + "To install see: https://github.com/mapillary/inplace_abn" 25 | ) 26 | 27 | conv = nn.Conv2d( 28 | in_channels, 29 | out_channels, 30 | kernel_size, 31 | stride=stride, 32 | padding=padding, 33 | bias=not (use_batchnorm), 34 | ) 35 | relu = nn.ReLU(inplace=True) 36 | 37 | if use_batchnorm == "inplace": 38 | bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) 39 | relu = nn.Identity() 40 | 41 | elif use_batchnorm and use_batchnorm != "inplace": 42 | bn = nn.BatchNorm2d(out_channels) 43 | 44 | else: 45 | bn = nn.Identity() 46 | 47 | super(Conv2dReLU, self).__init__(conv, bn, relu) 48 | 49 | 50 | class SCSEModule(nn.Module): 51 | def __init__(self, in_channels, reduction=16): 52 | super().__init__() 53 | self.cSE = nn.Sequential( 54 | nn.AdaptiveAvgPool2d(1), 55 | nn.Conv2d(in_channels, in_channels // reduction, 1), 56 | nn.ReLU(inplace=True), 57 | nn.Conv2d(in_channels // reduction, in_channels, 1), 58 | nn.Sigmoid(), 59 | ) 60 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) 61 | 62 | def forward(self, x): 63 | return x * self.cSE(x) + x * self.sSE(x) 64 | 65 | 66 | class ArgMax(nn.Module): 67 | def __init__(self, dim=None): 68 | super().__init__() 69 | self.dim = dim 70 | 71 | def forward(self, x): 72 | return torch.argmax(x, dim=self.dim) 73 | 74 | 75 | class Clamp(nn.Module): 76 | def __init__(self, min=0, max=1): 77 | super().__init__() 78 | self.min, self.max = min, max 79 | 80 | def forward(self, x): 81 | return torch.clamp(x, self.min, self.max) 82 | 83 | 84 | class Activation(nn.Module): 85 | def __init__(self, name, **params): 86 | 87 | super().__init__() 88 | 89 | if name is None or name == "identity": 90 | self.activation = nn.Identity(**params) 91 | elif name == "sigmoid": 92 | self.activation = nn.Sigmoid() 93 | elif name == "softmax2d": 94 | self.activation = nn.Softmax(dim=1, **params) 95 | elif name == "softmax": 96 | self.activation = nn.Softmax(**params) 97 | elif name == "logsoftmax": 98 | self.activation = nn.LogSoftmax(**params) 99 | elif name == "tanh": 100 | self.activation = nn.Tanh() 101 | elif name == "argmax": 102 | self.activation = ArgMax(**params) 103 | elif name == "argmax2d": 104 | self.activation = ArgMax(dim=1, **params) 105 | elif name == "clamp": 106 | self.activation = Clamp(**params) 107 | elif callable(name): 108 | self.activation = name(**params) 109 | else: 110 | raise ValueError( 111 | f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/" 112 | f"argmax/argmax2d/clamp/None; got {name}" 113 | ) 114 | 115 | def forward(self, x): 116 | return self.activation(x) 117 | 118 | 119 | class Attention(nn.Module): 120 | def __init__(self, name, **params): 121 | super().__init__() 122 | 123 | if name is None: 124 | self.attention = nn.Identity(**params) 125 | elif name == "scse": 126 | self.attention = SCSEModule(**params) 127 | else: 128 | raise ValueError("Attention {} is not implemented".format(name)) 129 | 130 | def forward(self, x): 131 | return self.attention(x) 132 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/utils/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _take_channels(*xs, ignore_channels=None): 5 | if ignore_channels is None: 6 | return xs 7 | else: 8 | channels = [channel for channel in range(xs[0].shape[1]) if channel not in ignore_channels] 9 | xs = [torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device)) for x in xs] 10 | return xs 11 | 12 | 13 | def _threshold(x, threshold=None): 14 | if threshold is not None: 15 | return (x > threshold).type(x.dtype) 16 | else: 17 | return x 18 | 19 | 20 | def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 21 | """Calculate Intersection over Union between ground truth and prediction 22 | Args: 23 | pr (torch.Tensor): predicted tensor 24 | gt (torch.Tensor): ground truth tensor 25 | eps (float): epsilon to avoid zero division 26 | threshold: threshold for outputs binarization 27 | Returns: 28 | float: IoU (Jaccard) score 29 | """ 30 | 31 | pr = _threshold(pr, threshold=threshold) 32 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 33 | 34 | intersection = torch.sum(gt * pr) 35 | union = torch.sum(gt) + torch.sum(pr) - intersection + eps 36 | return (intersection + eps) / union 37 | 38 | 39 | jaccard = iou 40 | 41 | 42 | def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None): 43 | """Calculate F-score between ground truth and prediction 44 | Args: 45 | pr (torch.Tensor): predicted tensor 46 | gt (torch.Tensor): ground truth tensor 47 | beta (float): positive constant 48 | eps (float): epsilon to avoid zero division 49 | threshold: threshold for outputs binarization 50 | Returns: 51 | float: F score 52 | """ 53 | 54 | pr = _threshold(pr, threshold=threshold) 55 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 56 | 57 | tp = torch.sum(gt * pr) 58 | fp = torch.sum(pr) - tp 59 | fn = torch.sum(gt) - tp 60 | 61 | score = ((1 + beta**2) * tp + eps) / ((1 + beta**2) * tp + beta**2 * fn + fp + eps) 62 | 63 | return score 64 | 65 | 66 | def accuracy(pr, gt, threshold=0.5, ignore_channels=None): 67 | """Calculate accuracy score between ground truth and prediction 68 | Args: 69 | pr (torch.Tensor): predicted tensor 70 | gt (torch.Tensor): ground truth tensor 71 | eps (float): epsilon to avoid zero division 72 | threshold: threshold for outputs binarization 73 | Returns: 74 | float: precision score 75 | """ 76 | pr = _threshold(pr, threshold=threshold) 77 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 78 | 79 | tp = torch.sum(gt == pr, dtype=pr.dtype) 80 | score = tp / gt.view(-1).shape[0] 81 | return score 82 | 83 | 84 | def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 85 | """Calculate precision score between ground truth and prediction 86 | Args: 87 | pr (torch.Tensor): predicted tensor 88 | gt (torch.Tensor): ground truth tensor 89 | eps (float): epsilon to avoid zero division 90 | threshold: threshold for outputs binarization 91 | Returns: 92 | float: precision score 93 | """ 94 | 95 | pr = _threshold(pr, threshold=threshold) 96 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 97 | 98 | tp = torch.sum(gt * pr) 99 | fp = torch.sum(pr) - tp 100 | 101 | score = (tp + eps) / (tp + fp + eps) 102 | 103 | return score 104 | 105 | 106 | def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 107 | """Calculate Recall between ground truth and prediction 108 | Args: 109 | pr (torch.Tensor): A list of predicted elements 110 | gt (torch.Tensor): A list of elements that are to be predicted 111 | eps (float): epsilon to avoid zero division 112 | threshold: threshold for outputs binarization 113 | Returns: 114 | float: recall score 115 | """ 116 | 117 | pr = _threshold(pr, threshold=threshold) 118 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 119 | 120 | tp = torch.sum(gt * pr) 121 | fn = torch.sum(gt) - tp 122 | 123 | score = (tp + eps) / (tp + fn + eps) 124 | 125 | return score 126 | -------------------------------------------------------------------------------- /demo/Demo8_3D_TransUnet_Segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from wama_modules.Encoder import ResNetEncoder 4 | from wama_modules.Decoder import UNet_decoder 5 | from wama_modules.Head import SegmentationHead 6 | from wama_modules.utils import resizeTensor 7 | from transformers import ViTModel 8 | from wama_modules.utils import load_weights, tmp_class 9 | 10 | 11 | class TransUnet(nn.Module): 12 | def __init__(self, in_channel, label_category_dict, dim=2): 13 | super().__init__() 14 | 15 | # encoder 16 | Encoder_f_channel_list = [64, 128, 256, 512] 17 | self.encoder = ResNetEncoder( 18 | in_channel, 19 | stage_output_channels=Encoder_f_channel_list, 20 | stage_middle_channels=Encoder_f_channel_list, 21 | blocks=[1, 2, 3, 4], 22 | type='131', 23 | downsample_ration=[0.5, 0.5, 0.5, 0.5], 24 | dim=dim) 25 | 26 | # neck 27 | neck_out_channel = 768 28 | transformer = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k') 29 | configuration = transformer.config 30 | self.trans_size_3D = [8, 8, 4] 31 | self.trans_size = configuration.image_size = [ 32 | self.trans_size_3D[0], self.trans_size_3D[1]*self.trans_size_3D[2] 33 | ] 34 | configuration.patch_size = [1, 1] 35 | configuration.num_channels = Encoder_f_channel_list[-1] 36 | configuration.encoder_stride = 1 # just for MAE decoder, otherwise this paramater is not used 37 | self.neck = ViTModel(configuration, add_pooling_layer=False) 38 | 39 | pretrained_weights = transformer.state_dict() 40 | pretrained_weights['embeddings.position_embeddings'] = self.neck.state_dict()[ 41 | 'embeddings.position_embeddings'] 42 | pretrained_weights['embeddings.patch_embeddings.projection.weight'] = self.neck.state_dict()[ 43 | 'embeddings.patch_embeddings.projection.weight'] 44 | pretrained_weights['embeddings.patch_embeddings.projection.bias'] = self.neck.state_dict()[ 45 | 'embeddings.patch_embeddings.projection.bias'] 46 | self.neck = load_weights(self.neck, pretrained_weights) # reload pretrained weights 47 | 48 | # decoder 49 | Decoder_f_channel_list = [32, 64, 128] 50 | self.decoder = UNet_decoder( 51 | in_channels_list=Encoder_f_channel_list[:-1]+[neck_out_channel], 52 | skip_connection=[True, True, True], 53 | out_channels_list=Decoder_f_channel_list, 54 | dim=dim) 55 | 56 | # seg head 57 | self.seg_head = SegmentationHead( 58 | label_category_dict, 59 | Decoder_f_channel_list[0], 60 | dim=dim) 61 | 62 | def forward(self, x): 63 | # encoder forward 64 | multi_scale_encoder = self.encoder(x) 65 | 66 | # neck forward 67 | neck_input = resizeTensor(multi_scale_encoder[-1], size=self.trans_size_3D) 68 | neck_input = neck_input.reshape(neck_input.shape[0], neck_input.shape[1], *self.trans_size) # 3D to 2D 69 | f_neck = self.neck(neck_input) 70 | f_neck = f_neck.last_hidden_state 71 | f_neck = f_neck[:, 1:] # remove class token 72 | f_neck = f_neck.permute(0, 2, 1) 73 | f_neck = f_neck.reshape( 74 | f_neck.shape[0], 75 | f_neck.shape[1], 76 | self.trans_size[0], 77 | self.trans_size[1] 78 | ) # reshape 79 | f_neck = f_neck.reshape(f_neck.shape[0], f_neck.shape[1], *self.trans_size_3D) # 2D to 3D 80 | f_neck = resizeTensor(f_neck, size=multi_scale_encoder[-1].shape[2:]) 81 | multi_scale_encoder[-1] = f_neck 82 | 83 | # decoder forward 84 | multi_scale_decoder = self.decoder(multi_scale_encoder) 85 | f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:]) 86 | 87 | # seg_head forward 88 | logits = self.seg_head(f_for_seg) 89 | return logits 90 | 91 | 92 | if __name__ == '__main__': 93 | x = torch.ones([2, 1, 128, 128, 96]) 94 | label_category_dict = dict(organ=3, tumor=4) 95 | model = TransUnet(in_channel=1, label_category_dict=label_category_dict, dim=3) 96 | with torch.no_grad(): 97 | logits = model(x) 98 | print('multi_label predicted logits') 99 | _ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] 100 | 101 | # out 102 | # multi_label predicted logits 103 | # logits of organ : torch.Size([2, 3, 128, 128, 96]) 104 | # logits of tumor : torch.Size([2, 4, 128, 128, 96]) 105 | 106 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import functools 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | from .resnet import resnet_encoders 6 | from .dpn import dpn_encoders 7 | from .vgg import vgg_encoders 8 | from .senet import senet_encoders 9 | from .densenet import densenet_encoders 10 | from .inceptionresnetv2 import inceptionresnetv2_encoders 11 | from .inceptionv4 import inceptionv4_encoders 12 | from .efficientnet import efficient_net_encoders 13 | from .mobilenet import mobilenet_encoders 14 | from .xception import xception_encoders 15 | from .timm_efficientnet import timm_efficientnet_encoders 16 | from .timm_resnest import timm_resnest_encoders 17 | from .timm_res2net import timm_res2net_encoders 18 | from .timm_regnet import timm_regnet_encoders 19 | from .timm_sknet import timm_sknet_encoders 20 | from .timm_mobilenetv3 import timm_mobilenetv3_encoders 21 | from .timm_gernet import timm_gernet_encoders 22 | from .mix_transformer import mix_transformer_encoders 23 | 24 | from .timm_universal import TimmUniversalEncoder 25 | 26 | from ._preprocessing import preprocess_input 27 | 28 | encoders = {} 29 | encoders.update(resnet_encoders) 30 | encoders.update(dpn_encoders) 31 | encoders.update(vgg_encoders) 32 | encoders.update(senet_encoders) 33 | encoders.update(densenet_encoders) 34 | encoders.update(inceptionresnetv2_encoders) 35 | encoders.update(inceptionv4_encoders) 36 | encoders.update(efficient_net_encoders) 37 | encoders.update(mobilenet_encoders) 38 | encoders.update(xception_encoders) 39 | encoders.update(timm_efficientnet_encoders) 40 | encoders.update(timm_resnest_encoders) 41 | encoders.update(timm_res2net_encoders) 42 | encoders.update(timm_regnet_encoders) 43 | encoders.update(timm_sknet_encoders) 44 | encoders.update(timm_mobilenetv3_encoders) 45 | encoders.update(timm_gernet_encoders) 46 | encoders.update(mix_transformer_encoders) 47 | 48 | 49 | def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): 50 | 51 | if name.startswith("tu-"): 52 | name = name[3:] 53 | encoder = TimmUniversalEncoder( 54 | name=name, 55 | in_channels=in_channels, 56 | depth=depth, 57 | output_stride=output_stride, 58 | pretrained=weights is not None, 59 | **kwargs, 60 | ) 61 | return encoder 62 | 63 | try: 64 | Encoder = encoders[name]["encoder"] 65 | except KeyError: 66 | raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys()))) 67 | 68 | params = encoders[name]["params"] 69 | params.update(depth=depth) 70 | encoder = Encoder(**params) 71 | 72 | if weights is not None: 73 | try: 74 | settings = encoders[name]["pretrained_settings"][weights] 75 | except KeyError: 76 | raise KeyError( 77 | "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format( 78 | weights, 79 | name, 80 | list(encoders[name]["pretrained_settings"].keys()), 81 | ) 82 | ) 83 | encoder.load_state_dict(model_zoo.load_url(settings["url"])) 84 | 85 | encoder.set_in_channels(in_channels, pretrained=weights is not None) 86 | if output_stride != 32: 87 | encoder.make_dilated(output_stride) 88 | 89 | return encoder 90 | 91 | 92 | def get_encoder_names(): 93 | return list(encoders.keys()) 94 | 95 | 96 | def get_preprocessing_params(encoder_name, pretrained="imagenet"): 97 | 98 | if encoder_name.startswith("tu-"): 99 | encoder_name = encoder_name[3:] 100 | if encoder_name not in timm.models.registry._model_has_pretrained: 101 | raise ValueError(f"{encoder_name} does not have pretrained weights and preprocessing parameters") 102 | settings = timm.models.registry._model_default_cfgs[encoder_name] 103 | else: 104 | all_settings = encoders[encoder_name]["pretrained_settings"] 105 | if pretrained not in all_settings.keys(): 106 | raise ValueError("Available pretrained options {}".format(all_settings.keys())) 107 | settings = all_settings[pretrained] 108 | 109 | formatted_settings = {} 110 | formatted_settings["input_space"] = settings.get("input_space", "RGB") 111 | formatted_settings["input_range"] = list(settings.get("input_range", [0, 1])) 112 | formatted_settings["mean"] = list(settings.get("mean")) 113 | formatted_settings["std"] = list(settings.get("std")) 114 | 115 | return formatted_settings 116 | 117 | 118 | def get_preprocessing_fn(encoder_name, pretrained="imagenet"): 119 | params = get_preprocessing_params(encoder_name, pretrained=pretrained) 120 | return functools.partial(preprocess_input, **params) 121 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/timm_gernet.py: -------------------------------------------------------------------------------- 1 | from timm.models import ByoModelCfg, ByoBlockCfg, ByobNet 2 | 3 | from ._base import EncoderMixin 4 | import torch.nn as nn 5 | 6 | 7 | class GERNetEncoder(ByobNet, EncoderMixin): 8 | def __init__(self, out_channels, depth=5, **kwargs): 9 | super().__init__(**kwargs) 10 | self._depth = depth 11 | self._out_channels = out_channels 12 | self._in_channels = 3 13 | 14 | del self.head 15 | 16 | def get_stages(self): 17 | return [ 18 | nn.Identity(), 19 | self.stem, 20 | self.stages[0], 21 | self.stages[1], 22 | self.stages[2], 23 | nn.Sequential(self.stages[3], self.stages[4], self.final_conv), 24 | ] 25 | 26 | def forward(self, x): 27 | stages = self.get_stages() 28 | 29 | features = [] 30 | for i in range(self._depth + 1): 31 | x = stages[i](x) 32 | features.append(x) 33 | 34 | return features 35 | 36 | def load_state_dict(self, state_dict, **kwargs): 37 | state_dict.pop("head.fc.weight", None) 38 | state_dict.pop("head.fc.bias", None) 39 | super().load_state_dict(state_dict, **kwargs) 40 | 41 | 42 | regnet_weights = { 43 | "timm-gernet_s": { 44 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth", # noqa 45 | }, 46 | "timm-gernet_m": { 47 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth", # noqa 48 | }, 49 | "timm-gernet_l": { 50 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth", # noqa 51 | }, 52 | } 53 | 54 | pretrained_settings = {} 55 | for model_name, sources in regnet_weights.items(): 56 | pretrained_settings[model_name] = {} 57 | for source_name, source_url in sources.items(): 58 | pretrained_settings[model_name][source_name] = { 59 | "url": source_url, 60 | "input_range": [0, 1], 61 | "mean": [0.485, 0.456, 0.406], 62 | "std": [0.229, 0.224, 0.225], 63 | "num_classes": 1000, 64 | } 65 | 66 | timm_gernet_encoders = { 67 | "timm-gernet_s": { 68 | "encoder": GERNetEncoder, 69 | "pretrained_settings": pretrained_settings["timm-gernet_s"], 70 | "params": { 71 | "out_channels": (3, 13, 48, 48, 384, 1920), 72 | "cfg": ByoModelCfg( 73 | blocks=( 74 | ByoBlockCfg(type="basic", d=1, c=48, s=2, gs=0, br=1.0), 75 | ByoBlockCfg(type="basic", d=3, c=48, s=2, gs=0, br=1.0), 76 | ByoBlockCfg(type="bottle", d=7, c=384, s=2, gs=0, br=1 / 4), 77 | ByoBlockCfg(type="bottle", d=2, c=560, s=2, gs=1, br=3.0), 78 | ByoBlockCfg(type="bottle", d=1, c=256, s=1, gs=1, br=3.0), 79 | ), 80 | stem_chs=13, 81 | stem_pool=None, 82 | num_features=1920, 83 | ), 84 | }, 85 | }, 86 | "timm-gernet_m": { 87 | "encoder": GERNetEncoder, 88 | "pretrained_settings": pretrained_settings["timm-gernet_m"], 89 | "params": { 90 | "out_channels": (3, 32, 128, 192, 640, 2560), 91 | "cfg": ByoModelCfg( 92 | blocks=( 93 | ByoBlockCfg(type="basic", d=1, c=128, s=2, gs=0, br=1.0), 94 | ByoBlockCfg(type="basic", d=2, c=192, s=2, gs=0, br=1.0), 95 | ByoBlockCfg(type="bottle", d=6, c=640, s=2, gs=0, br=1 / 4), 96 | ByoBlockCfg(type="bottle", d=4, c=640, s=2, gs=1, br=3.0), 97 | ByoBlockCfg(type="bottle", d=1, c=640, s=1, gs=1, br=3.0), 98 | ), 99 | stem_chs=32, 100 | stem_pool=None, 101 | num_features=2560, 102 | ), 103 | }, 104 | }, 105 | "timm-gernet_l": { 106 | "encoder": GERNetEncoder, 107 | "pretrained_settings": pretrained_settings["timm-gernet_l"], 108 | "params": { 109 | "out_channels": (3, 32, 128, 192, 640, 2560), 110 | "cfg": ByoModelCfg( 111 | blocks=( 112 | ByoBlockCfg(type="basic", d=1, c=128, s=2, gs=0, br=1.0), 113 | ByoBlockCfg(type="basic", d=2, c=192, s=2, gs=0, br=1.0), 114 | ByoBlockCfg(type="bottle", d=6, c=640, s=2, gs=0, br=1 / 4), 115 | ByoBlockCfg(type="bottle", d=5, c=640, s=2, gs=1, br=3.0), 116 | ByoBlockCfg(type="bottle", d=4, c=640, s=1, gs=1, br=3.0), 117 | ), 118 | stem_chs=32, 119 | stem_pool=None, 120 | num_features=2560, 121 | ), 122 | }, 123 | }, 124 | } 125 | -------------------------------------------------------------------------------- /Document_allmodules.md: -------------------------------------------------------------------------------- 1 | 2 | # All modules and functions 3 | 4 | ## 1 `wama_modules.BaseModule` 5 | 6 | ### 1.1 Pooling 7 | - `GlobalAvgPool` Global average pooling 8 | - `GlobalMaxPool` Global maximum pooling 9 | - `GlobalMaxAvgPool` GlobalMaxAvgPool = (GlobalAvgPool + GlobalMaxPool) / 2. 10 | 11 |
12 | Click here to see demo code 13 | 14 | ```python 15 | """ demo """ 16 | # import libs 17 | import torch 18 | from wama_modules.BaseModule import GlobalAvgPool, GlobalMaxPool, GlobalMaxAvgPool 19 | 20 | # make tensor 21 | inputs1D = torch.ones([3,12,13]) # 1D 22 | inputs2D = torch.ones([3,12,13,13]) # 2D 23 | inputs3D = torch.ones([3,12,13,13,13]) # 3D 24 | 25 | # build layer 26 | GAP = GlobalAvgPool() 27 | GMP = GlobalMaxPool() 28 | GAMP = GlobalMaxAvgPool() 29 | 30 | # test GAP & GMP & GAMP 31 | print(inputs1D.shape, GAP(inputs1D).shape) 32 | print(inputs2D.shape, GAP(inputs2D).shape) 33 | print(inputs3D.shape, GAP(inputs3D).shape) 34 | 35 | print(inputs1D.shape, GMP(inputs1D).shape) 36 | print(inputs2D.shape, GMP(inputs2D).shape) 37 | print(inputs3D.shape, GMP(inputs3D).shape) 38 | 39 | print(inputs1D.shape, GAMP(inputs1D).shape) 40 | print(inputs2D.shape, GAMP(inputs2D).shape) 41 | print(inputs3D.shape, GAMP(inputs3D).shape) 42 | ``` 43 |
44 | 45 | 46 | ### 1.2 Norm&Activation 47 | - `customLayerNorm` a custom implementation of layer normalization 48 | - `MakeNorm` make normalization layer, includes BN / GN / IN / LN 49 | - `MakeActive` make activation layer, includes Relu / LeakyRelu 50 | - `MakeConv` make 1D / 2D / 3D convolutional layer 51 | 52 |
53 | Click here to see demo code 54 | 55 | ```python 56 | """ demo """ 57 | ``` 58 |
59 | 60 | 61 | 62 | ### 1.3 Conv 63 | - `ConvNormActive` 'Convolution→Normalization→Activation', used in VGG or ResNet 64 | - `NormActiveConv` 'Normalization→Activation→Convolution', used in DenseNet 65 | - `VGGBlock` the basic module in VGG 66 | - `VGGStage` a VGGStage = few VGGBlocks 67 | - `ResBlock` the basic module in ResNet 68 | - `ResStage` a ResStage = few ResBlocks 69 | - `DenseLayer` the basic module in DenseNet 70 | - `DenseBlock` a DenseBlock = few DenseLayers 71 | 72 |
73 | Click here to see demo code 74 | 75 | ```python 76 | """ demo """ 77 | ``` 78 |
79 | 80 | ## 2 `wama_modules.utils` 81 | - `resizeTensor` scale torch tensor, similar to scipy's zoom 82 | - `tensor2array` transform tensor to ndarray 83 | - `load_weights` load torch weights and print loading details(miss keys and match keys) 84 | 85 |
86 | Click here to see demo code 87 | 88 | ```python 89 | """ demo """ 90 | ``` 91 |
92 | 93 | 94 | ## 3 `wama_modules.Attention` 95 | - `SCSEModule` 96 | - `NonLocal` 97 | 98 |
99 | Click here to see demo code 100 | 101 | ```python 102 | """ demo """ 103 | ``` 104 |
105 | 106 | 107 | ## 4 `wama_modules.Encoder` 108 | - `VGGEncoder` 109 | - `ResNetEncoder` 110 | - `DenseNetEncoder` 111 | - `???` 112 | 113 |
114 | Click here to see demo code 115 | 116 | ```python 117 | """ demo """ 118 | ``` 119 |
120 | 121 | 122 | ## 5 `wama_modules.Decoder` 123 | - `UNet_decoder` 124 | 125 |
126 | Click here to see demo code 127 | 128 | ```python 129 | """ demo """ 130 | ``` 131 |
132 | 133 | 134 | ## 6 `wama_modules.Neck` 135 | - `FPN` 136 | 137 |
138 | Click here to see demo code 139 | 140 | ```python 141 | """ demo """ 142 | import torch 143 | from wama_modules.Neck import FPN 144 | 145 | # make multi-scale feature maps 146 | featuremaps = [ 147 | torch.ones([3,16,32,32,32]), 148 | torch.ones([3,32,24,24,24]), 149 | torch.ones([3,64,16,16,16]), 150 | torch.ones([3,128,8,8,8]), 151 | ] 152 | 153 | # build FPN 154 | fpn_AddSmall2Big = FPN(in_channels_list=[16, 32, 64, 128], 155 | c1=128, 156 | c2=256, 157 | active='relu', 158 | norm='bn', 159 | gn_c=8, 160 | mode='AddSmall2Big', 161 | dim=3,) 162 | fpn_AddBig2Small = FPN(in_channels_list=[16, 32, 64, 128], 163 | c1=128, 164 | c2=256, 165 | active='relu', 166 | norm='bn', 167 | gn_c=8, 168 | mode='AddBig2Small', # Add big size feature to small size feature, for classification 169 | dim=3,) 170 | 171 | # forward 172 | f_listA = fpn_AddSmall2Big(featuremaps) 173 | f_listB = fpn_AddBig2Small(featuremaps) 174 | _ = [print(i.shape) for i in featuremaps] 175 | _ = [print(i.shape) for i in f_listA] 176 | _ = [print(i.shape) for i in f_listB] 177 | ``` 178 |
179 | 180 | 181 | ## 7 `wama_modules.Transformer` 182 | - `FeedForward` 183 | - `MultiHeadAttention` 184 | - `TransformerEncoderLayer` 185 | - `TransformerDecoderLayer` 186 | 187 |
188 | Click here to see demo code 189 | 190 | ```python 191 | """ demo """ 192 | ``` 193 |
194 | -------------------------------------------------------------------------------- /wama_modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import pickle 6 | 7 | 8 | class tmp_class(): 9 | """ 10 | for debug 11 | """ 12 | def __init__(self,): 13 | super().__init__() 14 | 15 | 16 | def resizeTensor(x, scale_factor=None, size=None): 17 | """ 18 | resize for 1D\2D\3D tensor (1D → signal 2D → image, 3D → volume) 19 | 20 | :param x: 1D [bz,c,l] 2D [bz,c,w,h] 3D [bz,c,w,h,l] 21 | :param scale_factor: 1D [2.,] 2D [2.,2.,] 3D [3.,3.,3.,] 22 | :param size: 2D [256.,256,m] or torch.ones([256,256]).shape 23 | :return: 24 | 25 | # 1D demo: 26 | x = torch.ones([3,1,256]) 27 | y = torch.ones([3,1,128]) 28 | x1 = resizeTensor(x, scale_factor=[2.,]) 29 | print(x1.shape) 30 | x1 = resizeTensor(x, size=y.shape[-1:]) 31 | print(x1.shape) 32 | 33 | # 2D demo: 34 | x = torch.ones([3,1,256,256]) 35 | y = torch.ones([3,1,256,128]) 36 | x1 = resizeTensor(x, scale_factor=[2.,2.]) 37 | print(x1.shape) 38 | x1 = resizeTensor(x, size=y.shape[-2:]) 39 | print(x1.shape) 40 | 41 | # 3D demo: 42 | x = torch.ones([3,1,256,256,256]) 43 | y = torch.ones([3,1,256,128,128]) 44 | x1 = resizeTensor(x, scale_factor=[2.,2.,2.]) 45 | print(x1.shape) 46 | x1 = resizeTensor(x, size=y.shape[-3:]) 47 | print(x1.shape) 48 | 49 | """ 50 | if len(x.shape) == 3: 51 | return F.interpolate(x, scale_factor=scale_factor, size=size, 52 | mode='linear', 53 | align_corners=True) 54 | if len(x.shape) == 4: 55 | return F.interpolate(x, scale_factor=scale_factor, size=size, 56 | mode='bicubic', 57 | align_corners=True) 58 | elif len(x.shape) == 5: 59 | return F.interpolate(x, scale_factor=scale_factor, size=size, 60 | mode='trilinear', 61 | align_corners=True) 62 | 63 | 64 | def tensor2array(tensor): 65 | return tensor.data.cpu().numpy() 66 | 67 | 68 | def load_weights(model, state_dict, drop_modelDOT=False, silence=False): 69 | if drop_modelDOT: 70 | new_dict = {} 71 | for k, v in state_dict.items(): 72 | new_dict[k[7:]] = v 73 | state_dict = new_dict 74 | net_dict = model.state_dict() # model dict 75 | pretrain_dict = {k: v for k, v in state_dict.items()} # pretrain dict 76 | InPretrain_InModel_dict = {k: v for k, v in state_dict.items() if k in net_dict.keys()} 77 | InPretrain_NotInModel_dict = {k: v for k, v in state_dict.items() if k not in net_dict.keys()} 78 | NotInPretrain_InModel_dict = {k: v for k, v in net_dict.items() if k not in state_dict.keys()} 79 | if not silence: 80 | print('-' * 200) 81 | print('keys ( Current model,C ) ', len(net_dict.keys()), net_dict.keys()) 82 | print('keys ( Pre-trained ,P ) ', len(pretrain_dict.keys()), pretrain_dict.keys()) 83 | print('keys ( In C & In P ) ', len(InPretrain_InModel_dict.keys()), InPretrain_InModel_dict.keys()) 84 | print('keys ( NoIn C & In P ) ', len(InPretrain_NotInModel_dict.keys()), InPretrain_NotInModel_dict.keys()) 85 | print('keys ( In C & NoIn P ) ', len(NotInPretrain_InModel_dict.keys()), NotInPretrain_InModel_dict.keys()) 86 | print('-' * 200) 87 | print('Pretrained keys :', len(InPretrain_InModel_dict.keys()), InPretrain_InModel_dict.keys()) 88 | print('Non-Pretrained keys:', len(NotInPretrain_InModel_dict.keys()), NotInPretrain_InModel_dict.keys()) 89 | print('-' * 200) 90 | net_dict.update(InPretrain_InModel_dict) 91 | model.load_state_dict(net_dict) 92 | return model 93 | 94 | 95 | def MaxMinNorm(array, FirstDimBATCH = True): 96 | """ 97 | :param array: 98 | :param FirstDimBATCH: bool, is the first dim batch? True or False 99 | :return: 100 | 101 | # demo for numpy ndarray 102 | 103 | 104 | 105 | 106 | # demo for torch tensor 107 | 108 | 109 | 110 | 111 | 112 | """ 113 | pass 114 | 115 | 116 | def mat2gray(image): 117 | """ 118 | 归一化函数(线性归一化) 119 | :param image: ndarray 120 | :return: 121 | """ 122 | # as dtype = np.float32 123 | image = image.astype(np.float32) 124 | image = (image - np.min(image)) / (np.max(image) - np.min(image) + 1e-14) 125 | return image 126 | 127 | 128 | def save_as_pkl(save_path, obj): 129 | data_output = open(save_path, 'wb') 130 | pickle.dump(obj, data_output) 131 | data_output.close() 132 | 133 | def load_from_pkl(load_path): 134 | data_input = open(load_path, 'rb') 135 | read_data = pickle.load(data_input) 136 | data_input.close() 137 | return read_data 138 | 139 | 140 | import matplotlib.pyplot as plt 141 | def show2D(img): 142 | plt.imshow(img) 143 | plt.show() 144 | 145 | # try: 146 | # from mayavi import mlab 147 | # def show3D(img3D): 148 | # vol = mlab.pipeline.volume(mlab.pipeline.scalar_field(img3D), name='3-d ultrasound ') 149 | # mlab.colorbar(orientation='vertical') 150 | # mlab.show() 151 | # except: 152 | # pass 153 | -------------------------------------------------------------------------------- /wama_modules/Decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from wama_modules.utils import tensor2array, resizeTensor, tmp_class 5 | from wama_modules.BaseModule import ConvNormActive 6 | 7 | 8 | class UNet_decoder(nn.Module): 9 | def __init__(self, 10 | in_channels_list=[64, 128, 256, 512], # from shallow to deep 11 | skip_connection=[True, True, True], # from shallow to deep 12 | out_channels_list=[12, 13, 14], # from shallow to deep 13 | norm='bn', 14 | gn_c=8, 15 | dim=2, 16 | ): 17 | super().__init__() 18 | self._skip_connection = skip_connection[::-1] # from deep to shallow 19 | _skip_channels_list = in_channels_list[-2::-1] # from deep to shallow [256, 128, 64] 20 | _out_channels_list = out_channels_list[::-1] # from deep to shallow [14, 13, 12] 21 | _in_conv_list = [in_channels_list[-1]] + _out_channels_list[:-1] # from deep to shallow 22 | self.docoder_conv_list = nn.ModuleList([]) 23 | for stage, _out_channels in enumerate(_out_channels_list): 24 | if self._skip_connection[stage]: 25 | _in_channel = _in_conv_list[stage] + _skip_channels_list[stage] 26 | else: 27 | _in_channel = _in_conv_list[stage] 28 | _out_channel = _out_channels_list[stage] 29 | self.docoder_conv_list.append( 30 | nn.Sequential( 31 | ConvNormActive(_in_channel, _out_channel, kernel_size=3, norm=norm, gn_c=gn_c, dim=dim), 32 | ConvNormActive(_out_channel, _out_channel, kernel_size=3, norm=norm, gn_c=gn_c, dim=dim), 33 | ) 34 | ) 35 | 36 | def forward(self, f_list): 37 | """ 38 | :return: decoder_f_list, feature list from shallow to deep, and decoder_f_list[0] can be used for seg head 39 | 40 | # demo 41 | 42 | # 1D ------------------------------------------------------------- 43 | f_list = [ 44 | torch.ones([3,64,128]), 45 | torch.ones([3,128,64]), 46 | torch.ones([3,256,32]), 47 | torch.ones([3,512,8]), 48 | ] 49 | 50 | decoder = UNet_decoder( 51 | in_channels_list=[64, 128, 256, 512], # from shallow to deep 52 | skip_connection=[False, True, True], # from shallow to deep 53 | out_channels_list=[12, 13, 14], # from shallow to deep 54 | norm='bn', 55 | gn_c=8, 56 | dim=1 57 | ) 58 | 59 | decoder_f_list = decoder(f_list) 60 | _ = [print(i.shape) for i in decoder_f_list] 61 | 62 | # 2D ------------------------------------------------------------- 63 | f_list = [ 64 | torch.ones([3,64,128,128]), 65 | torch.ones([3,128,64,64]), 66 | torch.ones([3,256,32,32]), 67 | torch.ones([3,512,8,8]), 68 | ] 69 | 70 | decoder = UNet_decoder( 71 | in_channels_list=[64, 128, 256, 512], # from shallow to deep 72 | skip_connection=[False, True, True], # from shallow to deep 73 | out_channels_list=[12, 13, 14], # from shallow to deep 74 | norm='bn', 75 | gn_c=8, 76 | dim=2 77 | ) 78 | 79 | decoder_f_list = decoder(f_list) 80 | _ = [print(i.shape) for i in decoder_f_list] 81 | 82 | # 3D ------------------------------------------------------------- 83 | f_list = [ 84 | torch.ones([3,64,128,128,128]), 85 | torch.ones([3,128,64,64,64]), 86 | torch.ones([3,256,32,32,32]), 87 | torch.ones([3,512,8,8,8]), 88 | ] 89 | 90 | decoder = UNet_decoder( 91 | in_channels_list=[64, 128, 256, 512], # from shallow to deep 92 | skip_connection=[False, True, True], # from shallow to deep 93 | out_channels_list=[12, 13, 14], # from shallow to deep 94 | norm='bn', 95 | gn_c=8, 96 | dim=3 97 | ) 98 | 99 | decoder_f_list = decoder(f_list) 100 | _ = [print(i.shape) for i in decoder_f_list] 101 | 102 | """ 103 | _f_list = f_list[::-1] 104 | feature = _f_list[0] 105 | _f_list = _f_list[1:] 106 | decoder_f_list = [] 107 | for stage, conv in enumerate(self.docoder_conv_list): 108 | if self._skip_connection[stage]: 109 | _in_feature = torch.cat([resizeTensor(feature, size=_f_list[stage].shape[2:]), _f_list[stage]], 1) 110 | else: 111 | _in_feature = resizeTensor(feature, size=_f_list[stage].shape[2:]) 112 | feature = conv(_in_feature) 113 | decoder_f_list.append(feature) 114 | decoder_f_list = decoder_f_list[::-1] 115 | return decoder_f_list # from shallow to deep, and decoder_f_list[0] can be used for seg head 116 | 117 | # psp 118 | 119 | 120 | # deeplabv3+ 121 | # try this https://blog.csdn.net/m0_51436734/article/details/124073901 122 | 123 | 124 | # NestedUNet(Unet++) 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/squeezenet.py: -------------------------------------------------------------------------------- 1 | '''SqueezeNet in PyTorch. 2 | 3 | See the paper "SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size" for more details. 4 | ''' 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.init as init 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | from functools import partial 13 | 14 | __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] 15 | 16 | 17 | class Fire(nn.Module): 18 | 19 | def __init__(self, inplanes, squeeze_planes, 20 | expand1x1_planes, expand3x3_planes, 21 | use_bypass=False): 22 | super(Fire, self).__init__() 23 | self.use_bypass = use_bypass 24 | self.inplanes = inplanes 25 | self.relu = nn.ReLU(inplace=True) 26 | self.squeeze = nn.Conv3d(inplanes, squeeze_planes, kernel_size=1) 27 | self.squeeze_bn = nn.BatchNorm3d(squeeze_planes) 28 | self.expand1x1 = nn.Conv3d(squeeze_planes, expand1x1_planes, 29 | kernel_size=1) 30 | self.expand1x1_bn = nn.BatchNorm3d(expand1x1_planes) 31 | self.expand3x3 = nn.Conv3d(squeeze_planes, expand3x3_planes, 32 | kernel_size=3, padding=1) 33 | self.expand3x3_bn = nn.BatchNorm3d(expand3x3_planes) 34 | 35 | def forward(self, x): 36 | out = self.squeeze(x) 37 | out = self.squeeze_bn(out) 38 | out = self.relu(out) 39 | 40 | out1 = self.expand1x1(out) 41 | out1 = self.expand1x1_bn(out1) 42 | 43 | out2 = self.expand3x3(out) 44 | out2 = self.expand3x3_bn(out2) 45 | 46 | out = torch.cat([out1, out2], 1) 47 | if self.use_bypass: 48 | out += x 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class SqueezeNet(nn.Module): 55 | 56 | def __init__(self,): 57 | super(SqueezeNet, self).__init__() 58 | # if version not in [1.0, 1.1]: 59 | # raise ValueError("Unsupported SqueezeNet version {version}:" 60 | # "1.0 or 1.1 expected".format(version=version)) 61 | 62 | # if version == 1.1: 63 | if True: 64 | self.features = nn.Sequential( 65 | nn.Conv3d(3, 64, kernel_size=3, stride=(1,2,2), padding=(1,1,1)), # 0 66 | nn.BatchNorm3d(64), # 1 67 | nn.ReLU(inplace=True), # 2 68 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1), # todo 3 69 | Fire(64, 16, 64, 64), # 4 70 | Fire(128, 16, 64, 64, use_bypass=True), # 5 71 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1), # todo 6 72 | Fire(128, 32, 128, 128), # 7 73 | Fire(256, 32, 128, 128, use_bypass=True), # 8 74 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1), # todo 9 75 | Fire(256, 48, 192, 192), # 10 76 | Fire(384, 48, 192, 192, use_bypass=True), # 11 77 | nn.MaxPool3d(kernel_size=3, stride=2, padding=1), # todo 12 78 | Fire(384, 64, 256, 256), # 13 79 | Fire(512, 64, 256, 256, use_bypass=True), # todo 14 80 | ) 81 | # Final convolution is initialized differently form the rest 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv3d): 85 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 86 | elif isinstance(m, nn.BatchNorm3d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | 90 | 91 | def forward(self, x): 92 | f_list = [] 93 | for i in range(len(self.features)): 94 | x = self.features[i](x) 95 | if i in [3,6,9,12,14]: 96 | f_list.append(x) 97 | return f_list 98 | 99 | 100 | def get_fine_tuning_parameters(model, ft_portion): 101 | if ft_portion == "complete": 102 | return model.parameters() 103 | 104 | elif ft_portion == "last_layer": 105 | ft_module_names = [] 106 | ft_module_names.append('classifier') 107 | 108 | parameters = [] 109 | for k, v in model.named_parameters(): 110 | for ft_module in ft_module_names: 111 | if ft_module in k: 112 | parameters.append({'params': v}) 113 | break 114 | else: 115 | parameters.append({'params': v, 'lr': 0.0}) 116 | return parameters 117 | 118 | else: 119 | raise ValueError("Unsupported ft_portion: 'complete' or 'last_layer' expected") 120 | 121 | 122 | def get_model(**kwargs): 123 | """ 124 | Returns the model. 125 | """ 126 | model = SqueezeNet(**kwargs) 127 | return model 128 | 129 | 130 | if __name__ == '__main__': 131 | model = SqueezeNet(version=1.1, sample_size = 112, sample_duration = 16, num_classes=600) 132 | model = model.cuda() 133 | model = nn.DataParallel(model, device_ids=None) 134 | print(model) 135 | 136 | input_var = Variable(torch.randn(8, 3, 16, 112, 112)) 137 | output = model(input_var) 138 | print(output.shape) 139 | -------------------------------------------------------------------------------- /wama_modules/Neck.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from wama_modules.BaseModule import ConvNormActive 4 | from wama_modules.utils import resizeTensor 5 | 6 | 7 | class FPN(nn.Module): 8 | def __init__(self, 9 | in_channels_list=[16, 32, 64, 128], 10 | c1=128, 11 | c2=256, 12 | active='relu', 13 | norm='bn', 14 | gn_c=8, 15 | mode='AddSmall2Big', # AddSmall2Big or AddBig2Small(much better or classification tasks) 16 | dim=2, 17 | ): 18 | super().__init__() 19 | self.mode = mode 20 | 21 | self.conv1_list = nn.ModuleList([ 22 | ConvNormActive(in_channels, c1, kernel_size=1, norm=norm, active=active, gn_c=gn_c, dim=dim, padding=0) 23 | for in_channels in in_channels_list 24 | ]) 25 | self.conv2_list = nn.ModuleList([ 26 | ConvNormActive(c1, c2, kernel_size=3, norm=norm, active=active, gn_c=gn_c, dim=dim, padding=1) 27 | for _ in range(len(in_channels_list)) 28 | ]) 29 | 30 | def forward(self, x_list): 31 | """ 32 | :param x_list: multi scale feature maps, from shallow(big size) to deep(small size) 33 | :return: 34 | 35 | # demo 36 | # 1D 37 | x_list = [ 38 | torch.ones([3,16,32]), 39 | torch.ones([3,32,24]), 40 | torch.ones([3,64,16]), 41 | torch.ones([3,128,8]), 42 | ] 43 | fpn = FPN(in_channels_list=[16, 32, 64, 128], 44 | c1=128, 45 | c2=256, 46 | active='relu', 47 | norm='bn', 48 | gn_c=8, 49 | mode='AddSmall2Big', 50 | dim=1,) 51 | fpn = FPN(in_channels_list=[16, 32, 64, 128], 52 | c1=128, 53 | c2=256, 54 | active='relu', 55 | norm='bn', 56 | gn_c=8, 57 | mode='AddBig2Small', # revserse, for classification 58 | dim=1,) 59 | f_list = fpn(x_list) 60 | _ = [print(i.shape) for i in x_list] 61 | _ = [print(i.shape) for i in f_list] 62 | 63 | # 2D 64 | x_list = [ 65 | torch.ones([3,16,32,32]), 66 | torch.ones([3,32,24,24]), 67 | torch.ones([3,64,16,16]), 68 | torch.ones([3,128,8,8]), 69 | ] 70 | fpn = FPN(in_channels_list=[16, 32, 64, 128], 71 | c1=128, 72 | c2=256, 73 | active='relu', 74 | norm='bn', 75 | gn_c=8, 76 | mode='AddSmall2Big', 77 | dim=2,) 78 | fpn = FPN(in_channels_list=[16, 32, 64, 128], 79 | c1=128, 80 | c2=256, 81 | active='relu', 82 | norm='bn', 83 | gn_c=8, 84 | mode='AddBig2Small', # revserse, for classification 85 | dim=2,) 86 | f_list = fpn(x_list) 87 | _ = [print(i.shape) for i in x_list] 88 | _ = [print(i.shape) for i in f_list] 89 | 90 | 91 | # 3D 92 | x_list = [ 93 | torch.ones([3,16,32,32,32]), 94 | torch.ones([3,32,24,24,24]), 95 | torch.ones([3,64,16,16,16]), 96 | torch.ones([3,128,8,8,8]), 97 | ] 98 | fpn = FPN(in_channels_list=[16, 32, 64, 128], 99 | c1=128, 100 | c2=256, 101 | active='relu', 102 | norm='bn', 103 | gn_c=8, 104 | mode='AddSmall2Big', 105 | dim=3,) 106 | fpn = FPN(in_channels_list=[16, 32, 64, 128], 107 | c1=128, 108 | c2=256, 109 | active='relu', 110 | norm='bn', 111 | gn_c=8, 112 | mode='AddBig2Small', # revserse, for classification 113 | dim=3,) 114 | f_list = fpn(x_list) 115 | _ = [print(i.shape) for i in x_list] 116 | _ = [print(i.shape) for i in f_list] 117 | 118 | """ 119 | f_list = [self.conv1_list[index](f) for index, f in enumerate(x_list)] 120 | 121 | if self.mode == 'AddSmall2Big': 122 | f_list_2 = [] 123 | x = f_list[-1] 124 | f_list_2.append(x) 125 | for index in range(len(f_list)-1): 126 | # print(f_list[-(index+2)].shape[2:]) 127 | x = f_list[-(index + 2)] + resizeTensor(x, size=f_list[-(index+2)].shape[2:]) 128 | f_list_2.append(x) 129 | f_list_2 = f_list_2[::-1] 130 | elif self.mode == 'AddBig2Small': 131 | f_list_2 = [] 132 | x = f_list[0] 133 | f_list_2.append(x) 134 | for index in range(len(f_list) - 1): 135 | # print(f_list[index + 1].shape[2:]) 136 | x = f_list[index + 1] + resizeTensor(x, size=f_list[index + 1].shape[2:]) 137 | f_list_2.append(x) 138 | 139 | return_list = [self.conv2_list[index](f) for index, f in enumerate(f_list_2)] 140 | 141 | return return_list 142 | 143 | 144 | 145 | 146 | # UCtrans的 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/densenet.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import re 27 | import torch.nn as nn 28 | 29 | from pretrainedmodels.models.torchvision_models import pretrained_settings 30 | from torchvision.models.densenet import DenseNet 31 | 32 | from ._base import EncoderMixin 33 | 34 | 35 | class TransitionWithSkip(nn.Module): 36 | def __init__(self, module): 37 | super().__init__() 38 | self.module = module 39 | 40 | def forward(self, x): 41 | for module in self.module: 42 | x = module(x) 43 | if isinstance(module, nn.ReLU): 44 | skip = x 45 | return x, skip 46 | 47 | 48 | class DenseNetEncoder(DenseNet, EncoderMixin): 49 | def __init__(self, out_channels, depth=5, **kwargs): 50 | super().__init__(**kwargs) 51 | self._out_channels = out_channels 52 | self._depth = depth 53 | self._in_channels = 3 54 | del self.classifier 55 | 56 | def make_dilated(self, *args, **kwargs): 57 | raise ValueError("DenseNet encoders do not support dilated mode " "due to pooling operation for downsampling!") 58 | 59 | def get_stages(self): 60 | return [ 61 | nn.Identity(), 62 | nn.Sequential(self.features.conv0, self.features.norm0, self.features.relu0), 63 | nn.Sequential( 64 | self.features.pool0, 65 | self.features.denseblock1, 66 | TransitionWithSkip(self.features.transition1), 67 | ), 68 | nn.Sequential(self.features.denseblock2, TransitionWithSkip(self.features.transition2)), 69 | nn.Sequential(self.features.denseblock3, TransitionWithSkip(self.features.transition3)), 70 | nn.Sequential(self.features.denseblock4, self.features.norm5), 71 | ] 72 | 73 | def forward(self, x): 74 | 75 | stages = self.get_stages() 76 | 77 | features = [] 78 | for i in range(self._depth + 1): 79 | x = stages[i](x) 80 | if isinstance(x, (list, tuple)): 81 | x, skip = x 82 | features.append(skip) 83 | else: 84 | features.append(x) 85 | 86 | return features 87 | 88 | def load_state_dict(self, state_dict): 89 | pattern = re.compile( 90 | r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" 91 | ) 92 | for key in list(state_dict.keys()): 93 | res = pattern.match(key) 94 | if res: 95 | new_key = res.group(1) + res.group(2) 96 | state_dict[new_key] = state_dict[key] 97 | del state_dict[key] 98 | 99 | # remove linear 100 | state_dict.pop("classifier.bias", None) 101 | state_dict.pop("classifier.weight", None) 102 | 103 | super().load_state_dict(state_dict) 104 | 105 | 106 | densenet_encoders = { 107 | "densenet121": { 108 | "encoder": DenseNetEncoder, 109 | "pretrained_settings": pretrained_settings["densenet121"], 110 | "params": { 111 | "out_channels": (3, 64, 256, 512, 1024, 1024), 112 | "num_init_features": 64, 113 | "growth_rate": 32, 114 | "block_config": (6, 12, 24, 16), 115 | }, 116 | }, 117 | "densenet169": { 118 | "encoder": DenseNetEncoder, 119 | "pretrained_settings": pretrained_settings["densenet169"], 120 | "params": { 121 | "out_channels": (3, 64, 256, 512, 1280, 1664), 122 | "num_init_features": 64, 123 | "growth_rate": 32, 124 | "block_config": (6, 12, 32, 32), 125 | }, 126 | }, 127 | "densenet201": { 128 | "encoder": DenseNetEncoder, 129 | "pretrained_settings": pretrained_settings["densenet201"], 130 | "params": { 131 | "out_channels": (3, 64, 256, 512, 1792, 1920), 132 | "num_init_features": 64, 133 | "growth_rate": 32, 134 | "block_config": (6, 12, 48, 32), 135 | }, 136 | }, 137 | "densenet161": { 138 | "encoder": DenseNetEncoder, 139 | "pretrained_settings": pretrained_settings["densenet161"], 140 | "params": { 141 | "out_channels": (3, 96, 384, 768, 2112, 2208), 142 | "num_init_features": 96, 143 | "growth_rate": 48, 144 | "block_config": (6, 12, 36, 24), 145 | }, 146 | }, 147 | } 148 | -------------------------------------------------------------------------------- /demo/Demo_BilinearPooling.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | 7 | import torch.fft as afft 8 | 9 | 10 | class CompactBilinearPooling(nn.Module): 11 | """ 12 | from https://github.com/DeepInsight-PCALab/CompactBilinearPooling-Pytorch 13 | 14 | Compute compact bilinear pooling over two bottom inputs. 15 | Args: 16 | output_dim: output dimension for compact bilinear pooling. 17 | sum_pool: (Optional) If True, sum the output along height and width 18 | dimensions and return output shape [batch_size, output_dim]. 19 | Otherwise return [batch_size, height, width, output_dim]. 20 | Default: True. 21 | rand_h_1: (Optional) an 1D numpy array containing indices in interval 22 | `[0, output_dim)`. Automatically generated from `seed_h_1` 23 | if is None. 24 | rand_s_1: (Optional) an 1D numpy array of 1 and -1, having the same shape 25 | as `rand_h_1`. Automatically generated from `seed_s_1` if is 26 | None. 27 | rand_h_2: (Optional) an 1D numpy array containing indices in interval 28 | `[0, output_dim)`. Automatically generated from `seed_h_2` 29 | if is None. 30 | rand_s_2: (Optional) an 1D numpy array of 1 and -1, having the same shape 31 | as `rand_h_2`. Automatically generated from `seed_s_2` if is 32 | None. 33 | """ 34 | 35 | def __init__(self, input_dim1, input_dim2, output_dim, 36 | sum_pool=True, cuda=True, 37 | rand_h_1=None, rand_s_1=None, rand_h_2=None, rand_s_2=None): 38 | super(CompactBilinearPooling, self).__init__() 39 | self.input_dim1 = input_dim1 40 | self.input_dim2 = input_dim2 41 | self.output_dim = output_dim 42 | self.sum_pool = sum_pool 43 | 44 | if rand_h_1 is None: 45 | np.random.seed(1) 46 | rand_h_1 = np.random.randint(output_dim, size=self.input_dim1) 47 | if rand_s_1 is None: 48 | np.random.seed(3) 49 | rand_s_1 = 2 * np.random.randint(2, size=self.input_dim1) - 1 50 | 51 | self.sparse_sketch_matrix1 = Variable(self.generate_sketch_matrix( 52 | rand_h_1, rand_s_1, self.output_dim)) 53 | 54 | if rand_h_2 is None: 55 | np.random.seed(5) 56 | rand_h_2 = np.random.randint(output_dim, size=self.input_dim2) 57 | if rand_s_2 is None: 58 | np.random.seed(7) 59 | rand_s_2 = 2 * np.random.randint(2, size=self.input_dim2) - 1 60 | 61 | self.sparse_sketch_matrix2 = Variable(self.generate_sketch_matrix( 62 | rand_h_2, rand_s_2, self.output_dim)) 63 | 64 | if cuda: 65 | self.sparse_sketch_matrix1 = self.sparse_sketch_matrix1.cuda() 66 | self.sparse_sketch_matrix2 = self.sparse_sketch_matrix2.cuda() 67 | 68 | def forward(self, bottom1, bottom2): 69 | """ 70 | bottom1: 1st input, 4D Tensor of shape [batch_size, input_dim1, height, width]. 71 | bottom2: 2nd input, 4D Tensor of shape [batch_size, input_dim2, height, width]. 72 | """ 73 | assert bottom1.size(1) == self.input_dim1 and \ 74 | bottom2.size(1) == self.input_dim2 75 | 76 | batch_size, _, height, width = bottom1.size() 77 | 78 | bottom1_flat = bottom1.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1) 79 | bottom2_flat = bottom2.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim2) 80 | 81 | sketch_1 = bottom1_flat.mm(self.sparse_sketch_matrix1) 82 | sketch_2 = bottom2_flat.mm(self.sparse_sketch_matrix2) 83 | 84 | fft1 = afft.fft(sketch_1) 85 | fft2 = afft.fft(sketch_2) 86 | 87 | fft_product = fft1 * fft2 88 | 89 | cbp_flat = afft.ifft(fft_product).real 90 | 91 | cbp = cbp_flat.view(batch_size, height, width, self.output_dim) 92 | 93 | if self.sum_pool: 94 | cbp = cbp.sum(dim=1).sum(dim=1) 95 | 96 | return cbp 97 | 98 | @staticmethod 99 | def generate_sketch_matrix(rand_h, rand_s, output_dim): 100 | """ 101 | Return a sparse matrix used for tensor sketch operation in compact bilinear 102 | pooling 103 | Args: 104 | rand_h: an 1D numpy array containing indices in interval `[0, output_dim)`. 105 | rand_s: an 1D numpy array of 1 and -1, having the same shape as `rand_h`. 106 | output_dim: the output dimensions of compact bilinear pooling. 107 | Returns: 108 | a sparse matrix of shape [input_dim, output_dim] for tensor sketch. 109 | """ 110 | 111 | # Generate a sparse matrix for tensor count sketch 112 | rand_h = rand_h.astype(np.int64) 113 | rand_s = rand_s.astype(np.float32) 114 | assert(rand_h.ndim == 1 and rand_s.ndim == 115 | 1 and len(rand_h) == len(rand_s)) 116 | assert(np.all(rand_h >= 0) and np.all(rand_h < output_dim)) 117 | 118 | input_dim = len(rand_h) 119 | indices = np.concatenate((np.arange(input_dim)[..., np.newaxis], 120 | rand_h[..., np.newaxis]), axis=1) 121 | indices = torch.from_numpy(indices) 122 | rand_s = torch.from_numpy(rand_s) 123 | sparse_sketch_matrix = torch.sparse.FloatTensor( 124 | indices.t(), rand_s, torch.Size([input_dim, output_dim])) 125 | return sparse_sketch_matrix.to_dense() 126 | 127 | 128 | if __name__ == '__main__': 129 | 130 | bottom1 = Variable(torch.randn(3, 512, 14, 14)) 131 | bottom2 = Variable(torch.randn(3, 128, 14, 14)) 132 | 133 | layer = CompactBilinearPooling(512, 128, 512, cuda=False) 134 | layer.train() 135 | 136 | out = layer(bottom1, bottom2) 137 | print(out.shape) 138 | 139 | 140 | 141 | 142 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/vgg.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch.nn as nn 27 | from torchvision.models.vgg import VGG 28 | from torchvision.models.vgg import make_layers 29 | from pretrainedmodels.models.torchvision_models import pretrained_settings 30 | 31 | from ._base import EncoderMixin 32 | 33 | # fmt: off 34 | cfg = { 35 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 36 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 37 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 38 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 39 | } 40 | # fmt: on 41 | 42 | 43 | class VGGEncoder(VGG, EncoderMixin): 44 | def __init__(self, out_channels, config, batch_norm=False, depth=5, **kwargs): 45 | super().__init__(make_layers(config, batch_norm=batch_norm), **kwargs) 46 | self._out_channels = out_channels 47 | self._depth = depth 48 | self._in_channels = 3 49 | del self.classifier 50 | 51 | def make_dilated(self, *args, **kwargs): 52 | raise ValueError("'VGG' models do not support dilated mode due to Max Pooling" " operations for downsampling!") 53 | 54 | def get_stages(self): 55 | stages = [] 56 | stage_modules = [] 57 | for module in self.features: 58 | if isinstance(module, nn.MaxPool2d): 59 | stages.append(nn.Sequential(*stage_modules)) 60 | stage_modules = [] 61 | stage_modules.append(module) 62 | stages.append(nn.Sequential(*stage_modules)) 63 | return stages 64 | 65 | def forward(self, x): 66 | stages = self.get_stages() 67 | 68 | features = [] 69 | for i in range(self._depth + 1): 70 | x = stages[i](x) 71 | features.append(x) 72 | 73 | return features 74 | 75 | def load_state_dict(self, state_dict, **kwargs): 76 | keys = list(state_dict.keys()) 77 | for k in keys: 78 | if k.startswith("classifier"): 79 | state_dict.pop(k, None) 80 | super().load_state_dict(state_dict, **kwargs) 81 | 82 | 83 | vgg_encoders = { 84 | "vgg11": { 85 | "encoder": VGGEncoder, 86 | "pretrained_settings": pretrained_settings["vgg11"], 87 | "params": { 88 | "out_channels": (64, 128, 256, 512, 512, 512), 89 | "config": cfg["A"], 90 | "batch_norm": False, 91 | }, 92 | }, 93 | "vgg11_bn": { 94 | "encoder": VGGEncoder, 95 | "pretrained_settings": pretrained_settings["vgg11_bn"], 96 | "params": { 97 | "out_channels": (64, 128, 256, 512, 512, 512), 98 | "config": cfg["A"], 99 | "batch_norm": True, 100 | }, 101 | }, 102 | "vgg13": { 103 | "encoder": VGGEncoder, 104 | "pretrained_settings": pretrained_settings["vgg13"], 105 | "params": { 106 | "out_channels": (64, 128, 256, 512, 512, 512), 107 | "config": cfg["B"], 108 | "batch_norm": False, 109 | }, 110 | }, 111 | "vgg13_bn": { 112 | "encoder": VGGEncoder, 113 | "pretrained_settings": pretrained_settings["vgg13_bn"], 114 | "params": { 115 | "out_channels": (64, 128, 256, 512, 512, 512), 116 | "config": cfg["B"], 117 | "batch_norm": True, 118 | }, 119 | }, 120 | "vgg16": { 121 | "encoder": VGGEncoder, 122 | "pretrained_settings": pretrained_settings["vgg16"], 123 | "params": { 124 | "out_channels": (64, 128, 256, 512, 512, 512), 125 | "config": cfg["D"], 126 | "batch_norm": False, 127 | }, 128 | }, 129 | "vgg16_bn": { 130 | "encoder": VGGEncoder, 131 | "pretrained_settings": pretrained_settings["vgg16_bn"], 132 | "params": { 133 | "out_channels": (64, 128, 256, 512, 512, 512), 134 | "config": cfg["D"], 135 | "batch_norm": True, 136 | }, 137 | }, 138 | "vgg19": { 139 | "encoder": VGGEncoder, 140 | "pretrained_settings": pretrained_settings["vgg19"], 141 | "params": { 142 | "out_channels": (64, 128, 256, 512, 512, 512), 143 | "config": cfg["E"], 144 | "batch_norm": False, 145 | }, 146 | }, 147 | "vgg19_bn": { 148 | "encoder": VGGEncoder, 149 | "pretrained_settings": pretrained_settings["vgg19_bn"], 150 | "params": { 151 | "out_channels": (64, 128, 256, 512, 512, 512), 152 | "config": cfg["E"], 153 | "batch_norm": True, 154 | }, 155 | }, 156 | } 157 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | 3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | 11 | def conv_bn(inp, oup, stride): 12 | return nn.Sequential( 13 | nn.Conv3d(inp, oup, kernel_size=3, stride=stride, padding=(1,1,1), bias=False), 14 | nn.BatchNorm3d(oup), 15 | nn.ReLU(inplace=True) 16 | ) 17 | 18 | 19 | def channel_shuffle(x, groups): 20 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 21 | batchsize, num_channels, depth, height, width = x.data.size() 22 | channels_per_group = num_channels // groups 23 | # reshape 24 | x = x.view(batchsize, groups, 25 | channels_per_group, depth, height, width) 26 | #permute 27 | x = x.permute(0,2,1,3,4,5).contiguous() 28 | # flatten 29 | x = x.view(batchsize, num_channels, depth, height, width) 30 | return x 31 | 32 | 33 | 34 | class Bottleneck(nn.Module): 35 | def __init__(self, in_planes, out_planes, stride, groups): 36 | super(Bottleneck, self).__init__() 37 | self.stride = stride 38 | self.groups = groups 39 | mid_planes = out_planes//4 40 | if self.stride == 2: 41 | out_planes = out_planes - in_planes 42 | g = 1 if in_planes==24 else groups 43 | self.conv1 = nn.Conv3d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 44 | self.bn1 = nn.BatchNorm3d(mid_planes) 45 | self.conv2 = nn.Conv3d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 46 | self.bn2 = nn.BatchNorm3d(mid_planes) 47 | self.conv3 = nn.Conv3d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 48 | self.bn3 = nn.BatchNorm3d(out_planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | 51 | if stride == 2: 52 | self.shortcut = nn.AvgPool3d(kernel_size=(2,3,3), stride=2, padding=(0,1,1)) 53 | 54 | 55 | def forward(self, x): 56 | out = self.relu(self.bn1(self.conv1(x))) 57 | out = channel_shuffle(out, self.groups) 58 | out = self.bn2(self.conv2(out)) 59 | out = self.bn3(self.conv3(out)) 60 | 61 | if self.stride == 2: 62 | out = self.relu(torch.cat([out, self.shortcut(x)], 1)) 63 | else: 64 | out = self.relu(out + x) 65 | 66 | return out 67 | 68 | 69 | class ShuffleNet(nn.Module): 70 | def __init__(self, 71 | groups=3, 72 | width_mult=1, 73 | num_classes=400): 74 | super(ShuffleNet, self).__init__() 75 | self.num_classes = num_classes 76 | self.groups = groups 77 | num_blocks = [4,8,4] 78 | 79 | # index 0 is invalid and should never be called. 80 | # only used for indexing convenience. 81 | if groups == 1: 82 | out_planes = [24, 144, 288, 567] 83 | elif groups == 2: 84 | out_planes = [24, 200, 400, 800] 85 | elif groups == 3: 86 | out_planes = [24, 240, 480, 960] 87 | elif groups == 4: 88 | out_planes = [24, 272, 544, 1088] 89 | elif groups == 8: 90 | out_planes = [24, 384, 768, 1536] 91 | else: 92 | raise ValueError( 93 | """{} groups is not supported for 94 | 1x1 Grouped Convolutions""".format(groups)) 95 | out_planes = [int(i * width_mult) for i in out_planes] 96 | self.in_planes = out_planes[0] 97 | self.conv1 = conv_bn(3, self.in_planes, stride=(1,2,2)) 98 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 99 | self.layer1 = self._make_layer(out_planes[1], num_blocks[0], self.groups) 100 | self.layer2 = self._make_layer(out_planes[2], num_blocks[1], self.groups) 101 | self.layer3 = self._make_layer(out_planes[3], num_blocks[2], self.groups) 102 | 103 | def _make_layer(self, out_planes, num_blocks, groups): 104 | layers = [] 105 | for i in range(num_blocks): 106 | stride = 2 if i == 0 else 1 107 | layers.append(Bottleneck(self.in_planes, out_planes, stride=stride, groups=groups)) 108 | self.in_planes = out_planes 109 | return nn.Sequential(*layers) 110 | 111 | def forward(self, x): 112 | f_list = [] 113 | out = self.conv1(x) 114 | out = self.maxpool(out) 115 | f_list.append(out) 116 | out = self.layer1(out) 117 | f_list.append(out) 118 | out = self.layer2(out) 119 | f_list.append(out) 120 | out = self.layer3(out) 121 | f_list.append(out) 122 | return f_list 123 | 124 | def get_fine_tuning_parameters(model, ft_portion): 125 | if ft_portion == "complete": 126 | return model.parameters() 127 | 128 | elif ft_portion == "last_layer": 129 | ft_module_names = [] 130 | ft_module_names.append('classifier') 131 | 132 | parameters = [] 133 | for k, v in model.named_parameters(): 134 | for ft_module in ft_module_names: 135 | if ft_module in k: 136 | parameters.append({'params': v}) 137 | break 138 | else: 139 | parameters.append({'params': v, 'lr': 0.0}) 140 | return parameters 141 | 142 | else: 143 | raise ValueError("Unsupported ft_portion: 'complete' or 'last_layer' expected") 144 | 145 | 146 | def get_model(**kwargs): 147 | """ 148 | Returns the model. 149 | """ 150 | model = ShuffleNet(**kwargs) 151 | return model 152 | 153 | 154 | if __name__ == "__main__": 155 | model = get_model(groups=3, num_classes=600, width_mult=1) 156 | model = model.cuda() 157 | model = nn.DataParallel(model, device_ids=None) 158 | print(model) 159 | 160 | input_var = Variable(torch.randn(8, 3, 16, 112, 112)) 161 | output = model(input_var) 162 | print(output.shape) 163 | 164 | 165 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/VC3D_kenshohara/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = ['WideResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101'] 9 | 10 | 11 | def conv3x3x3(in_planes, out_planes, stride=1): 12 | # 3x3x3 convolution with padding 13 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, 14 | stride=stride, padding=1, bias=False) 15 | 16 | 17 | def downsample_basic_block(x, planes, stride): 18 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 19 | zero_pads = torch.Tensor(out.size(0), planes - out.size(1), 20 | out.size(2), out.size(3), 21 | out.size(4)).zero_() 22 | if isinstance(out.data, torch.cuda.FloatTensor): 23 | zero_pads = zero_pads.cuda() 24 | 25 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 26 | 27 | return out 28 | 29 | 30 | class WideBottleneck(nn.Module): 31 | expansion = 2 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(WideBottleneck, self).__init__() 35 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 36 | self.bn1 = nn.BatchNorm3d(planes) 37 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, 38 | padding=1, bias=False) 39 | self.bn2 = nn.BatchNorm3d(planes) 40 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False) 41 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv3(out) 58 | out = self.bn3(out) 59 | 60 | if self.downsample is not None: 61 | residual = self.downsample(x) 62 | 63 | out += residual 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | 69 | class WideResNet(nn.Module): 70 | 71 | def __init__(self, block, layers, sample_size = None, sample_duration = None, k=2, shortcut_type='B', num_classes=400, last_fc=True): 72 | self.last_fc = last_fc 73 | 74 | self.inplanes = 64 75 | super(WideResNet, self).__init__() 76 | self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2), 77 | padding=(3, 3, 3), bias=False) 78 | self.bn1 = nn.BatchNorm3d(64) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 81 | self.layer1 = self._make_layer(block, 64 * k, layers[0], shortcut_type) 82 | self.layer2 = self._make_layer(block, 128 * k, layers[1], shortcut_type, stride=2) 83 | self.layer3 = self._make_layer(block, 256 * k, layers[2], shortcut_type, stride=2) 84 | self.layer4 = self._make_layer(block, 512 * k, layers[3], shortcut_type, stride=2) 85 | # last_duration = math.ceil(sample_duration / 16) 86 | # last_size = math.ceil(sample_size / 32) 87 | # self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1) 88 | # self.fc = nn.Linear(512 * k * block.expansion, num_classes) 89 | 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv3d): 92 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 93 | m.weight.data.normal_(0, math.sqrt(2. / n)) 94 | elif isinstance(m, nn.BatchNorm3d): 95 | m.weight.data.fill_(1) 96 | m.bias.data.zero_() 97 | 98 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | if shortcut_type == 'A': 102 | downsample = partial(downsample_basic_block, 103 | planes=planes * block.expansion, 104 | stride=stride) 105 | else: 106 | downsample = nn.Sequential( 107 | nn.Conv3d(self.inplanes, planes * block.expansion, 108 | kernel_size=1, stride=stride, bias=False), 109 | nn.BatchNorm3d(planes * block.expansion) 110 | ) 111 | 112 | layers = [] 113 | layers.append(block(self.inplanes, planes, stride, downsample)) 114 | self.inplanes = planes * block.expansion 115 | for i in range(1, blocks): 116 | layers.append(block(self.inplanes, planes)) 117 | 118 | return nn.Sequential(*layers) 119 | 120 | def forward(self, x): 121 | f_list = [] 122 | x = self.conv1(x) 123 | x = self.bn1(x) 124 | x = self.relu(x) 125 | x = self.maxpool(x) 126 | f_list.append(x) 127 | 128 | x = self.layer1(x) 129 | f_list.append(x) 130 | x = self.layer2(x) 131 | f_list.append(x) 132 | x = self.layer3(x) 133 | f_list.append(x) 134 | x = self.layer4(x) 135 | f_list.append(x) 136 | 137 | return f_list 138 | 139 | def get_fine_tuning_parameters(model, ft_begin_index): 140 | if ft_begin_index == 0: 141 | return model.parameters() 142 | 143 | ft_module_names = [] 144 | for i in range(ft_begin_index, 5): 145 | ft_module_names.append('layer{}'.format(ft_begin_index)) 146 | ft_module_names.append('fc') 147 | 148 | parameters = [] 149 | for k, v in model.named_parameters(): 150 | for ft_module in ft_module_names: 151 | if ft_module in k: 152 | parameters.append({'params': v}) 153 | break 154 | else: 155 | parameters.append({'params': v, 'lr': 0.0}) 156 | 157 | return parameters 158 | 159 | def generate_model(): 160 | """Constructs a ResNet-50 model. 161 | """ 162 | model = WideResNet(WideBottleneck, [3, 4, 6, 3]) 163 | return model 164 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/Efficient3D_okankop/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobilenetV2 in PyTorch. 2 | 3 | See the paper "MobileNetV2: Inverted Residuals and Linear Bottlenecks" for more details. 4 | ''' 5 | import torch 6 | import math 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | 11 | 12 | 13 | def conv_bn(inp, oup, stride): 14 | return nn.Sequential( 15 | nn.Conv3d(inp, oup, kernel_size=3, stride=stride, padding=(1,1,1), bias=False), 16 | nn.BatchNorm3d(oup), 17 | nn.ReLU6(inplace=True) 18 | ) 19 | 20 | 21 | def conv_1x1x1_bn(inp, oup): 22 | return nn.Sequential( 23 | nn.Conv3d(inp, oup, 1, 1, 0, bias=False), 24 | nn.BatchNorm3d(oup), 25 | nn.ReLU6(inplace=True) 26 | ) 27 | 28 | 29 | class InvertedResidual(nn.Module): 30 | def __init__(self, inp, oup, stride, expand_ratio): 31 | super(InvertedResidual, self).__init__() 32 | self.stride = stride 33 | 34 | hidden_dim = round(inp * expand_ratio) 35 | self.use_res_connect = self.stride == (1,1,1) and inp == oup 36 | 37 | if expand_ratio == 1: 38 | self.conv = nn.Sequential( 39 | # dw 40 | nn.Conv3d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 41 | nn.BatchNorm3d(hidden_dim), 42 | nn.ReLU6(inplace=True), 43 | # pw-linear 44 | nn.Conv3d(hidden_dim, oup, 1, 1, 0, bias=False), 45 | nn.BatchNorm3d(oup), 46 | ) 47 | else: 48 | self.conv = nn.Sequential( 49 | # pw 50 | nn.Conv3d(inp, hidden_dim, 1, 1, 0, bias=False), 51 | nn.BatchNorm3d(hidden_dim), 52 | nn.ReLU6(inplace=True), 53 | # dw 54 | nn.Conv3d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 55 | nn.BatchNorm3d(hidden_dim), 56 | nn.ReLU6(inplace=True), 57 | # pw-linear 58 | nn.Conv3d(hidden_dim, oup, 1, 1, 0, bias=False), 59 | nn.BatchNorm3d(oup), 60 | ) 61 | 62 | def forward(self, x): 63 | if self.use_res_connect: 64 | return x + self.conv(x) 65 | else: 66 | return self.conv(x) 67 | 68 | 69 | class MobileNetV2(nn.Module): 70 | def __init__(self, width_mult=1.): 71 | super(MobileNetV2, self).__init__() 72 | block = InvertedResidual 73 | input_channel = 32 74 | last_channel = 1280 75 | interverted_residual_setting = [ 76 | # t, c, n, s 77 | [1, 16, 1, (1,1,1)], 78 | [6, 24, 2, (2,2,2)], 79 | [6, 32, 3, (2,2,2)], 80 | [6, 64, 4, (2,2,2)], 81 | [6, 96, 3, (1,1,1)], 82 | [6, 160, 3, (2,2,2)], 83 | [6, 320, 1, (1,1,1)], 84 | ] 85 | 86 | # building first layer 87 | input_channel = int(input_channel * width_mult) 88 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 89 | self.features = [conv_bn(3, input_channel, (1,2,2))] 90 | # building inverted residual blocks 91 | for t, c, n, s in interverted_residual_setting: 92 | output_channel = int(c * width_mult) 93 | for i in range(n): 94 | stride = s if i == 0 else (1,1,1) 95 | self.features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 96 | input_channel = output_channel 97 | # building last several layers 98 | self.features.append(conv_1x1x1_bn(input_channel, self.last_channel)) 99 | # make it nn.Sequential 100 | self.features = nn.Sequential(*self.features) 101 | 102 | 103 | 104 | def forward(self, x): 105 | f_list = [] 106 | for i in range(len(self.features)): 107 | x = self.features[i](x) 108 | f_list.append(x) 109 | 110 | # keep last f 111 | f_list_ = [] 112 | for i, f in enumerate(f_list): 113 | if i == 0 or i == len(f_list)-1: 114 | f_list_.append(f) 115 | elif f.shape[1] != f_list[i+1].shape[1] and f.shape[2:] != f_list[i+1].shape[2:]: 116 | f_list_.append(f) 117 | 118 | return f_list_ 119 | 120 | 121 | def _initialize_weights(self): 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv3d): 124 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 125 | m.weight.data.normal_(0, math.sqrt(2. / n)) 126 | if m.bias is not None: 127 | m.bias.data.zero_() 128 | elif isinstance(m, nn.BatchNorm3d): 129 | m.weight.data.fill_(1) 130 | m.bias.data.zero_() 131 | elif isinstance(m, nn.Linear): 132 | n = m.weight.size(1) 133 | m.weight.data.normal_(0, 0.01) 134 | m.bias.data.zero_() 135 | 136 | 137 | def get_fine_tuning_parameters(model, ft_portion): 138 | if ft_portion == "complete": 139 | return model.parameters() 140 | 141 | elif ft_portion == "last_layer": 142 | ft_module_names = [] 143 | ft_module_names.append('classifier') 144 | 145 | parameters = [] 146 | for k, v in model.named_parameters(): 147 | for ft_module in ft_module_names: 148 | if ft_module in k: 149 | parameters.append({'params': v}) 150 | break 151 | else: 152 | parameters.append({'params': v, 'lr': 0.0}) 153 | return parameters 154 | 155 | else: 156 | raise ValueError("Unsupported ft_portion: 'complete' or 'last_layer' expected") 157 | 158 | 159 | def get_model(**kwargs): 160 | """ 161 | Returns the model. 162 | """ 163 | model = MobileNetV2(**kwargs) 164 | return model 165 | 166 | 167 | if __name__ == "__main__": 168 | model = get_model(num_classes=600, sample_size=112, width_mult=1.) 169 | model = model.cuda() 170 | model = nn.DataParallel(model, device_ids=None) 171 | print(model) 172 | 173 | 174 | input_var = Variable(torch.randn(8, 3, 16, 112, 112)) 175 | output = model(input_var) 176 | print(output.shape) 177 | 178 | 179 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/timm_res2net.py: -------------------------------------------------------------------------------- 1 | from ._base import EncoderMixin 2 | from timm.models.resnet import ResNet 3 | from timm.models.res2net import Bottle2neck 4 | import torch.nn as nn 5 | 6 | 7 | class Res2NetEncoder(ResNet, EncoderMixin): 8 | def __init__(self, out_channels, depth=5, **kwargs): 9 | super().__init__(**kwargs) 10 | self._depth = depth 11 | self._out_channels = out_channels 12 | self._in_channels = 3 13 | 14 | del self.fc 15 | del self.global_pool 16 | 17 | def get_stages(self): 18 | return [ 19 | nn.Identity(), 20 | nn.Sequential(self.conv1, self.bn1, self.act1), 21 | nn.Sequential(self.maxpool, self.layer1), 22 | self.layer2, 23 | self.layer3, 24 | self.layer4, 25 | ] 26 | 27 | def make_dilated(self, *args, **kwargs): 28 | raise ValueError("Res2Net encoders do not support dilated mode") 29 | 30 | def forward(self, x): 31 | stages = self.get_stages() 32 | 33 | features = [] 34 | for i in range(self._depth + 1): 35 | x = stages[i](x) 36 | features.append(x) 37 | 38 | return features 39 | 40 | def load_state_dict(self, state_dict, **kwargs): 41 | state_dict.pop("fc.bias", None) 42 | state_dict.pop("fc.weight", None) 43 | super().load_state_dict(state_dict, **kwargs) 44 | 45 | 46 | res2net_weights = { 47 | "timm-res2net50_26w_4s": { 48 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth", # noqa 49 | }, 50 | "timm-res2net50_48w_2s": { 51 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth", # noqa 52 | }, 53 | "timm-res2net50_14w_8s": { 54 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth", # noqa 55 | }, 56 | "timm-res2net50_26w_6s": { 57 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth", # noqa 58 | }, 59 | "timm-res2net50_26w_8s": { 60 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth", # noqa 61 | }, 62 | "timm-res2net101_26w_4s": { 63 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth", # noqa 64 | }, 65 | "timm-res2next50": { 66 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth", # noqa 67 | }, 68 | } 69 | 70 | pretrained_settings = {} 71 | for model_name, sources in res2net_weights.items(): 72 | pretrained_settings[model_name] = {} 73 | for source_name, source_url in sources.items(): 74 | pretrained_settings[model_name][source_name] = { 75 | "url": source_url, 76 | "input_size": [3, 224, 224], 77 | "input_range": [0, 1], 78 | "mean": [0.485, 0.456, 0.406], 79 | "std": [0.229, 0.224, 0.225], 80 | "num_classes": 1000, 81 | } 82 | 83 | 84 | timm_res2net_encoders = { 85 | "timm-res2net50_26w_4s": { 86 | "encoder": Res2NetEncoder, 87 | "pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"], 88 | "params": { 89 | "out_channels": (3, 64, 256, 512, 1024, 2048), 90 | "block": Bottle2neck, 91 | "layers": [3, 4, 6, 3], 92 | "base_width": 26, 93 | "block_args": {"scale": 4}, 94 | }, 95 | }, 96 | "timm-res2net101_26w_4s": { 97 | "encoder": Res2NetEncoder, 98 | "pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"], 99 | "params": { 100 | "out_channels": (3, 64, 256, 512, 1024, 2048), 101 | "block": Bottle2neck, 102 | "layers": [3, 4, 23, 3], 103 | "base_width": 26, 104 | "block_args": {"scale": 4}, 105 | }, 106 | }, 107 | "timm-res2net50_26w_6s": { 108 | "encoder": Res2NetEncoder, 109 | "pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"], 110 | "params": { 111 | "out_channels": (3, 64, 256, 512, 1024, 2048), 112 | "block": Bottle2neck, 113 | "layers": [3, 4, 6, 3], 114 | "base_width": 26, 115 | "block_args": {"scale": 6}, 116 | }, 117 | }, 118 | "timm-res2net50_26w_8s": { 119 | "encoder": Res2NetEncoder, 120 | "pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"], 121 | "params": { 122 | "out_channels": (3, 64, 256, 512, 1024, 2048), 123 | "block": Bottle2neck, 124 | "layers": [3, 4, 6, 3], 125 | "base_width": 26, 126 | "block_args": {"scale": 8}, 127 | }, 128 | }, 129 | "timm-res2net50_48w_2s": { 130 | "encoder": Res2NetEncoder, 131 | "pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"], 132 | "params": { 133 | "out_channels": (3, 64, 256, 512, 1024, 2048), 134 | "block": Bottle2neck, 135 | "layers": [3, 4, 6, 3], 136 | "base_width": 48, 137 | "block_args": {"scale": 2}, 138 | }, 139 | }, 140 | "timm-res2net50_14w_8s": { 141 | "encoder": Res2NetEncoder, 142 | "pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"], 143 | "params": { 144 | "out_channels": (3, 64, 256, 512, 1024, 2048), 145 | "block": Bottle2neck, 146 | "layers": [3, 4, 6, 3], 147 | "base_width": 14, 148 | "block_args": {"scale": 8}, 149 | }, 150 | }, 151 | "timm-res2next50": { 152 | "encoder": Res2NetEncoder, 153 | "pretrained_settings": pretrained_settings["timm-res2next50"], 154 | "params": { 155 | "out_channels": (3, 64, 256, 512, 1024, 2048), 156 | "block": Bottle2neck, 157 | "layers": [3, 4, 6, 3], 158 | "base_width": 4, 159 | "cardinality": 8, 160 | "block_args": {"scale": 4}, 161 | }, 162 | }, 163 | } 164 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/senet.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch.nn as nn 27 | 28 | from pretrainedmodels.models.senet import ( 29 | SENet, 30 | SEBottleneck, 31 | SEResNetBottleneck, 32 | SEResNeXtBottleneck, 33 | pretrained_settings, 34 | ) 35 | from ._base import EncoderMixin 36 | 37 | 38 | class SENetEncoder(SENet, EncoderMixin): 39 | def __init__(self, out_channels, depth=5, **kwargs): 40 | super().__init__(**kwargs) 41 | 42 | self._out_channels = out_channels 43 | self._depth = depth 44 | self._in_channels = 3 45 | 46 | del self.last_linear 47 | del self.avg_pool 48 | 49 | def get_stages(self): 50 | return [ 51 | nn.Identity(), 52 | self.layer0[:-1], 53 | nn.Sequential(self.layer0[-1], self.layer1), 54 | self.layer2, 55 | self.layer3, 56 | self.layer4, 57 | ] 58 | 59 | def forward(self, x): 60 | stages = self.get_stages() 61 | 62 | features = [] 63 | for i in range(self._depth + 1): 64 | x = stages[i](x) 65 | features.append(x) 66 | 67 | return features 68 | 69 | def load_state_dict(self, state_dict, **kwargs): 70 | state_dict.pop("last_linear.bias", None) 71 | state_dict.pop("last_linear.weight", None) 72 | super().load_state_dict(state_dict, **kwargs) 73 | 74 | 75 | senet_encoders = { 76 | "senet154": { 77 | "encoder": SENetEncoder, 78 | "pretrained_settings": pretrained_settings["senet154"], 79 | "params": { 80 | "out_channels": (3, 128, 256, 512, 1024, 2048), 81 | "block": SEBottleneck, 82 | "dropout_p": 0.2, 83 | "groups": 64, 84 | "layers": [3, 8, 36, 3], 85 | "num_classes": 1000, 86 | "reduction": 16, 87 | }, 88 | }, 89 | "se_resnet50": { 90 | "encoder": SENetEncoder, 91 | "pretrained_settings": pretrained_settings["se_resnet50"], 92 | "params": { 93 | "out_channels": (3, 64, 256, 512, 1024, 2048), 94 | "block": SEResNetBottleneck, 95 | "layers": [3, 4, 6, 3], 96 | "downsample_kernel_size": 1, 97 | "downsample_padding": 0, 98 | "dropout_p": None, 99 | "groups": 1, 100 | "inplanes": 64, 101 | "input_3x3": False, 102 | "num_classes": 1000, 103 | "reduction": 16, 104 | }, 105 | }, 106 | "se_resnet101": { 107 | "encoder": SENetEncoder, 108 | "pretrained_settings": pretrained_settings["se_resnet101"], 109 | "params": { 110 | "out_channels": (3, 64, 256, 512, 1024, 2048), 111 | "block": SEResNetBottleneck, 112 | "layers": [3, 4, 23, 3], 113 | "downsample_kernel_size": 1, 114 | "downsample_padding": 0, 115 | "dropout_p": None, 116 | "groups": 1, 117 | "inplanes": 64, 118 | "input_3x3": False, 119 | "num_classes": 1000, 120 | "reduction": 16, 121 | }, 122 | }, 123 | "se_resnet152": { 124 | "encoder": SENetEncoder, 125 | "pretrained_settings": pretrained_settings["se_resnet152"], 126 | "params": { 127 | "out_channels": (3, 64, 256, 512, 1024, 2048), 128 | "block": SEResNetBottleneck, 129 | "layers": [3, 8, 36, 3], 130 | "downsample_kernel_size": 1, 131 | "downsample_padding": 0, 132 | "dropout_p": None, 133 | "groups": 1, 134 | "inplanes": 64, 135 | "input_3x3": False, 136 | "num_classes": 1000, 137 | "reduction": 16, 138 | }, 139 | }, 140 | "se_resnext50_32x4d": { 141 | "encoder": SENetEncoder, 142 | "pretrained_settings": pretrained_settings["se_resnext50_32x4d"], 143 | "params": { 144 | "out_channels": (3, 64, 256, 512, 1024, 2048), 145 | "block": SEResNeXtBottleneck, 146 | "layers": [3, 4, 6, 3], 147 | "downsample_kernel_size": 1, 148 | "downsample_padding": 0, 149 | "dropout_p": None, 150 | "groups": 32, 151 | "inplanes": 64, 152 | "input_3x3": False, 153 | "num_classes": 1000, 154 | "reduction": 16, 155 | }, 156 | }, 157 | "se_resnext101_32x4d": { 158 | "encoder": SENetEncoder, 159 | "pretrained_settings": pretrained_settings["se_resnext101_32x4d"], 160 | "params": { 161 | "out_channels": (3, 64, 256, 512, 1024, 2048), 162 | "block": SEResNeXtBottleneck, 163 | "layers": [3, 4, 23, 3], 164 | "downsample_kernel_size": 1, 165 | "downsample_padding": 0, 166 | "dropout_p": None, 167 | "groups": 32, 168 | "inplanes": 64, 169 | "input_3x3": False, 170 | "num_classes": 1000, 171 | "reduction": 16, 172 | }, 173 | }, 174 | } 175 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/dpn.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch 27 | import torch.nn as nn 28 | import torch.nn.functional as F 29 | 30 | from pretrainedmodels.models.dpn import DPN 31 | from pretrainedmodels.models.dpn import pretrained_settings 32 | 33 | from ._base import EncoderMixin 34 | 35 | 36 | class DPNEncoder(DPN, EncoderMixin): 37 | def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): 38 | super().__init__(**kwargs) 39 | self._stage_idxs = stage_idxs 40 | self._depth = depth 41 | self._out_channels = out_channels 42 | self._in_channels = 3 43 | 44 | del self.last_linear 45 | 46 | def get_stages(self): 47 | return [ 48 | nn.Identity(), 49 | nn.Sequential(self.features[0].conv, self.features[0].bn, self.features[0].act), 50 | nn.Sequential(self.features[0].pool, self.features[1 : self._stage_idxs[0]]), 51 | self.features[self._stage_idxs[0] : self._stage_idxs[1]], 52 | self.features[self._stage_idxs[1] : self._stage_idxs[2]], 53 | self.features[self._stage_idxs[2] : self._stage_idxs[3]], 54 | ] 55 | 56 | def forward(self, x): 57 | 58 | stages = self.get_stages() 59 | 60 | features = [] 61 | for i in range(self._depth + 1): 62 | x = stages[i](x) 63 | if isinstance(x, (list, tuple)): 64 | features.append(F.relu(torch.cat(x, dim=1), inplace=True)) 65 | else: 66 | features.append(x) 67 | 68 | return features 69 | 70 | def load_state_dict(self, state_dict, **kwargs): 71 | state_dict.pop("last_linear.bias", None) 72 | state_dict.pop("last_linear.weight", None) 73 | super().load_state_dict(state_dict, **kwargs) 74 | 75 | 76 | dpn_encoders = { 77 | "dpn68": { 78 | "encoder": DPNEncoder, 79 | "pretrained_settings": pretrained_settings["dpn68"], 80 | "params": { 81 | "stage_idxs": (4, 8, 20, 24), 82 | "out_channels": (3, 10, 144, 320, 704, 832), 83 | "groups": 32, 84 | "inc_sec": (16, 32, 32, 64), 85 | "k_r": 128, 86 | "k_sec": (3, 4, 12, 3), 87 | "num_classes": 1000, 88 | "num_init_features": 10, 89 | "small": True, 90 | "test_time_pool": True, 91 | }, 92 | }, 93 | "dpn68b": { 94 | "encoder": DPNEncoder, 95 | "pretrained_settings": pretrained_settings["dpn68b"], 96 | "params": { 97 | "stage_idxs": (4, 8, 20, 24), 98 | "out_channels": (3, 10, 144, 320, 704, 832), 99 | "b": True, 100 | "groups": 32, 101 | "inc_sec": (16, 32, 32, 64), 102 | "k_r": 128, 103 | "k_sec": (3, 4, 12, 3), 104 | "num_classes": 1000, 105 | "num_init_features": 10, 106 | "small": True, 107 | "test_time_pool": True, 108 | }, 109 | }, 110 | "dpn92": { 111 | "encoder": DPNEncoder, 112 | "pretrained_settings": pretrained_settings["dpn92"], 113 | "params": { 114 | "stage_idxs": (4, 8, 28, 32), 115 | "out_channels": (3, 64, 336, 704, 1552, 2688), 116 | "groups": 32, 117 | "inc_sec": (16, 32, 24, 128), 118 | "k_r": 96, 119 | "k_sec": (3, 4, 20, 3), 120 | "num_classes": 1000, 121 | "num_init_features": 64, 122 | "test_time_pool": True, 123 | }, 124 | }, 125 | "dpn98": { 126 | "encoder": DPNEncoder, 127 | "pretrained_settings": pretrained_settings["dpn98"], 128 | "params": { 129 | "stage_idxs": (4, 10, 30, 34), 130 | "out_channels": (3, 96, 336, 768, 1728, 2688), 131 | "groups": 40, 132 | "inc_sec": (16, 32, 32, 128), 133 | "k_r": 160, 134 | "k_sec": (3, 6, 20, 3), 135 | "num_classes": 1000, 136 | "num_init_features": 96, 137 | "test_time_pool": True, 138 | }, 139 | }, 140 | "dpn107": { 141 | "encoder": DPNEncoder, 142 | "pretrained_settings": pretrained_settings["dpn107"], 143 | "params": { 144 | "stage_idxs": (5, 13, 33, 37), 145 | "out_channels": (3, 128, 376, 1152, 2432, 2688), 146 | "groups": 50, 147 | "inc_sec": (20, 64, 64, 128), 148 | "k_r": 200, 149 | "k_sec": (4, 8, 20, 3), 150 | "num_classes": 1000, 151 | "num_init_features": 128, 152 | "test_time_pool": True, 153 | }, 154 | }, 155 | "dpn131": { 156 | "encoder": DPNEncoder, 157 | "pretrained_settings": pretrained_settings["dpn131"], 158 | "params": { 159 | "stage_idxs": (5, 13, 41, 45), 160 | "out_channels": (3, 128, 352, 832, 1984, 2688), 161 | "groups": 40, 162 | "inc_sec": (16, 32, 32, 128), 163 | "k_r": 160, 164 | "k_sec": (4, 8, 28, 3), 165 | "num_classes": 1000, 166 | "num_init_features": 128, 167 | "test_time_pool": True, 168 | }, 169 | }, 170 | } 171 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/timm_mobilenetv3.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | from ._base import EncoderMixin 6 | 7 | 8 | def _make_divisible(x, divisible_by=8): 9 | return int(np.ceil(x * 1.0 / divisible_by) * divisible_by) 10 | 11 | 12 | class MobileNetV3Encoder(nn.Module, EncoderMixin): 13 | def __init__(self, model_name, width_mult, depth=5, **kwargs): 14 | super().__init__() 15 | if "large" not in model_name and "small" not in model_name: 16 | raise ValueError("MobileNetV3 wrong model name {}".format(model_name)) 17 | 18 | self._mode = "small" if "small" in model_name else "large" 19 | self._depth = depth 20 | self._out_channels = self._get_channels(self._mode, width_mult) 21 | self._in_channels = 3 22 | 23 | # minimal models replace hardswish with relu 24 | self.model = timm.create_model( 25 | model_name=model_name, 26 | scriptable=True, # torch.jit scriptable 27 | exportable=True, # onnx export 28 | features_only=True, 29 | ) 30 | 31 | def _get_channels(self, mode, width_mult): 32 | if mode == "small": 33 | channels = [16, 16, 24, 48, 576] 34 | else: 35 | channels = [16, 24, 40, 112, 960] 36 | channels = [ 37 | 3, 38 | ] + [_make_divisible(x * width_mult) for x in channels] 39 | return tuple(channels) 40 | 41 | def get_stages(self): 42 | if self._mode == "small": 43 | return [ 44 | nn.Identity(), 45 | nn.Sequential( 46 | self.model.conv_stem, 47 | self.model.bn1, 48 | self.model.act1, 49 | ), 50 | self.model.blocks[0], 51 | self.model.blocks[1], 52 | self.model.blocks[2:4], 53 | self.model.blocks[4:], 54 | ] 55 | elif self._mode == "large": 56 | return [ 57 | nn.Identity(), 58 | nn.Sequential( 59 | self.model.conv_stem, 60 | self.model.bn1, 61 | self.model.act1, 62 | self.model.blocks[0], 63 | ), 64 | self.model.blocks[1], 65 | self.model.blocks[2], 66 | self.model.blocks[3:5], 67 | self.model.blocks[5:], 68 | ] 69 | else: 70 | ValueError("MobileNetV3 mode should be small or large, got {}".format(self._mode)) 71 | 72 | def forward(self, x): 73 | stages = self.get_stages() 74 | 75 | features = [] 76 | for i in range(self._depth + 1): 77 | x = stages[i](x) 78 | features.append(x) 79 | 80 | return features 81 | 82 | def load_state_dict(self, state_dict, **kwargs): 83 | state_dict.pop("conv_head.weight", None) 84 | state_dict.pop("conv_head.bias", None) 85 | state_dict.pop("classifier.weight", None) 86 | state_dict.pop("classifier.bias", None) 87 | self.model.load_state_dict(state_dict, **kwargs) 88 | 89 | 90 | mobilenetv3_weights = { 91 | "tf_mobilenetv3_large_075": { 92 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth" # noqa 93 | }, 94 | "tf_mobilenetv3_large_100": { 95 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth" # noqa 96 | }, 97 | "tf_mobilenetv3_large_minimal_100": { 98 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth" # noqa 99 | }, 100 | "tf_mobilenetv3_small_075": { 101 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth" # noqa 102 | }, 103 | "tf_mobilenetv3_small_100": { 104 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth" # noqa 105 | }, 106 | "tf_mobilenetv3_small_minimal_100": { 107 | "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth" # noqa 108 | }, 109 | } 110 | 111 | pretrained_settings = {} 112 | for model_name, sources in mobilenetv3_weights.items(): 113 | pretrained_settings[model_name] = {} 114 | for source_name, source_url in sources.items(): 115 | pretrained_settings[model_name][source_name] = { 116 | "url": source_url, 117 | "input_range": [0, 1], 118 | "mean": [0.485, 0.456, 0.406], 119 | "std": [0.229, 0.224, 0.225], 120 | "input_space": "RGB", 121 | } 122 | 123 | 124 | timm_mobilenetv3_encoders = { 125 | "timm-mobilenetv3_large_075": { 126 | "encoder": MobileNetV3Encoder, 127 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_075"], 128 | "params": {"model_name": "tf_mobilenetv3_large_075", "width_mult": 0.75}, 129 | }, 130 | "timm-mobilenetv3_large_100": { 131 | "encoder": MobileNetV3Encoder, 132 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_100"], 133 | "params": {"model_name": "tf_mobilenetv3_large_100", "width_mult": 1.0}, 134 | }, 135 | "timm-mobilenetv3_large_minimal_100": { 136 | "encoder": MobileNetV3Encoder, 137 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_minimal_100"], 138 | "params": {"model_name": "tf_mobilenetv3_large_minimal_100", "width_mult": 1.0}, 139 | }, 140 | "timm-mobilenetv3_small_075": { 141 | "encoder": MobileNetV3Encoder, 142 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_075"], 143 | "params": {"model_name": "tf_mobilenetv3_small_075", "width_mult": 0.75}, 144 | }, 145 | "timm-mobilenetv3_small_100": { 146 | "encoder": MobileNetV3Encoder, 147 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_100"], 148 | "params": {"model_name": "tf_mobilenetv3_small_100", "width_mult": 1.0}, 149 | }, 150 | "timm-mobilenetv3_small_minimal_100": { 151 | "encoder": MobileNetV3Encoder, 152 | "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_minimal_100"], 153 | "params": {"model_name": "tf_mobilenetv3_small_minimal_100", "width_mult": 1.0}, 154 | }, 155 | } 156 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/SMP_qubvel/encoders/efficientnet.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | import torch.nn as nn 26 | from efficientnet_pytorch import EfficientNet 27 | from efficientnet_pytorch.utils import url_map, url_map_advprop, get_model_params 28 | 29 | from ._base import EncoderMixin 30 | 31 | 32 | class EfficientNetEncoder(EfficientNet, EncoderMixin): 33 | def __init__(self, stage_idxs, out_channels, model_name, depth=5): 34 | 35 | blocks_args, global_params = get_model_params(model_name, override_params=None) 36 | super().__init__(blocks_args, global_params) 37 | 38 | self._stage_idxs = stage_idxs 39 | self._out_channels = out_channels 40 | self._depth = depth 41 | self._in_channels = 3 42 | 43 | del self._fc 44 | 45 | def get_stages(self): 46 | return [ 47 | nn.Identity(), 48 | nn.Sequential(self._conv_stem, self._bn0, self._swish), 49 | self._blocks[: self._stage_idxs[0]], 50 | self._blocks[self._stage_idxs[0] : self._stage_idxs[1]], 51 | self._blocks[self._stage_idxs[1] : self._stage_idxs[2]], 52 | self._blocks[self._stage_idxs[2] :], 53 | ] 54 | 55 | def forward(self, x): 56 | stages = self.get_stages() 57 | 58 | block_number = 0.0 59 | drop_connect_rate = self._global_params.drop_connect_rate 60 | 61 | features = [] 62 | for i in range(self._depth + 1): 63 | 64 | # Identity and Sequential stages 65 | if i < 2: 66 | x = stages[i](x) 67 | 68 | # Block stages need drop_connect rate 69 | else: 70 | for module in stages[i]: 71 | drop_connect = drop_connect_rate * block_number / len(self._blocks) 72 | block_number += 1.0 73 | x = module(x, drop_connect) 74 | 75 | features.append(x) 76 | 77 | return features 78 | 79 | def load_state_dict(self, state_dict, **kwargs): 80 | state_dict.pop("_fc.bias", None) 81 | state_dict.pop("_fc.weight", None) 82 | super().load_state_dict(state_dict, **kwargs) 83 | 84 | 85 | def _get_pretrained_settings(encoder): 86 | pretrained_settings = { 87 | "imagenet": { 88 | "mean": [0.485, 0.456, 0.406], 89 | "std": [0.229, 0.224, 0.225], 90 | "url": url_map[encoder], 91 | "input_space": "RGB", 92 | "input_range": [0, 1], 93 | }, 94 | "advprop": { 95 | "mean": [0.5, 0.5, 0.5], 96 | "std": [0.5, 0.5, 0.5], 97 | "url": url_map_advprop[encoder], 98 | "input_space": "RGB", 99 | "input_range": [0, 1], 100 | }, 101 | } 102 | return pretrained_settings 103 | 104 | 105 | efficient_net_encoders = { 106 | "efficientnet-b0": { 107 | "encoder": EfficientNetEncoder, 108 | "pretrained_settings": _get_pretrained_settings("efficientnet-b0"), 109 | "params": { 110 | "out_channels": (3, 32, 24, 40, 112, 320), 111 | "stage_idxs": (3, 5, 9, 16), 112 | "model_name": "efficientnet-b0", 113 | }, 114 | }, 115 | "efficientnet-b1": { 116 | "encoder": EfficientNetEncoder, 117 | "pretrained_settings": _get_pretrained_settings("efficientnet-b1"), 118 | "params": { 119 | "out_channels": (3, 32, 24, 40, 112, 320), 120 | "stage_idxs": (5, 8, 16, 23), 121 | "model_name": "efficientnet-b1", 122 | }, 123 | }, 124 | "efficientnet-b2": { 125 | "encoder": EfficientNetEncoder, 126 | "pretrained_settings": _get_pretrained_settings("efficientnet-b2"), 127 | "params": { 128 | "out_channels": (3, 32, 24, 48, 120, 352), 129 | "stage_idxs": (5, 8, 16, 23), 130 | "model_name": "efficientnet-b2", 131 | }, 132 | }, 133 | "efficientnet-b3": { 134 | "encoder": EfficientNetEncoder, 135 | "pretrained_settings": _get_pretrained_settings("efficientnet-b3"), 136 | "params": { 137 | "out_channels": (3, 40, 32, 48, 136, 384), 138 | "stage_idxs": (5, 8, 18, 26), 139 | "model_name": "efficientnet-b3", 140 | }, 141 | }, 142 | "efficientnet-b4": { 143 | "encoder": EfficientNetEncoder, 144 | "pretrained_settings": _get_pretrained_settings("efficientnet-b4"), 145 | "params": { 146 | "out_channels": (3, 48, 32, 56, 160, 448), 147 | "stage_idxs": (6, 10, 22, 32), 148 | "model_name": "efficientnet-b4", 149 | }, 150 | }, 151 | "efficientnet-b5": { 152 | "encoder": EfficientNetEncoder, 153 | "pretrained_settings": _get_pretrained_settings("efficientnet-b5"), 154 | "params": { 155 | "out_channels": (3, 48, 40, 64, 176, 512), 156 | "stage_idxs": (8, 13, 27, 39), 157 | "model_name": "efficientnet-b5", 158 | }, 159 | }, 160 | "efficientnet-b6": { 161 | "encoder": EfficientNetEncoder, 162 | "pretrained_settings": _get_pretrained_settings("efficientnet-b6"), 163 | "params": { 164 | "out_channels": (3, 56, 40, 72, 200, 576), 165 | "stage_idxs": (9, 15, 31, 45), 166 | "model_name": "efficientnet-b6", 167 | }, 168 | }, 169 | "efficientnet-b7": { 170 | "encoder": EfficientNetEncoder, 171 | "pretrained_settings": _get_pretrained_settings("efficientnet-b7"), 172 | "params": { 173 | "out_channels": (3, 64, 48, 80, 224, 640), 174 | "stage_idxs": (11, 18, 38, 55), 175 | "model_name": "efficientnet-b7", 176 | }, 177 | }, 178 | } 179 | -------------------------------------------------------------------------------- /wama_modules/thirdparty_lib/VC3D_kenshohara/resnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = ['ResNeXt', 'resnet50', 'resnet101'] 9 | 10 | 11 | def conv3x3x3(in_planes, out_planes, stride=1): 12 | # 3x3x3 convolution with padding 13 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, 14 | stride=stride, padding=1, bias=False) 15 | 16 | 17 | def downsample_basic_block(x, planes, stride): 18 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 19 | zero_pads = torch.Tensor(out.size(0), planes - out.size(1), 20 | out.size(2), out.size(3), 21 | out.size(4)).zero_() 22 | if isinstance(out.data, torch.cuda.FloatTensor): 23 | zero_pads = zero_pads.cuda() 24 | 25 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 26 | 27 | return out 28 | 29 | 30 | class ResNeXtBottleneck(nn.Module): 31 | expansion = 2 32 | 33 | def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None): 34 | super(ResNeXtBottleneck, self).__init__() 35 | mid_planes = cardinality * int(planes / 32) 36 | self.conv1 = nn.Conv3d(inplanes, mid_planes, kernel_size=1, bias=False) 37 | self.bn1 = nn.BatchNorm3d(mid_planes) 38 | self.conv2 = nn.Conv3d(mid_planes, mid_planes, kernel_size=3, stride=stride, 39 | padding=1, groups=cardinality, bias=False) 40 | self.bn2 = nn.BatchNorm3d(mid_planes) 41 | self.conv3 = nn.Conv3d(mid_planes, planes * self.expansion, kernel_size=1, bias=False) 42 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv3(out) 59 | out = self.bn3(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class ResNeXt(nn.Module): 71 | 72 | def __init__(self, block, layers, sample_size = None, sample_duration = None, shortcut_type='B', cardinality=32, num_classes=400, last_fc=True): 73 | self.last_fc = last_fc 74 | 75 | self.inplanes = 64 76 | super(ResNeXt, self).__init__() 77 | self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2), 78 | padding=(3, 3, 3), bias=False) 79 | self.bn1 = nn.BatchNorm3d(64) 80 | self.relu = nn.ReLU(inplace=True) 81 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 82 | self.layer1 = self._make_layer(block, 128, layers[0], shortcut_type, cardinality) 83 | self.layer2 = self._make_layer(block, 256, layers[1], shortcut_type, cardinality, stride=2) 84 | self.layer3 = self._make_layer(block, 512, layers[2], shortcut_type, cardinality, stride=2) 85 | self.layer4 = self._make_layer(block, 1024, layers[3], shortcut_type, cardinality, stride=2) 86 | # last_duration = math.ceil(sample_duration / 16) 87 | # last_size = math.ceil(sample_size / 32) 88 | # self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1) 89 | # self.fc = nn.Linear(cardinality * 32 * block.expansion, num_classes) 90 | 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv3d): 93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 94 | m.weight.data.normal_(0, math.sqrt(2. / n)) 95 | elif isinstance(m, nn.BatchNorm3d): 96 | m.weight.data.fill_(1) 97 | m.bias.data.zero_() 98 | 99 | def _make_layer(self, block, planes, blocks, shortcut_type, cardinality, stride=1): 100 | downsample = None 101 | if stride != 1 or self.inplanes != planes * block.expansion: 102 | if shortcut_type == 'A': 103 | downsample = partial(downsample_basic_block, 104 | planes=planes * block.expansion, 105 | stride=stride) 106 | else: 107 | downsample = nn.Sequential( 108 | nn.Conv3d(self.inplanes, planes * block.expansion, 109 | kernel_size=1, stride=stride, bias=False), 110 | nn.BatchNorm3d(planes * block.expansion) 111 | ) 112 | 113 | layers = [] 114 | layers.append(block(self.inplanes, planes, cardinality, stride, downsample)) 115 | self.inplanes = planes * block.expansion 116 | for i in range(1, blocks): 117 | layers.append(block(self.inplanes, planes, cardinality)) 118 | 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | f_list = [] 123 | x = self.conv1(x) 124 | x = self.bn1(x) 125 | x = self.relu(x) 126 | x = self.maxpool(x) 127 | f_list.append(x) 128 | 129 | x = self.layer1(x) 130 | f_list.append(x) 131 | x = self.layer2(x) 132 | f_list.append(x) 133 | x = self.layer3(x) 134 | f_list.append(x) 135 | x = self.layer4(x) 136 | f_list.append(x) 137 | return f_list 138 | 139 | def get_fine_tuning_parameters(model, ft_begin_index): 140 | if ft_begin_index == 0: 141 | return model.parameters() 142 | 143 | ft_module_names = [] 144 | for i in range(ft_begin_index, 5): 145 | ft_module_names.append('layer{}'.format(ft_begin_index)) 146 | ft_module_names.append('fc') 147 | 148 | parameters = [] 149 | for k, v in model.named_parameters(): 150 | for ft_module in ft_module_names: 151 | if ft_module in k: 152 | parameters.append({'params': v}) 153 | break 154 | else: 155 | parameters.append({'params': v, 'lr': 0.0}) 156 | 157 | return parameters 158 | 159 | def resnet50(**kwargs): 160 | """Constructs a ResNet-50 model. 161 | """ 162 | model = ResNeXt(ResNeXtBottleneck, [3, 4, 6, 3], **kwargs) 163 | return model 164 | 165 | def resnet101(**kwargs): 166 | """Constructs a ResNet-101 model. 167 | """ 168 | model = ResNeXt(ResNeXtBottleneck, [3, 4, 23, 3], **kwargs) 169 | return model 170 | 171 | def resnet152(**kwargs): 172 | """Constructs a ResNet-101 model. 173 | """ 174 | model = ResNeXt(ResNeXtBottleneck, [3, 8, 36, 3], **kwargs) 175 | return model 176 | 177 | def generate_model(model_depth): 178 | if model_depth == 50: 179 | model = resnet50() 180 | elif model_depth == 101: 181 | model = resnet101() 182 | elif model_depth == 152: 183 | model = resnet152() 184 | else: 185 | model = None 186 | return model --------------------------------------------------------------------------------