├── .all-contributorsrc ├── .flake8 ├── .github ├── weekly-digest.yml └── workflows │ └── format-lint.yml ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── .pylintrc ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── __init__.py ├── config ├── config_validator.py ├── pretrained_model_url.yaml ├── prune │ └── cifar100 │ │ ├── densenet_small_l2mag.py │ │ ├── densenet_small_l2mag_slim.py │ │ ├── densenet_small_l2mag_slim_finetune.py │ │ ├── densenet_small_lr_rewinding.py │ │ ├── densenet_small_lth.py │ │ ├── densenet_small_slim.py │ │ ├── mixnet_l_l2_mag.py │ │ ├── mixnet_l_slim.py │ │ ├── simplenet_kd_lth.py │ │ ├── simplenet_l2mag.py │ │ ├── simplenet_l2mag_slim_finetune.py │ │ ├── simplenet_lth.py │ │ └── simplenet_slim.py ├── quantize │ └── cifar100 │ │ ├── densenet_121.py │ │ ├── densenet_small.py │ │ ├── micronet.py │ │ └── simplenet.py └── train │ └── cifar100 │ ├── densenet_121.py │ ├── densenet_201.py │ ├── densenet_small.py │ ├── densenet_small_finetune.py │ ├── densenet_small_kd.py │ ├── micronet.py │ ├── mixnet_l.py │ ├── mixnet_s.py │ ├── resnet18.py │ ├── simplenet.py │ ├── simplenet_finetune.py │ └── simplenet_kd.py ├── environment.yml ├── model_decomposition.py ├── mypy.ini ├── prune.py ├── quantize.py ├── requirements-dev.txt ├── run_check.sh ├── run_docker.sh ├── shrink.py ├── src ├── augmentation │ ├── methods.py │ ├── policies.py │ └── transforms.py ├── criterions.py ├── format.py ├── logger.py ├── lr_schedulers.py ├── models │ ├── __init__.py │ ├── adjmodule_getter.py │ ├── common_activations.py │ ├── common_layers.py │ ├── densenet.py │ ├── mixnet.py │ ├── quant_densenet.py │ ├── quant_mixnet.py │ ├── quant_resnet.py │ ├── quant_simplenet.py │ ├── resnet.py │ ├── simplenet.py │ └── utils.py ├── plotter.py ├── regularizers.py ├── runners │ ├── __init__.py │ ├── pruner.py │ ├── quantizer.py │ ├── runner.py │ ├── shrinker.py │ ├── trainer.py │ └── validator.py ├── tensor_decomposition │ └── decomposition.py └── utils.py ├── tests └── test_dummy.py ├── train.py └── val.py /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "files": [ 3 | "README.md" 4 | ], 5 | "imageSize": 100, 6 | "commit": false, 7 | "contributors": [ 8 | { 9 | "login": "Curt-Park", 10 | "name": "Jinwoo Park (Curt)", 11 | "avatar_url": "https://avatars3.githubusercontent.com/u/14961526?v=4", 12 | "profile": "https://github.com/Curt-Park", 13 | "contributions": [ 14 | "code" 15 | ] 16 | }, 17 | { 18 | "login": "Hoonyyhoon", 19 | "name": "Junghoon Kim", 20 | "avatar_url": "https://avatars0.githubusercontent.com/u/25141842?v=4", 21 | "profile": "https://github.com/Hoonyyhoon", 22 | "contributions": [ 23 | "code" 24 | ] 25 | }, 26 | { 27 | "login": "HSShin0", 28 | "name": "Hyungseok Shin", 29 | "avatar_url": "https://avatars0.githubusercontent.com/u/44793742?v=4", 30 | "profile": "https://github.com/HSShin0", 31 | "contributions": [ 32 | "code" 33 | ] 34 | }, 35 | { 36 | "login": "Ingenjoy", 37 | "name": "Juhee Lee", 38 | "avatar_url": "https://avatars0.githubusercontent.com/u/18753708?v=4", 39 | "profile": "https://www.linkedin.com/in/juhee-lee-393342126/", 40 | "contributions": [ 41 | "code" 42 | ] 43 | }, 44 | { 45 | "login": "JeiKeiLim", 46 | "name": "Jongkuk Lim", 47 | "avatar_url": "https://avatars.githubusercontent.com/u/10356193?v=4", 48 | "profile": "https://limjk.ai", 49 | "contributions": [ 50 | "code" 51 | ] 52 | }, 53 | { 54 | "login": "ulken94", 55 | "name": "Haneol Kim", 56 | "avatar_url": "https://avatars.githubusercontent.com/u/58245037?v=4", 57 | "profile": "https://github.com/ulken94", 58 | "contributions": [ 59 | "code" 60 | ] 61 | } 62 | ], 63 | "contributorsPerLine": 7, 64 | "projectName": "model_compression", 65 | "projectOwner": "j-marple-dev", 66 | "repoType": "github", 67 | "repoHost": "https://github.com", 68 | "skipCi": true 69 | } 70 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501, ANN101, D414, W503 3 | max-line-length = 88 4 | docstrings-convention = google 5 | -------------------------------------------------------------------------------- /.github/weekly-digest.yml: -------------------------------------------------------------------------------- 1 | # Configuration for weekly-digest - https://github.com/apps/weekly-digest 2 | publishDay: thu 3 | canPublishIssues: true 4 | canPublishPullRequests: true 5 | canPublishContributors: true 6 | canPublishStargazers: true 7 | canPublishCommits: true 8 | -------------------------------------------------------------------------------- /.github/workflows/format-lint.yml: -------------------------------------------------------------------------------- 1 | name: format-lint 2 | 3 | on: push 4 | 5 | jobs: 6 | format-lint: 7 | runs-on: ubuntu-18.04 8 | steps: 9 | - uses: actions/checkout@v2 10 | - uses: conda-incubator/setup-miniconda@v2.0.1 11 | with: 12 | activate-environment: model_compression 13 | environment-file: environment.yml 14 | python-version: 3.8 15 | auto-activate-base: false 16 | - shell: bash -l {0} 17 | run: | 18 | conda info 19 | conda list 20 | - name: Format with black, isort 21 | shell: bash -l {0} 22 | run: make format 23 | - name: Lint with pylint, mypy, flake8 using pytest 24 | shell: bash -l {0} 25 | run: make test 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # saved files while running 132 | save/* 133 | 134 | # Decomposed files 135 | decompose/* 136 | 137 | # caches 138 | wandb 139 | .DS_Store 140 | 141 | # Ignore Docker setting 142 | .last_exec_cont_id.txt 143 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | multi_line_output=3 3 | include_trailing_comma=True 4 | force_grid_wrap=0 5 | combine_as_imports=True 6 | line_length=88 7 | force_sort_within_sections=True 8 | known_third_party=wandb 9 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: format 5 | name: format 6 | language: system 7 | entry: make format 8 | types: [python] 9 | - id: test 10 | name: test 11 | language: system 12 | entry: make test 13 | types: [python] 14 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/tensorrt:21.05-py3 2 | 3 | LABEL maintainer="Jongkuk Lim " 4 | 5 | ENV DEBIAN_FRONTEND=noninteractive 6 | ENV TZ=Asia/Seoul 7 | 8 | ARG UID=1000 9 | ARG GID=1000 10 | RUN groupadd -g $GID -o user && useradd -m -u $UID -g $GID -o -s /bin/bash user 11 | 12 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 13 | RUN apt-get update && apt-get install -y sudo dialog apt-utils tzdata 14 | RUN echo "%sudo ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers && echo "user:user" | chpasswd && adduser user sudo 15 | 16 | WORKDIR /home/user 17 | USER user 18 | 19 | # Install Display dependencies 20 | RUN sudo apt-get update && sudo apt-get install -y libgl1-mesa-dev && sudo apt-get -y install jq 21 | 22 | # Install pip3 and C++ linter 23 | RUN sudo apt-get install -y clang-format cppcheck 24 | RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py --force-reinstall && python3 -m pip install --upgrade pip 25 | RUN python3 -m pip install wheel cpplint 26 | 27 | # Install doxygen for C++ documentation 28 | RUN sudo apt-get update && sudo apt-get install -y flex bison && sudo apt-get autoremove -y 29 | RUN git clone -b Release_1_9_2 https://github.com/doxygen/doxygen.git \ 30 | && cd doxygen \ 31 | && mkdir build \ 32 | && cd build \ 33 | && cmake -G "Unix Makefiles" .. \ 34 | && make -j `cat /proc/cpuinfo | grep cores | wc -l` \ 35 | && sudo make install 36 | 37 | # Install PyTorch CUDA 11.1 38 | RUN python3 -m pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html 39 | 40 | # Install other development dependencies 41 | COPY ./requirements-dev.txt ./ 42 | RUN python3 -m pip install -r requirements-dev.txt 43 | RUN rm requirements-dev.txt 44 | 45 | # Download libtorch 46 | RUN wget -q https://download.pytorch.org/libtorch/cu111/libtorch-cxx11-abi-shared-with-deps-1.9.1%2Bcu111.zip \ 47 | && unzip libtorch-cxx11-abi-shared-with-deps-1.9.1+cu111.zip \ 48 | && mkdir libs \ 49 | && mv libtorch libs/libtorch \ 50 | && rm libtorch-cxx11-abi-shared-with-deps-1.9.1+cu111.zip 51 | 52 | # Install cmake 3.21.0 version. 53 | RUN wget -q https://github.com/Kitware/CMake/releases/download/v3.21.0/cmake-3.21.0-linux-x86_64.tar.gz \ 54 | && tar -xzvf cmake-3.21.0-linux-x86_64.tar.gz \ 55 | && sudo ln -s /home/user/cmake-3.21.0-linux-x86_64/bin/cmake /usr/bin/cmake \ 56 | && sudo ln -s /home/user/root/cmake-3.21.0-linux-x86_64/bin/ctest /usr/bin/ctest \ 57 | && sudo ln -s /home/user/root/cmake-3.21.0-linux-x86_64/bin/cpack /usr/bin/cpack \ 58 | && rm cmake-3.21.0-linux-x86_64.tar.gz 59 | 60 | # Terminal environment 61 | RUN git clone https://github.com/JeiKeiLim/my_term.git \ 62 | && cd my_term \ 63 | && ./run.sh 64 | 65 | # Fix error messages with vim plugins 66 | RUN cd /home/user/.vim_runtime/sources_non_forked && rm -rf tlib vim-fugitive && git clone https://github.com/tomtom/tlib_vim.git tlib && git clone https://github.com/tpope/vim-fugitive.git 67 | 68 | # Install vim 8.2 with YCM 69 | RUN sudo apt-get install -y software-properties-common \ 70 | && sudo add-apt-repository ppa:jonathonf/vim \ 71 | && sudo add-apt-repository ppa:ubuntu-toolchain-r/test \ 72 | && sudo apt-get update \ 73 | && sudo apt-get install -y vim g++-8 libstdc++6 74 | 75 | RUN cd /home/user/.vim_runtime/my_plugins \ 76 | && git clone --recursive https://github.com/ycm-core/YouCompleteMe.git \ 77 | && cd YouCompleteMe \ 78 | && CC=gcc-8 CXX=g++-8 python3 install.py --clangd-completer 79 | 80 | # Install DALI 81 | RUN python3 -m pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda110 82 | 83 | # Add PATH 84 | RUN echo "export PATH=/home/user/.local/bin:\$PATH" >> /home/user/.bashrc 85 | RUN echo "export LC_ALL=C.UTF-8 && export LANG=C.UTF-8" >> /home/user/.bashrc 86 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jinwoo Park (Curt) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | format: 2 | black . --exclude checkpoint --exclude wandb --exclude save 3 | isort . --skip checkpoint --skip wandb --skip save 4 | 5 | test: 6 | black . --check --exclude checkpoint --exclude wandb --exclude save 7 | isort . --check-only --skip checkpoint --skip wandb --skip save 8 | env PYTHONPATH=. pytest --pylint --flake8 --mypy --ignore=checkpoint --ignore=wandb --ignore=save --ignore=config 9 | 10 | install: 11 | conda env create -f environment.yml 12 | 13 | dev: 14 | pip install pre-commit 15 | pre-commit install 16 | 17 | docker-push: 18 | docker build -t jmarpledev/model_compression . 19 | docker push jmarpledev/model_compression 20 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize module. 2 | 3 | - Author: Haneol Kim 4 | - Contact: hekim@jmarple.ai 5 | """ 6 | -------------------------------------------------------------------------------- /config/pretrained_model_url.yaml: -------------------------------------------------------------------------------- 1 | # densenet // L=190, k=40, num_classes=100 2 | '044b8c6cb6955787de239cc58180bc693c01a97c3c29dd8ec5b6c222': 3 | dir_name: densenet_large_numclasses_100 4 | file_name: model_best.pth.tar 5 | link: https://drive.google.com/u/0/uc?export=download&confirm=ilUf&id=1Mwr6pmdWN7r4wgmE1nYqEUDyuwLt5ce5 6 | # simplenet // num_classes=100 7 | '25b245b7bf273bea50974f22893e3329d285d9cbf5c99ba3942af8c1': 8 | dir_name: simplenet_numclasses_100 9 | file_name: model_best.pth.tar 10 | link: https://drive.google.com/u/0/uc?id=12rKUVZNDucVfxi4G3fLJwahF2dxKDOfZ&export=download 11 | -------------------------------------------------------------------------------- /config/prune/cifar100/densenet_small_l2mag.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for mangnitude channel-wise pruning. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import densenet_small 9 | 10 | train_config = densenet_small.config 11 | config = { 12 | "TRAIN_CONFIG": train_config, 13 | "N_PRUNING_ITER": 15, 14 | "PRUNE_METHOD": "Magnitude", 15 | "PRUNE_PARAMS": dict( 16 | PRUNE_AMOUNT=0.2, 17 | NORM=2, 18 | STORE_PARAM_BEFORE=train_config["EPOCHS"], 19 | TRAIN_START_FROM=0, 20 | PRUNE_AT_BEST=False, 21 | ), 22 | } 23 | -------------------------------------------------------------------------------- /config/prune/cifar100/densenet_small_l2mag_slim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for mangnitude channel-wise pruning. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import densenet_small 9 | 10 | train_config = densenet_small.config 11 | config = { 12 | "TRAIN_CONFIG": train_config, 13 | "N_PRUNING_ITER": 15, 14 | "PRUNE_METHOD": "Magnitude", 15 | "PRUNE_PARAMS": dict( 16 | PRUNE_AMOUNT=0.2, 17 | NORM=2, 18 | STORE_PARAM_BEFORE=train_config["EPOCHS"], 19 | TRAIN_START_FROM=0, 20 | PRUNE_AT_BEST=False, 21 | ), 22 | } 23 | -------------------------------------------------------------------------------- /config/prune/cifar100/densenet_small_l2mag_slim_finetune.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for mangnitude channel-wise pruning. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import densenet_small, densenet_small_finetune 9 | 10 | train_config = densenet_small.config 11 | regularizer_params = { 12 | "REGULARIZER": "BnWeight", 13 | "REGULARIZER_PARAMS": dict(coeff=1e-5), 14 | "EPOCHS": train_config["EPOCHS"], 15 | } 16 | train_config.update(regularizer_params) 17 | 18 | finetune_config = densenet_small_finetune.config 19 | regularizer_params = { 20 | "REGULARIZER": "BnWeight", 21 | "REGULARIZER_PARAMS": dict(coeff=1e-5), 22 | "EPOCHS": finetune_config["EPOCHS"], 23 | } 24 | finetune_config.update(regularizer_params) 25 | 26 | config = { 27 | "TRAIN_CONFIG": train_config, 28 | "TRAIN_CONFIG_AT_PRUNE": finetune_config, 29 | "N_PRUNING_ITER": 15, 30 | "PRUNE_METHOD": "SlimMagnitude", 31 | "PRUNE_PARAMS": dict( 32 | PRUNE_AMOUNT=0.1, 33 | NORM=2, 34 | STORE_PARAM_BEFORE=train_config["EPOCHS"], 35 | TRAIN_START_FROM=0, 36 | PRUNE_AT_BEST=False, 37 | ), 38 | } 39 | -------------------------------------------------------------------------------- /config/prune/cifar100/densenet_small_lr_rewinding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for Learning Rate Rewinding. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import densenet_small 9 | 10 | train_config = densenet_small.config 11 | config = { 12 | "TRAIN_CONFIG": train_config, 13 | "N_PRUNING_ITER": 15, 14 | "PRUNE_METHOD": "LotteryTicketHypothesis", 15 | "PRUNE_PARAMS": dict( 16 | PRUNE_AMOUNT=0.2, 17 | STORE_PARAM_BEFORE=train_config["EPOCHS"], 18 | TRAIN_START_FROM=0, 19 | PRUNE_AT_BEST=False, 20 | ), 21 | } 22 | -------------------------------------------------------------------------------- /config/prune/cifar100/densenet_small_lth.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for Lottery Ticket Hypothesis. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import densenet_small 9 | 10 | train_config = densenet_small.config 11 | config = { 12 | "TRAIN_CONFIG": train_config, 13 | "N_PRUNING_ITER": 15, 14 | "PRUNE_METHOD": "LotteryTicketHypothesis", 15 | "PRUNE_PARAMS": dict( 16 | PRUNE_AMOUNT=0.2, STORE_PARAM_BEFORE=0, TRAIN_START_FROM=0, PRUNE_AT_BEST=False 17 | ), 18 | } 19 | -------------------------------------------------------------------------------- /config/prune/cifar100/densenet_small_slim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for network slimming + L2 magnitude pruning. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import densenet_small 9 | 10 | train_config = densenet_small.config 11 | train_config.update({"REGULARIZER": "BnWeight", "REGULARIZER_PARAMS": dict(coeff=1e-5)}) 12 | config = { 13 | "TRAIN_CONFIG": train_config, 14 | "N_PRUNING_ITER": 15, 15 | "PRUNE_METHOD": "SlimMagnitude", 16 | "PRUNE_PARAMS": dict( 17 | PRUNE_AMOUNT=0.2, 18 | NORM=2, 19 | STORE_PARAM_BEFORE=train_config["EPOCHS"], 20 | TRAIN_START_FROM=0, 21 | PRUNE_AT_BEST=False, 22 | ), 23 | } 24 | -------------------------------------------------------------------------------- /config/prune/cifar100/mixnet_l_l2_mag.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for mangnitude layerwise pruning. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import mixnet_l 9 | 10 | train_config = mixnet_l.config 11 | train_config.update({"BATCH_SIZE": 128}) 12 | config = { 13 | "TRAIN_CONFIG": train_config, 14 | "N_PRUNING_ITER": 15, 15 | "PRUNE_METHOD": "Magnitude", 16 | "PRUNE_PARAMS": dict( 17 | PRUNE_AMOUNT=0.2, 18 | NORM=2, 19 | STORE_PARAM_BEFORE=train_config["EPOCHS"], 20 | TRAIN_START_FROM=0, 21 | PRUNE_AT_BEST=False, 22 | ), 23 | } 24 | -------------------------------------------------------------------------------- /config/prune/cifar100/mixnet_l_slim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for naive lottery ticket hypothesis. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import mixnet_l 9 | 10 | train_config = mixnet_l.config 11 | train_config.update( 12 | { 13 | "REGULARIZER": "BnWeight", 14 | "REGULARIZER_PARAMS": dict(coeff=1e-5), 15 | "BATCH_SIZE": 128, 16 | } 17 | ) 18 | config = { 19 | "TRAIN_CONFIG": train_config, 20 | "N_PRUNING_ITER": 15, 21 | "PRUNE_METHOD": "NetworkSlimming", 22 | "PRUNE_PARAMS": dict( 23 | PRUNE_AMOUNT=0.2, 24 | STORE_PARAM_BEFORE=train_config["EPOCHS"], 25 | TRAIN_START_FROM=0, 26 | PRUNE_AT_BEST=False, 27 | ), 28 | } 29 | -------------------------------------------------------------------------------- /config/prune/cifar100/simplenet_kd_lth.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for naive lottery ticket hypothesis with kd, simplenet. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import simplenet_kd 9 | 10 | config = { 11 | "TRAIN_CONFIG": simplenet_kd.config, 12 | "N_PRUNING_ITER": 15, 13 | "PRUNE_METHOD": "LotteryTicketHypothesis", 14 | "PRUNE_PARAMS": dict( 15 | PRUNE_AMOUNT=0.2, 16 | STORE_PARAM_BEFORE=simplenet_kd.config["EPOCHS"], 17 | TRAIN_START_FROM=0, 18 | PRUNE_AT_BEST=False, 19 | ), 20 | } 21 | -------------------------------------------------------------------------------- /config/prune/cifar100/simplenet_l2mag.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for Magnitude layerwise pruning. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import simplenet 9 | 10 | train_config = simplenet.config 11 | train_config.update({"REGULARIZER": "BnWeight", "REGULARIZER_PARAMS": dict(coeff=1e-5)}) 12 | config = { 13 | "TRAIN_CONFIG": train_config, 14 | "N_PRUNING_ITER": 5, 15 | "PRUNE_METHOD": "SlimMagnitude", 16 | "PRUNE_PARAMS": dict( 17 | PRUNE_AMOUNT=0.2, 18 | NORM=2, 19 | STORE_PARAM_BEFORE=train_config["EPOCHS"], 20 | TRAIN_START_FROM=0, 21 | PRUNE_AT_BEST=False, 22 | ), 23 | } 24 | -------------------------------------------------------------------------------- /config/prune/cifar100/simplenet_l2mag_slim_finetune.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for slimming simple network. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import simplenet, simplenet_finetune 9 | 10 | train_config = simplenet.config 11 | regularizer_params = { 12 | "REGULARIZER": "BnWeight", 13 | "REGULARIZER_PARAMS": dict(coeff=1e-5), 14 | "EPOCHS": train_config["EPOCHS"], 15 | } 16 | train_config.update(regularizer_params) 17 | 18 | finetune_config = simplenet_finetune.config 19 | regularizer_params = { 20 | "REGULARIZER": "BnWeight", 21 | "REGULARIZER_PARAMS": dict(coeff=1e-5), 22 | "EPOCHS": finetune_config["EPOCHS"], 23 | } 24 | finetune_config.update(regularizer_params) 25 | 26 | 27 | config = { 28 | "TRAIN_CONFIG": train_config, 29 | "TRAIN_CONFIG_AT_PRUNE": finetune_config, 30 | "N_PRUNING_ITER": 15, 31 | "PRUNE_METHOD": "SlimMagnitude", 32 | "PRUNE_PARAMS": dict( 33 | PRUNE_AMOUNT=0.1, 34 | NORM=2, 35 | STORE_PARAM_BEFORE=10, 36 | TRAIN_START_FROM=0, 37 | PRUNE_AT_BEST=False, 38 | ), 39 | } 40 | -------------------------------------------------------------------------------- /config/prune/cifar100/simplenet_lth.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for slimming simple network. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import simplenet 9 | 10 | config = { 11 | "TRAIN_CONFIG": simplenet.config, 12 | "N_PRUNING_ITER": 5, 13 | "PRUNE_METHOD": "LotteryTicketHypothesis", 14 | "PRUNE_PARAMS": dict( 15 | PRUNE_AMOUNT=0.2, STORE_PARAM_BEFORE=5, TRAIN_START_FROM=0, PRUNE_AT_BEST=False 16 | ), 17 | } 18 | -------------------------------------------------------------------------------- /config/prune/cifar100/simplenet_slim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for slimming simple network. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import simplenet 9 | 10 | train_config = simplenet.config 11 | train_config.update({"REGULARIZER": "BnWeight", "REGULARIZER_PARAMS": dict(coeff=1e-5)}) 12 | config = { 13 | "TRAIN_CONFIG": train_config, 14 | "N_PRUNING_ITER": 5, 15 | "PRUNE_METHOD": "NetworkSlimming", 16 | "PRUNE_PARAMS": dict( 17 | PRUNE_AMOUNT=0.2, 18 | STORE_PARAM_BEFORE=train_config["EPOCHS"], 19 | TRAIN_START_FROM=0, 20 | PRUNE_AT_BEST=False, 21 | ), 22 | } 23 | -------------------------------------------------------------------------------- /config/quantize/cifar100/densenet_121.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for training densenet_121. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import densenet_121 9 | 10 | config = densenet_121.config 11 | config.update( 12 | { 13 | "MODEL_NAME": "quant_densenet", 14 | "LR_SCHEDULER_PARAMS": dict(warmup_epochs=0, start_lr=1e-4), 15 | "LR": 1e-4, 16 | "EPOCHS": 5, 17 | } 18 | ) 19 | -------------------------------------------------------------------------------- /config/quantize/cifar100/densenet_small.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for quantization for densenet_small. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import densenet_small 9 | 10 | config = densenet_small.config 11 | config.update( 12 | { 13 | "MODEL_NAME": "quant_densenet", 14 | "LR_SCHEDULER_PARAMS": dict(warmup_epochs=0, start_lr=1e-4), 15 | "LR": 1e-4, 16 | "EPOCHS": 5, 17 | } 18 | ) 19 | -------------------------------------------------------------------------------- /config/quantize/cifar100/micronet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for quantization for MicroNet (CIFAR100). 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import micronet 9 | 10 | config = micronet.config 11 | config.update( 12 | { 13 | "MODEL_NAME": "quant_mixnet", 14 | "LR_SCHEDULER_PARAMS": dict(warmup_epochs=0, start_lr=1e-4), 15 | "LR": 1e-4, 16 | "EPOCHS": 2, 17 | } 18 | ) 19 | -------------------------------------------------------------------------------- /config/quantize/cifar100/simplenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for quantization for simplenet. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import simplenet 9 | 10 | config = simplenet.config 11 | config.update( 12 | { 13 | "MODEL_NAME": "quant_simplenet", 14 | "LR_SCHEDULER_PARAMS": dict(warmup_epochs=0, start_lr=1e-4), 15 | "LR": 1e-4, 16 | "EPOCHS": 2, 17 | } 18 | ) 19 | -------------------------------------------------------------------------------- /config/train/cifar100/densenet_121.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for training densenet_small. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | import os 9 | 10 | config = { 11 | "SEED": 777, 12 | "AUG_TRAIN": "randaugment_train_cifar100_224", 13 | "AUG_TRAIN_PARAMS": dict(n_select=2, level=None), 14 | "AUG_TEST": "simple_augment_test_cifar100_224", 15 | "CUTMIX": dict(beta=1, prob=0.5), 16 | "DATASET": "CIFAR100", 17 | "MODEL_NAME": "densenet", 18 | "MODEL_PARAMS": dict( 19 | num_classes=100, 20 | inplanes=24, 21 | growthRate=32, 22 | compressionRate=2, 23 | block_configs=(6, 12, 24, 16), 24 | small_input=False, 25 | efficient=False, 26 | ), 27 | "CRITERION": "CrossEntropy", 28 | "CRITERION_PARAMS": dict(num_classes=100, label_smoothing=0.1), 29 | "LR_SCHEDULER": "WarmupCosineLR", 30 | "LR_SCHEDULER_PARAMS": dict( 31 | warmup_epochs=5, start_lr=1e-3, min_lr=1e-5, n_rewinding=1 32 | ), 33 | "BATCH_SIZE": 128, 34 | "LR": 0.1, 35 | "MOMENTUM": 0.9, 36 | "WEIGHT_DECAY": 1e-4, 37 | "NESTEROV": True, 38 | "EPOCHS": 300, 39 | "N_WORKERS": os.cpu_count(), 40 | } 41 | -------------------------------------------------------------------------------- /config/train/cifar100/densenet_201.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for training densenet_small. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | import os 9 | 10 | config = { 11 | "SEED": 777, 12 | "AUG_TRAIN": "randaugment_train_cifar100_224", 13 | "AUG_TRAIN_PARAMS": dict(n_select=2, level=None), 14 | "AUG_TEST": "simple_augment_test_cifar100_224", 15 | "CUTMIX": dict(beta=1, prob=0.5), 16 | "DATASET": "CIFAR100", 17 | "MODEL_NAME": "densenet", 18 | "MODEL_PARAMS": dict( 19 | num_classes=100, 20 | inplanes=24, 21 | growthRate=32, 22 | compressionRate=2, 23 | block_configs=(6, 12, 48, 32), 24 | small_input=False, 25 | efficient=False, 26 | ), 27 | "CRITERION": "CrossEntropy", 28 | "CRITERION_PARAMS": dict(num_classes=100, label_smoothing=0.1), 29 | "LR_SCHEDULER": "WarmupCosineLR", 30 | "LR_SCHEDULER_PARAMS": dict( 31 | warmup_epochs=5, start_lr=1e-3, min_lr=1e-5, n_rewinding=1 32 | ), 33 | "BATCH_SIZE": 128, 34 | "LR": 0.1, 35 | "MOMENTUM": 0.9, 36 | "WEIGHT_DECAY": 1e-4, 37 | "NESTEROV": True, 38 | "EPOCHS": 300, 39 | "N_WORKERS": os.cpu_count(), 40 | } 41 | -------------------------------------------------------------------------------- /config/train/cifar100/densenet_small.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for training densenet_small. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | import os 9 | 10 | config = { 11 | "SEED": 777, 12 | "AUG_TRAIN": "randaugment_train_cifar100", 13 | "AUG_TRAIN_PARAMS": dict(n_select=2, level=None), 14 | "AUG_TEST": "simple_augment_test_cifar100", 15 | "CUTMIX": dict(beta=1, prob=0.5), 16 | "DATASET": "CIFAR100", 17 | "MODEL_NAME": "densenet", 18 | "MODEL_PARAMS": dict( 19 | num_classes=100, 20 | inplanes=24, 21 | growthRate=12, 22 | compressionRate=2, 23 | block_configs=(16, 16, 16), 24 | ), 25 | "CRITERION": "CrossEntropy", 26 | "CRITERION_PARAMS": dict(num_classes=100, label_smoothing=0.1), 27 | "LR_SCHEDULER": "WarmupCosineLR", 28 | "LR_SCHEDULER_PARAMS": dict( 29 | warmup_epochs=10, start_lr=1e-3, min_lr=1e-5, n_rewinding=1 30 | ), 31 | "BATCH_SIZE": 64, 32 | "LR": 0.1, 33 | "MOMENTUM": 0.9, 34 | "WEIGHT_DECAY": 1e-4, 35 | "EPOCHS": 300, 36 | "N_WORKERS": os.cpu_count(), 37 | } 38 | -------------------------------------------------------------------------------- /config/train/cifar100/densenet_small_finetune.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for finetune densenet small. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | import os 8 | 9 | config = { 10 | "SEED": 777, 11 | "AUG_TRAIN": "randaugment_train_cifar100", 12 | "AUG_TRAIN_PARAMS": dict(n_select=2, level=None), 13 | "CUTMIX": dict(beta=1, prob=0.5), 14 | "AUG_TEST": "simple_augment_test_cifar100", 15 | "DATASET": "CIFAR100", 16 | "MODEL_NAME": "densenet", 17 | "MODEL_PARAMS": dict( 18 | num_classes=100, 19 | inplanes=24, 20 | growthRate=12, 21 | compressionRate=2, 22 | block_configs=(16, 16, 16), 23 | ), 24 | "CRITERION": "CrossEntropy", 25 | "CRITERION_PARAMS": dict(num_classes=100), 26 | "LR_SCHEDULER": "WarmupCosineLR", 27 | "LR_SCHEDULER_PARAMS": dict( 28 | warmup_epochs=5, start_lr=1e-3, min_lr=5e-6, n_rewinding=1 29 | ), 30 | "BATCH_SIZE": 64, 31 | "LR": 0.001, 32 | "MOMENTUM": 0.9, 33 | "WEIGHT_DECAY": 1e-4, 34 | "NESTEROV": True, 35 | "EPOCHS": 10, 36 | "N_WORKERS": os.cpu_count(), 37 | } 38 | -------------------------------------------------------------------------------- /config/train/cifar100/densenet_small_kd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for knowledge distillation with densenet. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import densenet_small 9 | 10 | config = densenet_small.config 11 | config_override = { 12 | "CRITERION": "HintonKLD", 13 | "CRITERION_PARAMS": dict( 14 | T=4.0, 15 | alpha=0.9, 16 | teacher_model_name="densenet", 17 | teacher_model_params=dict( 18 | num_classes=100, 19 | inplanes=24, 20 | growthRate=40, 21 | compressionRate=2, 22 | block_configs=(31, 31, 31), 23 | ), 24 | crossentropy_params=dict(num_classes=100), 25 | ), 26 | "LR_SCHEDULER": "WarmupCosineLR", 27 | "LR_SCHEDULER_PARAMS": dict(warmup_epochs=3, start_lr=1e-4), 28 | "BATCH_SIZE": 32, 29 | "LR": 0.1, 30 | "MOMENTUM": 0.9, 31 | "WEIGHT_DECAY": 1e-4, 32 | "EPOCHS": 5, 33 | } 34 | config.update(config_override) 35 | -------------------------------------------------------------------------------- /config/train/cifar100/micronet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for training micronet (cifar100). 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | import os 9 | 10 | config = { 11 | "SEED": 777, 12 | "AUG_TRAIN": "randaugment_train_cifar100", 13 | "AUG_TRAIN_PARAMS": dict(n_select=2, level=None), 14 | "AUG_TEST": "simple_augment_test_cifar100", 15 | "CUTMIX": dict(beta=1.0, prob=0.5), 16 | "DATASET": "CIFAR100", 17 | "MODEL_NAME": "mixnet", 18 | "MODEL_PARAMS": dict(num_classes=100, model_type="MICRONET", dataset="CIFAR100"), 19 | "CRITERION": "CrossEntropy", 20 | "CRITERION_PARAMS": dict(num_classes=100, label_smoothing=0.1), 21 | "LR_SCHEDULER": "WarmupCosineLR", 22 | "LR_SCHEDULER_PARAMS": dict( 23 | warmup_epochs=10, start_lr=1e-3, min_lr=1e-4, n_rewinding=1, decay=0.0 24 | ), 25 | "BATCH_SIZE": 32, 26 | "LR": 0.1, 27 | "MOMENTUM": 0.9, 28 | "WEIGHT_DECAY": 1e-5, 29 | "NESTEROV": True, 30 | "EPOCHS": 600, 31 | "N_WORKERS": os.cpu_count(), 32 | } 33 | -------------------------------------------------------------------------------- /config/train/cifar100/mixnet_l.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for training as mixnet_l (cifar100). 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | import os 9 | 10 | config = { 11 | "SEED": 777, 12 | "AUG_TRAIN": "randaugment_train_cifar100", 13 | "AUG_TRAIN_PARAMS": dict(n_select=2, level=None), 14 | "AUG_TEST": "simple_augment_test_cifar100", 15 | "CUTMIX": dict(beta=1.0, prob=0.5), 16 | "DATASET": "CIFAR100", 17 | "MODEL_NAME": "mixnet", 18 | "MODEL_PARAMS": dict(num_classes=100, model_type="L", dataset="CIFAR100"), 19 | "CRITERION": "CrossEntropy", 20 | "CRITERION_PARAMS": dict(num_classes=100, label_smoothing=0.1), 21 | "LR_SCHEDULER": "WarmupCosineLR", 22 | "LR_SCHEDULER_PARAMS": dict( 23 | warmup_epochs=10, start_lr=1e-3, min_lr=1e-4, n_rewinding=1, decay=0.0 24 | ), 25 | "BATCH_SIZE": 256, 26 | "LR": 0.1, 27 | "MOMENTUM": 0.9, 28 | "WEIGHT_DECAY": 1e-4, 29 | "EPOCHS": 300, 30 | "N_WORKERS": os.cpu_count(), 31 | } 32 | -------------------------------------------------------------------------------- /config/train/cifar100/mixnet_s.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for training as mixnet_s (cifar100). 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | import os 9 | 10 | config = { 11 | "SEED": 777, 12 | "AUG_TRAIN": "randaugment_train_cifar100", 13 | "AUG_TRAIN_PARAMS": dict(n_select=2, level=14), 14 | "AUG_TEST": "simple_augment_test_cifar100", 15 | "CUTMIX": dict(beta=1.0, prob=0.5), 16 | "DATASET": "CIFAR100", 17 | "MODEL_NAME": "mixnet", 18 | "MODEL_PARAMS": dict(num_classes=100, model_type="S", dataset="CIFAR100"), 19 | "CRITERION": "CrossEntropy", 20 | "CRITERION_PARAMS": dict(num_classes=100), 21 | "LR_SCHEDULER": "WarmupCosineLR", 22 | "LR_SCHEDULER_PARAMS": dict(warmup_epochs=10, start_lr=1e-3), 23 | "BATCH_SIZE": 256, 24 | "LR": 0.1, 25 | "MOMENTUM": 0.9, 26 | "WEIGHT_DECAY": 1e-4, 27 | "EPOCHS": 300, 28 | "N_WORKERS": os.cpu_count(), 29 | } 30 | -------------------------------------------------------------------------------- /config/train/cifar100/resnet18.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for training resnet18 (cifar100). 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | import os 9 | 10 | config = { 11 | "SEED": 777, 12 | "AUG_TRAIN": "randaugment_train_cifar100", 13 | "AUG_TRAIN_PARAMS": dict(n_select=2, level=14), 14 | "AUG_TEST": "simple_augment_test_cifar100", 15 | "CUTMIX": dict(beta=1.0, prob=0.5), 16 | "DATASET": "CIFAR100", 17 | "MODEL_NAME": "resnet", 18 | "MODEL_PARAMS": dict(num_classes=100, model_type="resnet18"), 19 | "CRITERION": "CrossEntropy", 20 | "CRITERION_PARAMS": dict(num_classes=100), 21 | "LR_SCHEDULER": "WarmupCosineLR", 22 | "LR_SCHEDULER_PARAMS": dict(warmup_epochs=10, start_lr=1e-3), 23 | "BATCH_SIZE": 256, 24 | "LR": 0.1, 25 | "MOMENTUM": 0.9, 26 | "WEIGHT_DECAY": 1e-4, 27 | "NESTEROV": True, 28 | "EPOCHS": 300, 29 | "N_WORKERS": os.cpu_count(), 30 | } 31 | -------------------------------------------------------------------------------- /config/train/cifar100/simplenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for training as baseline. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | import os 9 | 10 | config = { 11 | "SEED": 777, 12 | "AUG_TRAIN": "simple_augment_train_cifar100", 13 | "AUG_TEST": "simple_augment_test_cifar100", 14 | "DATASET": "CIFAR100", 15 | "MODEL_NAME": "simplenet", 16 | "MODEL_PARAMS": dict(num_classes=100), 17 | "CRITERION": "CrossEntropy", 18 | "CRITERION_PARAMS": dict(num_classes=100), 19 | "LR_SCHEDULER": "WarmupCosineLR", 20 | "LR_SCHEDULER_PARAMS": dict(warmup_epochs=3, start_lr=1e-3), 21 | "BATCH_SIZE": 64, 22 | "LR": 0.1, 23 | "MOMENTUM": 0.9, 24 | "WEIGHT_DECAY": 1e-4, 25 | "NESTEROV": True, 26 | "EPOCHS": 5, 27 | "N_WORKERS": os.cpu_count(), 28 | } 29 | -------------------------------------------------------------------------------- /config/train/cifar100/simplenet_finetune.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for slimming simple network. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | import os 8 | 9 | config = { 10 | "SEED": 777, 11 | "AUG_TRAIN": "simple_augment_train_cifar100", 12 | "AUG_TEST": "simple_augment_test_cifar100", 13 | "DATASET": "CIFAR100", 14 | "MODEL_NAME": "simplenet", 15 | "MODEL_PARAMS": dict(num_classes=100), 16 | "CRITERION": "CrossEntropy", 17 | "CRITERION_PARAMS": dict(num_classes=100), 18 | "LR_SCHEDULER": "WarmupCosineLR", 19 | "LR_SCHEDULER_PARAMS": dict( 20 | warmup_epochs=5, start_lr=1e-3, min_lr=5e-6, n_rewinding=1 21 | ), 22 | "BATCH_SIZE": 64, 23 | "LR": 0.001, 24 | "MOMENTUM": 0.9, 25 | "WEIGHT_DECAY": 1e-4, 26 | "NESTEROV": True, 27 | "EPOCHS": 10, 28 | "N_WORKERS": os.cpu_count(), 29 | } 30 | -------------------------------------------------------------------------------- /config/train/cifar100/simplenet_kd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for naive lottery ticket hypothesis with kd, simplenet. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | 8 | from config.train.cifar100 import simplenet 9 | 10 | config = simplenet.config 11 | config_override = { 12 | "CRITERION": "HintonKLD", 13 | "CRITERION_PARAMS": dict( 14 | T=4.0, 15 | alpha=0.9, 16 | teacher_model_name="simplenet", 17 | teacher_model_params=dict(num_classes=100), 18 | crossentropy_params=dict(num_classes=100), 19 | ), 20 | "BATCH_SIZE": 16, 21 | "LR": 0.1, 22 | "MOMENTUM": 0.9, 23 | "WEIGHT_DECAY": 1e-4, 24 | "EPOCHS": 5, 25 | } 26 | config.update(config_override) 27 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: model_compression 2 | channels: 3 | - conda-forge 4 | - defaults 5 | - pytorch 6 | dependencies: 7 | - python=3.7 8 | - progressbar2=3.53.1 9 | - isort=5.1.4 10 | - black=22.3.0 11 | - flake8=3.8.3 12 | - flake8-bugbear=20.1.4 13 | - flake8-docstrings=1.5.0 14 | - pylint=2.6.0 15 | - pytest=6.1.1 16 | - pytest-pylint=0.17.0 17 | - pytest-flake8=1.0.6 18 | - pytest-mypy=0.7.0 19 | - flake8-polyfill=1.0.2 20 | - pytorch=1.9.1 21 | - torchvision=0.10.1 22 | - pip: 23 | - gdown==3.11.1 24 | - wandb 25 | - pyyaml 26 | - opencv-python==4.5.3.56 27 | - tensorly==0.6.0 28 | - p-tqdm==1.3.3 29 | - pyflakes==2.2.0 30 | - coverage==5.3 31 | - pytest-cov==2.10.1 32 | - mypy==0.971 33 | - flake8-annotations==2.4.0 34 | -------------------------------------------------------------------------------- /model_decomposition.py: -------------------------------------------------------------------------------- 1 | """Tensor decomposition. 2 | 3 | - Author: Haneol Kim 4 | - Contact: hekim@jmarple.ai 5 | """ 6 | 7 | import argparse 8 | from copy import deepcopy 9 | import os 10 | from pathlib import Path 11 | import time 12 | from typing import Dict, List, Tuple 13 | 14 | import numpy as np 15 | import torch 16 | from torch import nn 17 | 18 | from src.logger import colorstr, get_logger 19 | from src.runners import initialize 20 | from src.runners.validator import Validator 21 | from src.tensor_decomposition.decomposition import decompose_model 22 | from src.utils import count_param 23 | 24 | LOGGER = get_logger(__name__) 25 | 26 | 27 | def get_parser() -> argparse.Namespace: 28 | """Get argument parser.""" 29 | parser = argparse.ArgumentParser( 30 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 31 | ) 32 | parser.add_argument( 33 | "--config", 34 | type=str, 35 | default="config/train/cifar100/densenet_201.py", 36 | help="Configuration path (.py)", 37 | ) 38 | parser.add_argument( 39 | "--resume", 40 | type=str, 41 | default="", 42 | help="Input log directory name to resume in save/checkpoint", 43 | ) 44 | parser.add_argument("--gpu", default=0, type=int, help="GPU id to use") 45 | parser.add_argument("--multi-gpu", action="store_true", help="Multi-GPU use.") 46 | parser.add_argument( 47 | "--dst", 48 | type=str, 49 | default=os.path.join("exp", "decompose"), 50 | help="Export directory. Directory will be {dst}/decompose/{DATE}_runs1, ...", 51 | ) 52 | parser.add_argument( 53 | "--prune-step", 54 | default=0.01, 55 | type=float, 56 | help="Prunning trial max step. Maximum step while searching prunning ratio with binary search. Pruning will be applied priro to decomposition. If prune-step is equal or smaller than 0.0, prunning will not be applied.", 57 | ) 58 | parser.add_argument( 59 | "--loss-thr", 60 | default=0.1, 61 | type=float, 62 | help="Loss value to compare original model output and decomposed model output to judge to switch to decomposed conv.", 63 | ) 64 | parser.add_argument( 65 | "--half", dest="half", action="store_true", help="Use half precision" 66 | ) 67 | parser.add_argument( 68 | "--log", 69 | dest="log", 70 | action="store_true", 71 | help="Logging the tensor decomposition results.", 72 | ) 73 | parser.add_argument( 74 | "--file_name", 75 | type=str, 76 | default="decomposed_model.pt", 77 | help="Decomposed model's file name", 78 | ) 79 | parser.add_argument("--seed", type=int, default=0, help="Random seed.") 80 | parser.set_defaults(multi_gpu=False) 81 | parser.set_defaults(half=False) 82 | parser.set_defaults(log=False) 83 | return parser.parse_args() 84 | 85 | 86 | def log_result( 87 | ori_param: int, 88 | decomp_param: int, 89 | ori_time: float, 90 | decomp_time: float, 91 | ori_result: float, 92 | decomp_result: float, 93 | ) -> Tuple[Dict[str, str], List[str]]: 94 | """Generate string for logging.""" 95 | log_dict = {} 96 | dict_keys = [ 97 | "ori_param", 98 | "decomp_param", 99 | "ori_time", 100 | "decomp_time", 101 | "ori_result", 102 | "decomp_result", 103 | ] 104 | log_dict.update({dict_keys[0]: f" Original # of param : {ori_param}"}) 105 | log_dict.update({dict_keys[1]: f"Decomposed # of param : {decomp_param}"}) 106 | log_dict.update({dict_keys[2]: f"Time took (Original) : {ori_time:.5f}s"}) 107 | log_dict.update({dict_keys[3]: f"Time took (Decomposed) : {decomp_time:.5f}s"}) 108 | log_dict.update({dict_keys[4]: f"Original model accuray : {ori_result}"}) 109 | log_dict.update({dict_keys[5]: f"Decomposed model accuray : {decomp_result}"}) 110 | 111 | return log_dict, dict_keys 112 | 113 | 114 | def run_decompose( 115 | args: argparse.Namespace, 116 | validator: Validator, 117 | device: torch.device, 118 | ) -> Tuple[nn.Module, Tuple[Tuple[list, ...], np.ndarray, tuple]]: 119 | """Run tensor decomposition on given model. 120 | 121 | Args: 122 | args: arguments for the tensor decomposition. 123 | args.prune_step(float): prune step. 124 | args.loss_thr(float): Loss threshold for decomposition. 125 | model: Original model. 126 | validator: validation runner. 127 | device: device to run validation. 128 | 129 | Return: 130 | decomposed_model, 131 | ( 132 | (mP, mR, mAP0.5, mAP0.5:0.95, 0, 0, 0), 133 | mAP0.5 by classes, 134 | time measured (pre-processing, inference, NMS) 135 | ) 136 | """ 137 | t0 = time.monotonic() 138 | ori_result = validator.run()[1]["model_acc"] 139 | origin_time_took = time.monotonic() - t0 140 | model = validator.model 141 | decomposed_model = deepcopy(validator.model.cpu()) 142 | decompose_model( 143 | decomposed_model, loss_thr=args.loss_thr, prune_step=args.prune_step 144 | ) 145 | 146 | LOGGER.info( 147 | f"Decomposed with prunning step: {args.prune_step}, decomposition loss threshold: {args.loss_thr}" 148 | ) 149 | 150 | decomposed_model.to(device) 151 | decomposed_model.eval() 152 | 153 | validator.model = decomposed_model 154 | t0 = time.monotonic() 155 | decomposed_result = validator.run()[1]["model_acc"] 156 | decomposed_time_took = time.monotonic() - t0 157 | 158 | log_dict, log_keys = log_result( 159 | ori_param=count_param(model), 160 | decomp_param=count_param(decomposed_model), 161 | ori_time=origin_time_took, 162 | decomp_time=decomposed_time_took, 163 | ori_result=ori_result, 164 | decomp_result=decomposed_result, 165 | ) 166 | 167 | for key in log_keys: 168 | LOGGER.info(log_dict[key]) 169 | 170 | if args.log: 171 | log_file = os.path.join(args.resume, "decompose_log.txt") 172 | with open(log_file, "w") as f: 173 | f.writelines([log_dict[key] for key in log_keys]) 174 | 175 | return decomposed_model, decomposed_result 176 | 177 | 178 | if __name__ == "__main__": 179 | args = get_parser() 180 | 181 | torch.manual_seed(args.seed) 182 | LOGGER.info(f"Random Seed: {args.seed}") 183 | 184 | # initialize 185 | config, dir_prefix, device = initialize( 186 | "train", args.config, args.resume, args.multi_gpu, args.gpu 187 | ) 188 | 189 | validator = Validator( 190 | config=config, 191 | dir_prefix=dir_prefix, 192 | device=device, 193 | half=args.half, 194 | checkpt_dir="train", 195 | ) 196 | decomp_model, _ = run_decompose(args, validator, device) 197 | 198 | resume_dir = args.resume.split("/")[2] 199 | 200 | weight_dir = Path(os.path.join("decompose", resume_dir)) 201 | weight_dir.mkdir(parents=True, exist_ok=True) 202 | filename = Path(args.file_name) 203 | 204 | weight_path = weight_dir / filename 205 | LOGGER.info(f"Decomposed model saved in {colorstr('cyan', 'bold', weight_path)}") 206 | 207 | torch.save({"model": decomp_model.cpu().half(), "decomposed": True}, weight_path) 208 | 209 | os.popen(f'cp {os.path.join(args.resume, "*.py")} {weight_dir}') 210 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | # Global options: 2 | 3 | [mypy] 4 | python_version = 3.8 5 | ignore_missing_imports = True 6 | exclude = (config/|save/) 7 | -------------------------------------------------------------------------------- /prune.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Pruning Runner. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | 9 | import argparse 10 | 11 | from src.runners import curr_time, initialize 12 | 13 | # arguments 14 | parser = argparse.ArgumentParser(description="Model pruner.") 15 | parser.add_argument("--multi-gpu", action="store_true", help="Multi-GPU use") 16 | parser.add_argument("--gpu", default=0, type=int, help="GPU id to use") 17 | parser.add_argument( 18 | "--resume", 19 | type=str, 20 | default="", 21 | help="Input checkpoint directory name", 22 | ) 23 | parser.add_argument( 24 | "--wlog", dest="wlog", action="store_true", help="Turns on wandb logging" 25 | ) 26 | parser.add_argument( 27 | "--config", 28 | type=str, 29 | default="config/prune/simplenet_kd.py", 30 | help="Configuration path", 31 | ) 32 | parser.set_defaults(multi_gpu=False) 33 | parser.set_defaults(log=False) 34 | args = parser.parse_args() 35 | 36 | # initialize 37 | config, dir_prefix, device = initialize( 38 | "prune", args.config, args.resume, args.multi_gpu, args.gpu 39 | ) 40 | 41 | # run pruning 42 | wandb_name = args.resume if args.resume else curr_time 43 | wandb_init_params = dict(config=config, name=wandb_name, group=args.config) 44 | Pruner = getattr( 45 | __import__("src.runners.pruner", fromlist=[""]), config["PRUNE_METHOD"] 46 | ) 47 | 48 | pruner = Pruner( 49 | config=config, 50 | dir_prefix=dir_prefix, 51 | wandb_log=args.wlog, 52 | wandb_init_params=wandb_init_params, 53 | device=device, 54 | ) 55 | pruner.run(args.resume) 56 | -------------------------------------------------------------------------------- /quantize.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Quantization Runner. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | 9 | import argparse 10 | import os 11 | import shutil 12 | 13 | from src.runners import curr_time, initialize 14 | from src.runners.quantizer import Quantizer 15 | 16 | # arguments 17 | parser = argparse.ArgumentParser(description="Model quantizer.") 18 | parser.add_argument( 19 | "--resume", type=str, default="", help="Input log directory name to resume" 20 | ) 21 | parser.add_argument( 22 | "--check-acc", 23 | dest="check_acc", 24 | action="store_true", 25 | help="Check inference accuracy", 26 | ) 27 | parser.add_argument( 28 | "--wlog", dest="wlog", action="store_true", help="Turns on wandb logging" 29 | ) 30 | parser.add_argument( 31 | "--static", 32 | dest="static", 33 | action="store_true", 34 | help="Post-training static quantization", 35 | ) 36 | parser.add_argument( 37 | "--backend", type=str, default="fbgemm", help="pytorch quantization backend" 38 | ) 39 | parser.add_argument("--config", type=str, help="Configuration path") 40 | parser.add_argument("--checkpoint", type=str, help="input checkpoint path to quantize") 41 | parser.set_defaults(check_acc=False) 42 | parser.set_defaults(wlog=False) 43 | parser.set_defaults(static=False) 44 | args = parser.parse_args() 45 | 46 | # get config and directory path prefix for logging 47 | config, dir_prefix, _ = initialize("quantize", args.config, args.resume) 48 | 49 | if not args.resume: 50 | assert args.checkpoint and os.path.exists(args.checkpoint), "--checkpoint required" 51 | checkpoint_path = args.checkpoint 52 | shutil.copyfile(args.checkpoint, os.path.join(dir_prefix, "orig_model.pth.tar")) 53 | else: 54 | checkpoint_path = os.path.join(dir_prefix, "orig_model.pth.tar") 55 | 56 | # wandb 57 | wandb_name = curr_time 58 | wandb_init_params = dict(config=config, name=wandb_name, group=args.config) 59 | 60 | # run quantization 61 | quantizer = Quantizer( 62 | config=config, 63 | checkpoint_path=checkpoint_path, 64 | dir_prefix=dir_prefix, 65 | static=args.static, 66 | check_acc=args.check_acc, 67 | backend=args.backend, 68 | wandb_log=args.wlog, 69 | wandb_init_params=wandb_init_params, 70 | ) 71 | quantizer.run(args.resume) 72 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Dependencies for docker image. 2 | 3 | black==22.3.0 4 | coverage==5.3 5 | flake8==3.8.3 6 | flake8-bugbear==20.1.4 7 | flake8-docstrings==1.5.0 8 | isort==5.5.3 9 | mypy==0.971 10 | pre-commit==2.7.1 11 | pyflakes==2.2.0 12 | pylint==2.6.0 13 | pytest==6.0.2 14 | pytest-cov==2.10.1 15 | pytest-flake8==1.0.6 16 | pytest-mypy==0.7.0 17 | pytest-pylint==0.17.0 18 | flake8-annotations==2.4.0 19 | flake8-polyfill==1.0.2 20 | docformatter==1.4.0 21 | mkdocs==1.2.3 22 | # mkapi==1.0.14 23 | mkdocs-ivory==0.4.6 24 | opencv-python==4.5.3.56 25 | wandb==0.12.3 26 | p-tqdm==1.3.3 27 | matplotlib==3.3.4 28 | seaborn==0.11.2 29 | onnx==1.10.1 30 | onnx-simplifier>=0.3.6 31 | tensorly==0.6.0 32 | orjson==3.6.4 33 | progressbar2==3.53.1 34 | gdown==3.11.1 35 | -------------------------------------------------------------------------------- /run_check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Shell script for formating, linting and unit test 4 | # 5 | # - Author: Jongkuk Lim 6 | # - Contact: limjk@jmarple.ai 7 | 8 | # Bash 3 does not support hash dictionary. 9 | # hput and hget are alternative workaround. 10 | # Usage) 11 | # hput $VAR_NAME $KEY $VALUE 12 | hput() { 13 | eval "$1""$2"='$3' 14 | } 15 | 16 | # Usage) 17 | # `hget $VAR_NAME $KEY` 18 | hget() { 19 | eval echo '${'"$1$2"'#hash}' 20 | } 21 | 22 | # Define command names 23 | CMD_NAME=( 24 | "format" 25 | "lint" 26 | "test" 27 | "doc" 28 | "doc_server" 29 | "init_conda" 30 | "init_precommit" 31 | "init" 32 | "all" 33 | ) 34 | 35 | # Define descriptions 36 | hput CMD_DESC format "Run formating" 37 | hput CMD_DESC lint "Run linting check" 38 | hput CMD_DESC test "Run unit test" 39 | hput CMD_DESC doc "Generate MKDocs document" 40 | hput CMD_DESC doc_server "Run MKDocs hosting server (in local)" 41 | hput CMD_DESC init_conda "Create conda environment with default name" 42 | hput CMD_DESC init_precommit "Install pre-commit plugin" 43 | hput CMD_DESC init "Run init-conda and init-precommit" 44 | hput CMD_DESC all "Run formating, linting and unit test" 45 | 46 | # Define commands 47 | hput CMD_LIST format "black . && isort . && docformatter -i -r . --wrap-summaries 88 --wrap-descriptions 88" 48 | hput CMD_LIST lint "env PYTHONPATH=. pytest --pylint --mypy --flake8 --ignore tests --ignore cpp --ignore config --ignore save" 49 | hput CMD_LIST test "env PYTHONPATH=. pytest tests --cov=scripts --cov-report term-missing --cov-report html" 50 | hput CMD_LIST doc "env PYTHONPATH=. mkdocs build --no-directory-urls" 51 | hput CMD_LIST doc_server "env PYTHONPATH=. mkdocs serve -a 127.0.0.1:8000 --no-livereload" 52 | hput CMD_LIST init_conda "conda env create -f environment.yml" 53 | hput CMD_LIST init_precommit "pre-commit install --hook-type pre-commit --hook-type pre-push" 54 | hput CMD_LIST init "`hget CMD_LIST init_conda` && `hget CMD_LIST init_precommit`" 55 | hput CMD_LIST all "`hget CMD_LIST format` && `hget CMD_LIST lint` && `hget CMD_LIST test`" 56 | 57 | for _arg in $@ 58 | do 59 | if [[ `hget CMD_LIST $_arg` == "" ]]; then 60 | echo "$_arg is not valid option!" 61 | echo "--------------- $0 Usage ---------------" 62 | for _key in ${CMD_NAME[@]} 63 | do 64 | echo "$0 $_key - `hget CMD_DESC $_key`" 65 | done 66 | exit 0 67 | else 68 | cmd=`hget CMD_LIST $_arg` 69 | echo "Run $cmd" 70 | eval $cmd 71 | 72 | result=$? 73 | if [ $result -ne 0 ]; then 74 | exitCode=$result 75 | fi 76 | fi 77 | done 78 | 79 | exit $exitCode 80 | -------------------------------------------------------------------------------- /run_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Docker build image, run container, execute last container. 4 | # 5 | # - Author: Jongkuk Lim 6 | # - Contact: limjk@jmarple.ai 7 | 8 | xhost + 9 | 10 | ORG=jmarpledev 11 | 12 | PRJ_NAME=${PWD##*/} 13 | PRJ_NAME=${PRJ_NAME,,} 14 | 15 | DOCKER_TAG=$ORG/$PRJ_NAME 16 | 17 | CMD_ARGS=( ${@} ) 18 | CMD_ARGS=${CMD_ARGS[*]:1} 19 | 20 | if [[ $2 == :* ]]; then 21 | DOCKER_TAG=$DOCKER_TAG$2 22 | CMD_ARGS=${CMD_ARGS[*]:2} 23 | fi 24 | 25 | if [ "$1" = "build" ]; then 26 | echo "Building a docker image with tagname $DOCKER_TAG and arguments $CMD_ARGS" 27 | docker build . -t $DOCKER_TAG $CMD_ARGS --build-arg UID=`id -u` --build-arg GID=`id -g` 28 | elif [ "$1" = "run" ]; then 29 | echo "Run a docker image with tagname $DOCKER_TAG and arguments $CMD_ARGS" 30 | 31 | docker run -tid --privileged --gpus all \ 32 | -e DISPLAY=${DISPLAY} \ 33 | -e TERM=xterm-256color \ 34 | -v /tmp/.X11-unix:/tmp/.X11-unix:ro \ 35 | -v /dev:/dev \ 36 | -v $PWD:/home/user/$PRJ_NAME \ 37 | --network host \ 38 | $CMD_ARGS \ 39 | $DOCKER_TAG /bin/bash 40 | 41 | last_cont_id=$(docker ps -qn 1) 42 | echo $(docker ps -qn 1) > $PWD/.last_exec_cont_id.txt 43 | 44 | docker exec -ti $last_cont_id /bin/bash 45 | elif [ "$1" = "exec" ]; then 46 | echo "Execute the last docker container" 47 | 48 | last_cont_id=$(tail -1 $PWD/.last_exec_cont_id.txt) 49 | docker start ${last_cont_id} 50 | docker exec -ti ${last_cont_id} /bin/bash 51 | else 52 | echo "" 53 | echo "============= $0 [Usages] ============" 54 | echo "1) $0 build : build docker image" 55 | echo " build --no-cache : Build docker image without cache" 56 | echo "2) $0 run : launch a new docker container" 57 | echo "3) $0 exec : execute last container launched" 58 | fi 59 | 60 | -------------------------------------------------------------------------------- /shrink.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Shrink model, and save and run. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | 8 | 9 | import argparse 10 | import os 11 | import shutil 12 | 13 | from src.runners import initialize 14 | from src.runners.shrinker import Shrinker 15 | 16 | # arguments 17 | parser = argparse.ArgumentParser(description="Model shrinker.") 18 | parser.add_argument("--gpu", default=0, type=int, help="GPU id to use") 19 | parser.add_argument("--checkpoint", type=str, help="input checkpoint path to quantize") 20 | parser.add_argument("--config", type=str, help="Pruning configuration path") 21 | args = parser.parse_args() 22 | 23 | # get config and directory path prefix for logging 24 | config, dir_prefix, device = initialize( 25 | mode="shrink", config_path=args.config, gpu_id=args.gpu 26 | ) 27 | 28 | assert args.checkpoint and os.path.exists(args.checkpoint), "--checkpoint required" 29 | shutil.copyfile(args.checkpoint, os.path.join(dir_prefix, "orig_model.pth.tar")) 30 | 31 | # run quantization 32 | shrinker = Shrinker( 33 | config=config, 34 | checkpoint_path=args.checkpoint, 35 | dir_prefix=dir_prefix, 36 | device=device, 37 | ) 38 | shrinker.run() 39 | -------------------------------------------------------------------------------- /src/augmentation/methods.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Augmentation methods. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | - Reference: 7 | https://arxiv.org/pdf/1805.09501.pdf 8 | https://github.com/kakaobrain/fast-autoaugment/ 9 | """ 10 | 11 | from abc import ABC 12 | from itertools import chain 13 | import random 14 | from typing import List, Tuple 15 | 16 | from PIL.Image import Image 17 | import numpy as np 18 | import torch 19 | from torch.utils.data import Dataset 20 | 21 | from src.augmentation.transforms import transforms_info 22 | from src.utils import get_rand_bbox_coord, to_onehot 23 | 24 | 25 | class Augmentation(ABC): 26 | """Abstract class used by all augmentation methods.""" 27 | 28 | def __init__(self, n_level: int = 10) -> None: 29 | """Initialize.""" 30 | self.transforms_info = transforms_info() 31 | self.n_level = n_level 32 | 33 | def _apply_augment(self, img: Image, name: str, level: int) -> Image: 34 | """Apply and get the augmented image. 35 | 36 | Args: 37 | img (Image): an image to augment 38 | level (int): magnitude of augmentation in [0, n_level) 39 | 40 | returns: 41 | Image: an augmented image 42 | """ 43 | assert 0 <= level < self.n_level 44 | augment_fn, low, high = self.transforms_info[name] 45 | return augment_fn(img.copy(), level * (high - low) / self.n_level + low) 46 | 47 | 48 | class SequentialAugmentation(Augmentation): 49 | """Sequential augmentation class.""" 50 | 51 | def __init__( 52 | self, 53 | policies: List[Tuple[str, float, int]], 54 | n_level: int = 10, 55 | ) -> None: 56 | """Initialize.""" 57 | super(SequentialAugmentation, self).__init__(n_level) 58 | self.policies = policies 59 | 60 | def __call__(self, img: Image) -> Image: 61 | """Run augmentations.""" 62 | for name, pr, level in self.policies: 63 | if random.random() > pr: 64 | continue 65 | img = self._apply_augment(img, name, level) 66 | return img 67 | 68 | 69 | class AutoAugmentation(Augmentation): 70 | """Auto augmentation class. 71 | 72 | References: 73 | https://arxiv.org/pdf/1805.09501.pdf 74 | """ 75 | 76 | def __init__( 77 | self, 78 | policies: List[List[Tuple[str, float, int]]], 79 | n_select: int = 1, 80 | n_level: int = 10, 81 | ) -> None: 82 | """Initialize.""" 83 | super(AutoAugmentation, self).__init__(n_level) 84 | self.policies = policies 85 | self.n_select = n_select 86 | 87 | def __call__(self, img: Image) -> Image: 88 | """Run augmentations.""" 89 | chosen_policies = random.sample(self.policies, k=self.n_select) 90 | for name, pr, level in chain.from_iterable(chosen_policies): 91 | if random.random() > pr: 92 | continue 93 | img = self._apply_augment(img, name, level) 94 | return img 95 | 96 | 97 | class RandAugmentation(Augmentation): 98 | """Random augmentation class. 99 | 100 | References: 101 | RandAugment: Practical automated data augmentation with a reduced search space 102 | (https://arxiv.org/abs/1909.13719) 103 | """ 104 | 105 | def __init__( 106 | self, 107 | transforms: List[str], 108 | n_select: int = 2, 109 | level: int = 14, 110 | n_level: int = 31, 111 | ) -> None: 112 | """Initialize.""" 113 | super(RandAugmentation, self).__init__(n_level) 114 | self.n_select = n_select 115 | self.level = level if type(level) is int and 0 <= level < n_level else None 116 | self.transforms = transforms 117 | 118 | def __call__(self, img: Image) -> Image: 119 | """Run augmentations.""" 120 | chosen_transforms = random.sample(self.transforms, k=self.n_select) 121 | for transf in chosen_transforms: 122 | level = self.level if self.level else random.randint(0, self.n_level - 1) 123 | img = self._apply_augment(img, transf, level) 124 | return img 125 | 126 | 127 | class CutMix(Dataset): 128 | """A Dataset class for CutMix. 129 | 130 | References: 131 | https://github.com/ildoonet/cutmix 132 | """ 133 | 134 | def __init__( 135 | self, dataset: Dataset, num_classes: int, beta: float = 1.0, prob: float = 0.5 136 | ) -> None: 137 | """Initialize.""" 138 | self.dataset = dataset 139 | self.num_classes = num_classes 140 | self.beta = beta 141 | self.prob = prob 142 | 143 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: 144 | """Convert image and label to a cutmix image and label. 145 | 146 | Combine two training samples by cutting and pasting two images along a random 147 | box. The ground truth label is also "mixed" via the combination ratio. The 148 | combination ratio is sampled from a beta distribution. 149 | """ 150 | img, label = self.dataset[index] # label: int 151 | label = torch.tensor([label], dtype=torch.long) 152 | label_onehot = to_onehot(label, self.num_classes) 153 | # sampling the length ratio of random box to the image 154 | len_ratio = np.sqrt(np.random.beta(self.beta, self.beta)) 155 | 156 | if random.random() > self.prob or len_ratio < 1e-3: 157 | return img, label_onehot.squeeze_(0) 158 | 159 | w, h = img.size()[-2], img.size()[-1] 160 | (x0, y0), (x1, y1) = get_rand_bbox_coord(w, h, len_ratio) 161 | # compute the combination ratio 162 | comb_ratio = (x1 - x0) * (y1 - y0) / (w * h) 163 | 164 | rand_ind = np.random.randint(len(self)) 165 | rand_img, rand_label = self.dataset[rand_ind] 166 | rand_label = torch.tensor([rand_label], dtype=torch.long) 167 | img[:, x0:x1, y0:y1] = rand_img[:, x0:x1, y0:y1] 168 | label_onehot = (1 - comb_ratio) * label_onehot + comb_ratio * to_onehot( 169 | rand_label, self.num_classes 170 | ) 171 | return img, label_onehot.squeeze_(0) 172 | 173 | def __len__(self) -> int: 174 | """Get length of dataset.""" 175 | return len(self.dataset) # type: ignore 176 | -------------------------------------------------------------------------------- /src/augmentation/policies.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """PyTorch transforms for data augmentation. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | import torchvision.transforms as transforms 9 | 10 | from src.augmentation.methods import ( 11 | AutoAugmentation, 12 | RandAugmentation, 13 | SequentialAugmentation, 14 | ) 15 | from src.augmentation.transforms import FILLCOLOR 16 | 17 | CIFAR100_INFO = {"MEAN": (0.5071, 0.4865, 0.4409), "STD": (0.2673, 0.2564, 0.2762)} 18 | IMAGENET_INFO = {"MEAN": (0.485, 0.456, 0.406), "STD": (0.229, 0.224, 0.225)} 19 | 20 | 21 | def simple_augment_train_cifar100() -> transforms.Compose: 22 | """Return simple data augmentation rule for training CIFAR100.""" 23 | return transforms.Compose( 24 | [ 25 | transforms.RandomCrop(32, padding=4), 26 | transforms.RandomHorizontalFlip(), 27 | transforms.ToTensor(), 28 | transforms.Normalize(CIFAR100_INFO["MEAN"], CIFAR100_INFO["STD"]), 29 | ] 30 | ) 31 | 32 | 33 | def simple_augment_test_cifar100() -> transforms.Compose: 34 | """Return simple data augmentation rule for testing CIFAR100.""" 35 | return transforms.Compose( 36 | [ 37 | transforms.ToTensor(), 38 | transforms.Normalize(CIFAR100_INFO["MEAN"], CIFAR100_INFO["STD"]), 39 | ] 40 | ) 41 | 42 | 43 | def simple_augment_test_cifar100_224() -> transforms.Compose: 44 | """Return simple data augmentation rule for testing CIFAR100.""" 45 | return transforms.Compose( 46 | [ 47 | transforms.Resize(224), 48 | transforms.ToTensor(), 49 | transforms.Normalize(CIFAR100_INFO["MEAN"], CIFAR100_INFO["STD"]), 50 | ] 51 | ) 52 | 53 | 54 | def autoaugment_train_cifar100() -> transforms.Compose: 55 | """Return auto augmentation policy for training CIFAR100.""" 56 | policies = [ 57 | [("Invert", 0.1, 7), ("Contrast", 0.2, 6)], 58 | [("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)], 59 | [("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)], 60 | [("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)], 61 | [("AutoContrast", 0.5, 8), ("Equalize", 0.9, 2)], 62 | [("ShearY", 0.2, 7), ("Posterize", 0.3, 7)], 63 | [("Color", 0.4, 3), ("Brightness", 0.6, 7)], 64 | [("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)], 65 | [("Equalize", 0.6, 5), ("Equalize", 0.5, 1)], 66 | [("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)], 67 | [("Color", 0.7, 7), ("TranslateX", 0.5, 8)], 68 | [("Equalize", 0.3, 7), ("AutoContrast", 0.4, 8)], 69 | [("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)], 70 | [("Brightness", 0.9, 6), ("Color", 0.2, 8)], 71 | [("Solarize", 0.5, 2), ("Invert", 0.0, 3)], 72 | [("Equalize", 0.2, 0), ("AutoContrast", 0.6, 0)], 73 | [("Equalize", 0.2, 8), ("Equalize", 0.6, 4)], 74 | [("Color", 0.9, 9), ("Equalize", 0.6, 6)], 75 | [("AutoContrast", 0.8, 4), ("Solarize", 0.2, 8)], 76 | [("Brightness", 0.1, 3), ("Color", 0.7, 0)], 77 | [("Solarize", 0.4, 5), ("AutoContrast", 0.9, 3)], 78 | [("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)], 79 | [("AutoContrast", 0.9, 2), ("Solarize", 0.8, 3)], 80 | [("Equalize", 0.8, 8), ("Invert", 0.1, 3)], 81 | [("TranslateY", 0.7, 9), ("AutoContrast", 0.9, 1)], 82 | ] 83 | return transforms.Compose( 84 | [ 85 | transforms.RandomCrop(32, padding=4, fill=FILLCOLOR), 86 | transforms.RandomHorizontalFlip(), 87 | AutoAugmentation(policies), 88 | SequentialAugmentation([("Cutout", 1.0, 9)]), 89 | transforms.ToTensor(), 90 | transforms.Normalize(CIFAR100_INFO["MEAN"], CIFAR100_INFO["STD"]), 91 | ] 92 | ) 93 | 94 | 95 | def autoaugment_train_cifar100_riair() -> transforms.Compose: 96 | """Return RIAIR's Auto augmentation policy for training CIFAR100.""" 97 | policies = [ 98 | [("Invert", 0.2, 2)], 99 | [("Contrast", 0.4, 4)], 100 | [("Rotate", 0.5, 1)], 101 | [("TranslateX", 0.4, 3)], 102 | [("Sharpness", 0.5, 3)], 103 | [("ShearY", 0.3, 4)], 104 | [("TranslateY", 0.6, 8)], 105 | [("AutoContrast", 0.6, 3)], 106 | [("Equalize", 0.5, 5)], 107 | [("Solarize", 0.4, 4)], 108 | [("Color", 0.5, 5)], 109 | [("Posterize", 0.2, 2)], 110 | [("Brightness", 0.4, 5)], 111 | [("Cutout", 0.3, 3)], 112 | [("ShearX", 0.1, 3)], 113 | ] 114 | return transforms.Compose( 115 | [ 116 | AutoAugmentation(policies, n_select=2), 117 | transforms.RandomCrop(32, padding=4, fill=FILLCOLOR), 118 | transforms.RandomHorizontalFlip(), 119 | SequentialAugmentation([("Cutout", 1.0, 9)]), 120 | transforms.ToTensor(), 121 | transforms.Normalize(CIFAR100_INFO["MEAN"], CIFAR100_INFO["STD"]), 122 | ] 123 | ) 124 | 125 | 126 | def randaugment_train_cifar100( 127 | n_select: int = 2, 128 | level: int = 14, 129 | n_level: int = 31, 130 | ) -> transforms.Compose: 131 | """Return Random augmentation policy for training CIFAR100.""" 132 | operators = [ 133 | "Identity", 134 | "AutoContrast", 135 | "Equalize", 136 | "Rotate", 137 | "Solarize", 138 | "Color", 139 | "Posterize", 140 | "Contrast", 141 | "Brightness", 142 | "Sharpness", 143 | "ShearX", 144 | "ShearY", 145 | "TranslateX", 146 | "TranslateY", 147 | ] 148 | return transforms.Compose( 149 | [ 150 | RandAugmentation(operators, n_select, level, n_level), 151 | transforms.RandomCrop(32, padding=4, fill=FILLCOLOR), 152 | transforms.RandomHorizontalFlip(), 153 | SequentialAugmentation([("Cutout", 1.0, 9)]), 154 | transforms.ToTensor(), 155 | transforms.Normalize(CIFAR100_INFO["MEAN"], CIFAR100_INFO["STD"]), 156 | ] 157 | ) 158 | 159 | 160 | def randaugment_train_cifar100_224( 161 | n_select: int = 2, 162 | level: int = 14, 163 | n_level: int = 31, 164 | ) -> transforms.Compose: 165 | """Return Random augmentation policy for training CIFAR100.""" 166 | operators = [ 167 | "Identity", 168 | "AutoContrast", 169 | "Equalize", 170 | "Rotate", 171 | "Solarize", 172 | "Color", 173 | "Posterize", 174 | "Contrast", 175 | "Brightness", 176 | "Sharpness", 177 | "ShearX", 178 | "ShearY", 179 | "TranslateX", 180 | "TranslateY", 181 | ] 182 | return transforms.Compose( 183 | [ 184 | transforms.Resize(224), 185 | RandAugmentation(operators, n_select, level, n_level), 186 | transforms.RandomCrop(224, padding=4, fill=FILLCOLOR), 187 | transforms.RandomHorizontalFlip(), 188 | SequentialAugmentation([("Cutout", 1.0, 9)]), 189 | transforms.ToTensor(), 190 | transforms.Normalize(CIFAR100_INFO["MEAN"], CIFAR100_INFO["STD"]), 191 | ] 192 | ) 193 | -------------------------------------------------------------------------------- /src/augmentation/transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Image transformations for augmentation. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | - Reference: 7 | https://github.com/kakaobrain/fast-autoaugment/ 8 | https://github.com/DeepVoltaire/AutoAugment 9 | """ 10 | 11 | import random 12 | from typing import Callable, Dict, Tuple 13 | 14 | import PIL 15 | from PIL.Image import Image 16 | import PIL.ImageDraw 17 | import PIL.ImageEnhance 18 | import PIL.ImageOps 19 | 20 | from src.utils import get_rand_bbox_coord 21 | 22 | FILLCOLOR = (128, 128, 128) 23 | FILLCOLOR_RGBA = (128, 128, 128, 128) 24 | 25 | 26 | def transforms_info() -> Dict[ 27 | str, Tuple[Callable[[Image, float], Image], float, float] 28 | ]: 29 | """Return augmentation functions and their ranges.""" 30 | transforms_list = [ 31 | (Identity, 0.0, 0.0), 32 | (Invert, 0.0, 0.0), 33 | (Contrast, 0.0, 0.9), 34 | (AutoContrast, 0.0, 0.0), 35 | (Rotate, 0.0, 30.0), 36 | (TranslateX, 0.0, 150 / 331), 37 | (TranslateY, 0.0, 150 / 331), 38 | (Sharpness, 0.0, 0.9), 39 | (ShearX, 0.0, 0.3), 40 | (ShearY, 0.0, 0.3), 41 | (Color, 0.0, 0.9), 42 | (Brightness, 0.0, 0.9), 43 | (Equalize, 0.0, 0.0), 44 | (Solarize, 256.0, 0.0), 45 | (Posterize, 8, 4), 46 | (Cutout, 0, 0.5), 47 | ] 48 | return {f.__name__: (f, low, high) for f, low, high in transforms_list} 49 | 50 | 51 | def Identity(img: Image, _: float) -> Image: 52 | """Identity map.""" 53 | return img 54 | 55 | 56 | def Invert(img: Image, _: float) -> Image: 57 | """Invert the image.""" 58 | return PIL.ImageOps.invert(img) 59 | 60 | 61 | def Contrast(img: Image, magnitude: float) -> Image: 62 | """Put contrast effect on the image.""" 63 | return PIL.ImageEnhance.Contrast(img).enhance( 64 | 1 + magnitude * random.choice([-1, 1]) 65 | ) 66 | 67 | 68 | def AutoContrast(img: Image, _: float) -> Image: 69 | """Put contrast effect on the image.""" 70 | return PIL.ImageOps.autocontrast(img) 71 | 72 | 73 | def Rotate(img: Image, magnitude: float) -> Image: 74 | """Rotate the image (degree).""" 75 | rot = img.convert("RGBA").rotate(magnitude) 76 | return PIL.Image.composite( 77 | rot, PIL.Image.new("RGBA", rot.size, FILLCOLOR_RGBA), rot 78 | ).convert(img.mode) 79 | 80 | 81 | def TranslateX(img: Image, magnitude: float) -> Image: 82 | """Translate the image on x-axis.""" 83 | return img.transform( 84 | img.size, 85 | PIL.Image.AFFINE, 86 | (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 87 | fillcolor=FILLCOLOR, 88 | ) 89 | 90 | 91 | def TranslateY(img: Image, magnitude: float) -> Image: 92 | """Translate the image on y-axis.""" 93 | return img.transform( 94 | img.size, 95 | PIL.Image.AFFINE, 96 | (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 97 | fillcolor=FILLCOLOR, 98 | ) 99 | 100 | 101 | def Sharpness(img: Image, magnitude: float) -> Image: 102 | """Adjust the sharpness of the image.""" 103 | return PIL.ImageEnhance.Sharpness(img).enhance( 104 | 1 + magnitude * random.choice([-1, 1]) 105 | ) 106 | 107 | 108 | def ShearX(img: Image, magnitude: float) -> Image: 109 | """Shear the image on x-axis.""" 110 | return img.transform( 111 | img.size, 112 | PIL.Image.AFFINE, 113 | (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 114 | PIL.Image.BICUBIC, 115 | fillcolor=FILLCOLOR, 116 | ) 117 | 118 | 119 | def ShearY(img: Image, magnitude: float) -> Image: 120 | """Shear the image on y-axis.""" 121 | return img.transform( 122 | img.size, 123 | PIL.Image.AFFINE, 124 | (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 125 | PIL.Image.BICUBIC, 126 | fillcolor=FILLCOLOR, 127 | ) 128 | 129 | 130 | def Color(img: Image, magnitude: float) -> Image: 131 | """Adjust the color balance of the image.""" 132 | return PIL.ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])) 133 | 134 | 135 | def Brightness(img: Image, magnitude: float) -> Image: 136 | """Adjust brightness of the image.""" 137 | return PIL.ImageEnhance.Brightness(img).enhance( 138 | 1 + magnitude * random.choice([-1, 1]) 139 | ) 140 | 141 | 142 | def Equalize(img: Image, _: float) -> Image: 143 | """Equalize the image.""" 144 | return PIL.ImageOps.equalize(img) 145 | 146 | 147 | def Solarize(img: Image, magnitude: float) -> Image: 148 | """Solarize the image.""" 149 | return PIL.ImageOps.solarize(img, magnitude) 150 | 151 | 152 | def Posterize(img: Image, magnitude: float) -> Image: 153 | """Posterize the image.""" 154 | magnitude = int(magnitude) 155 | return PIL.ImageOps.posterize(img, magnitude) 156 | 157 | 158 | def Cutout(img: Image, magnitude: float) -> Image: 159 | """Cutout some region of the image.""" 160 | if magnitude == 0.0: 161 | return img 162 | w, h = img.size 163 | xy = get_rand_bbox_coord(w, h, magnitude) 164 | 165 | img = img.copy() 166 | PIL.ImageDraw.Draw(img).rectangle(xy, fill=FILLCOLOR) 167 | return img 168 | -------------------------------------------------------------------------------- /src/criterions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Collection of losses. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | 8 | import os 9 | from typing import Any, Dict, Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from src.models import utils as model_utils 16 | import src.utils as utils 17 | 18 | logger = utils.get_logger() 19 | 20 | 21 | class Criterion(nn.Module): 22 | """Base class for criterion.""" 23 | 24 | def __init__(self, device: torch.device) -> None: 25 | """Initialize. 26 | 27 | Args: 28 | device (torch.device): device type(GPU, CPU) 29 | """ 30 | super().__init__() 31 | self.device = device 32 | 33 | 34 | class HintonKLD(Criterion): 35 | """Hinton KLD Loss accepting soft labels. 36 | 37 | Reference: 38 | Distilling the Knowledge in a Neural Network(https://arxiv.org/pdf/1503.02531.pdf) 39 | 40 | Attributes: 41 | T (float): Hinton loss param, temperature(>0). 42 | alpha (float): Hinton loss param, alpha(0~1). 43 | cross_entropy (CrossEntropy): cross entropy loss. 44 | teacher (nn.Module): teacher model. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | device: torch.device, 50 | T: float, 51 | alpha: float, 52 | teacher_model_name: str, 53 | teacher_model_params: Dict[str, Any], 54 | crossentropy_params: Dict[str, Any], 55 | ) -> None: 56 | """Initialize cross entropy loss.""" 57 | super().__init__(device) 58 | self.cross_entropy = CrossEntropy(device, **crossentropy_params) 59 | self.T = T 60 | self.alpha = alpha 61 | self.teacher = self._create_teacher(teacher_model_name, teacher_model_params) 62 | 63 | def _create_teacher( 64 | self, teacher_model_name: str, teacher_model_params: Dict[str, Any] 65 | ) -> nn.Module: 66 | """Create teacher network.""" 67 | # create teacher instance 68 | teacher = model_utils.get_model(teacher_model_name, teacher_model_params).to( 69 | self.device 70 | ) 71 | 72 | # teacher path info 73 | prefix = os.path.join("save", "pretrained") 74 | model_info = model_utils.get_pretrained_model_info(teacher) 75 | model_name, file_name = model_info["dir_name"], model_info["file_name"] 76 | file_path = os.path.join(prefix, model_name, file_name) 77 | 78 | # load teacher model params: 79 | if not os.path.isfile(file_path): 80 | model_utils.download_pretrained_model(file_path, model_info["link"]) 81 | logger.info( 82 | f"Pretrained teacher model({model_name}) doesn't exist in the path.\t" 83 | f"Download teacher model as {file_path}" 84 | ) 85 | 86 | logger.info(f"Load teacher model: {file_path}") 87 | state_dict = torch.load(file_path, map_location=self.device)["state_dict"] 88 | model_utils.initialize_params(model=teacher, state_dict=state_dict) 89 | teacher = teacher.to(self.device) 90 | teacher.eval() 91 | return teacher 92 | 93 | def forward( 94 | self, model: nn.Module, images: torch.Tensor, labels: torch.Tensor 95 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 96 | """Forward model, calculate loss. 97 | 98 | Args: 99 | image (torch.Tensor): input images. 100 | labels (torch.Tensor): labels for input images. 101 | 102 | Returns: 103 | loss (torch.Tensor): calculated loss. 104 | logit (Dict[str, torch.Tensor]): model output. 105 | """ 106 | with torch.no_grad(): 107 | logit_t = self.teacher(images) 108 | logit_s = model(images) 109 | 110 | return ( 111 | self.calculate_loss(logit_s=logit_s, logit_t=logit_t, labels=labels), 112 | {"model": logit_s, "teacher": logit_t}, 113 | ) 114 | 115 | def calculate_loss( 116 | self, logit_s: torch.Tensor, logit_t: torch.Tensor, labels: torch.Tensor 117 | ) -> torch.Tensor: 118 | """Calculate loss. 119 | 120 | Pure part of calculate loss, does not contain model forward procedure, 121 | so that it can be combined with other loss. 122 | 123 | Args: 124 | logit_s (torch.Tensor): student model output, 125 | (https://developers.google.com/machine-learning/glossary/#logits). 126 | logit_t (torch.Tensor): teacher model ouptut. 127 | labels (torch.Tensor): labels for input images. 128 | 129 | Returns: 130 | loss (torch.Tensor): calculated loss. 131 | """ 132 | log_p_s = F.log_softmax(logit_s / self.T, dim=1) 133 | p_t = F.softmax(logit_t / self.T, dim=1) 134 | hinton_kld = F.kl_div(log_p_s, p_t, reduction="batchmean") * (self.T**2) 135 | ce = self.cross_entropy.calculate_loss(logit_s, labels) 136 | return (1.0 - self.alpha) * ce + self.alpha * hinton_kld 137 | 138 | 139 | class CrossEntropy(Criterion): 140 | """Crossentropy loss accepting soft labels. 141 | 142 | Attributes: 143 | log_softmax (nn.Module): log softmax function. 144 | num_classes (int): number of classes in dataset, to get onehot labels 145 | """ 146 | 147 | def __init__( 148 | self, device: torch.device, num_classes: int, label_smoothing: float 149 | ) -> None: 150 | """Initialize cross entropy loss.""" 151 | super().__init__(device) 152 | self.log_softmax = nn.LogSoftmax(dim=1) 153 | self.num_classes = num_classes 154 | self.label_smoothing = label_smoothing 155 | 156 | def forward( 157 | self, model: nn.Module, images: torch.Tensor, labels: torch.Tensor 158 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 159 | """Forward model, calculate loss. 160 | 161 | Args: 162 | image (torch.Tensor): input images. 163 | labels (torch.Tensor): labels for input images. 164 | Returns: 165 | loss (torch.Tensor): calculated loss. 166 | logit (Dict[str, torch.Tensor]): model output. 167 | """ 168 | logit = model(images) 169 | return self.calculate_loss(logit=logit, labels=labels), {"model": logit} 170 | 171 | def calculate_loss(self, logit: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 172 | """Calculate loss. 173 | 174 | Pure part of calculate loss, does not contain model forward procedure, 175 | so that it can be combined with other loss. 176 | 177 | Args: 178 | logit (torch.Tensor): model output, 179 | (https://developers.google.com/machine-learning/glossary/#logits). 180 | labels (torch.Tensor): labels for input images. 181 | Returns: 182 | loss (torch.Tensor): calculated loss. 183 | """ 184 | # if labels are index values -> expand to onehot for compatability 185 | target = utils.to_onehot(labels=labels, num_classes=self.num_classes).to( 186 | self.device 187 | ) 188 | pred = self.log_softmax(logit) 189 | 190 | # get smooth labels 191 | if self.label_smoothing > 0.0: 192 | target = self.add_label_smoothing(target) 193 | 194 | return torch.mean(torch.sum(-target * pred, dim=1)) 195 | 196 | @torch.no_grad() 197 | def add_label_smoothing(self, target: torch.Tensor) -> torch.Tensor: 198 | """Add smoothness in labels.""" 199 | nonzero_idxs = target != 0.0 200 | nonzero_cnt = nonzero_idxs.sum(dim=1, keepdim=True).float() 201 | 202 | target *= 1 - self.label_smoothing 203 | smooth_target = torch.ones_like(target).to(self.device) 204 | smooth_target *= self.label_smoothing / (self.num_classes - nonzero_cnt) 205 | smooth_target[nonzero_idxs] = target[nonzero_idxs] 206 | return smooth_target 207 | 208 | 209 | def get_criterion( 210 | criterion_name: str, 211 | criterion_params: Dict[str, Any], 212 | device: torch.device, 213 | ) -> nn.Module: 214 | """Create loss class.""" 215 | return eval(criterion_name)(device, **criterion_params) 216 | -------------------------------------------------------------------------------- /src/format.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """String formats for logging. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | 9 | def default_format(x: float) -> str: 10 | """General format used for loss, hyper params, etc.""" 11 | return str(round(x, 6)) 12 | 13 | 14 | def percent_format(x: float) -> str: 15 | """Return a formatted string for percent.""" 16 | return f"{x:.2f}%" 17 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | """Console logger module. 2 | 3 | - Author: Jongkuk Lim, Haneol Kim 4 | - Contact: limjk@jmarple.ai, hekim@jmarple.ai 5 | """ 6 | import logging 7 | import os 8 | from typing import Any, Optional 9 | 10 | LOG_LEVEL = logging.DEBUG 11 | 12 | LOCAL_RANK = int( 13 | os.getenv("LOCAL_RANK", -1) 14 | ) # https://pytorch.org/docs/stable/elastic/run.html 15 | RANK = int(os.getenv("RANK", -1)) 16 | WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1)) 17 | 18 | 19 | def colorstr(*args: Any) -> str: 20 | """Make color string🌈. 21 | 22 | Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, 23 | i.e. colorstr('blue', 'hello world') 24 | 25 | Color codes: 26 | "black", 27 | "red", 28 | "green", 29 | "yellow", 30 | "blue", (Default) 31 | "magenta", 32 | "cyan", 33 | "white", 34 | "bright_black", 35 | "bright_red", 36 | "bright_green", 37 | "bright_yellow", 38 | "bright_blue", 39 | "bright_magenta", 40 | "bright_cyan", 41 | "bright_white", 42 | 43 | Text format: 44 | "bold", (Default) 45 | "underline", 46 | 47 | Args: 48 | *args: string with text format. 49 | Ex) colorstr("red", "bold" "Hello world") 50 | will print red and bold text of "Hello world" 51 | 52 | Return: 53 | text with color🌈 54 | """ 55 | *args, string = ( 56 | args if len(args) > 1 else ("blue", "bold", args[0]) # type: ignore 57 | ) # color arguments, string 58 | colors = { 59 | "black": "\033[30m", # basic colors 60 | "red": "\033[31m", 61 | "green": "\033[32m", 62 | "yellow": "\033[33m", 63 | "blue": "\033[34m", 64 | "magenta": "\033[35m", 65 | "cyan": "\033[36m", 66 | "white": "\033[37m", 67 | "bright_black": "\033[90m", # bright colors 68 | "bright_red": "\033[91m", 69 | "bright_green": "\033[92m", 70 | "bright_yellow": "\033[93m", 71 | "bright_blue": "\033[94m", 72 | "bright_magenta": "\033[95m", 73 | "bright_cyan": "\033[96m", 74 | "bright_white": "\033[97m", 75 | "end": "\033[0m", # misc 76 | "bold": "\033[1m", 77 | "underline": "\033[4m", 78 | } 79 | 80 | return "".join(colors[x] for x in args) + f"{string}" + colors["end"] 81 | 82 | 83 | def get_logger( 84 | name: str, log_level: Optional[int] = None, main_proc_only: bool = True 85 | ) -> logging.Logger: 86 | """Get logger with formatter. 87 | 88 | Args: 89 | name: logger name 90 | log_level: logging level if None is given, constants.LOG_LEVEL will be used. 91 | main_proc_only: log only rank in [-1, 0] 92 | 93 | Return: 94 | logger with string formatter. 95 | """ 96 | logger = logging.getLogger(name) 97 | logger.setLevel(LOG_LEVEL) 98 | 99 | formatter = logging.Formatter( 100 | "[%(asctime)s]" 101 | + colorstr("yellow", "bold", "[%(levelname)s]") 102 | + colorstr("green", "bold", "[%(name)s]") 103 | + colorstr("cyan", "bold", "[%(filename)s:%(lineno)d]") 104 | + colorstr("blue", "bold", "(%(funcName)s)") 105 | + " %(message)s" 106 | ) 107 | 108 | if main_proc_only and RANK not in [-1, 0]: 109 | logger.disabled = True 110 | elif not logger.handlers: 111 | ch = logging.StreamHandler() 112 | ch.setFormatter(formatter) 113 | ch.setLevel(LOG_LEVEL if log_level is None else log_level) 114 | 115 | logger.addHandler(ch) 116 | logger.propagate = False 117 | 118 | return logger 119 | -------------------------------------------------------------------------------- /src/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Learning rate schedulers. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | 9 | from abc import ABC, abstractmethod 10 | import math 11 | from typing import Any, Dict, List 12 | 13 | from torch.optim.optimizer import Optimizer 14 | 15 | 16 | class LrScheduler(ABC): 17 | """Abstract class for learning rate schedulers.""" 18 | 19 | @abstractmethod 20 | def __call__(self, optimizer: Optimizer, epoch: int) -> None: 21 | """Set optimizer's learning rate.""" 22 | raise NotImplementedError 23 | 24 | 25 | class Identity(LrScheduler): 26 | """Keep learning rate as config["LR"].""" 27 | 28 | def __call__(self, optimizer: Optimizer, epoch: int) -> None: 29 | """Set optimizer's learning rate.""" 30 | return None 31 | 32 | 33 | class MultiStepLR(LrScheduler): 34 | """Multi Step LR scheduler.""" 35 | 36 | def __init__(self, milestones: List[int], gamma: float) -> None: 37 | """Initialize.""" 38 | self.milestones = set(milestones) 39 | self.gamma = gamma 40 | 41 | def __call__(self, optimizer: Optimizer, epoch: int) -> None: 42 | """Set optimizer's learning rate.""" 43 | if epoch not in self.milestones: 44 | return None 45 | 46 | for param_group in optimizer.param_groups: 47 | param_group["lr"] *= self.gamma 48 | 49 | 50 | class WarmupCosineLR(LrScheduler): 51 | """Cosine learning rate scheduler with warm-up.""" 52 | 53 | # epochs and target_lr are automatically set in config validator 54 | def __init__( 55 | self, 56 | warmup_epochs: int, 57 | epochs: int, 58 | start_lr: float, 59 | target_lr: float, 60 | min_lr: float, 61 | n_rewinding: int, 62 | decay: float, 63 | ) -> None: 64 | """Initialize.""" 65 | self.warmup_epochs = warmup_epochs 66 | self.base_lr = start_lr 67 | self.target_lr = target_lr 68 | self.min_lr = min_lr 69 | self.period = epochs // n_rewinding 70 | self.decay = decay 71 | self.coies = [ 72 | math.cos((i - warmup_epochs) * math.pi / (self.period - warmup_epochs)) 73 | for i in range(self.period) 74 | ] 75 | 76 | def lr(self, epoch: int) -> float: 77 | """Get learning rate.""" 78 | n_iter, epoch = divmod(epoch, self.period) 79 | if epoch < self.warmup_epochs: 80 | lr = ( 81 | self.base_lr 82 | + (self.target_lr - self.base_lr) / self.warmup_epochs * epoch 83 | ) 84 | else: 85 | lr = 0.5 * (1 + self.coies[epoch]) * self.target_lr 86 | lr *= (1.0 - self.decay) ** n_iter 87 | return max(lr, self.min_lr) 88 | 89 | def __call__(self, optimizer: Optimizer, epoch: int) -> None: 90 | """Set optimizer's learning rate.""" 91 | lr = self.lr(epoch) 92 | for param_group in optimizer.param_groups: 93 | param_group["lr"] = lr 94 | 95 | 96 | def get_lr_scheduler(name: str, lr_scheduler_params: Dict[str, Any]) -> LrScheduler: 97 | """LR scheduler getter.""" 98 | return eval(name)(**lr_scheduler_params) 99 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize. 2 | 3 | - Author: Haneol Kim. 4 | - Contact: hekim@jmarple.ai 5 | """ 6 | -------------------------------------------------------------------------------- /src/models/adjmodule_getter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Get adjacent modules from a given model. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from itertools import chain 9 | from typing import Any, Dict, List, Tuple, Union 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class AdjModuleGetter: 16 | """Adjacent module getter used by Shrinker. 17 | 18 | This gets adjacent module information to use for model shrinking. Assume the model 19 | consists of conv-bn-relu sequence. 20 | """ 21 | 22 | def __init__( 23 | self, model: nn.Module, input_size: Tuple[int, ...], device: torch.device 24 | ) -> None: 25 | """Initialize.""" 26 | self.model = model 27 | # op: module_ahead[op] = module 28 | self.module_ahead: Dict[Any, nn.Module] = dict() 29 | # op_behind[module] = op 30 | self.op_behind: Dict[nn.Module, Any] = dict() 31 | self.last_conv_shape = 0 32 | 33 | # register hooks 34 | hooks = [] 35 | for module in self.model.modules(): 36 | hooks.append(module.register_forward_hook(self._hook_fn)) 37 | 38 | # execute hooks and create backward operators graph 39 | rand_in = torch.randn(input_size, device=device) 40 | out = self.model(rand_in) 41 | self.op_backward_graph = self._create_backward_op_graph(out) 42 | 43 | # remove hooks 44 | for hook in hooks: 45 | hook.remove() 46 | 47 | def find_modules_ahead_of( 48 | self, module: nn.Module, target_type: "type" 49 | ) -> List[Any]: 50 | """Find all modules ahead of the input backward opterator.""" 51 | 52 | def find_modules_ahead_of(op: Any) -> List[Any]: 53 | if op in self.module_ahead and type(self.module_ahead[op]) is target_type: 54 | return [self.module_ahead[op]] 55 | modules: List[Any] = [] 56 | if op in self.op_backward_graph and self.op_backward_graph[op]: 57 | module_iter = chain.from_iterable( 58 | find_modules_ahead_of(prev_op) 59 | for prev_op in self.op_backward_graph[op] 60 | ) 61 | modules = list(module_iter) 62 | return modules 63 | 64 | return find_modules_ahead_of(self.op_behind[module]) 65 | 66 | def find_module_next_to( 67 | self, module: nn.Module, target_type: "type" 68 | ) -> Union[nn.Module, None]: 69 | """Find a single module right next to the input module.""" 70 | layers = [ 71 | v for v in self.model.modules() if type(v) in {type(module), target_type} 72 | ] 73 | for i in range(1, len(layers)): 74 | if layers[i - 1] is module and type(layers[i]) is target_type: 75 | return layers[i] 76 | return None 77 | 78 | def _create_backward_op_graph(self, out: torch.Tensor) -> Dict[Any, List[Any]]: 79 | """Create a graph that contains backward operators' information.""" 80 | graph: Dict[Any, List[Any]] = dict() 81 | 82 | def backward_search(var: Any) -> None: 83 | if var in graph: 84 | return 85 | graph[var] = [] 86 | if hasattr(var, "next_functions"): 87 | for next_var in var.next_functions: 88 | if next_var[0] is None: 89 | continue 90 | graph[var].append(next_var[0]) 91 | backward_search(next_var[0]) 92 | 93 | backward_search(out.grad_fn) 94 | return graph 95 | 96 | def _hook_fn(self, module: nn.Module, inp: torch.Tensor, out: torch.Tensor) -> None: 97 | self.module_ahead[out.grad_fn] = module 98 | self.op_behind[module] = out.grad_fn 99 | 100 | if type(module) == nn.Flatten: # type: ignore 101 | self.last_conv_shape = inp[0].size()[-1] 102 | -------------------------------------------------------------------------------- /src/models/common_activations.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Common activation modules. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class HSigmoid(nn.Module): 13 | """Hard Sigmoid.""" 14 | 15 | def __init__(self, inplace: bool = True) -> None: 16 | """Initialize.""" 17 | super(HSigmoid, self).__init__() 18 | self.relu6 = nn.ReLU6(inplace=inplace) 19 | 20 | def forward(self, x: torch.Tensor) -> torch.Tensor: 21 | """Forward.""" 22 | x = self.relu6(x + 3) / 6 23 | return x 24 | 25 | 26 | class QuantizableHSigmoid(nn.Module): 27 | """Hard Sigmoid for quantization.""" 28 | 29 | def __init__(self, inplace: bool = True) -> None: 30 | """Initialize.""" 31 | super(QuantizableHSigmoid, self).__init__() 32 | self.relu6 = nn.ReLU6(inplace=inplace) 33 | self.add_scalar = nn.quantized.FloatFunctional() 34 | self.mul_scalar = nn.quantized.FloatFunctional() 35 | 36 | def forward(self, x: torch.Tensor) -> torch.Tensor: 37 | """Forward.""" 38 | x = self.add_scalar.add_scalar(x, 3.0) 39 | x = self.relu6(x) 40 | x = self.mul_scalar.mul_scalar(x, 1 / 6) 41 | return x 42 | 43 | 44 | class HSwish(nn.Module): 45 | """Hard swish.""" 46 | 47 | def __init__(self, inplace: bool = True) -> None: 48 | """Initialize.""" 49 | super(HSwish, self).__init__() 50 | self.hsig = HSigmoid(inplace=inplace) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | """Forward.""" 54 | return x * self.hsig(x) 55 | 56 | 57 | class QuantizableHSwish(nn.Module): 58 | """Hard Swish for quantization.""" 59 | 60 | def __init__(self, inplace: bool = True) -> None: 61 | """Initialize.""" 62 | super(QuantizableHSwish, self).__init__() 63 | self.hsig = QuantizableHSigmoid(inplace=inplace) 64 | self.mul = nn.quantized.FloatFunctional() 65 | 66 | def forward(self, x: torch.Tensor) -> torch.Tensor: 67 | """Forward.""" 68 | return self.mul.mul(x, self.hsig(x)) 69 | -------------------------------------------------------------------------------- /src/models/common_layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Common layer modules. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from typing import List 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from src.models.common_activations import HSigmoid, QuantizableHSigmoid 14 | import src.models.utils as model_utils 15 | 16 | 17 | class Identity(nn.Module): 18 | """Identity.""" 19 | 20 | def forward(self, x: torch.Tensor) -> torch.Tensor: 21 | """Return the input.""" 22 | return x 23 | 24 | 25 | class ConvBN(nn.Module): 26 | """Conv2d + BatchNorm2d. 27 | 28 | If you want Conv2d work as Deption-wise Conv2d, set in_channels = groups. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | in_channels: int, 34 | out_channels: int, 35 | kernel_size: int = 3, 36 | stride: int = 1, 37 | groups: int = 1, 38 | bias: bool = False, 39 | momentum: float = 0.01, 40 | ) -> None: 41 | """Initialize.""" 42 | super(ConvBN, self).__init__() 43 | padding = (kernel_size - 1) // 2 44 | self.conv = nn.Conv2d( 45 | in_channels, 46 | out_channels, 47 | kernel_size, 48 | stride, 49 | padding, 50 | groups=groups, 51 | bias=bias, 52 | ) 53 | self.bn = nn.BatchNorm2d(out_channels, momentum=momentum) 54 | 55 | def forward(self, x: torch.Tensor) -> torch.Tensor: 56 | """Forward.""" 57 | x = self.conv(x) 58 | x = self.bn(x) 59 | return x 60 | 61 | 62 | class ConvBNReLU(nn.Module): 63 | """Conv2d + BatchNorm2d + ReLU. 64 | 65 | If you want Conv2d work as Deption-wise Conv2d, set in_channels = groups. 66 | """ 67 | 68 | def __init__( 69 | self, 70 | in_channels: int, 71 | out_channels: int, 72 | kernel_size: int = 3, 73 | stride: int = 1, 74 | groups: int = 1, 75 | bias: bool = False, 76 | momentum: float = 0.01, 77 | ) -> None: 78 | """Initialize.""" 79 | super(ConvBNReLU, self).__init__() 80 | padding = (kernel_size - 1) // 2 81 | self.conv = nn.Conv2d( 82 | in_channels, 83 | out_channels, 84 | kernel_size, 85 | stride, 86 | padding, 87 | groups=groups, 88 | bias=bias, 89 | ) 90 | self.bn = nn.BatchNorm2d(out_channels, momentum=momentum) 91 | self.relu = nn.ReLU(inplace=True) 92 | 93 | def forward(self, x: torch.Tensor) -> torch.Tensor: 94 | """Forward.""" 95 | x = self.conv(x) 96 | x = self.bn(x) 97 | x = self.relu(x) 98 | return x 99 | 100 | 101 | class SqueezeExcitation(nn.Module): 102 | """Squeeze and Excitation layer.""" 103 | 104 | def __init__(self, in_channels: int, se_ratio: float) -> None: 105 | """Initialize.""" 106 | super(SqueezeExcitation, self).__init__() 107 | hidden_channels = max(1, int(in_channels * se_ratio)) 108 | self.se_reduce = ConvBNReLU(in_channels, hidden_channels, bias=True) 109 | self.se_expand = ConvBN(hidden_channels, in_channels, bias=True) 110 | self.hsig = HSigmoid() 111 | 112 | def _mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 113 | """Multiply two tensors (elementwise).""" 114 | return x * y 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | """Forward.""" 118 | se_tensor = torch.mean(x, dim=[2, 3], keepdim=True) 119 | out = self.se_reduce(se_tensor) 120 | out = self.se_expand(out) 121 | out = self._mul(self.hsig(out), x) 122 | return out 123 | 124 | 125 | class QuantizableSqueezeExcitation(SqueezeExcitation): 126 | """Squeeze and Excitation layer.""" 127 | 128 | def __init__(self, **kwargs: bool) -> None: 129 | """Initialize.""" 130 | super(QuantizableSqueezeExcitation, self).__init__(**kwargs) 131 | self.mul = nn.quantized.FloatFunctional() 132 | self.hsig = QuantizableHSigmoid() 133 | 134 | def _mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 135 | """Multiply two tensors (elementwise).""" 136 | return self.mul.mul(x, y) 137 | 138 | 139 | class MDConvBlock(nn.Module): 140 | """Mixed-depthwise Conv2d-BN2d-(ReLU).""" 141 | 142 | def __init__( 143 | self, in_channels: int, n_chunks: int, stride: int = 1, with_relu: int = True 144 | ) -> None: 145 | """Initialize.""" 146 | super(MDConvBlock, self).__init__() 147 | self.n_chunks = n_chunks 148 | self.split_in_channels = model_utils.split_channels(in_channels, n_chunks) 149 | 150 | self.blocks = nn.ModuleList() 151 | for idx in range(self.n_chunks): 152 | kernel_size = 2 * idx + 3 153 | in_channels = out_channels = self.split_in_channels[idx] 154 | kwargs = dict( 155 | in_channels=in_channels, 156 | out_channels=out_channels, 157 | kernel_size=kernel_size, 158 | stride=stride, 159 | groups=in_channels, 160 | ) 161 | self.blocks.append(ConvBNReLU(**kwargs) if with_relu else ConvBN(**kwargs)) 162 | 163 | def _cat(self, block_res: List[torch.Tensor]) -> torch.Tensor: 164 | """Concat channels of block results.""" 165 | return torch.cat(block_res, dim=1) 166 | 167 | def forward(self, x: torch.Tensor) -> torch.Tensor: 168 | """Forward.""" 169 | split = torch.split(x, self.split_in_channels, dim=1) 170 | block_res = [] 171 | # torch.jit.script doesn't recognize zip(self.blocks, split) 172 | for i, block in enumerate(self.blocks): 173 | block_res.append(block(split[i])) 174 | return self._cat(block_res) 175 | 176 | 177 | class QuantizableMDConvBlock(MDConvBlock): 178 | """Mixed-depthwise Conv2d-BN2d-(ReLU).""" 179 | 180 | def __init__(self, **kwargs: bool) -> None: 181 | """Initialize.""" 182 | super(QuantizableMDConvBlock, self).__init__(**kwargs) 183 | self.cat = nn.quantized.FloatFunctional() 184 | 185 | def _cat(self, block_res: List[torch.Tensor]) -> torch.Tensor: 186 | """Concat channels of block results.""" 187 | return self.cat.cat(block_res, dim=1) 188 | -------------------------------------------------------------------------------- /src/models/densenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Fixed DenseNet Model. 3 | 4 | All blocks consist of ConvBNReLU for quantization. 5 | 6 | - Author: Curt-Park 7 | - Email: jwpark@jmarple.ai 8 | - References: 9 | https://github.com/bearpaw/pytorch-classification 10 | https://github.com/gpleiss/efficient_densenet_pytorch 11 | """ 12 | 13 | import math 14 | from typing import Any, Tuple 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.utils.checkpoint as cp 19 | 20 | from src.models.common_layers import ConvBNReLU 21 | 22 | 23 | class Bottleneck(nn.Module): 24 | """Bottleneck block for DenseNet.""" 25 | 26 | def __init__( 27 | self, 28 | inplanes: int, 29 | expansion: int, 30 | growthRate: int, 31 | efficient: bool, 32 | ) -> None: 33 | """Initialize.""" 34 | super(Bottleneck, self).__init__() 35 | planes = expansion * growthRate 36 | self.conv1 = ConvBNReLU(inplanes, planes, kernel_size=1) 37 | self.conv2 = ConvBNReLU(planes, growthRate, kernel_size=3) 38 | self.efficient = efficient 39 | 40 | def _expand(self, *features: torch.Tensor) -> torch.Tensor: 41 | """Bottleneck foward function.""" 42 | concated_features = torch.cat(features, 1) 43 | bottleneck_output = self.conv1(concated_features) 44 | return bottleneck_output 45 | 46 | def forward(self, *prev_features: torch.Tensor) -> torch.Tensor: 47 | """Forward.""" 48 | if self.efficient and any(feat.requires_grad for feat in prev_features): 49 | out = cp.checkpoint(self._expand, *prev_features) 50 | else: 51 | out = self._expand(*prev_features) 52 | out = self.conv2(out) 53 | return out 54 | 55 | 56 | class DenseBlock(nn.Module): 57 | """Densenet block.""" 58 | 59 | def __init__( 60 | self, 61 | inplanes: int, 62 | blocks: int, 63 | expansion: int, 64 | growth_rate: int, 65 | efficient: bool, 66 | Layer: "type" = Bottleneck, 67 | ) -> None: 68 | """Initialize.""" 69 | super(DenseBlock, self).__init__() 70 | self.layers = nn.ModuleList() 71 | for i in range(blocks): 72 | layer = Layer( 73 | inplanes=inplanes + i * growth_rate, 74 | expansion=expansion, 75 | growthRate=growth_rate, 76 | efficient=efficient, 77 | ) 78 | self.layers.append(layer) 79 | 80 | def forward(self, init_features: torch.Tensor) -> torch.Tensor: 81 | """Forward.""" 82 | features = [init_features] 83 | for layer in self.layers: 84 | new_features = layer(*features) 85 | features.append(new_features) 86 | return torch.cat(features, dim=1) 87 | 88 | 89 | class Transition(nn.Module): 90 | """Transition between blocks.""" 91 | 92 | def __init__(self, inplanes: int, outplanes: int) -> None: 93 | """Initialize.""" 94 | super(Transition, self).__init__() 95 | self.conv = ConvBNReLU(inplanes, outplanes, kernel_size=1) 96 | self.avg_pool = nn.AvgPool2d(2) 97 | 98 | def forward(self, x: torch.Tensor) -> torch.Tensor: 99 | """Forward.""" 100 | out = self.conv(x) 101 | out = self.avg_pool(out) 102 | return out 103 | 104 | 105 | class DenseNet(nn.Module): 106 | """DenseNet architecture.""" 107 | 108 | def __init__( 109 | self, 110 | num_classes: int, 111 | inplanes: int, 112 | expansion: int = 4, 113 | growthRate: int = 12, 114 | compressionRate: int = 2, 115 | block_configs: Tuple[int, ...] = (6, 12, 24, 16), 116 | small_input: bool = True, # e.g. CIFAR100 117 | efficient: bool = False, # memory efficient dense block 118 | Block: "type" = DenseBlock, 119 | ) -> None: 120 | """Initialize.""" 121 | super(DenseNet, self).__init__() 122 | 123 | self.growthRate = growthRate 124 | self.inplanes = inplanes 125 | self.expansion = expansion 126 | 127 | if small_input: 128 | self.stem = ConvBNReLU(3, self.inplanes, kernel_size=3, stride=1) 129 | else: 130 | self.stem = nn.Sequential( 131 | ConvBNReLU(3, self.inplanes, kernel_size=7, stride=2), 132 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False), 133 | ) 134 | 135 | layers = [] 136 | for i, n_bottleneck in enumerate(block_configs): 137 | dense_block = Block( 138 | self.inplanes, n_bottleneck, expansion, growthRate, efficient 139 | ) 140 | layers.append(dense_block) 141 | self.inplanes += n_bottleneck * self.growthRate 142 | # not add transition at the end 143 | if i < len(block_configs) - 1: 144 | layers.append(self._make_transition(compressionRate)) 145 | self.dense_blocks = nn.Sequential(*layers) 146 | 147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 148 | self.flatten = nn.Flatten() # type: ignore 149 | self.fc = nn.Linear(self.inplanes, num_classes) 150 | 151 | # Weight initialization 152 | for m in self.modules(): 153 | if isinstance(m, nn.Conv2d): 154 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 155 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 156 | elif isinstance(m, nn.BatchNorm2d): 157 | m.weight.data.fill_(1) 158 | m.bias.data.zero_() 159 | 160 | def _make_transition(self, compressionRate: int) -> nn.Module: 161 | """Make a transition.""" 162 | inplanes = self.inplanes 163 | outplanes = int(math.floor(self.inplanes // compressionRate)) 164 | self.inplanes = outplanes 165 | return Transition(inplanes, outplanes) 166 | 167 | def _forward_impl(self, x: torch.Tensor) -> torch.Tensor: 168 | """Actual forward procedures.""" 169 | x = self.stem(x) 170 | x = self.dense_blocks(x) 171 | x = self.avgpool(x) 172 | x = self.flatten(x) 173 | x = self.fc(x) 174 | return x 175 | 176 | def forward(self, x: torch.Tensor) -> torch.Tensor: 177 | """Forward.""" 178 | return self._forward_impl(x) 179 | 180 | 181 | def get_model(**kwargs: Any) -> nn.Module: 182 | """Construct a ResNet model.""" 183 | return DenseNet(**kwargs) 184 | -------------------------------------------------------------------------------- /src/models/mixnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """MixNet - S / M / L, MicroNet. 3 | 4 | * Note: SHRINKING IS NOT SUPPORTED! 5 | 6 | - Author: Curt-Park 7 | - Email: jwpark@jmarple.ai 8 | - Paper: https://arxiv.org/abs/1907.09595 9 | - Differences from the original model: 10 | Every Mixblock has a skip connection 11 | Swish function replaced with HSwish 12 | Mixblock doesn't use group conv operation 13 | Squeeze-and-Excitation is located behind projection 14 | - Reference: 15 | https://github.com/leaderj1001/Mixed-Depthwise-Convolutional-Kernels 16 | https://github.com/Kthyeon/micronet_neurips_challenge 17 | """ 18 | 19 | 20 | from typing import Any, Dict, List, Tuple 21 | 22 | import torch 23 | import torch.nn as nn 24 | 25 | from src.models.common_activations import HSwish 26 | from src.models.common_layers import ( 27 | ConvBN, 28 | ConvBNReLU, 29 | Identity, 30 | MDConvBlock, 31 | SqueezeExcitation, 32 | ) 33 | 34 | 35 | def round_filters( 36 | n_filters: int, multiplier: float = 1.0, divisor: int = 8, min_depth: int = None 37 | ) -> int: 38 | """Get the number of channels.""" 39 | multiplier = multiplier 40 | divisor = divisor 41 | min_depth = min_depth 42 | 43 | if not multiplier: 44 | return n_filters 45 | 46 | n_filters = int(n_filters * multiplier) 47 | min_depth = min_depth or divisor 48 | n_filters_new = max(min_depth, int(n_filters + divisor / 2) // divisor * divisor) 49 | if n_filters_new < 0.9 * n_filters: 50 | n_filters_new += divisor 51 | return n_filters_new 52 | 53 | 54 | class MixBlock(nn.Module): 55 | """MixBlock: Using different kernel sizes for each channel chunk.""" 56 | 57 | def __init__( 58 | self, 59 | in_channels: int, 60 | out_channels: int, 61 | n_chunks: int, 62 | stride: int, 63 | expand_ratio: float, 64 | se_ratio: float, 65 | hswish: bool, 66 | ) -> None: 67 | """Initialize.""" 68 | super(MixBlock, self).__init__() 69 | self.in_channels = in_channels 70 | self.exp_channels = int(in_channels * expand_ratio) 71 | self.out_channels = out_channels 72 | self.n_chunks = n_chunks 73 | self.stride = stride 74 | self.expand_ratio = expand_ratio 75 | self.has_se = se_ratio is not None 76 | self.se_ratio = se_ratio 77 | self.hswish = hswish 78 | 79 | self.expand_conv = Identity() 80 | if self.in_channels != self.exp_channels: 81 | self.expand_conv = ( 82 | nn.Sequential( 83 | ConvBN(self.in_channels, self.exp_channels, kernel_size=1), 84 | HSwish(inplace=True), 85 | ) 86 | if self.hswish 87 | else ConvBNReLU(self.in_channels, self.exp_channels, kernel_size=1) 88 | ) 89 | 90 | self.mdconv = nn.Sequential( 91 | MDConvBlock( 92 | self.exp_channels, 93 | n_chunks=self.n_chunks, 94 | stride=self.stride, 95 | with_relu=not self.hswish, 96 | ), 97 | Identity() if not self.hswish else HSwish(inplace=True), 98 | ) 99 | 100 | self.proj_conv = ConvBN(self.exp_channels, self.out_channels, kernel_size=1) 101 | 102 | self.se = ( 103 | SqueezeExcitation(self.out_channels, self.se_ratio) 104 | if self.has_se 105 | else Identity() 106 | ) 107 | 108 | self.downsample = ( 109 | ConvBN( 110 | self.in_channels, self.out_channels, kernel_size=1, stride=self.stride 111 | ) 112 | if self.stride != 1 or self.in_channels != self.out_channels 113 | else Identity() 114 | ) 115 | 116 | def _add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 117 | """Sum two tensors (elementwise).""" 118 | return x + y 119 | 120 | def forward(self, x: torch.Tensor) -> torch.Tensor: 121 | """Forward.""" 122 | out = self.expand_conv(x) 123 | out = self.mdconv(out) 124 | out = self.proj_conv(out) 125 | out = self.se(out) 126 | out = self._add(out, self.downsample(x)) 127 | return out 128 | 129 | 130 | class MixNet(nn.Module): 131 | """MixNet architecture.""" 132 | 133 | def __init__( 134 | self, 135 | stem: int, 136 | stem_stride: int, 137 | head: int, 138 | last_out_channels: int, 139 | block_args: Tuple[List[Any], ...], 140 | dropout: float = 0.2, 141 | num_classes: int = 1000, 142 | Block: "type" = MixBlock, 143 | ) -> None: 144 | """Initialize.""" 145 | super(MixNet, self).__init__() 146 | self.block_args = block_args 147 | 148 | self.stem = nn.Sequential( 149 | ConvBN( 150 | in_channels=3, 151 | out_channels=stem, 152 | kernel_size=3, 153 | stride=stem_stride, 154 | ), 155 | HSwish(inplace=True), 156 | ) 157 | 158 | layers = [] 159 | for ( 160 | in_channels, 161 | out_channels, 162 | n_chunks, 163 | stride, 164 | expand_ratio, 165 | se_ratio, 166 | hswish, 167 | ) in block_args: 168 | layers.append( 169 | Block( 170 | in_channels=in_channels, 171 | out_channels=out_channels, 172 | n_chunks=n_chunks, 173 | stride=stride, 174 | expand_ratio=expand_ratio, 175 | se_ratio=se_ratio, 176 | hswish=hswish, 177 | ) 178 | ) 179 | self.layers = nn.Sequential(*layers) 180 | 181 | if head: 182 | self.head = nn.Sequential( 183 | ConvBN( 184 | in_channels=last_out_channels, 185 | out_channels=head, 186 | kernel_size=1, 187 | ), 188 | HSwish(inplace=True), 189 | ) 190 | else: 191 | self.head = Identity() 192 | head = last_out_channels 193 | 194 | self.adapt_avg_pool2d = nn.AdaptiveAvgPool2d((1, 1)) 195 | self.dropout = nn.Dropout(p=dropout) 196 | 197 | self.fc = nn.Linear(head, num_classes) 198 | 199 | def _forward_impl(self, x: torch.Tensor) -> torch.Tensor: 200 | """Actual forward procedure.""" 201 | out = self.stem(x) 202 | out = self.layers(out) 203 | out = self.head(out) 204 | out = self.adapt_avg_pool2d(out) 205 | out = torch.flatten(out, 1) 206 | out = self.fc(out) 207 | return out 208 | 209 | def forward(self, x: torch.Tensor) -> torch.Tensor: 210 | """Forward.""" 211 | return self._forward_impl(x) 212 | 213 | 214 | def get_model_kwargs(model_type: str, num_classes: int, dataset: str) -> Dict[str, Any]: 215 | """Return the model kwargs according to the momdel type.""" 216 | if model_type == "MICRONET": 217 | kwargs = micronet(num_classes=num_classes, dataset=dataset) 218 | elif model_type == "S": 219 | kwargs = mixnet_s(num_classes=num_classes, dataset=dataset) 220 | elif model_type == "M": 221 | kwargs = mixnet_m(num_classes=num_classes, dataset=dataset) 222 | elif model_type == "L": 223 | kwargs = mixnet_l(num_classes=num_classes, dataset=dataset) 224 | else: 225 | raise NotImplementedError 226 | return kwargs 227 | 228 | 229 | def get_model(model_type: str, num_classes: int, dataset: str) -> nn.Module: 230 | """Construct a MixNet model.""" 231 | kwargs = get_model_kwargs(model_type, num_classes, dataset) 232 | return MixNet(**kwargs) 233 | 234 | 235 | def micronet( 236 | num_classes: int = 100, 237 | multiplier: float = 1.0, 238 | divisor: int = 8, 239 | min_depth: int = None, 240 | dataset: str = "IMAGENET", 241 | ) -> Dict[str, Any]: 242 | """Build MixNet-SS.""" 243 | if dataset == "CIFAR100": 244 | # in_channels, out_channels, n_chunks, stride, expand_ratio, se_ratio, hswish 245 | small = ( 246 | [32, 16, 1, 1, 3, None, False], 247 | [16, 16, 1, 1, 3, None, False], 248 | [16, 32, 1, 2, 3, None, False], 249 | [32, 32, 1, 1, 3, 0.25, True], 250 | [32, 48, 1, 1, 3, 0.25, True], 251 | [48, 48, 1, 1, 3, 0.25, True], 252 | [48, 48, 1, 1, 3, 0.25, True], 253 | [48, 72, 1, 2, 3, 0.25, True], 254 | [72, 72, 1, 1, 3, 0.25, True], 255 | [72, 72, 1, 1, 3, 0.25, True], 256 | [72, 72, 1, 1, 3, 0.25, True], 257 | [72, 72, 1, 1, 3, 0.25, True], 258 | [72, 80, 1, 2, 3, 0.25, True], 259 | [80, 88, 1, 1, 3, 0.25, True], 260 | [88, 88, 1, 1, 3, 0.25, True], 261 | [88, 106, 1, 1, 3, 0.25, True], 262 | ) 263 | stem = 32 264 | stem_stride = 1 265 | last_out_channels = 106 266 | dropout = 0.3 267 | else: 268 | raise NotImplementedError 269 | 270 | return dict( 271 | stem=stem, 272 | stem_stride=stem_stride, 273 | head=0, # head not used 274 | last_out_channels=last_out_channels, 275 | block_args=small, 276 | num_classes=num_classes, 277 | dropout=dropout, 278 | ) 279 | 280 | 281 | def mixnet_s( 282 | num_classes: int = 100, 283 | multiplier: float = 1.0, 284 | divisor: int = 8, 285 | min_depth: int = None, 286 | dataset: str = "IMAGENET", 287 | ) -> Dict[str, Any]: 288 | """Build MixNet-S.""" 289 | if dataset == "IMAGENET": 290 | # in_channels, out_channels, n_chunks, stride, expand_ratio, se_ratio, hswish 291 | small = ( 292 | [16, 16, 1, 1, 1, None, False], 293 | [16, 24, 1, 2, 6, None, False], 294 | [24, 24, 1, 1, 3, None, False], 295 | [24, 40, 3, 2, 6, 0.5, True], 296 | [40, 40, 2, 1, 6, 0.5, True], 297 | [40, 40, 2, 1, 6, 0.5, True], 298 | [40, 40, 2, 1, 6, 0.5, True], 299 | [40, 80, 3, 2, 6, 0.25, True], 300 | [80, 80, 2, 1, 6, 0.25, True], 301 | [80, 80, 2, 1, 6, 0.25, True], 302 | [80, 120, 3, 1, 6, 0.5, True], 303 | [120, 120, 4, 1, 3, 0.5, True], 304 | [120, 120, 4, 2, 3, 0.5, True], 305 | [120, 200, 5, 1, 6, 0.5, True], 306 | [200, 200, 4, 1, 6, 0.5, True], 307 | [200, 200, 4, 1, 6, 0.5, True], 308 | ) 309 | stem = round_filters(16, multiplier) 310 | stem_stride = 2 311 | last_out_channels = round_filters(200, multiplier) 312 | head = round_filters(1536, multiplier) 313 | elif dataset == "CIFAR100": 314 | small = ( 315 | [16, 16, 1, 1, 1, None, False], 316 | [16, 24, 1, 1, 6, None, False], 317 | [24, 24, 1, 1, 3, None, False], 318 | [24, 40, 3, 2, 6, 0.5, True], 319 | [40, 40, 2, 1, 6, 0.5, True], 320 | [40, 40, 2, 1, 6, 0.5, True], 321 | [40, 40, 2, 1, 6, 0.5, True], 322 | [40, 80, 3, 2, 6, 0.25, True], 323 | [80, 80, 2, 1, 6, 0.25, True], 324 | [80, 80, 2, 1, 6, 0.25, True], 325 | [80, 120, 3, 1, 6, 0.5, True], 326 | [120, 120, 4, 1, 3, 0.5, True], 327 | [120, 120, 4, 2, 3, 0.5, True], 328 | [120, 200, 5, 1, 6, 0.5, True], 329 | [200, 200, 4, 1, 6, 0.5, True], 330 | [200, 200, 4, 1, 6, 0.5, True], 331 | ) 332 | stem = round_filters(16, multiplier) 333 | stem_stride = 1 334 | last_out_channels = round_filters(200, multiplier) 335 | head = round_filters(1536, multiplier) 336 | else: 337 | raise NotImplementedError 338 | 339 | return dict( 340 | stem=stem, 341 | stem_stride=stem_stride, 342 | head=head, 343 | last_out_channels=last_out_channels, 344 | block_args=small, 345 | num_classes=num_classes, 346 | ) 347 | 348 | 349 | def mixnet_m( 350 | num_classes: int = 1000, 351 | multiplier: float = 1.0, 352 | divisor: int = 8, 353 | min_depth: int = None, 354 | dataset: str = "IMAGENET", 355 | ) -> Dict[str, Any]: 356 | """Build MixNet-M.""" 357 | if dataset == "IMAGENET": 358 | medium: Tuple[List[Any], ...] = ( 359 | [24, 24, 1, 1, 1, None, False], 360 | [24, 32, 3, 2, 6, None, False], 361 | [32, 32, 1, 1, 3, None, False], 362 | [32, 40, 4, 2, 6, 0.5, True], 363 | [40, 40, 2, 1, 6, 0.5, True], 364 | [40, 40, 2, 1, 6, 0.5, True], 365 | [40, 40, 2, 1, 6, 0.5, True], 366 | [40, 80, 3, 2, 6, 0.25, True], 367 | [80, 80, 4, 1, 6, 0.25, True], 368 | [80, 80, 4, 1, 6, 0.25, True], 369 | [80, 80, 4, 1, 6, 0.25, True], 370 | [80, 120, 1, 1, 6, 0.5, True], 371 | [120, 120, 4, 1, 3, 0.5, True], 372 | [120, 120, 4, 1, 3, 0.5, True], 373 | [120, 120, 4, 1, 3, 0.5, True], 374 | [120, 200, 4, 2, 6, 0.5, True], 375 | [200, 200, 4, 1, 6, 0.5, True], 376 | [200, 200, 4, 1, 6, 0.5, True], 377 | [200, 200, 4, 1, 6, 0.5, True], 378 | ) 379 | stem = round_filters(24, multiplier) 380 | stem_stride = 2 381 | last_out_channels = round_filters(200, multiplier) 382 | head = round_filters(1536, multiplier=1.0) 383 | elif dataset == "CIFAR100": 384 | medium = ( 385 | [24, 24, 1, 1, 1, None, False], 386 | [24, 32, 3, 1, 6, None, False], 387 | [32, 32, 1, 1, 3, None, False], 388 | [32, 40, 4, 2, 6, 0.5, True], 389 | [40, 40, 2, 1, 6, 0.5, True], 390 | [40, 40, 2, 1, 6, 0.5, True], 391 | [40, 40, 2, 1, 6, 0.5, True], 392 | [40, 80, 3, 2, 6, 0.25, True], 393 | [80, 80, 4, 1, 6, 0.25, True], 394 | [80, 80, 4, 1, 6, 0.25, True], 395 | [80, 80, 4, 1, 6, 0.25, True], 396 | [80, 120, 1, 1, 6, 0.5, True], 397 | [120, 120, 4, 1, 3, 0.5, True], 398 | [120, 120, 4, 1, 3, 0.5, True], 399 | [120, 120, 4, 1, 3, 0.5, True], 400 | [120, 200, 4, 2, 6, 0.5, True], 401 | [200, 200, 4, 1, 6, 0.5, True], 402 | [200, 200, 4, 1, 6, 0.5, True], 403 | [200, 200, 4, 1, 6, 0.5, True], 404 | ) 405 | stem = round_filters(24, multiplier) 406 | stem_stride = 1 407 | last_out_channels = round_filters(200, multiplier) 408 | head = round_filters(1536, multiplier=1.0) 409 | else: 410 | raise NotImplementedError 411 | 412 | for line in medium: 413 | line[0] = round_filters(line[0], multiplier) 414 | line[1] = round_filters(line[1], multiplier) 415 | 416 | return dict( 417 | stem=stem, 418 | stem_stride=stem_stride, 419 | head=head, 420 | last_out_channels=last_out_channels, 421 | block_args=medium, 422 | dropout=0.25, 423 | num_classes=num_classes, 424 | ) 425 | 426 | 427 | def mixnet_l(num_classes: int = 1000, dataset: str = "IMAGENET") -> Dict[str, Any]: 428 | """Build MixNet-L.""" 429 | return mixnet_m(num_classes=num_classes, multiplier=1.3, dataset=dataset) 430 | -------------------------------------------------------------------------------- /src/models/quant_densenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """DenseNet model for quantization. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from typing import Any, List, Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.quantization import DeQuantStub, QuantStub, fuse_modules 13 | 14 | from src.models.common_layers import ConvBN, ConvBNReLU 15 | from src.models.densenet import Bottleneck, DenseBlock, DenseNet 16 | 17 | 18 | class QuantizableBottleneck(Bottleneck): 19 | """Quantizable Bottleneck layer.""" 20 | 21 | def __init__( 22 | self, 23 | inplanes: int, 24 | expansion: int, 25 | growthRate: int, 26 | efficient: bool, 27 | ) -> None: 28 | """Initialize.""" 29 | super(QuantizableBottleneck, self).__init__( 30 | inplanes, expansion, growthRate, efficient=False 31 | ) 32 | self.cat = nn.quantized.FloatFunctional() 33 | 34 | # arbitrary sized input makes failure when quantizating models 35 | def forward(self, prev_features: List[torch.Tensor]) -> torch.Tensor: 36 | """Forward.""" 37 | # checkpoint doesn't work in scripted models 38 | out = self.cat.cat(prev_features, dim=1) 39 | out = self.conv1(out) 40 | out = self.conv2(out) 41 | return out 42 | 43 | 44 | class QuantizableDenseBlock(DenseBlock): 45 | """Quantizable Densenet block.""" 46 | 47 | def __init__( 48 | self, 49 | inplanes: int, 50 | blocks: int, 51 | expansion: int, 52 | growth_rate: int, 53 | efficient: bool, 54 | Layer: "type" = QuantizableBottleneck, 55 | ) -> None: 56 | """Initialize.""" 57 | super(QuantizableDenseBlock, self).__init__( 58 | inplanes, blocks, expansion, growth_rate, efficient, Layer 59 | ) 60 | self.cat = nn.quantized.FloatFunctional() 61 | 62 | def forward(self, init_features: torch.Tensor) -> torch.Tensor: 63 | """Forward.""" 64 | features = [init_features] 65 | for layer in self.layers: 66 | new_features = layer(features) 67 | features.append(new_features) 68 | return self.cat.cat(features, dim=1) 69 | 70 | 71 | class QuantizableDenseNet(DenseNet): 72 | """Quantizable DenseNet architecture.""" 73 | 74 | def __init__( 75 | self, 76 | num_classes: int, 77 | inplanes: int, 78 | expansion: int = 4, 79 | growthRate: int = 12, 80 | compressionRate: int = 2, 81 | block_configs: Tuple[int, ...] = (6, 12, 24, 16), 82 | small_input: bool = True, # e.g. CIFAR100 83 | efficient: bool = False, # memory efficient dense block 84 | Block: "type" = QuantizableDenseBlock, 85 | ) -> None: 86 | """Initialize.""" 87 | self.inplanes = 0 # type annotation 88 | super(QuantizableDenseNet, self).__init__( 89 | num_classes, 90 | inplanes, 91 | expansion, 92 | growthRate, 93 | compressionRate, 94 | block_configs, 95 | small_input, 96 | efficient, 97 | Block, 98 | ) 99 | self.quant = QuantStub() 100 | self.dequant = DeQuantStub() 101 | 102 | def forward(self, x: torch.Tensor) -> torch.Tensor: 103 | """Forward.""" 104 | x = self.quant(x) 105 | x = self._forward_impl(x) 106 | output = self.dequant(x) 107 | return output 108 | 109 | def fuse_model(self) -> None: 110 | """Fuse modules and create intrinsic opterators.""" 111 | for m in self.modules(): 112 | if type(m) is ConvBNReLU: 113 | fuse_modules(m, ["conv", "bn", "relu"], inplace=True) 114 | if type(m) is ConvBN: 115 | fuse_modules(m, ["conv", "bn"], inplace=True) 116 | 117 | 118 | def get_model(**kwargs: Any) -> nn.Module: 119 | """Construct a Simple model for quantization.""" 120 | return QuantizableDenseNet(**kwargs) 121 | -------------------------------------------------------------------------------- /src/models/quant_mixnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """MixNet for quantization - S / M / L. 3 | 4 | * Note: SHRINKING IS NOT SUPPORTED! 5 | 6 | - Author: Curt-Park 7 | - Email: jwpark@jmarple.ai 8 | - Paper: https://arxiv.org/abs/1907.09595 9 | - Reference: https://github.com/leaderj1001/Mixed-Depthwise-Convolutional-Kernels 10 | """ 11 | 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.quantization import DeQuantStub, QuantStub, fuse_modules 16 | 17 | from src.models.common_activations import QuantizableHSwish 18 | from src.models.common_layers import ( 19 | ConvBN, 20 | ConvBNReLU, 21 | Identity, 22 | QuantizableMDConvBlock, 23 | QuantizableSqueezeExcitation, 24 | ) 25 | from src.models.mixnet import MixBlock, MixNet, get_model_kwargs 26 | 27 | 28 | class QuantizableMixBlock(MixBlock): 29 | """MixBlock: Using different kernel sizes for each channel chunk.""" 30 | 31 | def __init__(self, **kwargs: bool) -> None: 32 | """Initialize.""" 33 | super(QuantizableMixBlock, self).__init__(**kwargs) 34 | self.add = nn.quantized.FloatFunctional() 35 | 36 | if self.in_channels != self.exp_channels: 37 | self.expand_conv = ( 38 | nn.Sequential( 39 | ConvBN(self.in_channels, self.exp_channels, kernel_size=1), 40 | QuantizableHSwish(inplace=True), 41 | ) 42 | if self.hswish 43 | else ConvBNReLU(self.in_channels, self.exp_channels, kernel_size=1) 44 | ) 45 | 46 | self.mdconv = nn.Sequential( 47 | QuantizableMDConvBlock( 48 | in_channels=self.exp_channels, 49 | n_chunks=self.n_chunks, 50 | stride=self.stride, 51 | with_relu=not self.hswish, 52 | ), 53 | Identity() if not self.hswish else QuantizableHSwish(inplace=True), 54 | ) 55 | 56 | self.se = ( 57 | QuantizableSqueezeExcitation( 58 | in_channels=self.out_channels, 59 | se_ratio=self.se_ratio, 60 | ) 61 | if self.has_se 62 | else Identity() 63 | ) 64 | 65 | def _add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 66 | """Sum two tensors (elementwise).""" 67 | return self.add.add(x, y) 68 | 69 | 70 | class QuantizableMixNet(MixNet): 71 | """MixNet architecture.""" 72 | 73 | def __init__(self, **kwargs: bool) -> None: 74 | """Initialize.""" 75 | super(QuantizableMixNet, self).__init__(**kwargs) 76 | self.quant = QuantStub() 77 | self.dequant = DeQuantStub() 78 | 79 | def forward(self, x: torch.Tensor) -> torch.Tensor: 80 | """Forward.""" 81 | x = self.quant(x) 82 | x = self._forward_impl(x) 83 | output = self.dequant(x) 84 | return output 85 | 86 | def fuse_model(self) -> None: 87 | """Fuse modules and create intrinsic opterators.""" 88 | for module in self.modules(): 89 | if type(module) is ConvBNReLU: 90 | fuse_modules(module, ["conv", "bn", "relu"], inplace=True) 91 | if type(module) is ConvBN: 92 | fuse_modules(module, ["conv", "bn"], inplace=True) 93 | 94 | 95 | def get_model(model_type: str, num_classes: int, dataset: str) -> nn.Module: 96 | """Construct a MixNet model.""" 97 | kwargs = get_model_kwargs(model_type, num_classes, dataset) 98 | kwargs.update(dict(Block=QuantizableMixBlock)) 99 | return QuantizableMixNet(**kwargs) 100 | -------------------------------------------------------------------------------- /src/models/quant_resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Quantizable ResNet model loader. 3 | 4 | * Note: SHRINKING IS NOT SUPPORTED! 5 | 6 | - Author: Curt-Park 7 | - Email: jwpark@jmarple.ai 8 | """ 9 | 10 | 11 | import torch.nn as nn 12 | 13 | 14 | def get_model(model_type: str, num_classes: int, pretrained: bool = False) -> nn.Module: 15 | """Construct a ResNet model.""" 16 | assert model_type in ["resnet18", "resnet50", "resnext101_32x8d"] 17 | return getattr( 18 | __import__("torchvision.models.quantization", fromlist=[""]), 19 | model_type, 20 | )(pretrained=pretrained, num_classes=num_classes) 21 | -------------------------------------------------------------------------------- /src/models/quant_simplenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Simple CNN Model for quantization. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.quantization import DeQuantStub, QuantStub 11 | 12 | from src.models.common_layers import ConvBNReLU 13 | from src.models.simplenet import SimpleNet 14 | 15 | 16 | class QuantizableSimpleNet(SimpleNet): 17 | """Quantizable SimpleNet architecture.""" 18 | 19 | def __init__(self, num_classes: int) -> None: 20 | """Initialize.""" 21 | super(QuantizableSimpleNet, self).__init__(num_classes) 22 | self.quant = QuantStub() 23 | self.dequant = DeQuantStub() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | """Forward.""" 27 | x = self.quant(x) 28 | x = self._forward_impl(x) 29 | output = self.dequant(x) 30 | return output 31 | 32 | def fuse_model(self) -> None: 33 | """Fuse modules and create intrinsic opterators. 34 | 35 | Fused modules are provided for common patterns in CNNs. 36 | Combining several operations together (like convolution and relu) 37 | allows for better quantization accuracy. 38 | 39 | References: 40 | https://pytorch.org/docs/stable/quantization.html#torch-nn-intrinsic 41 | """ 42 | for m in self.modules(): 43 | if type(m) is ConvBNReLU: 44 | torch.quantization.fuse_modules(m, ["conv", "bn", "relu"], inplace=True) 45 | 46 | 47 | def get_model(**kwargs: bool) -> nn.Module: 48 | """Construct a Simple model for quantization.""" 49 | return QuantizableSimpleNet(**kwargs) 50 | -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ResNet model loader. 3 | 4 | * Note: SHRINKING IS NOT SUPPORTED! 5 | 6 | - Author: Curt-Park 7 | - Email: jwpark@jmarple.ai 8 | """ 9 | 10 | import torch.nn as nn 11 | 12 | 13 | def get_model(model_type: str, num_classes: int, pretrained: bool = False) -> nn.Module: 14 | """Construct a ResNet model.""" 15 | assert model_type in [ 16 | "resnet18", 17 | "resnet34", 18 | "resnet50", 19 | "resnet101", 20 | "resnet152", 21 | "resnext50_32x4d", 22 | "resnext101_32x8d", 23 | "wide_resnet50_2", 24 | "wide_resnet101_2", 25 | ] 26 | return getattr( 27 | __import__("torchvision.models", fromlist=[""]), 28 | model_type, 29 | )(pretrained=pretrained, num_classes=num_classes) 30 | -------------------------------------------------------------------------------- /src/models/simplenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Simple CNN Model. 3 | 4 | Reference: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from src.models.common_layers import ConvBNReLU 12 | 13 | 14 | class SimpleNet(nn.Module): 15 | """SimpleNet architecture.""" 16 | 17 | def __init__(self, num_classes: int) -> None: 18 | """Initialize.""" 19 | super(SimpleNet, self).__init__() 20 | self.conv1 = ConvBNReLU(3, 32, kernel_size=3) 21 | self.conv2 = ConvBNReLU(32, 64, kernel_size=3) 22 | self.conv3 = ConvBNReLU(64, 128, kernel_size=3) 23 | self.conv4 = ConvBNReLU(128, 128, kernel_size=3) 24 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 25 | self.flatten = nn.Flatten() # type: ignore 26 | self.fc1 = nn.Linear(128, num_classes) # 5x5 image dimension 27 | 28 | def _forward_impl(self, x: torch.Tensor) -> torch.Tensor: 29 | """Forward procedures. 30 | 31 | Actual forward procedures. 32 | """ 33 | out = self.conv1(x) 34 | out = self.conv2(out) 35 | out = F.max_pool2d(out, (2, 2)) 36 | out = self.conv3(out) 37 | out = self.conv4(out) 38 | out = self.avgpool(out) 39 | out = self.flatten(out) 40 | out = self.fc1(out) 41 | return out 42 | 43 | def forward(self, x: torch.Tensor) -> torch.Tensor: 44 | """Forward.""" 45 | return self._forward_impl(x) 46 | 47 | 48 | def get_model(**kwargs: bool) -> nn.Module: 49 | """Construct a Simple model.""" 50 | return SimpleNet(**kwargs) 51 | -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Utils for handling models. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from collections import OrderedDict 9 | import hashlib 10 | import os 11 | import re 12 | import tarfile 13 | from typing import Any, Dict, List, Optional, Set, Tuple 14 | 15 | import gdown 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.utils.prune as prune 20 | import wandb 21 | import yaml # type: ignore 22 | 23 | from src.logger import colorstr, get_logger 24 | 25 | LOGGER = get_logger(__name__) 26 | 27 | 28 | def get_model(model_name: str, model_config: Dict[str, Any]) -> nn.Module: 29 | """Get PyTorch model.""" 30 | # get model constructor 31 | return __import__("src.models." + model_name, fromlist=[model_name]).get_model( 32 | **model_config 33 | ) 34 | 35 | 36 | def load_decomposed_model( 37 | weight_path: str, model_cfg_path: str = "", load_ema: bool = True 38 | ) -> Optional[nn.Module]: 39 | """Load PyTorch model. 40 | 41 | Args: 42 | weight_path: weight path which ends with .pt 43 | model_cfg_path: if provided, the model will first construct by the model_cfg, 44 | and transfer weights to the constructed model. 45 | In case of model_cfg_path was provided but not weight_path, 46 | the model weights will be randomly initialized 47 | (for experiment purpose). 48 | load_ema: load EMA weights if possible. 49 | Return: 50 | PyTorch model, 51 | None if loading PyTorch model has failed. 52 | """ 53 | if weight_path == "": 54 | LOGGER.warning( 55 | "Providing " 56 | + colorstr("bold", "no weights path") 57 | + " will validate a randomly initialized model. Please use only for a experiment purpose." 58 | ) 59 | else: 60 | ckpt = torch.load(weight_path) 61 | if isinstance(ckpt, dict): 62 | model_key = ( 63 | "ema" 64 | if load_ema and "ema" in ckpt.keys() and ckpt["ema"] is not None 65 | else "model" 66 | ) 67 | ckpt_model = ckpt[model_key] 68 | elif isinstance(ckpt, nn.Module): 69 | ckpt_model = ckpt 70 | 71 | ckpt_model = ckpt_model.cpu().float() 72 | 73 | if ckpt_model is None and model_cfg_path == "": 74 | LOGGER.warning("No weights and no model_cfg has been found.") 75 | return None 76 | 77 | model = ckpt_model 78 | 79 | return model 80 | 81 | 82 | def initialize_params( 83 | model: Any, state_dict: Dict[str, Any], with_mask: bool = True 84 | ) -> None: 85 | """Initialize weights and masks.""" 86 | model_dict = model.state_dict() 87 | # 1. filter out unnecessary keys 88 | pretrained_dict = OrderedDict() 89 | for key_ori, key_pre in zip(model_dict.keys(), state_dict.keys()): 90 | if with_mask or ("weight_mask" not in key_ori and "bias_mask" not in key_ori): 91 | pretrained_dict[key_ori] = state_dict[key_pre] 92 | # 3. load the new state dict 93 | model_dict.update(pretrained_dict) 94 | model.load_state_dict(model_dict) 95 | 96 | 97 | def get_model_hash(model: nn.Module) -> str: 98 | """Get model info as hash.""" 99 | return hashlib.sha224(str(model).encode("UTF-8")).hexdigest() 100 | 101 | 102 | def get_pretrained_model_info(model: nn.Module) -> Dict[str, str]: 103 | """Read yaml file and get pretrained model. 104 | 105 | Read yaml file, get pretrained model information(model_dir, gdrive_link) given hash. 106 | """ 107 | model_hash = str(get_model_hash(model)) 108 | with open("config/pretrained_model_url.yaml", mode="r") as f: 109 | model_info = yaml.load(f, Loader=yaml.FullLoader)[model_hash] 110 | return model_info 111 | 112 | 113 | def get_model_tensor_datatype(model: nn.Module) -> List[Tuple[str, torch.dtype]]: 114 | """Print all tensors data types.""" 115 | return [ 116 | (name, tensor.dtype) 117 | for name, tensor in model.state_dict().items() 118 | if hasattr(tensor, "dtype") 119 | ] 120 | 121 | 122 | def get_params( 123 | model: nn.Module, extract_conditions: Tuple[Tuple[Any, str], ...] 124 | ) -> Tuple[Tuple[nn.Module, str], ...]: 125 | """Get parameters(weight and bias) tuples for pruning.""" 126 | t = [] 127 | for module in model.modules(): 128 | for module_type, param_name in extract_conditions: 129 | # it returns true when we try hasattr(even though it returns None) 130 | if ( 131 | isinstance(module, module_type) 132 | and getattr(module, param_name) is not None 133 | ): 134 | t += [(module, param_name)] 135 | return tuple(t) 136 | 137 | 138 | def get_layernames(model: nn.Module) -> Set[str]: 139 | """Get parameters(weight and bias) layer name. 140 | 141 | Notes: 142 | No usage now, can be deprecated. 143 | """ 144 | t = set() 145 | for name, param in model.named_parameters(): 146 | if not param.requires_grad: 147 | continue 148 | layer_name = name.rsplit(".", 1)[0] 149 | t.add(layer_name) 150 | return t 151 | 152 | 153 | def get_model_size_mb(model: nn.Module) -> float: 154 | """Get the model file size.""" 155 | torch.save(model.state_dict(), "temp.p") 156 | size = os.path.getsize("temp.p") / 1e6 157 | os.remove("temp.p") 158 | return size 159 | 160 | 161 | def remove_pruning_reparameterization( 162 | params_to_prune: Tuple[Tuple[nn.Module, str], ...] 163 | ) -> None: 164 | """Combine (weight_orig, weight_mask) and reduce the model size.""" 165 | for module, weight_type in params_to_prune: 166 | prune.remove(module, weight_type) 167 | 168 | 169 | def get_masks(model: nn.Module) -> Dict[str, torch.Tensor]: 170 | """Get masks from the model.""" 171 | mask = dict() 172 | for k, v in model.state_dict().items(): 173 | if "mask" in k: 174 | mask[k] = v.detach().cpu().clone() 175 | return mask 176 | 177 | 178 | def dummy_pruning(params_all: Tuple[Tuple[nn.Module, str], ...]) -> None: 179 | """Conduct fake pruning.""" 180 | prune.global_unstructured( 181 | params_all, 182 | pruning_method=prune.L1Unstructured, 183 | amount=0.0, 184 | ) 185 | 186 | 187 | def sparsity( 188 | params_all: Tuple[Tuple[nn.Module, str], ...], 189 | module_types: Tuple[Any, ...] = ( 190 | nn.Conv2d, 191 | nn.Linear, 192 | nn.BatchNorm1d, 193 | nn.BatchNorm2d, 194 | ), 195 | ) -> float: 196 | """Get the proportion of zeros in weights (default: model's sparsity).""" 197 | n_zero = n_total = 0 198 | 199 | for module, param_name in params_all: 200 | match = next((m for m in module_types if type(module) is m), None) 201 | if not match: 202 | continue 203 | n_zero += int(torch.sum(getattr(module, param_name) == 0.0).item()) 204 | n_total += getattr(module, param_name).nelement() 205 | 206 | return (100.0 * n_zero / n_total) if n_total != 0 else 0.0 207 | 208 | 209 | def mask_sparsity( 210 | params_all: Tuple[Tuple[nn.Module, str], ...], 211 | module_types: Tuple[Any, ...] = ( 212 | nn.Conv2d, 213 | nn.Linear, 214 | nn.BatchNorm1d, 215 | nn.BatchNorm2d, 216 | ), 217 | ) -> float: 218 | """Get the ratio of zeros in weight masks.""" 219 | n_zero = n_total = 0 220 | for module, param_name in params_all: 221 | match = next((m for m in module_types if type(module) is m), None) 222 | if not match: 223 | continue 224 | param_mask_name = param_name + "_mask" 225 | if hasattr(module, param_mask_name): 226 | param = getattr(module, param_mask_name) 227 | n_zero += int(torch.sum(param == 0.0).item()) 228 | n_total += param.nelement() 229 | 230 | return (100.0 * n_zero / n_total) if n_total != 0 else 0.0 231 | 232 | 233 | def download_pretrained_model(file_path: str, download_link: str) -> None: 234 | """Get pretrained model from google drive.""" 235 | model_folder, model_name, file_name = file_path.rsplit(os.path.sep, 2) 236 | if not os.path.exists(model_folder): 237 | os.makedirs(model_folder) 238 | # Download, unzip 239 | zip_file_path = os.path.join(model_folder, model_name + ".tar.xz") 240 | gdown.download(download_link, zip_file_path) 241 | with tarfile.open(zip_file_path, "r:*") as f: 242 | f.extractall(model_folder) 243 | 244 | 245 | def dot2bracket(s: str) -> str: 246 | """Replace layer names with valid names for pruning. 247 | 248 | Test: 249 | >>> dot2bracket("dense2.1.bn1.bias") 250 | 'dense2[1].bn1.bias' 251 | >>> dot2bracket("dense2.13.bn1.bias") 252 | 'dense2[13].bn1.bias' 253 | >>> dot2bracket("conv2.123.bn1.bias") 254 | 'conv2[123].bn1.bias' 255 | >>> dot2bracket("dense2.6.conv2.5.bn1.bias") 256 | 'dense2[6].conv2[5].bn1.bias' 257 | >>> dot2bracket("model.6") 258 | 'model[6]' 259 | >>> dot2bracket("vgg.2.conv2.bn.2") 260 | 'vgg[2].conv2.bn[2]' 261 | >>> dot2bracket("features.11") 262 | 'features[11]' 263 | >>> dot2bracket("dense_blocks.0.0.conv1") 264 | 'dense_blocks[0][0].conv1' 265 | """ 266 | pattern = r"\.[0-9]+" 267 | s_list = list(s) 268 | for m in re.finditer(pattern, s): 269 | start, end = m.span() 270 | # e.g s_list == [..., ".", "0", ".", "0", ".", ...] 271 | # step1: [..., "[", "0", "].", "0", ".", ...] 272 | # step2: [..., "[", "0", "][", "0", "].", ...] 273 | s_list[start] = s_list[start][:-1] + "[" 274 | if end < len(s) and s_list[end] == ".": 275 | s_list[end] = "]." 276 | else: 277 | s_list.insert(end, "]") 278 | return "".join(s_list) 279 | 280 | 281 | def wlog_weight(model: nn.Module) -> None: 282 | """Log weights on wandb.""" 283 | wlog = dict() 284 | for name, param in model.named_parameters(): 285 | if not param.requires_grad: 286 | continue 287 | layer_name, weight_type = name.rsplit(".", 1) 288 | 289 | # get params(weight, bias, weight_orig) 290 | if weight_type in ("weight", "bias", "weight_orig"): 291 | w_name = "params/" + layer_name + "." + weight_type 292 | weight = eval("model." + dot2bracket(layer_name) + "." + weight_type) 293 | weight = weight.cpu().data.numpy() 294 | wlog.update({w_name: wandb.Histogram(weight)}) 295 | else: 296 | continue 297 | 298 | # get masked weights 299 | if weight_type == "weight_orig": 300 | w_name = "params/" + layer_name + ".weight" 301 | named_buffers = eval( 302 | "model." + dot2bracket(layer_name) + ".named_buffers()" 303 | ) 304 | mask: Tuple[str, torch.Tensor] = ( 305 | next(x for x in list(named_buffers) if x[0] == "weight_mask")[1] 306 | .cpu() 307 | .data.numpy() 308 | ) 309 | masked_weight = weight[np.where(mask == 1.0)] 310 | wlog.update({w_name: wandb.Histogram(masked_weight)}) 311 | wandb.log(wlog, commit=False) 312 | 313 | 314 | def split_channels(n_channels: int, n_chunks: int) -> List[int]: 315 | """Get splitted channel numbers. 316 | 317 | It adds up all the remainders to the first chunck. 318 | """ 319 | split = [n_channels // n_chunks for _ in range(n_chunks)] 320 | split[0] += n_channels - sum(split) 321 | return split 322 | 323 | 324 | def count_model_params(model: nn.Module) -> int: 325 | """Count and return the total number of model params.""" 326 | return sum(p.numel() for p in model.parameters()) 327 | 328 | 329 | if __name__ == "__main__": 330 | import doctest 331 | 332 | doctest.testmod() 333 | -------------------------------------------------------------------------------- /src/plotter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Plotter. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | 8 | import os 9 | from typing import List, NamedTuple, Sequence, Tuple, Union 10 | 11 | import PIL.Image 12 | import matplotlib.axes 13 | import matplotlib.pyplot as plt 14 | from matplotlib.ticker import FixedLocator 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import wandb 19 | 20 | from src.utils import get_logger 21 | 22 | # dummy file change 23 | logger = get_logger() 24 | 25 | 26 | class PruneStat(NamedTuple): 27 | """NamedTuple to handle prune statisics.""" 28 | 29 | pruned: Sequence[Union[int, float]] 30 | remained: Sequence[Union[int, float]] 31 | zero: Sequence[Union[int, float]] 32 | nonzero: Sequence[Union[int, float]] 33 | 34 | 35 | class Plotter: 36 | """Plotter for models. 37 | 38 | Currently, it only plots sparsity information of each layer of the model, but it can 39 | be utilized for plotting all sort of infomration. 40 | """ 41 | 42 | def __init__(self, wandb_log: bool) -> None: 43 | """Initialize.""" 44 | # params to plot 45 | self.width = 0.4 46 | self.leftmargin = 0.2 47 | self.rightmargin = 0.2 48 | self.wandb_log = wandb_log 49 | self.total_sparsity = 0.0 50 | 51 | def plot_conf_mat(self, conf_mat: np.ndarray, save_dir: str, epoch: int) -> None: 52 | """Save a confusion matrix as an image.""" 53 | fig = plt.figure(figsize=(10, 10)) 54 | ax = fig.add_subplot(1, 1, 1) 55 | ax.matshow(conf_mat) 56 | # Gridlines based on minor ticks 57 | ax.xaxis.set_major_locator(FixedLocator(np.linspace(0, 41, 1))) 58 | ax.yaxis.set_major_locator(FixedLocator(np.linspace(0, 41, 1))) 59 | 60 | ax.grid(which="minor", color="w", linestyle="-", linewidth=2) 61 | fig.savefig(save_dir + os.path.sep + str(epoch)) 62 | if self.wandb_log: 63 | pil_image = PIL.Image.frombytes( 64 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() 65 | ) 66 | wandb.log( 67 | { 68 | "Pruned/" 69 | + "confusion_matrix": [ 70 | wandb.Image(pil_image, caption="Confusion matrix") 71 | ] 72 | }, 73 | commit=False, 74 | ) 75 | plt.close(fig) 76 | 77 | def plot(self, model: nn.Module, path: str) -> None: 78 | """Plot sparsity information and save into given path(and wandb if enabled).""" 79 | layer_names, params, ratio = self._get_prune_statistics(model) 80 | self._plot_pruned_stats( 81 | layer_names, params, os.path.join(path, "parameters.png") 82 | ) 83 | self._plot_pruned_stats(layer_names, ratio, os.path.join(path, "ratio.png")) 84 | 85 | def _get_prune_statistics( 86 | self, model: nn.Module 87 | ) -> Tuple[List[str], PruneStat, PruneStat]: 88 | """Get prune statisics for each layer.""" 89 | layer_names = [] 90 | 91 | pruned_params, remained_params = [], [] 92 | zero_params, nonzero_params = [], [] 93 | pruned_ratio, remained_ratio = [], [] 94 | zero_ratio, nonzero_ratio = [], [] 95 | 96 | for name, module in model.named_modules(): 97 | if type(module) not in (nn.Conv2d, nn.Linear): 98 | continue 99 | if not (hasattr(module, "weight_mask") and hasattr(module, "weight")): 100 | continue 101 | layer_names.append(name) 102 | total = getattr(module, "weight_mask").nelement() 103 | pruned = int(torch.sum(getattr(module, "weight_mask") == 0.0).item()) 104 | zero = int(torch.sum(getattr(module, "weight_mask") == 0.0).item()) 105 | 106 | pruned_params.append(pruned) 107 | remained_params.append(total - pruned) 108 | 109 | pruned_ratio.append(pruned / total) 110 | remained_ratio.append(1 - pruned / total) 111 | 112 | zero_params.append(zero) 113 | nonzero_params.append(total - zero) 114 | 115 | zero_ratio.append(zero / total) 116 | nonzero_ratio.append(1 - zero / total) 117 | 118 | params = PruneStat(pruned_params, remained_params, zero_params, nonzero_params) 119 | ratio = PruneStat(pruned_ratio, remained_ratio, zero_ratio, nonzero_ratio) 120 | 121 | self.total_sparsity = sum(pruned_params) / ( 122 | sum(pruned_params) + sum(remained_params) 123 | ) 124 | 125 | return layer_names, params, ratio 126 | 127 | def _plot_pruned_stats( 128 | self, x_names: List[str], stats: PruneStat, save_path: str 129 | ) -> None: 130 | """Plot pruned parameters for each layers.""" 131 | # extract type save_path: 'path+type.png' 132 | stat_type = save_path.rsplit(".", 3)[0].rsplit("/", 1)[1] 133 | 134 | fig, ax = self._get_fig(x_names) 135 | x = np.arange(len(x_names)) 136 | 137 | kargs_base = dict(width=self.width, edgecolor="black") 138 | kargs_first_bar = {**kargs_base, "x": x - 1 / 2 * self.width} # type: ignore 139 | kargs_second_bar = {**kargs_base, "x": x + 1 / 2 * self.width} # type: ignore 140 | 141 | # draw first bar(pruned, remained) 142 | kargs_pruned = dict( 143 | **kargs_first_bar, 144 | height=stats.pruned, 145 | bottom=stats.remained, 146 | color="w", 147 | label="Pruned", 148 | ) 149 | kargs_remained = dict( 150 | **kargs_first_bar, height=stats.remained, label="Remained" 151 | ) 152 | 153 | # return needed only when we annotate info on bars 154 | bar_pruned = ax.bar(**kargs_pruned) 155 | bar_remained = ax.bar(**kargs_remained) 156 | 157 | # draw second bar(zero, nonzero) 158 | kargs_zero = dict( 159 | **kargs_second_bar, 160 | height=stats.zero, 161 | bottom=stats.nonzero, 162 | color="w", 163 | label="Zero", 164 | ) 165 | kargs_nonzero = dict(**kargs_second_bar, height=stats.nonzero, label="Nonzero") 166 | ax.bar(**kargs_zero) 167 | ax.bar(**kargs_nonzero) 168 | 169 | # annotate on top of bars 170 | self._annotate_on_bar(ax, bars=bar_remained) 171 | self._annotate_on_stacked_bars( 172 | ax, bars=bar_pruned, bottom_bars=bar_remained, addup_bottom_bar_data=True 173 | ) 174 | 175 | # draw info on figure 176 | ax.set_ylabel(stat_type.capitalize()) 177 | ax.set_title( 178 | f"Model layerwise statistics, total sparsity: {100 * self.total_sparsity:.2f}%" 179 | ) 180 | ax.set_xticks(x) 181 | ax.set_xticklabels(x_names, rotation=60, horizontalalignment="right") 182 | ax.legend() 183 | ax.autoscale(enable=True) 184 | _, ymax = ax.get_ylim() 185 | ax.set_ylim(top=ymax * 1.1) 186 | # plot pruned on top of remainder(shows remainder only) 187 | fig.savefig(save_path) 188 | if self.wandb_log: 189 | pil_image = PIL.Image.frombytes( 190 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() 191 | ) 192 | wandb.log( 193 | { 194 | "Pruned/" 195 | + stat_type: [ 196 | wandb.Image(pil_image, caption=stat_type.capitalize()) 197 | ] 198 | }, 199 | commit=False, 200 | ) 201 | 202 | def _annotate_on_bar( 203 | self, 204 | ax: matplotlib.axes.Axes, 205 | bars: List[matplotlib.axes.Axes.bar], 206 | ) -> None: 207 | """Attach a text label above each bar in rects, displaying its height.""" 208 | for _, bar in enumerate(bars): 209 | height = bar.get_height() 210 | display_value = height 211 | self._ax_annotate(ax, bar, display_value, height) 212 | 213 | def _annotate_on_stacked_bars( 214 | self, 215 | ax: matplotlib.axes.Axes, 216 | bars: List[matplotlib.axes.Axes.bar], 217 | bottom_bars: List[matplotlib.axes.Axes.bar], 218 | addup_bottom_bar_data: bool = False, 219 | ) -> None: 220 | """Same as annotate_on_bar but on top of stacked bars.""" # noqa: D401 221 | for i, bar in enumerate(bars): 222 | bottom_height = bottom_bars[i].get_height() 223 | height = bar.get_height() 224 | # value display can be either summed data bar below or only top bar. 225 | display_value = height + bottom_height if addup_bottom_bar_data else height 226 | self._ax_annotate(ax, bar, display_value, height + bottom_height) 227 | 228 | def _ax_annotate( 229 | self, 230 | ax: matplotlib.axes.Axes, 231 | bar: matplotlib.axes.Axes.bar, 232 | display_value: Union[int, float], 233 | height: float, 234 | ) -> None: 235 | """Warpper for ax.annotate.""" 236 | ax.annotate( 237 | f"{display_value*100:.1f}%" 238 | if isinstance(display_value, float) 239 | else f"{display_value}", 240 | xytext=(0, 50) 241 | if isinstance(display_value, float) 242 | else (0, 10), # 3 points vertical offset 243 | rotation=90 if isinstance(display_value, float) else 0, 244 | fontsize="large", 245 | xy=(bar.get_x() + bar.get_width() / 2, height), 246 | textcoords="offset points", 247 | ha="center", 248 | va="top", 249 | ) 250 | 251 | def _get_fig( 252 | self, labels: List[str] 253 | ) -> Tuple[matplotlib.pyplot.figure, matplotlib.axes.Axes]: 254 | """Get figure, axes.""" 255 | figwidth = self.leftmargin + self.rightmargin + (len(labels) + 1) * self.width 256 | if figwidth < 8: 257 | figwidth = 8 258 | 259 | fig, ax = plt.subplots(figsize=(figwidth, 14)) 260 | return fig, ax 261 | -------------------------------------------------------------------------------- /src/regularizers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Collection of regularizers. 3 | 4 | - Author: Junghoon Kim 5 | - Email: jhkim@jmarple.ai 6 | """ 7 | 8 | from typing import Any, Dict 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class BnWeight(nn.Module): 15 | """Apply L1 regularizer on BatchNorm weight. 16 | 17 | Reference: 18 | Learning Efficient Convolutional Networks through Network Slimming 19 | (https://arxiv.org/pdf/1708.06519.pdf) 20 | 21 | Attributes: 22 | model (nn.Module): Model to apply regularizer. 23 | coefficient (float): weight to regularize. 24 | """ 25 | 26 | def __init__(self, coeff: float) -> None: 27 | """Initlaize.""" 28 | super().__init__() 29 | self.coeff = coeff 30 | 31 | def forward(self, model: nn.Module) -> float: 32 | """Forward.""" 33 | reg = 0.0 34 | for m in model.modules(): 35 | if isinstance(m, nn.BatchNorm2d): 36 | reg += self.coeff * torch.norm(input=m.weight.data, p=1) 37 | return reg 38 | 39 | 40 | def get_regularizer( 41 | regularizer_name: str, 42 | regularizer_params: Dict[str, Any], 43 | ) -> nn.Module: 44 | """Create regularizer class.""" 45 | if not regularizer_params: 46 | regularizer_params = dict() 47 | return eval(regularizer_name)(**regularizer_params) 48 | -------------------------------------------------------------------------------- /src/runners/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Initialization for training or pruning. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | 9 | import datetime 10 | import glob 11 | import os 12 | from runpy import run_path 13 | import shutil 14 | from typing import Any, Dict, Tuple 15 | 16 | import torch 17 | 18 | from config.config_validator import ( 19 | PruneConfigValidator, 20 | QuantizeConfigValidator, 21 | ShrinkConfigValidator, 22 | TrainConfigValidator, 23 | ) 24 | import src.utils as utils 25 | 26 | # create directories 27 | curr_time = datetime.datetime.now().strftime("%y%m%d_%H%M%S") 28 | checkpt_path = os.path.join("save", "checkpoint") 29 | 30 | 31 | def initialize( 32 | mode: str, 33 | config_path: str, 34 | resume: str = "", 35 | multi_gpu: bool = False, 36 | gpu_id: int = -1, 37 | ) -> Tuple[Dict[str, Any], str, torch.device]: 38 | """Intialize.""" 39 | # setup device 40 | device = torch.device("cpu") 41 | if torch.cuda.is_available(): 42 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( 43 | f"{i}" for i in range(torch.cuda.device_count()) 44 | ) 45 | if multi_gpu: 46 | device = torch.device("cuda") 47 | elif 0 <= gpu_id < torch.cuda.device_count(): 48 | device = torch.device(f"cuda:{gpu_id}") 49 | 50 | # create directory 51 | dirs_in_save = ["", "data", "checkpoint"] 52 | dirs_in_save += [os.path.join("checkpoint", curr_time)] if not resume else [] 53 | for name in dirs_in_save: 54 | path = os.path.join("save", name) 55 | if not os.path.exists(path): 56 | os.mkdir(path) 57 | 58 | # resume or load existing configurations 59 | if resume: 60 | dir_prefix = resume 61 | assert os.path.exists(dir_prefix), f"{dir_prefix} does not exist" 62 | config_path = glob.glob(os.path.join(dir_prefix, "*.py"))[0] 63 | config_name = os.path.basename(config_path) 64 | else: 65 | assert os.path.exists(config_path), "--config required" 66 | dir_prefix = os.path.join(checkpt_path, curr_time) 67 | config_name = os.path.basename(config_path) 68 | shutil.copyfile(config_path, os.path.join(dir_prefix, config_name)) 69 | config = run_path(config_path)["config"] 70 | 71 | # set logger 72 | config_name = os.path.splitext(config_name)[0] 73 | utils.set_logger(filename=os.path.join(dir_prefix, f"{config_name}.log")) 74 | 75 | # config validation check 76 | if mode == "train": 77 | TrainConfigValidator(config).check() 78 | elif mode == "prune": 79 | PruneConfigValidator(config).check() 80 | elif mode == "quantize": 81 | QuantizeConfigValidator(config).check() 82 | elif mode == "shrink": 83 | ShrinkConfigValidator(config).check() 84 | elif mode == "val": 85 | TrainConfigValidator(config).check() 86 | else: 87 | raise NotImplementedError 88 | 89 | # set random seed 90 | utils.set_random_seed(config["SEED"]) 91 | 92 | return config, dir_prefix, device 93 | -------------------------------------------------------------------------------- /src/runners/quantizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Quantizer for trained models. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | 9 | import os 10 | from typing import Any, Dict 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.quantization 15 | 16 | from src.models import utils as model_utils 17 | from src.runners.runner import Runner 18 | from src.runners.trainer import Trainer 19 | import src.utils as utils 20 | 21 | logger = utils.get_logger() 22 | 23 | 24 | def print_datatypes(model: nn.Module, model_name: str, sep: str = "\n") -> None: 25 | """Print all datatypes in the model.""" 26 | log = model_name + "'s datatypes:" + sep 27 | log += sep.join(str(t) for t in model_utils.get_model_tensor_datatype(model)) 28 | logger.info(log) 29 | 30 | 31 | class Quantizer(Runner): 32 | """Quantizer for trained models.""" 33 | 34 | def __init__( 35 | self, 36 | config: Dict[str, Any], 37 | checkpoint_path: str, 38 | dir_prefix: str, 39 | static: bool, 40 | check_acc: bool, 41 | backend: str, 42 | wandb_log: bool, 43 | wandb_init_params: Dict[str, Any], 44 | ) -> None: 45 | """Initialize.""" 46 | super(Quantizer, self).__init__(config, dir_prefix) 47 | self.mask: Dict[str, torch.Tensor] = dict() 48 | self.params_pruned = None 49 | self.check_acc = check_acc 50 | self.static = static 51 | self.backend = backend 52 | 53 | # create a trainer 54 | self.trainer = Trainer( 55 | config=self.config, 56 | dir_prefix=dir_prefix, 57 | checkpt_dir="qat", 58 | device="cpu", 59 | wandb_log=wandb_log, 60 | wandb_init_params=wandb_init_params, 61 | test_preprocess_hook=self._quantize, 62 | ) 63 | 64 | self.model = self.trainer.model 65 | self.params_all = model_utils.get_params( 66 | self.model, 67 | ( 68 | (nn.Conv2d, "weight"), 69 | (nn.Conv2d, "bias"), 70 | (nn.BatchNorm2d, "weight"), 71 | (nn.BatchNorm2d, "bias"), 72 | (nn.Linear, "weight"), 73 | (nn.Linear, "bias"), 74 | ), 75 | ) 76 | 77 | # initialize the model 78 | self._init_model(checkpoint_path) 79 | 80 | def run(self, resume_info_path: str = "") -> None: 81 | """Run quantization.""" 82 | # print_datatypes(self.model, "original model") 83 | self.trainer.warmup_one_iter() 84 | orig_model_path = os.path.join(self.dir_prefix, "orig_model.pth") 85 | torch.save(self.model.state_dict(), orig_model_path) 86 | size = os.path.getsize(orig_model_path) / 1e6 87 | logger.info(f"Acc: {self.orig_acc} %\tSize: {size:.6f} MB") 88 | 89 | # fuse the model 90 | self._prepare() 91 | # print_datatypes(self.model, "Fused model") 92 | 93 | # post training static quantization 94 | if self.static: 95 | logger.info("Post Training Static Quantization: Run calibration") 96 | self.trainer.warmup_one_iter() 97 | # quantization-aware training 98 | else: 99 | logger.info("Quantization Aware Training: Run training") 100 | self.trainer.run(resume_info_path) 101 | self.model.apply(torch.quantization.disable_observer) 102 | self.model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) 103 | # load the best model 104 | self._load_best_model() 105 | 106 | # quantize the model 107 | quantized_model = self._quantize(self.model) 108 | if self.check_acc: 109 | _, acc = self.trainer.test_one_epoch() 110 | acc = f"{acc['model_acc']:.2f}" 111 | else: 112 | self.trainer.warmup_one_iter() 113 | acc = "None" 114 | quantized_model_path = os.path.join(self.dir_prefix, "quantized_model.pth") 115 | torch.save(quantized_model.state_dict(), quantized_model_path) 116 | size = os.path.getsize(quantized_model_path) / 1e6 117 | logger.info(f"Acc: {acc} %\tSize: {size:.6f} MB") 118 | 119 | # script the model 120 | scripted_model = torch.jit.script(quantized_model) 121 | # print_datatypes(scripted_model, "Scripted model") 122 | 123 | if self.check_acc: 124 | _, acc = self.trainer.test_one_epoch_model(scripted_model) 125 | acc = f"{acc['model_acc']:.2f}" 126 | else: 127 | self.trainer.warmup_one_iter() 128 | acc = "None" 129 | scripted_model_path = os.path.join(self.dir_prefix, "scripted_model.pth") 130 | torch.jit.save(scripted_model, scripted_model_path) 131 | size = os.path.getsize(scripted_model_path) / 1e6 132 | logger.info(f"Acc: {acc} %\tSize: {size:.6f} MB") 133 | 134 | def _init_model(self, checkpoint_path: str) -> None: 135 | """Create a model instance and load weights.""" 136 | # load weights 137 | logger.info(f"Load weights from the checkpoint {checkpoint_path}") 138 | checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) 139 | 140 | state_dict = checkpoint["state_dict"] 141 | self.orig_acc = checkpoint["test_acc"] 142 | 143 | is_pruned = ( 144 | next((name for name in state_dict if "mask" in name), None) is not None 145 | ) 146 | 147 | if is_pruned: 148 | logger.info("Dummy prunning to load pruned weights") 149 | model_utils.dummy_pruning(self.params_all) 150 | 151 | model_utils.initialize_params(self.model, state_dict) 152 | logger.info("Initialized weights") 153 | 154 | # check the trained model is pruned 155 | 156 | if is_pruned: 157 | logger.info( 158 | "Get masks and remove prunning reparameterization for prepare_qat" 159 | ) 160 | self.mask = model_utils.get_masks(self.model) 161 | model_utils.remove_pruning_reparameterization(self.params_all) 162 | 163 | def _prepare(self) -> None: 164 | """Quantize the model.""" 165 | self.model.fuse_model() 166 | 167 | # configuration 168 | self.model.qconfig = torch.quantization.get_default_qat_qconfig(self.backend) 169 | 170 | # prepare 171 | if self.static: 172 | torch.quantization.prepare(self.model, inplace=True) 173 | else: 174 | torch.quantization.prepare_qat(self.model, inplace=True) 175 | 176 | # load masks 177 | self._load_masks() 178 | 179 | def _load_masks(self) -> None: 180 | """Load masks.""" 181 | if not self.mask: 182 | return 183 | 184 | model_utils.dummy_pruning(self.params_all) 185 | for name, _ in self.model.named_buffers(): 186 | if name in self.mask: 187 | module_name, mask_name = name.rsplit(".", 1) 188 | module = eval("self.model." + module_name) 189 | module._buffers[mask_name] = self.mask[name] 190 | 191 | def _load_best_model(self) -> None: 192 | """Load the trained model with the best accuracy.""" 193 | self.trainer.resume() 194 | 195 | def _quantize(self, model: nn.Module) -> nn.Module: 196 | """Quantize the trained model.""" 197 | if self.mask: 198 | model_utils.remove_pruning_reparameterization(self.params_all) 199 | 200 | # check the accuracy after each epoch 201 | quantized_model = torch.quantization.convert(model.eval(), inplace=False) 202 | quantized_model.eval() 203 | 204 | # set masks again 205 | if self.mask: 206 | self._load_masks() 207 | 208 | return quantized_model 209 | -------------------------------------------------------------------------------- /src/runners/runner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Abstract Runner class which contains methods to implement. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from abc import ABC, abstractmethod 9 | import os 10 | from typing import Any, Dict 11 | 12 | 13 | class Runner(ABC): 14 | """Abstract class used by runners (e.g. trainer, pruner).""" 15 | 16 | def __init__(self, config: Dict[str, Any], dir_prefix: str) -> None: 17 | """Initialize.""" 18 | self.config = config 19 | self.dir_prefix = dir_prefix 20 | self.fileext = "pth.tar" 21 | self.checkpt_paths = "checkpt_paths.log" 22 | 23 | @abstractmethod 24 | def run(self, resume_info_path: str = "") -> None: 25 | """Run the module.""" 26 | pass 27 | 28 | def _fetch_latest_checkpt(self) -> str: 29 | """Fetch the latest checkpoint file path from the log file.""" 30 | checkpt_paths = os.path.join(self.dir_prefix, self.checkpt_paths) 31 | if not os.path.exists(checkpt_paths): 32 | return "" 33 | latest_file_path = "" 34 | with open(checkpt_paths, "r") as checkpts: 35 | checkpts_list = checkpts.readlines() 36 | if checkpts_list: 37 | latest_file_path = checkpts_list[-1][:-1] # w/o '\n' 38 | return latest_file_path 39 | -------------------------------------------------------------------------------- /src/runners/trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Trainer for models. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | from collections import defaultdict 9 | import os 10 | from typing import Any, Callable, DefaultDict, Dict, List, Tuple 11 | 12 | from progressbar import progressbar 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | import wandb 18 | 19 | from src.augmentation.methods import CutMix 20 | from src.criterions import get_criterion 21 | from src.format import default_format, percent_format 22 | from src.lr_schedulers import get_lr_scheduler 23 | from src.models import utils as model_utils 24 | from src.regularizers import get_regularizer 25 | from src.runners.runner import Runner 26 | import src.utils as utils 27 | 28 | logger = utils.get_logger() 29 | 30 | 31 | class Trainer(Runner): 32 | """Trainer for models.""" 33 | 34 | def __init__( 35 | self, 36 | config: Dict[str, Any], 37 | dir_prefix: str, 38 | checkpt_dir: str, 39 | device: torch.device, 40 | finetune: str = "", 41 | wandb_log: bool = False, 42 | wandb_init_params: Dict[str, Any] = None, 43 | half: bool = False, 44 | test_preprocess_hook: Callable[[nn.Module], nn.Module] = None, 45 | ) -> None: 46 | """Initialize.""" 47 | super(Trainer, self).__init__(config, dir_prefix) 48 | self.half = half 49 | self.device = device 50 | self.wandb_log = wandb_log 51 | self.reset(checkpt_dir) 52 | self.test_preprocess_hook = test_preprocess_hook 53 | 54 | # create a model 55 | model_name = self.config["MODEL_NAME"] 56 | model_config = self.config["MODEL_PARAMS"] 57 | self.model = model_utils.get_model(model_name, model_config).to(self.device) 58 | if device == torch.device("cuda"): # multi-gpu 59 | self.model = torch.nn.DataParallel(self.model).to(self.device) 60 | 61 | # load a model to finetune 62 | if finetune and os.path.exists(finetune): 63 | self.load_model(finetune) 64 | 65 | if self.half: 66 | self.model.half() 67 | 68 | # count parameter numbers in the model 69 | n_params = model_utils.count_model_params(self.model) 70 | logger.info( 71 | f"Created a model {self.config['MODEL_NAME']} with {(n_params / 10**6):.2f}M params" 72 | ) 73 | 74 | logger.info("Setup train configuration") 75 | self.setup_train_configuration(self.config) 76 | 77 | # create logger 78 | if wandb_log: 79 | wandb_init_params = wandb_init_params if wandb_init_params else dict() 80 | wandb.init(**wandb_init_params) 81 | 82 | self.n_correct_epoch: DefaultDict[str, int] = defaultdict(lambda: 0) 83 | 84 | def setup_train_configuration(self, config: Dict[str, Any]) -> None: 85 | """Set up train configuration.""" 86 | self.config = config 87 | self.total_epochs = self.config["EPOCHS"] 88 | 89 | # get datasets 90 | trainset, testset = utils.get_dataset( 91 | config["DATASET"], 92 | config["AUG_TRAIN"], 93 | config["AUG_TEST"], 94 | config["AUG_TRAIN_PARAMS"], 95 | config["AUG_TEST_PARAMS"], 96 | ) 97 | self.input_size = trainset[0][0].size() 98 | logger.info("Datasets prepared") 99 | 100 | # transform the training dataset for CutMix augmentation 101 | if "CUTMIX" in config: 102 | trainset = CutMix( 103 | trainset, 104 | config["MODEL_PARAMS"]["num_classes"], 105 | **config["CUTMIX"], 106 | ) 107 | 108 | # get dataloaders 109 | self.trainloader, self.testloader = utils.get_dataloader( 110 | trainset, 111 | testset, 112 | config["BATCH_SIZE"], 113 | config["N_WORKERS"], 114 | ) 115 | logger.info("Dataloader prepared") 116 | 117 | # define criterion and optimizer 118 | self.criterion = get_criterion( 119 | criterion_name=config["CRITERION"], 120 | criterion_params=config["CRITERION_PARAMS"], 121 | device=self.device, 122 | ) 123 | 124 | self.regularizer = None 125 | if "REGULARIZER" in config: 126 | self.regularizer = get_regularizer( 127 | config["REGULARIZER"], config["REGULARIZER_PARAMS"] 128 | ) 129 | 130 | self.optimizer = optim.SGD( 131 | self.model.parameters(), 132 | lr=config["LR"], 133 | momentum=config["MOMENTUM"], 134 | weight_decay=config["WEIGHT_DECAY"], 135 | nesterov=config["NESTEROV"], 136 | ) 137 | 138 | # learning rate scheduler 139 | self.lr_scheduler = get_lr_scheduler( 140 | config["LR_SCHEDULER"], 141 | config["LR_SCHEDULER_PARAMS"], 142 | ) 143 | 144 | def reset(self, checkpt_dir: str) -> None: 145 | """Reset the configurations.""" 146 | self.checkpt_dir = checkpt_dir 147 | self.best_acc = 0.0 148 | self.epoch = 0 149 | 150 | # best model path 151 | self.model_save_dir = os.path.join(self.dir_prefix, checkpt_dir) 152 | if not os.path.exists(self.model_save_dir): 153 | os.mkdir(self.model_save_dir) 154 | 155 | def resume(self) -> int: 156 | """Set to resume the training.""" 157 | last_epoch = -1 158 | latest_file_path = self._fetch_latest_checkpt() 159 | if latest_file_path and os.path.exists(latest_file_path): 160 | self.load_params(latest_file_path) 161 | _, self.checkpt_dir, filename = latest_file_path.rsplit(os.path.sep, 2) 162 | # fetch the last epoch from the filename 163 | last_epoch = int(filename.split("_", 1)[0]) 164 | return last_epoch + 1 165 | 166 | def run(self, resume_info_path: str = "") -> None: 167 | """Train the model.""" 168 | # resume trainer if needed 169 | start_epoch = 0 170 | if resume_info_path: 171 | start_epoch = self.resume() 172 | 173 | for self.epoch in range(start_epoch, self.config["EPOCHS"]): 174 | self.run_one_epoch(self.epoch) 175 | 176 | def run_one_epoch( 177 | self, 178 | epoch: int, 179 | extra_log_info: List[Tuple[str, float, Callable[[float], str]]] = None, 180 | ) -> None: 181 | """Train one epoch and run testing and logging.""" 182 | self.lr_scheduler(self.optimizer, epoch) 183 | 184 | # train 185 | train_loss, train_stat = self.train_one_epoch() 186 | 187 | # test 188 | test_loss, test_stat = self.test_one_epoch() 189 | 190 | # save all params that showed the best acc 191 | test_acc = test_stat["model_acc"] 192 | if test_acc > self.best_acc: 193 | self.best_acc = test_acc 194 | filename = str(epoch) + "_" + f"{self.best_acc:.2f}".replace(".", "_") 195 | self.save_params(self.model_save_dir, filename, epoch) 196 | 197 | # log 198 | if not extra_log_info: 199 | extra_log_info = [] 200 | lr = self.optimizer.param_groups[0]["lr"] 201 | log_info: List[Tuple[str, float, Callable[[float], str]]] = [] 202 | log_info.append(("train/lr", lr, default_format)) 203 | log_info.append(("train/loss", train_loss, default_format)) 204 | log_info += [("train/" + k, v, percent_format) for k, v in train_stat.items()] 205 | log_info.append(("test/loss", test_loss, default_format)) 206 | log_info += [("test/" + k, v, percent_format) for k, v in test_stat.items()] 207 | log_info.append(("test/best_acc", self.best_acc, percent_format)) 208 | self.log_one_epoch(epoch, log_info + extra_log_info) 209 | 210 | def log_one_epoch( 211 | self, epoch: int, log_info: List[Tuple[str, float, Callable[[float], str]]] 212 | ) -> None: 213 | """Log information after running one epoch.""" 214 | log_str = f"Epoch: [{epoch} | {self.config['EPOCHS']-1}]\t" 215 | log_str += "\t".join([f"{name}: " + f(val) for name, val, f in log_info]) 216 | logger.info(log_str) 217 | 218 | # logging 219 | if self.wandb_log: 220 | model_utils.wlog_weight(self.model) 221 | wandb.log(dict((name, val) for name, val, _ in log_info)) 222 | 223 | def train_one_epoch(self) -> Tuple[float, Dict[str, float]]: 224 | """Train one epoch.""" 225 | losses = [] 226 | self.model.train() 227 | 228 | # trainloaders contain same length(iteration) of batch dataset 229 | for data in progressbar(self.trainloader, prefix="[Train]\t"): 230 | # get the inputs; data is a list of [inputs, labels] 231 | images, labels = data[0].to(self.device), data[1].to(self.device) 232 | if self.half: 233 | images = images.half() 234 | 235 | # zero the parameter gradients 236 | self.optimizer.zero_grad() 237 | 238 | # forward + backward + optimize 239 | loss, outputs = self.criterion(self.model, images=images, labels=labels) 240 | if self.regularizer: 241 | loss += self.regularizer(self.model) 242 | self._count_correct_prediction(outputs, labels) 243 | loss.backward() 244 | self.optimizer.step() 245 | 246 | losses.append(loss.item()) 247 | 248 | avg_loss = sum(losses) / len(losses) 249 | acc = self._get_epoch_acc() 250 | return avg_loss, acc 251 | 252 | def test_one_epoch(self) -> Tuple[float, Dict[str, float]]: 253 | """Test one epoch.""" 254 | model = self.model 255 | if self.test_preprocess_hook: 256 | model = self.test_preprocess_hook(self.model) 257 | return self.test_one_epoch_model(model) 258 | 259 | @torch.no_grad() 260 | def test_one_epoch_model(self, model: nn.Module) -> Tuple[float, Dict[str, float]]: 261 | """Test the input model.""" 262 | losses = [] 263 | model.eval() 264 | 265 | # testloaders contain same length(iteration) of batch dataset 266 | for data in progressbar(self.testloader, prefix="[Test]\t"): 267 | images, labels = data[0].to(self.device), data[1].to(self.device) 268 | 269 | if self.half: 270 | images = images.half() 271 | 272 | # forward + backward + optimize 273 | loss, outputs = self.criterion(model, images=images, labels=labels) 274 | self._count_correct_prediction(outputs, labels) 275 | losses.append(loss.item()) 276 | 277 | import pdb 278 | 279 | pdb.set_trace() 280 | 281 | avg_loss = sum(losses) / len(losses) 282 | acc = self._get_epoch_acc(is_test=True) 283 | return avg_loss, acc 284 | 285 | @torch.no_grad() 286 | def warmup_one_iter(self) -> None: 287 | """Run one iter for wramup.""" 288 | self.model.eval() 289 | for batch_data in self.testloader: 290 | images, labels = ( 291 | batch_data[0].to(self.device), 292 | batch_data[1].to(self.device), 293 | ) 294 | 295 | # forward + backward + optimize 296 | loss, outputs = self.criterion( 297 | model=self.model, images=images, labels=labels 298 | ) 299 | return None 300 | 301 | def save_params( 302 | self, 303 | model_path: str, 304 | filename: str, 305 | epoch: int, 306 | record_path: bool = True, 307 | ) -> None: 308 | """Save model.""" 309 | params = { 310 | "state_dict": self.model.state_dict(), 311 | "optimizer": self.optimizer.state_dict(), 312 | "test_acc": self.best_acc, 313 | } 314 | 315 | filepath = os.path.join(model_path, f"{filename}.{self.fileext}") 316 | torch.save(params, filepath) 317 | logger.info( 318 | f"Saved all params in {model_path}{os.path.sep}{filename}.{self.fileext}" 319 | ) 320 | 321 | if record_path: 322 | with open( 323 | os.path.join(self.dir_prefix, self.checkpt_paths), "a" 324 | ) as checkpts: 325 | checkpts.write(filepath + "\n") 326 | 327 | def load_model(self, model_path: str, with_mask: bool = True) -> None: 328 | """Load weights and masks.""" 329 | checkpt = torch.load(model_path, map_location=self.device) 330 | model_utils.initialize_params( 331 | self.model, checkpt["state_dict"], with_mask=with_mask 332 | ) 333 | logger.info(f"Loaded the model from {model_path}") 334 | 335 | def load_params(self, model_path: str, with_mask: bool = True) -> None: 336 | """Load weights and masks.""" 337 | checkpt = torch.load(model_path, map_location=self.device) 338 | model_utils.initialize_params( 339 | self.model, checkpt["state_dict"], with_mask=with_mask 340 | ) 341 | model_utils.initialize_params( 342 | self.optimizer, checkpt["optimizer"], with_mask=False 343 | ) 344 | self.best_acc = checkpt["test_acc"] 345 | logger.info(f"Loaded parameters from {model_path}") 346 | 347 | def get_model_save_dir(self) -> str: 348 | """Get model save directory.""" 349 | return self.model_save_dir 350 | 351 | def load_best_model(self) -> None: 352 | """Load current best model.""" 353 | self.resume() 354 | 355 | def _count_correct_prediction( 356 | self, logits: Dict[str, torch.Tensor], labels: torch.Tensor 357 | ) -> None: 358 | """Count correct prediction in one iteration.""" 359 | if len(labels.size()) != 1: # For e.g., CutMix labels 360 | return 361 | for module_name, logit in logits.items(): 362 | _, predicted = torch.max(F.softmax(logit, dim=1).data, 1) 363 | n_correct = int((predicted == labels).sum().cpu()) 364 | import pdb 365 | 366 | pdb.set_trace() 367 | self.n_correct_epoch[module_name] += n_correct 368 | 369 | def _get_epoch_acc(self, is_test: bool = False) -> Dict[str, float]: 370 | """Get accuracy and reset statistics.""" 371 | n_total = ( 372 | len(self.testloader.dataset) if is_test else len(self.trainloader.dataset) 373 | ) 374 | acc = dict() 375 | for module_name in self.n_correct_epoch: 376 | accuracy = 100 * self.n_correct_epoch[module_name] / n_total 377 | acc.update({module_name + "_acc": accuracy}) 378 | self.n_correct_epoch.clear() 379 | 380 | return acc 381 | -------------------------------------------------------------------------------- /src/runners/validator.py: -------------------------------------------------------------------------------- 1 | """Validator for models. 2 | 3 | - Author: Haneol Kim. 4 | - Contact: hekim@jmarple.ai 5 | """ 6 | 7 | from collections import defaultdict 8 | import os 9 | from typing import Any, DefaultDict, Dict, Optional, Tuple, Union 10 | 11 | from progressbar import progressbar 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | 17 | from src.criterions import get_criterion 18 | from src.logger import get_logger 19 | from src.models import utils as model_utils 20 | from src.regularizers import get_regularizer 21 | from src.runners.runner import Runner 22 | from src.utils import count_param, get_dataloader, get_dataset, select_device 23 | 24 | LOGGER = get_logger(__name__) 25 | 26 | 27 | class Validator(Runner): 28 | """Validator for models.""" 29 | 30 | def __init__( 31 | self, 32 | config: Dict[str, Any], 33 | dir_prefix: str, 34 | checkpt_dir: str, 35 | device: Union[str, torch.device] = "cpu", 36 | half: bool = False, 37 | decomposed: bool = False, 38 | weight_path: Optional[str] = None, 39 | ) -> None: 40 | """Initialize vaildator.""" 41 | if decomposed and weight_path is None: 42 | raise ValueError("If decomposed, the weight_path should be given.") 43 | elif not decomposed and weight_path: 44 | decomposed = True 45 | 46 | super(Validator, self).__init__(config, dir_prefix) 47 | if isinstance(device, torch.device): 48 | self.device = device 49 | else: 50 | self.device = select_device(device) 51 | self.half = half 52 | 53 | self.decomposed = decomposed 54 | self.weight_path = weight_path 55 | 56 | # create a model 57 | if self.decomposed and self.weight_path: 58 | self.model = model_utils.load_decomposed_model(self.weight_path) 59 | self.model.to(self.device) 60 | if device == torch.device("cuda"): 61 | self.model = torch.nn.DataParallel(self.model).to(self.device) 62 | 63 | else: 64 | model_name = self.config["MODEL_NAME"] 65 | model_config = self.config["MODEL_PARAMS"] 66 | self.model = model_utils.get_model(model_name, model_config).to(self.device) 67 | if device == torch.device("cuda"): # multi-gpu 68 | self.model = torch.nn.DataParallel(self.model).to(self.device) 69 | 70 | if self.half: 71 | self.model.half() 72 | 73 | self.setup_val_configuration() 74 | self.n_correct_epoch: DefaultDict[str, int] = defaultdict(lambda: 0) 75 | 76 | LOGGER.info(f"Model parameters: {count_param(self.model)}") 77 | 78 | def setup_val_configuration(self) -> None: 79 | """Set up validation configuration.""" 80 | # get datasets 81 | trainset, testset = get_dataset( 82 | self.config["DATASET"], 83 | self.config["AUG_TRAIN"], 84 | self.config["AUG_TEST"], 85 | self.config["AUG_TRAIN_PARAMS"], 86 | self.config["AUG_TEST_PARAMS"], 87 | ) 88 | 89 | self.input_size = trainset[0][0].size() 90 | LOGGER.info("Datasets prepared") 91 | 92 | _, self.testloader = get_dataloader( 93 | trainset, 94 | testset, 95 | self.config["BATCH_SIZE"], 96 | self.config["N_WORKERS"], 97 | ) 98 | LOGGER.info("Dataloader prepared") 99 | 100 | # define criterion and optimizer 101 | self.criterion = get_criterion( 102 | criterion_name=self.config["CRITERION"], 103 | criterion_params=self.config["CRITERION_PARAMS"], 104 | device=self.device, 105 | ) 106 | 107 | self.regularizer = None 108 | if "REGULARIZER" in self.config: 109 | self.regularizer = get_regularizer( 110 | self.config["REGULARIZER"], self.config["REGULARIZER_PARAMS"] 111 | ) 112 | 113 | self.optimizer = optim.SGD( 114 | self.model.parameters(), 115 | lr=self.config["LR"], 116 | momentum=self.config["MOMENTUM"], 117 | weight_decay=self.config["WEIGHT_DECAY"], 118 | nesterov=self.config["NESTEROV"], 119 | ) 120 | if not self.decomposed: 121 | self.resume() 122 | 123 | def run(self) -> Tuple[float, dict]: 124 | """Train the model.""" 125 | # resume trainer if needed 126 | test_loss, acc = self.test_one_epoch() 127 | # LOGGER.info(f"loss : {test_loss}, accuracy : {acc['model_acc']}%") 128 | return test_loss, acc 129 | 130 | @torch.no_grad() 131 | def test_one_epoch_model(self, model: nn.Module) -> Tuple[float, Dict[str, float]]: 132 | """Test the input model.""" 133 | losses = [] 134 | model.eval() 135 | 136 | # testloaders contain same length(iteration) of batch dataset 137 | for data in progressbar(self.testloader, prefix="[Test]\t"): 138 | images, labels = data[0].to(self.device), data[1].to(self.device) 139 | 140 | if self.half: 141 | images = images.half() 142 | 143 | # forward + backward + optimize 144 | loss, outputs = self.criterion(model, images=images, labels=labels) 145 | self._count_correct_prediction(outputs, labels) 146 | 147 | losses.append(loss.item()) 148 | 149 | avg_loss = sum(losses) / len(losses) 150 | acc = self._get_epoch_acc(is_test=True) 151 | return avg_loss, acc 152 | 153 | def test_one_epoch(self) -> Tuple[float, Dict[str, float]]: 154 | """Test one epoch.""" 155 | model = self.model 156 | return self.test_one_epoch_model(model) 157 | 158 | def _count_correct_prediction( 159 | self, logits: Dict[str, torch.Tensor], labels: torch.Tensor 160 | ) -> None: 161 | """Count correct prediction in one iteration.""" 162 | if len(labels.size()) != 1: # For e.g., CutMix labels 163 | return 164 | for module_name, logit in logits.items(): 165 | _, predicted = torch.max(F.softmax(logit, dim=1).data, 1) 166 | n_correct = int((predicted == labels).sum().cpu()) 167 | self.n_correct_epoch[module_name] += n_correct 168 | 169 | def load_model(self, model_path: str, with_mask: bool = True) -> None: 170 | """Load weights and masks.""" 171 | checkpt = torch.load(model_path, map_location=self.device) 172 | 173 | model_utils.initialize_params( 174 | self.model, checkpt["state_dict"], with_mask=with_mask 175 | ) 176 | LOGGER.info(f"Loaded the model from {model_path}") 177 | 178 | def _get_epoch_acc(self, is_test: bool = False) -> Dict[str, float]: 179 | """Get accuracy and reset statistics.""" 180 | n_total = ( 181 | len(self.testloader.dataset) if is_test else len(self.trainloader.dataset) 182 | ) 183 | acc = dict() 184 | for module_name in self.n_correct_epoch: 185 | accuracy = 100 * self.n_correct_epoch[module_name] / n_total 186 | acc.update({module_name + "_acc": accuracy}) 187 | self.n_correct_epoch.clear() 188 | 189 | return acc 190 | 191 | def resume(self) -> int: 192 | """Set to resume the training.""" 193 | last_epoch = -1 194 | latest_file_path = self._fetch_latest_checkpt() 195 | if latest_file_path and os.path.exists(latest_file_path): 196 | self.load_params(latest_file_path) 197 | _, self.checkpt_dir, filename = latest_file_path.rsplit(os.path.sep, 2) 198 | # fetch the last epoch from the filename 199 | last_epoch = int(filename.split("_", 1)[0]) 200 | return last_epoch + 1 201 | 202 | def load_params(self, model_path: str, with_mask: bool = True) -> None: 203 | """Load weights and masks.""" 204 | checkpt = torch.load(model_path, map_location=self.device) 205 | model_utils.initialize_params( 206 | self.model, checkpt["state_dict"], with_mask=with_mask 207 | ) 208 | model_utils.initialize_params( 209 | self.optimizer, checkpt["optimizer"], with_mask=False 210 | ) 211 | self.best_acc = checkpt["test_acc"] 212 | LOGGER.info(f"Loaded parameters from {model_path}") 213 | -------------------------------------------------------------------------------- /src/tensor_decomposition/decomposition.py: -------------------------------------------------------------------------------- 1 | """Tensor Decomposition for YOLO model. 2 | 3 | - Author: Jongkuk Lim 4 | - Contact: limjk@jmarple.ai 5 | """ 6 | 7 | from copy import deepcopy 8 | from typing import Dict, List, Optional, Tuple, Union 9 | 10 | import numpy as np 11 | from scipy.optimize import minimize_scalar 12 | import tensorly as tl 13 | from tensorly.decomposition import partial_tucker 14 | import torch 15 | from torch import nn 16 | import torch.nn.utils.prune as prune 17 | 18 | from src.logger import get_logger 19 | 20 | LOGGER = get_logger(__name__) 21 | 22 | tl.set_backend("pytorch") 23 | 24 | 25 | def tau(x: np.ndarray, alpha: float) -> np.ndarray: 26 | """Compute tau value for EVBsigma2. 27 | 28 | Args: 29 | x: value to compute tau. 30 | alpha: alpha blending parameter. 31 | 32 | Return: 33 | tau value from x 34 | """ 35 | return 0.5 * (x - (1 + alpha) + np.sqrt((x - (1 + alpha)) ** 2 - 4 * alpha)) 36 | 37 | 38 | def EVBsigma2( 39 | sigma2: float, L: int, M: int, s: np.ndarray, residual: float, xubar: float 40 | ) -> float: 41 | """Compute sigma value for EVBMF. 42 | 43 | Args: 44 | sigma2: sigma value 45 | L: matrix shape L 46 | M: matrix shape M 47 | s: matrix of singular values 48 | residual: residual value. 49 | xubar: bar{x_u} 50 | 51 | Return: 52 | sigma value for EVBMF 53 | """ 54 | H = len(s) 55 | 56 | alpha = L / M 57 | x = s**2 / (M * sigma2) 58 | 59 | z1 = x[x > xubar] 60 | z2 = x[x <= xubar] 61 | tau_z1 = tau(z1, alpha) 62 | 63 | term1 = np.sum(z2 - np.log(z2)) 64 | term2 = np.sum(z1 - tau_z1) 65 | term3 = np.sum(np.log(np.divide(tau_z1 + 1, z1))) 66 | term4 = alpha * np.sum(np.log(tau_z1 / alpha + 1)) 67 | 68 | obj = ( 69 | term1 70 | + term2 71 | + term3 72 | + term4 73 | + residual / (M * sigma2) 74 | + (L - H) * np.log(sigma2) 75 | ) 76 | 77 | return obj 78 | 79 | 80 | def EVBMF( 81 | Y: torch.Tensor, sigma2: Optional[int] = None, H: Optional[int] = None 82 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict[str, np.ndarray]]: 83 | """Compute EVBMF. 84 | 85 | Implementation of the analytical solution to EVBMF. 86 | 87 | (Empirical Variational Bayes Matrix Factorization) 88 | 89 | This function can be used to calculate the analytical solution to empirical VBMF. 90 | This is based on the paper and MatLab code by Nakajima et al.: 91 | "Global analytic solution of fully-observed variational Bayesian matrix factorization." 92 | 93 | Notes 94 | ----- 95 | If sigma2 is unspecified, it is estimated by minimizing the free energy. 96 | If H is unspecified, it is set to the smallest of the sides of the input Y. 97 | 98 | Args: 99 | Y : numpy-array 100 | Input matrix that is to be factorized. Y has shape (L,M), where L<=M. 101 | 102 | sigma2 : int or None (default=None) 103 | Variance of the noise on Y. 104 | 105 | H : int or None (default = None) 106 | Maximum rank of the factorized matrices. 107 | 108 | Returns: 109 | U : numpy-array 110 | Left-singular vectors. 111 | 112 | S : numpy-array 113 | Diagonal matrix of singular values. 114 | 115 | V : numpy-array 116 | Right-singular vectors. 117 | 118 | post : dictionary 119 | Dictionary containing the computed posterior values. 120 | 121 | 122 | References 123 | ---------- 124 | .. [1] Nakajima, Shinichi, et al. "Global analytic solution of fully-observed variational Bayesian matrix factorization." Journal of Machine Learning Research 14.Jan (2013): 1-37. 125 | 126 | .. [2] Nakajima, Shinichi, et al. "Perfect dimensionality recovery by variational Bayesian PCA." Advances in Neural Information Processing Systems. 2012. 127 | """ 128 | L, M = Y.shape # has to be L<=M 129 | 130 | if H is None: 131 | H = L 132 | 133 | alpha = L / M 134 | tauubar = 2.5129 * np.sqrt(alpha) 135 | 136 | # SVD of the input matrix, max rank of H 137 | U, s, V = np.linalg.svd(Y) 138 | U = U[:, :H] 139 | s = s[:H] 140 | V = V[:H].T 141 | 142 | # Calculate residual 143 | residual = 0.0 144 | if H < L: 145 | residual = np.sum(np.sum(Y**2) - np.sum(s**2)) 146 | 147 | # Estimation of the variance when sigma2 is unspecified 148 | if sigma2 is None: 149 | xubar = (1 + tauubar) * (1 + alpha / tauubar) 150 | eH_ub = int(np.min([np.ceil(L / (1 + alpha)) - 1, H])) 151 | upper_bound = (np.sum(s**2) + residual) / (L * M) 152 | lower_bound = np.max([s[eH_ub] ** 2 / (M * xubar), np.mean(s[eH_ub:] ** 2) / M]) 153 | 154 | scale = 1.0 # /lower_bound 155 | s = s * np.sqrt(scale) 156 | residual = residual * scale 157 | lower_bound = lower_bound * scale 158 | upper_bound = upper_bound * scale 159 | 160 | sigma2_opt = minimize_scalar( 161 | EVBsigma2, 162 | args=(L, M, s, residual, xubar), 163 | bounds=[lower_bound, upper_bound], 164 | method="Bounded", 165 | ) 166 | sigma2 = sigma2_opt.x 167 | 168 | # Threshold gamma term 169 | threshold = np.sqrt(M * sigma2 * (1 + tauubar) * (1 + alpha / tauubar)) 170 | pos = np.sum(s > threshold) 171 | 172 | # Formula (15) from [2] 173 | d = np.multiply( 174 | s[:pos] / 2, 175 | 1 176 | - np.divide((L + M) * sigma2, s[:pos] ** 2) 177 | + np.sqrt( 178 | (1 - np.divide((L + M) * sigma2, s[:pos] ** 2)) ** 2 179 | - 4 * L * M * sigma2**2 / s[:pos] ** 4 180 | ), 181 | ) 182 | 183 | # Computation of the posterior 184 | post = {} 185 | post["ma"] = np.zeros(H) 186 | post["mb"] = np.zeros(H) 187 | post["sa2"] = np.zeros(H) 188 | post["sb2"] = np.zeros(H) 189 | post["cacb"] = np.zeros(H) 190 | 191 | tau = np.multiply(d, s[:pos]) / (M * sigma2) 192 | delta = np.multiply(np.sqrt(np.divide(M * d, L * s[:pos])), 1 + alpha / tau) 193 | 194 | post["ma"][:pos] = np.sqrt(np.multiply(d, delta)) 195 | post["mb"][:pos] = np.sqrt(np.divide(d, delta)) 196 | post["sa2"][:pos] = np.divide(sigma2 * delta, s[:pos]) 197 | post["sb2"][:pos] = np.divide(sigma2, np.multiply(delta, s[:pos])) 198 | post["cacb"][:pos] = np.sqrt(np.multiply(d, s[:pos]) / (L * M)) 199 | post["sigma2"] = sigma2 # type: ignore 200 | post["F"] = 0.5 * ( 201 | L * M * np.log(2 * np.pi * sigma2) 202 | + (residual + np.sum(s**2)) / sigma2 203 | + np.sum(M * np.log(tau + 1) + L * np.log(tau / alpha + 1) - M * tau) 204 | ) 205 | 206 | return U[:, :pos], np.diag(d), V[:, :pos], post 207 | 208 | 209 | def decompose_layer_evaluation( 210 | layer: nn.Conv2d, test_input: torch.Tensor, origin_out: torch.Tensor 211 | ) -> Tuple[Optional[nn.Sequential], Union[torch.Tensor, float]]: 212 | """Decompose layer and evaluate loss. 213 | 214 | Args: 215 | layer: layer to apply tensor decomposition. 216 | test_input: test input tensor to feedforward layer. 217 | origin_out: original output tensor: layer(test_input). 218 | 219 | Return: 220 | (decomposed_layer, loss) 221 | (None, inf) if failed to decompose the layer. 222 | """ 223 | original_layer = deepcopy(layer) 224 | decomposed_layer = None 225 | 226 | try: 227 | decomposed_layer = tucker_decomposition_conv_layer(original_layer) 228 | except ValueError: 229 | LOGGER.info("Decompose tensor failed.") 230 | return None, float("inf") 231 | 232 | decomposed_out = decomposed_layer(test_input) 233 | loss = torch.abs(origin_out - decomposed_out).sum() / origin_out.numel() 234 | 235 | return decomposed_layer, loss 236 | 237 | 238 | def decompose_model( 239 | model: nn.Module, loss_thr: float = 0.1, prune_step: float = 0.01 240 | ) -> None: 241 | """Decompose conv in model recursively. 242 | 243 | Decompose all (n, n) conv to (1, 1) -> (n, n) -> (1, 1) 244 | --> n > 1 245 | Note that this is in-place operation. 246 | 247 | Args: 248 | model: PyTorch model. 249 | loss_thr: loss threshold to compare between original conv with decomposed conv. 250 | loss = (o1 - o2).abs().sum() / o1.numel() 251 | o1: original conv out 252 | o2: decomposed conv out 253 | prune_step: pruning ratio step size. 254 | i.e. prune_step=0.1 will try to prune (0.1, 0.2, 0.3, ...) before decomposition until loss is larger than loss_thr. 255 | if prune_step is equal or smaller than 0.0, prunning will not be applied. 256 | """ 257 | for i, (name, module) in enumerate(model.named_children()): 258 | if len(list(module.children())) > 0: 259 | decompose_model( 260 | module, loss_thr=loss_thr, prune_step=prune_step 261 | ) # Call recursively 262 | 263 | if isinstance(module, nn.Conv2d): 264 | if isinstance(model, nn.ModuleList): 265 | conv = model[i] 266 | else: 267 | # conv = module 268 | conv = model.conv 269 | # conv = deepcopy(module) 270 | # module_param = [x for x in module.parameters()] 271 | # conv_param = [x for x in conv.parameters()] 272 | 273 | if conv != module: 274 | # if conv != module and not torch.equal(module_param[0], conv_param[0]): 275 | continue 276 | 277 | if conv.kernel_size == (1, 1): 278 | continue 279 | 280 | test_input = torch.rand((1024, *conv.weight.shape[1:])) 281 | origin_out = conv(test_input) 282 | decomposed_conv = None 283 | 284 | # Run decomposition before searching prunning ratio to check if it's worthy. 285 | decomposed_conv_candidate, loss = decompose_layer_evaluation( 286 | conv, test_input, origin_out 287 | ) 288 | LOGGER.info(f"{name} (Prune: {0.0:.3f}): Loss(mean): {loss}, ") 289 | if loss < loss_thr: 290 | run_b_search = True 291 | decomposed_conv = decomposed_conv_candidate 292 | else: 293 | run_b_search = False 294 | 295 | if prune_step <= 0: 296 | run_b_search = False 297 | 298 | max_prune_ratio = 1.0 299 | min_prune_ratio = 0.0 300 | prune_ratio = (max_prune_ratio + min_prune_ratio) / 2 301 | 302 | while run_b_search: # Binary search for pruning ratio. 303 | original_conv = deepcopy(conv) 304 | if prune_ratio > 0.0: 305 | prune.l1_unstructured( 306 | original_conv, name="weight", amount=prune_ratio 307 | ) 308 | prune.remove(original_conv, "weight") 309 | 310 | decomposed_conv_candidate, loss = decompose_layer_evaluation( 311 | original_conv, test_input, origin_out 312 | ) 313 | 314 | LOGGER.info(f"{name} (Prune: {prune_ratio:.3f}): Loss(mean): {loss}, ") 315 | 316 | if loss < loss_thr: 317 | min_prune_ratio = prune_ratio 318 | decomposed_conv = decomposed_conv_candidate 319 | else: 320 | max_prune_ratio = prune_ratio 321 | 322 | next_prune_ratio = (max_prune_ratio + min_prune_ratio) / 2 323 | if ( 324 | abs(prune_ratio - next_prune_ratio) == 0 325 | or abs(prune_ratio - next_prune_ratio) < prune_step 326 | ): 327 | break 328 | 329 | prune_ratio = next_prune_ratio 330 | 331 | if decomposed_conv is not None: 332 | for attr_name in ["in_channels", "out_channels", "kernel_size"]: 333 | setattr(decomposed_conv, attr_name, getattr(conv, attr_name)) 334 | # decomposed_conv.in_channels = conv.in_channels 335 | # decomposed_conv.out_channels = conv.out_channels 336 | # decomposed_conv.kernel_size = conv.kernel_size 337 | 338 | if isinstance(model, nn.ModuleList): 339 | model[i] = decomposed_conv 340 | else: 341 | model.conv = decomposed_conv 342 | 343 | LOGGER.info(" |---------- Switching conv to decomposed conv") 344 | else: 345 | LOGGER.info(" |---------- Skip switching to decomposed conv.") 346 | 347 | 348 | def estimate_ranks(layer: nn.Conv2d) -> List[int]: 349 | """Estimate ranks for the given layer. 350 | 351 | Unfold the 2 modes of the Tensor the decomposition will 352 | be performed on, and estimates the ranks of the matrices using VBMF 353 | 354 | Args: 355 | layer: Conv2d module. 356 | 357 | Return: 358 | estimated ranks 359 | """ 360 | weights = layer.weight.data 361 | unfold_0 = tl.base.unfold(weights, 0) 362 | unfold_1 = tl.base.unfold(weights, 1) 363 | 364 | _, diag_0, _, _ = EVBMF(unfold_0) 365 | _, diag_1, _, _ = EVBMF(unfold_1) 366 | ranks = [diag_0.shape[0], diag_1.shape[1]] 367 | return ranks 368 | 369 | 370 | def tucker_decomposition_conv_layer(layer: nn.Conv2d) -> nn.Sequential: 371 | """Perform Tucker decomposition on the Conv2d Layer. 372 | 373 | The ranks are estimated with a Python implementation of VBMF 374 | https://github.com/CasvandenBogaard/VBMF 375 | 376 | Args: 377 | layer: Conv2d module. 378 | 379 | Return: 380 | nn.Sequential object with the Tucker decomposition. 381 | Which consists of (1, 1) conv -> (n, n) conv -> (1, 1) conv 382 | """ 383 | ranks = estimate_ranks(layer) 384 | LOGGER.info(f"{layer} : VBMF Estimated ranks: {ranks}") 385 | core, [last, first] = partial_tucker( 386 | layer.weight.data, modes=[0, 1], rank=ranks, init="svd" 387 | ) 388 | 389 | # A pointwise convolution that reduces the channels from S to R3 390 | first_layer = torch.nn.Conv2d( 391 | in_channels=first.shape[0], 392 | out_channels=first.shape[1], 393 | kernel_size=1, 394 | stride=1, 395 | padding=0, 396 | dilation=layer.dilation, # type: ignore 397 | bias=False, 398 | ) 399 | 400 | # A regular 2D convolution layer with R3 input channels 401 | # and R3 output channels 402 | core_layer = torch.nn.Conv2d( 403 | in_channels=core.shape[1], 404 | out_channels=core.shape[0], 405 | kernel_size=layer.kernel_size, # type: ignore 406 | stride=layer.stride, # type: ignore 407 | padding=layer.padding, # type: ignore 408 | dilation=layer.dilation, # type: ignore 409 | bias=False, 410 | ) 411 | 412 | # A pointwise convolution that increases the channels from R4 to T 413 | last_layer = torch.nn.Conv2d( 414 | in_channels=last.shape[1], 415 | out_channels=last.shape[0], 416 | kernel_size=1, 417 | stride=1, 418 | padding=0, 419 | dilation=layer.dilation, # type: ignore 420 | bias=layer.bias is not None, 421 | ) 422 | 423 | if layer.bias is not None: 424 | last_layer.bias.data = layer.bias.data # type: ignore 425 | 426 | first_layer.weight.data = torch.transpose(first, 1, 0).unsqueeze(-1).unsqueeze(-1) 427 | last_layer.weight.data = last.unsqueeze(-1).unsqueeze(-1) 428 | core_layer.weight.data = core 429 | 430 | new_layers = [first_layer, core_layer, last_layer] 431 | return nn.Sequential(*new_layers) 432 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Utils for model compression. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | import glob 9 | import logging 10 | import logging.handlers 11 | import os 12 | import random 13 | import sys 14 | from typing import Dict, Optional, Tuple 15 | 16 | import numpy as np 17 | import torch 18 | from torch import nn 19 | import torch.nn.functional as F 20 | import torch.utils.data as data 21 | from torchvision.datasets import VisionDataset 22 | 23 | 24 | def set_random_seed(seed: int) -> None: 25 | """Set random seed.""" 26 | np.random.seed(seed) 27 | random.seed(seed) 28 | torch.manual_seed(seed) 29 | 30 | # for CuDNN backend 31 | if torch.backends.cudnn.is_available(): 32 | torch.backends.cudnn.deterministic = True # type: ignore 33 | torch.backends.cudnn.benchmark = False # type: ignore 34 | 35 | 36 | def get_rand_bbox_coord( 37 | w: int, h: int, len_ratio: float 38 | ) -> Tuple[Tuple[int, int], Tuple[int, int]]: 39 | """Get a coordinate of random box.""" 40 | size_hole_w = int(len_ratio * w) 41 | size_hole_h = int(len_ratio * h) 42 | x = random.randint(0, w) # [0, w] 43 | y = random.randint(0, h) # [0, h] 44 | 45 | x0 = max(0, x - size_hole_w // 2) 46 | y0 = max(0, y - size_hole_h // 2) 47 | x1 = min(w, x + size_hole_w // 2) 48 | y1 = min(h, y + size_hole_h // 2) 49 | return (x0, y0), (x1, y1) 50 | 51 | 52 | def to_onehot(labels: torch.Tensor, num_classes: int) -> torch.Tensor: 53 | """Convert index based labels into one-hot based labels. 54 | 55 | If labels are one-hot based already(e.g. [0.9, 0.01, 0.03,...]), do nothing. 56 | """ 57 | if len(labels.size()) == 1: 58 | return F.one_hot(labels, num_classes).float() 59 | return labels 60 | 61 | 62 | def get_dataset( 63 | dataset_name: str = "CIFAR100", 64 | transform_train: str = "simple_augment_train_cifar100", 65 | transform_test: str = "simple_augment_test_cifar100", 66 | transform_train_params: Dict[str, int] = None, 67 | transform_test_params: Dict[str, int] = None, 68 | ) -> Tuple[VisionDataset, VisionDataset]: 69 | """Get dataset for training and testing.""" 70 | if not transform_train_params: 71 | transform_train_params = dict() 72 | 73 | # preprocessing policies 74 | transform_train = getattr( 75 | __import__("src.augmentation.policies", fromlist=[""]), 76 | transform_train, 77 | )(**transform_train_params) 78 | transform_test = getattr( 79 | __import__("src.augmentation.policies", fromlist=[""]), 80 | transform_test, 81 | )(**transform_test_params) 82 | 83 | # pytorch dataset 84 | Dataset = getattr(__import__("torchvision.datasets", fromlist=[""]), dataset_name) 85 | trainset = Dataset( 86 | root="save/data", train=True, download=True, transform=transform_train 87 | ) 88 | testset = Dataset( 89 | root="save/data", train=False, download=False, transform=transform_test 90 | ) 91 | 92 | return trainset, testset 93 | 94 | 95 | def get_dataloader( 96 | trainset: VisionDataset, 97 | testset: VisionDataset, 98 | batch_size: int, 99 | n_workers: int, 100 | ) -> Tuple[data.DataLoader, data.DataLoader]: 101 | """Get dataloader for training and testing.""" 102 | trainloader = data.DataLoader( 103 | trainset, 104 | pin_memory=(torch.cuda.is_available()), 105 | num_workers=n_workers, 106 | shuffle=True, 107 | batch_size=batch_size, 108 | ) 109 | testloader = data.DataLoader( 110 | testset, 111 | pin_memory=(torch.cuda.is_available()), 112 | num_workers=n_workers, 113 | shuffle=False, 114 | batch_size=batch_size, 115 | ) 116 | return trainloader, testloader 117 | 118 | 119 | def get_latest_file(filepath: str, pattern: str = "*") -> str: 120 | """Get the latest file from the input filepath.""" 121 | filelist = glob.glob(os.path.join(filepath, pattern)) 122 | return max(filelist, key=os.path.getctime) if filelist else "" 123 | 124 | 125 | def set_logger( 126 | filename: str, 127 | mode: str = "a", 128 | level: int = logging.DEBUG, 129 | maxbytes: int = 1024 * 1024 * 10, # default: 10Mbyte 130 | backupcnt: int = 100, 131 | ) -> None: 132 | """Create and get the logger for the console and files.""" 133 | logger = logging.getLogger("model_compression") 134 | logger.setLevel(level) 135 | 136 | chdlr = logging.StreamHandler(sys.stdout) 137 | chdlr.setLevel(logging.DEBUG) 138 | cfmts = "%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s" 139 | chdlr.setFormatter(logging.Formatter(cfmts)) 140 | logger.addHandler(chdlr) 141 | 142 | fhdlr = logging.handlers.RotatingFileHandler( 143 | filename, mode=mode, maxBytes=maxbytes, backupCount=backupcnt 144 | ) 145 | fhdlr.setLevel(logging.DEBUG) 146 | ffmts = "%(asctime)s - " 147 | ffmts += "%(processName)s - %(threadName)s - " 148 | ffmts += "%(filename)s:%(lineno)d - %(levelname)s - %(message)s" 149 | fhdlr.setFormatter(logging.Formatter(ffmts)) 150 | logger.addHandler(fhdlr) 151 | 152 | 153 | def get_logger() -> logging.Logger: 154 | """Get logger instance.""" 155 | return logging.getLogger("model_compression") 156 | 157 | 158 | def count_param(model: nn.Module) -> int: 159 | """Count number of all parameters. 160 | 161 | Args: 162 | model: PyTorch model. 163 | 164 | Return: 165 | Sum of # of parameters 166 | """ 167 | return sum(list(x.numel() for x in model.parameters())) 168 | 169 | 170 | def select_device(device: str = "", batch_size: Optional[int] = None) -> torch.device: 171 | """Select torch device. 172 | 173 | Args: 174 | device: 'cpu' or '0' or '0, 1, 2, 3' format string. 175 | batch_size: distribute batch to multiple gpus. 176 | 177 | Returns: 178 | A torch device. 179 | """ 180 | cpu_request = device.lower() == "cpu" 181 | if device and not cpu_request: 182 | os.environ["CUDA_VISIBLE_DEVICES"] = device 183 | assert torch.cuda.is_available(), ( 184 | "CUDA unavailable, invalid device %s requested" % device 185 | ) 186 | 187 | cuda = False if cpu_request else torch.cuda.is_available() 188 | if cuda: 189 | c = 1024**2 190 | ng = torch.cuda.device_count() 191 | if ng > 1 and batch_size: 192 | assert ( 193 | batch_size % ng == 0 194 | ), "batch-size %g not multiple of GPU count %g" % (batch_size, ng) 195 | x = [torch.cuda.get_device_properties(i) for i in range(ng)] 196 | s = "Using CUDA " 197 | for i in range(0, ng): 198 | if i == 1: 199 | s = " " * len(s) 200 | print( 201 | "%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" 202 | % (s, i, x[i].name, x[i].total_memory / c) 203 | ) 204 | 205 | else: 206 | print("Using CPU") 207 | 208 | print("") 209 | return torch.device("cuda:0" if cuda else "cpu") 210 | -------------------------------------------------------------------------------- /tests/test_dummy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Dummy test to suppress exit code 5 of pytest. 3 | 4 | - Author: Curt-Park 5 | - Contact: jwpark@jmarple.ai 6 | """ 7 | 8 | 9 | def test_dummy() -> None: 10 | """Dummy test.""" 11 | pass 12 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Training Runner. 3 | 4 | - Author: Curt-Park 5 | - Email: jwpark@jmarple.ai 6 | """ 7 | 8 | 9 | import argparse 10 | 11 | from src.runners import curr_time, initialize 12 | from src.runners.trainer import Trainer 13 | 14 | # arguments 15 | parser = argparse.ArgumentParser(description="Model trainer.") 16 | parser.add_argument("--multi-gpu", action="store_true", help="Multi-GPU use") 17 | parser.add_argument("--gpu", default=0, type=int, help="GPU id to use") 18 | parser.add_argument( 19 | "--finetune", type=str, default="", help="Model path to finetune (pth.tar)" 20 | ) 21 | parser.add_argument( 22 | "--resume", 23 | type=str, 24 | default="", 25 | help="Input log directory name to resume in save/checkpoint", 26 | ) 27 | parser.add_argument( 28 | "--half", dest="half", action="store_true", help="Use half precision" 29 | ) 30 | parser.add_argument( 31 | "--wlog", dest="wlog", action="store_true", help="Turns on wandb logging" 32 | ) 33 | parser.add_argument( 34 | "--config", 35 | type=str, 36 | default="config/train/simplenet.py", 37 | help="Configuration path (.py)", 38 | ) 39 | parser.set_defaults(half=False) 40 | parser.set_defaults(multi_gpu=False) 41 | parser.set_defaults(wlog=False) 42 | args = parser.parse_args() 43 | 44 | # initialize 45 | config, dir_prefix, device = initialize( 46 | "train", args.config, args.resume, args.multi_gpu, args.gpu 47 | ) 48 | if args.resume: 49 | wandb_name = args.resume 50 | finetune = "" 51 | else: 52 | wandb_name = curr_time 53 | finetune = args.finetune 54 | wandb_init_params = dict(config=config, name=wandb_name, group=args.config) 55 | 56 | # run training 57 | trainer = Trainer( 58 | config=config, 59 | dir_prefix=dir_prefix, 60 | checkpt_dir="train", 61 | finetune=finetune, 62 | wandb_log=args.wlog, 63 | wandb_init_params=wandb_init_params, 64 | device=device, 65 | half=args.half, 66 | ) 67 | trainer.run(args.resume) 68 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | """Validation runner. 2 | 3 | - Author: Haneol Kim 4 | - Contact: hekim@jmarple.ai 5 | """ 6 | 7 | import argparse 8 | import os 9 | 10 | from src.logger import get_logger 11 | from src.runners import initialize 12 | from src.runners.validator import Validator 13 | 14 | LOGGER = get_logger(__name__) 15 | 16 | 17 | def get_parser() -> argparse.Namespace: 18 | """Parse command line arguments.""" 19 | parser = argparse.ArgumentParser( 20 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 21 | ) 22 | parser.add_argument("--multi-gpu", action="store_true", help="Multi-GPU use") 23 | parser.add_argument("--gpu", default=0, type=int, help="GPU id to use") 24 | parser.add_argument( 25 | "--resume", 26 | type=str, 27 | default="", 28 | help="Input log directory name to resume in save/checkpoint", 29 | ) 30 | parser.add_argument( 31 | "--half", dest="half", action="store_true", help="Use half precision" 32 | ) 33 | parser.add_argument( 34 | "--config", 35 | type=str, 36 | default="config/train/cifar100/densenet_201.py", 37 | help="Configuration path (.py)", 38 | ) 39 | parser.add_argument( 40 | "--decomp", dest="decomposed", action="store_true", help="Use decomposed model." 41 | ) 42 | parser.add_argument( 43 | "--decomp_dir", 44 | type=str, 45 | default="", 46 | help="Decomposed model weight file path (e.g. decompose/220714_180306/weight.pt).", 47 | ) 48 | parser.set_defaults(half=False) 49 | parser.set_defaults(multi_gpu=False) 50 | args = parser.parse_args() 51 | return args 52 | 53 | 54 | if __name__ == "__main__": 55 | args = get_parser() 56 | 57 | if args.decomposed and args.decomp_dir: 58 | if args.decomp_dir.endswith(".pt"): 59 | decomp_dir = args.decomp_dir.split("/") 60 | weight = decomp_dir[-1] 61 | decomp_dir = os.path.join(decomp_dir[0], decomp_dir[1]) 62 | config, dir_prefix, device = initialize( 63 | "val", args.config, decomp_dir, args.multi_gpu, args.gpu 64 | ) 65 | weight_path = args.decomp_dir 66 | else: 67 | raise ValueError("The decomposed dir should be end with pt") 68 | else: 69 | config, dir_prefix, device = initialize( 70 | "val", args.config, args.resume, args.multi_gpu, args.gpu 71 | ) 72 | weight_path = args.decomp_dir 73 | 74 | print(config) 75 | validator = Validator( 76 | config=config, 77 | dir_prefix=dir_prefix, 78 | checkpt_dir="train", 79 | device=device, 80 | half=args.half, 81 | decomposed=args.decomposed, 82 | weight_path=weight_path, 83 | ) 84 | 85 | _, acc = validator.run() 86 | LOGGER.info(f"accuracy : {acc['model_acc']}%") 87 | --------------------------------------------------------------------------------