├── tools ├── __init__.py ├── README.md ├── convert_xception_from_keras.py ├── convert_vgg_from_timm.py ├── convert_densenet_from_timm.py ├── convert_convmixer_from_timm.py ├── convert_resnet_from_timm.py ├── convert_hgnet_from_timm.py ├── convert_inception_next_from_timm.py ├── convert_vit_from_timm.py ├── convert_convnext_from_timm.py ├── convert_mobilenet_v2_from_timm.py ├── convert_regnet_from_timm.py ├── convert_inception_v3_from_timm.py ├── convert_ghostnet_from_timm.py ├── convert_repvgg_from_timm.py ├── convert_mobilenet_v3_from_timm.py ├── convert_mobileone_from_timm.py ├── convert_mobilevit_from_timm.py └── convert_efficientnet_from_timm.py ├── kimm ├── _src │ ├── blocks │ │ ├── __init__.py │ │ ├── squeeze_and_excitation.py │ │ ├── depthwise_separation.py │ │ ├── conv2d.py │ │ ├── transformer.py │ │ └── inverted_residual.py │ ├── export │ │ ├── __init__.py │ │ ├── export_onnx_test.py │ │ ├── export_tflite_test.py │ │ ├── export_onnx.py │ │ └── export_tflite.py │ ├── layers │ │ ├── __init__.py │ │ ├── layer_scale_test.py │ │ ├── learnable_affine_test.py │ │ ├── position_embedding_test.py │ │ ├── attention_test.py │ │ ├── layer_scale.py │ │ ├── learnable_affine.py │ │ ├── position_embedding.py │ │ ├── attention.py │ │ └── reparameterizable_conv2d_test.py │ ├── utils │ │ ├── __init__.py │ │ ├── module_utils.py │ │ ├── make_divisble.py │ │ ├── model_utils.py │ │ ├── model_utils_test.py │ │ ├── model_registry_test.py │ │ ├── model_registry.py │ │ └── timm_utils.py │ ├── version.py │ ├── models │ │ └── __init__.py │ └── kimm_export.py ├── models │ ├── xception │ │ └── __init__.py │ ├── base_model │ │ └── __init__.py │ ├── inception_v3 │ │ └── __init__.py │ ├── vgg │ │ └── __init__.py │ ├── convmixer │ │ └── __init__.py │ ├── inception_next │ │ └── __init__.py │ ├── densenet │ │ └── __init__.py │ ├── mobileone │ │ └── __init__.py │ ├── resnet │ │ └── __init__.py │ ├── mobilenet_v2 │ │ └── __init__.py │ ├── repvgg │ │ └── __init__.py │ ├── hgnet │ │ └── __init__.py │ ├── convnext │ │ └── __init__.py │ ├── ghostnet │ │ └── __init__.py │ ├── mobilevit │ │ └── __init__.py │ ├── vision_transformer │ │ └── __init__.py │ ├── mobilenet_v3 │ │ └── __init__.py │ ├── regnet │ │ └── __init__.py │ ├── efficientnet │ │ └── __init__.py │ └── __init__.py ├── export │ └── __init__.py ├── utils │ └── __init__.py ├── timm_utils │ └── __init__.py ├── __init__.py ├── layers │ └── __init__.py └── blocks │ └── __init__.py ├── docs └── banner │ └── kimm.png ├── shell ├── format.sh ├── api_gen.sh └── export_models.sh ├── api_gen.py ├── requirements.txt ├── .pre-commit-config.yaml ├── .github ├── dependabot.yml └── workflows │ ├── release.yml │ └── actions.yml ├── conftest.py ├── pyproject.toml └── .gitignore /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kimm/_src/blocks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kimm/_src/export/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kimm/_src/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kimm/_src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/banner/kimm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/james77777778/keras-image-models/HEAD/docs/banner/kimm.png -------------------------------------------------------------------------------- /kimm/_src/utils/module_utils.py: -------------------------------------------------------------------------------- 1 | from keras.src.utils.module_utils import LazyModule 2 | 3 | torch = LazyModule("torch") 4 | -------------------------------------------------------------------------------- /kimm/_src/version.py: -------------------------------------------------------------------------------- 1 | from kimm._src.kimm_export import kimm_export 2 | 3 | __version__ = "0.2.5" 4 | 5 | 6 | @kimm_export("kimm") 7 | def version(): 8 | return __version__ 9 | -------------------------------------------------------------------------------- /kimm/models/xception/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.xception import Xception 8 | -------------------------------------------------------------------------------- /kimm/models/base_model/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.base_model import BaseModel 8 | -------------------------------------------------------------------------------- /kimm/models/inception_v3/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.inception_v3 import InceptionV3 8 | -------------------------------------------------------------------------------- /shell/format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -Eeuo pipefail 3 | 4 | base_dir=$(dirname $(dirname $0)) 5 | isort --sp "${base_dir}/pyproject.toml" . 6 | black --config "${base_dir}/pyproject.toml" . 7 | ruff check --config "${base_dir}/pyproject.toml" --fix . 8 | -------------------------------------------------------------------------------- /kimm/export/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.export.export_onnx import export_onnx 8 | from kimm._src.export.export_tflite import export_tflite 9 | -------------------------------------------------------------------------------- /kimm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.utils.model_registry import list_models 8 | from kimm._src.utils.model_utils import get_reparameterized_model 9 | -------------------------------------------------------------------------------- /shell/api_gen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -Eeuo pipefail 3 | 4 | base_dir=$(dirname $(dirname $0)) 5 | 6 | echo "Generating api directory with public APIs..." 7 | python3 "${base_dir}"/api_gen.py 8 | 9 | echo "Formatting api directory..." 10 | bash "${base_dir}"/shell/format.sh 11 | 12 | echo -e "\nAPI generation finish!" 13 | -------------------------------------------------------------------------------- /kimm/models/vgg/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.vgg import VGG11 8 | from kimm._src.models.vgg import VGG13 9 | from kimm._src.models.vgg import VGG16 10 | from kimm._src.models.vgg import VGG19 11 | -------------------------------------------------------------------------------- /kimm/models/convmixer/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.convmixer import ConvMixer736D32 8 | from kimm._src.models.convmixer import ConvMixer1024D20 9 | from kimm._src.models.convmixer import ConvMixer1536D20 10 | -------------------------------------------------------------------------------- /kimm/models/inception_next/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.inception_next import InceptionNeXtBase 8 | from kimm._src.models.inception_next import InceptionNeXtSmall 9 | from kimm._src.models.inception_next import InceptionNeXtTiny 10 | -------------------------------------------------------------------------------- /api_gen.py: -------------------------------------------------------------------------------- 1 | import namex 2 | 3 | from kimm._src.version import __version__ 4 | 5 | namex.generate_api_files(package="kimm", code_directory="_src") 6 | 7 | # Add version string 8 | 9 | with open("kimm/__init__.py", "r") as f: 10 | contents = f.read() 11 | with open("kimm/__init__.py", "w") as f: 12 | contents += f'__version__ = "{__version__}"\n' 13 | f.write(contents) 14 | -------------------------------------------------------------------------------- /kimm/models/densenet/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.densenet import DenseNet121 8 | from kimm._src.models.densenet import DenseNet161 9 | from kimm._src.models.densenet import DenseNet169 10 | from kimm._src.models.densenet import DenseNet201 11 | -------------------------------------------------------------------------------- /kimm/models/mobileone/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.mobileone import MobileOneS0 8 | from kimm._src.models.mobileone import MobileOneS1 9 | from kimm._src.models.mobileone import MobileOneS2 10 | from kimm._src.models.mobileone import MobileOneS3 11 | -------------------------------------------------------------------------------- /kimm/models/resnet/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.resnet import ResNet18 8 | from kimm._src.models.resnet import ResNet34 9 | from kimm._src.models.resnet import ResNet50 10 | from kimm._src.models.resnet import ResNet101 11 | from kimm._src.models.resnet import ResNet152 12 | -------------------------------------------------------------------------------- /kimm/timm_utils/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.utils.timm_utils import assign_weights 8 | from kimm._src.utils.timm_utils import is_same_weights 9 | from kimm._src.utils.timm_utils import separate_keras_weights 10 | from kimm._src.utils.timm_utils import separate_torch_state_dict 11 | -------------------------------------------------------------------------------- /kimm/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm import blocks 8 | from kimm import export 9 | from kimm import layers 10 | from kimm import models 11 | from kimm import timm_utils 12 | from kimm import utils 13 | from kimm._src.utils.model_registry import list_models 14 | from kimm._src.version import version 15 | 16 | __version__ = "0.2.5" 17 | -------------------------------------------------------------------------------- /kimm/_src/utils/make_divisble.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | 4 | def make_divisible( 5 | v, 6 | divisor: int = 8, 7 | min_value: typing.Optional[float] = None, 8 | round_limit: float = 0.9, 9 | ): 10 | min_value = min_value or divisor 11 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 12 | # Make sure that round down does not go down by more than 10%. 13 | if new_v < round_limit * v: 14 | new_v += divisor 15 | return new_v 16 | -------------------------------------------------------------------------------- /kimm/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.layers.attention import Attention 8 | from kimm._src.layers.layer_scale import LayerScale 9 | from kimm._src.layers.learnable_affine import LearnableAffine 10 | from kimm._src.layers.position_embedding import PositionEmbedding 11 | from kimm._src.layers.reparameterizable_conv2d import ReparameterizableConv2D 12 | -------------------------------------------------------------------------------- /kimm/models/mobilenet_v2/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.mobilenet_v2 import MobileNetV2W050 8 | from kimm._src.models.mobilenet_v2 import MobileNetV2W100 9 | from kimm._src.models.mobilenet_v2 import MobileNetV2W110 10 | from kimm._src.models.mobilenet_v2 import MobileNetV2W120 11 | from kimm._src.models.mobilenet_v2 import MobileNetV2W140 12 | -------------------------------------------------------------------------------- /kimm/models/repvgg/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.repvgg import RepVGGA0 8 | from kimm._src.models.repvgg import RepVGGA1 9 | from kimm._src.models.repvgg import RepVGGA2 10 | from kimm._src.models.repvgg import RepVGGB0 11 | from kimm._src.models.repvgg import RepVGGB1 12 | from kimm._src.models.repvgg import RepVGGB2 13 | from kimm._src.models.repvgg import RepVGGB3 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Working GPU setup 2 | # CUDA 12.2, CUDNN 8.9 3 | # 4 | # tensorflow==2.15.0.post1 5 | # 6 | # --index-url https://download.pytorch.org/whl/cu121 7 | # torch torchvision 8 | # 9 | # -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 10 | # "jax[cuda12_local]" 11 | 12 | # Following is for github runner 13 | tensorflow-cpu>=2.16.1 14 | 15 | --extra-index-url https://download.pytorch.org/whl/cpu 16 | torch>=2.1.0 17 | torchvision>=0.16.0 18 | 19 | jax[cpu] 20 | 21 | keras>=3.3.0 22 | -------------------------------------------------------------------------------- /kimm/blocks/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.blocks.conv2d import apply_conv2d_block 8 | from kimm._src.blocks.depthwise_separation import ( 9 | apply_depthwise_separation_block, 10 | ) 11 | from kimm._src.blocks.inverted_residual import apply_inverted_residual_block 12 | from kimm._src.blocks.squeeze_and_excitation import apply_se_block 13 | from kimm._src.blocks.transformer import apply_mlp_block 14 | from kimm._src.blocks.transformer import apply_transformer_block 15 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | # Convert Model Weights 2 | 3 | ## Convert the weights from `timm` and `keras` 4 | 5 | - Use TensorFlow backend 6 | 7 | ```bash 8 | # At project root ./kimm/ 9 | ./shell/export_models.sh 10 | ``` 11 | 12 | ## Upload to Releases 13 | 14 | Setup `gh` 15 | 16 | [https://github.com/cli/cli/blob/trunk/docs/install_linux.md](https://github.com/cli/cli/blob/trunk/docs/install_linux.md) 17 | 18 | Upload the converted file 19 | 20 | ```bash 21 | # --clobber means overwrite the existing file 22 | gh release upload ... --clobber 23 | 24 | # For example: 25 | gh release upload 0.1.0 exported/* --clobber 26 | ``` 27 | -------------------------------------------------------------------------------- /kimm/models/hgnet/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.hgnet import HGNetBase 8 | from kimm._src.models.hgnet import HGNetSmall 9 | from kimm._src.models.hgnet import HGNetTiny 10 | from kimm._src.models.hgnet import HGNetV2B0 11 | from kimm._src.models.hgnet import HGNetV2B1 12 | from kimm._src.models.hgnet import HGNetV2B2 13 | from kimm._src.models.hgnet import HGNetV2B3 14 | from kimm._src.models.hgnet import HGNetV2B4 15 | from kimm._src.models.hgnet import HGNetV2B5 16 | from kimm._src.models.hgnet import HGNetV2B6 17 | -------------------------------------------------------------------------------- /kimm/models/convnext/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.convnext import ConvNeXtAtto 8 | from kimm._src.models.convnext import ConvNeXtBase 9 | from kimm._src.models.convnext import ConvNeXtFemto 10 | from kimm._src.models.convnext import ConvNeXtLarge 11 | from kimm._src.models.convnext import ConvNeXtNano 12 | from kimm._src.models.convnext import ConvNeXtPico 13 | from kimm._src.models.convnext import ConvNeXtSmall 14 | from kimm._src.models.convnext import ConvNeXtTiny 15 | from kimm._src.models.convnext import ConvNeXtXLarge 16 | -------------------------------------------------------------------------------- /kimm/_src/layers/layer_scale_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from absl.testing import parameterized 3 | from keras.src import testing 4 | 5 | from kimm._src.layers.layer_scale import LayerScale 6 | 7 | 8 | class LayerScaleTest(testing.TestCase, parameterized.TestCase): 9 | @pytest.mark.requires_trainable_backend 10 | def test_basic(self): 11 | self.run_layer_test( 12 | LayerScale, 13 | init_kwargs={"axis": -1}, 14 | input_shape=(1, 10), 15 | expected_output_shape=(1, 10), 16 | expected_num_trainable_weights=1, 17 | expected_num_non_trainable_weights=0, 18 | expected_num_losses=0, 19 | supports_masking=False, 20 | ) 21 | -------------------------------------------------------------------------------- /kimm/models/ghostnet/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.ghostnet import GhostNetV2W100 8 | from kimm._src.models.ghostnet import GhostNetV2W130 9 | from kimm._src.models.ghostnet import GhostNetV2W160 10 | from kimm._src.models.ghostnet import GhostNetW050 11 | from kimm._src.models.ghostnet import GhostNetW100 12 | from kimm._src.models.ghostnet import GhostNetW130 13 | from kimm._src.models.ghostnet_v3 import GhostNetV3W050 14 | from kimm._src.models.ghostnet_v3 import GhostNetV3W100 15 | from kimm._src.models.ghostnet_v3 import GhostNetV3W130 16 | from kimm._src.models.ghostnet_v3 import GhostNetV3W160 17 | -------------------------------------------------------------------------------- /kimm/models/mobilevit/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.mobilevit import MobileViTS 8 | from kimm._src.models.mobilevit import MobileViTV2W050 9 | from kimm._src.models.mobilevit import MobileViTV2W075 10 | from kimm._src.models.mobilevit import MobileViTV2W100 11 | from kimm._src.models.mobilevit import MobileViTV2W125 12 | from kimm._src.models.mobilevit import MobileViTV2W150 13 | from kimm._src.models.mobilevit import MobileViTV2W175 14 | from kimm._src.models.mobilevit import MobileViTV2W200 15 | from kimm._src.models.mobilevit import MobileViTXS 16 | from kimm._src.models.mobilevit import MobileViTXXS 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: check-ast 6 | - id: check-merge-conflict 7 | - id: check-toml 8 | - id: check-yaml 9 | - id: end-of-file-fixer 10 | files: \.py$ 11 | - id: debug-statements 12 | files: \.py$ 13 | - id: trailing-whitespace 14 | files: \.py$ 15 | 16 | - repo: https://github.com/pycqa/isort 17 | rev: 5.13.2 18 | hooks: 19 | - id: isort 20 | 21 | - repo: https://github.com/psf/black-pre-commit-mirror 22 | rev: 24.4.2 23 | hooks: 24 | - id: black 25 | 26 | - repo: https://github.com/astral-sh/ruff-pre-commit 27 | rev: v0.4.4 28 | hooks: 29 | - id: ruff 30 | args: 31 | - --fix 32 | - id: ruff-format 33 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "github-actions" 9 | directory: "/" 10 | schedule: 11 | interval: "monthly" 12 | groups: 13 | github-actions: 14 | patterns: 15 | - "*" 16 | - package-ecosystem: "pip" 17 | directory: "/" 18 | schedule: 19 | interval: "monthly" 20 | groups: 21 | python: 22 | patterns: 23 | - "*" 24 | -------------------------------------------------------------------------------- /kimm/_src/layers/learnable_affine_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from absl.testing import parameterized 3 | from keras.src import testing 4 | 5 | from kimm._src.layers.learnable_affine import LearnableAffine 6 | 7 | 8 | class LearnableAffineTest(testing.TestCase, parameterized.TestCase): 9 | @pytest.mark.requires_trainable_backend 10 | def test_basic(self): 11 | self.run_layer_test( 12 | LearnableAffine, 13 | init_kwargs={"scale_value": 1.0, "bias_value": 0.0}, 14 | input_shape=(1, 10), 15 | expected_output_shape=(1, 10), 16 | expected_num_trainable_weights=2, 17 | expected_num_non_trainable_weights=0, 18 | expected_num_losses=0, 19 | supports_masking=False, 20 | ) 21 | -------------------------------------------------------------------------------- /kimm/models/vision_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.vision_transformer import VisionTransformerBase16 8 | from kimm._src.models.vision_transformer import VisionTransformerBase32 9 | from kimm._src.models.vision_transformer import VisionTransformerLarge16 10 | from kimm._src.models.vision_transformer import VisionTransformerLarge32 11 | from kimm._src.models.vision_transformer import VisionTransformerSmall16 12 | from kimm._src.models.vision_transformer import VisionTransformerSmall32 13 | from kimm._src.models.vision_transformer import VisionTransformerTiny16 14 | from kimm._src.models.vision_transformer import VisionTransformerTiny32 15 | -------------------------------------------------------------------------------- /kimm/models/mobilenet_v3/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.mobilenet_v3 import LCNet035 8 | from kimm._src.models.mobilenet_v3 import LCNet050 9 | from kimm._src.models.mobilenet_v3 import LCNet075 10 | from kimm._src.models.mobilenet_v3 import LCNet100 11 | from kimm._src.models.mobilenet_v3 import LCNet150 12 | from kimm._src.models.mobilenet_v3 import MobileNetV3W050Small 13 | from kimm._src.models.mobilenet_v3 import MobileNetV3W075Small 14 | from kimm._src.models.mobilenet_v3 import MobileNetV3W100Large 15 | from kimm._src.models.mobilenet_v3 import MobileNetV3W100LargeMinimal 16 | from kimm._src.models.mobilenet_v3 import MobileNetV3W100Small 17 | from kimm._src.models.mobilenet_v3 import MobileNetV3W100SmallMinimal 18 | -------------------------------------------------------------------------------- /kimm/_src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from kimm._src.models import convmixer 2 | from kimm._src.models import convnext 3 | from kimm._src.models import densenet 4 | from kimm._src.models import efficientnet 5 | from kimm._src.models import ghostnet 6 | from kimm._src.models import ghostnet_v3 7 | from kimm._src.models import hgnet 8 | from kimm._src.models import inception_next 9 | from kimm._src.models import inception_v3 10 | from kimm._src.models import mobilenet_v2 11 | from kimm._src.models import mobilenet_v3 12 | from kimm._src.models import mobileone 13 | from kimm._src.models import mobilevit 14 | from kimm._src.models import regnet 15 | from kimm._src.models import repvgg 16 | from kimm._src.models import resnet 17 | from kimm._src.models import vgg 18 | from kimm._src.models import vision_transformer 19 | from kimm._src.models import xception 20 | from kimm._src.models.base_model import BaseModel 21 | -------------------------------------------------------------------------------- /shell/export_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -Euxo pipefail 3 | 4 | export CUDA_VISIBLE_DEVICES= 5 | export TF_CPP_MIN_LOG_LEVEL=3 6 | export KERAS_BACKEND=tensorflow 7 | python3 -m tools.convert_convmixer_from_timm 8 | python3 -m tools.convert_convnext_from_timm 9 | python3 -m tools.convert_densenet_from_timm 10 | python3 -m tools.convert_efficientnet_from_timm 11 | python3 -m tools.convert_ghostnet_from_timm 12 | python3 -m tools.convert_ghostnet_v3_from_github 13 | python3 -m tools.convert_hgnet_from_timm 14 | python3 -m tools.convert_inception_next_from_timm 15 | python3 -m tools.convert_inception_v3_from_timm 16 | python3 -m tools.convert_mobilenet_v2_from_timm 17 | python3 -m tools.convert_mobilenet_v3_from_timm 18 | python3 -m tools.convert_mobileone_from_timm 19 | python3 -m tools.convert_mobilevit_from_timm 20 | python3 -m tools.convert_regnet_from_timm 21 | python3 -m tools.convert_repvgg_from_timm 22 | python3 -m tools.convert_resnet_from_timm 23 | python3 -m tools.convert_vgg_from_timm 24 | python3 -m tools.convert_vit_from_timm 25 | python3 -m tools.convert_xception_from_keras 26 | 27 | echo "Export finished successfully!" 28 | -------------------------------------------------------------------------------- /kimm/models/regnet/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.regnet import RegNetX002 8 | from kimm._src.models.regnet import RegNetX004 9 | from kimm._src.models.regnet import RegNetX006 10 | from kimm._src.models.regnet import RegNetX008 11 | from kimm._src.models.regnet import RegNetX016 12 | from kimm._src.models.regnet import RegNetX032 13 | from kimm._src.models.regnet import RegNetX040 14 | from kimm._src.models.regnet import RegNetX064 15 | from kimm._src.models.regnet import RegNetX080 16 | from kimm._src.models.regnet import RegNetX120 17 | from kimm._src.models.regnet import RegNetX160 18 | from kimm._src.models.regnet import RegNetX320 19 | from kimm._src.models.regnet import RegNetY002 20 | from kimm._src.models.regnet import RegNetY004 21 | from kimm._src.models.regnet import RegNetY006 22 | from kimm._src.models.regnet import RegNetY008 23 | from kimm._src.models.regnet import RegNetY016 24 | from kimm._src.models.regnet import RegNetY032 25 | from kimm._src.models.regnet import RegNetY040 26 | from kimm._src.models.regnet import RegNetY064 27 | from kimm._src.models.regnet import RegNetY080 28 | from kimm._src.models.regnet import RegNetY120 29 | from kimm._src.models.regnet import RegNetY160 30 | from kimm._src.models.regnet import RegNetY320 31 | -------------------------------------------------------------------------------- /kimm/_src/export/export_onnx_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from absl.testing import parameterized 3 | from keras import backend 4 | from keras.src import testing 5 | 6 | from kimm._src import models 7 | from kimm._src.export import export_onnx 8 | 9 | 10 | class ExportOnnxTest(testing.TestCase, parameterized.TestCase): 11 | def get_model(self): 12 | input_shape = [3, 224, 224] # channels_first 13 | model = models.mobilenet_v3.MobileNetV3W050Small( 14 | include_preprocessing=False, weights=None 15 | ) 16 | return input_shape, model 17 | 18 | @classmethod 19 | def setUpClass(cls): 20 | cls.original_image_data_format = backend.image_data_format() 21 | 22 | @classmethod 23 | def tearDownClass(cls): 24 | backend.set_image_data_format(cls.original_image_data_format) 25 | 26 | @pytest.mark.skipif( 27 | backend.backend() != "torch", reason="Requires torch backend." 28 | ) 29 | def DISABLE_test_export_onnx_use(self): 30 | # TODO: turn on this test 31 | # SystemError: 32 | # returned a result with an exception set 33 | backend.set_image_data_format("channels_first") 34 | input_shape, model = self.get_model() 35 | 36 | temp_dir = self.get_temp_dir() 37 | 38 | export_onnx.export_onnx(model, input_shape, f"{temp_dir}/model.onnx") 39 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | pypi-publish: 9 | strategy: 10 | fail-fast: false 11 | name: Build wheel file and upload release to PyPI 12 | runs-on: ubuntu-latest 13 | permissions: 14 | id-token: write 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python 3.9 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: '3.9' 21 | - name: Get pip cache dir 22 | id: pip-cache 23 | run: | 24 | python -m pip install --upgrade pip setuptools 25 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT 26 | - name: Cache pip 27 | uses: actions/cache@v4 28 | with: 29 | path: ${{ steps.pip-cache.outputs.dir }} 30 | key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }} 31 | - name: Install dependencies 32 | run: | 33 | pip install -r requirements.txt --progress-bar off --upgrade 34 | pip install -e ".[tests]" --progress-bar off --upgrade 35 | - name: Build wheels 36 | shell: bash 37 | run: | 38 | pip install --upgrade pip setuptools wheel twine build 39 | python -m build 40 | - name: Publish package distributions to PyPI 41 | uses: pypa/gh-action-pypi-publish@release/v1 42 | with: 43 | verbose: true 44 | -------------------------------------------------------------------------------- /kimm/_src/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from kimm._src.kimm_export import kimm_export 2 | from kimm._src.models.base_model import BaseModel 3 | 4 | 5 | @kimm_export(parent_path=["kimm.utils"]) 6 | def get_reparameterized_model(model: BaseModel): 7 | """Get the reparameterized model. 8 | 9 | Internally, this function calls `get_reparameterized_model` from the 10 | provided `model`. 11 | 12 | Args: 13 | model: A `BaseModel` to convert to its reparameterized form. 14 | 15 | Returns: 16 | An instance of the same class as `model` in its reparameterized form. 17 | """ 18 | if not hasattr(model, "get_reparameterized_model"): 19 | raise ValueError( 20 | "There is no 'get_reparameterized_model' method in the model. " 21 | f"Received: model type={type(model)}" 22 | ) 23 | 24 | config = model.get_config() 25 | if config["reparameterized"] is True: 26 | return model 27 | 28 | config["reparameterized"] = True 29 | config["weights"] = None 30 | reparameterized_model = type(model).from_config(config) 31 | for layer, rep_layer in zip(model.layers, reparameterized_model.layers): 32 | if hasattr(layer, "get_reparameterized_weights"): 33 | kernel, bias = layer.get_reparameterized_weights() 34 | rep_layer.reparameterized_conv2d.kernel.assign(kernel) 35 | rep_layer.reparameterized_conv2d.bias.assign(bias) 36 | else: 37 | for weight, target_weight in zip(layer.weights, rep_layer.weights): 38 | target_weight.assign(weight) 39 | return reparameterized_model 40 | -------------------------------------------------------------------------------- /kimm/_src/layers/position_embedding_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from absl.testing import parameterized 3 | from keras import layers 4 | from keras import models 5 | from keras.src import testing 6 | 7 | from kimm._src.layers.position_embedding import PositionEmbedding 8 | 9 | 10 | class PositionEmbeddingTest(testing.TestCase, parameterized.TestCase): 11 | @pytest.mark.requires_trainable_backend 12 | def test_basic(self): 13 | self.run_layer_test( 14 | PositionEmbedding, 15 | init_kwargs={"height": 2, "width": 5}, 16 | input_shape=(1, 10, 10), 17 | expected_output_shape=(1, 11, 10), 18 | expected_num_trainable_weights=2, 19 | expected_num_non_trainable_weights=0, 20 | expected_num_losses=0, 21 | supports_masking=False, 22 | ) 23 | 24 | def test_embedding_resizing(self): 25 | temp_dir = self.get_temp_dir() 26 | model = models.Sequential( 27 | [layers.Input(shape=[256, 8]), PositionEmbedding(16, 16)] 28 | ) 29 | model.save(f"{temp_dir}/model.keras") 30 | 31 | # Resize from (16, 16) to (8, 8) 32 | model = models.Sequential( 33 | [layers.Input(shape=[64, 8]), PositionEmbedding(8, 8)] 34 | ) 35 | model.load_weights(f"{temp_dir}/model.keras") 36 | 37 | @pytest.mark.requires_trainable_backend 38 | def test_invalid_input_shape(self): 39 | inputs = layers.Input([3]) 40 | with self.assertRaisesRegex( 41 | ValueError, "PositionEmbedding only accepts 3-dimensional input." 42 | ): 43 | PositionEmbedding(2, 2)(inputs) 44 | -------------------------------------------------------------------------------- /kimm/_src/utils/model_utils_test.py: -------------------------------------------------------------------------------- 1 | from keras import random 2 | from keras.src import testing 3 | 4 | from kimm._src.models.regnet import RegNetX002 5 | from kimm._src.models.repvgg import RepVGG 6 | from kimm._src.utils.model_utils import get_reparameterized_model 7 | 8 | 9 | class ModelUtilsTest(testing.TestCase): 10 | def test_get_reparameterized_model(self): 11 | # dummy RepVGG with random initialization 12 | model = RepVGG( 13 | [1, 1, 1, 1], 14 | [8, 8, 8, 8], 15 | 8, 16 | include_preprocessing=False, 17 | weights=None, 18 | ) 19 | reparameterized_model = get_reparameterized_model(model) 20 | x = random.uniform([1, 32, 32, 3]) 21 | 22 | y1 = model(x, training=False) 23 | y2 = reparameterized_model(x, training=False) 24 | 25 | self.assertAllClose(y1, y2, atol=1e-5) 26 | 27 | def test_get_reparameterized_model_already(self): 28 | # dummy RepVGG with random initialization and reparameterized=True 29 | model = RepVGG( 30 | [1, 1, 1, 1], 31 | [8, 8, 8, 8], 32 | 8, 33 | reparameterized=True, 34 | include_preprocessing=False, 35 | weights=None, 36 | ) 37 | reparameterized_model = get_reparameterized_model(model) 38 | 39 | # same object 40 | self.assertEqual(id(model), id(reparameterized_model)) 41 | 42 | def test_get_reparameterized_model_invalid(self): 43 | model = RegNetX002(weights=None) 44 | 45 | with self.assertRaisesRegex( 46 | ValueError, "There is no 'get_reparameterized_model' method" 47 | ): 48 | get_reparameterized_model(model) 49 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from keras import backend 5 | 6 | 7 | def pytest_addoption(parser): 8 | parser.addoption( 9 | "--run_serialization", 10 | action="store_true", 11 | default=False, 12 | help="run serialization tests", 13 | ) 14 | 15 | 16 | def pytest_configure(config): 17 | import tensorflow as tf 18 | 19 | # disable tensorflow gpu memory preallocation 20 | physical_devices = tf.config.list_physical_devices("GPU") 21 | for device in physical_devices: 22 | tf.config.experimental.set_memory_growth(device, True) 23 | 24 | # disable jax gpu memory preallocation 25 | # https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html 26 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 27 | 28 | config.addinivalue_line( 29 | "markers", "serialization: mark test as a serialization test" 30 | ) 31 | config.addinivalue_line( 32 | "markers", 33 | "requires_trainable_backend: mark test for trainable backend only", 34 | ) 35 | 36 | 37 | def pytest_collection_modifyitems(config, items): 38 | run_serialization_tests = config.getoption("--run_serialization") 39 | skip_serialization = pytest.mark.skipif( 40 | not run_serialization_tests, 41 | reason="need --run_serialization option to run", 42 | ) 43 | requires_trainable_backend = pytest.mark.skipif( 44 | backend.backend() == "numpy", reason="require trainable backend" 45 | ) 46 | for item in items: 47 | if "requires_trainable_backend" in item.keywords: 48 | item.add_marker(requires_trainable_backend) 49 | if "serialization" in item.name: 50 | item.add_marker(skip_serialization) 51 | -------------------------------------------------------------------------------- /kimm/models/efficientnet/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.efficientnet import EfficientNetB0 8 | from kimm._src.models.efficientnet import EfficientNetB1 9 | from kimm._src.models.efficientnet import EfficientNetB2 10 | from kimm._src.models.efficientnet import EfficientNetB3 11 | from kimm._src.models.efficientnet import EfficientNetB4 12 | from kimm._src.models.efficientnet import EfficientNetB5 13 | from kimm._src.models.efficientnet import EfficientNetB6 14 | from kimm._src.models.efficientnet import EfficientNetB7 15 | from kimm._src.models.efficientnet import EfficientNetLiteB0 16 | from kimm._src.models.efficientnet import EfficientNetLiteB1 17 | from kimm._src.models.efficientnet import EfficientNetLiteB2 18 | from kimm._src.models.efficientnet import EfficientNetLiteB3 19 | from kimm._src.models.efficientnet import EfficientNetLiteB4 20 | from kimm._src.models.efficientnet import EfficientNetV2B0 21 | from kimm._src.models.efficientnet import EfficientNetV2B1 22 | from kimm._src.models.efficientnet import EfficientNetV2B2 23 | from kimm._src.models.efficientnet import EfficientNetV2B3 24 | from kimm._src.models.efficientnet import EfficientNetV2L 25 | from kimm._src.models.efficientnet import EfficientNetV2M 26 | from kimm._src.models.efficientnet import EfficientNetV2S 27 | from kimm._src.models.efficientnet import EfficientNetV2XL 28 | from kimm._src.models.efficientnet import TinyNetA 29 | from kimm._src.models.efficientnet import TinyNetB 30 | from kimm._src.models.efficientnet import TinyNetC 31 | from kimm._src.models.efficientnet import TinyNetD 32 | from kimm._src.models.efficientnet import TinyNetE 33 | -------------------------------------------------------------------------------- /kimm/_src/layers/attention_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import pytest 3 | from absl.testing import parameterized 4 | from keras.src import testing 5 | 6 | from kimm._src.layers.attention import Attention 7 | 8 | 9 | class AttentionTest(testing.TestCase, parameterized.TestCase): 10 | @pytest.mark.requires_trainable_backend 11 | def test_basic_3d(self): 12 | self.run_layer_test( 13 | Attention, 14 | init_kwargs={"hidden_dim": 20, "num_heads": 2}, 15 | input_shape=(1, 10, 20), 16 | expected_output_shape=(1, 10, 20), 17 | expected_num_trainable_weights=3, 18 | expected_num_non_trainable_weights=0, 19 | expected_num_losses=0, 20 | supports_masking=False, 21 | ) 22 | 23 | @pytest.mark.requires_trainable_backend 24 | def test_basic_4d(self): 25 | self.run_layer_test( 26 | Attention, 27 | init_kwargs={"hidden_dim": 20, "num_heads": 2}, 28 | input_shape=(1, 2, 10, 20), 29 | expected_output_shape=(1, 2, 10, 20), 30 | expected_num_trainable_weights=3, 31 | expected_num_non_trainable_weights=0, 32 | expected_num_losses=0, 33 | supports_masking=False, 34 | ) 35 | 36 | def test_invalid_ndim(self): 37 | # Test 2D 38 | inputs = keras.Input(shape=[1]) 39 | with self.assertRaisesRegex( 40 | ValueError, "The ndim of the inputs must be 3 or 4." 41 | ): 42 | Attention(1, 1)(inputs) 43 | 44 | # Test 5D 45 | inputs = keras.Input(shape=[1, 2, 3, 4]) 46 | with self.assertRaisesRegex( 47 | ValueError, "The ndim of the inputs must be 3 or 4." 48 | ): 49 | Attention(1, 1)(inputs) 50 | -------------------------------------------------------------------------------- /tools/convert_xception_from_keras.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import keras 5 | import numpy as np 6 | 7 | from kimm.models import xception 8 | 9 | ori_model_classes = [ 10 | keras.applications.Xception, 11 | ] 12 | keras_model_classes = [ 13 | xception.Xception, 14 | ] 15 | 16 | for ori_model_class, keras_model_class in zip( 17 | ori_model_classes, keras_model_classes 18 | ): 19 | """ 20 | Prepare timm model and keras model 21 | """ 22 | input_shape = (299, 299, 3) 23 | ori_model = ori_model_class( 24 | input_shape=input_shape, classifier_activation="linear" 25 | ) 26 | keras_model = keras_model_class( 27 | input_shape=input_shape, 28 | include_preprocessing=False, 29 | classifier_activation="linear", 30 | weights=None, 31 | ) 32 | with tempfile.TemporaryDirectory() as temp_dir: 33 | ori_model.save_weights(temp_dir + "/model.weights.h5") 34 | keras_model.load_weights(temp_dir + "/model.weights.h5") 35 | 36 | """ 37 | Verify model outputs 38 | """ 39 | np.random.seed(2023) 40 | keras_data = np.random.uniform(size=[1] + list(input_shape)).astype( 41 | "float32" 42 | ) 43 | ori_y = ori_model(keras_data, training=False) 44 | keras_y = keras_model(keras_data, training=False) 45 | ori_y = keras.ops.convert_to_numpy(ori_y) 46 | keras_y = keras.ops.convert_to_numpy(keras_y) 47 | np.testing.assert_allclose(ori_y, keras_y, atol=1e-5) 48 | print(f"{keras_model_class.__name__}: output matched!") 49 | 50 | """ 51 | Save converted model 52 | """ 53 | os.makedirs("exported", exist_ok=True) 54 | export_path = f"exported/{keras_model.name.lower()}.keras" 55 | keras_model.save(export_path) 56 | print(f"Export to {export_path}") 57 | -------------------------------------------------------------------------------- /kimm/_src/blocks/squeeze_and_excitation.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from keras import backend 4 | from keras import layers 5 | 6 | from kimm._src.kimm_export import kimm_export 7 | from kimm._src.utils.make_divisble import make_divisible 8 | 9 | 10 | @kimm_export(parent_path=["kimm.blocks"]) 11 | def apply_se_block( 12 | inputs, 13 | se_ratio: float = 0.25, 14 | activation: typing.Optional[str] = "relu", 15 | gate_activation: typing.Optional[str] = "sigmoid", 16 | make_divisible_number: typing.Optional[int] = None, 17 | se_input_channels: typing.Optional[int] = None, 18 | name: str = "se_block", 19 | ): 20 | """Squeeze and Excitation.""" 21 | channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 22 | input_channels = inputs.shape[channels_axis] 23 | if se_input_channels is None: 24 | se_input_channels = input_channels 25 | if make_divisible_number is None: 26 | se_channels = round(se_input_channels * se_ratio) 27 | else: 28 | se_channels = make_divisible( 29 | se_input_channels * se_ratio, make_divisible_number 30 | ) 31 | 32 | x = inputs 33 | x = layers.GlobalAveragePooling2D( 34 | data_format=backend.image_data_format(), 35 | keepdims=True, 36 | name=f"{name}_mean", 37 | )(x) 38 | x = layers.Conv2D( 39 | se_channels, 1, use_bias=True, name=f"{name}_conv_reduce" 40 | )(x) 41 | if activation is not None: 42 | x = layers.Activation(activation, name=f"{name}_act1")(x) 43 | x = layers.Conv2D( 44 | input_channels, 1, use_bias=True, name=f"{name}_conv_expand" 45 | )(x) 46 | if activation is not None: 47 | x = layers.Activation(gate_activation, name=f"{name}_gate")(x) 48 | x = layers.Multiply(name=name)([inputs, x]) 49 | return x 50 | -------------------------------------------------------------------------------- /kimm/_src/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras import initializers 3 | from keras import layers 4 | from keras import ops 5 | 6 | from kimm._src.kimm_export import kimm_export 7 | 8 | 9 | @kimm_export(parent_path=["kimm.layers"]) 10 | @keras.saving.register_keras_serializable(package="kimm") 11 | class LayerScale(layers.Layer): 12 | def __init__( 13 | self, 14 | axis: int = -1, 15 | initializer: initializers.Initializer = initializers.Constant(1e-5), 16 | **kwargs, 17 | ): 18 | super().__init__(**kwargs) 19 | self.axis = axis 20 | self.initializer = initializer 21 | 22 | def build(self, input_shape): 23 | if isinstance(self.axis, list): 24 | shape = tuple([input_shape[dim] for dim in self.axis]) 25 | else: 26 | shape = (input_shape[self.axis],) 27 | self.axis = [self.axis] 28 | self.gamma = self.add_weight( 29 | shape, initializer=self.initializer, name="gamma" 30 | ) 31 | self.built = True 32 | 33 | def call(self, inputs, training=None, mask=None): 34 | # Broadcasting only necessary for norm when the axis is not just 35 | # the last dimension 36 | input_shape = inputs.shape 37 | ndims = len(inputs.shape) 38 | broadcast_shape = [1] * ndims 39 | for dim in self.axis: 40 | broadcast_shape[dim] = input_shape[dim] 41 | gamma = ops.reshape(self.gamma, broadcast_shape) 42 | gamma = ops.cast(gamma, self.compute_dtype) 43 | return ops.multiply(inputs, gamma) 44 | 45 | def get_config(self): 46 | config = super().get_config() 47 | config.update( 48 | { 49 | "axis": self.axis, 50 | "initializer": initializers.serialize(self.initializer), 51 | "name": self.name, 52 | } 53 | ) 54 | return config 55 | -------------------------------------------------------------------------------- /kimm/_src/layers/learnable_affine.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras import layers 3 | from keras import ops 4 | 5 | from kimm._src.kimm_export import kimm_export 6 | 7 | 8 | @kimm_export(parent_path=["kimm.layers"]) 9 | @keras.saving.register_keras_serializable(package="kimm") 10 | class LearnableAffine(layers.Layer): 11 | def __init__(self, scale_value=1.0, bias_value=0.0, **kwargs): 12 | super().__init__(**kwargs) 13 | if isinstance(scale_value, int): 14 | raise ValueError( 15 | f"scale_value must be a integer. Received: {scale_value}" 16 | ) 17 | if isinstance(bias_value, int): 18 | raise ValueError( 19 | f"bias_value must be a integer. Received: {bias_value}" 20 | ) 21 | self.scale_value = scale_value 22 | self.bias_value = bias_value 23 | 24 | def build(self, input_shape): 25 | self.scale = self.add_weight( 26 | shape=(), 27 | initializer=lambda shape, dtype: ops.cast(self.scale_value, dtype), 28 | trainable=True, 29 | name="scale", 30 | ) 31 | self.bias = self.add_weight( 32 | shape=(), 33 | initializer=lambda shape, dtype: ops.cast(self.bias_value, dtype), 34 | trainable=True, 35 | name="bias", 36 | ) 37 | self.built = True 38 | 39 | def call(self, inputs, training=None, mask=None): 40 | scale = ops.cast(self.scale, self.compute_dtype) 41 | bias = ops.cast(self.bias, self.compute_dtype) 42 | return ops.add(ops.multiply(inputs, scale), bias) 43 | 44 | def get_config(self): 45 | config = super().get_config() 46 | config.update( 47 | { 48 | "scale_value": self.scale_value, 49 | "bias_value": self.bias_value, 50 | "name": self.name, 51 | } 52 | ) 53 | return config 54 | -------------------------------------------------------------------------------- /kimm/_src/blocks/depthwise_separation.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from keras import backend 4 | from keras import layers 5 | 6 | from kimm._src.blocks.conv2d import apply_conv2d_block 7 | from kimm._src.blocks.squeeze_and_excitation import apply_se_block 8 | from kimm._src.kimm_export import kimm_export 9 | 10 | 11 | @kimm_export(parent_path=["kimm.blocks"]) 12 | def apply_depthwise_separation_block( 13 | inputs, 14 | filters: int, 15 | depthwise_kernel_size: int = 3, 16 | pointwise_kernel_size: int = 1, 17 | strides: int = 1, 18 | se_ratio: float = 0.0, 19 | activation: typing.Optional[str] = "swish", 20 | se_activation: typing.Optional[str] = "relu", 21 | se_gate_activation: typing.Optional[str] = "sigmoid", 22 | se_make_divisible_number: typing.Optional[int] = None, 23 | pw_activation: typing.Optional[str] = None, 24 | has_skip: bool = True, 25 | bn_epsilon: float = 1e-5, 26 | padding: typing.Optional[typing.Literal["same", "valid"]] = None, 27 | name: str = "depthwise_separation_block", 28 | ): 29 | """Conv2D block + (SqueezeAndExcitation) + Conv2D.""" 30 | channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 31 | input_filters = inputs.shape[channels_axis] 32 | if has_skip and (strides != 1 or input_filters != filters): 33 | raise ValueError( 34 | "If `has_skip=True`, strides must be 1 and `filters` must be the " 35 | "same as input_filters. " 36 | f"Received: strides={strides}, filters={filters}, " 37 | f"input_filters={input_filters}" 38 | ) 39 | 40 | x = inputs 41 | x = apply_conv2d_block( 42 | x, 43 | kernel_size=depthwise_kernel_size, 44 | strides=strides, 45 | activation=activation, 46 | use_depthwise=True, 47 | bn_epsilon=bn_epsilon, 48 | padding=padding, 49 | name=f"{name}_conv_dw", 50 | ) 51 | if se_ratio > 0: 52 | x = apply_se_block( 53 | x, 54 | se_ratio, 55 | activation=se_activation, 56 | gate_activation=se_gate_activation, 57 | make_divisible_number=se_make_divisible_number, 58 | name=f"{name}_se", 59 | ) 60 | x = apply_conv2d_block( 61 | x, 62 | filters, 63 | pointwise_kernel_size, 64 | 1, 65 | activation=pw_activation, 66 | bn_epsilon=bn_epsilon, 67 | padding=padding, 68 | name=f"{name}_conv_pw", 69 | ) 70 | if has_skip: 71 | x = layers.Add()([x, inputs]) 72 | return x 73 | -------------------------------------------------------------------------------- /kimm/_src/export/export_tflite_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from absl.testing import parameterized 3 | from keras import backend 4 | from keras import ops 5 | from keras import random 6 | from keras.src import testing 7 | 8 | from kimm._src import models 9 | from kimm._src.export import export_tflite 10 | 11 | 12 | class ExportTFLiteTest(testing.TestCase, parameterized.TestCase): 13 | def get_model_and_representative_dataset(self): 14 | input_shape = [224, 224, 3] 15 | model = models.mobilenet_v3.MobileNetV3W050Small( 16 | include_preprocessing=False, weights=None 17 | ) 18 | 19 | def representative_dataset(): 20 | for _ in range(10): 21 | yield [ 22 | ops.convert_to_numpy( 23 | random.uniform([1, *input_shape], maxval=255.0) 24 | ) 25 | ] 26 | 27 | return input_shape, model, representative_dataset 28 | 29 | @classmethod 30 | def setUpClass(cls): 31 | cls.original_image_data_format = backend.image_data_format() 32 | 33 | @classmethod 34 | def tearDownClass(cls): 35 | backend.set_image_data_format(cls.original_image_data_format) 36 | 37 | @pytest.mark.skipif( 38 | backend.backend() != "tensorflow", reason="Requires tensorflow backend." 39 | ) 40 | def test_export_tflite_fp32(self): 41 | (input_shape, model, _) = self.get_model_and_representative_dataset() 42 | temp_dir = self.get_temp_dir() 43 | 44 | export_tflite.export_tflite( 45 | model, input_shape, f"{temp_dir}/model_fp32.onnx", "float32" 46 | ) 47 | 48 | @pytest.mark.skipif( 49 | backend.backend() != "tensorflow", reason="Requires tensorflow backend." 50 | ) 51 | def test_export_tflite_fp16(self): 52 | (input_shape, model, _) = self.get_model_and_representative_dataset() 53 | temp_dir = self.get_temp_dir() 54 | 55 | export_tflite.export_tflite( 56 | model, input_shape, f"{temp_dir}/model_fp16.tflite", "float16" 57 | ) 58 | 59 | @pytest.mark.skipif( 60 | backend.backend() != "tensorflow", reason="Requires tensorflow backend." 61 | ) 62 | def test_export_tflite_int8(self): 63 | ( 64 | input_shape, 65 | model, 66 | representative_dataset, 67 | ) = self.get_model_and_representative_dataset() 68 | temp_dir = self.get_temp_dir() 69 | 70 | export_tflite.export_tflite( 71 | model, 72 | input_shape, 73 | f"{temp_dir}/model_int8.tflite", 74 | "int8", 75 | representative_dataset, 76 | ) 77 | -------------------------------------------------------------------------------- /kimm/_src/blocks/conv2d.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from keras import backend 4 | from keras import layers 5 | from keras.src.utils.argument_validation import standardize_tuple 6 | 7 | from kimm._src.kimm_export import kimm_export 8 | 9 | 10 | @kimm_export(parent_path=["kimm.blocks"]) 11 | def apply_conv2d_block( 12 | inputs, 13 | filters: typing.Optional[int] = None, 14 | kernel_size: typing.Union[int, typing.Sequence[int]] = 1, 15 | strides: int = 1, 16 | groups: int = 1, 17 | activation: typing.Optional[str] = None, 18 | use_depthwise: bool = False, 19 | has_skip: bool = False, 20 | bn_momentum: float = 0.9, 21 | bn_epsilon: float = 1e-5, 22 | padding: typing.Optional[typing.Literal["same", "valid"]] = None, 23 | name="conv2d_block", 24 | ): 25 | """(ZeroPadding) + Conv2D/DepthwiseConv2D + BN + (Activation).""" 26 | if kernel_size is None: 27 | raise ValueError( 28 | f"kernel_size must be passed. Received: kernel_size={kernel_size}" 29 | ) 30 | kernel_size = standardize_tuple(kernel_size, 2, "kernel_size") 31 | 32 | channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 33 | input_filters = inputs.shape[channels_axis] 34 | if has_skip and (strides != 1 or input_filters != filters): 35 | raise ValueError( 36 | "If `has_skip=True`, strides must be 1 and `filters` must be the " 37 | "same as input_filters. " 38 | f"Received: strides={strides}, filters={filters}, " 39 | f"input_filters={input_filters}" 40 | ) 41 | x = inputs 42 | 43 | if padding is None: 44 | padding = "same" 45 | if strides > 1: 46 | padding = "valid" 47 | x = layers.ZeroPadding2D( 48 | ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2), 49 | name=f"{name}_pad", 50 | )(x) 51 | 52 | if not use_depthwise: 53 | x = layers.Conv2D( 54 | filters, 55 | kernel_size, 56 | strides, 57 | padding=padding, 58 | groups=groups, 59 | use_bias=False, 60 | name=f"{name}_conv2d", 61 | )(x) 62 | else: 63 | x = layers.DepthwiseConv2D( 64 | kernel_size, 65 | strides, 66 | padding=padding, 67 | use_bias=False, 68 | name=f"{name}_dwconv2d", 69 | )(x) 70 | x = layers.BatchNormalization( 71 | axis=channels_axis, 72 | name=f"{name}_bn", 73 | momentum=bn_momentum, 74 | epsilon=bn_epsilon, 75 | )(x) 76 | if activation is not None: 77 | x = layers.Activation(activation, name=name)(x) 78 | if has_skip: 79 | x = layers.Add()([x, inputs]) 80 | return x 81 | -------------------------------------------------------------------------------- /kimm/_src/blocks/transformer.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from keras import backend 4 | from keras import layers 5 | 6 | from kimm._src.kimm_export import kimm_export 7 | from kimm._src.layers.attention import Attention 8 | 9 | 10 | @kimm_export(parent_path=["kimm.blocks"]) 11 | def apply_mlp_block( 12 | inputs, 13 | hidden_dim: int, 14 | output_dim: typing.Optional[int] = None, 15 | activation: str = "gelu", 16 | use_bias: bool = True, 17 | dropout_rate: float = 0.0, 18 | use_conv_mlp: bool = False, 19 | data_format: typing.Optional[str] = None, 20 | name: str = "mlp_block", 21 | ): 22 | """Dense/Conv2D + Activation + Dense/Conv2D.""" 23 | if data_format is None: 24 | data_format = backend.image_data_format() 25 | dim_axis = -1 if data_format == "channels_last" else 1 26 | input_dim = inputs.shape[dim_axis] 27 | output_dim = output_dim or input_dim 28 | 29 | x = inputs 30 | if use_conv_mlp: 31 | x = layers.Conv2D( 32 | hidden_dim, 1, use_bias=use_bias, name=f"{name}_fc1_conv2d" 33 | )(x) 34 | else: 35 | x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x) 36 | x = layers.Activation(activation, name=f"{name}_act")(x) 37 | x = layers.Dropout(dropout_rate, name=f"{name}_drop1")(x) 38 | if use_conv_mlp: 39 | x = layers.Conv2D( 40 | output_dim, 1, use_bias=use_bias, name=f"{name}_fc2_conv2d" 41 | )(x) 42 | else: 43 | x = layers.Dense(output_dim, use_bias=use_bias, name=f"{name}_fc2")(x) 44 | x = layers.Dropout(dropout_rate, name=f"{name}_drop2")(x) 45 | return x 46 | 47 | 48 | @kimm_export(parent_path=["kimm.blocks"]) 49 | def apply_transformer_block( 50 | inputs, 51 | dim: int, 52 | num_heads: int, 53 | mlp_ratio: float = 4.0, 54 | use_qkv_bias: bool = False, 55 | projection_dropout_rate: float = 0.0, 56 | attention_dropout_rate: float = 0.0, 57 | activation: str = "gelu", 58 | name: str = "transformer_block", 59 | ): 60 | """LN + Attention + LN + MLP block.""" 61 | # data_format must be "channels_last" 62 | x = inputs 63 | residual_1 = x 64 | 65 | x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm1")(x) 66 | x = Attention( 67 | dim, 68 | num_heads, 69 | use_qkv_bias, 70 | attention_dropout_rate, 71 | projection_dropout_rate, 72 | name=f"{name}_attn", 73 | )(x) 74 | x = layers.Add()([residual_1, x]) 75 | 76 | residual_2 = x 77 | x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm2")(x) 78 | x = apply_mlp_block( 79 | x, 80 | int(dim * mlp_ratio), 81 | activation=activation, 82 | dropout_rate=projection_dropout_rate, 83 | data_format="channels_last", 84 | name=f"{name}_mlp", 85 | ) 86 | x = layers.Add()([residual_2, x]) 87 | return x 88 | -------------------------------------------------------------------------------- /kimm/_src/blocks/inverted_residual.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from keras import backend 4 | from keras import layers 5 | 6 | from kimm._src.blocks.conv2d import apply_conv2d_block 7 | from kimm._src.blocks.squeeze_and_excitation import apply_se_block 8 | from kimm._src.kimm_export import kimm_export 9 | from kimm._src.utils.make_divisble import make_divisible 10 | 11 | 12 | @kimm_export(parent_path=["kimm.blocks"]) 13 | def apply_inverted_residual_block( 14 | inputs, 15 | filters: int, 16 | depthwise_kernel_size: int = 3, 17 | expansion_kernel_size: int = 1, 18 | pointwise_kernel_size: int = 1, 19 | strides: int = 1, 20 | expansion_ratio: float = 1.0, 21 | se_ratio: float = 0.0, 22 | activation: str = "swish", 23 | se_channels: typing.Optional[int] = None, 24 | se_activation: typing.Optional[str] = None, 25 | se_gate_activation: typing.Optional[str] = "sigmoid", 26 | se_make_divisible_number: typing.Optional[int] = None, 27 | bn_epsilon: float = 1e-5, 28 | padding: typing.Optional[typing.Literal["same", "valid"]] = None, 29 | name: str = "inverted_residual_block", 30 | ): 31 | """Conv2D block + DepthwiseConv2D block + (SE) + Conv2D.""" 32 | channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 33 | input_channels = inputs.shape[channels_axis] 34 | hidden_channels = make_divisible(input_channels * expansion_ratio) 35 | has_skip = strides == 1 and input_channels == filters 36 | 37 | x = inputs 38 | # Point-wise expansion 39 | x = apply_conv2d_block( 40 | x, 41 | hidden_channels, 42 | expansion_kernel_size, 43 | 1, 44 | activation=activation, 45 | bn_epsilon=bn_epsilon, 46 | padding=padding, 47 | name=f"{name}_conv_pw", 48 | ) 49 | # Depth-wise convolution 50 | x = apply_conv2d_block( 51 | x, 52 | kernel_size=depthwise_kernel_size, 53 | strides=strides, 54 | activation=activation, 55 | use_depthwise=True, 56 | bn_epsilon=bn_epsilon, 57 | padding=padding, 58 | name=f"{name}_conv_dw", 59 | ) 60 | # Squeeze-and-excitation 61 | if se_ratio > 0: 62 | x = apply_se_block( 63 | x, 64 | se_ratio, 65 | activation=se_activation or activation, 66 | gate_activation=se_gate_activation, 67 | se_input_channels=se_channels, 68 | make_divisible_number=se_make_divisible_number, 69 | name=f"{name}_se", 70 | ) 71 | # Point-wise linear projection 72 | x = apply_conv2d_block( 73 | x, 74 | filters, 75 | pointwise_kernel_size, 76 | 1, 77 | activation=None, 78 | bn_epsilon=bn_epsilon, 79 | padding=padding, 80 | name=f"{name}_conv_pwl", 81 | ) 82 | if has_skip: 83 | x = layers.Add()([x, inputs]) 84 | return x 85 | -------------------------------------------------------------------------------- /kimm/_src/export/export_onnx.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import typing 3 | 4 | from keras import backend 5 | from keras import layers 6 | from keras import models 7 | from keras import ops 8 | 9 | from kimm._src.kimm_export import kimm_export 10 | from kimm._src.models.base_model import BaseModel 11 | from kimm._src.utils.module_utils import torch 12 | 13 | 14 | @kimm_export(parent_path=["kimm.export"]) 15 | def export_onnx( 16 | model: BaseModel, 17 | input_shape: typing.Union[int, typing.Sequence[int]], 18 | export_path: typing.Union[str, pathlib.Path], 19 | batch_size: int = 1, 20 | ): 21 | """Export the model to onnx format (in float32). 22 | 23 | Only torch backend with 'channels_first' is supported. The onnx model will 24 | be generated using `torch.onnx.export` and optimized through `onnxsim` and 25 | `onnxoptimizer`. 26 | 27 | Note that `onnx`, `onnxruntime`, `onnxsim` and `onnxoptimizer` must be 28 | installed. 29 | 30 | Args: 31 | model: keras.Model, the model to be exported. 32 | input_shape: int or sequence of int, specifying the shape of the input. 33 | export_path: str or pathlib.Path, specifying the path to export. 34 | batch_size: int, specifying the batch size of the input, 35 | defaults to `1`. 36 | """ 37 | if backend.backend() != "torch": 38 | raise ValueError("`export_onnx` only supports torch backend") 39 | if backend.image_data_format() != "channels_first": 40 | raise ValueError( 41 | "`export_onnx` only supports 'channels_first' data format." 42 | ) 43 | try: 44 | import onnx 45 | import onnxoptimizer 46 | import onnxsim 47 | except ModuleNotFoundError: 48 | raise ModuleNotFoundError( 49 | "Failed to import 'onnx', 'onnxsim' or 'onnxoptimizer'. " 50 | "Please install them by the following instruction:\n" 51 | "'pip install torch onnx onnxsim onnxoptimizer'" 52 | ) 53 | 54 | if isinstance(input_shape, int): 55 | input_shape = [3, input_shape, input_shape] 56 | elif len(input_shape) == 2: 57 | input_shape = [3, input_shape[0], input_shape[1]] 58 | elif len(input_shape) == 3: 59 | input_shape = input_shape 60 | 61 | # Fix input shape 62 | inputs = layers.Input( 63 | shape=input_shape, batch_size=batch_size, name="inputs" 64 | ) 65 | outputs = model(inputs, training=False) 66 | model = models.Model(inputs, outputs) 67 | model = model.eval() 68 | 69 | full_input_shape = [1] + list(input_shape) 70 | dummy_inputs = ops.ones(full_input_shape, dtype="float32") 71 | scripted_model = torch.jit.trace( 72 | model.forward, example_inputs=[dummy_inputs] 73 | ) 74 | torch.onnx.export(scripted_model, dummy_inputs, export_path) 75 | 76 | # Further optimization 77 | model = onnx.load(export_path) 78 | model_simp, _ = onnxsim.simplify(model) 79 | model_simp = onnxoptimizer.optimize(model_simp) 80 | onnx.save(model_simp, export_path) 81 | -------------------------------------------------------------------------------- /.github/workflows/actions.yml: -------------------------------------------------------------------------------- 1 | # Ref: https://github.com/keras-team/keras/blob/master/.github/workflows/actions.yml 2 | name: Tests 3 | 4 | on: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | release: 9 | types: [created] 10 | 11 | permissions: 12 | contents: read 13 | 14 | jobs: 15 | format: 16 | name: Check the code format 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v4 20 | - name: Set up Python 3.9 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: '3.9' 24 | - name: Lint 25 | uses: pre-commit/action@v3.0.1 26 | - name: Get pip cache dir 27 | id: pip-cache 28 | run: | 29 | python -m pip install --upgrade pip setuptools 30 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT 31 | - name: Cache pip 32 | uses: actions/cache@v4 33 | with: 34 | path: ${{ steps.pip-cache.outputs.dir }} 35 | key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }} 36 | - name: Install dependencies 37 | run: | 38 | pip install -r requirements.txt --progress-bar off --upgrade 39 | pip install -e ".[tests]" --progress-bar off --upgrade 40 | - name: Check for API changes 41 | run: | 42 | bash shell/api_gen.sh 43 | git status 44 | clean=$(git status | grep "nothing to commit") 45 | if [ -z "$clean" ]; then 46 | echo "Please run shell/api_gen.sh to generate API." 47 | exit 1 48 | fi 49 | 50 | build: 51 | strategy: 52 | fail-fast: false 53 | matrix: 54 | backend: [tensorflow, jax, torch, numpy] 55 | name: Run tests 56 | runs-on: ubuntu-latest 57 | env: 58 | KERAS_BACKEND: ${{ matrix.backend }} 59 | steps: 60 | - uses: actions/checkout@v4 61 | - name: Set up Python 3.9 62 | uses: actions/setup-python@v5 63 | with: 64 | python-version: '3.9' 65 | - name: Get pip cache dir 66 | id: pip-cache 67 | run: | 68 | python -m pip install --upgrade pip setuptools 69 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT 70 | - name: Cache pip 71 | uses: actions/cache@v4 72 | with: 73 | path: ${{ steps.pip-cache.outputs.dir }} 74 | key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }} 75 | - name: Install dependencies 76 | run: | 77 | pip install -r requirements.txt --progress-bar off --upgrade 78 | pip install -e ".[tests]" --progress-bar off --upgrade 79 | - name: Test with pytest 80 | run: | 81 | pytest 82 | coverage xml -o coverage.xml 83 | - name: Upload coverage reports to Codecov 84 | uses: codecov/codecov-action@v5 85 | with: 86 | token: ${{ secrets.CODECOV_TOKEN }} 87 | files: coverage.xml 88 | flags: kimm,kimm-${{ matrix.backend }} 89 | fail_ci_if_error: false 90 | -------------------------------------------------------------------------------- /kimm/_src/kimm_export.py: -------------------------------------------------------------------------------- 1 | try: 2 | import namex 3 | except ImportError: 4 | namex = None 5 | 6 | # These dicts reference "canonical names" only 7 | # (i.e. the first name an object was registered with). 8 | REGISTERED_NAMES_TO_OBJS = {} 9 | REGISTERED_OBJS_TO_NAMES = {} 10 | 11 | 12 | def register_internal_serializable(path, symbol): 13 | global REGISTERED_NAMES_TO_OBJS 14 | if isinstance(path, (list, tuple)): 15 | name = path[0] 16 | else: 17 | name = path 18 | REGISTERED_NAMES_TO_OBJS[name] = symbol 19 | REGISTERED_OBJS_TO_NAMES[symbol] = name 20 | 21 | 22 | def get_symbol_from_name(name): 23 | return REGISTERED_NAMES_TO_OBJS.get(name, None) 24 | 25 | 26 | def get_name_from_symbol(symbol): 27 | return REGISTERED_OBJS_TO_NAMES.get(symbol, None) 28 | 29 | 30 | if namex: 31 | 32 | class kimm_export: 33 | def __init__(self, parent_path): 34 | package = "kimm" 35 | 36 | if isinstance(parent_path, str): 37 | export_paths = [parent_path] 38 | elif isinstance(parent_path, list): 39 | export_paths = parent_path 40 | else: 41 | raise ValueError( 42 | f"Invalid type for `parent_path` argument: " 43 | f"Received '{parent_path}' " 44 | f"of type {type(parent_path)}" 45 | ) 46 | for p in export_paths: 47 | if not p.startswith(package): 48 | raise ValueError( 49 | "All `export_path` values should start with " 50 | f"'{package}.'. Received: parent_path={parent_path}" 51 | ) 52 | self.package = package 53 | self.parent_path = parent_path 54 | 55 | def __call__(self, symbol): 56 | if hasattr(symbol, "_api_export_path") and ( 57 | symbol._api_export_symbol_id == id(symbol) 58 | ): 59 | raise ValueError( 60 | f"Symbol {symbol} is already exported as " 61 | f"'{symbol._api_export_path}'. " 62 | f"Cannot also export it to '{self.parent_path}'." 63 | ) 64 | if isinstance(self.parent_path, list): 65 | path = [p + f".{symbol.__name__}" for p in self.parent_path] 66 | elif isinstance(self.parent_path, str): 67 | path = self.parent_path + f".{symbol.__name__}" 68 | symbol._api_export_path = path 69 | symbol._api_export_symbol_id = id(symbol) 70 | 71 | register_internal_serializable(path, symbol) 72 | return symbol 73 | 74 | else: 75 | 76 | class kimm_export: 77 | def __init__(self, parent_path): 78 | self.parent_path = parent_path 79 | 80 | def __call__(self, symbol): 81 | if isinstance(self.parent_path, list): 82 | path = [p + f".{symbol.__name__}" for p in self.parent_path] 83 | elif isinstance(self.parent_path, str): 84 | path = self.parent_path + f".{symbol.__name__}" 85 | 86 | register_internal_serializable(path, symbol) 87 | return symbol 88 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "kimm" 7 | description = "A Keras model zoo with pretrained weights." 8 | keywords = [ 9 | "deep-learning", 10 | "model-zoo", 11 | "keras", 12 | "jax", 13 | "tensorflow", 14 | "torch", 15 | "imagenet", 16 | "pretrained-weights", 17 | "timm", 18 | ] 19 | authors = [{ name = "Hong-Yu Chiu", email = "james77777778@gmail.com" }] 20 | maintainers = [{ name = "Hong-Yu Chiu", email = "james77777778@gmail.com" }] 21 | readme = "README.md" 22 | requires-python = ">=3.9" 23 | license = { text = "Apache License 2.0" } 24 | classifiers = [ 25 | "Programming Language :: Python", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Programming Language :: Python :: 3.12", 31 | "Programming Language :: Python :: 3 :: Only", 32 | "Operating System :: Unix", 33 | "Operating System :: MacOS", 34 | "Intended Audience :: Science/Research", 35 | "Topic :: Scientific/Engineering", 36 | "Topic :: Software Development", 37 | ] 38 | dynamic = ["version"] 39 | dependencies = ["keras"] 40 | 41 | [project.urls] 42 | Homepage = "https://github.com/james77777778/keras-image-models" 43 | Documentation = "https://github.com/james77777778/keras-image-models" 44 | Repository = "https://github.com/james77777778/keras-image-models.git" 45 | Issues = "https://github.com/james77777778/keras-image-models/issues" 46 | 47 | [project.optional-dependencies] 48 | tests = [ 49 | # export 50 | "tf2onnx", 51 | "onnx", 52 | "onnxoptimizer", 53 | "onnxsim", 54 | # linter and formatter 55 | "isort", 56 | "ruff", 57 | "black", 58 | "pytest", 59 | "pytest-cov", 60 | "coverage", 61 | # tool 62 | "pre-commit", 63 | "namex", 64 | ] 65 | examples = ["opencv-python", "matplotlib"] 66 | 67 | [tool.setuptools.packages] 68 | find = { include = ["kimm*"] } 69 | 70 | [tool.setuptools.dynamic] 71 | version = { attr = "kimm.__version__" } 72 | 73 | [tool.black] 74 | line-length = 80 75 | 76 | [tool.ruff] 77 | line-length = 80 78 | lint.select = ["E", "W", "F"] 79 | lint.isort.force-single-line = true 80 | exclude = [ 81 | ".venv", 82 | ".vscode", 83 | ".github", 84 | ".devcontainer", 85 | "venv", 86 | "__pycache__", 87 | ] 88 | 89 | [tool.ruff.lint.per-file-ignores] 90 | "**/__init__.py" = ["F401"] 91 | 92 | [tool.isort] 93 | profile = "black" 94 | force_single_line = true 95 | known_first_party = ["kimm"] 96 | line_length = 80 97 | 98 | [tool.pytest.ini_options] 99 | addopts = "-vv --durations 10 --cov --cov-report html --cov-report term:skip-covered --cov-report xml" 100 | testpaths = ["kimm"] 101 | filterwarnings = [ 102 | "error", 103 | "ignore::UserWarning", 104 | "ignore::DeprecationWarning", 105 | "ignore::ImportWarning", 106 | "ignore::RuntimeWarning", 107 | "ignore::PendingDeprecationWarning", 108 | "ignore::FutureWarning", 109 | ] 110 | 111 | [tool.coverage.run] 112 | source = ["kimm"] 113 | omit = ["**/__init__.py", "*test*"] 114 | 115 | [tool.coverage.report] 116 | exclude_lines = [ 117 | "pragma: no cover", 118 | "@abstract", 119 | "raise NotImplementedError", 120 | "raise ValueError", 121 | ] 122 | -------------------------------------------------------------------------------- /kimm/_src/layers/position_embedding.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras import layers 3 | from keras import ops 4 | 5 | from kimm._src.kimm_export import kimm_export 6 | 7 | 8 | @kimm_export(parent_path=["kimm.layers"]) 9 | @keras.saving.register_keras_serializable(package="kimm") 10 | class PositionEmbedding(layers.Layer): 11 | def __init__(self, height, width, **kwargs): 12 | super().__init__(**kwargs) 13 | # We need height and width for saving and loading 14 | self.height = int(height) 15 | self.width = int(width) 16 | 17 | def build(self, input_shape): 18 | if len(input_shape) != 3: 19 | raise ValueError( 20 | "PositionEmbedding only accepts 3-dimensional input. " 21 | f"Received: input_shape={input_shape}" 22 | ) 23 | if self.height * self.width != input_shape[-2]: 24 | raise ValueError( 25 | "The embedding size doesn't match the height and width. " 26 | f"Received: height={self.height}, width={self.width}, " 27 | f"input_shape={input_shape}" 28 | ) 29 | self.pos_embed = self.add_weight( 30 | shape=[1, input_shape[-2] + 1, input_shape[-1]], 31 | initializer="random_normal", 32 | name="pos_embed", 33 | ) 34 | self.cls_token = self.add_weight( 35 | shape=[1, 1, input_shape[-1]], initializer="zeros", name="cls_token" 36 | ) 37 | self.built = True 38 | 39 | def call(self, inputs, training=None, mask=None): 40 | input_shape = ops.shape(inputs) 41 | x = ops.concatenate( 42 | [ops.tile(self.cls_token, [input_shape[0], 1, 1]), inputs], 43 | axis=1, 44 | ) 45 | x = ops.add(x, self.pos_embed) 46 | return x 47 | 48 | def compute_output_shape(self, input_shape): 49 | output_shape = list(input_shape) 50 | output_shape[1] = output_shape[1] + 1 51 | return output_shape 52 | 53 | def save_own_variables(self, store): 54 | super().save_own_variables(store) 55 | # Add height and width information 56 | store["height"] = self.height 57 | store["width"] = self.width 58 | 59 | def load_own_variables(self, store): 60 | old_height = int(store["height"][...]) 61 | old_width = int(store["width"][...]) 62 | if old_height == self.height and old_width == self.width: 63 | self.pos_embed.assign(store["0"]) 64 | self.cls_token.assign(store["1"]) 65 | return 66 | 67 | # Resize the embedding if there is a shape mismatch 68 | pos_embed = store["0"] 69 | pos_embed_prefix, pos_embed = pos_embed[:, :1], pos_embed[:, 1:] 70 | pos_embed_dim = pos_embed.shape[-1] 71 | pos_embed = ops.cast(pos_embed, "float32") 72 | pos_embed = ops.reshape(pos_embed, [1, old_height, old_width, -1]) 73 | pos_embed = ops.image.resize( 74 | pos_embed, 75 | size=[self.height, self.width], 76 | interpolation="bilinear", 77 | antialias=True, 78 | data_format="channels_last", 79 | ) 80 | pos_embed = ops.reshape(pos_embed, [1, -1, pos_embed_dim]) 81 | pos_embed = ops.concatenate([pos_embed_prefix, pos_embed], axis=1) 82 | self.pos_embed.assign(pos_embed) 83 | self.cls_token.assign(store["1"]) 84 | 85 | def get_config(self): 86 | config = super().get_config() 87 | config.update( 88 | {"height": self.height, "width": self.width, "name": self.name} 89 | ) 90 | return config 91 | -------------------------------------------------------------------------------- /kimm/_src/utils/model_registry_test.py: -------------------------------------------------------------------------------- 1 | from keras import models 2 | from keras.src import testing 3 | 4 | from kimm._src.models.base_model import BaseModel 5 | from kimm._src.utils.model_registry import MODEL_REGISTRY 6 | from kimm._src.utils.model_registry import add_model_to_registry 7 | from kimm._src.utils.model_registry import clear_registry 8 | from kimm._src.utils.model_registry import list_models 9 | 10 | 11 | class DummyModel(models.Model): 12 | pass 13 | 14 | 15 | class DummyFeatureExtractor(BaseModel): 16 | available_feature_keys = ["A", "B", "C"] 17 | 18 | 19 | class ModelRegistryTest(testing.TestCase): 20 | def test_add_model_to_registry(self): 21 | clear_registry() 22 | self.assertEqual(len(MODEL_REGISTRY), 0) 23 | 24 | add_model_to_registry(DummyModel, None) 25 | self.assertEqual(len(MODEL_REGISTRY), 1) 26 | self.assertEqual(MODEL_REGISTRY[0]["name"], DummyModel.__name__) 27 | self.assertEqual(MODEL_REGISTRY[0]["feature_extractor"], False) 28 | self.assertEqual(MODEL_REGISTRY[0]["feature_keys"], []) 29 | self.assertEqual(MODEL_REGISTRY[0]["weights"], None) 30 | 31 | add_model_to_registry(DummyFeatureExtractor, "imagenet") 32 | self.assertEqual(len(MODEL_REGISTRY), 2) 33 | self.assertEqual( 34 | MODEL_REGISTRY[1]["name"], DummyFeatureExtractor.__name__ 35 | ) 36 | self.assertEqual(MODEL_REGISTRY[1]["feature_extractor"], True) 37 | self.assertEqual(MODEL_REGISTRY[1]["feature_keys"], ["A", "B", "C"]) 38 | self.assertEqual(MODEL_REGISTRY[1]["weights"], "imagenet") 39 | 40 | def test_add_model_to_registry_invalid(self): 41 | clear_registry() 42 | add_model_to_registry(DummyModel, None) 43 | with self.assertWarnsRegex(Warning, "MODEL_REGISTRY already contains"): 44 | add_model_to_registry(DummyModel, None) 45 | 46 | def test_list_models(self): 47 | clear_registry() 48 | add_model_to_registry(DummyModel, None) 49 | add_model_to_registry(DummyFeatureExtractor, "imagenet") 50 | 51 | # all models 52 | result = list_models() 53 | self.assertEqual(len(result), 2) 54 | self.assertTrue(DummyModel.__name__ in result) 55 | self.assertTrue(DummyFeatureExtractor.__name__ in result) 56 | 57 | # filter name 58 | result = list_models("DummyModel") 59 | self.assertEqual(len(result), 1) 60 | self.assertTrue(DummyModel.__name__ in result) 61 | self.assertTrue(DummyFeatureExtractor.__name__ not in result) 62 | 63 | # filter feature_extractor 64 | result = list_models(feature_extractor=True) 65 | self.assertEqual(len(result), 1) 66 | self.assertTrue(DummyModel.__name__ not in result) 67 | self.assertTrue(DummyFeatureExtractor.__name__ in result) 68 | 69 | # filter weights="imagenet" 70 | result = list_models(weights="imagenet") 71 | self.assertEqual(len(result), 1) 72 | self.assertTrue(DummyModel.__name__ not in result) 73 | self.assertTrue(DummyFeatureExtractor.__name__ in result) 74 | 75 | # filter weights=True 76 | result = list_models(weights=True) 77 | self.assertEqual(len(result), 1) 78 | self.assertTrue(DummyModel.__name__ not in result) 79 | self.assertTrue(DummyFeatureExtractor.__name__ in result) 80 | 81 | # filter multiple conditions 82 | result = list_models(feature_extractor=True, weights=False) 83 | self.assertEqual(len(result), 0) 84 | self.assertTrue(DummyModel.__name__ not in result) 85 | self.assertTrue(DummyFeatureExtractor.__name__ not in result) 86 | 87 | result = list_models( 88 | "Dummy", feature_extractor=True, weights="imagenet" 89 | ) 90 | self.assertEqual(len(result), 1) 91 | self.assertTrue(DummyModel.__name__ not in result) 92 | self.assertTrue(DummyFeatureExtractor.__name__ in result) 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Keras 163 | *.keras 164 | exported 165 | 166 | # Exported model 167 | *.tflite 168 | *.onnx 169 | -------------------------------------------------------------------------------- /kimm/_src/export/export_tflite.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import tempfile 3 | import typing 4 | 5 | from keras import backend 6 | from keras import layers 7 | from keras import models 8 | from keras.src.utils.module_utils import tensorflow as tf 9 | 10 | from kimm._src.kimm_export import kimm_export 11 | from kimm._src.models.base_model import BaseModel 12 | 13 | 14 | @kimm_export(parent_path=["kimm.export"]) 15 | def export_tflite( 16 | model: BaseModel, 17 | input_shape: typing.Union[int, typing.Sequence[int]], 18 | export_path: typing.Union[str, pathlib.Path], 19 | export_dtype: typing.Literal["float32", "float16", "int8"] = "float32", 20 | representative_dataset: typing.Optional[typing.Iterator] = None, 21 | batch_size: int = 1, 22 | ): 23 | """Export the model to tflite format. 24 | 25 | Only TensorFlow backend with 'channels_last' is supported. The 26 | tflite model will be generated using 27 | `tf.lite.TFLiteConverter.from_saved_model` and optimized through tflite 28 | built-in functions. 29 | 30 | Note that when exporting an `int8` tflite model, `representative_dataset` 31 | must be passed. 32 | 33 | Args: 34 | model: keras.Model, the model to be exported. 35 | input_shape: int or sequence of int, specifying the shape of the input. 36 | export_path: str or pathlib.Path, specifying the path to export. 37 | export_dtype: str, specifying the export dtype. 38 | representative_dataset: None or Iterator, the calibration dataset for 39 | exporting int8 tflite. 40 | batch_size: int, specifying the batch size of the input, 41 | defaults to `1`. 42 | """ 43 | if backend.backend() not in ("tensorflow",): 44 | raise ValueError("`export_tflite` only supports TensorFlow backend") 45 | if backend.image_data_format() != "channels_last": 46 | raise ValueError( 47 | "`export_tflite` only supports 'channels_last' data format." 48 | ) 49 | if export_dtype not in ("float32", "float16", "int8"): 50 | raise ValueError( 51 | "`export_dtype` must be one of ('float32', 'float16', 'int8'). " 52 | f"Received: export_dtype={export_dtype}" 53 | ) 54 | if export_dtype == "int8" and representative_dataset is None: 55 | raise ValueError( 56 | "For full integer quantization, a `representative_dataset` should " 57 | "be specified." 58 | ) 59 | if isinstance(input_shape, int): 60 | input_shape = [input_shape, input_shape, 3] 61 | elif len(input_shape) == 2: 62 | input_shape = [input_shape[0], input_shape[1], 3] 63 | elif len(input_shape) == 3: 64 | input_shape = input_shape 65 | 66 | # Fix input shape 67 | inputs = layers.Input(shape=input_shape, batch_size=batch_size) 68 | outputs = model(inputs, training=False) 69 | model = models.Model(inputs, outputs) 70 | 71 | # Construct TFLiteConverter 72 | with tempfile.TemporaryDirectory() as temp_dir: 73 | temp_path = pathlib.Path(temp_dir, "temp_saved_model") 74 | model.export(temp_path) 75 | converter = tf.lite.TFLiteConverter.from_saved_model(str(temp_path)) 76 | 77 | # Configure converter 78 | if export_dtype != "float32": 79 | converter.optimizations = [tf.lite.Optimize.DEFAULT] 80 | if export_dtype == "int8": 81 | converter.target_spec.supported_ops = [ 82 | tf.lite.OpsSet.TFLITE_BUILTINS_INT8 83 | ] 84 | converter.inference_input_type = tf.int8 85 | converter.inference_output_type = tf.int8 86 | elif export_dtype == "float16": 87 | converter.target_spec.supported_types = [tf.float16] 88 | if representative_dataset is not None: 89 | converter.representative_dataset = representative_dataset 90 | 91 | # Convert 92 | tflite_model = converter.convert() 93 | 94 | # Export 95 | with open(export_path, "wb") as f: 96 | f.write(tflite_model) 97 | -------------------------------------------------------------------------------- /kimm/_src/utils/model_registry.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import typing 3 | import warnings 4 | 5 | from kimm._src.kimm_export import kimm_export 6 | 7 | # { 8 | # "name", # str 9 | # "feature_extractor", # bool 10 | # "feature_keys", # list of str 11 | # "weights", # None or str 12 | # } 13 | MODEL_REGISTRY: typing.List[typing.Dict[str, typing.Union[str, bool]]] = [] 14 | 15 | 16 | def _match_string(query: str, target: str): 17 | query = query.lower().replace(" ", "").replace("_", "").replace(".", "") 18 | target = target.lower() 19 | matched_idx = -1 20 | for q_char in query: 21 | matched = False 22 | for idx, t_char in enumerate(target): 23 | if matched: 24 | break 25 | if q_char == t_char and idx > matched_idx: 26 | matched_idx = idx 27 | matched = True 28 | if not matched: 29 | return False 30 | return True 31 | 32 | 33 | def clear_registry(): 34 | MODEL_REGISTRY.clear() 35 | 36 | 37 | def add_model_to_registry(model_cls, weights: typing.Optional[str] = None): 38 | from kimm._src.models.base_model import BaseModel 39 | 40 | # Deal with __all__ 41 | mod = sys.modules[model_cls.__module__] 42 | model_name = model_cls.__name__ 43 | if hasattr(mod, "__all__"): 44 | mod.__all__.append(model_name) 45 | else: 46 | mod.__all__ = [model_name] 47 | 48 | # Add model information 49 | feature_extractor = False 50 | feature_keys = [] 51 | if issubclass(model_cls, BaseModel): 52 | feature_extractor = True 53 | feature_keys = model_cls.available_feature_keys 54 | for info in MODEL_REGISTRY: 55 | if info["name"] == model_cls.__name__: 56 | warnings.warn( 57 | f"MODEL_REGISTRY already contains name={model_cls.__name__}!" 58 | ) 59 | if weights is not None: 60 | if not isinstance(weights, str): 61 | raise ValueError( 62 | "`weights` must be one of (None, str). " 63 | f"Recieved: weight={weights}" 64 | ) 65 | weights = weights.lower() 66 | MODEL_REGISTRY.append( 67 | { 68 | "name": model_cls.__name__, 69 | "feature_extractor": feature_extractor, 70 | "feature_keys": feature_keys, 71 | "weights": weights, 72 | } 73 | ) 74 | 75 | 76 | @kimm_export(parent_path=["kimm", "kimm.utils"]) 77 | def list_models( 78 | name: typing.Optional[str] = None, 79 | feature_extractor: typing.Optional[bool] = None, 80 | weights: typing.Optional[typing.Union[bool, str]] = None, 81 | ): 82 | """List the models with the given arguments. 83 | 84 | Args: 85 | name: An optional `str` specifying the substring of the name of the 86 | model to seatch for. If not specified, all models will be included. 87 | feature_extractor: Whether to include models that support 88 | feature extraction. Defaults to `None`, which means this 89 | argument is not considered. 90 | weights: An optional boolean or `str` specifying the name of the 91 | pretrained weights. The available values are (`"imagenet"`). 92 | Defaults to `None`, which means this argument is not considered. 93 | 94 | Returns: 95 | A list of model names. 96 | """ 97 | result_names: typing.Set[str] = set() 98 | for info in MODEL_REGISTRY: 99 | # Add by default 100 | result_names.add(info["name"]) 101 | need_remove = False 102 | 103 | # Match string (simple implementation) 104 | if name is not None: 105 | need_remove = not _match_string(name, info["name"]) 106 | 107 | # Filter by feature_extractor and weights 108 | if ( 109 | feature_extractor is not None 110 | and info["feature_extractor"] is not feature_extractor 111 | ): 112 | need_remove = True 113 | if weights is not None and info["weights"] != weights: 114 | if weights is True and info["weights"] is None: 115 | need_remove = True 116 | elif weights is False and info["weights"] is not None: 117 | need_remove = True 118 | elif isinstance(weights, str): 119 | if weights.lower() != info["weights"]: 120 | need_remove = True 121 | if need_remove: 122 | result_names.remove(info["name"]) 123 | return sorted(result_names) 124 | -------------------------------------------------------------------------------- /tools/convert_vgg_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import vgg 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "vgg11_bn.tv_in1k", 21 | "vgg13_bn.tv_in1k", 22 | "vgg16_bn.tv_in1k", 23 | "vgg19_bn.tv_in1k", 24 | ] 25 | keras_model_classes = [ 26 | vgg.VGG11, 27 | vgg.VGG13, 28 | vgg.VGG16, 29 | vgg.VGG19, 30 | ] 31 | 32 | for timm_model_name, keras_model_class in zip( 33 | timm_model_names, keras_model_classes 34 | ): 35 | """ 36 | Prepare timm model and keras model 37 | """ 38 | input_shape = [224, 224, 3] 39 | torch_model = timm.create_model(timm_model_name, pretrained=True) 40 | torch_model = torch_model.eval() 41 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 42 | torch_model.state_dict() 43 | ) 44 | keras_model = keras_model_class( 45 | input_shape=input_shape, 46 | include_preprocessing=False, 47 | classifier_activation="linear", 48 | weights=None, 49 | ) 50 | trainable_weights, non_trainable_weights = separate_keras_weights( 51 | keras_model 52 | ) 53 | 54 | # for torch_name, (_, keras_name) in zip( 55 | # trainable_state_dict.keys(), trainable_weights 56 | # ): 57 | # print(f"{torch_name} {keras_name}") 58 | 59 | # print(len(trainable_state_dict.keys())) 60 | # print(len(trainable_weights)) 61 | 62 | # exit() 63 | 64 | """ 65 | Assign weights 66 | """ 67 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 68 | keras_name: str 69 | torch_name = keras_name 70 | torch_name = torch_name.replace("_", ".") 71 | # blocks 72 | torch_name = torch_name.replace("conv2d", "") 73 | torch_name = torch_name.replace("pre.logits", "pre_logits") 74 | # head 75 | torch_name = torch_name.replace("classifier", "head.fc") 76 | 77 | # weights naming mapping 78 | torch_name = torch_name.replace("kernel", "weight") # conv2d 79 | torch_name = torch_name.replace("gamma", "weight") # bn 80 | torch_name = torch_name.replace("beta", "bias") # bn 81 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 82 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 83 | 84 | # assign weights 85 | if torch_name in trainable_state_dict: 86 | torch_weights = trainable_state_dict[torch_name].numpy() 87 | elif torch_name in non_trainable_state_dict: 88 | torch_weights = non_trainable_state_dict[torch_name].numpy() 89 | else: 90 | raise ValueError( 91 | "Can't find the corresponding torch weights. " 92 | f"Got keras_name={keras_name}, torch_name={torch_name}" 93 | ) 94 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 95 | assign_weights(keras_name, keras_weight, torch_weights) 96 | else: 97 | raise ValueError( 98 | "Can't find the corresponding torch weights. The shape is " 99 | f"mismatched. Got keras_name={keras_name}, " 100 | f"keras_weight shape={keras_weight.shape}, " 101 | f"torch_name={torch_name}, " 102 | f"torch_weights shape={torch_weights.shape}" 103 | ) 104 | 105 | """ 106 | Verify model outputs 107 | """ 108 | np.random.seed(2023) 109 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 110 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 111 | torch_y = torch_model(torch_data) 112 | keras_y = keras_model(keras_data, training=False) 113 | torch_y = torch_y.detach().cpu().numpy() 114 | keras_y = keras.ops.convert_to_numpy(keras_y) 115 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) 116 | print(f"{keras_model_class.__name__}: output matched!") 117 | 118 | """ 119 | Save converted model 120 | """ 121 | os.makedirs("exported", exist_ok=True) 122 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 123 | keras_model.save(export_path) 124 | print(f"Export to {export_path}") 125 | -------------------------------------------------------------------------------- /kimm/_src/layers/attention.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras import InputSpec 3 | from keras import layers 4 | from keras import ops 5 | 6 | from kimm._src.kimm_export import kimm_export 7 | 8 | 9 | @kimm_export(parent_path=["kimm.layers"]) 10 | @keras.saving.register_keras_serializable(package="kimm") 11 | class Attention(layers.Layer): 12 | def __init__( 13 | self, 14 | hidden_dim: int, 15 | num_heads: int = 8, 16 | use_qkv_bias: bool = False, 17 | attention_dropout_rate: float = 0.0, 18 | projection_dropout_rate: float = 0.0, 19 | **kwargs, 20 | ): 21 | super().__init__(**kwargs) 22 | self.hidden_dim = hidden_dim 23 | self.num_heads = num_heads 24 | self.head_dim = hidden_dim // num_heads 25 | self.scale = self.head_dim ** (-0.5) 26 | self.use_qkv_bias = use_qkv_bias 27 | self.attention_dropout_rate = attention_dropout_rate 28 | self.projection_dropout_rate = projection_dropout_rate 29 | 30 | self.qkv = layers.Dense( 31 | hidden_dim * 3, 32 | use_bias=use_qkv_bias, 33 | dtype=self.dtype_policy, 34 | name=f"{self.name}_qkv", 35 | ) 36 | 37 | self.attention_dropout = layers.Dropout( 38 | attention_dropout_rate, 39 | dtype=self.dtype_policy, 40 | name=f"{self.name}_attn_drop", 41 | ) 42 | self.projection = layers.Dense( 43 | hidden_dim, dtype=self.dtype_policy, name=f"{self.name}_proj" 44 | ) 45 | self.projection_dropout = layers.Dropout( 46 | projection_dropout_rate, 47 | dtype=self.dtype_policy, 48 | name=f"{self.name}_proj_drop", 49 | ) 50 | 51 | def build(self, input_shape): 52 | self.input_spec = InputSpec(ndim=len(input_shape)) 53 | if self.input_spec.ndim not in (3, 4): 54 | raise ValueError( 55 | "The ndim of the inputs must be 3 or 4. " 56 | f"Received: input_shape={input_shape}" 57 | ) 58 | 59 | self.qkv.build(input_shape) 60 | qkv_output_shape = list(input_shape) 61 | qkv_output_shape[-1] = qkv_output_shape[-1] * 3 62 | attention_input_shape = [ 63 | input_shape[0], 64 | self.num_heads, 65 | input_shape[1], 66 | input_shape[1], 67 | ] 68 | self.attention_dropout.build(attention_input_shape) 69 | self.projection.build(input_shape) 70 | self.projection_dropout.build(input_shape) 71 | self.built = True 72 | 73 | def call(self, inputs, training=None, mask=None): 74 | input_shape = ops.shape(inputs) 75 | qkv = self.qkv(inputs) 76 | if self.input_spec.ndim == 3: 77 | qkv = ops.reshape( 78 | qkv, 79 | [ 80 | input_shape[0], 81 | input_shape[1], 82 | 3, 83 | self.num_heads, 84 | self.head_dim, 85 | ], 86 | ) 87 | qkv = ops.transpose(qkv, [0, 3, 2, 1, 4]) 88 | q, k, v = ops.unstack(qkv, 3, axis=2) 89 | else: 90 | # self.input_spec.ndim==4 91 | qkv = ops.reshape( 92 | qkv, 93 | [ 94 | input_shape[0], 95 | input_shape[1], 96 | input_shape[2], 97 | 3, 98 | self.num_heads, 99 | self.head_dim, 100 | ], 101 | ) 102 | qkv = ops.transpose(qkv, [0, 1, 4, 3, 2, 5]) 103 | q, k, v = ops.unstack(qkv, 3, axis=3) 104 | 105 | # attention 106 | q = ops.multiply(q, self.scale) 107 | attn = ops.matmul(q, ops.swapaxes(k, -2, -1)) 108 | attn = ops.softmax(attn) 109 | attn = self.attention_dropout(attn) 110 | x = ops.matmul(attn, v) 111 | x = ops.reshape(ops.swapaxes(x, -3, -2), input_shape) 112 | x = self.projection(x) 113 | x = self.projection_dropout(x) 114 | return x 115 | 116 | def get_config(self): 117 | config = super().get_config() 118 | config.update( 119 | { 120 | "hidden_dim": self.hidden_dim, 121 | "num_heads": self.num_heads, 122 | "use_qkv_bias": self.use_qkv_bias, 123 | "attention_dropout_rate": self.attention_dropout_rate, 124 | "projection_dropout_rate": self.projection_dropout_rate, 125 | "name": self.name, 126 | } 127 | ) 128 | return config 129 | -------------------------------------------------------------------------------- /tools/convert_densenet_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import densenet 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "densenet121.ra_in1k", 21 | "densenet161.tv_in1k", 22 | "densenet169.tv_in1k", 23 | "densenet201.tv_in1k", 24 | ] 25 | keras_model_classes = [ 26 | densenet.DenseNet121, 27 | densenet.DenseNet161, 28 | densenet.DenseNet169, 29 | densenet.DenseNet201, 30 | ] 31 | 32 | for timm_model_name, keras_model_class in zip( 33 | timm_model_names, keras_model_classes 34 | ): 35 | """ 36 | Prepare timm model and keras model 37 | """ 38 | input_shape = [224, 224, 3] 39 | torch_model = timm.create_model(timm_model_name, pretrained=True) 40 | torch_model = torch_model.eval() 41 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 42 | torch_model.state_dict() 43 | ) 44 | keras_model = keras_model_class( 45 | input_shape=input_shape, 46 | include_preprocessing=False, 47 | classifier_activation="linear", 48 | weights=None, 49 | ) 50 | trainable_weights, non_trainable_weights = separate_keras_weights( 51 | keras_model 52 | ) 53 | 54 | # for torch_name, (_, keras_name) in zip( 55 | # trainable_state_dict.keys(), trainable_weights 56 | # ): 57 | # print(f"{torch_name} {keras_name}") 58 | 59 | # print(len(trainable_state_dict.keys())) 60 | # print(len(trainable_weights)) 61 | 62 | # exit() 63 | 64 | """ 65 | Assign weights 66 | """ 67 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 68 | keras_name: str 69 | torch_name = keras_name 70 | torch_name = torch_name.replace("_", ".") 71 | # stem 72 | torch_name = torch_name.replace("conv0.conv2d", "conv0") 73 | torch_name = torch_name.replace("conv0.bn", "norm0") 74 | # blocks 75 | torch_name = torch_name.replace("conv1.conv2d", "conv1") 76 | torch_name = torch_name.replace("conv1.bn", "norm2") 77 | 78 | # weights naming mapping 79 | torch_name = torch_name.replace("kernel", "weight") # conv2d 80 | torch_name = torch_name.replace("gamma", "weight") # bn 81 | torch_name = torch_name.replace("beta", "bias") # bn 82 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 83 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 84 | 85 | # assign weights 86 | if torch_name in trainable_state_dict: 87 | torch_weights = trainable_state_dict[torch_name].numpy() 88 | elif torch_name in non_trainable_state_dict: 89 | torch_weights = non_trainable_state_dict[torch_name].numpy() 90 | else: 91 | raise ValueError( 92 | "Can't find the corresponding torch weights. " 93 | f"Got keras_name={keras_name}, torch_name={torch_name}" 94 | ) 95 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 96 | assign_weights(keras_name, keras_weight, torch_weights) 97 | else: 98 | raise ValueError( 99 | "Can't find the corresponding torch weights. The shape is " 100 | f"mismatched. Got keras_name={keras_name}, " 101 | f"keras_weight shape={keras_weight.shape}, " 102 | f"torch_name={torch_name}, " 103 | f"torch_weights shape={torch_weights.shape}" 104 | ) 105 | 106 | """ 107 | Verify model outputs 108 | """ 109 | np.random.seed(2023) 110 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 111 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 112 | torch_y = torch_model(torch_data) 113 | keras_y = keras_model(keras_data, training=False) 114 | torch_y = torch_y.detach().cpu().numpy() 115 | keras_y = keras.ops.convert_to_numpy(keras_y) 116 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) 117 | print(f"{keras_model_class.__name__}: output matched!") 118 | 119 | """ 120 | Save converted model 121 | """ 122 | os.makedirs("exported", exist_ok=True) 123 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 124 | keras_model.save(export_path) 125 | print(f"Export to {export_path}") 126 | -------------------------------------------------------------------------------- /tools/convert_convmixer_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import convmixer 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "convmixer_768_32.in1k", 21 | "convmixer_1024_20_ks9_p14.in1k", 22 | "convmixer_1536_20.in1k", 23 | ] 24 | keras_model_classes = [ 25 | convmixer.ConvMixer736D32, 26 | convmixer.ConvMixer1024D20, 27 | convmixer.ConvMixer1536D20, 28 | ] 29 | 30 | for timm_model_name, keras_model_class in zip( 31 | timm_model_names, keras_model_classes 32 | ): 33 | """ 34 | Prepare timm model and keras model 35 | """ 36 | input_shape = [224, 224, 3] 37 | torch_model = timm.create_model(timm_model_name, pretrained=True) 38 | torch_model = torch_model.eval() 39 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 40 | torch_model.state_dict() 41 | ) 42 | keras_model = keras_model_class( 43 | input_shape=input_shape, 44 | include_preprocessing=False, 45 | classifier_activation="linear", 46 | weights=None, 47 | ) 48 | trainable_weights, non_trainable_weights = separate_keras_weights( 49 | keras_model 50 | ) 51 | 52 | # for torch_name, (_, keras_name) in zip( 53 | # trainable_state_dict.keys(), trainable_weights 54 | # ): 55 | # print(f"{torch_name} {keras_name}") 56 | 57 | # print(len(trainable_state_dict.keys())) 58 | # print(len(trainable_weights)) 59 | 60 | # exit() 61 | 62 | """ 63 | Assign weights 64 | """ 65 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 66 | keras_name: str 67 | torch_name = keras_name 68 | torch_name = torch_name.replace("_", ".") 69 | # stem 70 | torch_name = torch_name.replace("stem.conv2d", "stem.0") 71 | torch_name = torch_name.replace("stem.bn", "stem.2") 72 | # blocks 73 | torch_name = torch_name.replace("dwconv2d.", "") 74 | torch_name = torch_name.replace("conv2d.", "") 75 | # head 76 | torch_name = torch_name.replace("classifier", "head") 77 | 78 | # weights naming mapping 79 | torch_name = torch_name.replace("kernel", "weight") # conv2d 80 | torch_name = torch_name.replace("gamma", "weight") # bn 81 | torch_name = torch_name.replace("beta", "bias") # bn 82 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 83 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 84 | 85 | # assign weights 86 | if torch_name in trainable_state_dict: 87 | torch_weights = trainable_state_dict[torch_name].numpy() 88 | elif torch_name in non_trainable_state_dict: 89 | torch_weights = non_trainable_state_dict[torch_name].numpy() 90 | else: 91 | raise ValueError( 92 | "Can't find the corresponding torch weights. " 93 | f"Got keras_name={keras_name}, torch_name={torch_name}" 94 | ) 95 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 96 | assign_weights(keras_name, keras_weight, torch_weights) 97 | else: 98 | raise ValueError( 99 | "Can't find the corresponding torch weights. The shape is " 100 | f"mismatched. Got keras_name={keras_name}, " 101 | f"keras_weight shape={keras_weight.shape}, " 102 | f"torch_name={torch_name}, " 103 | f"torch_weights shape={torch_weights.shape}" 104 | ) 105 | 106 | """ 107 | Verify model outputs 108 | """ 109 | np.random.seed(2023) 110 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 111 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 112 | torch_y = torch_model(torch_data) 113 | keras_y = keras_model(keras_data, training=False) 114 | torch_y = torch_y.detach().cpu().numpy() 115 | keras_y = keras.ops.convert_to_numpy(keras_y) 116 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-4) 117 | print(f"{keras_model_class.__name__}: output matched!") 118 | 119 | """ 120 | Save converted model 121 | """ 122 | os.makedirs("exported", exist_ok=True) 123 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 124 | keras_model.save(export_path) 125 | print(f"Export to {export_path}") 126 | -------------------------------------------------------------------------------- /tools/convert_resnet_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import resnet 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "resnet18.a1_in1k", 21 | "resnet34.a1_in1k", 22 | "resnet50.a1_in1k", 23 | "resnet101.a1_in1k", 24 | "resnet152.a1_in1k", 25 | ] 26 | keras_model_classes = [ 27 | resnet.ResNet18, 28 | resnet.ResNet34, 29 | resnet.ResNet50, 30 | resnet.ResNet101, 31 | resnet.ResNet152, 32 | ] 33 | 34 | for timm_model_name, keras_model_class in zip( 35 | timm_model_names, keras_model_classes 36 | ): 37 | """ 38 | Prepare timm model and keras model 39 | """ 40 | input_shape = [224, 224, 3] 41 | torch_model = timm.create_model(timm_model_name, pretrained=True) 42 | torch_model = torch_model.eval() 43 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 44 | torch_model.state_dict() 45 | ) 46 | keras_model = keras_model_class( 47 | input_shape=input_shape, 48 | include_preprocessing=False, 49 | classifier_activation="linear", 50 | weights=None, 51 | ) 52 | trainable_weights, non_trainable_weights = separate_keras_weights( 53 | keras_model 54 | ) 55 | 56 | # for torch_name, (_, keras_name) in zip( 57 | # trainable_state_dict.keys(), trainable_weights 58 | # ): 59 | # print(f"{torch_name} {keras_name}") 60 | 61 | # print(len(trainable_state_dict.keys())) 62 | # print(len(trainable_weights)) 63 | 64 | # exit() 65 | 66 | """ 67 | Assign weights 68 | """ 69 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 70 | keras_name: str 71 | torch_name = keras_name 72 | torch_name = torch_name.replace("_", ".") 73 | # stem 74 | torch_name = torch_name.replace("conv.stem.conv2d", "conv1") 75 | torch_name = torch_name.replace("conv.stem.bn", "bn1") 76 | # blocks 77 | torch_name = torch_name.replace("conv1.conv2d", "conv1") 78 | torch_name = torch_name.replace("conv1.bn", "bn1") 79 | torch_name = torch_name.replace("conv2.conv2d", "conv2") 80 | torch_name = torch_name.replace("conv2.bn", "bn2") 81 | torch_name = torch_name.replace("conv3.conv2d", "conv3") 82 | torch_name = torch_name.replace("conv3.bn", "bn3") 83 | torch_name = torch_name.replace("downsample.conv2d", "downsample.0") 84 | torch_name = torch_name.replace("downsample.bn", "downsample.1") 85 | # head 86 | torch_name = torch_name.replace("classifier", "fc") 87 | 88 | # weights naming mapping 89 | torch_name = torch_name.replace("kernel", "weight") # conv2d 90 | torch_name = torch_name.replace("gamma", "weight") # bn 91 | torch_name = torch_name.replace("beta", "bias") # bn 92 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 93 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 94 | 95 | # assign weights 96 | if torch_name in trainable_state_dict: 97 | torch_weights = trainable_state_dict[torch_name].numpy() 98 | elif torch_name in non_trainable_state_dict: 99 | torch_weights = non_trainable_state_dict[torch_name].numpy() 100 | else: 101 | raise ValueError( 102 | "Can't find the corresponding torch weights. " 103 | f"Got keras_name={keras_name}, torch_name={torch_name}" 104 | ) 105 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 106 | assign_weights(keras_name, keras_weight, torch_weights) 107 | else: 108 | raise ValueError( 109 | "Can't find the corresponding torch weights. The shape is " 110 | f"mismatched. Got keras_name={keras_name}, " 111 | f"keras_weight shape={keras_weight.shape}, " 112 | f"torch_name={torch_name}, " 113 | f"torch_weights shape={torch_weights.shape}" 114 | ) 115 | 116 | """ 117 | Verify model outputs 118 | """ 119 | np.random.seed(2023) 120 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 121 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 122 | torch_y = torch_model(torch_data) 123 | keras_y = keras_model(keras_data, training=False) 124 | torch_y = torch_y.detach().cpu().numpy() 125 | keras_y = keras.ops.convert_to_numpy(keras_y) 126 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) 127 | print(f"{keras_model_class.__name__}: output matched!") 128 | 129 | """ 130 | Save converted model 131 | """ 132 | os.makedirs("exported", exist_ok=True) 133 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 134 | keras_model.save(export_path) 135 | print(f"Export to {export_path}") 136 | -------------------------------------------------------------------------------- /tools/convert_hgnet_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import hgnet 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | # HGNet 21 | "hgnet_tiny.ssld_in1k", 22 | "hgnet_small.ssld_in1k", 23 | "hgnet_base.ssld_in1k", 24 | # HGNetV2 25 | "hgnetv2_b0.ssld_stage2_ft_in1k", 26 | "hgnetv2_b1.ssld_stage2_ft_in1k", 27 | "hgnetv2_b2.ssld_stage2_ft_in1k", 28 | "hgnetv2_b3.ssld_stage2_ft_in1k", 29 | "hgnetv2_b4.ssld_stage2_ft_in1k", 30 | "hgnetv2_b5.ssld_stage2_ft_in1k", 31 | "hgnetv2_b6.ssld_stage2_ft_in1k", 32 | ] 33 | keras_model_classes = [ 34 | hgnet.HGNetTiny, 35 | hgnet.HGNetSmall, 36 | hgnet.HGNetBase, 37 | hgnet.HGNetV2B0, 38 | hgnet.HGNetV2B1, 39 | hgnet.HGNetV2B2, 40 | hgnet.HGNetV2B3, 41 | hgnet.HGNetV2B4, 42 | hgnet.HGNetV2B5, 43 | hgnet.HGNetV2B6, 44 | ] 45 | 46 | for timm_model_name, keras_model_class in zip( 47 | timm_model_names, keras_model_classes 48 | ): 49 | """ 50 | Prepare timm model and keras model 51 | """ 52 | input_shape = [224, 224, 3] 53 | torch_model = timm.create_model(timm_model_name, pretrained=True) 54 | torch_model = torch_model.eval() 55 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 56 | torch_model.state_dict() 57 | ) 58 | keras_model = keras_model_class( 59 | input_shape=input_shape, 60 | include_preprocessing=False, 61 | classifier_activation="linear", 62 | weights=None, 63 | ) 64 | trainable_weights, non_trainable_weights = separate_keras_weights( 65 | keras_model 66 | ) 67 | 68 | # for torch_name, (_, keras_name) in zip( 69 | # trainable_state_dict.keys(), trainable_weights 70 | # ): 71 | # print(f"{torch_name} {keras_name}") 72 | 73 | # print(len(trainable_state_dict.keys())) 74 | # print(len(trainable_weights)) 75 | 76 | # exit() 77 | 78 | """ 79 | Assign weights 80 | """ 81 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 82 | keras_name: str 83 | torch_name = keras_name 84 | torch_name = torch_name.replace("_", ".") 85 | # stem 86 | if "stem.stem" not in torch_name: 87 | # HGNet 88 | torch_name = torch_name.replace("stem", "stem.stem") 89 | # conv2d 90 | torch_name = torch_name.replace("dwconv2d.kernel", "conv.weight") 91 | torch_name = torch_name.replace("conv2d.kernel", "conv.weight") 92 | # head 93 | torch_name = torch_name.replace("last.conv", "last_conv") 94 | torch_name = torch_name.replace("classifier", "head.fc") 95 | 96 | # weights naming mapping 97 | torch_name = torch_name.replace("kernel", "weight") # conv2d 98 | torch_name = torch_name.replace("gamma", "weight") # bn 99 | torch_name = torch_name.replace("beta", "bias") # bn 100 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 101 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 102 | 103 | # assign weights 104 | if torch_name in trainable_state_dict: 105 | torch_weights = trainable_state_dict[torch_name].numpy() 106 | elif torch_name in non_trainable_state_dict: 107 | torch_weights = non_trainable_state_dict[torch_name].numpy() 108 | else: 109 | raise ValueError( 110 | "Can't find the corresponding torch weights. " 111 | f"Got keras_name={keras_name}, torch_name={torch_name}" 112 | ) 113 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 114 | assign_weights(keras_name, keras_weight, torch_weights) 115 | else: 116 | raise ValueError( 117 | "Can't find the corresponding torch weights. The shape is " 118 | f"mismatched. Got keras_name={keras_name}, " 119 | f"keras_weight shape={keras_weight.shape}, " 120 | f"torch_name={torch_name}, " 121 | f"torch_weights shape={torch_weights.shape}" 122 | ) 123 | 124 | """ 125 | Verify model outputs 126 | """ 127 | np.random.seed(2023) 128 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 129 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 130 | torch_y = torch_model(torch_data) 131 | keras_y = keras_model(keras_data, training=False) 132 | torch_y = torch_y.detach().cpu().numpy() 133 | keras_y = keras.ops.convert_to_numpy(keras_y) 134 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) 135 | print(f"{keras_model_class.__name__}: output matched!") 136 | 137 | """ 138 | Save converted model 139 | """ 140 | os.makedirs("exported", exist_ok=True) 141 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 142 | keras_model.save(export_path) 143 | print(f"Export to {export_path}") 144 | -------------------------------------------------------------------------------- /tools/convert_inception_next_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import inception_next 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "inception_next_tiny.sail_in1k", 21 | "inception_next_small.sail_in1k", 22 | "inception_next_base.sail_in1k_384", 23 | ] 24 | keras_model_classes = [ 25 | inception_next.InceptionNeXtTiny, 26 | inception_next.InceptionNeXtSmall, 27 | inception_next.InceptionNeXtBase, 28 | ] 29 | 30 | for timm_model_name, keras_model_class in zip( 31 | timm_model_names, keras_model_classes 32 | ): 33 | """ 34 | Prepare timm model and keras model 35 | """ 36 | input_shape = [224, 224, 3] 37 | torch_model = timm.create_model(timm_model_name, pretrained=True) 38 | torch_model = torch_model.eval() 39 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 40 | torch_model.state_dict() 41 | ) 42 | keras_model = keras_model_class( 43 | input_shape=input_shape, 44 | include_preprocessing=False, 45 | classifier_activation="linear", 46 | weights=None, 47 | ) 48 | trainable_weights, non_trainable_weights = separate_keras_weights( 49 | keras_model 50 | ) 51 | 52 | # for torch_name, (_, keras_name) in zip( 53 | # trainable_state_dict.keys(), trainable_weights 54 | # ): 55 | # print(f"{torch_name} {keras_name}") 56 | 57 | # print(len(trainable_state_dict.keys())) 58 | # print(len(trainable_weights)) 59 | 60 | # exit() 61 | 62 | """ 63 | Assign weights 64 | """ 65 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 66 | # prevent gamma to be replaced 67 | is_layerscale = False 68 | keras_name: str 69 | torch_name = keras_name 70 | torch_name = torch_name.replace("_", ".") 71 | 72 | # stem 73 | torch_name = torch_name.replace("stem.0.conv2d.kernel", "stem.0.weight") 74 | torch_name = torch_name.replace("stem.0.conv2d.bias", "stem.0.bias") 75 | 76 | # blocks 77 | torch_name = torch_name.replace("dwconv2d.", "") 78 | torch_name = torch_name.replace("conv2d.", "") 79 | torch_name = torch_name.replace("conv.dw", "conv_dw") 80 | if "layerscale" in torch_name: 81 | is_layerscale = True 82 | torch_name = torch_name.replace("layerscale.", "") 83 | torch_name = torch_name.replace("token.mixer", "token_mixer") 84 | torch_name = torch_name.replace("dwconv.hw.", "dwconv_hw.") 85 | torch_name = torch_name.replace("dwconv.w.", "dwconv_w.") 86 | torch_name = torch_name.replace("dwconv.h.", "dwconv_h.") 87 | # head 88 | torch_name = torch_name.replace("classifier", "head.fc2") 89 | 90 | # weights naming mapping 91 | torch_name = torch_name.replace("kernel", "weight") # conv2d 92 | if not is_layerscale: 93 | torch_name = torch_name.replace("gamma", "weight") # bn 94 | torch_name = torch_name.replace("beta", "bias") # bn 95 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 96 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 97 | 98 | # assign weights 99 | if torch_name in trainable_state_dict: 100 | torch_weights = trainable_state_dict[torch_name].numpy() 101 | elif torch_name in non_trainable_state_dict: 102 | torch_weights = non_trainable_state_dict[torch_name].numpy() 103 | else: 104 | raise ValueError( 105 | "Can't find the corresponding torch weights. " 106 | f"Got keras_name={keras_name}, torch_name={torch_name}" 107 | ) 108 | if is_layerscale: 109 | assign_weights(keras_name, keras_weight, torch_weights) 110 | elif is_same_weights( 111 | keras_name, keras_weight, torch_name, torch_weights 112 | ): 113 | assign_weights(keras_name, keras_weight, torch_weights) 114 | else: 115 | raise ValueError( 116 | "Can't find the corresponding torch weights. The shape is " 117 | f"mismatched. Got keras_name={keras_name}, " 118 | f"keras_weight shape={keras_weight.shape}, " 119 | f"torch_name={torch_name}, " 120 | f"torch_weights shape={torch_weights.shape}" 121 | ) 122 | 123 | """ 124 | Verify model outputs 125 | """ 126 | np.random.seed(2023) 127 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 128 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 129 | torch_y = torch_model(torch_data) 130 | keras_y = keras_model(keras_data, training=False) 131 | torch_y = torch_y.detach().cpu().numpy() 132 | keras_y = keras.ops.convert_to_numpy(keras_y) 133 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-4) 134 | print(f"{keras_model_class.__name__}: output matched!") 135 | 136 | """ 137 | Save converted model 138 | """ 139 | os.makedirs("exported", exist_ok=True) 140 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 141 | keras_model.save(export_path) 142 | print(f"Export to {export_path}") 143 | -------------------------------------------------------------------------------- /tools/convert_vit_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import vision_transformer 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "vit_tiny_patch16_384", 21 | # no tiny patch32 weights 22 | "vit_small_patch16_384", 23 | "vit_small_patch32_384", 24 | "vit_base_patch16_384", 25 | "vit_base_patch32_384", 26 | "vit_large_patch16_384", 27 | "vit_large_patch32_384", 28 | ] 29 | keras_model_classes = [ 30 | vision_transformer.VisionTransformerTiny16, 31 | vision_transformer.VisionTransformerSmall16, 32 | vision_transformer.VisionTransformerSmall32, 33 | vision_transformer.VisionTransformerBase16, 34 | vision_transformer.VisionTransformerBase32, 35 | vision_transformer.VisionTransformerLarge16, 36 | vision_transformer.VisionTransformerLarge32, 37 | ] 38 | 39 | for timm_model_name, keras_model_class in zip( 40 | timm_model_names, keras_model_classes 41 | ): 42 | """ 43 | Prepare timm model and keras model 44 | """ 45 | input_shape = [384, 384, 3] # use size of 384 for best performance 46 | torch_model = timm.create_model(timm_model_name, pretrained=True) 47 | torch_model = torch_model.eval() 48 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 49 | torch_model.state_dict() 50 | ) 51 | keras_model = keras_model_class( 52 | input_shape=input_shape, 53 | include_preprocessing=False, 54 | classifier_activation="linear", 55 | weights=None, 56 | ) 57 | trainable_weights, non_trainable_weights = separate_keras_weights( 58 | keras_model 59 | ) 60 | 61 | # for torch_name, (_, keras_name) in zip( 62 | # trainable_state_dict.keys(), trainable_weights 63 | # ): 64 | # print(f"{torch_name} {keras_name}") 65 | 66 | # print(len(trainable_state_dict.keys())) 67 | # print(len(trainable_weights)) 68 | 69 | # exit() 70 | 71 | """ 72 | Assign weights 73 | """ 74 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 75 | keras_name: str 76 | torch_name = keras_name 77 | torch_name = torch_name.replace("_", ".") 78 | # patch embedding 79 | torch_name = torch_name.replace("patch.embed.conv", "patch_embed.proj") 80 | # postition_embedding 81 | torch_name = torch_name.replace( 82 | "postition.embedding.pos.embed", "pos_embed" 83 | ) 84 | torch_name = torch_name.replace( 85 | "postition.embedding.cls.token", "cls_token" 86 | ) 87 | # blocks 88 | torch_name = torch_name.replace("attn", "attn.qkv") 89 | # torch_name = torch_name.replace("attn", "attn.proj") 90 | 91 | # weights naming mapping 92 | torch_name = torch_name.replace("kernel", "weight") # conv2d 93 | torch_name = torch_name.replace("gamma", "weight") # bn 94 | torch_name = torch_name.replace("beta", "bias") # bn 95 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 96 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 97 | 98 | # assign weights 99 | if torch_name in trainable_state_dict: 100 | torch_weights = trainable_state_dict[torch_name].numpy() 101 | elif torch_name in non_trainable_state_dict: 102 | torch_weights = non_trainable_state_dict[torch_name].numpy() 103 | else: 104 | raise ValueError( 105 | "Can't find the corresponding torch weights. " 106 | f"Got keras_name={keras_name}, torch_name={torch_name}" 107 | ) 108 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 109 | assign_weights(keras_name, keras_weight, torch_weights) 110 | # special case for Attention module 111 | elif "attn" in keras_name: 112 | torch_name = torch_name.replace("attn.qkv", "attn.proj") 113 | torch_weights = trainable_state_dict[torch_name].numpy() 114 | assign_weights(keras_name, keras_weight, torch_weights) 115 | else: 116 | raise ValueError( 117 | "Can't find the corresponding torch weights. The shape is " 118 | f"mismatched. Got keras_name={keras_name}, " 119 | f"keras_weight shape={keras_weight.shape}, " 120 | f"torch_name={torch_name}, " 121 | f"torch_weights shape={torch_weights.shape}" 122 | ) 123 | 124 | """ 125 | Verify model outputs 126 | """ 127 | np.random.seed(2023) 128 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 129 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 130 | torch_y = torch_model(torch_data) 131 | keras_y = keras_model(keras_data, training=False) 132 | torch_y = torch_y.detach().cpu().numpy() 133 | keras_y = keras.ops.convert_to_numpy(keras_y) 134 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-4) 135 | print(f"{keras_model_class.__name__}: output matched!") 136 | 137 | """ 138 | Save converted model 139 | """ 140 | os.makedirs("exported", exist_ok=True) 141 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 142 | keras_model.save(export_path) 143 | print(f"Export to {export_path}") 144 | -------------------------------------------------------------------------------- /tools/convert_convnext_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import convnext 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "convnext_atto.d2_in1k", 21 | "convnext_femto.d1_in1k", 22 | "convnext_pico.d1_in1k", 23 | "convnext_nano.in12k_ft_in1k", 24 | "convnext_tiny.in12k_ft_in1k", 25 | "convnext_small.in12k_ft_in1k", 26 | "convnext_base.fb_in22k_ft_in1k", 27 | "convnext_large.fb_in22k_ft_in1k", 28 | "convnext_xlarge.fb_in22k_ft_in1k", 29 | ] 30 | keras_model_classes = [ 31 | convnext.ConvNeXtAtto, 32 | convnext.ConvNeXtFemto, 33 | convnext.ConvNeXtPico, 34 | convnext.ConvNeXtNano, 35 | convnext.ConvNeXtTiny, 36 | convnext.ConvNeXtSmall, 37 | convnext.ConvNeXtBase, 38 | convnext.ConvNeXtLarge, 39 | convnext.ConvNeXtXLarge, 40 | ] 41 | 42 | for timm_model_name, keras_model_class in zip( 43 | timm_model_names, keras_model_classes 44 | ): 45 | """ 46 | Prepare timm model and keras model 47 | """ 48 | input_shape = [224, 224, 3] 49 | torch_model = timm.create_model(timm_model_name, pretrained=True) 50 | torch_model = torch_model.eval() 51 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 52 | torch_model.state_dict() 53 | ) 54 | keras_model = keras_model_class( 55 | input_shape=input_shape, 56 | include_preprocessing=False, 57 | classifier_activation="linear", 58 | weights=None, 59 | ) 60 | trainable_weights, non_trainable_weights = separate_keras_weights( 61 | keras_model 62 | ) 63 | 64 | # for torch_name, (_, keras_name) in zip( 65 | # trainable_state_dict.keys(), trainable_weights 66 | # ): 67 | # print(f"{torch_name} {keras_name}") 68 | 69 | # print(len(trainable_state_dict.keys())) 70 | # print(len(trainable_weights)) 71 | 72 | # exit() 73 | 74 | """ 75 | Assign weights 76 | """ 77 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 78 | # prevent gamma to be replaced 79 | is_layerscale = False 80 | keras_name: str 81 | torch_name = keras_name 82 | torch_name = torch_name.replace("_", ".") 83 | 84 | # stem 85 | torch_name = torch_name.replace("stem.0.conv2d.kernel", "stem.0.weight") 86 | torch_name = torch_name.replace("stem.0.conv2d.bias", "stem.0.bias") 87 | 88 | # blocks 89 | torch_name = torch_name.replace("dwconv2d.", "") 90 | torch_name = torch_name.replace("conv2d.", "") 91 | torch_name = torch_name.replace("conv.dw", "conv_dw") 92 | if "layerscale" in torch_name: 93 | is_layerscale = True 94 | torch_name = torch_name.replace("layerscale.", "") 95 | # head 96 | torch_name = torch_name.replace("classifier", "head.fc") 97 | 98 | # weights naming mapping 99 | torch_name = torch_name.replace("kernel", "weight") # conv2d 100 | if not is_layerscale: 101 | torch_name = torch_name.replace("gamma", "weight") # bn 102 | torch_name = torch_name.replace("beta", "bias") # bn 103 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 104 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 105 | 106 | # assign weights 107 | if torch_name in trainable_state_dict: 108 | torch_weights = trainable_state_dict[torch_name].numpy() 109 | elif torch_name in non_trainable_state_dict: 110 | torch_weights = non_trainable_state_dict[torch_name].numpy() 111 | else: 112 | raise ValueError( 113 | "Can't find the corresponding torch weights. " 114 | f"Got keras_name={keras_name}, torch_name={torch_name}" 115 | ) 116 | if is_layerscale: 117 | assign_weights(keras_name, keras_weight, torch_weights) 118 | elif is_same_weights( 119 | keras_name, keras_weight, torch_name, torch_weights 120 | ): 121 | assign_weights(keras_name, keras_weight, torch_weights) 122 | else: 123 | raise ValueError( 124 | "Can't find the corresponding torch weights. The shape is " 125 | f"mismatched. Got keras_name={keras_name}, " 126 | f"keras_weight shape={keras_weight.shape}, " 127 | f"torch_name={torch_name}, " 128 | f"torch_weights shape={torch_weights.shape}" 129 | ) 130 | 131 | """ 132 | Verify model outputs 133 | """ 134 | np.random.seed(2023) 135 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 136 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 137 | torch_y = torch_model(torch_data) 138 | keras_y = keras_model(keras_data, training=False) 139 | torch_y = torch_y.detach().cpu().numpy() 140 | keras_y = keras.ops.convert_to_numpy(keras_y) 141 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-4) 142 | print(f"{keras_model_class.__name__}: output matched!") 143 | 144 | """ 145 | Save converted model 146 | """ 147 | os.makedirs("exported", exist_ok=True) 148 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 149 | keras_model.save(export_path) 150 | print(f"Export to {export_path}") 151 | -------------------------------------------------------------------------------- /tools/convert_mobilenet_v2_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import mobilenet_v2 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "mobilenetv2_050.lamb_in1k", 21 | "mobilenetv2_100.ra_in1k", 22 | "mobilenetv2_110d.ra_in1k", 23 | "mobilenetv2_120d.ra_in1k", 24 | "mobilenetv2_140.ra_in1k", 25 | ] 26 | keras_model_classes = [ 27 | mobilenet_v2.MobileNetV2W050, 28 | mobilenet_v2.MobileNetV2W100, 29 | mobilenet_v2.MobileNetV2W110, 30 | mobilenet_v2.MobileNetV2W120, 31 | mobilenet_v2.MobileNetV2W140, 32 | ] 33 | 34 | for timm_model_name, keras_model_class in zip( 35 | timm_model_names, keras_model_classes 36 | ): 37 | """ 38 | Prepare timm model and keras model 39 | """ 40 | input_shape = [224, 224, 3] 41 | torch_model = timm.create_model(timm_model_name, pretrained=True) 42 | torch_model = torch_model.eval() 43 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 44 | torch_model.state_dict() 45 | ) 46 | keras_model = keras_model_class( 47 | input_shape=input_shape, 48 | include_preprocessing=False, 49 | classifier_activation="linear", 50 | weights=None, 51 | ) 52 | trainable_weights, non_trainable_weights = separate_keras_weights( 53 | keras_model 54 | ) 55 | 56 | # for torch_name, (_, keras_name) in zip( 57 | # trainable_state_dict.keys(), trainable_weights 58 | # ): 59 | # print(f"{torch_name} {keras_name}") 60 | 61 | # print(len(trainable_state_dict.keys())) 62 | # print(len(trainable_weights)) 63 | 64 | # exit() 65 | 66 | """ 67 | Assign weights 68 | """ 69 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 70 | keras_name: str 71 | torch_name = keras_name 72 | torch_name = torch_name.replace("_", ".") 73 | # stem 74 | torch_name = torch_name.replace("conv.stem.conv2d", "conv_stem") 75 | torch_name = torch_name.replace("conv.stem.bn", "bn1") 76 | # blocks 77 | if "blocks.0.0" in torch_name: 78 | # depthwise separation block 79 | torch_name = torch_name.replace("conv.dw.dwconv2d", "conv_dw") 80 | torch_name = torch_name.replace("conv.dw.bn", "bn1") 81 | torch_name = torch_name.replace("conv.pw.conv2d", "conv_pw") 82 | torch_name = torch_name.replace("conv.pw.bn", "bn2") 83 | else: 84 | # inverted residual block 85 | torch_name = torch_name.replace("conv.pw.conv2d", "conv_pw") 86 | torch_name = torch_name.replace("conv.pw.bn", "bn1") 87 | torch_name = torch_name.replace("conv.dw.dwconv2d", "conv_dw") 88 | torch_name = torch_name.replace("conv.dw.bn", "bn2") 89 | torch_name = torch_name.replace("conv.pwl.conv2d", "conv_pwl") 90 | torch_name = torch_name.replace("conv.pwl.bn", "bn3") 91 | # conv head 92 | torch_name = torch_name.replace("conv.head.conv2d", "conv_head") 93 | torch_name = torch_name.replace("conv.head.bn", "bn2") 94 | 95 | # weights naming mapping 96 | torch_name = torch_name.replace("kernel", "weight") # conv2d 97 | torch_name = torch_name.replace("gamma", "weight") # bn 98 | torch_name = torch_name.replace("beta", "bias") # bn 99 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 100 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 101 | 102 | # assign weights 103 | if torch_name in trainable_state_dict: 104 | torch_weights = trainable_state_dict[torch_name].numpy() 105 | elif torch_name in non_trainable_state_dict: 106 | torch_weights = non_trainable_state_dict[torch_name].numpy() 107 | else: 108 | raise ValueError( 109 | "Can't find the corresponding torch weights. " 110 | f"Got keras_name={keras_name}, torch_name={torch_name}" 111 | ) 112 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 113 | assign_weights(keras_name, keras_weight, torch_weights) 114 | else: 115 | raise ValueError( 116 | "Can't find the corresponding torch weights. The shape is " 117 | f"mismatched. Got keras_name={keras_name}, " 118 | f"keras_weight shape={keras_weight.shape}, " 119 | f"torch_name={torch_name}, " 120 | f"torch_weights shape={torch_weights.shape}" 121 | ) 122 | 123 | """ 124 | Verify model outputs 125 | """ 126 | np.random.seed(2023) 127 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 128 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 129 | torch_y = torch_model(torch_data) 130 | keras_y = keras_model(keras_data, training=False) 131 | torch_y = torch_y.detach().cpu().numpy() 132 | keras_y = keras.ops.convert_to_numpy(keras_y) 133 | np.testing.assert_allclose(torch_y, keras_y, atol=2e-5) 134 | print(f"{keras_model_class.__name__}: output matched!") 135 | 136 | """ 137 | Save converted model 138 | """ 139 | os.makedirs("exported", exist_ok=True) 140 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 141 | keras_model.save(export_path) 142 | print(f"Export to {export_path}") 143 | -------------------------------------------------------------------------------- /tools/convert_regnet_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import regnet 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "regnetx_002.pycls_in1k", 21 | "regnety_002.pycls_in1k", 22 | "regnetx_004.pycls_in1k", 23 | "regnety_004.tv2_in1k", 24 | "regnetx_006.pycls_in1k", 25 | "regnety_006.pycls_in1k", 26 | "regnetx_008.tv2_in1k", 27 | "regnety_008.pycls_in1k", 28 | "regnetx_016.tv2_in1k", 29 | "regnety_016.tv2_in1k", 30 | "regnetx_032.tv2_in1k", 31 | "regnety_032.ra_in1k", 32 | "regnetx_040.pycls_in1k", 33 | "regnety_040.ra3_in1k", 34 | "regnetx_064.pycls_in1k", 35 | "regnety_064.ra3_in1k", 36 | "regnetx_080.tv2_in1k", 37 | "regnety_080.ra3_in1k", 38 | "regnetx_120.pycls_in1k", 39 | "regnety_120.sw_in12k_ft_in1k", 40 | "regnetx_160.tv2_in1k", 41 | "regnety_160.swag_ft_in1k", 42 | "regnetx_320.tv2_in1k", 43 | "regnety_320.swag_ft_in1k", 44 | ] 45 | keras_model_classes = [ 46 | regnet.RegNetX002, 47 | regnet.RegNetY002, 48 | regnet.RegNetX004, 49 | regnet.RegNetY004, 50 | regnet.RegNetX006, 51 | regnet.RegNetY006, 52 | regnet.RegNetX008, 53 | regnet.RegNetY008, 54 | regnet.RegNetX016, 55 | regnet.RegNetY016, 56 | regnet.RegNetX032, 57 | regnet.RegNetY032, 58 | regnet.RegNetX040, 59 | regnet.RegNetY040, 60 | regnet.RegNetX064, 61 | regnet.RegNetY064, 62 | regnet.RegNetX080, 63 | regnet.RegNetY080, 64 | regnet.RegNetX120, 65 | regnet.RegNetY120, 66 | regnet.RegNetX160, 67 | regnet.RegNetY160, 68 | regnet.RegNetX320, 69 | regnet.RegNetY320, 70 | ] 71 | 72 | for timm_model_name, keras_model_class in zip( 73 | timm_model_names, keras_model_classes 74 | ): 75 | """ 76 | Prepare timm model and keras model 77 | """ 78 | input_shape = [224, 224, 3] 79 | torch_model = timm.create_model(timm_model_name, pretrained=True) 80 | torch_model = torch_model.eval() 81 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 82 | torch_model.state_dict() 83 | ) 84 | keras_model = keras_model_class( 85 | input_shape=input_shape, 86 | include_preprocessing=False, 87 | classifier_activation="linear", 88 | weights=None, 89 | ) 90 | trainable_weights, non_trainable_weights = separate_keras_weights( 91 | keras_model 92 | ) 93 | 94 | # for torch_name, (_, keras_name) in zip( 95 | # trainable_state_dict.keys(), trainable_weights 96 | # ): 97 | # print(f"{torch_name} {keras_name}") 98 | 99 | # print(len(trainable_state_dict.keys())) 100 | # print(len(trainable_weights)) 101 | # print(timm_model_name, keras_model_class.__name__) 102 | 103 | # exit() 104 | 105 | """ 106 | Assign weights 107 | """ 108 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 109 | keras_name: str 110 | torch_name = keras_name 111 | torch_name = torch_name.replace("_", ".") 112 | # stem 113 | torch_name = torch_name.replace("stem_conv2d", "stem.conv") 114 | # blocks 115 | torch_name = torch_name.replace("conv2d", "conv") 116 | # se 117 | torch_name = torch_name.replace("se.conv.reduce", "se.fc1") 118 | torch_name = torch_name.replace("se.conv.expand", "se.fc2") 119 | # head 120 | torch_name = torch_name.replace("classifier", "head.fc") 121 | 122 | # weights naming mapping 123 | torch_name = torch_name.replace("kernel", "weight") # conv2d 124 | torch_name = torch_name.replace("gamma", "weight") # bn 125 | torch_name = torch_name.replace("beta", "bias") # bn 126 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 127 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 128 | 129 | # assign weights 130 | if torch_name in trainable_state_dict: 131 | torch_weights = trainable_state_dict[torch_name].numpy() 132 | elif torch_name in non_trainable_state_dict: 133 | torch_weights = non_trainable_state_dict[torch_name].numpy() 134 | else: 135 | raise ValueError( 136 | "Can't find the corresponding torch weights. " 137 | f"Got keras_name={keras_name}, torch_name={torch_name}" 138 | ) 139 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 140 | assign_weights(keras_name, keras_weight, torch_weights) 141 | else: 142 | raise ValueError( 143 | "Can't find the corresponding torch weights. The shape is " 144 | f"mismatched. Got keras_name={keras_name}, " 145 | f"keras_weight shape={keras_weight.shape}, " 146 | f"torch_name={torch_name}, " 147 | f"torch_weights shape={torch_weights.shape}" 148 | ) 149 | 150 | """ 151 | Verify model outputs 152 | """ 153 | np.random.seed(2023) 154 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 155 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 156 | torch_y = torch_model(torch_data) 157 | keras_y = keras_model(keras_data, training=False) 158 | torch_y = torch_y.detach().cpu().numpy() 159 | keras_y = keras.ops.convert_to_numpy(keras_y) 160 | try: 161 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) 162 | except AssertionError as e: 163 | print(timm_model_name, keras_model_class.__name__) 164 | raise e 165 | print(f"{keras_model_class.__name__}: output matched!") 166 | 167 | """ 168 | Save converted model 169 | """ 170 | os.makedirs("exported", exist_ok=True) 171 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 172 | keras_model.save(export_path) 173 | print(f"Export to {export_path}") 174 | -------------------------------------------------------------------------------- /tools/convert_inception_v3_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import inception_v3 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "inception_v3.gluon_in1k", 21 | "inception_v3.gluon_in1k", 22 | ] 23 | keras_model_classes = [ 24 | inception_v3.InceptionV3, 25 | inception_v3.InceptionV3, 26 | ] 27 | has_aux_logits_list = [True, False] 28 | 29 | for timm_model_name, keras_model_class, has_aux_logits in zip( 30 | timm_model_names, 31 | keras_model_classes, 32 | has_aux_logits_list, 33 | ): 34 | """ 35 | Prepare timm model and keras model 36 | """ 37 | input_shape = [299, 299, 3] 38 | torch_model = timm.create_model( 39 | timm_model_name, pretrained=True, aux_logits=has_aux_logits 40 | ) 41 | torch_model = torch_model.eval() 42 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 43 | torch_model.state_dict() 44 | ) 45 | keras_model = keras_model_class( 46 | has_aux_logits=has_aux_logits, 47 | input_shape=input_shape, 48 | include_preprocessing=False, 49 | classifier_activation="linear", 50 | weights=None, 51 | ) 52 | trainable_weights, non_trainable_weights = separate_keras_weights( 53 | keras_model 54 | ) 55 | 56 | # for torch_name, (_, keras_name) in zip( 57 | # trainable_state_dict.keys(), trainable_weights 58 | # ): 59 | # print(f"{torch_name} {keras_name}") 60 | 61 | # print(len(trainable_state_dict.keys())) 62 | # print(len(trainable_weights)) 63 | 64 | # exit() 65 | 66 | """ 67 | Preprocess 68 | """ 69 | new_dict = {} 70 | old_keys = trainable_state_dict.keys() 71 | new_keys = [] 72 | for k in old_keys: 73 | new_key = k.replace("_", ".") 74 | new_key = new_key.replace("running.mean", "running_mean") 75 | new_key = new_key.replace("running.var", "running_var") 76 | new_keys.append(new_key) 77 | for k1, k2 in zip(trainable_state_dict.keys(), new_keys): 78 | new_dict[k2] = trainable_state_dict[k1] 79 | trainable_state_dict = new_dict 80 | 81 | new_dict = {} 82 | old_keys = non_trainable_state_dict.keys() 83 | new_keys = [] 84 | for k in old_keys: 85 | new_key = k.replace("_", ".") 86 | new_key = new_key.replace("running.mean", "running_mean") 87 | new_key = new_key.replace("running.var", "running_var") 88 | new_keys.append(new_key) 89 | for k1, k2 in zip(non_trainable_state_dict.keys(), new_keys): 90 | new_dict[k2] = non_trainable_state_dict[k1] 91 | non_trainable_state_dict = new_dict 92 | 93 | """ 94 | Assign weights 95 | """ 96 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 97 | keras_name: str 98 | torch_name = keras_name 99 | torch_name = torch_name.replace("_", ".") 100 | # general 101 | torch_name = torch_name.replace("conv2d", "conv") 102 | # head 103 | torch_name = torch_name.replace("classifier", "fc") 104 | 105 | # weights naming mapping 106 | torch_name = torch_name.replace("kernel", "weight") # conv2d 107 | torch_name = torch_name.replace("gamma", "weight") # bn 108 | torch_name = torch_name.replace("beta", "bias") # bn 109 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 110 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 111 | 112 | # assign weights 113 | if torch_name in trainable_state_dict: 114 | torch_weights = trainable_state_dict[torch_name].numpy() 115 | elif torch_name in non_trainable_state_dict: 116 | torch_weights = non_trainable_state_dict[torch_name].numpy() 117 | else: 118 | raise ValueError( 119 | "Can't find the corresponding torch weights. " 120 | f"Got keras_name={keras_name}, torch_name={torch_name}" 121 | ) 122 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 123 | assign_weights(keras_name, keras_weight, torch_weights) 124 | else: 125 | raise ValueError( 126 | "Can't find the corresponding torch weights. The shape is " 127 | f"mismatched. Got keras_name={keras_name}, " 128 | f"keras_weight shape={keras_weight.shape}, " 129 | f"torch_name={torch_name}, " 130 | f"torch_weights shape={torch_weights.shape}" 131 | ) 132 | 133 | """ 134 | Verify model outputs 135 | """ 136 | np.random.seed(2023) 137 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 138 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 139 | if has_aux_logits: 140 | torch_y = torch_model(torch_data)[0] 141 | keras_y = keras_model(keras_data, training=False)[0] 142 | torch_y = torch_y.detach().cpu().numpy() 143 | keras_y = keras.ops.convert_to_numpy(keras_y) 144 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) 145 | else: 146 | torch_y = torch_model(torch_data) 147 | keras_y = keras_model(keras_data, training=False) 148 | torch_y = torch_y.detach().cpu().numpy() 149 | keras_y = keras.ops.convert_to_numpy(keras_y) 150 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) 151 | print(f"{keras_model_class.__name__}: output matched!") 152 | 153 | """ 154 | Save converted model 155 | """ 156 | os.makedirs("exported", exist_ok=True) 157 | if has_aux_logits: 158 | export_path = ( 159 | f"exported/{keras_model.name.lower()}_{timm_model_name}_" 160 | "aux_logits.keras" 161 | ) 162 | else: 163 | export_path = ( 164 | f"exported/{keras_model.name.lower()}_{timm_model_name}_" 165 | "no_aux_logits.keras" 166 | ) 167 | keras_model.save(export_path) 168 | print(f"Export to {export_path}") 169 | -------------------------------------------------------------------------------- /tools/convert_ghostnet_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models.ghostnet import GhostNet100 14 | from kimm.models.ghostnet import GhostNet100V2 15 | from kimm.models.ghostnet import GhostNet130V2 16 | from kimm.models.ghostnet import GhostNet160V2 17 | from kimm.timm_utils import assign_weights 18 | from kimm.timm_utils import is_same_weights 19 | from kimm.timm_utils import separate_keras_weights 20 | from kimm.timm_utils import separate_torch_state_dict 21 | 22 | timm_model_names = [ 23 | "ghostnet_100", 24 | "ghostnetv2_100", 25 | "ghostnetv2_130", 26 | "ghostnetv2_160", 27 | ] 28 | keras_model_classes = [ 29 | GhostNet100, 30 | GhostNet100V2, 31 | GhostNet130V2, 32 | GhostNet160V2, 33 | ] 34 | 35 | for timm_model_name, keras_model_class in zip( 36 | timm_model_names, keras_model_classes 37 | ): 38 | """ 39 | Prepare timm model and keras model 40 | """ 41 | input_shape = [224, 224, 3] 42 | torch_model = timm.create_model(timm_model_name, pretrained=True) 43 | torch_model = torch_model.eval() 44 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 45 | torch_model.state_dict() 46 | ) 47 | keras_model = keras_model_class( 48 | input_shape=input_shape, 49 | include_preprocessing=False, 50 | classifier_activation="linear", 51 | weights=None, 52 | ) 53 | trainable_weights, non_trainable_weights = separate_keras_weights( 54 | keras_model 55 | ) 56 | 57 | # for torch_name, (_, keras_name) in zip( 58 | # trainable_state_dict.keys(), trainable_weights 59 | # ): 60 | # print(f"{torch_name} {keras_name}") 61 | 62 | # print(len(trainable_state_dict.keys())) 63 | # print(len(trainable_weights)) 64 | 65 | # exit() 66 | 67 | """ 68 | Assign weights 69 | """ 70 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 71 | keras_name: str 72 | torch_name = keras_name 73 | torch_name = torch_name.replace("_", ".") 74 | # stem 75 | torch_name = torch_name.replace("conv.stem.conv2d", "conv_stem") 76 | torch_name = torch_name.replace("conv.stem.bn", "bn1") 77 | # blocks 78 | torch_name = torch_name.replace("primary.conv.conv2d", "primary_conv.0") 79 | torch_name = torch_name.replace("primary.conv.bn", "primary_conv.1") 80 | torch_name = torch_name.replace( 81 | "cheap.operation.dwconv2d", "cheap_operation.0" 82 | ) 83 | torch_name = torch_name.replace( 84 | "cheap.operation.bn", "cheap_operation.1" 85 | ) 86 | torch_name = torch_name.replace("conv.dw.dwconv2d", "conv_dw") 87 | torch_name = torch_name.replace("conv.dw.bn", "bn_dw") 88 | torch_name = torch_name.replace("shortcut1.dwconv2d", "shortcut.0") 89 | torch_name = torch_name.replace("shortcut1.bn", "shortcut.1") 90 | torch_name = torch_name.replace("shortcut2.conv2d", "shortcut.2") 91 | torch_name = torch_name.replace("shortcut2.bn", "shortcut.3") 92 | # se 93 | torch_name = torch_name.replace("se.conv.reduce", "se.conv_reduce") 94 | torch_name = torch_name.replace("se.conv.expand", "se.conv_expand") 95 | # short conv (GhostNetV2) 96 | torch_name = torch_name.replace("short.conv1.conv2d", "short_conv.0") 97 | torch_name = torch_name.replace("short.conv1.bn", "short_conv.1") 98 | torch_name = torch_name.replace("short.conv2.dwconv2d", "short_conv.2") 99 | torch_name = torch_name.replace("short.conv2.bn", "short_conv.3") 100 | torch_name = torch_name.replace("short.conv3.dwconv2d", "short_conv.4") 101 | torch_name = torch_name.replace("short.conv3.bn", "short_conv.5") 102 | # final block 103 | torch_name = torch_name.replace("blocks.9.conv2d", "blocks.9.0.conv") 104 | torch_name = torch_name.replace("blocks.9.bn", "blocks.9.0.bn1") 105 | # conv head 106 | if torch_name.startswith("conv.head"): 107 | torch_name = torch_name.replace("conv.head", "conv_head") 108 | 109 | # weights naming mapping 110 | torch_name = torch_name.replace("kernel", "weight") # conv2d 111 | torch_name = torch_name.replace("gamma", "weight") # bn 112 | torch_name = torch_name.replace("beta", "bias") # bn 113 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 114 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 115 | 116 | # assign weights 117 | if torch_name in trainable_state_dict: 118 | torch_weights = trainable_state_dict[torch_name].numpy() 119 | elif torch_name in non_trainable_state_dict: 120 | torch_weights = non_trainable_state_dict[torch_name].numpy() 121 | else: 122 | raise ValueError( 123 | "Can't find the corresponding torch weights. " 124 | f"Got keras_name={keras_name}, torch_name={torch_name}" 125 | ) 126 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 127 | assign_weights(keras_name, keras_weight, torch_weights) 128 | else: 129 | raise ValueError( 130 | "Can't find the corresponding torch weights. The shape is " 131 | f"mismatched. Got keras_name={keras_name}, " 132 | f"keras_weight shape={keras_weight.shape}, " 133 | f"torch_name={torch_name}, " 134 | f"torch_weights shape={torch_weights.shape}" 135 | ) 136 | 137 | """ 138 | Verify model outputs 139 | """ 140 | np.random.seed(2023) 141 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 142 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 143 | torch_y = torch_model(torch_data) 144 | keras_y = keras_model(keras_data, training=False) 145 | torch_y = torch_y.detach().cpu().numpy() 146 | keras_y = keras.ops.convert_to_numpy(keras_y) 147 | np.testing.assert_allclose(torch_y, keras_y, atol=5e-1) 148 | print(f"{keras_model_class.__name__}: output matched!") 149 | 150 | """ 151 | Save converted model 152 | """ 153 | os.makedirs("exported", exist_ok=True) 154 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 155 | keras_model.save(export_path) 156 | print(f"Export to {export_path}") 157 | -------------------------------------------------------------------------------- /tools/convert_repvgg_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import repvgg 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "repvgg_a0.rvgg_in1k", 21 | "repvgg_a1.rvgg_in1k", 22 | "repvgg_a2.rvgg_in1k", 23 | "repvgg_b0.rvgg_in1k", 24 | "repvgg_b1.rvgg_in1k", 25 | "repvgg_b2.rvgg_in1k", 26 | "repvgg_b3.rvgg_in1k", 27 | ] 28 | keras_model_classes = [ 29 | repvgg.RepVGGA0, 30 | repvgg.RepVGGA1, 31 | repvgg.RepVGGA2, 32 | repvgg.RepVGGB0, 33 | repvgg.RepVGGB1, 34 | repvgg.RepVGGB2, 35 | repvgg.RepVGGB3, 36 | ] 37 | 38 | for timm_model_name, keras_model_class in zip( 39 | timm_model_names, keras_model_classes 40 | ): 41 | """ 42 | Prepare timm model and keras model 43 | """ 44 | input_shape = [224, 224, 3] 45 | torch_model = timm.create_model(timm_model_name, pretrained=True) 46 | torch_model = torch_model.eval() 47 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 48 | torch_model.state_dict() 49 | ) 50 | keras_model = keras_model_class( 51 | input_shape=input_shape, 52 | include_preprocessing=False, 53 | classifier_activation="linear", 54 | weights=None, 55 | ) 56 | trainable_weights, non_trainable_weights = separate_keras_weights( 57 | keras_model 58 | ) 59 | 60 | # for torch_name, (_, keras_name) in zip( 61 | # trainable_state_dict.keys(), trainable_weights 62 | # ): 63 | # print(f"{torch_name} {keras_name}") 64 | 65 | # print(len(trainable_state_dict.keys())) 66 | # print(len(trainable_weights)) 67 | 68 | # for torch_name, (_, keras_name) in zip( 69 | # non_trainable_state_dict.keys(), non_trainable_weights 70 | # ): 71 | # print(f"{torch_name} {keras_name}") 72 | 73 | # print(len(non_trainable_state_dict.keys())) 74 | # print(len(non_trainable_weights)) 75 | 76 | # exit() 77 | 78 | """ 79 | Assign weights 80 | """ 81 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 82 | keras_name: str 83 | torch_name = keras_name 84 | torch_name = torch_name.replace("_", ".") 85 | # skip reparam_conv 86 | if "reparam_conv_conv2d" in keras_name: 87 | continue 88 | # repconv2d 89 | torch_name = torch_name.replace("skip.gamma", "identity.gamma") 90 | torch_name = torch_name.replace("skip.beta", "identity.beta") 91 | torch_name = torch_name.replace( 92 | "conv.scale.kernel", "conv_1x1.conv.kernel" 93 | ) 94 | torch_name = torch_name.replace("conv.scale.gamma", "conv_1x1.bn.gamma") 95 | torch_name = torch_name.replace("conv.scale.beta", "conv_1x1.bn.beta") 96 | torch_name = torch_name.replace( 97 | "conv.kxk.0.kernel", "conv_kxk.conv.kernel" 98 | ) 99 | torch_name = torch_name.replace("conv.kxk.0.gamma", "conv_kxk.bn.gamma") 100 | torch_name = torch_name.replace("conv.kxk.0.beta", "conv_kxk.bn.beta") 101 | # repconv2d bn 102 | torch_name = torch_name.replace( 103 | "skip.moving.mean", "identity.moving.mean" 104 | ) 105 | torch_name = torch_name.replace( 106 | "skip.moving.variance", "identity.moving.variance" 107 | ) 108 | torch_name = torch_name.replace( 109 | "conv.scale.moving.mean", "conv_1x1.bn.moving.mean" 110 | ) 111 | torch_name = torch_name.replace( 112 | "conv.scale.moving.variance", "conv_1x1.bn.moving.variance" 113 | ) 114 | torch_name = torch_name.replace( 115 | "conv.kxk.0.moving.mean", "conv_kxk.bn.moving.mean" 116 | ) 117 | torch_name = torch_name.replace( 118 | "conv.kxk.0.moving.variance", "conv_kxk.bn.moving.variance" 119 | ) 120 | # head 121 | torch_name = torch_name.replace("classifier", "head.fc") 122 | 123 | # weights naming mapping 124 | torch_name = torch_name.replace("kernel", "weight") # conv2d 125 | torch_name = torch_name.replace("gamma", "weight") # bn 126 | torch_name = torch_name.replace("beta", "bias") # bn 127 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 128 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 129 | 130 | # assign weights 131 | if torch_name in trainable_state_dict: 132 | torch_weights = trainable_state_dict[torch_name].numpy() 133 | elif torch_name in non_trainable_state_dict: 134 | torch_weights = non_trainable_state_dict[torch_name].numpy() 135 | else: 136 | raise ValueError( 137 | "Can't find the corresponding torch weights. " 138 | f"Got keras_name={keras_name}, torch_name={torch_name}" 139 | ) 140 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 141 | assign_weights(keras_name, keras_weight, torch_weights) 142 | else: 143 | raise ValueError( 144 | "Can't find the corresponding torch weights. The shape is " 145 | f"mismatched. Got keras_name={keras_name}, " 146 | f"keras_weight shape={keras_weight.shape}, " 147 | f"torch_name={torch_name}, " 148 | f"torch_weights shape={torch_weights.shape}" 149 | ) 150 | 151 | """ 152 | Verify model outputs 153 | """ 154 | np.random.seed(2023) 155 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 156 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 157 | torch_y = torch_model(torch_data) 158 | keras_y = keras_model(keras_data, training=False) 159 | torch_y = torch_y.detach().cpu().numpy() 160 | keras_y = keras.ops.convert_to_numpy(keras_y) 161 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) 162 | print(f"{keras_model_class.__name__}: output matched!") 163 | 164 | """ 165 | Save converted model 166 | """ 167 | os.makedirs("exported", exist_ok=True) 168 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 169 | keras_model.save(export_path) 170 | print(f"Export to {export_path}") 171 | -------------------------------------------------------------------------------- /tools/convert_mobilenet_v3_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import mobilenet_v3 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "mobilenetv3_small_050.lamb_in1k", 21 | "mobilenetv3_small_075.lamb_in1k", 22 | "tf_mobilenetv3_small_minimal_100.in1k", 23 | "mobilenetv3_small_100.lamb_in1k", 24 | "mobilenetv3_large_100.miil_in21k_ft_in1k", 25 | "tf_mobilenetv3_large_minimal_100.in1k", 26 | "lcnet_050.ra2_in1k", 27 | "lcnet_075.ra2_in1k", 28 | "lcnet_100.ra2_in1k", 29 | ] 30 | keras_model_classes = [ 31 | mobilenet_v3.MobileNetV3W050Small, 32 | mobilenet_v3.MobileNetV3W075Small, 33 | mobilenet_v3.MobileNetV3W100SmallMinimal, 34 | mobilenet_v3.MobileNetV3W100Small, 35 | mobilenet_v3.MobileNetV3W100Large, 36 | mobilenet_v3.MobileNetV3W100LargeMinimal, 37 | mobilenet_v3.LCNet050, 38 | mobilenet_v3.LCNet075, 39 | mobilenet_v3.LCNet100, 40 | ] 41 | 42 | for timm_model_name, keras_model_class in zip( 43 | timm_model_names, keras_model_classes 44 | ): 45 | """ 46 | Prepare timm model and keras model 47 | """ 48 | input_shape = [224, 224, 3] 49 | torch_model = timm.create_model(timm_model_name, pretrained=True) 50 | torch_model = torch_model.eval() 51 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 52 | torch_model.state_dict() 53 | ) 54 | keras_model = keras_model_class( 55 | input_shape=input_shape, 56 | include_preprocessing=False, 57 | classifier_activation="linear", 58 | weights=None, 59 | ) 60 | trainable_weights, non_trainable_weights = separate_keras_weights( 61 | keras_model 62 | ) 63 | 64 | # for torch_name, (_, keras_name) in zip( 65 | # trainable_state_dict.keys(), trainable_weights 66 | # ): 67 | # print(f"{torch_name} {keras_name}") 68 | 69 | # print(len(trainable_state_dict.keys())) 70 | # print(len(trainable_weights)) 71 | 72 | # exit() 73 | 74 | """ 75 | Assign weights 76 | """ 77 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 78 | keras_name: str 79 | torch_name = keras_name 80 | torch_name = torch_name.replace("_", ".") 81 | # stem 82 | torch_name = torch_name.replace("conv.stem.conv2d", "conv_stem") 83 | torch_name = torch_name.replace("conv.stem.bn", "bn1") 84 | # LCNet 85 | if "LCNet" in keras_model_class.__name__: 86 | # depthwise separation block 87 | torch_name = torch_name.replace("conv.dw.dwconv2d", "conv_dw") 88 | torch_name = torch_name.replace("conv.dw.bn", "bn1") 89 | torch_name = torch_name.replace("conv.pw.conv2d", "conv_pw") 90 | torch_name = torch_name.replace("conv.pw.bn", "bn2") 91 | # blocks 92 | if "blocks.0.0" in torch_name: 93 | # depthwise separation block 94 | torch_name = torch_name.replace("conv.dw.dwconv2d", "conv_dw") 95 | torch_name = torch_name.replace("conv.dw.bn", "bn1") 96 | torch_name = torch_name.replace("conv.pw.conv2d", "conv_pw") 97 | torch_name = torch_name.replace("conv.pw.bn", "bn2") 98 | else: 99 | # inverted residual block 100 | torch_name = torch_name.replace("conv.pw.conv2d", "conv_pw") 101 | torch_name = torch_name.replace("conv.pw.bn", "bn1") 102 | torch_name = torch_name.replace("conv.dw.dwconv2d", "conv_dw") 103 | torch_name = torch_name.replace("conv.dw.bn", "bn2") 104 | torch_name = torch_name.replace("conv.pwl.conv2d", "conv_pwl") 105 | torch_name = torch_name.replace("conv.pwl.bn", "bn3") 106 | # se 107 | torch_name = torch_name.replace("se.conv.reduce", "se.conv_reduce") 108 | torch_name = torch_name.replace("se.conv.expand", "se.conv_expand") 109 | # last conv block 110 | if "Small" in keras_model_class.__name__: 111 | if "blocks.5.0" in torch_name: 112 | torch_name = torch_name.replace("conv2d", "conv") 113 | torch_name = torch_name.replace("bn", "bn1") 114 | if "Large" in keras_model_class.__name__: 115 | if "blocks.6.0" in torch_name: 116 | torch_name = torch_name.replace("conv2d", "conv") 117 | torch_name = torch_name.replace("bn", "bn1") 118 | # conv head 119 | torch_name = torch_name.replace("conv.head", "conv_head") 120 | 121 | # weights naming mapping 122 | torch_name = torch_name.replace("kernel", "weight") # conv2d 123 | torch_name = torch_name.replace("gamma", "weight") # bn 124 | torch_name = torch_name.replace("beta", "bias") # bn 125 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 126 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 127 | 128 | # assign weights 129 | if torch_name in trainable_state_dict: 130 | torch_weights = trainable_state_dict[torch_name].numpy() 131 | elif torch_name in non_trainable_state_dict: 132 | torch_weights = non_trainable_state_dict[torch_name].numpy() 133 | else: 134 | raise ValueError( 135 | "Can't find the corresponding torch weights. " 136 | f"Got keras_name={keras_name}, torch_name={torch_name}" 137 | ) 138 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 139 | assign_weights(keras_name, keras_weight, torch_weights) 140 | else: 141 | raise ValueError( 142 | "Can't find the corresponding torch weights. The shape is " 143 | f"mismatched. Got keras_name={keras_name}, " 144 | f"keras_weight shape={keras_weight.shape}, " 145 | f"torch_name={torch_name}, " 146 | f"torch_weights shape={torch_weights.shape}" 147 | ) 148 | 149 | """ 150 | Verify model outputs 151 | """ 152 | np.random.seed(2023) 153 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 154 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 155 | torch_y = torch_model(torch_data) 156 | keras_y = keras_model(keras_data, training=False) 157 | torch_y = torch_y.detach().cpu().numpy() 158 | keras_y = keras.ops.convert_to_numpy(keras_y) 159 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-4) 160 | print(f"{keras_model_class.__name__}: output matched!") 161 | 162 | """ 163 | Save converted model 164 | """ 165 | os.makedirs("exported", exist_ok=True) 166 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 167 | keras_model.save(export_path) 168 | print(f"Export to {export_path}") 169 | -------------------------------------------------------------------------------- /tools/convert_mobileone_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import mobileone 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "mobileone_s0.apple_in1k", 21 | "mobileone_s1.apple_in1k", 22 | "mobileone_s2.apple_in1k", 23 | "mobileone_s3.apple_in1k", 24 | # "mobileone_s4.apple_in1k", 25 | ] 26 | keras_model_classes = [ 27 | mobileone.MobileOneS0, 28 | mobileone.MobileOneS1, 29 | mobileone.MobileOneS2, 30 | mobileone.MobileOneS3, 31 | # mobileone.MobileOneS4, 32 | ] 33 | 34 | for timm_model_name, keras_model_class in zip( 35 | timm_model_names, keras_model_classes 36 | ): 37 | """ 38 | Prepare timm model and keras model 39 | """ 40 | input_shape = [224, 224, 3] 41 | torch_model = timm.create_model(timm_model_name, pretrained=True) 42 | torch_model = torch_model.eval() 43 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 44 | torch_model.state_dict() 45 | ) 46 | keras_model = keras_model_class( 47 | input_shape=input_shape, 48 | include_preprocessing=False, 49 | classifier_activation="linear", 50 | weights=None, 51 | ) 52 | trainable_weights, non_trainable_weights = separate_keras_weights( 53 | keras_model 54 | ) 55 | 56 | # for torch_name, (_, keras_name) in zip( 57 | # trainable_state_dict.keys(), trainable_weights 58 | # ): 59 | # print(f"{torch_name} {keras_name}") 60 | 61 | # print(len(trainable_state_dict.keys())) 62 | # print(len(trainable_weights)) 63 | 64 | # for torch_name, (_, keras_name) in zip( 65 | # non_trainable_state_dict.keys(), non_trainable_weights 66 | # ): 67 | # print(f"{torch_name} {keras_name}") 68 | 69 | # print(len(non_trainable_state_dict.keys())) 70 | # print(len(non_trainable_weights)) 71 | 72 | # exit() 73 | 74 | """ 75 | Assign weights 76 | """ 77 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 78 | keras_name: str 79 | torch_name = keras_name 80 | torch_name = torch_name.replace("_", ".") 81 | # skip reparam_conv 82 | if "reparam_conv_conv2d" in keras_name: 83 | continue 84 | # mobile_one_conv2d 85 | torch_name = torch_name.replace("skip.gamma", "identity.gamma") 86 | torch_name = torch_name.replace("skip.beta", "identity.beta") 87 | torch_name = torch_name.replace( 88 | "conv.scale.kernel", "conv_scale.conv.kernel" 89 | ) 90 | torch_name = torch_name.replace( 91 | "conv.scale.gamma", "conv_scale.bn.gamma" 92 | ) 93 | torch_name = torch_name.replace("conv.scale.beta", "conv_scale.bn.beta") 94 | if "conv.kxk" in torch_name and "kernel" in torch_name: 95 | torch_name = torch_name.replace("conv.kxk", "conv_kxk") 96 | torch_name = torch_name.replace("kernel", "conv.kernel") 97 | if "conv.kxk" in torch_name and "gamma" in torch_name: 98 | torch_name = torch_name.replace("conv.kxk", "conv_kxk") 99 | torch_name = torch_name.replace("gamma", "bn.gamma") 100 | if "conv.kxk" in torch_name and "beta" in torch_name: 101 | torch_name = torch_name.replace("conv.kxk", "conv_kxk") 102 | torch_name = torch_name.replace("beta", "bn.beta") 103 | # mobile_one_conv2d bn 104 | torch_name = torch_name.replace( 105 | "skip.moving.mean", "identity.moving.mean" 106 | ) 107 | torch_name = torch_name.replace( 108 | "skip.moving.variance", "identity.moving.variance" 109 | ) 110 | torch_name = torch_name.replace( 111 | "conv.scale.moving.mean", "conv_scale.bn.moving.mean" 112 | ) 113 | torch_name = torch_name.replace( 114 | "conv.scale.moving.variance", "conv_scale.bn.moving.variance" 115 | ) 116 | if "conv.kxk" in torch_name and "moving.mean" in torch_name: 117 | torch_name = torch_name.replace("conv.kxk", "conv_kxk") 118 | torch_name = torch_name.replace("moving.mean", "bn.moving.mean") 119 | if "conv.kxk" in torch_name and "moving.variance" in torch_name: 120 | torch_name = torch_name.replace("conv.kxk", "conv_kxk") 121 | torch_name = torch_name.replace( 122 | "moving.variance", "bn.moving.variance" 123 | ) 124 | # head 125 | torch_name = torch_name.replace("classifier", "head.fc") 126 | 127 | # weights naming mapping 128 | torch_name = torch_name.replace("kernel", "weight") # conv2d 129 | torch_name = torch_name.replace("gamma", "weight") # bn 130 | torch_name = torch_name.replace("beta", "bias") # bn 131 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 132 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 133 | 134 | # assign weights 135 | if torch_name in trainable_state_dict: 136 | torch_weights = trainable_state_dict[torch_name].numpy() 137 | elif torch_name in non_trainable_state_dict: 138 | torch_weights = non_trainable_state_dict[torch_name].numpy() 139 | else: 140 | raise ValueError( 141 | "Can't find the corresponding torch weights. " 142 | f"Got keras_name={keras_name}, torch_name={torch_name}" 143 | ) 144 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 145 | assign_weights(keras_name, keras_weight, torch_weights) 146 | else: 147 | raise ValueError( 148 | "Can't find the corresponding torch weights. The shape is " 149 | f"mismatched. Got keras_name={keras_name}, " 150 | f"keras_weight shape={keras_weight.shape}, " 151 | f"torch_name={torch_name}, " 152 | f"torch_weights shape={torch_weights.shape}" 153 | ) 154 | 155 | """ 156 | Verify model outputs 157 | """ 158 | np.random.seed(2023) 159 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 160 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 161 | torch_y = torch_model(torch_data) 162 | keras_y = keras_model(keras_data, training=False) 163 | torch_y = torch_y.detach().cpu().numpy() 164 | keras_y = keras.ops.convert_to_numpy(keras_y) 165 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-4) 166 | print(f"{keras_model_class.__name__}: output matched!") 167 | 168 | """ 169 | Save converted model 170 | """ 171 | os.makedirs("exported", exist_ok=True) 172 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 173 | keras_model.save(export_path) 174 | print(f"Export to {export_path}") 175 | -------------------------------------------------------------------------------- /kimm/_src/layers/reparameterizable_conv2d_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from absl.testing import parameterized 3 | from keras import backend 4 | from keras import random 5 | from keras.src import testing 6 | 7 | from kimm._src.layers.reparameterizable_conv2d import ReparameterizableConv2D 8 | 9 | TEST_CASES = [ 10 | { 11 | "filters": 16, 12 | "kernel_size": 3, 13 | "has_skip": True, 14 | "has_scale": True, 15 | "use_depthwise": False, 16 | "branch_size": 2, 17 | "data_format": "channels_last", 18 | "input_shape": (1, 4, 4, 16), 19 | "output_shape": (1, 4, 4, 16), 20 | "num_trainable_weights": 11, 21 | "num_non_trainable_weights": 8, 22 | }, 23 | { 24 | "filters": 16, 25 | "kernel_size": 3, 26 | "has_skip": True, 27 | "has_scale": True, 28 | "use_depthwise": True, 29 | "branch_size": 3, 30 | "data_format": "channels_last", 31 | "input_shape": (1, 4, 4, 16), 32 | "output_shape": (1, 4, 4, 16), 33 | "num_trainable_weights": 14, 34 | "num_non_trainable_weights": 10, 35 | }, 36 | { 37 | "filters": 16, 38 | "kernel_size": 3, 39 | "has_skip": False, 40 | "has_scale": True, 41 | "use_depthwise": False, 42 | "branch_size": 2, 43 | "data_format": "channels_last", 44 | "input_shape": (1, 4, 4, 8), 45 | "output_shape": (1, 4, 4, 16), 46 | "num_trainable_weights": 9, 47 | "num_non_trainable_weights": 6, 48 | }, 49 | { 50 | "filters": 16, 51 | "kernel_size": 5, 52 | "has_skip": True, 53 | "has_scale": True, 54 | "use_depthwise": False, 55 | "branch_size": 2, 56 | "data_format": "channels_last", 57 | "input_shape": (1, 4, 4, 16), 58 | "output_shape": (1, 4, 4, 16), 59 | "num_trainable_weights": 11, 60 | "num_non_trainable_weights": 8, 61 | }, 62 | { 63 | "filters": 16, 64 | "kernel_size": 3, 65 | "has_skip": True, 66 | "has_scale": True, 67 | "use_depthwise": False, 68 | "branch_size": 2, 69 | "data_format": "channels_first", 70 | "input_shape": (1, 16, 4, 4), 71 | "output_shape": (1, 16, 4, 4), 72 | "num_trainable_weights": 11, 73 | "num_non_trainable_weights": 8, 74 | }, 75 | { 76 | "filters": 16, 77 | "kernel_size": 1, 78 | "has_skip": True, 79 | "has_scale": False, 80 | "use_depthwise": False, 81 | "branch_size": 2, 82 | "data_format": "channels_last", 83 | "input_shape": (1, 4, 4, 16), 84 | "output_shape": (1, 4, 4, 16), 85 | "num_trainable_weights": 8, 86 | "num_non_trainable_weights": 6, 87 | }, 88 | { 89 | "filters": 16, 90 | "kernel_size": 1, 91 | "has_skip": False, 92 | "has_scale": False, 93 | "use_depthwise": True, 94 | "branch_size": 3, 95 | "data_format": "channels_last", 96 | "input_shape": (1, 4, 4, 16), 97 | "output_shape": (1, 4, 4, 16), 98 | "num_trainable_weights": 9, 99 | "num_non_trainable_weights": 6, 100 | }, 101 | ] 102 | 103 | 104 | class ReparameterizableConv2DTest(testing.TestCase, parameterized.TestCase): 105 | @parameterized.parameters(TEST_CASES) 106 | @pytest.mark.requires_trainable_backend 107 | def test_basic( 108 | self, 109 | filters, 110 | kernel_size, 111 | has_skip, 112 | has_scale, 113 | use_depthwise, 114 | branch_size, 115 | data_format, 116 | input_shape, 117 | output_shape, 118 | num_trainable_weights, 119 | num_non_trainable_weights, 120 | ): 121 | if ( 122 | backend.backend() == "tensorflow" 123 | and data_format == "channels_first" 124 | ): 125 | self.skipTest( 126 | "Conv2D in tensorflow backend with 'channels_first' is limited " 127 | "to be supported" 128 | ) 129 | self.run_layer_test( 130 | ReparameterizableConv2D, 131 | init_kwargs={ 132 | "filters": filters, 133 | "kernel_size": kernel_size, 134 | "has_skip": has_skip, 135 | "has_scale": has_scale, 136 | "use_depthwise": use_depthwise, 137 | "branch_size": branch_size, 138 | "data_format": data_format, 139 | }, 140 | input_shape=input_shape, 141 | expected_output_shape=output_shape, 142 | expected_num_trainable_weights=num_trainable_weights, 143 | expected_num_non_trainable_weights=num_non_trainable_weights, 144 | expected_num_losses=0, 145 | supports_masking=False, 146 | ) 147 | 148 | @parameterized.parameters(TEST_CASES) 149 | def test_get_reparameterized_weights( 150 | self, 151 | filters, 152 | kernel_size, 153 | has_skip, 154 | has_scale, 155 | use_depthwise, 156 | branch_size, 157 | data_format, 158 | input_shape, 159 | output_shape, 160 | num_trainable_weights, 161 | num_non_trainable_weights, 162 | ): 163 | if ( 164 | backend.backend() == "tensorflow" 165 | and data_format == "channels_first" 166 | ): 167 | self.skipTest( 168 | "Conv2D in tensorflow backend with 'channels_first' is limited " 169 | "to be supported" 170 | ) 171 | layer = ReparameterizableConv2D( 172 | filters=filters, 173 | kernel_size=kernel_size, 174 | has_skip=has_skip, 175 | has_scale=has_scale, 176 | use_depthwise=use_depthwise, 177 | branch_size=branch_size, 178 | data_format=data_format, 179 | ) 180 | layer.build(input_shape) 181 | reparameterized_layer = ReparameterizableConv2D( 182 | filters=filters, 183 | kernel_size=kernel_size, 184 | has_skip=has_skip, 185 | has_scale=has_scale, 186 | use_depthwise=use_depthwise, 187 | branch_size=branch_size, 188 | reparameterized=True, 189 | data_format=data_format, 190 | ) 191 | reparameterized_layer.build(input_shape) 192 | x = random.uniform(input_shape) 193 | 194 | kernel, bias = layer.get_reparameterized_weights() 195 | reparameterized_layer.reparameterized_conv2d.kernel.assign(kernel) 196 | reparameterized_layer.reparameterized_conv2d.bias.assign(bias) 197 | y1 = layer(x, training=False) 198 | y2 = reparameterized_layer(x, training=False) 199 | 200 | self.assertAllClose(y1, y2, atol=1e-3) 201 | 202 | def test_invalid_args(self): 203 | layer = ReparameterizableConv2D( 204 | filters=4, 205 | kernel_size=3, 206 | has_skip=False, 207 | has_scale=False, 208 | use_depthwise=True, 209 | branch_size=1, 210 | data_format="channels_last", 211 | ) 212 | with self.assertRaisesRegex(ValueError, "must be the same as"): 213 | layer.build([1, 4, 4, 8]) 214 | -------------------------------------------------------------------------------- /tools/convert_mobilevit_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import mobilevit 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "mobilevit_xxs.cvnets_in1k", 21 | "mobilevit_xs.cvnets_in1k", 22 | "mobilevit_s.cvnets_in1k", 23 | "mobilevitv2_050.cvnets_in1k", 24 | "mobilevitv2_075.cvnets_in1k", 25 | "mobilevitv2_100.cvnets_in1k", 26 | "mobilevitv2_125.cvnets_in1k", 27 | "mobilevitv2_150.cvnets_in22k_ft_in1k_384", 28 | "mobilevitv2_175.cvnets_in22k_ft_in1k_384", 29 | "mobilevitv2_200.cvnets_in22k_ft_in1k_384", 30 | ] 31 | keras_model_classes = [ 32 | mobilevit.MobileViTXXS, 33 | mobilevit.MobileViTXS, 34 | mobilevit.MobileViTS, 35 | mobilevit.MobileViTV2W050, 36 | mobilevit.MobileViTV2W075, 37 | mobilevit.MobileViTV2W100, 38 | mobilevit.MobileViTV2W125, 39 | mobilevit.MobileViTV2W150, 40 | mobilevit.MobileViTV2W175, 41 | mobilevit.MobileViTV2W200, 42 | ] 43 | 44 | for timm_model_name, keras_model_class in zip( 45 | timm_model_names, keras_model_classes 46 | ): 47 | """ 48 | Prepare timm model and keras model 49 | """ 50 | input_shape = [256, 256, 3] # use size of 384 for best performance 51 | torch_model = timm.create_model(timm_model_name, pretrained=True) 52 | torch_model = torch_model.eval() 53 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 54 | torch_model.state_dict() 55 | ) 56 | keras_model = keras_model_class( 57 | input_shape=input_shape, 58 | include_preprocessing=False, 59 | classifier_activation="linear", 60 | weights=None, 61 | ) 62 | trainable_weights, non_trainable_weights = separate_keras_weights( 63 | keras_model 64 | ) 65 | 66 | # for torch_name, (_, keras_name) in zip( 67 | # trainable_state_dict.keys(), trainable_weights 68 | # ): 69 | # print(f"{torch_name} {keras_name}") 70 | 71 | # print(len(trainable_state_dict.keys())) 72 | # print(len(trainable_weights)) 73 | 74 | # exit() 75 | 76 | """ 77 | Assign weights 78 | """ 79 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 80 | keras_name: str 81 | torch_name = keras_name 82 | torch_name = torch_name.replace("_", ".") 83 | # stem 84 | torch_name = torch_name.replace("stem.conv2d", "stem.conv") 85 | # inverted residual block 86 | torch_name = torch_name.replace("conv.pw.conv2d", "conv1_1x1.conv") 87 | torch_name = torch_name.replace("conv.pw.bn", "conv1_1x1.bn") 88 | torch_name = torch_name.replace("conv.dw.dwconv2d", "conv2_kxk.conv") 89 | torch_name = torch_name.replace("conv.dw.bn", "conv2_kxk.bn") 90 | torch_name = torch_name.replace("conv.pwl.conv2d", "conv3_1x1.conv") 91 | torch_name = torch_name.replace("conv.pwl.bn", "conv3_1x1.bn") 92 | # mobilevit block 93 | torch_name = torch_name.replace("conv.kxk.conv2d", "conv_kxk.conv") 94 | torch_name = torch_name.replace("conv.kxk.bn", "conv_kxk.bn") 95 | torch_name = torch_name.replace("conv.1x1", "conv_1x1") 96 | torch_name = torch_name.replace("attn", "attn.qkv") 97 | # torch_name = torch_name.replace("attn", "attn.proj") 98 | torch_name = torch_name.replace("conv.proj.conv2d", "conv_proj.conv") 99 | torch_name = torch_name.replace("conv.proj.bn", "conv_proj.bn") 100 | torch_name = torch_name.replace( 101 | "conv.fusion.conv2d", "conv_fusion.conv" 102 | ) 103 | torch_name = torch_name.replace("conv.fusion.bn", "conv_fusion.bn") 104 | # mobilevitv2 block 105 | torch_name = torch_name.replace("conv.kxk.dwconv2d", "conv_kxk.conv") 106 | torch_name = torch_name.replace( 107 | "attn.qkv.qkv.proj.conv2d", "attn.qkv_proj" 108 | ) 109 | torch_name = torch_name.replace( 110 | "attn.qkv.out.proj.conv2d", "attn.out_proj" 111 | ) 112 | torch_name = torch_name.replace("mlp.fc1.conv2d", "mlp.fc1") 113 | torch_name = torch_name.replace("mlp.fc2.conv2d", "mlp.fc2") 114 | # final block 115 | torch_name = torch_name.replace("final.conv.conv2d", "final_conv.conv") 116 | torch_name = torch_name.replace("final.conv.bn", "final_conv.bn") 117 | # head 118 | torch_name = torch_name.replace("classifier", "head.fc") 119 | 120 | # weights naming mapping 121 | torch_name = torch_name.replace("kernel", "weight") # conv2d 122 | torch_name = torch_name.replace("gamma", "weight") # bn 123 | torch_name = torch_name.replace("beta", "bias") # bn 124 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 125 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 126 | 127 | # assign weights 128 | if torch_name in trainable_state_dict: 129 | torch_weights = trainable_state_dict[torch_name].numpy() 130 | elif torch_name in non_trainable_state_dict: 131 | torch_weights = non_trainable_state_dict[torch_name].numpy() 132 | else: 133 | raise ValueError( 134 | "Can't find the corresponding torch weights. " 135 | f"Got keras_name={keras_name}, torch_name={torch_name}" 136 | ) 137 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 138 | assign_weights(keras_name, keras_weight, torch_weights) 139 | # special case for Attention module 140 | elif "attn" in keras_name: 141 | torch_name = torch_name.replace("attn.qkv", "attn.proj") 142 | torch_weights = trainable_state_dict[torch_name].numpy() 143 | assign_weights(keras_name, keras_weight, torch_weights) 144 | else: 145 | raise ValueError( 146 | "Can't find the corresponding torch weights. The shape is " 147 | f"mismatched. Got keras_name={keras_name}, " 148 | f"keras_weight shape={keras_weight.shape}, " 149 | f"torch_name={torch_name}, " 150 | f"torch_weights shape={torch_weights.shape}" 151 | ) 152 | 153 | """ 154 | Verify model outputs 155 | """ 156 | np.random.seed(2023) 157 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 158 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 159 | torch_y = torch_model(torch_data) 160 | keras_y = keras_model(keras_data, training=False) 161 | torch_y = torch_y.detach().cpu().numpy() 162 | keras_y = keras.ops.convert_to_numpy(keras_y) 163 | np.testing.assert_allclose(torch_y, keras_y, atol=1e-3) 164 | print(f"{keras_model_class.__name__}: output matched!") 165 | 166 | """ 167 | Save converted model 168 | """ 169 | os.makedirs("exported", exist_ok=True) 170 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 171 | keras_model.save(export_path) 172 | print(f"Export to {export_path}") 173 | -------------------------------------------------------------------------------- /kimm/_src/utils/timm_utils.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import keras 4 | import numpy as np 5 | 6 | from kimm._src.kimm_export import kimm_export 7 | 8 | 9 | def _is_useless_weights(name: str): 10 | if "num_batches_tracked" in name: 11 | return True 12 | else: 13 | return False 14 | 15 | 16 | def _is_non_trainable_weights(name: str): 17 | if "running_mean" in name or "running_var" in name: 18 | return True 19 | else: 20 | return False 21 | 22 | 23 | @kimm_export(parent_path=["kimm.timm_utils"]) 24 | def separate_torch_state_dict(state_dict: typing.OrderedDict): 25 | """Separate the torch state dict into trainable and non-trainable parts. 26 | 27 | Args: 28 | state_dict: A `collections.OrderedDict`. 29 | 30 | Returns: 31 | A tuple containing the trainable and non-trainable state dicts. 32 | """ 33 | trainable_state_dict = state_dict.copy() 34 | non_trainable_state_dict = state_dict.copy() 35 | trainable_remove_keys = [] 36 | non_trainable_remove_keys = [] 37 | for k in state_dict.keys(): 38 | if _is_useless_weights(k): 39 | trainable_remove_keys.append(k) 40 | non_trainable_remove_keys.append(k) 41 | continue 42 | if _is_non_trainable_weights(k): 43 | trainable_remove_keys.append(k) 44 | else: 45 | non_trainable_remove_keys.append(k) 46 | for k in trainable_remove_keys: 47 | trainable_state_dict.pop(k) 48 | for k in non_trainable_remove_keys: 49 | non_trainable_state_dict.pop(k) 50 | return trainable_state_dict, non_trainable_state_dict 51 | 52 | 53 | @kimm_export(parent_path=["kimm.timm_utils"]) 54 | def separate_keras_weights(keras_model: keras.Model): 55 | """Separate the Keras model into trainable and non-trainable parts. 56 | 57 | Args: 58 | keras_model: A `keras.Model` instance. 59 | 60 | Returns: 61 | A tuple containing the trainable and non-trainable state lists. Each 62 | list contains (`keras.Variable`, name) pairs. 63 | """ 64 | trainable_weights = [] 65 | non_trainable_weights = [] 66 | for layer in keras_model.layers: 67 | if hasattr(layer, "_sublayers"): 68 | for sub_layer in layer._sublayers: 69 | sub_layer: keras.Layer 70 | for weight in sub_layer.trainable_weights: 71 | trainable_weights.append( 72 | (weight, sub_layer.name + "_" + weight.name) 73 | ) 74 | for weight in sub_layer.non_trainable_weights: 75 | non_trainable_weights.append( 76 | (weight, sub_layer.name + "_" + weight.name) 77 | ) 78 | else: 79 | layer: keras.Layer 80 | for weight in layer.trainable_weights: 81 | trainable_weights.append( 82 | (weight, layer.name + "_" + weight.name) 83 | ) 84 | for weight in layer.non_trainable_weights: 85 | non_trainable_weights.append( 86 | (weight, layer.name + "_" + weight.name) 87 | ) 88 | return trainable_weights, non_trainable_weights 89 | 90 | 91 | @kimm_export(parent_path=["kimm.timm_utils"]) 92 | def assign_weights( 93 | keras_name: str, keras_weight: keras.Variable, torch_weight: np.ndarray 94 | ): 95 | """Assign the torch weights to the keras weights based on the arguments. 96 | 97 | Some basic criterion: 98 | 1. 4D must be a convolution weights (also check the name) 99 | 2. 2D must be a dense weights 100 | 3. 1D must be a vector weights 101 | 4. 0D must be a scalar weights 102 | 103 | Args: 104 | keras_name: A `str` representing the name of the target weights. 105 | keras_weights: A `keras.Variable` representing the target weights. 106 | torch_weights: A `numpy.ndarray` representing the original source 107 | weights. 108 | """ 109 | if len(keras_weight.shape) == 4: 110 | if ( 111 | "conv" in keras_name 112 | or "pointwise" in keras_name 113 | or "dwconv2d" in keras_name 114 | or "depthwise" in keras_name 115 | ): 116 | try: 117 | # conventional conv2d layer 118 | keras_weight.assign(np.transpose(torch_weight, [2, 3, 1, 0])) 119 | except ValueError: 120 | # depthwise conv2d layer 121 | keras_weight.assign(np.transpose(torch_weight, [2, 3, 0, 1])) 122 | else: 123 | raise ValueError( 124 | f"Failed to assign {keras_name}. " 125 | f"keras weight shape={keras_weight.shape}, " 126 | f"torch weight shape={torch_weight.shape}" 127 | ) 128 | elif len(keras_weight.shape) == 2: 129 | # dense layer 130 | keras_weight.assign(np.transpose(torch_weight)) 131 | elif len(keras_weight.shape) == 1: 132 | keras_weight.assign(torch_weight) 133 | elif tuple(keras_weight.shape) == tuple(torch_weight.shape): 134 | keras_weight.assign(torch_weight) 135 | elif len(keras_weight.shape) == 0: # Deal with scalar 136 | if len(torch_weight.shape) == 1: 137 | keras_weight.assign(torch_weight[0]) 138 | else: 139 | raise ValueError( 140 | f"Failed to assign {keras_name}, " 141 | f"keras_weight.shape={keras_weight.shape}, " 142 | f"torch_weight.shape={torch_weight.shape}, " 143 | ) 144 | 145 | 146 | @kimm_export(parent_path=["kimm.timm_utils"]) 147 | def is_same_weights( 148 | keras_name: str, 149 | keras_weights: keras.Variable, 150 | torch_name: str, 151 | torch_weights: np.ndarray, 152 | ): 153 | """Check whether the given keras weights and torch weigths are the same. 154 | 155 | Args: 156 | keras_name: A `str` representing the name of the target weights. 157 | keras_weights: A `keras.Variable` representing the target weights. 158 | torch_name: A `str` representing the name of the original source 159 | weights. 160 | torch_weights: A `numpy.ndarray` representing the original source 161 | weights. 162 | 163 | Returns: 164 | A boolean indicating whether the two weights are the same. 165 | """ 166 | if np.sum(keras_weights.shape) != np.sum(torch_weights.shape): 167 | if np.sum(keras_weights.shape) == 0: # Deal with scalar 168 | if np.sum(torch_weights.shape) == 1: 169 | return True 170 | return False 171 | elif keras_name[-6:] == "kernel" and torch_name[-6:] != "weight": 172 | # Conv kernel 173 | return False 174 | elif keras_name[-5:] == "gamma" and torch_name[-6:] != "weight": 175 | # BatchNormalization gamma 176 | return False 177 | elif keras_name[-4:] == "beta" and torch_name[-4:] != "bias": 178 | # BatchNormalization beta 179 | return False 180 | elif ( 181 | keras_name[-11:] == "moving_mean" and torch_name[-12:] != "running_mean" 182 | ): 183 | # BatchNormalization moving_mean 184 | return False 185 | elif ( 186 | keras_name[-11:] == "moving_variance" 187 | and torch_name[-12:] != "running_var" 188 | ): 189 | # BatchNormalization moving_variance 190 | return False 191 | else: 192 | # TODO: is it always true? 193 | return True 194 | -------------------------------------------------------------------------------- /tools/convert_efficientnet_from_timm.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 3 | pip install timm 4 | """ 5 | 6 | import os 7 | 8 | import keras 9 | import numpy as np 10 | import timm 11 | import torch 12 | 13 | from kimm.models import efficientnet 14 | from kimm.timm_utils import assign_weights 15 | from kimm.timm_utils import is_same_weights 16 | from kimm.timm_utils import separate_keras_weights 17 | from kimm.timm_utils import separate_torch_state_dict 18 | 19 | timm_model_names = [ 20 | "tf_efficientnet_b0.ns_jft_in1k", 21 | "tf_efficientnet_b1.ns_jft_in1k", 22 | "tf_efficientnet_b2.ns_jft_in1k", 23 | "tf_efficientnet_b3.ns_jft_in1k", 24 | "tf_efficientnet_b4.ns_jft_in1k", 25 | "tf_efficientnet_b5.ns_jft_in1k", 26 | "tf_efficientnet_b6.ns_jft_in1k", 27 | "tf_efficientnet_b7.ns_jft_in1k", 28 | "tf_efficientnet_lite0.in1k", 29 | "tf_efficientnet_lite1.in1k", 30 | "tf_efficientnet_lite2.in1k", 31 | "tf_efficientnet_lite3.in1k", 32 | "tf_efficientnet_lite4.in1k", 33 | "tf_efficientnetv2_s.in21k_ft_in1k", 34 | "tf_efficientnetv2_m.in21k_ft_in1k", 35 | "tf_efficientnetv2_l.in21k_ft_in1k", 36 | "tf_efficientnetv2_xl.in21k_ft_in1k", 37 | "tf_efficientnetv2_b0.in1k", 38 | "tf_efficientnetv2_b1.in1k", 39 | "tf_efficientnetv2_b2.in1k", 40 | "tf_efficientnetv2_b3.in1k", 41 | "tinynet_a.in1k", 42 | "tinynet_b.in1k", 43 | "tinynet_c.in1k", 44 | "tinynet_d.in1k", 45 | "tinynet_e.in1k", 46 | ] 47 | keras_model_classes = [ 48 | efficientnet.EfficientNetB0, 49 | efficientnet.EfficientNetB1, 50 | efficientnet.EfficientNetB2, 51 | efficientnet.EfficientNetB3, 52 | efficientnet.EfficientNetB4, 53 | efficientnet.EfficientNetB5, 54 | efficientnet.EfficientNetB6, 55 | efficientnet.EfficientNetB7, 56 | efficientnet.EfficientNetLiteB0, 57 | efficientnet.EfficientNetLiteB1, 58 | efficientnet.EfficientNetLiteB2, 59 | efficientnet.EfficientNetLiteB3, 60 | efficientnet.EfficientNetLiteB4, 61 | efficientnet.EfficientNetV2S, 62 | efficientnet.EfficientNetV2M, 63 | efficientnet.EfficientNetV2L, 64 | efficientnet.EfficientNetV2XL, 65 | efficientnet.EfficientNetV2B0, 66 | efficientnet.EfficientNetV2B1, 67 | efficientnet.EfficientNetV2B2, 68 | efficientnet.EfficientNetV2B3, 69 | efficientnet.TinyNetA, 70 | efficientnet.TinyNetB, 71 | efficientnet.TinyNetC, 72 | efficientnet.TinyNetD, 73 | efficientnet.TinyNetE, 74 | ] 75 | 76 | for timm_model_name, keras_model_class in zip( 77 | timm_model_names, keras_model_classes 78 | ): 79 | """ 80 | Prepare timm model and keras model 81 | """ 82 | input_shape = [224, 224, 3] 83 | torch_model = timm.create_model(timm_model_name, pretrained=True) 84 | torch_model = torch_model.eval() 85 | trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( 86 | torch_model.state_dict() 87 | ) 88 | keras_model = keras_model_class( 89 | input_shape=input_shape, 90 | include_preprocessing=False, 91 | classifier_activation="linear", 92 | weights=None, 93 | ) 94 | trainable_weights, non_trainable_weights = separate_keras_weights( 95 | keras_model 96 | ) 97 | 98 | # for torch_name, (_, keras_name) in zip( 99 | # trainable_state_dict.keys(), trainable_weights 100 | # ): 101 | # print(f"{torch_name} {keras_name}") 102 | 103 | # print(len(trainable_state_dict.keys())) 104 | # print(len(trainable_weights)) 105 | 106 | # exit() 107 | 108 | """ 109 | Assign weights 110 | """ 111 | for keras_weight, keras_name in trainable_weights + non_trainable_weights: 112 | keras_name: str 113 | torch_name = keras_name 114 | torch_name = torch_name.replace("_", ".") 115 | # stem 116 | torch_name = torch_name.replace("conv.stem.conv2d", "conv_stem") 117 | torch_name = torch_name.replace("conv.stem.bn", "bn1") 118 | # blocks 119 | if "EfficientNetV2" in keras_model_class.__name__: 120 | if "blocks.0" in torch_name: 121 | # normal conv 122 | torch_name = torch_name.replace("conv2d", "conv") 123 | torch_name = torch_name.replace("bn", "bn1") 124 | elif "blocks.1" in torch_name or "blocks.2" in torch_name: 125 | # edge residual block 126 | torch_name = torch_name.replace("conv.exp.conv2d", "conv_exp") 127 | torch_name = torch_name.replace("conv.exp.bn", "bn1") 128 | torch_name = torch_name.replace("conv.pwl.conv2d", "conv_pwl") 129 | torch_name = torch_name.replace("conv.pwl.bn", "bn2") 130 | else: 131 | if "blocks.0" in torch_name: 132 | # depthwise separation block 133 | torch_name = torch_name.replace("conv.dw.dwconv2d", "conv_dw") 134 | torch_name = torch_name.replace("conv.dw.bn", "bn1") 135 | torch_name = torch_name.replace("conv.pw.conv2d", "conv_pw") 136 | torch_name = torch_name.replace("conv.pw.bn", "bn2") 137 | # inverted residual block 138 | torch_name = torch_name.replace("conv.pw.conv2d", "conv_pw") 139 | torch_name = torch_name.replace("conv.pw.bn", "bn1") 140 | torch_name = torch_name.replace("conv.dw.dwconv2d", "conv_dw") 141 | torch_name = torch_name.replace("conv.dw.bn", "bn2") 142 | torch_name = torch_name.replace("conv.pwl.conv2d", "conv_pwl") 143 | torch_name = torch_name.replace("conv.pwl.bn", "bn3") 144 | # se 145 | torch_name = torch_name.replace("se.conv.reduce", "se.conv_reduce") 146 | torch_name = torch_name.replace("se.conv.expand", "se.conv_expand") 147 | # conv head 148 | torch_name = torch_name.replace("conv.head.conv2d", "conv_head") 149 | torch_name = torch_name.replace("conv.head.bn", "bn2") 150 | 151 | # weights naming mapping 152 | torch_name = torch_name.replace("kernel", "weight") # conv2d 153 | torch_name = torch_name.replace("gamma", "weight") # bn 154 | torch_name = torch_name.replace("beta", "bias") # bn 155 | torch_name = torch_name.replace("moving.mean", "running_mean") # bn 156 | torch_name = torch_name.replace("moving.variance", "running_var") # bn 157 | 158 | # assign weights 159 | if torch_name in trainable_state_dict: 160 | torch_weights = trainable_state_dict[torch_name].numpy() 161 | elif torch_name in non_trainable_state_dict: 162 | torch_weights = non_trainable_state_dict[torch_name].numpy() 163 | else: 164 | raise ValueError( 165 | "Can't find the corresponding torch weights. " 166 | f"Got keras_name={keras_name}, torch_name={torch_name}" 167 | ) 168 | if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): 169 | assign_weights(keras_name, keras_weight, torch_weights) 170 | else: 171 | raise ValueError( 172 | "Can't find the corresponding torch weights. The shape is " 173 | f"mismatched. Got keras_name={keras_name}, " 174 | f"keras_weight shape={keras_weight.shape}, " 175 | f"torch_name={torch_name}, " 176 | f"torch_weights shape={torch_weights.shape}" 177 | ) 178 | 179 | """ 180 | Verify model outputs 181 | """ 182 | np.random.seed(2023) 183 | keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") 184 | torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) 185 | torch_y = torch_model(torch_data) 186 | keras_y = keras_model(keras_data, training=False) 187 | torch_y = torch_y.detach().cpu().numpy() 188 | keras_y = keras.ops.convert_to_numpy(keras_y) 189 | np.testing.assert_allclose(torch_y, keras_y, atol=2e-5) 190 | print(f"{keras_model_class.__name__}: output matched!") 191 | 192 | """ 193 | Save converted model 194 | """ 195 | os.makedirs("exported", exist_ok=True) 196 | export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" 197 | keras_model.save(export_path) 198 | print(f"Export to {export_path}") 199 | -------------------------------------------------------------------------------- /kimm/models/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from kimm._src.models.base_model import BaseModel 8 | from kimm._src.models.convmixer import ConvMixer736D32 9 | from kimm._src.models.convmixer import ConvMixer1024D20 10 | from kimm._src.models.convmixer import ConvMixer1536D20 11 | from kimm._src.models.convnext import ConvNeXtAtto 12 | from kimm._src.models.convnext import ConvNeXtBase 13 | from kimm._src.models.convnext import ConvNeXtFemto 14 | from kimm._src.models.convnext import ConvNeXtLarge 15 | from kimm._src.models.convnext import ConvNeXtNano 16 | from kimm._src.models.convnext import ConvNeXtPico 17 | from kimm._src.models.convnext import ConvNeXtSmall 18 | from kimm._src.models.convnext import ConvNeXtTiny 19 | from kimm._src.models.convnext import ConvNeXtXLarge 20 | from kimm._src.models.densenet import DenseNet121 21 | from kimm._src.models.densenet import DenseNet161 22 | from kimm._src.models.densenet import DenseNet169 23 | from kimm._src.models.densenet import DenseNet201 24 | from kimm._src.models.efficientnet import EfficientNetB0 25 | from kimm._src.models.efficientnet import EfficientNetB1 26 | from kimm._src.models.efficientnet import EfficientNetB2 27 | from kimm._src.models.efficientnet import EfficientNetB3 28 | from kimm._src.models.efficientnet import EfficientNetB4 29 | from kimm._src.models.efficientnet import EfficientNetB5 30 | from kimm._src.models.efficientnet import EfficientNetB6 31 | from kimm._src.models.efficientnet import EfficientNetB7 32 | from kimm._src.models.efficientnet import EfficientNetLiteB0 33 | from kimm._src.models.efficientnet import EfficientNetLiteB1 34 | from kimm._src.models.efficientnet import EfficientNetLiteB2 35 | from kimm._src.models.efficientnet import EfficientNetLiteB3 36 | from kimm._src.models.efficientnet import EfficientNetLiteB4 37 | from kimm._src.models.efficientnet import EfficientNetV2B0 38 | from kimm._src.models.efficientnet import EfficientNetV2B1 39 | from kimm._src.models.efficientnet import EfficientNetV2B2 40 | from kimm._src.models.efficientnet import EfficientNetV2B3 41 | from kimm._src.models.efficientnet import EfficientNetV2L 42 | from kimm._src.models.efficientnet import EfficientNetV2M 43 | from kimm._src.models.efficientnet import EfficientNetV2S 44 | from kimm._src.models.efficientnet import EfficientNetV2XL 45 | from kimm._src.models.efficientnet import TinyNetA 46 | from kimm._src.models.efficientnet import TinyNetB 47 | from kimm._src.models.efficientnet import TinyNetC 48 | from kimm._src.models.efficientnet import TinyNetD 49 | from kimm._src.models.efficientnet import TinyNetE 50 | from kimm._src.models.ghostnet import GhostNetV2W100 51 | from kimm._src.models.ghostnet import GhostNetV2W130 52 | from kimm._src.models.ghostnet import GhostNetV2W160 53 | from kimm._src.models.ghostnet import GhostNetW050 54 | from kimm._src.models.ghostnet import GhostNetW100 55 | from kimm._src.models.ghostnet import GhostNetW130 56 | from kimm._src.models.ghostnet_v3 import GhostNetV3W050 57 | from kimm._src.models.ghostnet_v3 import GhostNetV3W100 58 | from kimm._src.models.ghostnet_v3 import GhostNetV3W130 59 | from kimm._src.models.ghostnet_v3 import GhostNetV3W160 60 | from kimm._src.models.hgnet import HGNetBase 61 | from kimm._src.models.hgnet import HGNetSmall 62 | from kimm._src.models.hgnet import HGNetTiny 63 | from kimm._src.models.hgnet import HGNetV2B0 64 | from kimm._src.models.hgnet import HGNetV2B1 65 | from kimm._src.models.hgnet import HGNetV2B2 66 | from kimm._src.models.hgnet import HGNetV2B3 67 | from kimm._src.models.hgnet import HGNetV2B4 68 | from kimm._src.models.hgnet import HGNetV2B5 69 | from kimm._src.models.hgnet import HGNetV2B6 70 | from kimm._src.models.inception_next import InceptionNeXtBase 71 | from kimm._src.models.inception_next import InceptionNeXtSmall 72 | from kimm._src.models.inception_next import InceptionNeXtTiny 73 | from kimm._src.models.inception_v3 import InceptionV3 74 | from kimm._src.models.mobilenet_v2 import MobileNetV2W050 75 | from kimm._src.models.mobilenet_v2 import MobileNetV2W100 76 | from kimm._src.models.mobilenet_v2 import MobileNetV2W110 77 | from kimm._src.models.mobilenet_v2 import MobileNetV2W120 78 | from kimm._src.models.mobilenet_v2 import MobileNetV2W140 79 | from kimm._src.models.mobilenet_v3 import LCNet035 80 | from kimm._src.models.mobilenet_v3 import LCNet050 81 | from kimm._src.models.mobilenet_v3 import LCNet075 82 | from kimm._src.models.mobilenet_v3 import LCNet100 83 | from kimm._src.models.mobilenet_v3 import LCNet150 84 | from kimm._src.models.mobilenet_v3 import MobileNetV3W050Small 85 | from kimm._src.models.mobilenet_v3 import MobileNetV3W075Small 86 | from kimm._src.models.mobilenet_v3 import MobileNetV3W100Large 87 | from kimm._src.models.mobilenet_v3 import MobileNetV3W100LargeMinimal 88 | from kimm._src.models.mobilenet_v3 import MobileNetV3W100Small 89 | from kimm._src.models.mobilenet_v3 import MobileNetV3W100SmallMinimal 90 | from kimm._src.models.mobileone import MobileOneS0 91 | from kimm._src.models.mobileone import MobileOneS1 92 | from kimm._src.models.mobileone import MobileOneS2 93 | from kimm._src.models.mobileone import MobileOneS3 94 | from kimm._src.models.mobilevit import MobileViTS 95 | from kimm._src.models.mobilevit import MobileViTV2W050 96 | from kimm._src.models.mobilevit import MobileViTV2W075 97 | from kimm._src.models.mobilevit import MobileViTV2W100 98 | from kimm._src.models.mobilevit import MobileViTV2W125 99 | from kimm._src.models.mobilevit import MobileViTV2W150 100 | from kimm._src.models.mobilevit import MobileViTV2W175 101 | from kimm._src.models.mobilevit import MobileViTV2W200 102 | from kimm._src.models.mobilevit import MobileViTXS 103 | from kimm._src.models.mobilevit import MobileViTXXS 104 | from kimm._src.models.regnet import RegNetX002 105 | from kimm._src.models.regnet import RegNetX004 106 | from kimm._src.models.regnet import RegNetX006 107 | from kimm._src.models.regnet import RegNetX008 108 | from kimm._src.models.regnet import RegNetX016 109 | from kimm._src.models.regnet import RegNetX032 110 | from kimm._src.models.regnet import RegNetX040 111 | from kimm._src.models.regnet import RegNetX064 112 | from kimm._src.models.regnet import RegNetX080 113 | from kimm._src.models.regnet import RegNetX120 114 | from kimm._src.models.regnet import RegNetX160 115 | from kimm._src.models.regnet import RegNetX320 116 | from kimm._src.models.regnet import RegNetY002 117 | from kimm._src.models.regnet import RegNetY004 118 | from kimm._src.models.regnet import RegNetY006 119 | from kimm._src.models.regnet import RegNetY008 120 | from kimm._src.models.regnet import RegNetY016 121 | from kimm._src.models.regnet import RegNetY032 122 | from kimm._src.models.regnet import RegNetY040 123 | from kimm._src.models.regnet import RegNetY064 124 | from kimm._src.models.regnet import RegNetY080 125 | from kimm._src.models.regnet import RegNetY120 126 | from kimm._src.models.regnet import RegNetY160 127 | from kimm._src.models.regnet import RegNetY320 128 | from kimm._src.models.repvgg import RepVGGA0 129 | from kimm._src.models.repvgg import RepVGGA1 130 | from kimm._src.models.repvgg import RepVGGA2 131 | from kimm._src.models.repvgg import RepVGGB0 132 | from kimm._src.models.repvgg import RepVGGB1 133 | from kimm._src.models.repvgg import RepVGGB2 134 | from kimm._src.models.repvgg import RepVGGB3 135 | from kimm._src.models.resnet import ResNet18 136 | from kimm._src.models.resnet import ResNet34 137 | from kimm._src.models.resnet import ResNet50 138 | from kimm._src.models.resnet import ResNet101 139 | from kimm._src.models.resnet import ResNet152 140 | from kimm._src.models.vgg import VGG11 141 | from kimm._src.models.vgg import VGG13 142 | from kimm._src.models.vgg import VGG16 143 | from kimm._src.models.vgg import VGG19 144 | from kimm._src.models.vision_transformer import VisionTransformerBase16 145 | from kimm._src.models.vision_transformer import VisionTransformerBase32 146 | from kimm._src.models.vision_transformer import VisionTransformerLarge16 147 | from kimm._src.models.vision_transformer import VisionTransformerLarge32 148 | from kimm._src.models.vision_transformer import VisionTransformerSmall16 149 | from kimm._src.models.vision_transformer import VisionTransformerSmall32 150 | from kimm._src.models.vision_transformer import VisionTransformerTiny16 151 | from kimm._src.models.vision_transformer import VisionTransformerTiny32 152 | from kimm._src.models.xception import Xception 153 | from kimm.models import base_model 154 | from kimm.models import convmixer 155 | from kimm.models import convnext 156 | from kimm.models import densenet 157 | from kimm.models import efficientnet 158 | from kimm.models import ghostnet 159 | from kimm.models import hgnet 160 | from kimm.models import inception_next 161 | from kimm.models import inception_v3 162 | from kimm.models import mobilenet_v2 163 | from kimm.models import mobilenet_v3 164 | from kimm.models import mobileone 165 | from kimm.models import mobilevit 166 | from kimm.models import regnet 167 | from kimm.models import repvgg 168 | from kimm.models import resnet 169 | from kimm.models import vgg 170 | from kimm.models import vision_transformer 171 | from kimm.models import xception 172 | --------------------------------------------------------------------------------