├── .idea ├── .gitignore ├── KnowledgeDistillation.iml ├── deployment.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── Examples └── example_multi_layer_based_model │ └── distill_bert.py ├── LICENSE ├── README.rst ├── knowledge_distillation ├── Evaluator │ ├── __init__.py │ ├── ievaluator.py │ └── multi_layer_based_distillation_evaluator.py ├── Loss │ ├── __init__.py │ ├── cosine_similarity_loss.py │ ├── loss_functions.py │ └── multi_layer_based_distillation_loss.py ├── __init__.py └── knowledge_distillation.py └── setup.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Default ignored files 3 | /workspace.xml -------------------------------------------------------------------------------- /.idea/KnowledgeDistillation.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Examples/example_multi_layer_based_model/distill_bert.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a simple example to distill a small bert by bigger bert 3 | 4 | """ 5 | # import packages 6 | import torch 7 | import logging 8 | import numpy as np 9 | from transformers import BertModel, BertConfig 10 | from torch.utils.data import DataLoader, RandomSampler, TensorDataset 11 | 12 | from knowledge_distillation import KnowledgeDistiller, MultiLayerBasedDistillationLoss 13 | from knowledge_distillation import MultiLayerBasedDistillationEvaluator 14 | 15 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 16 | # Some global variables 17 | train_batch_size = 40 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | learning_rate = 1e-5 20 | num_epoch = 10 21 | 22 | # define student and teacher model 23 | # Teacher Model 24 | bert_config = BertConfig(num_hidden_layers=12, hidden_size=60, intermediate_size=60, output_hidden_states=True, 25 | output_attentions=True) 26 | teacher_model = BertModel(bert_config) 27 | # Student Model 28 | bert_config = BertConfig(num_hidden_layers=3, hidden_size=60, intermediate_size=60, output_hidden_states=True, 29 | output_attentions=True) 30 | student_model = BertModel(bert_config) 31 | 32 | ### Train data loader 33 | input_ids = torch.LongTensor(np.random.randint(100, 1000, (100000, 50))) 34 | attention_mask = torch.LongTensor(np.ones((100000, 50))) 35 | token_type_ids = torch.LongTensor(np.zeros((100000, 50))) 36 | train_data = TensorDataset(input_ids, attention_mask, token_type_ids) 37 | train_sampler = RandomSampler(train_data) 38 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size) 39 | 40 | 41 | ### Train data adaptor 42 | ### It is a function that turn batch_data (from train_dataloader) to the inputs of teacher_model and student_model 43 | ### You can define your own train_data_adaptor. Remember the input must be device and batch_data. 44 | ### The output is either dict or tuple, but must be consistent with you model's input 45 | def train_data_adaptor(device, batch_data): 46 | batch_data = tuple(t.to(device) for t in batch_data) 47 | batch_data_dict = {"input_ids": batch_data[0], 48 | "attention_mask": batch_data[1], 49 | "token_type_ids": batch_data[2], } 50 | # In this case, the teacher and student use the same input 51 | return batch_data_dict, batch_data_dict 52 | 53 | 54 | ### The loss model is the key for this generation. 55 | ### We have already provided a general loss model for distilling multi bert layer 56 | ### In most cases, you can directly use this model. 57 | #### First, we should define a distill_config which indicates how to compute ths loss between teacher and student. 58 | #### distill_config is a list-object, each item indicates how to calculate loss. 59 | #### It also defines which output of which layer to calculate loss. 60 | #### It shoulde be consistent with your output_adaptor 61 | distill_config = [ 62 | # means that compute a loss by their embedding_layer's embedding 63 | {"teacher_layer_name": "embedding_layer", "teacher_layer_output_name": "embedding", 64 | "student_layer_name": "embedding_layer", "student_layer_output_name": "embedding", 65 | "loss": {"loss_function": "mse_with_mask", "args": {}}, "weight": 1.0 66 | }, 67 | # means that compute a loss between teacher's bert_layer12's hidden_states and student's bert_layer3's hidden_states 68 | {"teacher_layer_name": "bert_layer12", "teacher_layer_output_name": "hidden_states", 69 | "student_layer_name": "bert_layer3", "student_layer_output_name": "hidden_states", 70 | "loss": {"loss_function": "mse_with_mask", "args": {}}, "weight": 1.0 71 | }, 72 | {"teacher_layer_name": "bert_layer12", "teacher_layer_output_name": "attention", 73 | "student_layer_name": "bert_layer3", "student_layer_output_name": "attention", 74 | "loss": {"loss_function": "attention_mse_with_mask", "args": {}}, "weight": 1.0 75 | }, 76 | {"teacher_layer_name": "pred_layer", "teacher_layer_output_name": "pooler_output", 77 | "student_layer_name": "pred_layer", "student_layer_output_name": "pooler_output", 78 | "loss": {"loss_function": "mse", "args": {}}, "weight": 1.0 79 | }, 80 | ] 81 | 82 | 83 | ### teacher_output_adaptor and student_output_adaptor 84 | ### In most cases, model's output is tuple-object, However, in our package, we need the output is dict-object, 85 | ### like: { "layer_name":{"output_name":value} .... } 86 | ### Hence, the output adaptor is to turn your model's output to dict-object output 87 | ### In my case, teacher and student can use one adaptor 88 | def output_adaptor(model_output): 89 | last_hidden_state, pooler_output, hidden_states, attentions = model_output 90 | output = {"embedding_layer": {"embedding": hidden_states[0]}} 91 | for idx in range(len(attentions)): 92 | output["bert_layer" + str(idx + 1)] = {"hidden_states": hidden_states[idx + 1], 93 | "attention": attentions[idx]} 94 | output["pred_layer"] = {"pooler_output": pooler_output} 95 | return output 96 | 97 | 98 | # loss_model 99 | loss_model = MultiLayerBasedDistillationLoss(distill_config=distill_config, 100 | teacher_output_adaptor=output_adaptor, 101 | student_output_adaptor=output_adaptor) 102 | # optimizer 103 | param_optimizer = list(student_model.named_parameters()) 104 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 105 | optimizer_grouped_parameters = [ 106 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 107 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 108 | ] 109 | optimizer = torch.optim.Adam(params=optimizer_grouped_parameters, lr=learning_rate) 110 | # evaluator 111 | # this is a basic evalator, it can output loss value and save models 112 | # You can define you own evaluator class that implements the interface IEvaluator 113 | 114 | evaluator = MultiLayerBasedDistillationEvaluator(save_dir="save_model", save_step=1000, print_loss_step=20) 115 | # Get a KnowledgeDistiller 116 | distiller = KnowledgeDistiller(teacher_model=teacher_model, student_model=student_model, 117 | train_dataloader=train_dataloader, dev_dataloader=None, 118 | train_data_adaptor=train_data_adaptor, dev_data_adaptor=None, 119 | device=device, loss_model=loss_model, optimizer=optimizer, 120 | evaluator=evaluator, num_epoch=num_epoch) 121 | # start distillate 122 | distiller.distillate() 123 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Fernando M. F. Nogueira, 4 | Guillaume Lemaitre, 5 | Dayvid Victor 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | KnowledgeDistillation 2 | ====================== 3 | 4 | Update 5 | ------------ 6 | **July, 2020** 7 | 8 | **Knowledge Distillation** has been used in Deep Learning for about two years. 9 | It is still at an early stage of development. 10 | So far, many distillation methods have been proposed, due to complexity and diversity of these methods, 11 | it is hard to integrate all of them into a framework. Hence, I think this package is more suitable for the beginners. 12 | 13 | This package mainly contain two parts: 14 | 15 | 1. Distillation of MultiLayerBasedModel 16 | 2. Other distillation methods 17 | 18 | This is the last update for distillation of MultiLayerBasedModel. Other distillation methods will be added in succession. 19 | When **Knowledge Distillation** is mature enough, I will integrate them into a framework. 20 | 21 | 22 | **March, 2020** 23 | 24 | - Now, users could define their own loss functions. The requirement of loss function can be found in API document. 25 | 26 | - Add more built-in loss functions (**mse_with_mask** and **attention_mse_with_mask**). 27 | 28 | - Unify hidden loss and predict loss, the key "type" is removed from distill_config. 29 | 30 | - Now, the device information is removed from loss value. 31 | 32 | Introduction 33 | ------------ 34 | 35 | What is knowledge distillation? 36 | ::::::::::::::::::::::::::::::::::::::::: 37 | **Knowledge Distillation** is model compression method in which a small model is trained 38 | to mimic a pre-trained, larger model (or ensemble of models). Recently, many models have achieved SOTA performance. 39 | However, their billions of parameters make it computationally expensive and inefficient considering both memory 40 | consumption and high latency. Hence, it is necessary to get a small model from a large model by using knowledge 41 | distillation. 42 | 43 | KnowledgeDistillation's training setting is sometimes referred to as "teacher-student", 44 | where the large model is the teacher and the small model is the student. 45 | The method was first proposed by `Bucila `_ 46 | and generalized by `Hinton `_. 47 | 48 | Introduction of KnowledgeDistillation Package 49 | ::::::::::::::::::::::::::::::::::::::::::::::: 50 | **KnowledgeDistillation** is a knowledge distillation framework. You can distill your own model 51 | by using this toolkit. Our framework is highly abstract and you can achieve many distillation methods by using this framework. 52 | Besides, we also provide a distillation of MultiLayerBasedModel considering many models are multi layers. 53 | 54 | Usage 55 | -------- 56 | 57 | To use the package, you should define these objects: 58 | 59 | * **Teacher Model** (large model, trained) 60 | * **Student Model** (small model, untrained) 61 | * **Data loader**, a generator or iterator to get training data or dev data. For example, `torch.utils.data.DataLoader` 62 | * **Train data adaptor**, a function that turn batch_data (from train_dataloader) to the inputs of teacher_model and student_model 63 | * **Distill config**, a list-object, each item indicates how to calculate loss. It also defines which output of which layer to calculate loss. 64 | * **Output adaptor**, a function that turn your model's output to dict-object output which meet distiller's requirements 65 | * **Evaluator**, a class with evaluate function, it define when and how to save your student model 66 | 67 | 68 | Installation 69 | --------------- 70 | Requirements 71 | :::::::::::::::::: 72 | - Python >= 3.6 73 | - PyTorch >= 1.1.0 74 | - NumPy 75 | - Transformers >= 2.0 (optional, used by some examples) 76 | 77 | Install from PyPI 78 | :::::::::::::::::: 79 | 80 | **KnowledgeDistillation** is currently available on the PyPi's repository and you can 81 | install it via pip:: 82 | 83 | pip install -U KnowledgeDistillation 84 | 85 | Install from the Github 86 | :::::::::::::::::::::::::::::: 87 | If you prefer, you can clone it and run the setup.py file. Use the following 88 | command to get a copy from GitHub:: 89 | 90 | git clone https://github.com/DunZhang/KnowledgeDistillation.git 91 | 92 | 93 | How to Contribute 94 | ------------------ 95 | Welcome to add examples for latest knowledge distillation methods. There is no need to add an example if the author 96 | has provided an official implementation. The example should be simple and easy to be executed. Hence, I suggest to make some fake data for your example. 97 | 98 | A simple example 99 | ---------------- 100 | A simple example:: 101 | 102 | # import packages 103 | import torch 104 | import logging 105 | import numpy as np 106 | from transformers import BertModel, BertConfig 107 | from torch.utils.data import DataLoader, RandomSampler, TensorDataset 108 | 109 | from knowledge_distillation import KnowledgeDistiller, MultiLayerBasedDistillationLoss 110 | from knowledge_distillation import MultiLayerBasedDistillationEvaluator 111 | 112 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 113 | # Some global variables 114 | train_batch_size = 40 115 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 116 | learning_rate = 1e-5 117 | num_epoch = 10 118 | 119 | # define student and teacher model 120 | # Teacher Model 121 | bert_config = BertConfig(num_hidden_layers=12, hidden_size=60, intermediate_size=60, output_hidden_states=True, 122 | output_attentions=True) 123 | teacher_model = BertModel(bert_config) 124 | # Student Model 125 | bert_config = BertConfig(num_hidden_layers=3, hidden_size=60, intermediate_size=60, output_hidden_states=True, 126 | output_attentions=True) 127 | student_model = BertModel(bert_config) 128 | 129 | ### Train data loader 130 | input_ids = torch.LongTensor(np.random.randint(100, 1000, (100000, 50))) 131 | attention_mask = torch.LongTensor(np.ones((100000, 50))) 132 | token_type_ids = torch.LongTensor(np.zeros((100000, 50))) 133 | train_data = TensorDataset(input_ids, attention_mask, token_type_ids) 134 | train_sampler = RandomSampler(train_data) 135 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size) 136 | 137 | 138 | ### Train data adaptor 139 | ### It is a function that turn batch_data (from train_dataloader) to the inputs of teacher_model and student_model 140 | ### You can define your own train_data_adaptor. Remember the input must be device and batch_data. 141 | ### The output is either dict or tuple, but must be consistent with you model's input 142 | def train_data_adaptor(device, batch_data): 143 | batch_data = tuple(t.to(device) for t in batch_data) 144 | batch_data_dict = {"input_ids": batch_data[0], 145 | "attention_mask": batch_data[1], 146 | "token_type_ids": batch_data[2], } 147 | # In this case, the teacher and student use the same input 148 | return batch_data_dict, batch_data_dict 149 | 150 | 151 | ### The loss model is the key for this generation. 152 | ### We have already provided a general loss model for distilling multi bert layer 153 | ### In most cases, you can directly use this model. 154 | #### First, we should define a distill_config which indicates how to compute ths loss between teacher and student. 155 | #### distill_config is a list-object, each item indicates how to calculate loss. 156 | #### It also defines which output of which layer to calculate loss. 157 | #### It shoulde be consistent with your output_adaptor 158 | distill_config = [ 159 | # means that compute a loss by their embedding_layer's embedding 160 | {"teacher_layer_name": "embedding_layer", "teacher_layer_output_name": "embedding", 161 | "student_layer_name": "embedding_layer", "student_layer_output_name": "embedding", 162 | "loss": {"loss_function": "mse_with_mask", "args": {}}, "weight": 1.0 163 | }, 164 | # means that compute a loss between teacher's bert_layer12's hidden_states and student's bert_layer3's hidden_states 165 | {"teacher_layer_name": "bert_layer12", "teacher_layer_output_name": "hidden_states", 166 | "student_layer_name": "bert_layer3", "student_layer_output_name": "hidden_states", 167 | "loss": {"loss_function": "mse_with_mask", "args": {}}, "weight": 1.0 168 | }, 169 | {"teacher_layer_name": "bert_layer12", "teacher_layer_output_name": "attention", 170 | "student_layer_name": "bert_layer3", "student_layer_output_name": "attention", 171 | "loss": {"loss_function": "attention_mse_with_mask", "args": {}}, "weight": 1.0 172 | }, 173 | {"teacher_layer_name": "pred_layer", "teacher_layer_output_name": "pooler_output", 174 | "student_layer_name": "pred_layer", "student_layer_output_name": "pooler_output", 175 | "loss": {"loss_function": "mse", "args": {}}, "weight": 1.0 176 | }, 177 | ] 178 | 179 | 180 | ### teacher_output_adaptor and student_output_adaptor 181 | ### In most cases, model's output is tuple-object, However, in our package, we need the output is dict-object, 182 | ### like: { "layer_name":{"output_name":value} .... } 183 | ### Hence, the output adaptor is to turn your model's output to dict-object output 184 | ### In my case, teacher and student can use one adaptor 185 | def output_adaptor(model_output): 186 | last_hidden_state, pooler_output, hidden_states, attentions = model_output 187 | output = {"embedding_layer": {"embedding": hidden_states[0]}} 188 | for idx in range(len(attentions)): 189 | output["bert_layer" + str(idx + 1)] = {"hidden_states": hidden_states[idx + 1], 190 | "attention": attentions[idx]} 191 | output["pred_layer"] = {"pooler_output": pooler_output} 192 | return output 193 | 194 | 195 | # loss_model 196 | loss_model = MultiLayerBasedDistillationLoss(distill_config=distill_config, 197 | teacher_output_adaptor=output_adaptor, 198 | student_output_adaptor=output_adaptor) 199 | # optimizer 200 | param_optimizer = list(student_model.named_parameters()) 201 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 202 | optimizer_grouped_parameters = [ 203 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 204 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 205 | ] 206 | optimizer = torch.optim.Adam(params=optimizer_grouped_parameters, lr=learning_rate) 207 | # evaluator 208 | # this is a basic evalator, it can output loss value and save models 209 | # You can define you own evaluator class that implements the interface IEvaluator 210 | 211 | evaluator = MultiLayerBasedDistillationEvaluator(save_dir="save_model", save_step=1000, print_loss_step=20) 212 | # Get a KnowledgeDistiller 213 | distiller = KnowledgeDistiller(teacher_model=teacher_model, student_model=student_model, 214 | train_dataloader=train_dataloader, dev_dataloader=None, 215 | train_data_adaptor=train_data_adaptor, dev_data_adaptor=None, 216 | device=device, loss_model=loss_model, optimizer=optimizer, 217 | evaluator=evaluator, num_epoch=num_epoch) 218 | # start distillate 219 | distiller.distillate() 220 | 221 | -------------------------------------------------------------------------------- /knowledge_distillation/Evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from .ievaluator import IEvaluator 2 | from .multi_layer_based_distillation_evaluator import MultiLayerBasedDistillationEvaluator 3 | -------------------------------------------------------------------------------- /knowledge_distillation/Evaluator/ievaluator.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class IEvaluator(metaclass=abc.ABCMeta): 5 | @abc.abstractmethod 6 | def evaluate(self, *args, **kwargs): 7 | pass 8 | -------------------------------------------------------------------------------- /knowledge_distillation/Evaluator/multi_layer_based_distillation_evaluator.py: -------------------------------------------------------------------------------- 1 | from .ievaluator import IEvaluator 2 | import os 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class MultiLayerBasedDistillationEvaluator(IEvaluator): 9 | """ 10 | A simple evaluator: output loss and save model. 11 | You can define you own evaluator class that implements the interface IEvaluator. 12 | """ 13 | 14 | def __init__(self, save_dir, save_step=None, print_loss_step=20): 15 | """ 16 | 17 | :param save_dir: output directory 18 | :param save_step: frequency of saving model 19 | :param print_loss_step: output loss value every (print_loss_step) steps 20 | """ 21 | super(MultiLayerBasedDistillationEvaluator, self).__init__() 22 | self.save_step = save_step 23 | self.print_loss_step = print_loss_step 24 | self.save_dir = os.path.abspath(save_dir) 25 | self.loss_fw = None 26 | 27 | os.makedirs(self.save_dir, exist_ok=True) 28 | if save_dir and save_step: 29 | self.loss_fw = open(os.path.join(save_dir, "LossValue.txt"), "w", encoding="utf8") 30 | 31 | def evaluate(self, teacher_model, student_model, dev_data, dev_data_adaptor, epoch, step, loss_value): 32 | if step > 0 and step % self.print_loss_step == 0: 33 | logger.info("epoch:{},\tstep:{},\tloss value:{}".format(epoch, step, loss_value)) 34 | if self.save_dir and self.save_step: 35 | self.loss_fw.write("epoch:{},\tstep:{},\tloss value:{}\n".format(epoch, step, loss_value)) 36 | self.loss_fw.flush() 37 | if self.save_dir and self.save_step: 38 | if step > 0 and step % self.save_step == 0: 39 | save_path = os.path.join(self.save_dir, str(epoch) + "-" + str(step)) 40 | os.makedirs(save_path, exist_ok=True) 41 | student_model.save_pretrained(save_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | logger.info("epoch:{}\tstep:{}\tloss value:{}".format(1, 2, 3)) 46 | -------------------------------------------------------------------------------- /knowledge_distillation/Loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .multi_layer_based_distillation_loss import MultiLayerBasedDistillationLoss 2 | from .cosine_similarity_loss import CosineSimilarityLoss 3 | from .loss_functions import mse, mse_with_mask, attention_mse_with_mask 4 | -------------------------------------------------------------------------------- /knowledge_distillation/Loss/cosine_similarity_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CosineSimilarityLoss(nn.Module): 6 | def __init__(self, ): 7 | super(CosineSimilarityLoss, self).__init__() 8 | 9 | def forward(self, x1, x2): 10 | return 0.5 - 0.5 * torch.cosine_similarity(x1, x2) 11 | -------------------------------------------------------------------------------- /knowledge_distillation/Loss/loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def mse_with_mask(teacher_output, student_output, teacher_input=None, student_input=None): 6 | mask = teacher_input["attention_mask"] 7 | mask = mask.to(student_output) 8 | # * hidden_size 9 | valid_count = mask.sum() * student_output.size(-1) 10 | loss = (F.mse_loss(teacher_output, student_output, reduction='none') * mask.unsqueeze(-1)).sum() / valid_count 11 | return loss 12 | 13 | 14 | def attention_mse_with_mask(teacher_output, student_output, teacher_input=None, student_input=None): 15 | mask = teacher_input["attention_mask"] 16 | mask = mask.to(student_output).unsqueeze(1).expand(-1, student_output.size(1), -1) # (bs, num_of_heads, len) 17 | valid_count = torch.pow(mask.sum(dim=2), 2).sum() 18 | loss = (F.mse_loss(student_output, teacher_output, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze( 19 | 2)).sum() / valid_count 20 | return loss 21 | 22 | 23 | def mse(teacher_output, student_output, teacher_input=None, student_input=None): 24 | return F.mse_loss(teacher_output, student_output, reduction='mean') 25 | -------------------------------------------------------------------------------- /knowledge_distillation/Loss/multi_layer_based_distillation_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .cosine_similarity_loss import CosineSimilarityLoss 3 | from .loss_functions import mse_with_mask, mse, attention_mse_with_mask 4 | 5 | 6 | class MultiLayerBasedDistillationLoss(nn.Module): 7 | def __init__(self, distill_config=None, teacher_output_adaptor=None, student_output_adaptor=None): 8 | super(MultiLayerBasedDistillationLoss, self).__init__() 9 | self.distill_config = distill_config 10 | self.teacher_output_adaptor = teacher_output_adaptor 11 | self.student_output_adaptor = student_output_adaptor 12 | self.loss_functions = {"mse": mse, "mse_with_mask": mse_with_mask, 13 | "attention_mse_with_mask": attention_mse_with_mask, 14 | "cross_entropy": nn.CrossEntropyLoss(), "cos": CosineSimilarityLoss()} 15 | 16 | def forward(self, teacher_output, student_output, teacher_input_data, student_input_data): 17 | teacher_adaptor_output = self.teacher_output_adaptor(teacher_output) 18 | student_adaptor_output = self.student_output_adaptor(student_output) 19 | loss = 0 20 | for distill_info in self.distill_config: 21 | tmp_teacher_output = teacher_adaptor_output[distill_info["teacher_layer_name"]][ 22 | distill_info["teacher_layer_output_name"]] 23 | tmp_student_output = student_adaptor_output[distill_info["student_layer_name"]][ 24 | distill_info["student_layer_output_name"]] 25 | tmp_loss = self.loss_functions[distill_info["loss"]["loss_function"]](tmp_teacher_output, 26 | tmp_student_output, 27 | teacher_input_data, 28 | student_input_data, 29 | **distill_info["loss"]["args"]) 30 | tmp_loss *= distill_info["weight"] 31 | loss += tmp_loss 32 | 33 | # student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att), student_att) # 小于-100的部分变成0 34 | return loss 35 | -------------------------------------------------------------------------------- /knowledge_distillation/__init__.py: -------------------------------------------------------------------------------- 1 | from .knowledge_distillation import KnowledgeDistiller 2 | from .Loss import MultiLayerBasedDistillationLoss 3 | from .Evaluator import MultiLayerBasedDistillationEvaluator 4 | -------------------------------------------------------------------------------- /knowledge_distillation/knowledge_distillation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class KnowledgeDistiller(): 5 | def __init__(self, teacher_model, student_model, train_dataloader, dev_dataloader, device, loss_model, optimizer, 6 | evaluator, num_epoch, train_data_adaptor, dev_data_adaptor): 7 | # send model to device 8 | self.teacher_model = teacher_model.to(device) 9 | self.student_model = student_model.to(device) 10 | self.train_dataloader = train_dataloader 11 | self.dev_dataloader = dev_dataloader 12 | self.device = device 13 | self.loss_model = loss_model 14 | self.optimizer = optimizer 15 | self.evaluator = evaluator 16 | self.num_epoch = num_epoch 17 | self.train_data_adaptor = train_data_adaptor 18 | self.dev_data_adaptor = dev_data_adaptor 19 | 20 | def distillate(self): 21 | # not train teacher model 22 | self.teacher_model.eval() 23 | # train student model 24 | self.student_model.train() 25 | for epoch in range(self.num_epoch): 26 | for step, batch_data in enumerate(self.train_dataloader): 27 | # get input data for teacher model and student model 28 | teacher_batch_data, student_batch_data = self.train_data_adaptor(self.device, batch_data) 29 | # get teacher output, not compute gradient 30 | with torch.no_grad(): 31 | if isinstance(teacher_batch_data, dict): 32 | teacher_output = self.teacher_model.forward(**teacher_batch_data) 33 | else: 34 | teacher_output = self.teacher_model.forward(*teacher_batch_data) 35 | # get student output 36 | if isinstance(student_batch_data, dict): 37 | student_output = self.student_model.forward(**student_batch_data) 38 | else: 39 | student_output = self.student_model.forward(*student_batch_data) 40 | # get loss 41 | loss = self.loss_model.forward(teacher_output, student_output, teacher_batch_data, student_batch_data) 42 | loss.backward() 43 | self.optimizer.step() 44 | self.optimizer.zero_grad() 45 | # evaluate and save model 46 | self.evaluator.evaluate(self.teacher_model, self.student_model, self.dev_dataloader, 47 | self.dev_data_adaptor, epoch, step, loss) 48 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='KnowledgeDistillation', 5 | version="1.0.4", 6 | description=('A general knowledge distillation framework'), 7 | long_description=open('README.rst').read(), 8 | author='ZhangDun', 9 | author_email='dunnzhang0@gmail.com', 10 | maintainer='ZhangDun', 11 | maintainer_email='dunnzhang0@gmail.com', 12 | license='MIT', 13 | packages=find_packages(), 14 | url='https://github.com/DunZhang/KnowledgeDistillation', 15 | install_requires=['torch>0.4.0', 'python>=3.6'], 16 | classifiers=[ 17 | 'Development Status :: 2 - Pre-Alpha', 18 | 'Operating System :: Microsoft :: Windows', 19 | 'Operating System :: POSIX', 20 | 'Operating System :: Unix', 21 | 'Operating System :: MacOS', 22 | 'Intended Audience :: Developers', 23 | 'Intended Audience :: Science/Research', 24 | 'Programming Language :: Python :: 3', 25 | 'License :: OSI Approved :: MIT License', 26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence' 27 | ], 28 | keywords="Transformer Networks BERT XLNet PyTorch NLP deep learning" 29 | ) 30 | --------------------------------------------------------------------------------