├── only_train_once ├── assets │ ├── __init__.py │ └── theme.py ├── operation │ └── __init__.py ├── transform │ ├── __init__.py │ ├── graph_transform.py │ ├── tensor_transform.py │ └── ge.py ├── dependency_graph │ └── __init__.py ├── graph │ ├── __init__.py │ ├── node.py │ └── utils.py ├── subnet_construction │ └── __init__.py └── optimizer │ ├── __init__.py │ ├── hyperparameter.py │ └── importance_score │ ├── __init__.py │ ├── magnitude.py │ ├── cosine_similarity.py │ └── taylor.py ├── requirements.txt ├── visual_examples ├── pruning │ ├── carn.pdf │ ├── phi2.pdf │ ├── llama.pdf │ ├── resnet.pdf │ ├── tel_gan.pdf │ ├── vggbn.pdf │ ├── yolov5.pdf │ ├── bert_base.pdf │ ├── convnext.pdf │ ├── densenet.pdf │ ├── llama_lora.pdf │ ├── DemoNetConcatCase1.pdf │ ├── DemoNetConcatCase2.pdf │ ├── DemoNetWeightShareCase1.pdf │ ├── DemoNetWeightShareCase2.pdf │ ├── DemonetBatchnormPruning.pdf │ ├── DemoNetConvtransposeInCase1.pdf │ ├── DemoNetConvtransposeInCase2.pdf │ ├── DemoNetInstanceNorm2DCase3.pdf │ └── yolov5_with_param_displayed.pdf ├── erasing │ ├── stackedunets_h2spg.pdf │ ├── stackedunets_subnetwork.pdf │ ├── stackedunets_trace_graph.pdf │ └── stackedunets_erasing_dependency_graph.pdf └── README.md ├── sanity_check ├── backends │ ├── peft │ │ ├── import_utils.py │ │ ├── tuners │ │ │ ├── __init__.py │ │ │ ├── prefix_tuning.py │ │ │ ├── prompt_tuning.py │ │ │ └── p_tuning.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── save_and_load.py │ │ │ └── config.py │ │ ├── __init__.py │ │ └── mapping.py │ ├── demonet_weightshare_case2.py │ ├── diffusion │ │ ├── configs │ │ │ ├── celeba.yml │ │ │ ├── bedroom.yml │ │ │ ├── church.yml │ │ │ └── cifar10.yml │ │ └── ema.py │ ├── __init__.py │ ├── demonet_concat_case1.py │ ├── demonet_in_case3.py │ ├── demonet_weightshare_case1.py │ ├── demonet_concat_case2.py │ ├── demonet_in_case4.py │ ├── demonet_convtranspose_in_case1.py │ ├── demonet_convtranspose_in_case2.py │ ├── carn │ │ ├── carn.py │ │ └── ops.py │ ├── demo_group_conv_case1.py │ ├── demonet_batchnorm_pruning.py │ └── densenet.py ├── README.md ├── test_in_case3.py ├── test_in_case4.py ├── test_weight_share_case1.py ├── test_convtranspose_in_case2.py ├── test_convtranspose_in_case1.py ├── test_diffmodel_cifar.py ├── test_diffmodel_celeba.py ├── test_diffmodel_church.py ├── test_diffmodel_bedroom.py ├── test_batchnorm_case1.py ├── test_weight_share_case2.py ├── sanity_check.py ├── test_densenet121.py ├── test_resnet50.py ├── test_vgg16bn.py ├── test_shufflefacenet.py ├── test_resnet18.py ├── test_concat_case1.py ├── test_concat_case2.py ├── test_carn.py ├── test_convnextlarge.py ├── test_convnextxlarge.py ├── test_convnexttiny.py ├── test_groupconv_case1.py ├── test_yolov5.py ├── test_llamav1.py ├── test_bert.py ├── test_llamav2.py ├── test_llamav1_lora.py └── peft_lora │ └── utils.py ├── LICENSE ├── setup.py └── tutorials ├── utils └── utils.py ├── README.md └── 04.oto_distributed_data_parallelism.ipynb /only_train_once/assets/__init__.py: -------------------------------------------------------------------------------- 1 | from .theme import * -------------------------------------------------------------------------------- /only_train_once/operation/__init__.py: -------------------------------------------------------------------------------- 1 | from .operator import * -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | graphviz 3 | pydot 4 | onnx 5 | pygraphviz -------------------------------------------------------------------------------- /only_train_once/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph_transform import * 2 | from .tensor_transform import * -------------------------------------------------------------------------------- /only_train_once/dependency_graph/__init__.py: -------------------------------------------------------------------------------- 1 | from .pruning_dependency import build_pruning_dependency_graph -------------------------------------------------------------------------------- /only_train_once/graph/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph import * 2 | from .node_group import * 3 | from .node import * -------------------------------------------------------------------------------- /only_train_once/subnet_construction/__init__.py: -------------------------------------------------------------------------------- 1 | from .pruning_compression import automated_pruning_compression -------------------------------------------------------------------------------- /visual_examples/pruning/carn.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/carn.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/phi2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/phi2.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/llama.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/llama.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/resnet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/resnet.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/tel_gan.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/tel_gan.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/vggbn.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/vggbn.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/yolov5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/yolov5.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/bert_base.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/bert_base.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/convnext.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/convnext.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/densenet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/densenet.pdf -------------------------------------------------------------------------------- /only_train_once/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .hesso import HESSO 2 | # from .lhspg import LHSPG 3 | # from .dhspg import DHSPG 4 | # from .h2spg import H2SPG -------------------------------------------------------------------------------- /visual_examples/pruning/llama_lora.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/llama_lora.pdf -------------------------------------------------------------------------------- /visual_examples/erasing/stackedunets_h2spg.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/erasing/stackedunets_h2spg.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/DemoNetConcatCase1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/DemoNetConcatCase1.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/DemoNetConcatCase2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/DemoNetConcatCase2.pdf -------------------------------------------------------------------------------- /visual_examples/erasing/stackedunets_subnetwork.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/erasing/stackedunets_subnetwork.pdf -------------------------------------------------------------------------------- /visual_examples/erasing/stackedunets_trace_graph.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/erasing/stackedunets_trace_graph.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/DemoNetWeightShareCase1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/DemoNetWeightShareCase1.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/DemoNetWeightShareCase2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/DemoNetWeightShareCase2.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/DemonetBatchnormPruning.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/DemonetBatchnormPruning.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/DemoNetConvtransposeInCase1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/DemoNetConvtransposeInCase1.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/DemoNetConvtransposeInCase2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/DemoNetConvtransposeInCase2.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/DemoNetInstanceNorm2DCase3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/DemoNetInstanceNorm2DCase3.pdf -------------------------------------------------------------------------------- /visual_examples/pruning/yolov5_with_param_displayed.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/pruning/yolov5_with_param_displayed.pdf -------------------------------------------------------------------------------- /visual_examples/erasing/stackedunets_erasing_dependency_graph.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyic/only_train_once_personal_footprint/HEAD/visual_examples/erasing/stackedunets_erasing_dependency_graph.pdf -------------------------------------------------------------------------------- /only_train_once/assets/theme.py: -------------------------------------------------------------------------------- 1 | THEMES = { 2 | "basic": { 3 | "background_color": "#FFFFFF", 4 | "fill_color": "#E8E8E8", 5 | "outline_color": "#000000", 6 | "font_color": "#000000", 7 | "font_name": "Times", 8 | "font_size": "10", 9 | "margin": "0,0", 10 | "padding": "1.0,0.5", 11 | }, 12 | } -------------------------------------------------------------------------------- /sanity_check/backends/peft/import_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import importlib 16 | 17 | 18 | def is_bnb_available(): 19 | return importlib.util.find_spec("bitsandbytes") is not None 20 | -------------------------------------------------------------------------------- /sanity_check/backends/demonet_weightshare_case2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class DemoNetWeightShareCase2(nn.Module): 5 | def __init__(self): 6 | super(DemoNetWeightShareCase2, self).__init__() 7 | self.conv1 = nn.Conv2d(3, 832, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 8 | self.conv2 = nn.Conv2d(832, 416, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 9 | self.conv3 = nn.Conv2d(832, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 10 | self.conv4 = nn.Conv2d(1536, 416, kernel_size=(1, 1), stride=(1, 1)) 11 | self.conv5 = nn.Conv2d(832, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 12 | self.num_layers = 4 13 | 14 | def forward(self, x): 15 | x = self.conv1(x) 16 | skip_x = self.conv2(x) 17 | for i in range(self.num_layers): 18 | x = torch.cat([self.conv4(self.conv3(x)), skip_x], dim=1) 19 | return self.conv5(x) 20 | -------------------------------------------------------------------------------- /only_train_once/optimizer/hyperparameter.py: -------------------------------------------------------------------------------- 1 | DEFAULT_OPT_PARAMS = { 2 | "sgd": { 3 | "first_momentum": 0.0, 4 | "second_momentum": 0.0, 5 | "dampening": 0.0, 6 | "weight_decay": 0.0, 7 | "lmbda": 1e-3, 8 | "lmbda_amplify": 2, 9 | "hat_lmbda_coeff": 10 10 | } 11 | , 12 | "adam": { 13 | "lr": 1e-3, 14 | "first_momentum": 0.9, 15 | "second_momentum": 0.999, 16 | "dampening": 0.0, 17 | "weight_decay": 0.0, 18 | "lmbda": 1e-2, 19 | "lmbda_amplify": 20, 20 | "hat_lmbda_coeff": 1e3 21 | } 22 | , 23 | "adamw": { 24 | "lr": 1e-3, 25 | "first_momentum": 0.9, 26 | "second_momentum": 0.999, 27 | "dampening": 0.0, 28 | "weight_decay": 1e-2, 29 | "lmbda": 1e-2, 30 | "lmbda_amplify": 20, 31 | "hat_lmbda_coeff": 1e3 32 | } 33 | } 34 | 35 | SUPPORT_GRADIENT_ESTIMATES = ['sgd', 'adam', 'adamw'] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tianyi Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /sanity_check/backends/diffusion/configs/celeba.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CELEBA" 3 | image_size: 64 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 4 11 | 12 | model: 13 | type: "simple" 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: [1, 2, 2, 2, 4] 18 | num_res_blocks: 2 19 | attn_resolutions: [16, ] 20 | dropout: 0.1 21 | var_type: fixedlarge 22 | ema_rate: 0.9999 23 | ema: True 24 | resamp_with_conv: True 25 | 26 | diffusion: 27 | beta_schedule: linear 28 | beta_start: 0.0001 29 | beta_end: 0.02 30 | num_diffusion_timesteps: 1000 31 | 32 | training: 33 | batch_size: 128 34 | n_epochs: 10000 35 | n_iters: 5000000 36 | snapshot_freq: 5000 37 | validation_freq: 20000 38 | 39 | sampling: 40 | batch_size: 32 41 | last_only: True 42 | 43 | optim: 44 | weight_decay: 0.000 45 | optimizer: "Adam" 46 | lr: 0.0002 47 | beta1: 0.9 48 | amsgrad: false 49 | eps: 0.00000001 50 | grad_clip: 1.0 51 | -------------------------------------------------------------------------------- /sanity_check/backends/diffusion/configs/bedroom.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "LSUN" 3 | category: "bedroom" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 32 12 | 13 | model: 14 | type: "simple" 15 | in_channels: 3 16 | out_ch: 3 17 | ch: 128 18 | ch_mult: [1, 1, 2, 2, 4, 4] 19 | num_res_blocks: 2 20 | attn_resolutions: [16, ] 21 | dropout: 0.0 22 | var_type: fixedsmall 23 | ema_rate: 0.999 24 | ema: True 25 | resamp_with_conv: True 26 | 27 | diffusion: 28 | beta_schedule: linear 29 | beta_start: 0.0001 30 | beta_end: 0.02 31 | num_diffusion_timesteps: 1000 32 | 33 | training: 34 | batch_size: 64 35 | n_epochs: 10000 36 | n_iters: 5000000 37 | snapshot_freq: 5000 38 | validation_freq: 2000 39 | 40 | sampling: 41 | batch_size: 32 42 | last_only: True 43 | 44 | optim: 45 | weight_decay: 0.000 46 | optimizer: "Adam" 47 | lr: 0.00002 48 | beta1: 0.9 49 | amsgrad: false 50 | eps: 0.00000001 51 | -------------------------------------------------------------------------------- /sanity_check/backends/diffusion/configs/church.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "LSUN" 3 | category: "church_outdoor" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 32 12 | 13 | model: 14 | type: "simple" 15 | in_channels: 3 16 | out_ch: 3 17 | ch: 128 18 | ch_mult: [1, 1, 2, 2, 4, 4] 19 | num_res_blocks: 2 20 | attn_resolutions: [16, ] 21 | dropout: 0.0 22 | var_type: fixedsmall 23 | ema_rate: 0.999 24 | ema: True 25 | resamp_with_conv: True 26 | 27 | diffusion: 28 | beta_schedule: linear 29 | beta_start: 0.0001 30 | beta_end: 0.02 31 | num_diffusion_timesteps: 1000 32 | 33 | training: 34 | batch_size: 64 35 | n_epochs: 10000 36 | n_iters: 5000000 37 | snapshot_freq: 5000 38 | validation_freq: 2000 39 | 40 | sampling: 41 | batch_size: 32 42 | last_only: True 43 | 44 | optim: 45 | weight_decay: 0.000 46 | optimizer: "Adam" 47 | lr: 0.00002 48 | beta1: 0.9 49 | amsgrad: false 50 | eps: 0.00000001 51 | -------------------------------------------------------------------------------- /sanity_check/backends/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf_llama.modeling_llama import LlamaForCausalLM 2 | from .demonet_concat_case1 import DemoNetConcatCase1 3 | from .demonet_concat_case2 import DemoNetConcatCase2 4 | from .demonet_convtranspose_in_case1 import DemoNetConvtransposeInCase1 5 | from .demonet_convtranspose_in_case2 import DemoNetConvtransposeInCase2 6 | from .demonet_weightshare_case1 import DemoNetWeightShareCase1 7 | from .demonet_weightshare_case2 import DemoNetWeightShareCase2 8 | from .demo_group_conv_case1 import DemoNetGroupConvCase1 9 | from .demonet_in_case3 import DemoNetInstanceNorm2DCase3 10 | from .demonet_in_case4 import DemoNetInstanceNorm2DCase4 11 | from .densenet import densenet121, densenet161, densenet169, densenet201 12 | from .resnet_cifar10 import resnet18_cifar10 13 | from .demonet_batchnorm_pruning import DemonetBatchnormPruning 14 | from .carn.carn import CarnNet 15 | from .convnext import convnext_tiny, convnext_small, convnext_base, convnext_large, convnext_xlarge 16 | # from .shufflefacenet import ShuffleFaceNet 17 | # from .mixnet import mixnet_s, mixnet_m, mixnet_l 18 | from .diffusion.diffusion import DiffModelCIFAR, DiffModelBedroom, DiffModelCeleba, DiffModelChurch -------------------------------------------------------------------------------- /sanity_check/README.md: -------------------------------------------------------------------------------- 1 | # Sanity Check 2 | 3 | We highly recommend to proceed a sanity check to test the compliance of OTO onto target DNN. The sanity check will randomly pick up a set of minimally removal structures as redundant 4 | 5 | ```python 6 | oto.random_set_zero_groups() 7 | ``` 8 | and produce compact subnetwork, as presented in [sanity_check](https://github.com/tianyic/only_train_once/blob/main/sanity_check/test_resnet18.py). If sanity check does not pass, please mark illed node groups as unprunable via 9 | 10 | ```python 11 | oto.mark_unprunable_by_node_ids() 12 | ``` 13 | For example, in [YOLOv5](https://github.com/tianyic/only_train_once/blob/main/sanity_check/test_yolov5.py), we mark the node groups corresponding to detection heads as unprunable. 14 | 15 | If all variable groups of pruning minimally removal structures are pruning zero-invariant groups (PZIGs), the returned sub-network should return the exact same output as the full group sparse model given random inputs. 16 | 17 | Run the sanity check by the below command 18 | 19 | ```python 20 | python sanity_check.py 21 | ``` 22 | 23 | Note some sanity checks may require additional dependency, thereby comment off the ones that you do not need. -------------------------------------------------------------------------------- /sanity_check/backends/diffusion/configs/cifar10.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CIFAR10" 3 | image_size: 32 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 4 11 | 12 | model: 13 | type: "simple" 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: [1, 2, 2, 2] 18 | # ch_mult: [1, 2, 2] 19 | num_res_blocks: 2 20 | attn_resolutions: [16, ] 21 | dropout: 0.1 22 | var_type: fixedlarge 23 | ema_rate: 0.9999 24 | ema: True 25 | resamp_with_conv: True 26 | 27 | diffusion: 28 | beta_schedule: linear 29 | beta_start: 0.0001 30 | beta_end: 0.02 31 | num_diffusion_timesteps: 1000 32 | 33 | training: 34 | batch_size: 128 35 | n_epochs: 10000 36 | n_iters: 5000000 37 | snapshot_freq: 5000 38 | validation_freq: 2000 39 | 40 | sampling: 41 | batch_size: 128 42 | last_only: True 43 | ckpt_id: 100000 44 | 45 | optim: 46 | weight_decay: 0.000 47 | optimizer: "Adam" 48 | lr: 0.0002 49 | beta1: 0.9 50 | amsgrad: false 51 | eps: 0.00000001 52 | grad_clip: 1.0 53 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | def parse_requirements(filename): 4 | lineiter = (line.strip() for line in open(filename)) 5 | return [line for line in lineiter if line and not line.startswith("#")] 6 | 7 | reqs = parse_requirements('requirements.txt') 8 | 9 | VERSION = '3.0.1' 10 | DESCRIPTION = 'Only Train Once (OTO): Automatic One-Shot General DNN Training and Compression Framework' 11 | LONG_DESCRIPTION = 'Only Train Once (OTO): Automatic One-Shot General DNN Training and Compression Framework' 12 | 13 | setup( 14 | name="only_train_once", 15 | version=VERSION, 16 | description=DESCRIPTION, 17 | long_description=LONG_DESCRIPTION, 18 | author="Tianyi Chen", 19 | author_email="tiachen@microsoft.com", 20 | license='MIT', 21 | packages=find_packages(), 22 | install_requires=reqs, 23 | url="https://github.com/tianyic/only_train_once", 24 | keywords='automatic, one-shot, structure pruning, sparse optimization', 25 | classifiers= [ 26 | "Development Status :: 3 - Alpha", 27 | "Intended Audience :: Developers", 28 | 'License :: OSI Approved :: MIT License', 29 | "Programming Language :: Python :: 3", 30 | ] 31 | ) -------------------------------------------------------------------------------- /sanity_check/backends/peft/tuners/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .lora import LoraConfig, LoraModel 21 | from .adalora import AdaLoraConfig, AdaLoraModel 22 | from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType 23 | from .prefix_tuning import PrefixEncoder, PrefixTuningConfig 24 | from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit 25 | -------------------------------------------------------------------------------- /sanity_check/test_in_case3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DemoNetInstanceNorm2DCase3 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestINCase3(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 32, 32)): 11 | model = DemoNetInstanceNorm2DCase3() 12 | oto = OTO(model, dummy_input) 13 | 14 | oto.visualize(view=False, out_dir=OUT_DIR) 15 | 16 | oto.random_set_zero_groups() 17 | oto.construct_subnet(out_dir=OUT_DIR) 18 | full_model = torch.load(oto.full_group_sparse_model_path) 19 | compressed_model = torch.load(oto.compressed_model_path) 20 | 21 | full_output = full_model(dummy_input) 22 | compressed_output = compressed_model(dummy_input) 23 | 24 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 25 | print("Maximum output difference : ", max_output_diff.item()) 26 | full_model_size = os.stat(oto.full_group_sparse_model_path) 27 | compressed_model_size = os.stat(oto.compressed_model_path) 28 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 29 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 30 | self.assertLessEqual(max_output_diff, 3.0) 31 | -------------------------------------------------------------------------------- /sanity_check/test_in_case4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DemoNetInstanceNorm2DCase4 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | 10 | class TestINCase4(unittest.TestCase): 11 | def test_sanity(self, dummy_input=torch.rand(1, 3, 32, 32)): 12 | model = DemoNetInstanceNorm2DCase4() 13 | oto = OTO(model, dummy_input) 14 | 15 | oto.visualize(view=False, out_dir=OUT_DIR) 16 | 17 | oto.random_set_zero_groups() 18 | oto.construct_subnet(out_dir=OUT_DIR) 19 | full_model = torch.load(oto.full_group_sparse_model_path) 20 | compressed_model = torch.load(oto.compressed_model_path) 21 | 22 | full_output = full_model(dummy_input) 23 | compressed_output = compressed_model(dummy_input) 24 | 25 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 26 | print("Maximum output difference : ", max_output_diff.item()) 27 | full_model_size = os.stat(oto.full_group_sparse_model_path) 28 | compressed_model_size = os.stat(oto.compressed_model_path) 29 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 30 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 31 | self.assertLessEqual(max_output_diff, 1e-4) 32 | 33 | -------------------------------------------------------------------------------- /sanity_check/test_weight_share_case1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DemoNetWeightShareCase1 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestDemoNetWeighShareCase1(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 32, 32)): 11 | model = DemoNetWeightShareCase1() 12 | oto = OTO(model, dummy_input) 13 | 14 | oto.visualize(view=False, out_dir=OUT_DIR) 15 | 16 | oto.random_set_zero_groups() 17 | oto.construct_subnet(out_dir=OUT_DIR) 18 | full_model = torch.load(oto.full_group_sparse_model_path) 19 | compressed_model = torch.load(oto.compressed_model_path) 20 | 21 | full_output = full_model(dummy_input) 22 | compressed_output = compressed_model(dummy_input) 23 | 24 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 25 | print("Maximum output difference : ", max_output_diff.item()) 26 | full_model_size = os.stat(oto.full_group_sparse_model_path) 27 | compressed_model_size = os.stat(oto.compressed_model_path) 28 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 29 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 30 | self.assertLessEqual(max_output_diff, 1e-4) 31 | -------------------------------------------------------------------------------- /sanity_check/test_convtranspose_in_case2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DemoNetConvtransposeInCase2 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestConvTransposeInCase2(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 32, 32)): 11 | model = DemoNetConvtransposeInCase2() 12 | oto = OTO(model, dummy_input) 13 | 14 | oto.visualize(view=False, out_dir=OUT_DIR) 15 | 16 | oto.random_set_zero_groups() 17 | oto.construct_subnet(out_dir=OUT_DIR) 18 | full_model = torch.load(oto.full_group_sparse_model_path) 19 | compressed_model = torch.load(oto.compressed_model_path) 20 | 21 | full_output = full_model(dummy_input) 22 | compressed_output = compressed_model(dummy_input) 23 | 24 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 25 | print("Maximum output difference : ", max_output_diff.item()) 26 | full_model_size = os.stat(oto.full_group_sparse_model_path) 27 | compressed_model_size = os.stat(oto.compressed_model_path) 28 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 29 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 30 | self.assertLessEqual(max_output_diff, 2.0) 31 | -------------------------------------------------------------------------------- /sanity_check/test_convtranspose_in_case1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DemoNetConvtransposeInCase1 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestConvTransposeInCase1(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 32, 32)): 11 | model = DemoNetConvtransposeInCase1() 12 | oto = OTO(model, dummy_input) 13 | 14 | oto.visualize(view=False, out_dir=OUT_DIR) 15 | 16 | oto.random_set_zero_groups() 17 | oto.construct_subnet(out_dir=OUT_DIR) 18 | full_model = torch.load(oto.full_group_sparse_model_path) 19 | compressed_model = torch.load(oto.compressed_model_path) 20 | 21 | full_output = full_model(dummy_input, debug=True) 22 | compressed_output = compressed_model(dummy_input, debug=True) 23 | 24 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 25 | print("Maximum output difference : ", max_output_diff.item()) 26 | full_model_size = os.stat(oto.full_group_sparse_model_path) 27 | compressed_model_size = os.stat(oto.compressed_model_path) 28 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 29 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 30 | self.assertLessEqual(max_output_diff, 3.0) 31 | -------------------------------------------------------------------------------- /visual_examples/README.md: -------------------------------------------------------------------------------- 1 | # Visualization of dependency graphs for pruning and erasing mode 2 | 3 | This visualization of pruning dependency graphs and erasing dependency graphs provides a frequently used tool for employing OTO onto new unseen DNNs if meets errors. 4 | 5 | In the [`pruning`](https://github.com/tianyic/only_train_once/tree/main/visual_examples/pruning) folder, we provide the generated pruning dependency graphs for the DNNs covered in the [`sanity_check`](https://github.com/tianyic/only_train_once/tree/main/sanity_check) along with some dependency graphs met during our daily use of OTO onto various applications. 6 | 7 | ```python 8 | # view will try to open generated dependency graphs via some pdf reader, set up as False if running on remote servers. 9 | oto.visualize(view=True or False, out_dir=PATH) 10 | ``` 11 | 12 | In the depicted pruning dependency graphs, 13 | 14 | - The nodes marked by the same color form one node group. The nodes in the same node group have dependency that need to be pruned together. 15 | 16 | - One node group is **prunable** if it is filled by solid color. 17 | 18 | - One node group is **unprunable** if it is outlined by dash lines. 19 | 20 | - Nodes with black font color have trainable variables. Otherwise, the font color becomes white. 21 | 22 | 23 | We will provide more explanations for the visualization [`erasing`](https://github.com/tianyic/only_train_once/tree/main/visual_examples/erasing) mode. -------------------------------------------------------------------------------- /sanity_check/test_diffmodel_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DiffModelCIFAR 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestDiffModelCIFAR(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 4, 32, 32)): 11 | model = DiffModelCIFAR() 12 | 13 | oto = OTO(model, dummy_input) 14 | # The layout rendering of DiffUnet is out-of-time via graphviz 15 | # oto.visualize(view=False, out_dir=OUT_DIR, display_params=True) 16 | oto.random_set_zero_groups() 17 | 18 | oto.construct_subnet(out_dir=OUT_DIR) 19 | 20 | full_model = torch.load(oto.full_group_sparse_model_path) 21 | compressed_model = torch.load(oto.compressed_model_path) 22 | 23 | full_output = full_model(dummy_input) 24 | compressed_output = compressed_model(dummy_input) 25 | 26 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 27 | print("Maximum output difference " + str(max_output_diff.item())) 28 | # self.assertLessEqual(max_output_diff, 1e-4) 29 | full_model_size = os.stat(oto.full_group_sparse_model_path) 30 | compressed_model_size = os.stat(oto.compressed_model_path) 31 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 32 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") -------------------------------------------------------------------------------- /sanity_check/test_diffmodel_celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DiffModelCeleba 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestDiffModelCeleba(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 4, 64, 64)): 11 | model = DiffModelCeleba() 12 | 13 | oto = OTO(model, dummy_input) 14 | # The layout rendering of DiffUnet is out-of-time via graphviz 15 | # oto.visualize(view=False, out_dir=OUT_DIR, display_params=True) 16 | oto.random_set_zero_groups() 17 | 18 | oto.construct_subnet(out_dir=OUT_DIR) 19 | 20 | full_model = torch.load(oto.full_group_sparse_model_path) 21 | compressed_model = torch.load(oto.compressed_model_path) 22 | 23 | full_output = full_model(dummy_input) 24 | compressed_output = compressed_model(dummy_input) 25 | 26 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 27 | print("Maximum output difference " + str(max_output_diff.item())) 28 | # self.assertLessEqual(max_output_diff, 1e-4) 29 | full_model_size = os.stat(oto.full_group_sparse_model_path) 30 | compressed_model_size = os.stat(oto.compressed_model_path) 31 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 32 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") -------------------------------------------------------------------------------- /sanity_check/test_diffmodel_church.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DiffModelChurch 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestDiffModelChurch(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 4, 256, 256)): 11 | model = DiffModelChurch() 12 | 13 | oto = OTO(model, dummy_input) 14 | # The layout rendering of DiffUnet is out-of-time via graphviz 15 | # oto.visualize(view=False, out_dir=OUT_DIR, display_params=True) 16 | oto.random_set_zero_groups() 17 | 18 | oto.construct_subnet(out_dir=OUT_DIR) 19 | 20 | full_model = torch.load(oto.full_group_sparse_model_path) 21 | compressed_model = torch.load(oto.compressed_model_path) 22 | 23 | full_output = full_model(dummy_input) 24 | compressed_output = compressed_model(dummy_input) 25 | 26 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 27 | print("Maximum output difference " + str(max_output_diff.item())) 28 | # self.assertLessEqual(max_output_diff, 1e-4) 29 | full_model_size = os.stat(oto.full_group_sparse_model_path) 30 | compressed_model_size = os.stat(oto.compressed_model_path) 31 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 32 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") -------------------------------------------------------------------------------- /sanity_check/test_diffmodel_bedroom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DiffModelBedroom 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestDiffModelBedroom(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 4, 256, 256)): 11 | model = DiffModelBedroom() 12 | 13 | oto = OTO(model, dummy_input) 14 | # The layout rendering of DiffUnet is out-of-time via graphviz 15 | # oto.visualize(view=False, out_dir=OUT_DIR, display_params=True) 16 | oto.random_set_zero_groups() 17 | 18 | oto.construct_subnet(out_dir=OUT_DIR) 19 | 20 | full_model = torch.load(oto.full_group_sparse_model_path) 21 | compressed_model = torch.load(oto.compressed_model_path) 22 | 23 | full_output = full_model(dummy_input) 24 | compressed_output = compressed_model(dummy_input) 25 | 26 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 27 | print("Maximum output difference " + str(max_output_diff.item())) 28 | # self.assertLessEqual(max_output_diff, 1e-4) 29 | full_model_size = os.stat(oto.full_group_sparse_model_path) 30 | compressed_model_size = os.stat(oto.compressed_model_path) 31 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 32 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") -------------------------------------------------------------------------------- /sanity_check/backends/demonet_concat_case1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class DemoNetConcatCase1(nn.Module): 5 | def __init__(self): 6 | super(DemoNetConcatCase1, self).__init__() 7 | self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 8 | self.bn_1 = nn.BatchNorm2d(64) 9 | self.conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 10 | self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 11 | self.bn_2 = nn.BatchNorm2d(128) 12 | self.bn_3 = nn.BatchNorm2d(128) 13 | self.bn_4 = nn.BatchNorm2d(192) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.conv5 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 16 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 17 | self.gemm1 = nn.Linear(in_features=256, out_features=128, bias=True) 18 | self.gemm2 = nn.Linear(in_features=128, out_features=10, bias=True) 19 | 20 | def forward(self, x): 21 | x_1 = self.relu(self.bn_1(self.conv1(x))) 22 | x_1_tmp = self.conv2(x_1) 23 | x_3 = self.conv3(x_1) 24 | x_2 = self.bn_2(x_1_tmp) + self.bn_3(x_1_tmp) + x_3 25 | x_4 = torch.cat([x_1, x_2], dim=1) 26 | x_5 = self.conv5(self.bn_4(x_4)) 27 | x_7 = self.avg_pool(x_5) 28 | x_7 = x_7.view(x_7.size(0), -1) 29 | return self.gemm2(self.gemm1(x_7)) -------------------------------------------------------------------------------- /sanity_check/backends/demonet_in_case3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DemoNetInstanceNorm2DCase3(nn.Module): 5 | def __init__(self): 6 | super(DemoNetInstanceNorm2DCase3, self).__init__() 7 | self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 8 | self.conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 9 | self.conv3 = nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 10 | self.conv4 = nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 11 | 12 | self.in_1 = nn.InstanceNorm2d(64, affine=True) 13 | self.in_2 = nn.InstanceNorm2d(128, affine=True) 14 | self.in_3 = nn.InstanceNorm2d(256, affine=True) 15 | self.in_4 = nn.InstanceNorm2d(512, affine=True) 16 | 17 | self.leakyrelu = nn.LeakyReLU() 18 | 19 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 20 | 21 | self.gemm1 = nn.Linear(in_features=512, out_features=128, bias=True) 22 | self.gemm2 = nn.Linear(in_features=128, out_features=10, bias=True) 23 | 24 | def forward(self, x): 25 | x = self.leakyrelu(self.in_1(self.conv1(x))) 26 | x = self.leakyrelu(self.in_2(self.conv2(x))) 27 | x = self.leakyrelu(self.in_3(self.conv3(x))) 28 | x = self.leakyrelu(self.in_4(self.conv4(x))) 29 | x = self.avg_pool(x) 30 | x = x.view(x.size(0), -1) 31 | return self.gemm2(self.gemm1(x)) -------------------------------------------------------------------------------- /sanity_check/test_batchnorm_case1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DemonetBatchnormPruning 4 | import unittest 5 | import os 6 | import torch.nn as nn 7 | 8 | OUT_DIR = './cache' 9 | 10 | class TestDemoNetBatchnormPruningCase1(unittest.TestCase): 11 | def test_sanity(self): 12 | model = DemonetBatchnormPruning(13,32,256,5,3,nn.LeakyReLU(),False,256) 13 | dummy_input=[torch.rand(1, 3, 256, 256),torch.rand(1, 4, 256, 256),torch.rand(1, 6, 256, 256)] 14 | oto = OTO(model, dummy_input) 15 | 16 | oto.visualize(view=False, out_dir=OUT_DIR) 17 | oto.random_set_zero_groups() 18 | oto.construct_subnet(out_dir=OUT_DIR) 19 | full_model = torch.load(oto.full_group_sparse_model_path) 20 | compressed_model = torch.load(oto.compressed_model_path) 21 | 22 | full_output = full_model(dummy_input) 23 | compressed_output = compressed_model(dummy_input) 24 | 25 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 26 | print("Maximum output difference : ", max_output_diff.item()) 27 | full_model_size = os.stat(oto.full_group_sparse_model_path) 28 | compressed_model_size = os.stat(oto.compressed_model_path) 29 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 30 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 31 | self.assertLessEqual(max_output_diff, 1e-4) -------------------------------------------------------------------------------- /sanity_check/backends/demonet_weightshare_case1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class DemoNetWeightShareCase1(nn.Module): 5 | def __init__(self): 6 | super(DemoNetWeightShareCase1, self).__init__() 7 | self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 8 | self.bn_1 = nn.BatchNorm2d(64) 9 | self.conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 10 | # self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 11 | self.bn_2 = nn.BatchNorm2d(128) 12 | self.bn_3 = nn.BatchNorm2d(128) 13 | self.bn_4 = nn.BatchNorm2d(192) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.conv5 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 16 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 17 | self.gemm1 = nn.Linear(in_features=256, out_features=128, bias=True) 18 | self.gemm2 = nn.Linear(in_features=128, out_features=10, bias=True) 19 | 20 | def forward(self, x): 21 | x_1 = self.relu(self.bn_1(self.conv1(x))) 22 | x_1_tmp = self.conv2(x_1) 23 | x_3 = self.conv2(x_1) 24 | # x_3 = self.conv3(x_1) 25 | x_2 = self.bn_2(x_1_tmp) + self.bn_3(x_1_tmp) + x_3 26 | x_4 = torch.cat([x_1, x_2], dim=1) 27 | x_5 = self.conv5(self.bn_4(x_4)) 28 | x_7 = self.avg_pool(x_5) 29 | x_7 = x_7.view(x_7.size(0), -1) 30 | return self.gemm2(self.gemm1(x_7)) -------------------------------------------------------------------------------- /sanity_check/backends/peft/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType 21 | from .other import ( 22 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 23 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, 24 | TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, 25 | CONFIG_NAME, 26 | WEIGHTS_NAME, 27 | _set_trainable, 28 | bloom_model_postprocess_past_key_value, 29 | prepare_model_for_int8_training, 30 | shift_tokens_right, 31 | transpose, 32 | _get_submodules, 33 | _set_adapter, 34 | _freeze_adapter, 35 | ModulesToSaveWrapper, 36 | ) 37 | from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict 38 | -------------------------------------------------------------------------------- /sanity_check/backends/demonet_concat_case2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class DemoNetConcatCase2(nn.Module): 5 | def __init__(self): 6 | super(DemoNetConcatCase2, self).__init__() 7 | self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 8 | self.bn_1 = nn.BatchNorm2d(64) 9 | self.conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 10 | self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 11 | self.bn_2 = nn.BatchNorm2d(128) 12 | self.bn_3 = nn.BatchNorm2d(128) 13 | self.bn_4 = nn.BatchNorm2d(768) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.conv5 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 16 | self.conv6 = nn.Conv2d(384, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 17 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 18 | self.gemm1 = nn.Linear(in_features=768, out_features=128, bias=True) 19 | self.gemm2 = nn.Linear(in_features=128, out_features=10, bias=True) 20 | 21 | def forward(self, x): 22 | x_1 = self.relu(self.bn_1(self.conv1(x))) 23 | x_1_tmp = self.conv2(x_1) 24 | x_3 = self.conv3(x_1) 25 | x_2 = self.bn_2(x_1_tmp) + self.bn_3(x_1_tmp) + x_3 26 | x_4 = torch.cat([x_1, x_2], dim=1) 27 | x_5 = self.conv5(x_4) 28 | x_6 = self.conv6(torch.cat([x_5, x_2], dim=1)) 29 | x_6 = self.bn_4(torch.cat([x_6, x_5], dim=1)) 30 | x_7 = self.avg_pool(x_6) 31 | x_7 = x_7.view(x_7.size(0), -1) 32 | return self.gemm2(self.gemm1(x_7)) -------------------------------------------------------------------------------- /sanity_check/backends/demonet_in_case4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DemoNetInstanceNorm2DCase4(nn.Module): 6 | def __init__(self): 7 | super(DemoNetInstanceNorm2DCase4, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 9 | self.conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 10 | self.conv3 = nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 11 | self.conv4 = nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 12 | 13 | self.in_1 = nn.InstanceNorm2d(64, affine=True) 14 | self.in_2 = nn.InstanceNorm2d(128, affine=True) 15 | self.in_3 = nn.InstanceNorm2d(256, affine=True) 16 | self.in_4 = nn.InstanceNorm2d(512, affine=True) 17 | 18 | self.prelu_1 = nn.PReLU(num_parameters=64) 19 | self.prelu_2 = nn.PReLU(num_parameters=128) 20 | self.prelu_3 = nn.PReLU(num_parameters=256) 21 | self.prelu_4 = nn.PReLU(num_parameters=512) 22 | 23 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 24 | 25 | self.gemm1 = nn.Linear(in_features=512, out_features=128, bias=True) 26 | self.gemm2 = nn.Linear(in_features=128, out_features=10, bias=True) 27 | 28 | def forward(self, x): 29 | x = self.prelu_1(self.in_1(self.conv1(x))) 30 | x = self.prelu_2(self.in_2(self.conv2(x))) 31 | x = self.prelu_3(self.in_3(self.conv3(x))) 32 | x = self.prelu_4(self.in_4(self.conv4(x))) 33 | x = self.avg_pool(x) 34 | x = x.view(x.size(0), -1) 35 | return self.gemm2(self.gemm1(x)) 36 | -------------------------------------------------------------------------------- /sanity_check/test_weight_share_case2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DemoNetWeightShareCase2 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestDemoNetWeighShareCase2(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 32, 32)): 11 | model = DemoNetWeightShareCase2() 12 | oto = OTO(model, dummy_input) 13 | 14 | oto.visualize(view=False, out_dir=OUT_DIR) 15 | 16 | unpruned_node_group_ids = ['node-11', 'node-12', 'node-14', 'node-17', 'node-20', 'node-23'] 17 | for node_group in oto._graph.node_groups.values(): 18 | if node_group.id in unpruned_node_group_ids: 19 | node_group.is_prunable = False 20 | 21 | oto.random_set_zero_groups(target_group_sparsity=0.24) 22 | oto.construct_subnet(out_dir=OUT_DIR) 23 | full_model = torch.load(oto.full_group_sparse_model_path) 24 | compressed_model = torch.load(oto.compressed_model_path) 25 | 26 | full_output = full_model(dummy_input) 27 | compressed_output = compressed_model(dummy_input) 28 | 29 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 30 | print("Maximum output difference : ", max_output_diff.item()) 31 | full_model_size = os.stat(oto.full_group_sparse_model_path) 32 | compressed_model_size = os.stat(oto.compressed_model_path) 33 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 34 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 35 | self.assertLessEqual(max_output_diff, 1e-4) 36 | -------------------------------------------------------------------------------- /only_train_once/optimizer/importance_score/__init__.py: -------------------------------------------------------------------------------- 1 | from .magnitude import * 2 | from .cosine_similarity import * 3 | from .taylor import * 4 | import torch 5 | 6 | def calculate_importance_score_dhspg(criteria, param_group): 7 | param_group['importance_scores'] = dict() 8 | with torch.no_grad(): 9 | for cri_name in criteria: 10 | if 'magnitude' == cri_name: 11 | importance_score_by_magnitude_dhspg(param_group) 12 | elif 'avg_magnitude' == cri_name: 13 | importance_score_by_avg_magnitude_dhspg(param_group) 14 | elif 'cosine_similarity' == cri_name: 15 | importance_score_by_cosine_similarity_dhspg(param_group) 16 | elif 'taylor_first_order' == cri_name: 17 | importance_score_by_first_order_taylor_dhspg(param_group) 18 | elif 'taylor_second_order' == cri_name: 19 | importance_score_by_second_order_taylor_dhspg(param_group) 20 | 21 | def calculate_importance_score_lhspg(criteria, param_group, global_params): 22 | with torch.no_grad(): 23 | for cri_name in criteria: 24 | if 'magnitude' in cri_name: 25 | importance_score_by_magnitude_lhspg(param_group) 26 | elif 'avg_magnitude' == cri_name: 27 | importance_score_by_avg_magnitude_lhspg(param_group) 28 | elif 'cosine_similarity' in cri_name: 29 | importance_score_by_cosine_similarity_lhspg(param_group, global_params) 30 | elif 'taylor_first_order' in cri_name: 31 | importance_score_by_first_order_taylor_lhspg(param_group, global_params) 32 | elif 'taylor_second_order' in cri_name: 33 | importance_score_by_second_order_taylor_lhspg(param_group, global_params) -------------------------------------------------------------------------------- /sanity_check/sanity_check.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import sys 3 | import os 4 | currentdir = os.path.dirname(os.path.realpath(__file__)) 5 | parentdir = os.path.dirname(currentdir) 6 | sys.path.append(parentdir) 7 | 8 | """ 9 | LLM Test Cases, needs to add dependency for transformers 10 | """ 11 | from test_llamav1 import TestLLAMAv1 12 | from test_llamav1_lora import TestLLAMAv1LoRA 13 | from test_bert import TestBert 14 | """ 15 | CNN Test Cases 16 | """ 17 | from test_concat_case1 import TestDemoNetConcatCase1 18 | from test_concat_case2 import TestDemoNetConcatCase2 19 | from test_convtranspose_in_case1 import TestConvTransposeInCase1 20 | from test_convtranspose_in_case2 import TestConvTransposeInCase2 21 | from test_in_case3 import TestINCase3 22 | from test_in_case4 import TestINCase4 23 | from test_weight_share_case1 import TestDemoNetWeighShareCase1 24 | from test_weight_share_case2 import TestDemoNetWeighShareCase2 25 | from test_batchnorm_case1 import TestDemoNetBatchnormPruningCase1 26 | from test_groupconv_case1 import TestGroupConvCase1 27 | 28 | from test_resnet18 import TestResNet18 29 | from test_resnet50 import TestResNet50 30 | from test_densenet121 import TestDenseNet121 31 | from test_vgg16bn import TestVGG16BN 32 | from test_carn import TestCARN 33 | from test_convnexttiny import TestConvNextTiny 34 | from test_convnextlarge import TestConvNextLarge 35 | from test_convnextxlarge import TestConvNextXLarge 36 | from test_yolov5 import TestYolov5 # need to install yolov5 dependency 37 | from test_mixnet import TestMixNet 38 | 39 | from test_diffmodel_cifar import TestDiffModelCIFAR 40 | from test_diffmodel_bedroom import TestDiffModelBedroom 41 | from test_diffmodel_celeba import TestDiffModelCeleba 42 | from test_diffmodel_church import TestDiffModelChurch 43 | 44 | 45 | OUT_DIR = './cache' 46 | 47 | os.makedirs(OUT_DIR, exist_ok=True) 48 | 49 | if __name__ == '__main__': 50 | unittest.main() -------------------------------------------------------------------------------- /sanity_check/test_densenet121.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import densenet121 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestDenseNet121(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 32, 32)): 11 | model = densenet121() 12 | oto = OTO(model, dummy_input) 13 | oto.visualize(view=False, out_dir=OUT_DIR) 14 | # For test FLOP and param reductions. 15 | full_flops = oto.compute_flops(in_million=True)['total'] 16 | full_num_params = oto.compute_num_params(in_million=True) 17 | 18 | oto.random_set_zero_groups() 19 | oto.construct_subnet(out_dir=OUT_DIR) 20 | full_model = torch.load(oto.full_group_sparse_model_path) 21 | compressed_model = torch.load(oto.compressed_model_path) 22 | 23 | full_output = full_model(dummy_input) 24 | compressed_output = compressed_model(dummy_input) 25 | 26 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 27 | print("Maximum output difference : ", max_output_diff.item()) 28 | self.assertLessEqual(max_output_diff, 1e-4) 29 | full_model_size = os.stat(oto.full_group_sparse_model_path) 30 | compressed_model_size = os.stat(oto.compressed_model_path) 31 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 32 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 33 | 34 | # Compute FLOP and param for pruned model after oto.construct_subnet() 35 | pruned_flops = oto.compute_flops(in_million=True)['total'] 36 | pruned_num_params = oto.compute_num_params(in_million=True) 37 | 38 | print("FLOP reduction (%) : ", 1.0 - pruned_flops / full_flops) 39 | print("Param reduction (%) : ", 1.0 - pruned_num_params / full_num_params) -------------------------------------------------------------------------------- /sanity_check/test_resnet50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | import torchvision.models 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestResNet50(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 224, 224)): 11 | model = torchvision.models.resnet50() 12 | oto = OTO(model, dummy_input) 13 | oto.visualize(view=False, out_dir=OUT_DIR) 14 | # For test FLOP and param reductions. 15 | full_flops = oto.compute_flops(in_million=True)['total'] 16 | full_num_params = oto.compute_num_params(in_million=True) 17 | 18 | oto.random_set_zero_groups() 19 | oto.construct_subnet(out_dir=OUT_DIR) 20 | full_model = torch.load(oto.full_group_sparse_model_path) 21 | compressed_model = torch.load(oto.compressed_model_path) 22 | 23 | full_output = full_model(dummy_input) 24 | compressed_output = compressed_model(dummy_input) 25 | 26 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 27 | print("Maximum output difference : ", max_output_diff.item()) 28 | self.assertLessEqual(max_output_diff, 1e-4) 29 | full_model_size = os.stat(oto.full_group_sparse_model_path) 30 | compressed_model_size = os.stat(oto.compressed_model_path) 31 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 32 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 33 | 34 | # Compute FLOP and param for pruned model after oto.construct_subnet() 35 | pruned_flops = oto.compute_flops(in_million=True)['total'] 36 | pruned_num_params = oto.compute_num_params(in_million=True) 37 | 38 | print("FLOP reduction (%) : ", 1.0 - pruned_flops / full_flops) 39 | print("Param reduction (%) : ", 1.0 - pruned_num_params / full_num_params) -------------------------------------------------------------------------------- /sanity_check/test_vgg16bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | import torchvision.models 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestVGG16BN(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 224, 224)): 11 | model = torchvision.models.vgg16_bn() 12 | oto = OTO(model, dummy_input) 13 | oto.visualize(view=False, out_dir=OUT_DIR) 14 | 15 | # For test FLOP and param reductions. 16 | full_flops = oto.compute_flops(in_million=True)['total'] 17 | full_num_params = oto.compute_num_params(in_million=True) 18 | 19 | oto.random_set_zero_groups() 20 | oto.construct_subnet(out_dir=OUT_DIR) 21 | full_model = torch.load(oto.full_group_sparse_model_path) 22 | compressed_model = torch.load(oto.compressed_model_path) 23 | 24 | full_output = full_model(dummy_input) 25 | compressed_output = compressed_model(dummy_input) 26 | 27 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 28 | print("Maximum output difference : ", max_output_diff.item()) 29 | self.assertLessEqual(max_output_diff, 1e-4) 30 | full_model_size = os.stat(oto.full_group_sparse_model_path) 31 | compressed_model_size = os.stat(oto.compressed_model_path) 32 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 33 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 34 | 35 | # Compute FLOP and param for pruned model after oto.construct_subnet() 36 | pruned_flops = oto.compute_flops(in_million=True)['total'] 37 | pruned_num_params = oto.compute_num_params(in_million=True) 38 | 39 | print("FLOP reduction (%) : ", 1.0 - pruned_flops / full_flops) 40 | print("Param reduction (%) : ", 1.0 - pruned_num_params / full_num_params) -------------------------------------------------------------------------------- /sanity_check/test_shufflefacenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import ShuffleFaceNet 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestShuffleFaceNet(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 112, 112)): 11 | model = ShuffleFaceNet() 12 | for name, param in model.named_parameters(): 13 | print(name, param.shape, param.requires_grad) 14 | 15 | oto = OTO(model, dummy_input) 16 | oto.mark_unprunable_by_node_ids( 17 | [ 18 | 'node-407', 'node-419', 'node-451', 'node-483', 'node-515', \ 19 | 'node-526', 'node-528', 'node-540', 'node-572', 'node-604', \ 20 | 'node-636', 'node-668', 'node-700', 'node-732', 'node-764', \ 21 | 'node-775', 'node-777', 'node-789', 'node-821', 'node-853', \ 22 | 'node-885' 23 | ] 24 | ) 25 | oto.visualize(view=False, out_dir=OUT_DIR) 26 | 27 | oto.random_set_zero_groups() 28 | oto.construct_subnet(out_dir=OUT_DIR) 29 | full_model = torch.load(oto.full_group_sparse_model_path) 30 | compressed_model = torch.load(oto.compressed_model_path) 31 | 32 | full_output = full_model(dummy_input) 33 | compressed_output = compressed_model(dummy_input) 34 | 35 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 36 | print("Maximum output difference " + str(max_output_diff.item())) 37 | self.assertLessEqual(max_output_diff, 1e-4) 38 | full_model_size = os.stat(oto.full_group_sparse_model_path) 39 | compressed_model_size = os.stat(oto.compressed_model_path) 40 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 41 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 42 | 43 | -------------------------------------------------------------------------------- /sanity_check/test_resnet18.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | import torchvision.models 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestResNet18(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 32, 32)): 11 | model = torchvision.models.resnet18() 12 | oto = OTO(model, dummy_input) 13 | oto.visualize(view=False, out_dir=OUT_DIR) 14 | 15 | # Compute FLOP and param for full model. 16 | full_flops = oto.compute_flops(in_million=True)['total'] 17 | full_num_params = oto.compute_num_params(in_million=True) 18 | 19 | oto.random_set_zero_groups() 20 | oto.construct_subnet(out_dir=OUT_DIR) 21 | 22 | full_model = torch.load(oto.full_group_sparse_model_path) 23 | compressed_model = torch.load(oto.compressed_model_path) 24 | 25 | full_output = full_model(dummy_input) 26 | compressed_output = compressed_model(dummy_input) 27 | 28 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 29 | print("Maximum output difference : ", max_output_diff.item()) 30 | self.assertLessEqual(max_output_diff, 1e-4) 31 | full_model_size = os.stat(oto.full_group_sparse_model_path) 32 | compressed_model_size = os.stat(oto.compressed_model_path) 33 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 34 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 35 | 36 | # Compute FLOP and param for pruned model after oto.construct_subnet() 37 | pruned_flops = oto.compute_flops(in_million=True)['total'] 38 | pruned_num_params = oto.compute_num_params(in_million=True) 39 | 40 | print("FLOP reduction (%) : ", 1.0 - pruned_flops / full_flops) 41 | print("Param reduction (%) : ", 1.0 - pruned_num_params / full_num_params) -------------------------------------------------------------------------------- /sanity_check/backends/diffusion/ema.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class EMAHelper(object): 5 | def __init__(self, mu=0.999): 6 | self.mu = mu 7 | self.shadow = {} 8 | 9 | def register(self, module): 10 | if isinstance(module, nn.DataParallel): 11 | module = module.module 12 | for name, param in module.named_parameters(): 13 | if param.requires_grad: 14 | self.shadow[name] = param.data.clone() 15 | 16 | def update(self, module): 17 | if isinstance(module, nn.DataParallel): 18 | module = module.module 19 | for name, param in module.named_parameters(): 20 | if param.requires_grad: 21 | self.shadow[name].data = ( 22 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data 23 | 24 | def ema(self, module): 25 | if isinstance(module, nn.DataParallel): 26 | module = module.module 27 | for name, param in module.named_parameters(): 28 | if param.requires_grad: 29 | param.data.copy_(self.shadow[name].data) 30 | 31 | def ema_copy(self, module): 32 | if isinstance(module, nn.DataParallel): 33 | inner_module = module.module 34 | module_copy = type(inner_module)( 35 | inner_module.config).to(inner_module.config.device) 36 | module_copy.load_state_dict(inner_module.state_dict()) 37 | module_copy = nn.DataParallel(module_copy) 38 | else: 39 | module_copy = type(module)(module.config).to(module.config.device) 40 | module_copy.load_state_dict(module.state_dict()) 41 | # module_copy = copy.deepcopy(module) 42 | self.ema(module_copy) 43 | return module_copy 44 | 45 | def state_dict(self): 46 | return self.shadow 47 | 48 | def load_state_dict(self, state_dict): 49 | self.shadow = state_dict 50 | -------------------------------------------------------------------------------- /sanity_check/test_concat_case1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DemoNetConcatCase1 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestDemoNetConcatCase1(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 32, 32)): 11 | model = DemoNetConcatCase1() 12 | oto = OTO(model, dummy_input) 13 | oto.visualize(view=False, out_dir=OUT_DIR) 14 | # Compute FLOP and param for full model. 15 | full_flops = oto.compute_flops(in_million=True)['total'] 16 | full_num_params = oto.compute_num_params(in_million=True) 17 | 18 | oto.random_set_zero_groups() 19 | oto.construct_subnet(out_dir=OUT_DIR) 20 | full_model = torch.load(oto.full_group_sparse_model_path) 21 | compressed_model = torch.load(oto.compressed_model_path) 22 | 23 | full_output = full_model(dummy_input) 24 | compressed_output = compressed_model(dummy_input) 25 | 26 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 27 | print("Maximum output difference : ", max_output_diff.item()) 28 | self.assertLessEqual(max_output_diff, 1e-4) 29 | full_model_size = os.stat(oto.full_group_sparse_model_path) 30 | compressed_model_size = os.stat(oto.compressed_model_path) 31 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 32 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 33 | 34 | # Compute FLOP and param for pruned model after oto.construct_subnet() 35 | pruned_flops = oto.compute_flops(in_million=True)['total'] 36 | pruned_num_params = oto.compute_num_params(in_million=True) 37 | 38 | print("FLOP reduction (%) : ", 1.0 - pruned_flops / full_flops) 39 | print("Param reduction (%) : ", 1.0 - pruned_num_params / full_num_params) -------------------------------------------------------------------------------- /sanity_check/test_concat_case2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DemoNetConcatCase2 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestDemoNetConcatCase2(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 32, 32)): 11 | model = DemoNetConcatCase2() 12 | oto = OTO(model, dummy_input) 13 | oto.visualize(view=False, out_dir=OUT_DIR) 14 | # Compute FLOP and param for full model. 15 | full_flops = oto.compute_flops(in_million=True)['total'] 16 | full_num_params = oto.compute_num_params(in_million=True) 17 | 18 | oto.random_set_zero_groups() 19 | oto.construct_subnet(out_dir=OUT_DIR) 20 | full_model = torch.load(oto.full_group_sparse_model_path) 21 | compressed_model = torch.load(oto.compressed_model_path) 22 | 23 | full_output = full_model(dummy_input) 24 | compressed_output = compressed_model(dummy_input) 25 | 26 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 27 | print("Maximum output difference : ", max_output_diff.item()) 28 | self.assertLessEqual(max_output_diff, 1e-4) 29 | full_model_size = os.stat(oto.full_group_sparse_model_path) 30 | compressed_model_size = os.stat(oto.compressed_model_path) 31 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 32 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 33 | 34 | # Compute FLOP and param for pruned model after oto.construct_subnet() 35 | pruned_flops = oto.compute_flops(in_million=True)['total'] 36 | pruned_num_params = oto.compute_num_params(in_million=True) 37 | 38 | print("FLOP reduction (%) : ", 1.0 - pruned_flops / full_flops) 39 | print("Param reduction (%) : ", 1.0 - pruned_num_params / full_num_params) -------------------------------------------------------------------------------- /sanity_check/backends/peft/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | __version__ = "0.3.0.dev0" 21 | 22 | from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model 23 | from .peft_model import ( 24 | PeftModel, 25 | PeftModelForCausalLM, 26 | PeftModelForSeq2SeqLM, 27 | PeftModelForSequenceClassification, 28 | PeftModelForTokenClassification, 29 | ) 30 | from .tuners import ( 31 | LoraConfig, 32 | LoraModel, 33 | AdaLoraConfig, 34 | AdaLoraModel, 35 | PrefixEncoder, 36 | PrefixTuningConfig, 37 | PromptEmbedding, 38 | PromptEncoder, 39 | PromptEncoderConfig, 40 | PromptEncoderReparameterizationType, 41 | PromptTuningConfig, 42 | PromptTuningInit, 43 | ) 44 | from .utils import ( 45 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 46 | PeftConfig, 47 | PeftType, 48 | PromptLearningConfig, 49 | TaskType, 50 | bloom_model_postprocess_past_key_value, 51 | get_peft_model_state_dict, 52 | prepare_model_for_int8_training, 53 | set_peft_model_state_dict, 54 | shift_tokens_right, 55 | ) 56 | -------------------------------------------------------------------------------- /sanity_check/test_carn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import CarnNet 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestCARN(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 224, 224)): 11 | scale = 2 12 | model = CarnNet(scale=scale, multi_scale=False, group=1) 13 | oto = OTO(model, (dummy_input, scale)) 14 | oto.visualize(view=False, out_dir=OUT_DIR) 15 | 16 | # Compute FLOP and param for full model. 17 | full_flops = oto.compute_flops(in_million=True)['total'] 18 | full_num_params = oto.compute_num_params(in_million=True) 19 | 20 | oto.random_set_zero_groups() 21 | oto.construct_subnet(out_dir=OUT_DIR) 22 | full_model = torch.load(oto.full_group_sparse_model_path) 23 | compressed_model = torch.load(oto.compressed_model_path) 24 | 25 | full_output = full_model(dummy_input, scale) 26 | compressed_output = compressed_model(dummy_input, scale) 27 | 28 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 29 | print("Maximum output difference : ", max_output_diff.item()) 30 | self.assertLessEqual(max_output_diff, 1e-4) 31 | full_model_size = os.stat(oto.full_group_sparse_model_path) 32 | compressed_model_size = os.stat(oto.compressed_model_path) 33 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 34 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 35 | 36 | # Compute FLOP and param for pruned model after oto.construct_subnet() 37 | pruned_flops = oto.compute_flops(in_million=True)['total'] 38 | pruned_num_params = oto.compute_num_params(in_million=True) 39 | 40 | print("FLOP reduction (%) : ", 1.0 - pruned_flops / full_flops) 41 | print("Param reduction (%) : ", 1.0 - pruned_num_params / full_num_params) -------------------------------------------------------------------------------- /sanity_check/test_convnextlarge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import convnext_large 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestConvNextLarge(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 224, 224)): 11 | # layer_scale_init_value will disable a nn.Parameters gamma. 12 | # The singleton parameter is not supported in OTOv3 13 | # Pretrained convnext has such parameters, thereby need to skip a few node groups. 14 | model = convnext_large(layer_scale_init_value=-1) 15 | # model = convnext_tiny(pretrained=True) 16 | oto = OTO(model, dummy_input) 17 | 18 | oto.visualize(view=False, out_dir=OUT_DIR) 19 | # Compute FLOP and param for full model. 20 | full_flops = oto.compute_flops(in_million=True)['total'] 21 | full_num_params = oto.compute_num_params(in_million=True) 22 | 23 | oto.random_set_zero_groups() 24 | oto.construct_subnet(out_dir=OUT_DIR) 25 | full_model = torch.load(oto.full_group_sparse_model_path) 26 | compressed_model = torch.load(oto.compressed_model_path) 27 | 28 | full_output = full_model(dummy_input) 29 | compressed_output = compressed_model(dummy_input) 30 | 31 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 32 | print("Maximum output difference : ", max_output_diff.item()) 33 | # self.assertLessEqual(max_output_diff, 1e-4) 34 | full_model_size = os.stat(oto.full_group_sparse_model_path) 35 | compressed_model_size = os.stat(oto.compressed_model_path) 36 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 37 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 38 | 39 | # Compute FLOP and param for pruned model after oto.construct_subnet() 40 | pruned_flops = oto.compute_flops(in_million=True)['total'] 41 | pruned_num_params = oto.compute_num_params(in_million=True) 42 | 43 | print("FLOP reduction (%) : ", 1.0 - pruned_flops / full_flops) 44 | print("Param reduction (%) : ", 1.0 - pruned_num_params / full_num_params) -------------------------------------------------------------------------------- /sanity_check/test_convnextxlarge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import convnext_xlarge 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestConvNextXLarge(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 224, 224)): 11 | # layer_scale_init_value will disable a nn.Parameters gamma. 12 | # The singleton parameter is not supported in OTOv3 13 | # Pretrained convnext has such parameters, thereby need to skip a few node groups. 14 | model = convnext_xlarge(layer_scale_init_value=-1) 15 | # model = convnext_tiny(pretrained=True) 16 | oto = OTO(model, dummy_input) 17 | 18 | oto.visualize(view=False, out_dir=OUT_DIR) 19 | # Compute FLOP and param for full model. 20 | full_flops = oto.compute_flops(in_million=True)['total'] 21 | full_num_params = oto.compute_num_params(in_million=True) 22 | 23 | oto.random_set_zero_groups() 24 | oto.construct_subnet(out_dir=OUT_DIR) 25 | full_model = torch.load(oto.full_group_sparse_model_path) 26 | compressed_model = torch.load(oto.compressed_model_path) 27 | 28 | full_output = full_model(dummy_input) 29 | compressed_output = compressed_model(dummy_input) 30 | 31 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 32 | print("Maximum output difference : ", max_output_diff.item()) 33 | # self.assertLessEqual(max_output_diff, 1e-4) 34 | full_model_size = os.stat(oto.full_group_sparse_model_path) 35 | compressed_model_size = os.stat(oto.compressed_model_path) 36 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 37 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 38 | 39 | # Compute FLOP and param for pruned model after oto.construct_subnet() 40 | pruned_flops = oto.compute_flops(in_million=True)['total'] 41 | pruned_num_params = oto.compute_num_params(in_million=True) 42 | 43 | print("FLOP reduction (%) : ", 1.0 - pruned_flops / full_flops) 44 | print("Param reduction (%) : ", 1.0 - pruned_num_params / full_num_params) -------------------------------------------------------------------------------- /sanity_check/backends/demonet_convtranspose_in_case1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | AFFINE=True 5 | 6 | class DemoNetConvtransposeInCase1(nn.Module): 7 | def __init__(self): 8 | super(DemoNetConvtransposeInCase1, self).__init__() 9 | self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 10 | self.bn_1 = nn.InstanceNorm2d(64, affine=AFFINE) 11 | self.bn_2 = nn.InstanceNorm2d(64, affine=AFFINE) 12 | self.leakyrelu = nn.LeakyReLU() 13 | self.conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 14 | self.conv3 = nn.Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 15 | self.conv4 = nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 16 | self.bn_3 = nn.InstanceNorm2d(128, affine=AFFINE) 17 | 18 | self.conv6 = nn.ConvTranspose2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True) 19 | 20 | self.bn_6 = nn.InstanceNorm2d(512, affine=AFFINE) 21 | 22 | self.conv8 = nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 23 | self.conv9 = nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 24 | 25 | self.in_1 = nn.InstanceNorm2d(256, affine=AFFINE) 26 | self.in_2 = nn.InstanceNorm2d(256, affine=AFFINE) 27 | self.conv10 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 28 | 29 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 30 | 31 | self.gemm1 = nn.Linear(in_features=256, out_features=128, bias=True) 32 | self.gemm2 = nn.Linear(in_features=128, out_features=10, bias=True) 33 | 34 | 35 | def forward(self, x, debug=False): 36 | x_1 = self.conv1(x) 37 | x_2 = self.leakyrelu(self.bn_1(x_1)) 38 | x_3 = self.leakyrelu(self.bn_2(x_1)) 39 | x_2 = self.conv3(x_2) 40 | x_3 = self.conv4(self.leakyrelu(self.bn_3(self.conv2(x_3)))) 41 | x = x_2 + x_3 42 | x = self.leakyrelu(self.bn_6(self.conv6(x))) 43 | x = self.in_1(self.conv8(x)) + self.in_2(self.conv9(x)) 44 | 45 | x = self.avg_pool(self.conv10(x)) 46 | x = x.view(x.size(0), -1) 47 | 48 | return self.gemm2(self.gemm1(x)) -------------------------------------------------------------------------------- /sanity_check/test_convnexttiny.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import convnext_tiny 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestConvNextTiny(unittest.TestCase): 10 | def test_sanity(self, dummy_input=torch.rand(1, 3, 224, 224)): 11 | # layer_scale_init_value will disable a nn.Parameters gamma. 12 | # The singleton parameter is not supported in OTOv3 13 | # Pretrained convnext has such parameters, thereby need to skip a few node groups. 14 | model = convnext_tiny(layer_scale_init_value=-1) 15 | # model = convnext_tiny(pretrained=True) 16 | oto = OTO(model, dummy_input) 17 | 18 | # # For pretrained convnexttiny 19 | # oto.mark_unprunable_by_node_ids(['node-183', 'node-312', 'node-422', 'node-712']) 20 | 21 | oto.visualize(view=False, out_dir=OUT_DIR) 22 | # Compute FLOP and param for full model. 23 | full_flops = oto.compute_flops(in_million=True)['total'] 24 | full_num_params = oto.compute_num_params(in_million=True) 25 | 26 | oto.random_set_zero_groups() 27 | oto.construct_subnet(out_dir=OUT_DIR) 28 | full_model = torch.load(oto.full_group_sparse_model_path) 29 | compressed_model = torch.load(oto.compressed_model_path) 30 | 31 | full_output = full_model(dummy_input) 32 | compressed_output = compressed_model(dummy_input) 33 | 34 | max_output_diff = torch.max(torch.abs(full_output - compressed_output)) 35 | print("Maximum output difference : ", max_output_diff.item()) 36 | # self.assertLessEqual(max_output_diff, 1e-4) 37 | full_model_size = os.stat(oto.full_group_sparse_model_path) 38 | compressed_model_size = os.stat(oto.compressed_model_path) 39 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 40 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 41 | 42 | # Compute FLOP and param for pruned model after oto.construct_subnet() 43 | pruned_flops = oto.compute_flops(in_million=True)['total'] 44 | pruned_num_params = oto.compute_num_params(in_million=True) 45 | 46 | print("FLOP reduction (%) : ", 1.0 - pruned_flops / full_flops) 47 | print("Param reduction (%) : ", 1.0 - pruned_num_params / full_num_params) -------------------------------------------------------------------------------- /sanity_check/backends/demonet_convtranspose_in_case2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class DemoNetConvtransposeInCase2(nn.Module): 5 | def __init__(self): 6 | super(DemoNetConvtransposeInCase2, self).__init__() 7 | self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 8 | self.in_1 = nn.InstanceNorm2d(64, affine=True) 9 | self.in_2 = nn.InstanceNorm2d(64, affine=True) 10 | self.leakyrelu = nn.LeakyReLU() 11 | self.conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 12 | self.conv3 = nn.Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 13 | self.conv4 = nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 14 | self.in_3 = nn.InstanceNorm2d(128, affine=True) 15 | self.in_6 = nn.InstanceNorm2d(512, affine=True) 16 | 17 | self.in_4 = nn.InstanceNorm2d(256, affine=True) 18 | self.in_5 = nn.InstanceNorm2d(256, affine=True) 19 | 20 | self.conv5 = nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 21 | self.conv6 = nn.ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 22 | self.conv7 = nn.ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 23 | 24 | self.conv8 = nn.Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 25 | self.conv9 = nn.ConvTranspose2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 26 | 27 | 28 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 29 | 30 | self.gemm1 = nn.Linear(in_features=384, out_features=128, bias=True) 31 | self.gemm2 = nn.Linear(in_features=128, out_features=10, bias=True) 32 | 33 | def forward(self, x): 34 | x_1 = self.conv1(x) 35 | x_2 = self.leakyrelu(self.in_1(x_1)) 36 | x_3 = self.leakyrelu(self.in_2(x_1)) 37 | x_2 = self.conv3(x_2) 38 | x_3 = self.conv4(self.leakyrelu(self.in_3(self.conv2(x_3)))) 39 | x = torch.cat([x_2, x_3], dim=1) 40 | x = self.leakyrelu(self.in_6(x)) 41 | 42 | x = self.leakyrelu(torch.cat([self.in_4(self.conv5(x)), self.in_5(self.conv6(x))], dim=1)) 43 | x = self.leakyrelu(self.conv7(x)) 44 | x = self.leakyrelu(self.conv8(x)) + self.leakyrelu(self.conv9(x)) 45 | x = self.avg_pool(x) 46 | x = x.view(x.size(0), -1) 47 | return self.gemm2(self.gemm1(x)) -------------------------------------------------------------------------------- /sanity_check/test_groupconv_case1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | from backends import DemoNetGroupConvCase1 4 | import unittest 5 | import os 6 | 7 | OUT_DIR = './cache' 8 | 9 | class TestGroupConvCase1(unittest.TestCase): 10 | def test_sanity( 11 | self, 12 | dummy_input=( 13 | torch.rand(1, 3, 512, 512), 14 | torch.rand(1, 3, 512, 512), 15 | torch.rand(1, 384, 16, 16), 16 | torch.rand(1, 64, 16, 16) 17 | ) 18 | ): 19 | affine = True 20 | norm_type = 'in' 21 | model = DemoNetGroupConvCase1(norm_type=norm_type, affine=affine) 22 | oto = OTO(model, dummy_input) 23 | unprunable_param_names = [ 24 | 'conv_1.conv1.weight', 25 | 'conv_5.conv2.weight', 26 | 'conv_6.conv1.weight' 27 | ] 28 | oto.mark_unprunable_by_param_names(param_names=unprunable_param_names) 29 | 30 | oto.visualize(view=False, out_dir=OUT_DIR, display_params=True) 31 | # For test FLOP and param reductions. 32 | full_flops = oto.compute_flops(in_million=True)['total'] 33 | full_num_params = oto.compute_num_params(in_million=True) 34 | 35 | oto.random_set_zero_groups() 36 | oto.construct_subnet(out_dir=OUT_DIR) 37 | full_model = torch.load(oto.full_group_sparse_model_path) 38 | compressed_model = torch.load(oto.compressed_model_path) 39 | 40 | full_output = full_model(*dummy_input) 41 | compressed_output = compressed_model(*dummy_input) 42 | 43 | max_output_diff = torch.max(torch.abs(full_output[0] - compressed_output[0])) 44 | print("Maximum output difference : ", max_output_diff.item()) 45 | full_model_size = os.stat(oto.full_group_sparse_model_path) 46 | compressed_model_size = os.stat(oto.compressed_model_path) 47 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 48 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 49 | self.assertLessEqual(max_output_diff, 1e-3) 50 | 51 | # Compute FLOP and param for pruned model after oto.construct_subnet() 52 | pruned_flops = oto.compute_flops(in_million=True)['total'] 53 | pruned_num_params = oto.compute_num_params(in_million=True) 54 | 55 | print("FLOP reduction (%) : ", 1.0 - pruned_flops / full_flops) 56 | print("Param reduction (%) : ", 1.0 - pruned_num_params / full_num_params) -------------------------------------------------------------------------------- /tutorials/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | 5 | def accuracy_topk(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].reshape(-1).view(-1).float().sum(0, keepdim=True) 17 | res.append(correct_k) 18 | return res 19 | 20 | 21 | def check_accuracy(model, testloader, two_input=False): 22 | correct1 = 0 23 | correct5 = 0 24 | total = 0 25 | model = model.eval() 26 | device = next(model.parameters()).device 27 | with torch.no_grad(): 28 | for X, y in testloader: 29 | X = X.to(device) 30 | y = y.to(device) 31 | if two_input: 32 | y_pred = model.forward(X, X) 33 | else: 34 | y_pred = model.forward(X) 35 | total += y.size(0) 36 | 37 | prec1, prec5 = accuracy_topk(y_pred.data, y, topk=(1, 5)) 38 | 39 | correct1 += prec1.item() 40 | correct5 += prec5.item() 41 | 42 | model = model.train() 43 | accuracy1 = correct1 / total 44 | accuracy5 = correct5 / total 45 | return accuracy1, accuracy5 46 | 47 | 48 | def check_accuracy_onnx(model_path, testloader, two_input=False): 49 | import onnxruntime as ort 50 | sess_options = ort.SessionOptions() 51 | ort_sess = ort.InferenceSession(model_path, sess_options) 52 | correct1 = 0 53 | correct5 = 0 54 | total = 0 55 | 56 | for X, y in testloader: 57 | try: 58 | if not two_input: 59 | outputs = ort_sess.run(None, {'input.1': X.numpy()})[0] 60 | else: 61 | outputs = ort_sess.run(None, {'input.1': X.numpy(), 'input.2': X.numpy()})[0] 62 | except: 63 | continue 64 | prec1, prec5 = accuracy_topk(torch.tensor(outputs), y.data, topk=(1, 5)) 65 | correct1 += prec1.item() 66 | correct5 += prec5.item() 67 | total += y.size(0) 68 | 69 | accuracy1 = correct1 / total 70 | accuracy5 = correct5 / total 71 | return accuracy1, accuracy5 72 | 73 | def compute_output_onnx_given_input(model_path, input_tensor): 74 | import onnxruntime as ort 75 | ort_sess = ort.InferenceSession(model_path) 76 | output = ort_sess.run(None, {'input.1': input_tensor.numpy()})[0] 77 | return output -------------------------------------------------------------------------------- /sanity_check/backends/carn/carn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .ops import ResidualBlock, BasicBlock, MeanShift, UpsampleBlock 4 | 5 | class Block(nn.Module): 6 | def __init__(self, 7 | in_channels, out_channels, 8 | group=1): 9 | super(Block, self).__init__() 10 | 11 | self.b1 = ResidualBlock(64, 64) 12 | self.b2 = ResidualBlock(64, 64) 13 | self.b3 = ResidualBlock(64, 64) 14 | self.c1 = BasicBlock(64*2, 64, 1, 1, 0) 15 | self.c2 = BasicBlock(64*3, 64, 1, 1, 0) 16 | self.c3 = BasicBlock(64*4, 64, 1, 1, 0) 17 | 18 | def forward(self, x): 19 | c0 = o0 = x 20 | 21 | b1 = self.b1(o0) 22 | c1 = torch.cat([c0, b1], dim=1) 23 | o1 = self.c1(c1) 24 | 25 | b2 = self.b2(o1) 26 | c2 = torch.cat([c1, b2], dim=1) 27 | o2 = self.c2(c2) 28 | 29 | b3 = self.b3(o2) 30 | c3 = torch.cat([c2, b3], dim=1) 31 | o3 = self.c3(c3) 32 | 33 | return o3 34 | 35 | 36 | class CarnNet(nn.Module): 37 | def __init__(self, **kwargs): 38 | super(CarnNet, self).__init__() 39 | 40 | scale = kwargs.get("scale") 41 | multi_scale = kwargs.get("multi_scale") 42 | group = kwargs.get("group", 1) 43 | 44 | self.sub_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=True) 45 | self.add_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=False) 46 | 47 | self.entry = nn.Conv2d(3, 64, 3, 1, 1) 48 | 49 | self.b1 = Block(64, 64) 50 | self.b2 = Block(64, 64) 51 | self.b3 = Block(64, 64) 52 | self.c1 = BasicBlock(64*2, 64, 1, 1, 0) 53 | self.c2 = BasicBlock(64*3, 64, 1, 1, 0) 54 | self.c3 = BasicBlock(64*4, 64, 1, 1, 0) 55 | 56 | self.upsample = UpsampleBlock(64, scale=scale, 57 | multi_scale=multi_scale, 58 | group=group) 59 | self.exit = nn.Conv2d(64, 3, 3, 1, 1) 60 | 61 | def forward(self, x, scale): 62 | x = self.sub_mean(x) 63 | x = self.entry(x) 64 | c0 = o0 = x 65 | 66 | b1 = self.b1(o0) 67 | c1 = torch.cat([c0, b1], dim=1) 68 | o1 = self.c1(c1) 69 | 70 | b2 = self.b2(o1) 71 | c2 = torch.cat([c1, b2], dim=1) 72 | o2 = self.c2(c2) 73 | 74 | b3 = self.b3(o2) 75 | c3 = torch.cat([c2, b3], dim=1) 76 | o3 = self.c3(c3) 77 | 78 | out = self.upsample(o3, scale=scale) 79 | 80 | out = self.exit(out) 81 | out = self.add_mean(out) 82 | 83 | return out -------------------------------------------------------------------------------- /only_train_once/graph/node.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Node: 4 | def __init__(self, id=None, op_name="", op=None, inputs=[], outputs=[], param_names=[], output_shape=[]): 5 | super().__init__() 6 | self.id = id 7 | self.op = op 8 | self.op_name = op_name 9 | self.inputs = ['node-' + str(i) for i in inputs] 10 | self.outputs = ['node-' + str(o) for o in outputs] 11 | self.param_names = param_names 12 | self.node_group_ids = list() 13 | self.pruned_status = { 14 | "out_dim": False, 15 | "in_dim": False 16 | } 17 | self.output_shape = output_shape 18 | self.input_shape = [] 19 | 20 | def __repr__(self) -> str: 21 | return f"Node id: {self.id}, op_name: {self.op_name}, param_names: {self.param_names}" 22 | 23 | @property 24 | def title(self): 25 | if not self.op: 26 | return self.op_name 27 | # Default 28 | title = (self.op_name + '-' + self.op._type) if self.op_name != self.op._type else self.op._type 29 | if "kernel_shape" in self.op.cfg_params: 30 | # Kernel 31 | kernel = self.op.cfg_params["kernel_shape"] 32 | title += "x".join(map(str, kernel)) 33 | if "stride" in self.op.cfg_params: 34 | stride = self.op.cfg_params["stride"] 35 | if np.unique(stride).size == 1: 36 | stride = stride[0] 37 | if stride != 1: 38 | title += "/s{}".format(str(stride)) 39 | return title 40 | 41 | def is_stem(self): 42 | if self.op is not None: 43 | if self.op.is_basic: 44 | return self.op.is_stem 45 | else: 46 | return self.is_conv() or self.is_convtranspose() or self.is_linear() 47 | else: 48 | return False 49 | 50 | def is_conv(self): 51 | return self.op_name == "Conv" or self.op_name == 'conv' 52 | 53 | def is_convtranspose(self): 54 | return self.op_name == "ConvTranspose" or self.op_name == 'convtranspose' 55 | 56 | def is_linear(self): 57 | return self.op_name == "Linear" or self.op_name == 'linear' \ 58 | or self.op_name == "Gemm" or self.op_name == "gemm" 59 | 60 | def is_concat(self, axis=None): 61 | # Check if concat at first 62 | _is_concat = self.op_name == "Concat" or self.op_name == 'concat' 63 | if axis == None: 64 | return _is_concat 65 | # Check if axis match 66 | if _is_concat and hasattr(self.op, 'cfg_params'): 67 | if 'axis' in self.op.cfg_params: 68 | return True if self.op.cfg_params['axis'] == axis else False 69 | else: 70 | return False 71 | return _is_concat 72 | 73 | def is_dummy(self): 74 | return True if self.id == 'dummy_input' or self.id == 'dummy_output' else False 75 | 76 | 77 | -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | We will routinely update tutorials to cover varying use cases in 2024. Please expect slow update in Januaray due to recent heavy workload. 4 | 5 | Here are the **empirical principles** that we would like to highlight if employing OTO onto new applications outside our tutorials. 6 | 7 | ## Sanity Check 8 | 9 | We highly recommend to proceed a sanity check to test the compliance of OTO onto target DNN. The sanity check will randomly pick up a set of minimally removal structures as redundant 10 | 11 | ```python 12 | oto.random_set_zero_groups() 13 | ``` 14 | and produce compact subnetwork, as presented in [sanity_check](https://github.com/tianyic/only_train_once/blob/main/sanity_check/test_resnet18.py). If sanity check does not pass, please mark illed node groups as unprunable via either `node_ids` 15 | 16 | ```python 17 | oto.mark_unprunable_by_node_ids() 18 | ``` 19 | or `param_names` 20 | ```python 21 | oto.mark_unprunable_by_param_names() 22 | ``` 23 | For example, in [YOLOv5](https://github.com/tianyic/only_train_once/blob/main/sanity_check/test_yolov5.py), we mark the node groups corresponding to detection heads as unprunable. In [DemoNetGroupConvCase1](https://github.com/tianyic/only_train_once/blob/main/sanity_check/test_groupconv_case1.py) which origins from a multi-modal DNN, we mark the node groups including a set of `param_names` as unprunable. 24 | 25 | ## Optimizer setup (Important) 26 | 27 | OTO is designed to **seamlessly** integrate into the existing training pipeline for the full models. This existing pipeline is typically reliable for achieving high performance with full models. 28 | 29 | To minimize the effort in [**hyperparameters**](https://github.com/tianyic/only_train_once/blob/cbb3d3dccf95c383e9cddcbaf8592cf3db13817b/only_train_once/__init__.py#L47) tuning while ensuring high performance, we recommend **setting the hyperparameters in OTO's optimizers identical to those in the baseline optimizers**. This setup generally yields satisfactory results in DNN compression across a wide range of applications, from computer vision to natural language processing, and from academic benchmarks to real-world AI products. However, be aware that some applications might require extended training steps for convergence due to the reduced learning capacity of sparse models. 30 | 31 | It is important to note that different optimizer setups can lead to significantly varied performance outcomes. Additionally, there is potential that alternative hyperparameter configurations, differing from our baseline recommendation, could enhance performance. We suggest users with the interest and resources to experiment with different hyperparameter settings and exploit these possibilities, which typically delivers the optimal compressed model. 32 | 33 | 34 | ## Old tutorials 35 | 36 | Tutorials over old library can be found at [here](https://github.com/tianyic/only_train_once/tree/otov2_legacy_backup/tutorials). It covers ResNet50 CIFAR10, ResNet50 ImageNet and VGG16BN CIFAR10. These tutorials will be refreshed upon the new library next year. 37 | -------------------------------------------------------------------------------- /sanity_check/test_yolov5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | import unittest 4 | import os 5 | import onnxruntime as ort 6 | import numpy as np 7 | 8 | OUT_DIR = './cache' 9 | 10 | class TestYolov5(unittest.TestCase): 11 | def test_sanity(self, dummy_input=torch.rand(1, 3, 640, 640)): 12 | model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) 13 | # All parameters in the pretrained Yolov5 are not trainable. 14 | for _, param in model.named_parameters(): 15 | param.requires_grad = True 16 | 17 | oto = OTO(model, dummy_input) 18 | # Mark a conv-concat and the detection heads as unprunable 19 | # The node_ids may be varying upon different torch version. 20 | oto.mark_unprunable_by_node_ids( 21 | # ['node-229', 'node-329', 'node-443', 'node-553'] 22 | ['node-229', 'node-581', 'node-471', 'node-359'] 23 | ) 24 | # The above can be also achieved by 25 | oto.mark_unprunable_by_param_names( 26 | ['model.model.model.9.cv1.conv.weight', 'model.model.model.24.m.2.weight', \ 27 | 'model.model.model.24.m.1.weight', 'model.model.model.24.m.0.weight'] 28 | ) 29 | 30 | # Display param name and shape in dependency graph visualization 31 | oto.visualize(view=False, out_dir=OUT_DIR, display_params=True) 32 | # Compute FLOP and param for full model. 33 | full_flops = oto.compute_flops(in_million=True)['total'] 34 | full_num_params = oto.compute_num_params(in_million=True) 35 | 36 | optimizer = oto.hesso( 37 | variant='sgd', 38 | lr=0.1 39 | ) 40 | oto.random_set_zero_groups() 41 | # YOLOv5 has some trouble to directly load torch model 42 | oto.construct_subnet( 43 | out_dir=OUT_DIR, 44 | ckpt_format='onnx' 45 | ) 46 | 47 | full_sess = ort.InferenceSession(oto.full_group_sparse_model_path) 48 | full_output = full_sess.run(None, {'onnx::Cast_0': dummy_input.numpy()}) 49 | compressed_sess = ort.InferenceSession(oto.compressed_model_path) 50 | compressed_output = compressed_sess.run(None, {'onnx::Cast_0': dummy_input.numpy()}) 51 | 52 | max_output_diff = np.max(np.abs(full_output[0] - compressed_output[0])) 53 | print("Maximum output difference : ", max_output_diff.item()) 54 | self.assertLessEqual(max_output_diff, 1e-3) 55 | 56 | full_model_size = os.stat(oto.full_group_sparse_model_path) 57 | compressed_model_size = os.stat(oto.compressed_model_path) 58 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 59 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 60 | 61 | # Compute FLOP and param for pruned model after oto.construct_subnet() 62 | pruned_flops = oto.compute_flops(in_million=True)['total'] 63 | pruned_num_params = oto.compute_num_params(in_million=True) 64 | 65 | print("FLOP reduction (%) : ", 1.0 - pruned_flops / full_flops) 66 | print("Param reduction (%) : ", 1.0 - pruned_num_params / full_num_params) -------------------------------------------------------------------------------- /only_train_once/transform/graph_transform.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import re 3 | from . import ge 4 | 5 | class Rename(): 6 | def __init__(self, op=None, name=None, to=None): 7 | assert op or name, "Either op or name must be provided" 8 | assert not(op and name), "Either op or name should be provided, but not both" 9 | assert bool(to), "The to parameter is required" 10 | self.to = to 11 | self.op = re.compile(op) if op else None 12 | self.name = re.compile(name) if name else None 13 | 14 | def apply(self, graph): 15 | for i, node in enumerate(graph.nodes.values()): 16 | if self.op: 17 | node.op_name = self.op.sub(self.to, node.op_name) 18 | if self.name is None: 19 | node.op_name = str(node.op_name) 20 | else: 21 | node.op_name = self.name.sub(self.to, node.op_name) 22 | 23 | class Fold(): 24 | def __init__(self, pattern, to, name=None): 25 | # TODO: validate that op and name are valid 26 | self.pattern = ge.GEParser(pattern).parse() 27 | self.to = to 28 | self.name = name 29 | 30 | def apply(self, graph): 31 | while True: 32 | matches, _ = graph.search(self.pattern) 33 | if not matches: 34 | break 35 | 36 | # Replace pattern with new node 37 | if self.to == "__first__": 38 | combo = matches[0] 39 | elif self.to == "__last__": 40 | combo = matches[-1] 41 | else: 42 | # find the most bottom child 43 | outputs = set() 44 | match_ids = [node.id for node in matches] 45 | for match_node in matches: 46 | for outgoing_node in graph.outgoing(match_node): 47 | if outgoing_node.id not in match_ids: 48 | outputs.add(outgoing_node) 49 | # combine operators 50 | combo_op = matches[0].op 51 | for i in range(1, len(matches)): 52 | combo_op += matches[i].op 53 | combo_op.name = self.to or self.pattern 54 | combo = Node(id=graph.sequence_id(), 55 | op=combo_op, 56 | output_shape=matches[-1].output_shape, 57 | outputs = list(outputs)) # TODO, check bugs 58 | combo._caption = "/".join(filter(None, [l.caption for l in matches])) 59 | graph.replace(matches, combo) 60 | 61 | 62 | class ConvBNFuse(): 63 | def __init__(self, pattern, to, name=None): 64 | self.pattern = ge.GEParser(pattern).parse() 65 | self.to = to 66 | self.name = name 67 | 68 | def apply(self, graph): 69 | graph.fused_conv_bns = list() 70 | while True: 71 | matches, _ = graph.search(self.pattern) 72 | if not matches: 73 | break 74 | for match_node in matches: 75 | match_node._skip_pattern_search = True 76 | graph.fused_conv_bns.append(matches) 77 | 78 | # PyTorch Graph Transforms 79 | FRAMEWORK_TRANSFORMS = [ 80 | Rename(op=r"onnx::(.*)", to=r"\1"), 81 | Rename(op=r"gemm", to=r"linear"), 82 | Rename(op=r"batchnormalization", to="batchnorm"), 83 | ] 84 | 85 | CONV_BN_FUSE = ConvBNFuse("conv > batchnorm", "convbn") 86 | -------------------------------------------------------------------------------- /tutorials/04.oto_distributed_data_parallelism.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f4441533-4d5a-4c9e-8906-1df24bda96ec", 6 | "metadata": {}, 7 | "source": [ 8 | "## Tutorial 4. OTO with Distributed Data Parallelism\n", 9 | "\n", 10 | "\n", 11 | "In this tutorial, we briefly show how to use the pruning mode of OTO to train and prune DNN under **distributed data parallelism (DDP)**.\n", 12 | "\n", 13 | "We acknowledge the contributions from @C0NGTRI123 on DDP support from [issue](https://github.com/tianyic/only_train_once/issues/44)." 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "id": "8067c3a6-bf56-4c85-96fb-2d4da488680b", 19 | "metadata": {}, 20 | "source": [ 21 | "### Step 1. Create OTO instance\n", 22 | "\n", 23 | "Create OTO instance **before warping DNN into distributed DNN**. \n", 24 | "\n", 25 | "We recommand to create model and dummy input on CPU." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "id": "f2ce0399-51c7-4afc-b0f0-35bda3698825", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "from only_train_once import OTO\n", 36 | "\n", 37 | "oto = OTO(model, dummy_input)\n", 38 | "\n", 39 | "\n", 40 | "from torch.nn.parallel import DistributedDataParallel as DDP\n", 41 | "model = model.to(local_rank)\n", 42 | "model = DDP(model, device_ids=[local_rank])" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "id": "01639056-422d-4bfb-89db-ff9f0725e807", 48 | "metadata": {}, 49 | "source": [ 50 | "### Step 2. Create optimizer and train as normal\n", 51 | "\n", 52 | "Set up the `device` as current `local_rank`.\n", 53 | "\n", 54 | "We acknowledge the contribution from @Nadav-out on the [pull request](https://github.com/tianyic/only_train_once/pull/53)." 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "3e8cefb1-85cf-4833-a6ff-ee2f9270f02d", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "optimizer = oto.hesso(\n", 65 | " device=local_rank, \n", 66 | " ... # other arguments are the same as standalone training. \n", 67 | " )\n", 68 | "\n", 69 | "# Train as normal\n", 70 | "optimizer.step()" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "id": "fc99fc48-fed7-41a5-ab6d-f2e5686f2003", 76 | "metadata": {}, 77 | "source": [ 78 | "### Step 3. Construct sub-network\n", 79 | "\n", 80 | "Call `oto.construct_subnet()` on only one GPU." 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "7ff0b356", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "if local_rank == 0:\n", 91 | " oto.construct_subnet(...)" 92 | ] 93 | } 94 | ], 95 | "metadata": { 96 | "kernelspec": { 97 | "display_name": "Python 3 (ipykernel)", 98 | "language": "python", 99 | "name": "python3" 100 | }, 101 | "language_info": { 102 | "codemirror_mode": { 103 | "name": "ipython", 104 | "version": 3 105 | }, 106 | "file_extension": ".py", 107 | "mimetype": "text/x-python", 108 | "name": "python", 109 | "nbconvert_exporter": "python", 110 | "pygments_lexer": "ipython3", 111 | "version": "3.11.3" 112 | } 113 | }, 114 | "nbformat": 4, 115 | "nbformat_minor": 5 116 | } 117 | -------------------------------------------------------------------------------- /only_train_once/optimizer/importance_score/magnitude.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once.transform import tensor_transformation, TensorTransform 3 | 4 | def importance_score_by_magnitude_dhspg(param_group): 5 | norm_group = None 6 | for param, p_transform in zip(param_group['params'], param_group['p_transform']): 7 | param_transform = None 8 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 9 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups'], param_group['num_heads']) 10 | else: 11 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups']) 12 | if norm_group == None: 13 | norm_group = torch.norm(param_transform, dim=1) ** 2 14 | else: 15 | norm_group += torch.norm(param_transform, dim=1) ** 2 16 | param_group['importance_scores']['magnitude'] = torch.sqrt(norm_group) 17 | 18 | def importance_score_by_avg_magnitude_dhspg(param_group): 19 | norm_group = None 20 | group_sizes = 0 21 | for param, p_transform in zip(param_group['params'], param_group['p_transform']): 22 | param_transform = None 23 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 24 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups'], param_group['num_heads']) 25 | else: 26 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups']) 27 | if norm_group == None: 28 | norm_group = torch.norm(param_transform, dim=1) ** 2 29 | else: 30 | norm_group += torch.norm(param_transform, dim=1) ** 2 31 | group_sizes += param_transform.shape[1] 32 | param_group['importance_scores']['avg_magnitude'] = torch.sqrt(norm_group) / float(group_sizes + 1e-6) 33 | 34 | def importance_score_by_magnitude_lhspg(param_group): 35 | norm_group = None 36 | for p_name, param, p_transform in zip(param_group['p_names'], param_group['params'], param_group['p_transform']): 37 | if 'lora_A' in p_name or 'lora_B' in p_name: 38 | continue 39 | param_transform = None 40 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 41 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups'], param_group['num_heads']) 42 | else: 43 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups']) 44 | if norm_group == None: 45 | norm_group = torch.norm(param_transform, dim=1) ** 2 46 | else: 47 | norm_group += torch.norm(param_transform, dim=1) ** 2 48 | param_group['importance_scores']['magnitude'] = torch.sqrt(norm_group) 49 | 50 | def importance_score_by_avg_magnitude_lhspg(param_group): 51 | norm_group = None 52 | group_sizes = 0 53 | for p_name, param, p_transform in zip(param_group['p_names'], param_group['params'], param_group['p_transform']): 54 | if 'lora_A' in p_name or 'lora_B' in p_name: 55 | continue 56 | param_transform = None 57 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 58 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups'], param_group['num_heads']) 59 | else: 60 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups']) 61 | if norm_group == None: 62 | norm_group = torch.norm(param_transform, dim=1) ** 2 63 | else: 64 | norm_group += torch.norm(param_transform, dim=1) ** 2 65 | group_sizes += param_group['num_groups'] 66 | param_group['importance_scores']['avg_magnitude'] = torch.sqrt(norm_group) / float(group_sizes + 1e-6) -------------------------------------------------------------------------------- /sanity_check/test_llamav1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | import unittest 4 | import os 5 | from transformers import LlamaConfig, LlamaTokenizer 6 | from backends import LlamaForCausalLM 7 | 8 | OUT_DIR = './cache' 9 | 10 | class TestLLAMAv1(unittest.TestCase): 11 | def test_sanity(self, dummy_input=None): 12 | llama_config = LlamaConfig() 13 | llama_config.num_hidden_layers = 4 14 | llama_config.num_attention_heads = 32 15 | llama_config.hidden_size = 4096 16 | llama_config.intermediate_size = 11008 17 | model = LlamaForCausalLM(llama_config) 18 | 19 | tokenizer = LlamaTokenizer.from_pretrained('huggyllama/llama-7b') 20 | tokenizer.pad_token_id = (0) 21 | tokenizer.padding_side = "left" 22 | tokenizer.save_pretrained(OUT_DIR) 23 | 24 | text = 'This is a test sentence of a very long string and random wording that is used to test dolly model.' * 7 25 | input_data = tokenizer(text, return_tensors='pt').input_ids 26 | 27 | oto = OTO(model, dummy_input=(input_data,), strict_out_nodes=True) 28 | 29 | oto.visualize(view=False, out_dir=OUT_DIR) 30 | 31 | oto.random_set_zero_groups() 32 | 33 | oto.construct_subnet( 34 | export_huggingface_format=False, 35 | export_float16=False, 36 | full_group_sparse_model_dir=OUT_DIR, 37 | compressed_model_dir=OUT_DIR 38 | ) 39 | 40 | text_1 = 'This is a test sentence of a very long string and random wording that is used to test dolly model.' * 7 41 | input_data_1 = tokenizer(text_1, return_tensors='pt').input_ids 42 | 43 | text_2 = 'This is a good test sentence of a pretty short string and wording that is used to test dolly model.' * 7 44 | input_data_2 = tokenizer(text_2, return_tensors='pt').input_ids 45 | 46 | full_model = torch.load(oto.full_group_sparse_model_path) 47 | compressed_model = torch.load(oto.compressed_model_path) 48 | full_output_1 = full_model(input_data_1.to(full_model.device)) 49 | full_output_2 = full_model(input_data_2.to(full_model.device)) 50 | compressed_output_1 = compressed_model(input_data_1.to(compressed_model.device)) 51 | compressed_output_2 = compressed_model(input_data_2.to(compressed_model.device)) 52 | max_output_diff_1 = torch.max(full_output_1.logits - compressed_output_1.logits).item() 53 | max_output_diff_2 = torch.max(full_output_2.logits - compressed_output_2.logits).item() 54 | max_output_diff_3 = torch.max(full_output_1.logits - compressed_output_2.logits).item() 55 | max_output_diff_4 = torch.max(full_output_2.logits - compressed_output_1.logits).item() 56 | print("Maximum output difference under the same inputs:") 57 | print(max_output_diff_1) 58 | 59 | print("Maximum output difference under the same inputs:") 60 | print(max_output_diff_2) 61 | 62 | print("Maximum output difference under the different inputs:") 63 | print(max_output_diff_3) 64 | 65 | print("Maximum output difference under the different inputs:") 66 | print(max_output_diff_4) 67 | 68 | full_model_size = os.stat(oto.full_group_sparse_model_path) 69 | compressed_model_size = os.stat(oto.compressed_model_path) 70 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 71 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 72 | 73 | self.assertLessEqual(max_output_diff_1, 3.0) 74 | self.assertLessEqual(max_output_diff_2, 3.0) 75 | self.assertLessEqual(max_output_diff_3, 6.0) 76 | self.assertLessEqual(max_output_diff_4, 6.0) 77 | -------------------------------------------------------------------------------- /sanity_check/test_bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | import unittest 4 | import os 5 | from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer 6 | from transformers import BertConfig, BertModel 7 | 8 | OUT_DIR = './cache' 9 | 10 | class TestBert(unittest.TestCase): 11 | def test_sanity(self, dummy_input=None): 12 | tokenizer = AutoTokenizer.from_pretrained( 13 | 'bert-base-uncased', 14 | cache_dir=OUT_DIR, 15 | ) 16 | 17 | model = AutoModelForQuestionAnswering.from_pretrained( 18 | 'bert-base-uncased', 19 | cache_dir=OUT_DIR, 20 | ) 21 | text = 'This is a test sentence of a very long string and random wording that is used to test dolly model.' * 7 22 | input_data = tokenizer(text, return_tensors='pt').input_ids 23 | oto = OTO(model, dummy_input=(input_data,), strict_out_nodes=True) 24 | 25 | for name, param in model.named_parameters(): 26 | print(name, param.shape, param.requires_grad) 27 | 28 | # Exclude emebdding 29 | oto.mark_unprunable_by_node_ids(['node-218']) 30 | # for node_group in oto._graph.node_groups.values(): 31 | # if 'node-218' in node_group.id: 32 | # node_group.is_prunable = False 33 | 34 | oto.visualize(view=False, out_dir=OUT_DIR) 35 | oto.random_set_zero_groups() 36 | 37 | oto.construct_subnet( 38 | export_huggingface_format=False, 39 | export_float16=False, 40 | full_group_sparse_model_dir=OUT_DIR, 41 | compressed_model_dir=OUT_DIR 42 | ) 43 | 44 | text_1 = 'This is a test sentence of a very long string and random wording that is used to test dolly model.' * 7 45 | input_data_1 = tokenizer(text_1, return_tensors='pt').input_ids 46 | 47 | text_2 = 'This is a good test sentence of a pretty short string and wording that is used to test dolly model.' * 7 48 | input_data_2 = tokenizer(text_2, return_tensors='pt').input_ids 49 | 50 | full_model = torch.load(oto.full_group_sparse_model_path) 51 | compressed_model = torch.load(oto.compressed_model_path) 52 | full_output_1 = full_model(input_data_1.to(full_model.device)) 53 | full_output_2 = full_model(input_data_2.to(full_model.device)) 54 | compressed_output_1 = compressed_model(input_data_1.to(compressed_model.device)) 55 | compressed_output_2 = compressed_model(input_data_2.to(compressed_model.device)) 56 | max_output_diff_1 = torch.max(full_output_1[0] - compressed_output_1[0]).item() 57 | max_output_diff_2 = torch.max(full_output_2[0] - compressed_output_2[0]).item() 58 | max_output_diff_3 = torch.max(full_output_1[0] - compressed_output_2[0]).item() 59 | max_output_diff_4 = torch.max(full_output_2[0] - compressed_output_1[0]).item() 60 | 61 | print("Maximum output difference under the same inputs:") 62 | print(max_output_diff_1) 63 | 64 | print("Maximum output difference under the same inputs:") 65 | print(max_output_diff_2) 66 | 67 | print("Maximum output difference under the different inputs:") 68 | print(max_output_diff_3) 69 | 70 | print("Maximum output difference under the different inputs:") 71 | print(max_output_diff_4) 72 | 73 | full_model_size = os.stat(oto.full_group_sparse_model_path) 74 | compressed_model_size = os.stat(oto.compressed_model_path) 75 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 76 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 77 | 78 | self.assertLessEqual(max_output_diff_1, 3.0) 79 | self.assertLessEqual(max_output_diff_2, 3.0) 80 | self.assertLessEqual(max_output_diff_3, 6.0) 81 | self.assertLessEqual(max_output_diff_4, 6.0) 82 | -------------------------------------------------------------------------------- /sanity_check/test_llamav2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | import unittest 4 | import os 5 | from transformers import LlamaConfig, LlamaTokenizer 6 | from backends import LlamaForCausalLM 7 | 8 | OUT_DIR = './cache' 9 | 10 | class TestLLAMAv2(unittest.TestCase): 11 | def test_sanity(self, dummy_input=None): 12 | # llama_config = LlamaConfig() 13 | # llama_config.num_hidden_layers = 4 14 | # llama_config.num_attention_heads = 32 15 | # llama_config.hidden_size = 4096 16 | # llama_config.intermediate_size = 11008 17 | # model = LlamaForCausalLM(llama_config) 18 | model = LlamaForCausalLM.from_pretrained( 19 | 'NousResearch/Llama-2-7b-hf', 20 | low_cpu_mem_usage=True 21 | ) 22 | tokenizer = LlamaTokenizer.from_pretrained('NousResearch/Llama-2-7b-hf') 23 | tokenizer.pad_token_id = (0) 24 | tokenizer.padding_side = "left" 25 | tokenizer.save_pretrained(OUT_DIR) 26 | 27 | text = 'Tell me what is Microsoft and Facebook. Explain their difference' 28 | input_data = tokenizer(text, return_tensors='pt').input_ids 29 | 30 | out_tokens = model.generate(input_data, max_length=100) 31 | print(tokenizer.batch_decode(out_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]) 32 | 33 | oto = OTO(model, dummy_input=(input_data,), strict_out_nodes=True) 34 | 35 | # oto.visualize(view=False, out_dir=OUT_DIR) 36 | 37 | oto.random_set_zero_groups() 38 | 39 | oto.construct_subnet( 40 | export_huggingface_format=False, 41 | export_float16=False, 42 | full_group_sparse_model_dir=OUT_DIR, 43 | compressed_model_dir=OUT_DIR 44 | ) 45 | 46 | text_1 = 'This is a test sentence of a very long string and random wording that is used to test dolly model.' * 7 47 | input_data_1 = tokenizer(text_1, return_tensors='pt').input_ids 48 | 49 | text_2 = 'This is a good test sentence of a pretty short string and wording that is used to test dolly model.' * 7 50 | input_data_2 = tokenizer(text_2, return_tensors='pt').input_ids 51 | 52 | full_model = torch.load(oto.full_group_sparse_model_path) 53 | compressed_model = torch.load(oto.compressed_model_path) 54 | full_output_1 = full_model(input_data_1.to(full_model.device)) 55 | full_output_2 = full_model(input_data_2.to(full_model.device)) 56 | compressed_output_1 = compressed_model(input_data_1.to(compressed_model.device)) 57 | compressed_output_2 = compressed_model(input_data_2.to(compressed_model.device)) 58 | max_output_diff_1 = torch.max(full_output_1.logits - compressed_output_1.logits).item() 59 | max_output_diff_2 = torch.max(full_output_2.logits - compressed_output_2.logits).item() 60 | max_output_diff_3 = torch.max(full_output_1.logits - compressed_output_2.logits).item() 61 | max_output_diff_4 = torch.max(full_output_2.logits - compressed_output_1.logits).item() 62 | print("Maximum output difference under the same inputs:") 63 | print(max_output_diff_1) 64 | 65 | print("Maximum output difference under the same inputs:") 66 | print(max_output_diff_2) 67 | 68 | print("Maximum output difference under the different inputs:") 69 | print(max_output_diff_3) 70 | 71 | print("Maximum output difference under the different inputs:") 72 | print(max_output_diff_4) 73 | 74 | full_model_size = os.stat(oto.full_group_sparse_model_path) 75 | compressed_model_size = os.stat(oto.compressed_model_path) 76 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 77 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 78 | 79 | self.assertLessEqual(max_output_diff_1, 3.0) 80 | self.assertLessEqual(max_output_diff_2, 3.0) 81 | self.assertLessEqual(max_output_diff_3, 6.0) 82 | self.assertLessEqual(max_output_diff_4, 6.0) 83 | -------------------------------------------------------------------------------- /sanity_check/test_llamav1_lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from only_train_once import OTO 3 | import unittest 4 | import os 5 | from transformers import LlamaConfig, LlamaTokenizer 6 | from backends import LlamaForCausalLM 7 | from peft_lora.lora_model import LoraModel, LoraConfig 8 | 9 | OUT_DIR = './cache' 10 | 11 | class TestLLAMAv1LoRA(unittest.TestCase): 12 | def test_sanity(self, dummy_input=None): 13 | llama_config = LlamaConfig() 14 | llama_config.num_hidden_layers = 4 15 | llama_config.num_attention_heads = 32 16 | llama_config.hidden_size = 4096 17 | llama_config.intermediate_size = 11096 18 | model = LlamaForCausalLM(llama_config) 19 | 20 | tokenizer = LlamaTokenizer.from_pretrained('huggyllama/llama-7b') 21 | tokenizer.pad_token_id = (0) 22 | tokenizer.padding_side = "left" 23 | tokenizer.save_pretrained(OUT_DIR) 24 | 25 | text = 'This is a test sentence of a very long string and random wording that is used to test dolly model.' * 7 26 | input_data = tokenizer(text, return_tensors='pt').input_ids 27 | 28 | target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj'] 29 | lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=target_modules, lora_dropout=0.05, bias="none") 30 | model = LoraModel(model, lora_config) 31 | 32 | oto = OTO(model, dummy_input=(input_data,), strict_out_nodes=True) 33 | oto.visualize(out_dir=OUT_DIR) 34 | 35 | oto.random_set_zero_groups() 36 | 37 | oto.construct_subnet( 38 | merge_lora_to_base=True, 39 | export_huggingface_format=False, 40 | export_float16=False, 41 | full_group_sparse_model_dir=OUT_DIR, 42 | compressed_model_dir=OUT_DIR 43 | ) 44 | 45 | text_1 = 'This is a test sentence of a very long string and random wording that is used to test dolly model.' * 7 46 | input_data_1 = tokenizer(text_1, return_tensors='pt').input_ids 47 | 48 | text_2 = 'This is a good test sentence of a pretty short string and wording that is used to test dolly model.' * 7 49 | input_data_2 = tokenizer(text_2, return_tensors='pt').input_ids 50 | 51 | full_model = torch.load(oto.full_group_sparse_model_path) 52 | compressed_model = torch.load(oto.compressed_model_path) 53 | full_output_1 = full_model(input_data_1.to(full_model.device)) 54 | full_output_2 = full_model(input_data_2.to(full_model.device)) 55 | compressed_output_1 = compressed_model(input_data_1.to(compressed_model.device)) 56 | compressed_output_2 = compressed_model(input_data_2.to(compressed_model.device)) 57 | max_output_diff_1 = torch.max(full_output_1.logits - compressed_output_1.logits).item() 58 | max_output_diff_2 = torch.max(full_output_2.logits - compressed_output_2.logits).item() 59 | max_output_diff_3 = torch.max(full_output_1.logits - compressed_output_2.logits).item() 60 | max_output_diff_4 = torch.max(full_output_2.logits - compressed_output_1.logits).item() 61 | print("Maximum output difference under the same inputs:") 62 | print(max_output_diff_1) 63 | 64 | print("Maximum output difference under the same inputs:") 65 | print(max_output_diff_2) 66 | 67 | print("Maximum output difference under the different inputs:") 68 | print(max_output_diff_3) 69 | 70 | print("Maximum output difference under the different inputs:") 71 | print(max_output_diff_4) 72 | 73 | full_model_size = os.stat(oto.full_group_sparse_model_path) 74 | compressed_model_size = os.stat(oto.compressed_model_path) 75 | print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs") 76 | print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs") 77 | 78 | self.assertLessEqual(max_output_diff_1, 2.0) 79 | self.assertLessEqual(max_output_diff_2, 2.0) 80 | self.assertGreaterEqual(max_output_diff_3, 2.0) 81 | self.assertGreaterEqual(max_output_diff_4, 2.0) 82 | -------------------------------------------------------------------------------- /only_train_once/transform/tensor_transform.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | class TensorTransform(IntEnum): 4 | NO_UPDATE = 0 5 | NO_PRUNE = 1 6 | BASIC = 2 7 | ACCESSORY = 3 8 | MULTIHEAD_HEADDIM = 4 # Only affects the tensor itself 9 | MULTIHEAD_NUMHEAD = 5 # Only affects the tensor itself 10 | REVERSE_MULTIHEAD_HEADDIM = 6 # Only affects the tensor itself 11 | REVERSE_MULTIHEAD_NUMHEAD = 7 # Only affects the tensor itself 12 | AUXILIARY = 8 13 | TRANSPOSE = 9 14 | MULTIHEAD_NUMHEAD_SPREAD = 10 # Affect other nodes in the same node group 15 | REVERSE_MULTIHEAD_NUMHEAD_SPREAD = 11 # Affect other nodes in the same node group 16 | 17 | TOTAL = 12 18 | 19 | def is_spread_transformation(transformation_type): 20 | if transformation_type == TensorTransform.MULTIHEAD_NUMHEAD_SPREAD: 21 | return True 22 | elif transformation_type == TensorTransform.REVERSE_MULTIHEAD_NUMHEAD_SPREAD: 23 | return True 24 | else: 25 | return False 26 | 27 | SPREAD_TRANSFORM_MAP = { 28 | TensorTransform.MULTIHEAD_NUMHEAD_SPREAD: TensorTransform.MULTIHEAD_NUMHEAD 29 | } 30 | 31 | def tensor_transformation(tensor, transformation_type, num_groups=1, num_heads=1, head_dim=1): 32 | if transformation_type == TensorTransform.NO_UPDATE or \ 33 | transformation_type == TensorTransform.NO_PRUNE: 34 | return tensor 35 | elif transformation_type == TensorTransform.BASIC: 36 | return basic_transformation(tensor, num_groups) 37 | elif transformation_type == TensorTransform.ACCESSORY: 38 | return basic_transformation(tensor, num_groups) 39 | elif transformation_type == TensorTransform.MULTIHEAD_HEADDIM: 40 | return multihead_headdim_transformation(tensor, num_groups, num_heads) 41 | elif transformation_type == TensorTransform.MULTIHEAD_NUMHEAD: 42 | return multihead_numhead_transformation(tensor, num_groups) 43 | elif transformation_type == TensorTransform.MULTIHEAD_NUMHEAD_SPREAD: 44 | return multihead_numhead_transformation(tensor, num_groups) 45 | elif transformation_type == TensorTransform.REVERSE_MULTIHEAD_HEADDIM: 46 | return reverse_multihead_headdim_transformation(tensor, num_groups, num_heads) 47 | elif transformation_type == TensorTransform.REVERSE_MULTIHEAD_NUMHEAD: 48 | return reverse_multihead_numhead_transformation(tensor, num_groups, head_dim) 49 | elif transformation_type == TensorTransform.TRANSPOSE: 50 | return transpose_transformation(tensor, num_groups) 51 | 52 | def basic_transformation(tensor, num_groups=1): 53 | return tensor.view(num_groups, -1) 54 | 55 | def multihead_headdim_transformation(tensor, num_groups=1, num_heads=1): 56 | return tensor.view(num_heads, num_groups, -1).permute(1, 0, 2).contiguous().view(num_groups, -1) 57 | 58 | def multihead_numhead_transformation(tensor, num_groups=1): 59 | return tensor.view(num_groups, -1) 60 | 61 | def reverse_multihead_headdim_transformation(tensor, num_groups=1, num_heads=1): 62 | if tensor.numel() >= num_groups * num_heads: 63 | return tensor.view(num_groups, num_heads, -1).permute(1, 0, 2).contiguous().view(num_heads * num_groups, -1) 64 | else: 65 | if len(tensor.shape) == 1: 66 | return tensor.unsqueeze(1).repeat(1, num_heads).view(num_groups, num_heads, -1).permute(1, 0, 2).contiguous()\ 67 | .view(num_heads * num_groups, -1).squeeze() 68 | else: 69 | return tensor 70 | 71 | def reverse_multihead_numhead_transformation(tensor, num_groups=1, head_dim=1): 72 | if tensor.numel() >= num_groups * head_dim: 73 | raise NotImplementedError 74 | else: 75 | if len(tensor.shape) == 1: 76 | return tensor.unsqueeze(1).repeat(1, head_dim).view(num_groups * head_dim, -1).squeeze() 77 | else: 78 | return tensor 79 | 80 | def transpose_transformation(tensor, num_groups=1): 81 | if len(tensor.shape) == 1: 82 | return tensor.view(num_groups, -1) 83 | elif len(tensor.shape) == 2: 84 | return tensor.permute(1, 0).contiguous().view(num_groups, -1) 85 | elif len(tensor.shape) == 4: 86 | return tensor.permute(1, 0, 2, 3).contiguous().view(num_groups, -1) -------------------------------------------------------------------------------- /sanity_check/peft_lora/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | 4 | COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks"] 5 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = { 6 | "t5": ["q", "v"], 7 | "mt5": ["q", "v"], 8 | "bart": ["q_proj", "v_proj"], 9 | "gpt2": ["c_attn"], 10 | "bloom": ["query_key_value"], 11 | "blip-2": ["q", "v", "q_proj", "v_proj"], 12 | "opt": ["q_proj", "v_proj"], 13 | "gptj": ["q_proj", "v_proj"], 14 | "gpt_neox": ["query_key_value"], 15 | "gpt_neo": ["q_proj", "v_proj"], 16 | "bert": ["query", "value"], 17 | "roberta": ["query", "value"], 18 | "xlm-roberta": ["query", "value"], 19 | "electra": ["query", "value"], 20 | "deberta-v2": ["query_proj", "value_proj"], 21 | "deberta": ["in_proj"], 22 | "layoutlm": ["query", "value"], 23 | "llama": ["q_proj", "v_proj"], 24 | "chatglm": ["query_key_value"], 25 | "starcoder": ["c_attn"], 26 | } 27 | 28 | def transpose(weight, fan_in_fan_out): 29 | return weight.T if fan_in_fan_out else weight 30 | 31 | class ModulesToSaveWrapper(torch.nn.Module): 32 | def __init__(self, module_to_save, adapter_name): 33 | super().__init__() 34 | self.original_module = module_to_save 35 | self.modules_to_save = torch.nn.ModuleDict({}) 36 | self.update(adapter_name) 37 | self.active_adapter = adapter_name 38 | 39 | def update(self, adapter_name): 40 | self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)})) 41 | 42 | def forward(self, *args, **kwargs): 43 | if self.active_adapter not in self.modules_to_save: 44 | return self.original_module(*args, **kwargs) 45 | return self.modules_to_save[self.active_adapter](*args, **kwargs) 46 | 47 | 48 | def prepare_model_for_int8_training( 49 | model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"] 50 | ): 51 | r""" 52 | This method wraps the entire protocol for preparing a model before running a training. This includes: 53 | 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm 54 | head to fp32 55 | 56 | Args: 57 | model, (`transformers.PreTrainedModel`): 58 | The loaded model from `transformers` 59 | """ 60 | loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False) 61 | 62 | for name, param in model.named_parameters(): 63 | # freeze base model's layers 64 | param.requires_grad = False 65 | 66 | if loaded_in_8bit: 67 | # cast layer norm in fp32 for stability for 8bit models 68 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): 69 | param.data = param.data.to(torch.float32) 70 | 71 | if loaded_in_8bit and use_gradient_checkpointing: 72 | # For backward compatibility 73 | if hasattr(model, "enable_input_require_grads"): 74 | model.enable_input_require_grads() 75 | else: 76 | 77 | def make_inputs_require_grad(module, input, output): 78 | output.requires_grad_(True) 79 | 80 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 81 | 82 | # enable gradient checkpointing for memory efficiency 83 | model.gradient_checkpointing_enable() 84 | 85 | if hasattr(model, output_embedding_layer_name): 86 | output_embedding_layer = getattr(model, output_embedding_layer_name) 87 | input_dtype = output_embedding_layer.weight.dtype 88 | 89 | class CastOutputToFloat(torch.nn.Sequential): 90 | r""" 91 | Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted 92 | in fp32 93 | 94 | """ 95 | 96 | def forward(self, x): 97 | return super().forward(x.to(input_dtype)).to(torch.float32) 98 | 99 | setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) 100 | 101 | return model 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /sanity_check/backends/peft/tuners/prefix_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from dataclasses import dataclass, field 18 | 19 | import torch 20 | 21 | from ..utils import PeftType, PromptLearningConfig 22 | 23 | 24 | @dataclass 25 | class PrefixTuningConfig(PromptLearningConfig): 26 | """ 27 | This is the configuration class to store the configuration of a [`PrefixEncoder`]. 28 | 29 | Args: 30 | encoder_hidden_size (`int`): The hidden size of the prompt encoder. 31 | prefix_projection (`bool`): Whether to project the prefix embeddings. 32 | """ 33 | 34 | encoder_hidden_size: int = field( 35 | default=None, 36 | metadata={"help": "The hidden size of the encoder"}, 37 | ) 38 | prefix_projection: bool = field( 39 | default=False, 40 | metadata={"help": "Whether to project the prefix tokens"}, 41 | ) 42 | 43 | def __post_init__(self): 44 | self.peft_type = PeftType.PREFIX_TUNING 45 | 46 | 47 | # Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py 48 | # with some refactor 49 | class PrefixEncoder(torch.nn.Module): 50 | r""" 51 | The `torch.nn` model to encode the prefix. 52 | 53 | Args: 54 | config ([`PrefixTuningConfig`]): The configuration of the prefix encoder. 55 | 56 | Example: 57 | 58 | ```py 59 | >>> from peft import PrefixEncoder, PrefixTuningConfig 60 | 61 | >>> config = PrefixTuningConfig( 62 | ... peft_type="PREFIX_TUNING", 63 | ... task_type="SEQ_2_SEQ_LM", 64 | ... num_virtual_tokens=20, 65 | ... token_dim=768, 66 | ... num_transformer_submodules=1, 67 | ... num_attention_heads=12, 68 | ... num_layers=12, 69 | ... encoder_hidden_size=768, 70 | ... ) 71 | >>> prefix_encoder = PrefixEncoder(config) 72 | ``` 73 | 74 | **Attributes**: 75 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prefix encoder. 76 | - **transform** (`torch.nn.Sequential`) -- The two-layer MLP to transform the prefix embeddings if 77 | `prefix_projection` is `True`. 78 | - **prefix_projection** (`bool`) -- Whether to project the prefix embeddings. 79 | 80 | Input shape: (`batch_size`, `num_virtual_tokens`) 81 | 82 | Output shape: (`batch_size`, `num_virtual_tokens`, `2*layers*hidden`) 83 | """ 84 | 85 | def __init__(self, config): 86 | super().__init__() 87 | self.prefix_projection = config.prefix_projection 88 | token_dim = config.token_dim 89 | num_layers = config.num_layers 90 | encoder_hidden_size = config.encoder_hidden_size 91 | num_virtual_tokens = config.num_virtual_tokens 92 | if self.prefix_projection and not config.inference_mode: 93 | # Use a two-layer MLP to encode the prefix 94 | self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim) 95 | self.transform = torch.nn.Sequential( 96 | torch.nn.Linear(token_dim, encoder_hidden_size), 97 | torch.nn.Tanh(), 98 | torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim), 99 | ) 100 | else: 101 | self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim) 102 | 103 | def forward(self, prefix: torch.Tensor): 104 | if self.prefix_projection: 105 | prefix_tokens = self.embedding(prefix) 106 | past_key_values = self.transform(prefix_tokens) 107 | else: 108 | past_key_values = self.embedding(prefix) 109 | return past_key_values 110 | -------------------------------------------------------------------------------- /only_train_once/optimizer/importance_score/cosine_similarity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from only_train_once.transform import tensor_transformation, TensorTransform 4 | 5 | def importance_score_by_cosine_similarity_dhspg(param_group): 6 | norm_params = None 7 | norm_grads = None 8 | params_grads_inner_prod = None 9 | for p_name, param, p_transform in zip(param_group['p_names'], param_group['params'], param_group['p_transform']): 10 | if p_name not in param_group['grad_variant']: 11 | continue 12 | grad = param_group['grad_variant'][p_name] 13 | param_transform = None 14 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 15 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups'], param_group['num_heads']) 16 | else: 17 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups']) 18 | if norm_params == None: 19 | norm_params = torch.norm(param_transform, dim=1) ** 2 20 | else: 21 | norm_params += torch.norm(param_transform, dim=1) ** 2 22 | 23 | grad_transform = None 24 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 25 | grad_transform = tensor_transformation(grad, p_transform, param_group['num_groups'], param_group['num_heads']) 26 | else: 27 | grad_transform = tensor_transformation(grad, p_transform, param_group['num_groups']) 28 | if norm_grads == None: 29 | norm_grads = torch.norm(grad_transform, dim=1) ** 2 30 | else: 31 | norm_grads += torch.norm(grad_transform, dim=1) ** 2 32 | 33 | if params_grads_inner_prod == None: 34 | params_grads_inner_prod = torch.sum(param_transform * grad_transform, dim=1) 35 | else: 36 | params_grads_inner_prod += torch.sum(param_transform * grad_transform, dim=1) 37 | 38 | norm_params = torch.sqrt(norm_params) 39 | norm_grads = torch.sqrt(norm_grads) 40 | param_group['importance_scores']['cosine_similarity'] = params_grads_inner_prod / (norm_params + 1e-8) / (norm_grads + 1e-8) + 1 41 | 42 | def importance_score_by_cosine_similarity_lhspg(param_group, global_params): 43 | norm_params = None 44 | norm_grads = None 45 | params_grads_inner_prod = None 46 | for p_name, param, p_transform in zip(param_group['p_names'], param_group['params'], param_group['p_transform']): 47 | if 'lora_B' in p_name: 48 | lora_A_name = p_name.replace('lora_B', 'lora_A') 49 | lora_A = global_params[lora_A_name] 50 | lora_BA = torch.matmul(param, lora_A) 51 | original_param_name = p_name.split('lora_B')[0] + 'weight' 52 | original_param = global_params[original_param_name] 53 | 54 | param_transform = None 55 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 56 | param_transform = tensor_transformation(original_param, p_transform, param_group['num_groups'], param_group['num_heads']) 57 | else: 58 | param_transform = tensor_transformation(original_param, p_transform, param_group['num_groups']) 59 | if norm_params == None: 60 | norm_params = torch.norm(param_transform, dim=1) ** 2 61 | else: 62 | norm_params += torch.norm(param_transform, dim=1) ** 2 63 | 64 | grad_transform = None 65 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 66 | grad_transform = tensor_transformation(lora_BA, p_transform, param_group['num_groups'], param_group['num_heads']) 67 | else: 68 | grad_transform = tensor_transformation(lora_BA, p_transform, param_group['num_groups']) 69 | if norm_grads == None: 70 | norm_grads = torch.norm(grad_transform, dim=1) ** 2 71 | else: 72 | norm_grads += torch.norm(grad_transform, dim=1) ** 2 73 | 74 | if params_grads_inner_prod == None: 75 | params_grads_inner_prod = torch.sum(param_transform * grad_transform, dim=1) 76 | else: 77 | params_grads_inner_prod += torch.sum(param_transform * grad_transform, dim=1) 78 | 79 | norm_params = torch.sqrt(norm_params) 80 | norm_grads = torch.sqrt(norm_grads) 81 | param_group['importance_scores']['cosine_similarity'] = params_grads_inner_prod / (norm_params + 1e-8) / (norm_grads + 1e-8) + 1 -------------------------------------------------------------------------------- /sanity_check/backends/carn/ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | 7 | def init_weights(modules): 8 | pass 9 | 10 | 11 | class MeanShift(nn.Module): 12 | def __init__(self, mean_rgb, sub): 13 | super(MeanShift, self).__init__() 14 | 15 | sign = -1 if sub else 1 16 | r = mean_rgb[0] * sign 17 | g = mean_rgb[1] * sign 18 | b = mean_rgb[2] * sign 19 | 20 | self.shifter = nn.Conv2d(3, 3, 1, 1, 0) 21 | self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) 22 | self.shifter.bias.data = torch.Tensor([r, g, b]) 23 | 24 | # Freeze the mean shift layer 25 | for params in self.shifter.parameters(): 26 | params.requires_grad = False 27 | 28 | def forward(self, x): 29 | x = self.shifter(x) 30 | return x 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | def __init__(self, 35 | in_channels, out_channels, 36 | ksize=3, stride=1, pad=1): 37 | super(BasicBlock, self).__init__() 38 | 39 | self.body = nn.Sequential( 40 | nn.Conv2d(in_channels, out_channels, ksize, stride, pad), 41 | nn.ReLU(inplace=True) 42 | ) 43 | 44 | init_weights(self.modules) 45 | 46 | def forward(self, x): 47 | out = self.body(x) 48 | return out 49 | 50 | 51 | class ResidualBlock(nn.Module): 52 | def __init__(self, 53 | in_channels, out_channels): 54 | super(ResidualBlock, self).__init__() 55 | 56 | self.body = nn.Sequential( 57 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 58 | nn.ReLU(inplace=True), 59 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 60 | ) 61 | 62 | init_weights(self.modules) 63 | 64 | def forward(self, x): 65 | out = self.body(x) 66 | out = F.relu(out + x) 67 | return out 68 | 69 | 70 | class EResidualBlock(nn.Module): 71 | def __init__(self, 72 | in_channels, out_channels, 73 | group=1): 74 | super(EResidualBlock, self).__init__() 75 | 76 | self.body = nn.Sequential( 77 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group), 78 | nn.ReLU(inplace=True), 79 | nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group), 80 | nn.ReLU(inplace=True), 81 | nn.Conv2d(out_channels, out_channels, 1, 1, 0), 82 | ) 83 | 84 | init_weights(self.modules) 85 | 86 | def forward(self, x): 87 | out = self.body(x) 88 | out = F.relu(out + x) 89 | return out 90 | 91 | 92 | class UpsampleBlock(nn.Module): 93 | def __init__(self, 94 | n_channels, scale, multi_scale, 95 | group=1): 96 | super(UpsampleBlock, self).__init__() 97 | 98 | if multi_scale: 99 | self.up2 = _UpsampleBlock(n_channels, scale=2, group=group) 100 | self.up3 = _UpsampleBlock(n_channels, scale=3, group=group) 101 | self.up4 = _UpsampleBlock(n_channels, scale=4, group=group) 102 | else: 103 | self.up = _UpsampleBlock(n_channels, scale=scale, group=group) 104 | 105 | self.multi_scale = multi_scale 106 | 107 | def forward(self, x, scale): 108 | if self.multi_scale: 109 | if scale == 2: 110 | return self.up2(x) 111 | elif scale == 3: 112 | return self.up3(x) 113 | elif scale == 4: 114 | return self.up4(x) 115 | else: 116 | return self.up(x) 117 | 118 | 119 | class _UpsampleBlock(nn.Module): 120 | def __init__(self, 121 | n_channels, scale, 122 | group=1): 123 | super(_UpsampleBlock, self).__init__() 124 | 125 | modules = [] 126 | if scale == 2 or scale == 4 or scale == 8: 127 | for _ in range(int(math.log(scale, 2))): 128 | modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 129 | modules += [nn.PixelShuffle(2)] 130 | elif scale == 3: 131 | modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 132 | modules += [nn.PixelShuffle(3)] 133 | 134 | self.body = nn.Sequential(*modules) 135 | init_weights(self.modules) 136 | 137 | def forward(self, x): 138 | out = self.body(x) 139 | return out -------------------------------------------------------------------------------- /sanity_check/backends/peft/mapping.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .peft_model import ( 17 | PeftModel, 18 | PeftModelForCausalLM, 19 | PeftModelForSeq2SeqLM, 20 | PeftModelForSequenceClassification, 21 | PeftModelForTokenClassification, 22 | ) 23 | from .tuners import AdaLoraConfig, LoraConfig, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig 24 | from .utils import PromptLearningConfig 25 | 26 | 27 | MODEL_TYPE_TO_PEFT_MODEL_MAPPING = { 28 | "SEQ_CLS": PeftModelForSequenceClassification, 29 | "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM, 30 | "CAUSAL_LM": PeftModelForCausalLM, 31 | "TOKEN_CLS": PeftModelForTokenClassification, 32 | } 33 | 34 | PEFT_TYPE_TO_CONFIG_MAPPING = { 35 | "PROMPT_TUNING": PromptTuningConfig, 36 | "PREFIX_TUNING": PrefixTuningConfig, 37 | "P_TUNING": PromptEncoderConfig, 38 | "LORA": LoraConfig, 39 | "ADALORA": AdaLoraConfig, 40 | } 41 | 42 | 43 | def get_peft_config(config_dict): 44 | """ 45 | Returns a Peft config object from a dictionary. 46 | 47 | Args: 48 | config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters. 49 | """ 50 | 51 | return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict) 52 | 53 | 54 | def _prepare_prompt_learning_config(peft_config, model_config): 55 | if peft_config.num_layers is None: 56 | if "num_hidden_layers" in model_config: 57 | num_layers = model_config["num_hidden_layers"] 58 | elif "num_layers" in model_config: 59 | num_layers = model_config["num_layers"] 60 | elif "n_layer" in model_config: 61 | num_layers = model_config["n_layer"] 62 | else: 63 | raise ValueError("Please specify `num_layers` in `peft_config`") 64 | peft_config.num_layers = num_layers 65 | 66 | if peft_config.token_dim is None: 67 | if "hidden_size" in model_config: 68 | token_dim = model_config["hidden_size"] 69 | elif "n_embd" in model_config: 70 | token_dim = model_config["n_embd"] 71 | elif "d_model" in model_config: 72 | token_dim = model_config["d_model"] 73 | else: 74 | raise ValueError("Please specify `token_dim` in `peft_config`") 75 | peft_config.token_dim = token_dim 76 | 77 | if peft_config.num_attention_heads is None: 78 | if "num_attention_heads" in model_config: 79 | num_attention_heads = model_config["num_attention_heads"] 80 | elif "n_head" in model_config: 81 | num_attention_heads = model_config["n_head"] 82 | elif "num_heads" in model_config: 83 | num_attention_heads = model_config["num_heads"] 84 | elif "encoder_attention_heads" in model_config: 85 | num_attention_heads = model_config["encoder_attention_heads"] 86 | else: 87 | raise ValueError("Please specify `num_attention_heads` in `peft_config`") 88 | peft_config.num_attention_heads = num_attention_heads 89 | 90 | if getattr(peft_config, "encoder_hidden_size", None) is None: 91 | setattr(peft_config, "encoder_hidden_size", token_dim) 92 | 93 | return peft_config 94 | 95 | 96 | def get_peft_model(model, peft_config): 97 | """ 98 | Returns a Peft model object from a model and a config. 99 | 100 | Args: 101 | model ([`transformers.PreTrainedModel`]): Model to be wrapped. 102 | peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model. 103 | """ 104 | model_config = model.config.to_dict() if hasattr(model.config, "to_dict") else model.config 105 | peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) 106 | if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance( 107 | peft_config, PromptLearningConfig 108 | ): 109 | return PeftModel(model, peft_config) 110 | if isinstance(peft_config, PromptLearningConfig): 111 | peft_config = _prepare_prompt_learning_config(peft_config, model_config) 112 | return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config) 113 | -------------------------------------------------------------------------------- /sanity_check/backends/demo_group_conv_case1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class DepthConv(nn.Module): 5 | def __init__(self, in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), dilation=1, groups=None): 6 | super(DepthConv, self).__init__() 7 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels) 8 | self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), dilation=dilation, groups=1) 9 | 10 | def forward(self, x): 11 | return self.conv2(self.conv1(x)) 12 | 13 | normalizations = { 14 | 'bn': nn.BatchNorm2d, 15 | 'in': nn.InstanceNorm2d, 16 | } 17 | 18 | class DemoNetGroupConvCase1(nn.Module): 19 | def __init__(self, norm_type='in', affine=True, bias=True): 20 | super(DemoNetGroupConvCase1, self).__init__() 21 | self.conv_1 = DepthConv(6, 48, kernel_size=(5,5), stride=(2,2), padding=(2,2)) 22 | self.in_1 = normalizations[norm_type](48, affine=affine) 23 | 24 | self.leakyrelu = nn.LeakyReLU() 25 | 26 | self.conv_2 = DepthConv(48, 96, kernel_size=(3,3), stride=(2,2), padding=(1,1)) 27 | self.in_2 = normalizations[norm_type](96, affine=affine) 28 | 29 | self.conv_3 = DepthConv(96, 192, kernel_size=(3,3), stride=(2,2), padding=(1,1)) 30 | self.in_3 = normalizations[norm_type](192, affine=affine) 31 | 32 | self.conv_4 = DepthConv(192, 384, kernel_size=(3,3), stride=(2,2), padding=(1,1)) 33 | self.in_4 = normalizations[norm_type](384, affine=affine) 34 | 35 | self.conv_5 = DepthConv(384, 384, kernel_size=(3,3), stride=(2,2), padding=(1,1)) 36 | self.in_5 = normalizations[norm_type](384, affine=affine) 37 | 38 | self.conv_6 = DepthConv(832, 1536, kernel_size=(2,2), stride=(1,1), padding=(1,1), dilation=2) 39 | 40 | self.convt_7 = nn.ConvTranspose2d(384, 384, kernel_size=(3, 3), padding=1, output_padding=1, stride=2) 41 | self.in_7 = normalizations[norm_type](384, affine=affine) 42 | 43 | self.convt_8 = nn.ConvTranspose2d(768, 192, kernel_size=(3, 3), padding=1, output_padding=1, stride=2) 44 | self.in_8 = normalizations[norm_type](192, affine=affine) 45 | 46 | self.convt_9 = nn.ConvTranspose2d(384, 96, kernel_size=(3, 3), padding=1, output_padding=1, stride=2) 47 | self.in_9 = normalizations[norm_type](96, affine=affine) 48 | 49 | self.convt_10 = nn.ConvTranspose2d(192, 48, kernel_size=(3, 3), padding=1, output_padding=1, stride=2) 50 | self.in_10 = normalizations[norm_type](48, affine=affine) 51 | 52 | self.conv_11 = nn.ConvTranspose2d(96, 48, kernel_size=(3, 3), padding=1, output_padding=1, stride=2) 53 | self.in_11 = normalizations[norm_type](48, affine=affine) 54 | 55 | self.conv_12 = nn.Conv2d(48, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 56 | 57 | self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 58 | self.gemm1 = nn.Linear(in_features=96, out_features=48, bias=True) 59 | self.gemm2 = nn.Linear(in_features=48, out_features=10, bias=True) 60 | 61 | self.in_debug = normalizations[norm_type](384, affine=affine) 62 | 63 | def forward(self, x_1, x_2, x_3, x_4): 64 | x = torch.cat([x_1, x_2], dim=1) 65 | x = self.leakyrelu(self.in_1(self.conv_1(x))) 66 | 67 | x_down_1 = self.leakyrelu(self.in_2(self.conv_2(x))) 68 | x_down_2 = self.leakyrelu(self.in_3(self.conv_3(x_down_1))) 69 | x_down_3 = self.leakyrelu(self.in_4(self.conv_4(x_down_2))) 70 | x_down_4 = self.leakyrelu(self.in_5(self.conv_5(x_down_3))) 71 | 72 | x_down_4 = torch.cat([x_4, x_down_4, x_3], dim=1) 73 | x_down_4 = self.conv_6(x_down_4) 74 | x_down_4_up, x_down_4_out = x_down_4[:, :384, ...], x_down_4[:, 384:, ...] 75 | 76 | # print(x_down_4) 77 | # return x_down_4 78 | # x_down_4_up, x_down_4_out = x_down_4, x_down_4 79 | 80 | x_up_1 = self.in_7(self.convt_7(x_down_4_up)) 81 | 82 | # return x_up_1 83 | 84 | x_up_1 = torch.cat([x_up_1, x_down_3], dim=1) 85 | 86 | x_up_2 = self.in_8(self.convt_8(x_up_1)) 87 | x_up_2 = torch.cat([x_up_2, x_down_2], dim=1) 88 | x_up_3 = self.in_9(self.convt_9(x_up_2)) 89 | x_up_3 = torch.cat([x_up_3, x_down_1], dim=1) 90 | x_up_4 = self.in_10(self.convt_10(x_up_3)) 91 | x_up_4 = torch.cat([x_up_4, x], dim=1) 92 | x_up_5 = self.in_11(self.conv_11(x_up_4)) 93 | 94 | x_out = self.conv_12(x_up_5) 95 | x_out = self.avg_pool(x_out) 96 | x_out = x_out.view(x_out.size(0), -1) 97 | 98 | return self.gemm2(self.gemm1(x_out)), x_down_4_out 99 | # return self.gemm2(self.gemm1(x_out)) 100 | # net = DemoNetGroupConvCase1() 101 | 102 | # dummy_input_1 = torch.rand(1, 3, 512, 512) 103 | # dummy_input_2 = torch.rand(1, 3, 512, 512) 104 | # dummy_input_3 = torch.rand(1, 384, 16, 16) 105 | # dummy_input_4 = torch.rand(1, 64, 16, 16) 106 | 107 | # outs = net(dummy_input_1, dummy_input_2, dummy_input_3, dummy_input_4) 108 | # print(outs[0].shape, outs[1].shape) 109 | -------------------------------------------------------------------------------- /sanity_check/backends/demonet_batchnorm_pruning.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm 3 | import torch 4 | 5 | def spectral_norm(module, use_spect=True): 6 | """use spectral normal layer to stable the training process""" 7 | if use_spect: 8 | return SpectralNorm(module) 9 | else: 10 | return module 11 | 12 | class DemonetBatchnormPruning(nn.Module): 13 | def __init__(self, input_nc, base_nc, max_nc, encoder_layers, decoder_layers, nonlinearity, use_spect, size=256): 14 | super(DemonetBatchnormPruning, self).__init__() 15 | 16 | if size == 512: 17 | self.input_layer = nn.Sequential(nn.Conv2d(input_nc, base_nc, kernel_size=7, stride=2, padding=3), 18 | nn.Conv2d(base_nc, base_nc, kernel_size=7, stride=1, padding=3)) 19 | elif size == 256: 20 | self.input_layer = nn.Conv2d(input_nc, base_nc, kernel_size=7, stride=1, padding=3) 21 | elif size == 64: 22 | self.input_layer = nn.Conv2d(input_nc, base_nc, kernel_size=3, stride=1, padding=1) 23 | else: 24 | raise Exception('Input layer for the size is not defined: ', size) 25 | 26 | for i in range(encoder_layers): 27 | in_channels = min(base_nc * 2**i, max_nc) 28 | out_channels = min(base_nc * 2**(i+1), max_nc) 29 | model = ResBlock(in_channels, out_channels, out_channels, use_transpose=False, 30 | nonlinearity=nonlinearity, use_spect=use_spect) 31 | setattr(self, 'encoder' + str(i), model) 32 | 33 | for i in range(encoder_layers - decoder_layers, encoder_layers)[::-1]: 34 | in_channels = min(base_nc * (2 ** (i + 1)), max_nc) 35 | in_channels = in_channels * 2 if i != (encoder_layers - 1) else in_channels 36 | out_channels = min(base_nc * (2 ** i), max_nc) 37 | model = ResBlock(in_channels, out_channels, out_channels, use_transpose=True, 38 | nonlinearity=nonlinearity, use_spect=use_spect) 39 | setattr(self, 'decoder' + str(i), model) 40 | 41 | self.output_nc = out_channels * 2 42 | self.output_layer = nn.Conv2d(self.output_nc, self.output_nc, kernel_size=3, stride=1, padding=1) 43 | 44 | self.encoder_layers = encoder_layers 45 | self.decoder_layers = decoder_layers 46 | 47 | def forward(self, x): 48 | x = torch.cat(x, dim=1) 49 | out = self.input_layer(x) 50 | out_list = [out] 51 | for i in range(self.encoder_layers): 52 | model = getattr(self, 'encoder' + str(i)) 53 | out = model(out) 54 | out_list.append(out) 55 | 56 | out = out_list.pop() 57 | for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]: 58 | model = getattr(self, 'decoder' + str(i)) 59 | out = model(out) 60 | out = torch.cat([out, out_list.pop()], 1) 61 | 62 | out = self.output_layer(out) 63 | return out 64 | 65 | class ResBlock(nn.Module): 66 | def __init__(self, input_nc, output_nc, hidden_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), 67 | use_spect=False): 68 | super(ResBlock, self).__init__() 69 | # Attributes 70 | self.actvn = nonlinearity 71 | hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc 72 | 73 | kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1} 74 | if use_transpose: 75 | kwargs_up = {'kernel_size': 3, 'stride': 2, 'padding': 1, 'output_padding': 1} 76 | else: 77 | kwargs_up = {'kernel_size': 3, 'stride': 2, 'padding': 1} 78 | 79 | # create conv layers 80 | self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect) 81 | if use_transpose: 82 | self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect) 83 | self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect) 84 | else: 85 | self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect))#, 86 | # nn.Upsample(scale_factor=2)) 87 | self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect))#, 88 | # nn.Upsample(scale_factor=2)) 89 | # # define normalization layers 90 | self.norm_0 = nn.BatchNorm2d(input_nc) 91 | self.norm_1 = nn.BatchNorm2d(hidden_nc) 92 | self.norm_s = nn.BatchNorm2d(input_nc) 93 | 94 | def forward(self, x): 95 | x_s = self.shortcut(x) 96 | dx = self.conv_0(self.actvn(self.norm_0(x))) 97 | dx = self.conv_1(self.actvn(self.norm_1(dx))) 98 | out = x_s + dx 99 | return out 100 | 101 | def shortcut(self, x): 102 | x_s = self.conv_s(self.actvn(self.norm_s(x))) 103 | return x_s 104 | 105 | 106 | if __name__=="__main__": 107 | 108 | net = DemonetBatchnormPruning(13,32,256,5,3,nn.LeakyReLU(),False,256) 109 | input = torch.randn(1,13,256,256) 110 | out = net(input) 111 | print(out.shape) -------------------------------------------------------------------------------- /sanity_check/backends/peft/tuners/prompt_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import enum 17 | import math 18 | from dataclasses import dataclass, field 19 | from typing import Optional, Union 20 | 21 | import torch 22 | 23 | from ..utils import PeftType, PromptLearningConfig 24 | 25 | 26 | class PromptTuningInit(str, enum.Enum): 27 | TEXT = "TEXT" 28 | RANDOM = "RANDOM" 29 | 30 | 31 | @dataclass 32 | class PromptTuningConfig(PromptLearningConfig): 33 | """ 34 | This is the configuration class to store the configuration of a [`PromptEmbedding`]. 35 | 36 | Args: 37 | prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): The initialization of the prompt embedding. 38 | prompt_tuning_init_text (`str`, *optional*): 39 | The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`. 40 | tokenizer_name_or_path (`str`, *optional*): 41 | The name or path of the tokenizer. Only used if `prompt_tuning_init` is `TEXT`. 42 | """ 43 | 44 | prompt_tuning_init: Union[PromptTuningInit, str] = field( 45 | default=PromptTuningInit.RANDOM, 46 | metadata={"help": "How to initialize the prompt tuning parameters"}, 47 | ) 48 | prompt_tuning_init_text: Optional[str] = field( 49 | default=None, 50 | metadata={ 51 | "help": "The text to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" 52 | }, 53 | ) 54 | tokenizer_name_or_path: Optional[str] = field( 55 | default=None, 56 | metadata={ 57 | "help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" 58 | }, 59 | ) 60 | 61 | def __post_init__(self): 62 | self.peft_type = PeftType.PROMPT_TUNING 63 | 64 | 65 | class PromptEmbedding(torch.nn.Module): 66 | """ 67 | The model to encode virtual tokens into prompt embeddings. 68 | 69 | Args: 70 | config ([`PromptTuningConfig`]): The configuration of the prompt embedding. 71 | word_embeddings (`torch.nn.Module`): The word embeddings of the base transformer model. 72 | 73 | **Attributes**: 74 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt embedding. 75 | 76 | Example: 77 | 78 | ```py 79 | >>> from peft import PromptEmbedding, PromptTuningConfig 80 | 81 | >>> config = PromptTuningConfig( 82 | ... peft_type="PROMPT_TUNING", 83 | ... task_type="SEQ_2_SEQ_LM", 84 | ... num_virtual_tokens=20, 85 | ... token_dim=768, 86 | ... num_transformer_submodules=1, 87 | ... num_attention_heads=12, 88 | ... num_layers=12, 89 | ... prompt_tuning_init="TEXT", 90 | ... prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral", 91 | ... tokenizer_name_or_path="t5-base", 92 | ... ) 93 | 94 | >>> # t5_model.shared is the word embeddings of the base model 95 | >>> prompt_embedding = PromptEmbedding(config, t5_model.shared) 96 | ``` 97 | 98 | Input Shape: (`batch_size`, `total_virtual_tokens`) 99 | 100 | Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`) 101 | """ 102 | 103 | def __init__(self, config, word_embeddings): 104 | super().__init__() 105 | 106 | total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules 107 | self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim) 108 | if config.prompt_tuning_init == PromptTuningInit.TEXT: 109 | from transformers import AutoTokenizer 110 | 111 | tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path) 112 | init_text = config.prompt_tuning_init_text 113 | init_token_ids = tokenizer(init_text)["input_ids"] 114 | # Trim or iterate until num_text_tokens matches total_virtual_tokens 115 | num_text_tokens = len(init_token_ids) 116 | if num_text_tokens > total_virtual_tokens: 117 | init_token_ids = init_token_ids[:total_virtual_tokens] 118 | elif num_text_tokens < total_virtual_tokens: 119 | num_reps = math.ceil(total_virtual_tokens / num_text_tokens) 120 | init_token_ids = init_token_ids * num_reps 121 | init_token_ids = init_token_ids[:total_virtual_tokens] 122 | 123 | word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone() 124 | word_embedding_weights = word_embedding_weights.to(torch.float32) 125 | self.embedding.weight = torch.nn.Parameter(word_embedding_weights) 126 | 127 | def forward(self, indices): 128 | # Just get embeddings 129 | prompt_embeddings = self.embedding(indices) 130 | return prompt_embeddings 131 | -------------------------------------------------------------------------------- /sanity_check/backends/densenet.py: -------------------------------------------------------------------------------- 1 | """dense net in pytorch 2 | 3 | 4 | 5 | [1] Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger. 6 | 7 | Densely Connected Convolutional Networks 8 | https://arxiv.org/abs/1608.06993v5 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | 16 | #"""Bottleneck layers. Although each layer only produces k 17 | #output feature-maps, it typically has many more inputs. It 18 | #has been noted in [37, 11] that a 1×1 convolution can be in- 19 | #troduced as bottleneck layer before each 3×3 convolution 20 | #to reduce the number of input feature-maps, and thus to 21 | #improve computational efficiency.""" 22 | class Bottleneck(nn.Module): 23 | def __init__(self, in_channels, growth_rate): 24 | super().__init__() 25 | #"""In our experiments, we let each 1×1 convolution 26 | #produce 4k feature-maps.""" 27 | inner_channel = 4 * growth_rate 28 | 29 | #"""We find this design especially effective for DenseNet and 30 | #we refer to our network with such a bottleneck layer, i.e., 31 | #to the BN-ReLU-Conv(1×1)-BN-ReLU-Conv(3×3) version of H ` , 32 | #as DenseNet-B.""" 33 | self.bottle_neck = nn.Sequential( 34 | nn.BatchNorm2d(in_channels), 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(in_channels, inner_channel, kernel_size=1, bias=False), 37 | nn.BatchNorm2d(inner_channel), 38 | nn.ReLU(inplace=True), 39 | nn.Conv2d(inner_channel, growth_rate, kernel_size=3, padding=1, bias=False) 40 | ) 41 | 42 | def forward(self, x): 43 | return torch.cat([x, self.bottle_neck(x)], 1) 44 | 45 | #"""We refer to layers between blocks as transition 46 | #layers, which do convolution and pooling.""" 47 | class Transition(nn.Module): 48 | def __init__(self, in_channels, out_channels): 49 | super().__init__() 50 | #"""The transition layers used in our experiments 51 | #consist of a batch normalization layer and an 1×1 52 | #convolutional layer followed by a 2×2 average pooling 53 | #layer""". 54 | self.down_sample = nn.Sequential( 55 | nn.BatchNorm2d(in_channels), 56 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 57 | nn.AvgPool2d(2, stride=2) 58 | ) 59 | 60 | def forward(self, x): 61 | return self.down_sample(x) 62 | 63 | #DesneNet-BC 64 | #B stands for bottleneck layer(BN-RELU-CONV(1x1)-BN-RELU-CONV(3x3)) 65 | #C stands for compression factor(0<=theta<=1) 66 | class DenseNet(nn.Module): 67 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_class=100): 68 | super().__init__() 69 | self.growth_rate = growth_rate 70 | 71 | #"""Before entering the first dense block, a convolution 72 | #with 16 (or twice the growth rate for DenseNet-BC) 73 | #output channels is performed on the input images.""" 74 | inner_channels = 2 * growth_rate 75 | 76 | #For convolutional layers with kernel size 3×3, each 77 | #side of the inputs is zero-padded by one pixel to keep 78 | #the feature-map size fixed. 79 | self.conv1 = nn.Conv2d(3, inner_channels, kernel_size=3, padding=1, bias=False) 80 | 81 | self.features = nn.Sequential() 82 | 83 | for index in range(len(nblocks) - 1): 84 | self.features.add_module("dense_block_layer_{}".format(index), self._make_dense_layers(block, inner_channels, nblocks[index])) 85 | inner_channels += growth_rate * nblocks[index] 86 | 87 | #"""If a dense block contains m feature-maps, we let the 88 | #following transition layer generate θm output feature- 89 | #maps, where 0 < θ ≤ 1 is referred to as the compression 90 | #fac-tor. 91 | out_channels = int(reduction * inner_channels) # int() will automatic floor the value 92 | self.features.add_module("transition_layer_{}".format(index), Transition(inner_channels, out_channels)) 93 | inner_channels = out_channels 94 | 95 | self.features.add_module("dense_block{}".format(len(nblocks) - 1), self._make_dense_layers(block, inner_channels, nblocks[len(nblocks)-1])) 96 | inner_channels += growth_rate * nblocks[len(nblocks) - 1] 97 | self.features.add_module('bn', nn.BatchNorm2d(inner_channels)) 98 | self.features.add_module('relu', nn.ReLU(inplace=True)) 99 | 100 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 101 | 102 | self.linear = nn.Linear(inner_channels, num_class) 103 | 104 | def forward(self, x): 105 | output = self.conv1(x) 106 | output = self.features(output) 107 | output = self.avgpool(output) 108 | output = output.view(output.size()[0], -1) 109 | output = self.linear(output) 110 | return output 111 | 112 | def _make_dense_layers(self, block, in_channels, nblocks): 113 | dense_block = nn.Sequential() 114 | for index in range(nblocks): 115 | dense_block.add_module('bottle_neck_layer_{}'.format(index), block(in_channels, self.growth_rate)) 116 | in_channels += self.growth_rate 117 | return dense_block 118 | 119 | def densenet121(): 120 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 121 | 122 | def densenet169(): 123 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 124 | 125 | def densenet201(): 126 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 127 | 128 | def densenet161(): 129 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) -------------------------------------------------------------------------------- /only_train_once/transform/ge.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is borrowed from hiddenlayer 3 | Licensed under the MIT License 4 | """ 5 | 6 | import re 7 | 8 | class GEParser(): 9 | def __init__(self, text): 10 | self.index = 0 11 | self.text = text 12 | 13 | def parse(self): 14 | return self.serial() or self.parallel() or self.expression() 15 | 16 | def parallel(self): 17 | index = self.index 18 | expressions = [] 19 | while len(expressions) == 0 or self.token("|"): 20 | e = self.expression() 21 | if not e: 22 | break 23 | expressions.append(e) 24 | if len(expressions) >= 2: 25 | return ParallelPattern(expressions) 26 | # No match. Reset index 27 | self.index = index 28 | 29 | def serial(self): 30 | index = self.index 31 | expressions = [] 32 | while len(expressions) == 0 or self.token(">"): 33 | e = self.expression() 34 | if not e: 35 | break 36 | expressions.append(e) 37 | 38 | if len(expressions) >= 2: 39 | return SerialPattern(expressions) 40 | self.index = index 41 | 42 | def expression(self): 43 | index = self.index 44 | 45 | if self.token("("): 46 | e = self.serial() or self.parallel() or self.op() 47 | if e and self.token(")"): 48 | return e 49 | self.index = index 50 | e = self.op() 51 | return e 52 | 53 | def op(self): 54 | t = self.re(r"\w+") 55 | if t: 56 | c = self.condition() 57 | return NodePattern(t, c) 58 | 59 | def condition(self): 60 | # TODO: not implemented yet. This function is a placeholder 61 | index = self.index 62 | if self.token("["): 63 | c = self.token("1x1") or self.token("3x3") 64 | if c: 65 | if self.token("]"): 66 | return c 67 | self.index = index 68 | 69 | def token(self, s): 70 | return self.re(r"\s*(" + re.escape(s) + r")\s*", 1) 71 | 72 | def string(self, s): 73 | if s == self.text[self.index:self.index+len(s)]: 74 | self.index += len(s) 75 | return s 76 | 77 | def re(self, regex, group=0): 78 | m = re.match(regex, self.text[self.index:]) 79 | if m: 80 | self.index += len(m.group(0)) 81 | return m.group(group) 82 | 83 | 84 | class NodePattern(): 85 | def __init__(self, op, condition=None): 86 | self.op = op 87 | self.condition = condition # TODO: not implemented yet 88 | 89 | def match(self, graph, node): 90 | # if node is a list, means there exists multiple edges, pattern does not allow. 91 | if isinstance(node, list): 92 | return [], None 93 | if self.op == node.op and not node._skip_pattern_search: 94 | # following nodes may be multiple, we only accept singleton 95 | following = graph.outgoing(node) 96 | if len(following) == 1: 97 | following = following[0] 98 | return [node], following 99 | else: 100 | return [], None 101 | 102 | 103 | class SerialPattern(): 104 | def __init__(self, patterns): 105 | self.patterns = patterns 106 | 107 | def match(self, graph, node): 108 | all_matches = [] 109 | for i, p in enumerate(self.patterns): 110 | matches, following = p.match(graph, node) 111 | if not matches: 112 | return [], None 113 | all_matches.extend(matches) 114 | if i < len(self.patterns) - 1: 115 | node = following # Might be more than one node 116 | return all_matches, following 117 | 118 | 119 | class ParallelPattern(): 120 | def __init__(self, patterns): 121 | self.patterns = patterns 122 | 123 | def match(self, graph, nodes): 124 | if not nodes: 125 | return [], None 126 | nodes = nodes if isinstance(nodes, list) else [nodes] 127 | # If a single node, assume we need to match with its siblings 128 | if len(nodes) == 1: 129 | nodes = graph.siblings(nodes[0]) 130 | else: 131 | # Verify all nodes have the same parent or all have no parent 132 | parents = [graph.incoming(n) for n in nodes] 133 | matches = [set(p) == set(parents[0]) for p in parents[1:]] 134 | if not all(matches): 135 | return [], None 136 | 137 | # TODO: If more nodes than patterns, we should consider 138 | # all permutations of the nodes 139 | if len(self.patterns) != len(nodes): 140 | return [], None 141 | 142 | patterns = self.patterns.copy() 143 | nodes = nodes.copy() 144 | all_matches = [] 145 | end_node = None 146 | for p in patterns: 147 | found = False 148 | for n in nodes: 149 | matches, following = p.match(graph, n) 150 | if matches: 151 | found = True 152 | nodes.remove(n) 153 | all_matches.extend(matches) 154 | # Verify all branches end in the same node 155 | if end_node: 156 | if end_node != following: 157 | return [], None 158 | else: 159 | end_node = following 160 | break 161 | if not found: 162 | return [], None 163 | return all_matches, end_node 164 | 165 | 166 | -------------------------------------------------------------------------------- /sanity_check/backends/peft/utils/save_and_load.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .config import PeftType, PromptLearningConfig 17 | 18 | 19 | def get_peft_model_state_dict(model, state_dict=None, adapter_name="default"): 20 | """ 21 | Get the state dict of the Peft model. 22 | 23 | Args: 24 | model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP, 25 | the model should be the underlying model/unwrapped model (i.e. model.module). 26 | state_dict (`dict`, *optional*, defaults to `None`): 27 | The state dict of the model. If not provided, the state dict of the model 28 | will be used. 29 | """ 30 | config = model.peft_config[adapter_name] 31 | if state_dict is None: 32 | state_dict = model.state_dict() 33 | if config.peft_type in (PeftType.LORA, PeftType.ADALORA): 34 | # to_return = lora_state_dict(model, bias=model.peft_config.bias) 35 | # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` 36 | # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP 37 | bias = config.bias 38 | if bias == "none": 39 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k} 40 | elif bias == "all": 41 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k} 42 | elif bias == "lora_only": 43 | to_return = {} 44 | for k in state_dict: 45 | if "lora_" in k: 46 | to_return[k] = state_dict[k] 47 | bias_name = k.split("lora_")[0] + "bias" 48 | if bias_name in state_dict: 49 | to_return[bias_name] = state_dict[bias_name] 50 | else: 51 | raise NotImplementedError 52 | to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))} 53 | if config.peft_type == PeftType.ADALORA: 54 | rank_pattern = config.rank_pattern 55 | if rank_pattern is not None: 56 | rank_pattern = {k.replace(f".{adapter_name}", ""): v for k, v in rank_pattern.items()} 57 | config.rank_pattern = rank_pattern 58 | to_return = model.resize_state_dict_by_rank_pattern(rank_pattern, to_return, adapter_name) 59 | elif isinstance(config, PromptLearningConfig): 60 | to_return = {} 61 | if config.inference_mode: 62 | prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight 63 | else: 64 | prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name) 65 | to_return["prompt_embeddings"] = prompt_embeddings 66 | else: 67 | raise NotImplementedError 68 | if model.modules_to_save is not None: 69 | for key, value in state_dict.items(): 70 | if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save): 71 | to_return[key.replace("modules_to_save.", "")] = value 72 | 73 | to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()} 74 | return to_return 75 | 76 | 77 | def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="default"): 78 | """ 79 | Set the state dict of the Peft model. 80 | 81 | Args: 82 | model ([`PeftModel`]): The Peft model. 83 | peft_model_state_dict (`dict`): The state dict of the Peft model. 84 | """ 85 | config = model.peft_config[adapter_name] 86 | state_dict = {} 87 | if model.modules_to_save is not None: 88 | for key, value in peft_model_state_dict.items(): 89 | if any(module_name in key for module_name in model.modules_to_save): 90 | for module_name in model.modules_to_save: 91 | if module_name in key: 92 | key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}") 93 | break 94 | state_dict[key] = value 95 | else: 96 | state_dict = peft_model_state_dict 97 | 98 | #print("config.peft_type: ".format(config.peft_type)) 99 | if config.peft_type in (PeftType.LORA, PeftType.ADALORA): 100 | peft_model_state_dict = {} 101 | for k, v in state_dict.items(): 102 | if "lora_" in k: 103 | suffix = k.split("lora_")[1] 104 | if "." in suffix: 105 | suffix_to_replace = ".".join(suffix.split(".")[1:]) 106 | k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}") 107 | else: 108 | k = f"{k}.{adapter_name}" 109 | peft_model_state_dict[k] = v 110 | else: 111 | peft_model_state_dict[k] = v 112 | if config.peft_type == PeftType.ADALORA: 113 | rank_pattern = config.rank_pattern 114 | if rank_pattern is not None: 115 | model.resize_modules_by_rank_pattern(rank_pattern, adapter_name) 116 | elif isinstance(config, PromptLearningConfig): 117 | peft_model_state_dict = state_dict 118 | else: 119 | raise NotImplementedError 120 | 121 | model.load_state_dict(peft_model_state_dict, strict=False) 122 | #exit() 123 | if isinstance(config, PromptLearningConfig): 124 | model.prompt_encoder[adapter_name].embedding.load_state_dict( 125 | {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True 126 | ) 127 | -------------------------------------------------------------------------------- /only_train_once/optimizer/importance_score/taylor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from only_train_once.transform import tensor_transformation, TensorTransform 4 | 5 | def importance_score_by_first_order_taylor_dhspg(param_group): 6 | params_grads_inner_prod = None 7 | # for param, grad, p_transform in zip(param_group['params'], param_group['grad_variant'], param_group['p_transform']): 8 | for p_name, param, p_transform in zip(param_group['p_names'], param_group['params'], param_group['p_transform']): 9 | if p_name not in param_group['grad_variant']: 10 | continue 11 | grad = param_group['grad_variant'][p_name] 12 | param_transform = None 13 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 14 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups'], param_group['num_heads']) 15 | else: 16 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups']) 17 | 18 | grad_transform = None 19 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 20 | grad_transform = tensor_transformation(grad, p_transform, param_group['num_groups'], param_group['num_heads']) 21 | else: 22 | grad_transform = tensor_transformation(grad, p_transform, param_group['num_groups']) 23 | 24 | if params_grads_inner_prod == None: 25 | params_grads_inner_prod = torch.sum(param_transform * grad_transform, dim=1) 26 | else: 27 | params_grads_inner_prod += torch.sum(param_transform * grad_transform, dim=1) 28 | param_group['importance_scores']['taylor_first_order'] = torch.abs(params_grads_inner_prod) 29 | 30 | def importance_score_by_second_order_taylor_dhspg(param_group): 31 | if 'taylor_first_order' in param_group['importance_scores']: 32 | param_group['importance_scores']['taylor_second_order'] = 0.5 * param_group['importance_scores']['taylor_first_order'] ** 2 33 | return 34 | 35 | params_grads_inner_prod = None 36 | # for param, grad, p_transform in zip(param_group['params'], param_group['grad_variant'], param_group['p_transform']): 37 | for p_name, param, p_transform in zip(param_group['p_names'], param_group['params'], param_group['p_transform']): 38 | if p_name not in param_group['grad_variant']: 39 | continue 40 | grad = param_group['grad_variant'][p_name] 41 | param_transform = None 42 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 43 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups'], param_group['num_heads']) 44 | else: 45 | param_transform = tensor_transformation(param, p_transform, param_group['num_groups']) 46 | 47 | grad_transform = None 48 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 49 | grad_transform = tensor_transformation(grad, p_transform, param_group['num_groups'], param_group['num_heads']) 50 | else: 51 | grad_transform = tensor_transformation(grad, p_transform, param_group['num_groups']) 52 | 53 | if params_grads_inner_prod == None: 54 | params_grads_inner_prod = torch.sum(param_transform * grad_transform, dim=1) 55 | else: 56 | params_grads_inner_prod += torch.sum(param_transform * grad_transform, dim=1) 57 | param_group['importance_scores']['taylor_second_order'] = 0.5 * params_grads_inner_prod ** 2 58 | 59 | def importance_score_by_first_order_taylor_lhspg(param_group, global_params): 60 | params_grads_inner_prod = None 61 | for p_name, param, p_transform in zip(param_group['p_names'], param_group['params'], param_group['p_transform']): 62 | if 'lora_B' in p_name: 63 | lora_A_name = p_name.replace('lora_B', 'lora_A') 64 | lora_A = global_params[lora_A_name] 65 | lora_BA = torch.matmul(param, lora_A) 66 | original_param_name = p_name.split('lora_B')[0] + 'weight' 67 | original_param = global_params[original_param_name] 68 | 69 | param_transform = None 70 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 71 | param_transform = tensor_transformation(original_param, p_transform, param_group['num_groups'], param_group['num_heads']) 72 | else: 73 | param_transform = tensor_transformation(original_param, p_transform, param_group['num_groups']) 74 | grad_transform = None 75 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 76 | grad_transform = tensor_transformation(lora_BA, p_transform, param_group['num_groups'], param_group['num_heads']) 77 | else: 78 | grad_transform = tensor_transformation(lora_BA, p_transform, param_group['num_groups']) 79 | 80 | if params_grads_inner_prod == None: 81 | params_grads_inner_prod = torch.sum(param_transform * grad_transform, dim=1) 82 | else: 83 | params_grads_inner_prod += torch.sum(param_transform * grad_transform, dim=1) 84 | 85 | param_group['importance_scores']['taylor_first_order'] = torch.abs(params_grads_inner_prod) 86 | 87 | def importance_score_by_second_order_taylor_lhspg(param_group, global_params): 88 | if 'taylor_first_order' in param_group['importance_scores']: 89 | param_group['importance_scores']['taylor_second_order'] = 0.5 * param_group['importance_scores']['taylor_first_order'] ** 2 90 | return 91 | 92 | params_grads_inner_prod = None 93 | for p_name, param, p_transform in zip(param_group['p_names'], param_group['params'], param_group['p_transform']): 94 | if 'lora_B' in p_name: 95 | lora_A_name = p_name.replace('lora_B', 'lora_A') 96 | lora_A = global_params[lora_A_name] 97 | lora_BA = torch.matmul(param, lora_A) 98 | original_param_name = p_name.split('lora_B')[0] + 'weight' 99 | original_param = global_params[original_param_name] 100 | 101 | param_transform = None 102 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 103 | param_transform = tensor_transformation(original_param, p_transform, param_group['num_groups'], param_group['num_heads']) 104 | else: 105 | param_transform = tensor_transformation(original_param, p_transform, param_group['num_groups']) 106 | grad_transform = None 107 | if p_transform == TensorTransform.MULTIHEAD_HEADDIM: 108 | grad_transform = tensor_transformation(lora_BA, p_transform, param_group['num_groups'], param_group['num_heads']) 109 | else: 110 | grad_transform = tensor_transformation(lora_BA, p_transform, param_group['num_groups']) 111 | 112 | if params_grads_inner_prod == None: 113 | params_grads_inner_prod = torch.sum(param_transform * grad_transform, dim=1) 114 | else: 115 | params_grads_inner_prod += torch.sum(param_transform * grad_transform, dim=1) 116 | 117 | param_group['importance_scores']['taylor_second_order'] = 0.5 * params_grads_inner_prod ** 2 -------------------------------------------------------------------------------- /sanity_check/backends/peft/tuners/p_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import enum 17 | import warnings 18 | from dataclasses import dataclass, field 19 | from typing import Union 20 | 21 | import torch 22 | 23 | from ..utils import PeftType, PromptLearningConfig 24 | 25 | 26 | class PromptEncoderReparameterizationType(str, enum.Enum): 27 | MLP = "MLP" 28 | LSTM = "LSTM" 29 | 30 | 31 | @dataclass 32 | class PromptEncoderConfig(PromptLearningConfig): 33 | """ 34 | This is the configuration class to store the configuration of a [`PromptEncoder`]. 35 | 36 | Args: 37 | encoder_reparameterization_type (Union[[`PromptEncoderReparameterizationType`], `str`]): 38 | The type of reparameterization to use. 39 | encoder_hidden_size (`int`): The hidden size of the prompt encoder. 40 | encoder_num_layers (`int`): The number of layers of the prompt encoder. 41 | encoder_dropout (`float`): The dropout probability of the prompt encoder. 42 | """ 43 | 44 | encoder_reparameterization_type: Union[str, PromptEncoderReparameterizationType] = field( 45 | default=PromptEncoderReparameterizationType.MLP, 46 | metadata={"help": "How to reparameterize the prompt encoder"}, 47 | ) 48 | encoder_hidden_size: int = field( 49 | default=None, 50 | metadata={"help": "The hidden size of the prompt encoder"}, 51 | ) 52 | encoder_num_layers: int = field( 53 | default=2, 54 | metadata={"help": "The number of layers of the prompt encoder"}, 55 | ) 56 | encoder_dropout: float = field( 57 | default=0.0, 58 | metadata={"help": "The dropout of the prompt encoder"}, 59 | ) 60 | 61 | def __post_init__(self): 62 | self.peft_type = PeftType.P_TUNING 63 | 64 | 65 | # Based on https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/modules/common/prompt_encoder.py 66 | # with some refactor 67 | class PromptEncoder(torch.nn.Module): 68 | """ 69 | The prompt encoder network that is used to generate the virtual token embeddings for p-tuning. 70 | 71 | Args: 72 | config ([`PromptEncoderConfig`]): The configuration of the prompt encoder. 73 | 74 | Example: 75 | 76 | ```py 77 | >>> from peft import PromptEncoder, PromptEncoderConfig 78 | 79 | >>> config = PromptEncoderConfig( 80 | ... peft_type="P_TUNING", 81 | ... task_type="SEQ_2_SEQ_LM", 82 | ... num_virtual_tokens=20, 83 | ... token_dim=768, 84 | ... num_transformer_submodules=1, 85 | ... num_attention_heads=12, 86 | ... num_layers=12, 87 | ... encoder_reparameterization_type="MLP", 88 | ... encoder_hidden_size=768, 89 | ... ) 90 | 91 | >>> prompt_encoder = PromptEncoder(config) 92 | ``` 93 | 94 | **Attributes**: 95 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt encoder. 96 | - **mlp_head** (`torch.nn.Sequential`) -- The MLP head of the prompt encoder if `inference_mode=False`. 97 | - **lstm_head** (`torch.nn.LSTM`) -- The LSTM head of the prompt encoder if `inference_mode=False` and 98 | `encoder_reparameterization_type="LSTM"`. 99 | - **token_dim** (`int`) -- The hidden embedding dimension of the base transformer model. 100 | - **input_size** (`int`) -- The input size of the prompt encoder. 101 | - **output_size** (`int`) -- The output size of the prompt encoder. 102 | - **hidden_size** (`int`) -- The hidden size of the prompt encoder. 103 | - **total_virtual_tokens** (`int`): The total number of virtual tokens of the 104 | prompt encoder. 105 | - **encoder_type** (Union[[`PromptEncoderReparameterizationType`], `str`]): The encoder type of the prompt 106 | encoder. 107 | 108 | 109 | Input shape: (`batch_size`, `total_virtual_tokens`) 110 | 111 | Output shape: (`batch_size`, `total_virtual_tokens`, `token_dim`) 112 | """ 113 | 114 | def __init__(self, config): 115 | super().__init__() 116 | self.token_dim = config.token_dim 117 | self.input_size = self.token_dim 118 | self.output_size = self.token_dim 119 | self.hidden_size = config.encoder_hidden_size 120 | self.total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules 121 | self.encoder_type = config.encoder_reparameterization_type 122 | 123 | # embedding 124 | self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim) 125 | if not config.inference_mode: 126 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM: 127 | lstm_dropout = config.encoder_dropout 128 | num_layers = config.encoder_num_layers 129 | # LSTM 130 | self.lstm_head = torch.nn.LSTM( 131 | input_size=self.input_size, 132 | hidden_size=self.hidden_size, 133 | num_layers=num_layers, 134 | dropout=lstm_dropout, 135 | bidirectional=True, 136 | batch_first=True, 137 | ) 138 | 139 | self.mlp_head = torch.nn.Sequential( 140 | torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2), 141 | torch.nn.ReLU(), 142 | torch.nn.Linear(self.hidden_size * 2, self.output_size), 143 | ) 144 | 145 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP: 146 | warnings.warn( 147 | f"for {self.encoder_type}, the `encoder_num_layers` is ignored. Exactly 2 MLP layers are used." 148 | ) 149 | layers = [ 150 | torch.nn.Linear(self.input_size, self.hidden_size), 151 | torch.nn.ReLU(), 152 | torch.nn.Linear(self.hidden_size, self.hidden_size), 153 | torch.nn.ReLU(), 154 | torch.nn.Linear(self.hidden_size, self.output_size), 155 | ] 156 | self.mlp_head = torch.nn.Sequential(*layers) 157 | 158 | else: 159 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") 160 | 161 | def forward(self, indices): 162 | input_embeds = self.embedding(indices) 163 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM: 164 | output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]) 165 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP: 166 | output_embeds = self.mlp_head(input_embeds) 167 | else: 168 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") 169 | 170 | return output_embeds 171 | -------------------------------------------------------------------------------- /sanity_check/backends/peft/utils/config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import enum 16 | import json 17 | import os 18 | from dataclasses import asdict, dataclass, field 19 | from typing import Optional, Union 20 | 21 | from huggingface_hub import hf_hub_download 22 | from transformers.utils import PushToHubMixin 23 | 24 | from .other import CONFIG_NAME 25 | 26 | 27 | class PeftType(str, enum.Enum): 28 | PROMPT_TUNING = "PROMPT_TUNING" 29 | P_TUNING = "P_TUNING" 30 | PREFIX_TUNING = "PREFIX_TUNING" 31 | LORA = "LORA" 32 | ADALORA = "ADALORA" 33 | 34 | 35 | class TaskType(str, enum.Enum): 36 | SEQ_CLS = "SEQ_CLS" 37 | SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM" 38 | CAUSAL_LM = "CAUSAL_LM" 39 | TOKEN_CLS = "TOKEN_CLS" 40 | 41 | 42 | @dataclass 43 | class PeftConfigMixin(PushToHubMixin): 44 | r""" 45 | This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all 46 | PEFT adapter models. This class inherits from [`~transformers.utils.PushToHubMixin`] which contains the methods to 47 | push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a 48 | directory. The method `from_pretrained` will load the configuration of your adapter model from a directory. 49 | 50 | Args: 51 | peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. 52 | """ 53 | peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."}) 54 | 55 | @property 56 | def __dict__(self): 57 | return asdict(self) 58 | 59 | def to_dict(self): 60 | return self.__dict__ 61 | 62 | def save_pretrained(self, save_directory, **kwargs): 63 | r""" 64 | This method saves the configuration of your adapter model in a directory. 65 | 66 | Args: 67 | save_directory (`str`): 68 | The directory where the configuration will be saved. 69 | kwargs (additional keyword arguments, *optional*): 70 | Additional keyword arguments passed along to the [`~transformers.utils.PushToHubMixin.push_to_hub`] 71 | method. 72 | """ 73 | if os.path.isfile(save_directory): 74 | raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") 75 | 76 | os.makedirs(save_directory, exist_ok=True) 77 | 78 | output_dict = self.__dict__ 79 | output_path = os.path.join(save_directory, CONFIG_NAME) 80 | 81 | # save it 82 | with open(output_path, "w") as writer: 83 | writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) 84 | 85 | @classmethod 86 | def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs): 87 | r""" 88 | This method loads the configuration of your adapter model from a directory. 89 | 90 | Args: 91 | pretrained_model_name_or_path (`str`): 92 | The directory or the Hub repository id where the configuration is saved. 93 | kwargs (additional keyword arguments, *optional*): 94 | Additional keyword arguments passed along to the child class initialization. 95 | """ 96 | path = ( 97 | os.path.join(pretrained_model_name_or_path, subfolder) 98 | if subfolder is not None 99 | else pretrained_model_name_or_path 100 | ) 101 | if os.path.isfile(os.path.join(path, CONFIG_NAME)): 102 | config_file = os.path.join(path, CONFIG_NAME) 103 | else: 104 | try: 105 | config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder) 106 | except Exception: 107 | raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'") 108 | 109 | loaded_attributes = cls.from_json_file(config_file) 110 | 111 | config = cls(**kwargs) 112 | 113 | for key, value in loaded_attributes.items(): 114 | if hasattr(config, key): 115 | setattr(config, key, value) 116 | 117 | return config 118 | 119 | @classmethod 120 | def from_json_file(cls, path_json_file, **kwargs): 121 | r""" 122 | Loads a configuration file from a json file. 123 | 124 | Args: 125 | path_json_file (`str`): 126 | The path to the json file. 127 | """ 128 | with open(path_json_file, "r") as file: 129 | json_object = json.load(file) 130 | 131 | return json_object 132 | 133 | 134 | @dataclass 135 | class PeftConfig(PeftConfigMixin): 136 | """ 137 | This is the base configuration class to store the configuration of a [`PeftModel`]. 138 | 139 | Args: 140 | peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. 141 | task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform. 142 | inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode. 143 | """ 144 | 145 | base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."}) 146 | peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"}) 147 | task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"}) 148 | inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"}) 149 | 150 | 151 | @dataclass 152 | class PromptLearningConfig(PeftConfig): 153 | """ 154 | This is the base configuration class to store the configuration of [`PrefixTuning`], [`PromptEncoder`], or 155 | [`PromptTuning`]. 156 | 157 | Args: 158 | num_virtual_tokens (`int`): The number of virtual tokens to use. 159 | token_dim (`int`): The hidden embedding dimension of the base transformer model. 160 | num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model. 161 | num_attention_heads (`int`): The number of attention heads in the base transformer model. 162 | num_layers (`int`): The number of layers in the base transformer model. 163 | """ 164 | 165 | num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"}) 166 | token_dim: int = field( 167 | default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"} 168 | ) 169 | num_transformer_submodules: Optional[int] = field( 170 | default=None, metadata={"help": "Number of transformer submodules"} 171 | ) 172 | num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"}) 173 | num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"}) 174 | -------------------------------------------------------------------------------- /only_train_once/graph/utils.py: -------------------------------------------------------------------------------- 1 | from torch import _C 2 | import torch._C._onnx as _C_onnx 3 | from torch.onnx import ( 4 | symbolic_helper 5 | ) 6 | import textwrap 7 | from torch.onnx._globals import GLOBALS 8 | 9 | def _is_constant_tensor_list(node): 10 | if node.kind() != "prim::Constant": 11 | return False 12 | output_type = node.output().type() 13 | if output_type.isSubtypeOf(_C.ListType.ofTensors()): 14 | return True 15 | if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())): 16 | return True 17 | 18 | def _split_tensor_list_constants(g, block): 19 | for node in block.nodes(): 20 | for subblock in node.blocks(): 21 | _split_tensor_list_constants(g, subblock) 22 | if _is_constant_tensor_list(node): 23 | inputs = [] 24 | for val in node.output().toIValue(): 25 | input = g.insertConstant(val) 26 | input.node().moveBefore(node) 27 | input.node().copyMetadata(node) 28 | inputs.append(input) 29 | 30 | lc = ( 31 | g.create("prim::ListConstruct", inputs) 32 | .insertBefore(node) 33 | .output() 34 | .setType(_C.ListType.ofTensors()) 35 | ) 36 | lc.node().copyMetadata(node) 37 | node.output().replaceAllUsesWith(lc) 38 | 39 | def _optimize_trace_graph_no_onnx_operator( 40 | graph: _C.Graph, 41 | operator_export_type: _C_onnx.OperatorExportTypes, 42 | _disable_torch_constant_prop: bool = False, 43 | fixed_batch_size: bool = False, 44 | params_dict=None, 45 | dynamic_axes=None, 46 | input_names=None, 47 | module=None, 48 | ): 49 | if params_dict is None: 50 | params_dict = {} 51 | 52 | # Inline everything 53 | _C._jit_pass_inline(graph) 54 | 55 | # Remove fork/wait nodes 56 | _C._jit_pass_inline_fork_wait(graph) 57 | _C._jit_pass_lint(graph) 58 | _C._jit_pass_onnx_autograd_function_process(graph) 59 | _C._jit_pass_lower_all_tuples(graph) 60 | 61 | # we now record some ops like ones/zeros 62 | # into a trace where we previously recorded constants. 63 | # use constant prop to maintain our current level of onnx support 64 | # without implementing symbolics for all of them 65 | if _disable_torch_constant_prop is False: 66 | _C._jit_pass_constant_propagation(graph) 67 | 68 | _split_tensor_list_constants(graph, graph) 69 | # run dce to eliminate dead parts of the graph that might have been 70 | # left behind by things like symbolic_override 71 | _C._jit_pass_dce(graph) 72 | _C._jit_pass_lint(graph) 73 | 74 | # CSE should improve perf when Autocast is used with disabled cache 75 | # Autocast is disabled due to a limitation on tracer as described at https://github.com/pytorch/pytorch/issues/84092 76 | # Must run before _C._jit_pass_erase_number_types to prevent type substitution 77 | if _C._jit_pass_cse(graph): 78 | _C._jit_pass_onnx_lint(graph) 79 | 80 | _C._jit_pass_canonicalize_graph_fuser_ops(graph) 81 | _C._jit_pass_lint(graph) 82 | _C._jit_pass_peephole(graph, True) 83 | _C._jit_pass_fuse_addmm(graph) 84 | _C._jit_pass_lint(graph) 85 | 86 | _C._jit_pass_peephole(graph, True) 87 | _C._jit_pass_lower_all_tuples(graph) 88 | # in _jit_pass_onnx, symbolic functions are called for each node for conversion. 89 | # However, there are nodes that cannot be converted without additional context. 90 | # For example, the number of outputs from split (and whether it is static or dynamic) is unknown 91 | # until the point where it is unpacked by listUnpack node. 92 | # This pass does a preprocess, and prepares the nodes such that enough context can be received 93 | # by the symbolic function. 94 | _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) 95 | _C._jit_pass_onnx_preprocess(graph) 96 | 97 | # onnx does not support tuples, so try to remove them 98 | _C._jit_pass_lint(graph) 99 | 100 | # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0 101 | _C._jit_pass_prepare_division_for_onnx(graph) 102 | 103 | _C._jit_pass_onnx_remove_print(graph) 104 | _C._jit_pass_onnx_preprocess_caffe2(graph) 105 | 106 | symbolic_helper._quantized_ops.clear() 107 | # Unpack quantized weights for conv and linear ops and insert into graph. 108 | _C._jit_pass_onnx_unpack_quantized_weights( 109 | graph, params_dict, symbolic_helper.is_caffe2_aten_fallback() 110 | ) 111 | if symbolic_helper.is_caffe2_aten_fallback(): 112 | # Insert permutes before and after each conv op to ensure correct order. 113 | _C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict) 114 | 115 | # Find consecutive permutes that are no-ops and remove them. 116 | _C._jit_pass_custom_pattern_based_rewrite_graph( 117 | textwrap.dedent( 118 | """\ 119 | graph(%Pi): 120 | %Pq = quantized::nhwc2nchw(%Pi) 121 | %Pr = quantized::nchw2nhwc(%Pq) 122 | return (%Pr)""" 123 | ), 124 | textwrap.dedent( 125 | """\ 126 | graph(%Ri): 127 | return (%Ri)""" 128 | ), 129 | graph, 130 | ) 131 | 132 | # onnx only supports tensors, so we turn all out number types into tensors 133 | _C._jit_pass_erase_number_types(graph) 134 | if GLOBALS.onnx_shape_inference: 135 | input_names = [] if input_names is None else input_names 136 | dynamic_axes = {} if dynamic_axes is None else dynamic_axes 137 | _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names) 138 | _C._jit_pass_onnx_lint(graph) 139 | 140 | graph = _C._jit_pass_onnx(graph, operator_export_type) 141 | # except: 142 | # pass 143 | _C._jit_pass_onnx_lint(graph) 144 | _C._jit_pass_lint(graph) 145 | 146 | _C._jit_pass_onnx_scalar_type_analysis( 147 | graph, True, GLOBALS.export_onnx_opset_version 148 | ) 149 | _C._jit_pass_lint(graph) 150 | 151 | _C._jit_pass_onnx_peephole( 152 | graph, GLOBALS.export_onnx_opset_version, fixed_batch_size 153 | ) 154 | _C._jit_pass_lint(graph) 155 | 156 | # graph is not a valid jit graph anymore because types have been replaced 157 | # (e.g. int with Tensor), so it now contains operators that don't actually 158 | # exist. We can't run normal dead code elimination because it'd fail trying 159 | # to look up if an operator has side effects, but we can run a dead code 160 | # elimination variant that doesn't need to look up if an op has side effects. 161 | _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) 162 | _C._jit_pass_lint(graph) 163 | graph = _C._jit_pass_canonicalize(graph) 164 | _C._jit_pass_lint(graph) 165 | try: 166 | if GLOBALS.onnx_shape_inference: 167 | _C._jit_pass_onnx_graph_shape_type_inference( 168 | graph, params_dict, GLOBALS.export_onnx_opset_version 169 | ) 170 | except: 171 | pass 172 | return graph 173 | 174 | def _get_str_inside_parenthesis(str_to_processed, prefix_str=None): 175 | if not str_to_processed.startswith(prefix_str): 176 | return None 177 | stack = [] 178 | start_idx = len(prefix_str) + 1 179 | end_idx = -1 180 | for c in str_to_processed: 181 | if c == '(': 182 | stack.append(c) 183 | elif c == ')': 184 | stack.pop() 185 | end_idx += 1 186 | if len(stack) == 0 and end_idx > len(prefix_str): 187 | break 188 | return str_to_processed[start_idx : end_idx] 189 | 190 | def _get_tensor_shape(str_to_processed, prefix_str='Float'): 191 | # Parse output shape given the string of one torch node 192 | # Should have some better way for completing it 193 | output_str = _get_str_inside_parenthesis(str_to_processed, prefix_str=prefix_str) 194 | if output_str is None: 195 | return None 196 | output_str_splits = output_str.split(',') 197 | output_shapes = [] 198 | for item in output_str_splits: 199 | item = item.strip() 200 | if item.isnumeric(): 201 | output_shapes.append(int(item)) 202 | else: 203 | break 204 | return output_shapes 205 | 206 | MILLION = 1e6 207 | BILLION = 1e9 208 | 209 | def _scale_value(value, in_million=True, in_billion=False): 210 | if in_million: 211 | value /= float(MILLION) 212 | elif in_billion: 213 | value /= float(BILLION) 214 | return value --------------------------------------------------------------------------------