├── .idea
├── .gitignore
├── KnowledgeDistillation.iml
├── deployment.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── Examples
└── example_multi_layer_based_model
│ └── distill_bert.py
├── LICENSE
├── README.rst
├── knowledge_distillation
├── Evaluator
│ ├── __init__.py
│ ├── ievaluator.py
│ └── multi_layer_based_distillation_evaluator.py
├── Loss
│ ├── __init__.py
│ ├── cosine_similarity_loss.py
│ ├── loss_functions.py
│ └── multi_layer_based_distillation_loss.py
├── __init__.py
└── knowledge_distillation.py
└── setup.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Default ignored files
3 | /workspace.xml
--------------------------------------------------------------------------------
/.idea/KnowledgeDistillation.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/Examples/example_multi_layer_based_model/distill_bert.py:
--------------------------------------------------------------------------------
1 | """
2 | This is a simple example to distill a small bert by bigger bert
3 |
4 | """
5 | # import packages
6 | import torch
7 | import logging
8 | import numpy as np
9 | from transformers import BertModel, BertConfig
10 | from torch.utils.data import DataLoader, RandomSampler, TensorDataset
11 |
12 | from knowledge_distillation import KnowledgeDistiller, MultiLayerBasedDistillationLoss
13 | from knowledge_distillation import MultiLayerBasedDistillationEvaluator
14 |
15 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
16 | # Some global variables
17 | train_batch_size = 40
18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 | learning_rate = 1e-5
20 | num_epoch = 10
21 |
22 | # define student and teacher model
23 | # Teacher Model
24 | bert_config = BertConfig(num_hidden_layers=12, hidden_size=60, intermediate_size=60, output_hidden_states=True,
25 | output_attentions=True)
26 | teacher_model = BertModel(bert_config)
27 | # Student Model
28 | bert_config = BertConfig(num_hidden_layers=3, hidden_size=60, intermediate_size=60, output_hidden_states=True,
29 | output_attentions=True)
30 | student_model = BertModel(bert_config)
31 |
32 | ### Train data loader
33 | input_ids = torch.LongTensor(np.random.randint(100, 1000, (100000, 50)))
34 | attention_mask = torch.LongTensor(np.ones((100000, 50)))
35 | token_type_ids = torch.LongTensor(np.zeros((100000, 50)))
36 | train_data = TensorDataset(input_ids, attention_mask, token_type_ids)
37 | train_sampler = RandomSampler(train_data)
38 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)
39 |
40 |
41 | ### Train data adaptor
42 | ### It is a function that turn batch_data (from train_dataloader) to the inputs of teacher_model and student_model
43 | ### You can define your own train_data_adaptor. Remember the input must be device and batch_data.
44 | ### The output is either dict or tuple, but must be consistent with you model's input
45 | def train_data_adaptor(device, batch_data):
46 | batch_data = tuple(t.to(device) for t in batch_data)
47 | batch_data_dict = {"input_ids": batch_data[0],
48 | "attention_mask": batch_data[1],
49 | "token_type_ids": batch_data[2], }
50 | # In this case, the teacher and student use the same input
51 | return batch_data_dict, batch_data_dict
52 |
53 |
54 | ### The loss model is the key for this generation.
55 | ### We have already provided a general loss model for distilling multi bert layer
56 | ### In most cases, you can directly use this model.
57 | #### First, we should define a distill_config which indicates how to compute ths loss between teacher and student.
58 | #### distill_config is a list-object, each item indicates how to calculate loss.
59 | #### It also defines which output of which layer to calculate loss.
60 | #### It shoulde be consistent with your output_adaptor
61 | distill_config = [
62 | # means that compute a loss by their embedding_layer's embedding
63 | {"teacher_layer_name": "embedding_layer", "teacher_layer_output_name": "embedding",
64 | "student_layer_name": "embedding_layer", "student_layer_output_name": "embedding",
65 | "loss": {"loss_function": "mse_with_mask", "args": {}}, "weight": 1.0
66 | },
67 | # means that compute a loss between teacher's bert_layer12's hidden_states and student's bert_layer3's hidden_states
68 | {"teacher_layer_name": "bert_layer12", "teacher_layer_output_name": "hidden_states",
69 | "student_layer_name": "bert_layer3", "student_layer_output_name": "hidden_states",
70 | "loss": {"loss_function": "mse_with_mask", "args": {}}, "weight": 1.0
71 | },
72 | {"teacher_layer_name": "bert_layer12", "teacher_layer_output_name": "attention",
73 | "student_layer_name": "bert_layer3", "student_layer_output_name": "attention",
74 | "loss": {"loss_function": "attention_mse_with_mask", "args": {}}, "weight": 1.0
75 | },
76 | {"teacher_layer_name": "pred_layer", "teacher_layer_output_name": "pooler_output",
77 | "student_layer_name": "pred_layer", "student_layer_output_name": "pooler_output",
78 | "loss": {"loss_function": "mse", "args": {}}, "weight": 1.0
79 | },
80 | ]
81 |
82 |
83 | ### teacher_output_adaptor and student_output_adaptor
84 | ### In most cases, model's output is tuple-object, However, in our package, we need the output is dict-object,
85 | ### like: { "layer_name":{"output_name":value} .... }
86 | ### Hence, the output adaptor is to turn your model's output to dict-object output
87 | ### In my case, teacher and student can use one adaptor
88 | def output_adaptor(model_output):
89 | last_hidden_state, pooler_output, hidden_states, attentions = model_output
90 | output = {"embedding_layer": {"embedding": hidden_states[0]}}
91 | for idx in range(len(attentions)):
92 | output["bert_layer" + str(idx + 1)] = {"hidden_states": hidden_states[idx + 1],
93 | "attention": attentions[idx]}
94 | output["pred_layer"] = {"pooler_output": pooler_output}
95 | return output
96 |
97 |
98 | # loss_model
99 | loss_model = MultiLayerBasedDistillationLoss(distill_config=distill_config,
100 | teacher_output_adaptor=output_adaptor,
101 | student_output_adaptor=output_adaptor)
102 | # optimizer
103 | param_optimizer = list(student_model.named_parameters())
104 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
105 | optimizer_grouped_parameters = [
106 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
107 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
108 | ]
109 | optimizer = torch.optim.Adam(params=optimizer_grouped_parameters, lr=learning_rate)
110 | # evaluator
111 | # this is a basic evalator, it can output loss value and save models
112 | # You can define you own evaluator class that implements the interface IEvaluator
113 |
114 | evaluator = MultiLayerBasedDistillationEvaluator(save_dir="save_model", save_step=1000, print_loss_step=20)
115 | # Get a KnowledgeDistiller
116 | distiller = KnowledgeDistiller(teacher_model=teacher_model, student_model=student_model,
117 | train_dataloader=train_dataloader, dev_dataloader=None,
118 | train_data_adaptor=train_data_adaptor, dev_data_adaptor=None,
119 | device=device, loss_model=loss_model, optimizer=optimizer,
120 | evaluator=evaluator, num_epoch=num_epoch)
121 | # start distillate
122 | distiller.distillate()
123 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2014 Fernando M. F. Nogueira,
4 | Guillaume Lemaitre,
5 | Dayvid Victor
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | KnowledgeDistillation
2 | ======================
3 |
4 | Update
5 | ------------
6 | **July, 2020**
7 |
8 | **Knowledge Distillation** has been used in Deep Learning for about two years.
9 | It is still at an early stage of development.
10 | So far, many distillation methods have been proposed, due to complexity and diversity of these methods,
11 | it is hard to integrate all of them into a framework. Hence, I think this package is more suitable for the beginners.
12 |
13 | This package mainly contain two parts:
14 |
15 | 1. Distillation of MultiLayerBasedModel
16 | 2. Other distillation methods
17 |
18 | This is the last update for distillation of MultiLayerBasedModel. Other distillation methods will be added in succession.
19 | When **Knowledge Distillation** is mature enough, I will integrate them into a framework.
20 |
21 |
22 | **March, 2020**
23 |
24 | - Now, users could define their own loss functions. The requirement of loss function can be found in API document.
25 |
26 | - Add more built-in loss functions (**mse_with_mask** and **attention_mse_with_mask**).
27 |
28 | - Unify hidden loss and predict loss, the key "type" is removed from distill_config.
29 |
30 | - Now, the device information is removed from loss value.
31 |
32 | Introduction
33 | ------------
34 |
35 | What is knowledge distillation?
36 | :::::::::::::::::::::::::::::::::::::::::
37 | **Knowledge Distillation** is model compression method in which a small model is trained
38 | to mimic a pre-trained, larger model (or ensemble of models). Recently, many models have achieved SOTA performance.
39 | However, their billions of parameters make it computationally expensive and inefficient considering both memory
40 | consumption and high latency. Hence, it is necessary to get a small model from a large model by using knowledge
41 | distillation.
42 |
43 | KnowledgeDistillation's training setting is sometimes referred to as "teacher-student",
44 | where the large model is the teacher and the small model is the student.
45 | The method was first proposed by `Bucila `_
46 | and generalized by `Hinton `_.
47 |
48 | Introduction of KnowledgeDistillation Package
49 | :::::::::::::::::::::::::::::::::::::::::::::::
50 | **KnowledgeDistillation** is a knowledge distillation framework. You can distill your own model
51 | by using this toolkit. Our framework is highly abstract and you can achieve many distillation methods by using this framework.
52 | Besides, we also provide a distillation of MultiLayerBasedModel considering many models are multi layers.
53 |
54 | Usage
55 | --------
56 |
57 | To use the package, you should define these objects:
58 |
59 | * **Teacher Model** (large model, trained)
60 | * **Student Model** (small model, untrained)
61 | * **Data loader**, a generator or iterator to get training data or dev data. For example, `torch.utils.data.DataLoader`
62 | * **Train data adaptor**, a function that turn batch_data (from train_dataloader) to the inputs of teacher_model and student_model
63 | * **Distill config**, a list-object, each item indicates how to calculate loss. It also defines which output of which layer to calculate loss.
64 | * **Output adaptor**, a function that turn your model's output to dict-object output which meet distiller's requirements
65 | * **Evaluator**, a class with evaluate function, it define when and how to save your student model
66 |
67 |
68 | Installation
69 | ---------------
70 | Requirements
71 | ::::::::::::::::::
72 | - Python >= 3.6
73 | - PyTorch >= 1.1.0
74 | - NumPy
75 | - Transformers >= 2.0 (optional, used by some examples)
76 |
77 | Install from PyPI
78 | ::::::::::::::::::
79 |
80 | **KnowledgeDistillation** is currently available on the PyPi's repository and you can
81 | install it via pip::
82 |
83 | pip install -U KnowledgeDistillation
84 |
85 | Install from the Github
86 | ::::::::::::::::::::::::::::::
87 | If you prefer, you can clone it and run the setup.py file. Use the following
88 | command to get a copy from GitHub::
89 |
90 | git clone https://github.com/DunZhang/KnowledgeDistillation.git
91 |
92 |
93 | How to Contribute
94 | ------------------
95 | Welcome to add examples for latest knowledge distillation methods. There is no need to add an example if the author
96 | has provided an official implementation. The example should be simple and easy to be executed. Hence, I suggest to make some fake data for your example.
97 |
98 | A simple example
99 | ----------------
100 | A simple example::
101 |
102 | # import packages
103 | import torch
104 | import logging
105 | import numpy as np
106 | from transformers import BertModel, BertConfig
107 | from torch.utils.data import DataLoader, RandomSampler, TensorDataset
108 |
109 | from knowledge_distillation import KnowledgeDistiller, MultiLayerBasedDistillationLoss
110 | from knowledge_distillation import MultiLayerBasedDistillationEvaluator
111 |
112 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
113 | # Some global variables
114 | train_batch_size = 40
115 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116 | learning_rate = 1e-5
117 | num_epoch = 10
118 |
119 | # define student and teacher model
120 | # Teacher Model
121 | bert_config = BertConfig(num_hidden_layers=12, hidden_size=60, intermediate_size=60, output_hidden_states=True,
122 | output_attentions=True)
123 | teacher_model = BertModel(bert_config)
124 | # Student Model
125 | bert_config = BertConfig(num_hidden_layers=3, hidden_size=60, intermediate_size=60, output_hidden_states=True,
126 | output_attentions=True)
127 | student_model = BertModel(bert_config)
128 |
129 | ### Train data loader
130 | input_ids = torch.LongTensor(np.random.randint(100, 1000, (100000, 50)))
131 | attention_mask = torch.LongTensor(np.ones((100000, 50)))
132 | token_type_ids = torch.LongTensor(np.zeros((100000, 50)))
133 | train_data = TensorDataset(input_ids, attention_mask, token_type_ids)
134 | train_sampler = RandomSampler(train_data)
135 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)
136 |
137 |
138 | ### Train data adaptor
139 | ### It is a function that turn batch_data (from train_dataloader) to the inputs of teacher_model and student_model
140 | ### You can define your own train_data_adaptor. Remember the input must be device and batch_data.
141 | ### The output is either dict or tuple, but must be consistent with you model's input
142 | def train_data_adaptor(device, batch_data):
143 | batch_data = tuple(t.to(device) for t in batch_data)
144 | batch_data_dict = {"input_ids": batch_data[0],
145 | "attention_mask": batch_data[1],
146 | "token_type_ids": batch_data[2], }
147 | # In this case, the teacher and student use the same input
148 | return batch_data_dict, batch_data_dict
149 |
150 |
151 | ### The loss model is the key for this generation.
152 | ### We have already provided a general loss model for distilling multi bert layer
153 | ### In most cases, you can directly use this model.
154 | #### First, we should define a distill_config which indicates how to compute ths loss between teacher and student.
155 | #### distill_config is a list-object, each item indicates how to calculate loss.
156 | #### It also defines which output of which layer to calculate loss.
157 | #### It shoulde be consistent with your output_adaptor
158 | distill_config = [
159 | # means that compute a loss by their embedding_layer's embedding
160 | {"teacher_layer_name": "embedding_layer", "teacher_layer_output_name": "embedding",
161 | "student_layer_name": "embedding_layer", "student_layer_output_name": "embedding",
162 | "loss": {"loss_function": "mse_with_mask", "args": {}}, "weight": 1.0
163 | },
164 | # means that compute a loss between teacher's bert_layer12's hidden_states and student's bert_layer3's hidden_states
165 | {"teacher_layer_name": "bert_layer12", "teacher_layer_output_name": "hidden_states",
166 | "student_layer_name": "bert_layer3", "student_layer_output_name": "hidden_states",
167 | "loss": {"loss_function": "mse_with_mask", "args": {}}, "weight": 1.0
168 | },
169 | {"teacher_layer_name": "bert_layer12", "teacher_layer_output_name": "attention",
170 | "student_layer_name": "bert_layer3", "student_layer_output_name": "attention",
171 | "loss": {"loss_function": "attention_mse_with_mask", "args": {}}, "weight": 1.0
172 | },
173 | {"teacher_layer_name": "pred_layer", "teacher_layer_output_name": "pooler_output",
174 | "student_layer_name": "pred_layer", "student_layer_output_name": "pooler_output",
175 | "loss": {"loss_function": "mse", "args": {}}, "weight": 1.0
176 | },
177 | ]
178 |
179 |
180 | ### teacher_output_adaptor and student_output_adaptor
181 | ### In most cases, model's output is tuple-object, However, in our package, we need the output is dict-object,
182 | ### like: { "layer_name":{"output_name":value} .... }
183 | ### Hence, the output adaptor is to turn your model's output to dict-object output
184 | ### In my case, teacher and student can use one adaptor
185 | def output_adaptor(model_output):
186 | last_hidden_state, pooler_output, hidden_states, attentions = model_output
187 | output = {"embedding_layer": {"embedding": hidden_states[0]}}
188 | for idx in range(len(attentions)):
189 | output["bert_layer" + str(idx + 1)] = {"hidden_states": hidden_states[idx + 1],
190 | "attention": attentions[idx]}
191 | output["pred_layer"] = {"pooler_output": pooler_output}
192 | return output
193 |
194 |
195 | # loss_model
196 | loss_model = MultiLayerBasedDistillationLoss(distill_config=distill_config,
197 | teacher_output_adaptor=output_adaptor,
198 | student_output_adaptor=output_adaptor)
199 | # optimizer
200 | param_optimizer = list(student_model.named_parameters())
201 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
202 | optimizer_grouped_parameters = [
203 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
204 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
205 | ]
206 | optimizer = torch.optim.Adam(params=optimizer_grouped_parameters, lr=learning_rate)
207 | # evaluator
208 | # this is a basic evalator, it can output loss value and save models
209 | # You can define you own evaluator class that implements the interface IEvaluator
210 |
211 | evaluator = MultiLayerBasedDistillationEvaluator(save_dir="save_model", save_step=1000, print_loss_step=20)
212 | # Get a KnowledgeDistiller
213 | distiller = KnowledgeDistiller(teacher_model=teacher_model, student_model=student_model,
214 | train_dataloader=train_dataloader, dev_dataloader=None,
215 | train_data_adaptor=train_data_adaptor, dev_data_adaptor=None,
216 | device=device, loss_model=loss_model, optimizer=optimizer,
217 | evaluator=evaluator, num_epoch=num_epoch)
218 | # start distillate
219 | distiller.distillate()
220 |
221 |
--------------------------------------------------------------------------------
/knowledge_distillation/Evaluator/__init__.py:
--------------------------------------------------------------------------------
1 | from .ievaluator import IEvaluator
2 | from .multi_layer_based_distillation_evaluator import MultiLayerBasedDistillationEvaluator
3 |
--------------------------------------------------------------------------------
/knowledge_distillation/Evaluator/ievaluator.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class IEvaluator(metaclass=abc.ABCMeta):
5 | @abc.abstractmethod
6 | def evaluate(self, *args, **kwargs):
7 | pass
8 |
--------------------------------------------------------------------------------
/knowledge_distillation/Evaluator/multi_layer_based_distillation_evaluator.py:
--------------------------------------------------------------------------------
1 | from .ievaluator import IEvaluator
2 | import os
3 | import logging
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 |
8 | class MultiLayerBasedDistillationEvaluator(IEvaluator):
9 | """
10 | A simple evaluator: output loss and save model.
11 | You can define you own evaluator class that implements the interface IEvaluator.
12 | """
13 |
14 | def __init__(self, save_dir, save_step=None, print_loss_step=20):
15 | """
16 |
17 | :param save_dir: output directory
18 | :param save_step: frequency of saving model
19 | :param print_loss_step: output loss value every (print_loss_step) steps
20 | """
21 | super(MultiLayerBasedDistillationEvaluator, self).__init__()
22 | self.save_step = save_step
23 | self.print_loss_step = print_loss_step
24 | self.save_dir = os.path.abspath(save_dir)
25 | self.loss_fw = None
26 |
27 | os.makedirs(self.save_dir, exist_ok=True)
28 | if save_dir and save_step:
29 | self.loss_fw = open(os.path.join(save_dir, "LossValue.txt"), "w", encoding="utf8")
30 |
31 | def evaluate(self, teacher_model, student_model, dev_data, dev_data_adaptor, epoch, step, loss_value):
32 | if step > 0 and step % self.print_loss_step == 0:
33 | logger.info("epoch:{},\tstep:{},\tloss value:{}".format(epoch, step, loss_value))
34 | if self.save_dir and self.save_step:
35 | self.loss_fw.write("epoch:{},\tstep:{},\tloss value:{}\n".format(epoch, step, loss_value))
36 | self.loss_fw.flush()
37 | if self.save_dir and self.save_step:
38 | if step > 0 and step % self.save_step == 0:
39 | save_path = os.path.join(self.save_dir, str(epoch) + "-" + str(step))
40 | os.makedirs(save_path, exist_ok=True)
41 | student_model.save_pretrained(save_path)
42 |
43 |
44 | if __name__ == "__main__":
45 | logger.info("epoch:{}\tstep:{}\tloss value:{}".format(1, 2, 3))
46 |
--------------------------------------------------------------------------------
/knowledge_distillation/Loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .multi_layer_based_distillation_loss import MultiLayerBasedDistillationLoss
2 | from .cosine_similarity_loss import CosineSimilarityLoss
3 | from .loss_functions import mse, mse_with_mask, attention_mse_with_mask
4 |
--------------------------------------------------------------------------------
/knowledge_distillation/Loss/cosine_similarity_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class CosineSimilarityLoss(nn.Module):
6 | def __init__(self, ):
7 | super(CosineSimilarityLoss, self).__init__()
8 |
9 | def forward(self, x1, x2):
10 | return 0.5 - 0.5 * torch.cosine_similarity(x1, x2)
11 |
--------------------------------------------------------------------------------
/knowledge_distillation/Loss/loss_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def mse_with_mask(teacher_output, student_output, teacher_input=None, student_input=None):
6 | mask = teacher_input["attention_mask"]
7 | mask = mask.to(student_output)
8 | # * hidden_size
9 | valid_count = mask.sum() * student_output.size(-1)
10 | loss = (F.mse_loss(teacher_output, student_output, reduction='none') * mask.unsqueeze(-1)).sum() / valid_count
11 | return loss
12 |
13 |
14 | def attention_mse_with_mask(teacher_output, student_output, teacher_input=None, student_input=None):
15 | mask = teacher_input["attention_mask"]
16 | mask = mask.to(student_output).unsqueeze(1).expand(-1, student_output.size(1), -1) # (bs, num_of_heads, len)
17 | valid_count = torch.pow(mask.sum(dim=2), 2).sum()
18 | loss = (F.mse_loss(student_output, teacher_output, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze(
19 | 2)).sum() / valid_count
20 | return loss
21 |
22 |
23 | def mse(teacher_output, student_output, teacher_input=None, student_input=None):
24 | return F.mse_loss(teacher_output, student_output, reduction='mean')
25 |
--------------------------------------------------------------------------------
/knowledge_distillation/Loss/multi_layer_based_distillation_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .cosine_similarity_loss import CosineSimilarityLoss
3 | from .loss_functions import mse_with_mask, mse, attention_mse_with_mask
4 |
5 |
6 | class MultiLayerBasedDistillationLoss(nn.Module):
7 | def __init__(self, distill_config=None, teacher_output_adaptor=None, student_output_adaptor=None):
8 | super(MultiLayerBasedDistillationLoss, self).__init__()
9 | self.distill_config = distill_config
10 | self.teacher_output_adaptor = teacher_output_adaptor
11 | self.student_output_adaptor = student_output_adaptor
12 | self.loss_functions = {"mse": mse, "mse_with_mask": mse_with_mask,
13 | "attention_mse_with_mask": attention_mse_with_mask,
14 | "cross_entropy": nn.CrossEntropyLoss(), "cos": CosineSimilarityLoss()}
15 |
16 | def forward(self, teacher_output, student_output, teacher_input_data, student_input_data):
17 | teacher_adaptor_output = self.teacher_output_adaptor(teacher_output)
18 | student_adaptor_output = self.student_output_adaptor(student_output)
19 | loss = 0
20 | for distill_info in self.distill_config:
21 | tmp_teacher_output = teacher_adaptor_output[distill_info["teacher_layer_name"]][
22 | distill_info["teacher_layer_output_name"]]
23 | tmp_student_output = student_adaptor_output[distill_info["student_layer_name"]][
24 | distill_info["student_layer_output_name"]]
25 | tmp_loss = self.loss_functions[distill_info["loss"]["loss_function"]](tmp_teacher_output,
26 | tmp_student_output,
27 | teacher_input_data,
28 | student_input_data,
29 | **distill_info["loss"]["args"])
30 | tmp_loss *= distill_info["weight"]
31 | loss += tmp_loss
32 |
33 | # student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att), student_att) # 小于-100的部分变成0
34 | return loss
35 |
--------------------------------------------------------------------------------
/knowledge_distillation/__init__.py:
--------------------------------------------------------------------------------
1 | from .knowledge_distillation import KnowledgeDistiller
2 | from .Loss import MultiLayerBasedDistillationLoss
3 | from .Evaluator import MultiLayerBasedDistillationEvaluator
4 |
--------------------------------------------------------------------------------
/knowledge_distillation/knowledge_distillation.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class KnowledgeDistiller():
5 | def __init__(self, teacher_model, student_model, train_dataloader, dev_dataloader, device, loss_model, optimizer,
6 | evaluator, num_epoch, train_data_adaptor, dev_data_adaptor):
7 | # send model to device
8 | self.teacher_model = teacher_model.to(device)
9 | self.student_model = student_model.to(device)
10 | self.train_dataloader = train_dataloader
11 | self.dev_dataloader = dev_dataloader
12 | self.device = device
13 | self.loss_model = loss_model
14 | self.optimizer = optimizer
15 | self.evaluator = evaluator
16 | self.num_epoch = num_epoch
17 | self.train_data_adaptor = train_data_adaptor
18 | self.dev_data_adaptor = dev_data_adaptor
19 |
20 | def distillate(self):
21 | # not train teacher model
22 | self.teacher_model.eval()
23 | # train student model
24 | self.student_model.train()
25 | for epoch in range(self.num_epoch):
26 | for step, batch_data in enumerate(self.train_dataloader):
27 | # get input data for teacher model and student model
28 | teacher_batch_data, student_batch_data = self.train_data_adaptor(self.device, batch_data)
29 | # get teacher output, not compute gradient
30 | with torch.no_grad():
31 | if isinstance(teacher_batch_data, dict):
32 | teacher_output = self.teacher_model.forward(**teacher_batch_data)
33 | else:
34 | teacher_output = self.teacher_model.forward(*teacher_batch_data)
35 | # get student output
36 | if isinstance(student_batch_data, dict):
37 | student_output = self.student_model.forward(**student_batch_data)
38 | else:
39 | student_output = self.student_model.forward(*student_batch_data)
40 | # get loss
41 | loss = self.loss_model.forward(teacher_output, student_output, teacher_batch_data, student_batch_data)
42 | loss.backward()
43 | self.optimizer.step()
44 | self.optimizer.zero_grad()
45 | # evaluate and save model
46 | self.evaluator.evaluate(self.teacher_model, self.student_model, self.dev_dataloader,
47 | self.dev_data_adaptor, epoch, step, loss)
48 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='KnowledgeDistillation',
5 | version="1.0.4",
6 | description=('A general knowledge distillation framework'),
7 | long_description=open('README.rst').read(),
8 | author='ZhangDun',
9 | author_email='dunnzhang0@gmail.com',
10 | maintainer='ZhangDun',
11 | maintainer_email='dunnzhang0@gmail.com',
12 | license='MIT',
13 | packages=find_packages(),
14 | url='https://github.com/DunZhang/KnowledgeDistillation',
15 | install_requires=['torch>0.4.0', 'python>=3.6'],
16 | classifiers=[
17 | 'Development Status :: 2 - Pre-Alpha',
18 | 'Operating System :: Microsoft :: Windows',
19 | 'Operating System :: POSIX',
20 | 'Operating System :: Unix',
21 | 'Operating System :: MacOS',
22 | 'Intended Audience :: Developers',
23 | 'Intended Audience :: Science/Research',
24 | 'Programming Language :: Python :: 3',
25 | 'License :: OSI Approved :: MIT License',
26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence'
27 | ],
28 | keywords="Transformer Networks BERT XLNet PyTorch NLP deep learning"
29 | )
30 |
--------------------------------------------------------------------------------