├── README.md ├── examples ├── bert-sst2-gt.py ├── bert-sst2-training.py ├── t5-sst2-gt.py ├── t5-sst2-inf.py ├── t5-sst2-mlp.py ├── t5_cluster_example.py └── t5_select_example.py ├── faster_moefication ├── README.md ├── balanced_assignment │ ├── ba.cpp │ └── setup.py ├── main.py └── moefication.py └── moefication ├── adj.py ├── mlp_select_example.py ├── param_cluster_example.py ├── similarity_select_example.py ├── trans_gp.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | # MoEfication 3 | 4 | Source code for "[MoEfication: Transformer Feed-forward Layers are Mixtures of Experts](https://arxiv.org/abs/2110.01786)" and "[Exploring the Benefit of Activation Sparsity in Pre-training](https://openreview.net/forum?id=KfXXPCcobh)" 5 | 6 | **Important Update (2024/07/21)**: In addition to the original work, this repository now contains [subsequent research](https://openreview.net/forum?id=KfXXPCcobh) that provides a **faster and better parameter clustering method** by incorporating GPUs and initialization from previous results. We encourage users to use the new contributions when considering **MoEfication during training**. This method is detailed in the [faster_moefication](./faster_moefication/README.md). 7 | 8 | **Update (2022/11/30):** We provide a simple example of using fastmoe for efficient MoE implementation in the branch [fastmoe](https://github.com/thunlp/MoEfication/tree/fastmoe). We will provide how to transform a MoEfied checkpoint to a fastmoe checkpoint soon. Keep tuned! 9 | 10 | ## Reqirements: 11 | 12 | * Python3.8 13 | * torch==1.6.0 14 | * transformers==4.20.1 15 | * tqdm 16 | * scikit-learn 17 | * k_means_constrained 18 | * datasets 19 | * numpy 20 | * scipy 21 | 22 | ## Expert Construction 23 | 24 | For parameter clustering split, we use balanced K-Means. The details of the implementation can be found in `param_cluster_example.py`. 25 | 26 | For co-activation graph split, we first construct a co-activation graph by `adj.py`. For T5, the output graphs are named as `encoder.blocks.0.ff.dense_relu_dense.wi.weight`, `encoder.blocks.1.ff.dense_relu_dense.wi.weight`, ..., `decoder.blocks.11.ff.dense_relu_dense.wi.weight`, which are the weight names. 27 | 28 | Then, we use [METIS](http://glaros.dtc.umn.edu/gkhome/metis/metis/download) to split the graph into subgraphs. 29 | ``` 30 | gpmetis encoder.blocks.0.ff.dense_relu_dense.wi.weight num_expert 31 | ``` 32 | where `num_expert` is the number of experts. 33 | 34 | Finally, we balance the neurons in each expert. 35 | ``` 36 | # num_expert=128 37 | python trans_gp.py encoder.blocks.0.ff.dense_relu_dense.wi.weight.part.128 38 | ``` 39 | 40 | ## Expert Selection 41 | 42 | For similarity selection, we average the corresponding weight columns as the expert representation. The details of the implementation can be found in `similarity_select_example.py`. 43 | 44 | For MLP selection, We train a multi-layer perceptron (MLP), which takes the $\vx$ as input and predicts the sum of positive values in each expert. The details of the implementation can be found in `mlp_select_example.py`. 45 | 46 | ## T5 Examples 47 | 48 | We provide an example of T5-base on SST-2 in `examples`, including groundtruth selection and MLP selection based on parameter clustering. 49 | 50 | First, you need to construct expert by 51 | 52 | ``` 53 | python examples/t5_cluster_example.py 54 | ``` 55 | 56 | Then, you can directly evaluate groundtruth selection by 57 | 58 | ``` 59 | python examples/t5-sst2-gt.py 60 | ``` 61 | 62 | To use MLP selection, you need to train the MLP by 63 | 64 | ``` 65 | python examples/t5-sst2-inf.py 66 | python examples/t5_select_example.py 67 | ``` 68 | 69 | And, you can evaluate the performance of MLP selection by 70 | 71 | ``` 72 | python examples/t5-sst2-mlp.py 73 | ``` 74 | 75 | ## BERT Examples 76 | 77 | We also provide an example of BERT-large on SST-2 in `examples`. The checkpoint of ReLU-based BERT is available [here](https://cloud.tsinghua.edu.cn/f/cce7d1c994904f0f81bd/?dl=1). 78 | 79 | You need to first download it and fine-tune it on SST-2 by 80 | 81 | ``` 82 | python examples/bert-sst2-training.py 83 | ``` 84 | 85 | Then, you need to construct expert by 86 | 87 | ``` 88 | python moefication/param_cluster_example.py --model_path bert-sst2-bsz32/epoch_1.bin --res_path results/bert-sst2 --num-layer 24 --num-expert 128 --templates bert.encoder.layer.{}.intermediate.dense.weight 89 | ``` 90 | 91 | you can evaluate groundtruth selection by 92 | 93 | ``` 94 | python examples/bert-sst2-gt.py 95 | ``` 96 | 97 | ### Tips for Training ReLU-based BERT 98 | 99 | We use the pre-training script from [NVIDIA](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT). The only difference is that we replace the activation function with ReLU and set the bias of the intermediate layer to None. We initialize the model with the checkpoint of [BERT-Large-Uncased](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/dle/models/bert_large_pyt_ckpt_mode-pretrain) provided by NVIDIA. In the experiments, we found that training around 200 steps is enough to get a good performance. 100 | 101 | ## Cite 102 | 103 | If you use the code, please cite this paper: 104 | 105 | ``` 106 | @inproceedings{ 107 | zhang2024exploring, 108 | title={Exploring the Benefit of Activation Sparsity in Pre-training}, 109 | author={Zhengyan Zhang and Chaojun Xiao and Qiujieli Qin and Yankai Lin and Zhiyuan Zeng and Xu Han and Zhiyuan Liu and Ruobing Xie and Maosong Sun and Jie Zhou}, 110 | booktitle={Proceedings of ICML}, 111 | year={2024}, 112 | } 113 | 114 | @inproceedings{zhang2022moefication, 115 | title={{MoEfication}: Transformer Feed-forward Layers are Mixtures of Experts}, 116 | author={Zhang, Zhengyan and Lin, Yankai and Liu, Zhiyuan and Li, Peng and Sun, Maosong and Zhou, Jie}, 117 | booktitle={Findings of ACL 2022}, 118 | year={2022} 119 | } 120 | ``` 121 | -------------------------------------------------------------------------------- /examples/bert-sst2-gt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | import tqdm 4 | import torch 5 | from datasets import load_dataset 6 | from transformers import BertTokenizer, BertForSequenceClassification, BertConfig 7 | from transformers.models.bert.modeling_bert import BertIntermediate 8 | import numpy as np 9 | 10 | batch_size = 8 11 | k = 20 12 | ckpt_path = 'bert-sst2-bsz32/epoch_1.bin' 13 | 14 | tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") 15 | config = BertConfig.from_pretrained("bert-large-uncased") 16 | 17 | # transform BERT to relu-based BERT 18 | config.hidden_act = 'relu' 19 | config.num_labels = 2 20 | config.problem_type = "single_label_classification" 21 | model = BertForSequenceClassification(config=config) 22 | for x in model.bert.encoder.layer: 23 | x.intermediate.dense.bias = None 24 | 25 | res = model.load_state_dict(torch.load(ckpt_path, map_location='cpu'), strict=False) 26 | print(res) 27 | model.cuda() 28 | 29 | sst2 = load_dataset('sst2') 30 | 31 | sst2_eval = sst2['validation'] 32 | eval_dataloaders = torch.utils.data.DataLoader(sst2_eval, batch_size=batch_size) 33 | 34 | def change_forward(model, k=20): 35 | 36 | def _forward(ffn_self, hidden_states): 37 | hidden_states = ffn_self.forward_old(hidden_states) 38 | 39 | if ffn_self.patterns is not None: 40 | # golden 41 | k = ffn_self.k 42 | bsz, seq_len, hidden_size = hidden_states.shape 43 | hidden_states_relu = hidden_states.clone() 44 | hidden_states_relu = hidden_states_relu.view(-1, hidden_size) 45 | score = torch.matmul(hidden_states_relu, ffn_self.patterns.transpose(0, 1)) 46 | labels = torch.topk(score, k=k, dim=-1)[1].view(bsz, seq_len, k) 47 | cur_mask = torch.nn.functional.embedding(labels, ffn_self.patterns).sum(-2) 48 | hidden_states[cur_mask == False] = 0 49 | 50 | return hidden_states 51 | 52 | def modify_ffn(ffn, path): 53 | assert type(ffn) == BertIntermediate 54 | labels = torch.load(path) 55 | cluster_num = max(labels)+1 56 | patterns = [] 57 | for i in range(cluster_num): 58 | patterns.append(np.array(labels) == i) 59 | ffn.patterns = torch.Tensor(patterns).cuda() 60 | ffn.k = k 61 | ffn.forward_old = ffn.forward 62 | ffn.forward = types.MethodType(_forward, ffn) 63 | 64 | # encoder 65 | for layer_idx, layer in enumerate(model.bert.encoder.layer): 66 | ffn = layer.intermediate 67 | path = os.path.join('results/bert-sst2', 'param_split', 'bert.encoder.layer.{}.intermediate.dense.weight'.format(layer_idx)) 68 | modify_ffn(ffn, path) 69 | 70 | change_forward(model, k) 71 | 72 | model.eval() 73 | correct = 0 74 | total = 0 75 | for batch in tqdm.tqdm(eval_dataloaders): 76 | inputs = tokenizer(batch['sentence'], return_tensors='pt', padding=True, truncation=True, max_length=128) 77 | labels = torch.tensor(batch['label']) 78 | inputs = {k: v.cuda() for k, v in inputs.items()} 79 | labels = labels.cuda() 80 | with torch.no_grad(): 81 | outputs = model(**inputs, labels=labels) 82 | logits = outputs.logits 83 | pred = logits.argmax(dim=-1) 84 | tmp_correct = (pred == labels).sum().item() 85 | correct += tmp_correct 86 | total += len(labels) 87 | print("Acc", correct * 1. / total) 88 | -------------------------------------------------------------------------------- /examples/bert-sst2-training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | import tqdm 4 | import torch 5 | from datasets import load_dataset 6 | from transformers import BertTokenizer, BertForSequenceClassification, BertConfig 7 | import numpy as np 8 | 9 | batch_size = 32 10 | gradient_accumulation_steps = 4 11 | mini_batch_size = batch_size // gradient_accumulation_steps 12 | folder = "bert-sst2-bsz32" 13 | ckpt_path = "relu-bert-large-uncased.bin" 14 | 15 | if not os.path.exists(folder): 16 | os.makedirs(folder) 17 | 18 | tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") 19 | config = BertConfig.from_pretrained("bert-large-uncased") 20 | 21 | # transform BERT to relu-based BERT 22 | config.hidden_act = 'relu' 23 | config.num_labels = 2 24 | config.problem_type = "single_label_classification" 25 | model = BertForSequenceClassification(config=config) 26 | for x in model.bert.encoder.layer: 27 | x.intermediate.dense.bias = None 28 | 29 | res = model.load_state_dict(torch.load(ckpt_path, map_location='cpu'), strict=False) 30 | print(res) 31 | model.cuda() 32 | 33 | sst2 = load_dataset('sst2') 34 | sst2_train = sst2['train'] 35 | 36 | optimizer = torch.optim.Adam(model.parameters(), lr=2e-5) 37 | all_step = len(sst2_train) // batch_size 38 | warmup_step = all_step // 10 39 | lr_lambda = lambda step: min(step / (warmup_step + 1e-8), 1.0) 40 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 41 | dataloaders = torch.utils.data.DataLoader(sst2_train, batch_size=mini_batch_size, shuffle=True) 42 | 43 | sst2_eval = sst2['validation'] 44 | eval_dataloaders = torch.utils.data.DataLoader(sst2_eval, batch_size=mini_batch_size) 45 | 46 | for epoch in range(3): 47 | model.train() 48 | optimizer.zero_grad() 49 | step = 0 50 | for batch in tqdm.tqdm(dataloaders): 51 | inputs = tokenizer(batch['sentence'], return_tensors='pt', padding=True, truncation=True, max_length=128) 52 | labels = torch.tensor(batch['label']) 53 | inputs = {k: v.cuda() for k, v in inputs.items()} 54 | labels = labels.cuda() 55 | outputs = model(**inputs, labels=labels) 56 | loss = outputs.loss 57 | loss.backward() 58 | print(loss.item()) 59 | 60 | step += 1 61 | if step % gradient_accumulation_steps == 0: 62 | optimizer.step() 63 | scheduler.step() 64 | optimizer.zero_grad() 65 | 66 | torch.save(model.state_dict(), '{}/epoch_{}.bin'.format(folder, epoch)) 67 | 68 | model.eval() 69 | correct = 0 70 | total = 0 71 | for batch in tqdm.tqdm(eval_dataloaders): 72 | inputs = tokenizer(batch['sentence'], return_tensors='pt', padding=True, truncation=True, max_length=128) 73 | labels = torch.tensor(batch['label']) 74 | inputs = {k: v.cuda() for k, v in inputs.items()} 75 | labels = labels.cuda() 76 | outputs = model(**inputs, labels=labels) 77 | logits = outputs.logits 78 | pred = logits.argmax(dim=-1) 79 | tmp_correct = (pred == labels).sum().item() 80 | correct += tmp_correct 81 | total += len(labels) 82 | print("Acc", correct * 1. / total) 83 | -------------------------------------------------------------------------------- /examples/t5-sst2-gt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | import torch 4 | from datasets import load_dataset 5 | from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config 6 | import numpy as np 7 | from transformers.models.t5.modeling_t5 import T5DenseActDense 8 | 9 | # the number of selected experts for each token 10 | k=20 11 | 12 | tokenizer = T5Tokenizer.from_pretrained('t5-base') 13 | config = T5Config.from_pretrained('t5-base') 14 | model = T5ForConditionalGeneration.from_pretrained('t5-base').cuda() 15 | 16 | sst2 = load_dataset('sst2') 17 | sst2_dev = sst2['validation'] 18 | 19 | pred = [] 20 | 21 | def change_forward(model, k=20): 22 | 23 | def _forward(ffn_self, hidden_states): 24 | hidden_states = ffn_self.wi(hidden_states) 25 | hidden_states = ffn_self.act(hidden_states) 26 | 27 | if ffn_self.patterns is not None: 28 | # golden 29 | k = ffn_self.k 30 | bsz, seq_len, hidden_size = hidden_states.shape 31 | hidden_states_relu = hidden_states.clone() 32 | hidden_states_relu = hidden_states_relu.view(-1, hidden_size) 33 | score = torch.matmul(hidden_states_relu, ffn_self.patterns.transpose(0, 1)) 34 | labels = torch.topk(score, k=k, dim=-1)[1].view(bsz, seq_len, k) 35 | cur_mask = torch.nn.functional.embedding(labels, ffn_self.patterns).sum(-2) 36 | hidden_states[cur_mask == False] = 0 37 | 38 | hidden_states = ffn_self.dropout(hidden_states) 39 | hidden_states = ffn_self.wo(hidden_states) 40 | return hidden_states 41 | 42 | def modify_ffn(ffn, path): 43 | assert type(ffn) == T5DenseActDense 44 | labels = torch.load(path) 45 | cluster_num = max(labels)+1 46 | patterns = [] 47 | for i in range(cluster_num): 48 | patterns.append(np.array(labels) == i) 49 | ffn.patterns = torch.Tensor(patterns).cuda() 50 | ffn.k = k 51 | ffn.forward_old = ffn.forward 52 | ffn.forward = types.MethodType(_forward, ffn) 53 | 54 | # encoder 55 | for layer_idx, layer in enumerate(model.encoder.block): 56 | ffn = layer.layer[1].DenseReluDense 57 | path = os.path.join('results/t5-base', 'param_split', 'encoder.block.{}.layer.1.DenseReluDense.wi.weight'.format(layer_idx)) 58 | modify_ffn(ffn, path) 59 | 60 | #decoder 61 | for layer_idx, layer in enumerate(model.decoder.block): 62 | ffn = layer.layer[2].DenseReluDense 63 | path = os.path.join('results/t5-base', 'param_split', 'decoder.block.{}.layer.2.DenseReluDense.wi.weight'.format(layer_idx)) 64 | modify_ffn(ffn, path) 65 | 66 | change_forward(model, k) 67 | 68 | # sst2 evaluation 69 | for instance in sst2_dev: 70 | input_ids = tokenizer("sst2 sentence: "+instance['sentence'], return_tensors="pt").input_ids.cuda() 71 | dec_input_ids = tokenizer("", return_tensors="pt").input_ids.cuda()[:, :1] 72 | 73 | output = model(input_ids=input_ids, labels=dec_input_ids) 74 | 75 | pred.append(int(output.logits[:, 0, 1465].item() > output.logits[:, 0, 2841].item()) == instance['label']) 76 | 77 | print("Acc", sum(pred) * 1. / len(pred), 'k', k) 78 | -------------------------------------------------------------------------------- /examples/t5-sst2-inf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | import tqdm 4 | import torch 5 | from datasets import load_dataset 6 | from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config 7 | import numpy as np 8 | from transformers.models.t5.modeling_t5 import T5DenseActDense 9 | 10 | tokenizer = T5Tokenizer.from_pretrained('t5-base') 11 | config = T5Config.from_pretrained('t5-base') 12 | model = T5ForConditionalGeneration.from_pretrained('t5-base').cuda() 13 | 14 | sst2 = load_dataset('sst2') 15 | sst2_train = sst2['train'] 16 | 17 | pred = [] 18 | 19 | def change_forward(model): 20 | 21 | def _forward(ffn_self, hidden_states): 22 | ffn_self.res.append(hidden_states.detach().cpu()) 23 | 24 | hidden_states = ffn_self.wi(hidden_states) 25 | hidden_states = ffn_self.act(hidden_states) 26 | 27 | hidden_states = ffn_self.dropout(hidden_states) 28 | hidden_states = ffn_self.wo(hidden_states) 29 | return hidden_states 30 | 31 | def modify_ffn(ffn, res): 32 | assert type(ffn) == T5DenseActDense 33 | ffn.res = res 34 | ffn.forward = types.MethodType(_forward, ffn) 35 | 36 | # encoder 37 | res = {} 38 | for layer_idx, layer in enumerate(model.encoder.block): 39 | ffn = layer.layer[1].DenseReluDense 40 | name = 'encoder.block.{}.layer.1.DenseReluDense.wi.weight'.format(layer_idx) 41 | res[name] = [] 42 | modify_ffn(ffn, res[name]) 43 | 44 | #decoder 45 | for layer_idx, layer in enumerate(model.decoder.block): 46 | ffn = layer.layer[2].DenseReluDense 47 | name = 'decoder.block.{}.layer.2.DenseReluDense.wi.weight'.format(layer_idx) 48 | res[name] = [] 49 | modify_ffn(ffn, res[name]) 50 | 51 | return res 52 | 53 | res = change_forward(model) 54 | 55 | # sst2 evaluation 56 | for idx, instance in enumerate(tqdm.tqdm(sst2_train)): 57 | if idx == 10000: 58 | break 59 | 60 | input_ids = tokenizer("sst2 sentence: "+instance['sentence'], return_tensors="pt").input_ids.cuda() 61 | dec_input_ids = tokenizer("", return_tensors="pt").input_ids.cuda()[:, :1] 62 | 63 | output = model(input_ids=input_ids, labels=dec_input_ids) 64 | 65 | pred.append(int(output.logits[:, 0, 1465].item() > output.logits[:, 0, 2841].item()) == instance['label']) 66 | 67 | print("Acc", sum(pred) * 1. / len(pred)) 68 | 69 | for k, v in res.items(): 70 | v = [x.reshape(-1, x.shape[-1]) for x in v] 71 | v = torch.cat(v, dim=0) 72 | torch.save(v, 'results/t5-base/'+k) -------------------------------------------------------------------------------- /examples/t5-sst2-mlp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | import torch 4 | from datasets import load_dataset 5 | from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config 6 | import numpy as np 7 | from transformers.models.t5.modeling_t5 import T5DenseActDense 8 | 9 | k=20 10 | 11 | tokenizer = T5Tokenizer.from_pretrained('t5-base') 12 | config = T5Config.from_pretrained('t5-base') 13 | model = T5ForConditionalGeneration.from_pretrained('t5-base').cuda() 14 | 15 | sst2 = load_dataset('sst2') 16 | sst2_dev = sst2['validation'] 17 | 18 | pred = [] 19 | 20 | def change_forward(model, k=20): 21 | 22 | def _forward(ffn_self, hidden_states): 23 | 24 | bsz, seq_len, hidden_size = hidden_states.shape 25 | hidden_states_mlp = hidden_states.clone().detach() 26 | hidden_states_mlp = hidden_states_mlp.view(-1, hidden_size) 27 | 28 | hidden_states_mlp = hidden_states_mlp / torch.norm(hidden_states_mlp, dim=-1).unsqueeze(-1) 29 | score = ffn_self.mlp(hidden_states_mlp) 30 | 31 | labels = torch.topk(score, k=k, dim=-1)[1].view(bsz, seq_len, k) 32 | cur_mask = torch.nn.functional.embedding(labels, ffn_self.patterns).sum(-2) 33 | 34 | hidden_states = ffn_self.wi(hidden_states) 35 | hidden_states = ffn_self.act(hidden_states) 36 | hidden_states[cur_mask == False] = 0 37 | 38 | hidden_states = ffn_self.dropout(hidden_states) 39 | hidden_states = ffn_self.wo(hidden_states) 40 | return hidden_states 41 | 42 | def modify_ffn(ffn, path): 43 | assert type(ffn) == T5DenseActDense 44 | labels = torch.load(path) 45 | cluster_num = max(labels)+1 46 | patterns = [] 47 | for i in range(cluster_num): 48 | patterns.append(np.array(labels) == i) 49 | ffn.patterns = torch.Tensor(patterns).cuda() 50 | ffn.k = k 51 | ffn.mlp = torch.load(path+'_input_compl').cuda() 52 | ffn.forward_old = ffn.forward 53 | ffn.forward = types.MethodType(_forward, ffn) 54 | 55 | # encoder 56 | for layer_idx, layer in enumerate(model.encoder.block): 57 | ffn = layer.layer[1].DenseReluDense 58 | path = os.path.join('results/t5-base', 'param_split', 'encoder.block.{}.layer.1.DenseReluDense.wi.weight'.format(layer_idx)) 59 | modify_ffn(ffn, path) 60 | 61 | #decoder 62 | for layer_idx, layer in enumerate(model.decoder.block): 63 | ffn = layer.layer[2].DenseReluDense 64 | path = os.path.join('results/t5-base', 'param_split', 'decoder.block.{}.layer.2.DenseReluDense.wi.weight'.format(layer_idx)) 65 | modify_ffn(ffn, path) 66 | 67 | change_forward(model, k) 68 | 69 | # sst2 evaluation 70 | for instance in sst2_dev: 71 | input_ids = tokenizer("sst2 sentence: "+instance['sentence'], return_tensors="pt").input_ids.cuda() 72 | dec_input_ids = tokenizer("", return_tensors="pt").input_ids.cuda()[:, :1] 73 | 74 | output = model(input_ids=input_ids, labels=dec_input_ids) 75 | 76 | pred.append(int(output.logits[:, 0, 1465].item() > output.logits[:, 0, 2841].item()) == instance['label']) 77 | 78 | print("Acc", sum(pred) * 1. / len(pred), 'k', k) -------------------------------------------------------------------------------- /examples/t5_cluster_example.py: -------------------------------------------------------------------------------- 1 | from ast import arg 2 | from re import template 3 | import os 4 | import sys 5 | import numpy as np 6 | import torch 7 | import argparse 8 | import tqdm 9 | from transformers import T5ForConditionalGeneration 10 | 11 | sys.path.append('moefication') 12 | import utils 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument('--model_name', type=str, default='t5-base', help='model name in huggingface model hub') 17 | parser.add_argument('--res_path', type=str, default='results/t5-base/', help='path to store the results of moefication') 18 | parser.add_argument('--num-layer', type=int, default=12, help='number of layers') 19 | parser.add_argument('--num-expert', type=int, default=96, help='number of experts') 20 | parser.add_argument('--templates', type=str, default='encoder.block.{}.layer.1.DenseReluDense.wi.weight,decoder.block.{}.layer.2.DenseReluDense.wi.weight', help='weight names of the first linear layer in each FFN (use comma to separate multiple templates)') 21 | 22 | args = parser.parse_args() 23 | if not os.path.exists(args.res_path): 24 | os.makedirs(args.res_path) 25 | 26 | model = T5ForConditionalGeneration.from_pretrained(args.model_name) 27 | torch.save(model.state_dict(), os.path.join(args.res_path, 'model.pt')) 28 | 29 | config = utils.ModelConfig(os.path.join(args.res_path, 'model.pt'), args.res_path, split_num=args.num_expert) 30 | 31 | templates = args.templates.split(',') 32 | 33 | for template in templates: 34 | for i in tqdm.tqdm(range(args.num_layer)): 35 | split = utils.ParamSplit(config, template, i) 36 | split.split() 37 | split.cnt() 38 | split.save() 39 | -------------------------------------------------------------------------------- /examples/t5_select_example.py: -------------------------------------------------------------------------------- 1 | from tempfile import template 2 | import numpy as np 3 | import torch 4 | import sys 5 | import argparse 6 | from transformers import T5ForConditionalGeneration 7 | import os 8 | 9 | sys.path.append('moefication') 10 | import utils 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--model_name', type=str, default='t5-base', help='model name in huggingface model hub') 15 | parser.add_argument('--res_path', type=str, default='results/t5-base/', help='path to store the results of moefication') 16 | parser.add_argument('--num-layer', type=int, default=12, help='number of layers') 17 | parser.add_argument('--num-expert', type=int, default=96, help='number of experts') 18 | parser.add_argument('--templates', type=str, default='encoder.block.{}.layer.1.DenseReluDense.wi.weight,decoder.block.{}.layer.2.DenseReluDense.wi.weight', help='weight names of the first linear layer in each FFN (use comma to separate multiple templates)') 19 | 20 | args = parser.parse_args() 21 | 22 | model = T5ForConditionalGeneration.from_pretrained(args.model_name) 23 | torch.save(model.state_dict(), os.path.join(args.res_path, 'model.pt')) 24 | 25 | config = utils.ModelConfig(os.path.join(args.res_path, 'model.pt'), args.res_path, split_num=args.num_expert) 26 | 27 | templates = args.templates.split(',') 28 | for template in templates: 29 | for i in range(args.num_layer): 30 | center = utils.MLPCenter(config, template, '{}/param_split/{}'.format(args.res_path, template.format(i)), i) 31 | center.cal_center() 32 | -------------------------------------------------------------------------------- /faster_moefication/README.md: -------------------------------------------------------------------------------- 1 | # Faster MoEfication 2 | 3 | Source code for ICML 2024 paper "[Exploring the Benefit of Activation Sparsity in Pre-training](https://openreview.net/forum?id=KfXXPCcobh)" 4 | 5 | Faster Moefication improves upon the parameter clustering method in Moefication by leveraging multi-GPU computation to significantly accelerate the model's parameter clustering process. This approach consists of two main parts: 6 | 7 | 1. Initial clustering using the k-means implementation provided by faiss-gpu. 8 | 2. Application of a balanced allocation algorithm on the k-means results to achieve a balanced cluster structure, leading to the final Moefication outcome. 9 | 10 | This method facilitates the use of Moefication during the training process. In our ICML paper, we obtained a model that performs well under both dense and MoE sparse computations by alternating between dense and MoE sparse training. The MoE conversion in this process utilized the Faster Moefication method. 11 | 12 | ## Reqirements: 13 | 14 | * Python3.8 15 | * torch 16 | * tqdm 17 | * scikit-learn 18 | * numpy 19 | * faiss-gpu 20 | 21 | Besides, users need to install our custom allocation algorithm by 22 | 23 | ``` 24 | cd balanced_assignment/ 25 | python setup.py install 26 | ``` 27 | 28 | ## kmeans_balanced 29 | 30 | The main function interface for Faster Moefication is `kmeans_balanced`. This function performs balanced k-means clustering on the input matrix. 31 | 32 | Function signature: 33 | ```python 34 | def kmeans_balanced(matrix, num_clusters, cluster_size, ...): 35 | ... 36 | ``` 37 | 38 | Main parameters: 39 | 40 | 1. matrix: The input matrix to be partitioned. 41 | 2. num_clusters: The number of clusters to create. 42 | 3. cluster_size: The size of each cluster. 43 | 44 | This function first applies k-means clustering using `faiss-gpu`, then adjusts the clusters to ensure balanced sizes using our custom allocation algorithm. 45 | 46 | ## Usage Example 47 | 48 | We provide a usage example in `main.py` that simulates the Moefication process for an eight-layer network. This example demonstrates the full workflow of Faster Moefication: 49 | 50 | 1. Parameter Distribution: The script starts by distributing parameters from a single GPU to eight different GPUs. 51 | 2. Layer-wise Clustering: It then performs clustering on each layer independently. 52 | 3. Result Gathering: Finally, it gathers the results back to a single GPU. 53 | 54 | This process completes the Moefication of the model, leveraging multi-GPU parallelism for increased efficiency. 55 | 56 | To run the example: 57 | 58 | ```bash 59 | torchruntorchrun --nproc_per_node=8 main.py 60 | ``` 61 | 62 | The script will report the total time of this process. 63 | 64 | ## Acknowledgement 65 | 66 | Our custom allocation algorithm is inspired by the expert allocation algorithm implemented by [Base Layers](https://arxiv.org/abs/2103.16716). We are grateful to the authors for their innovative approach, which has significantly influenced our work. 67 | 68 | ## Cite 69 | 70 | If you use the code, please cite this paper: 71 | 72 | ``` 73 | @inproceedings{ 74 | zhang2024exploring, 75 | title={Exploring the Benefit of Activation Sparsity in Pre-training}, 76 | author={Zhengyan Zhang and Chaojun Xiao and Qiujieli Qin and Yankai Lin and Zhiyuan Zeng and Xu Han and Zhiyuan Liu and Ruobing Xie and Maosong Sun and Jie Zhou}, 77 | booktitle={Proceedings of ICML}, 78 | year={2024}, 79 | } 80 | 81 | @inproceedings{zhang2022moefication, 82 | title={{MoEfication}: Transformer Feed-forward Layers are Mixtures of Experts}, 83 | author={Zhang, Zhengyan and Lin, Yankai and Liu, Zhiyuan and Li, Peng and Sun, Maosong and Zhou, Jie}, 84 | booktitle={Findings of ACL 2022}, 85 | year={2022} 86 | } 87 | ``` -------------------------------------------------------------------------------- /faster_moefication/balanced_assignment/ba.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | /* 10 | C++ code for solving the linear assignment problem. 11 | Based on the Auction Algorithm from 12 | https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the 13 | implementation from: https://github.com/bkj/auction-lap Adapted to be more 14 | efficient when each worker is looking for k jobs instead of 1. 15 | */ 16 | #include 17 | #include 18 | using namespace torch::indexing; 19 | std::pair balanced_assignment(torch::Tensor job_and_worker_to_score, int max_iterations) { 20 | // int max_iterations = 100; 21 | // torch::Tensor epsilon = 22 | // (job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50; 23 | torch::Tensor epsilon = 24 | (job_and_worker_to_score.max() - job_and_worker_to_score.min()) / max_iterations * 2; 25 | 26 | epsilon.clamp_min_(1e-04); 27 | torch::Tensor worker_and_job_to_score = 28 | job_and_worker_to_score.detach().transpose(0, 1).contiguous(); 29 | int num_workers = worker_and_job_to_score.size(0); 30 | int num_jobs = worker_and_job_to_score.size(1); 31 | auto device = worker_and_job_to_score.device(); 32 | int jobs_per_worker = num_jobs / num_workers; 33 | torch::Tensor value = worker_and_job_to_score.clone(); 34 | int counter = 0; 35 | torch::Tensor max_value = worker_and_job_to_score.max(); 36 | 37 | torch::Tensor bid_indices; 38 | torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs}); 39 | torch::Tensor bids = 40 | worker_and_job_to_score.new_empty({num_workers, num_jobs}); 41 | torch::Tensor bid_increments = 42 | worker_and_job_to_score.new_empty({num_workers, jobs_per_worker}); 43 | torch::Tensor top_values = 44 | worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1}); 45 | torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs}); 46 | 47 | torch::Tensor top_index = top_values.to(torch::kLong); 48 | torch::Tensor high_bidders = top_index.new_empty({num_jobs}); 49 | torch::Tensor have_bids = high_bidders.to(torch::kBool); 50 | torch::Tensor jobs_indices = 51 | torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device)); 52 | torch::Tensor true_tensor = 53 | torch::ones({1}, torch::dtype(torch::kBool).device(device)); 54 | 55 | while (true) { 56 | bids.zero_(); 57 | torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1); 58 | 59 | // Each worker bids the difference in value between that job and the k+1th 60 | // job 61 | torch::sub_out( 62 | bid_increments, 63 | top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}), 64 | top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1)); 65 | 66 | bid_increments.add_(epsilon); 67 | bids.scatter_( 68 | 1, 69 | top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}), 70 | bid_increments); 71 | 72 | if (counter < max_iterations && counter > 0) { 73 | // Put in a minimal bid to retain items from the last round if no-one else 74 | // bids for them this round 75 | bids.view(-1).index_put_({bid_indices}, epsilon); 76 | } 77 | 78 | // Find the highest bidding worker per job 79 | torch::max_out(high_bids, high_bidders, bids, 0); 80 | torch::gt_out(have_bids, high_bids, 0); 81 | 82 | if (have_bids.all().item()) { 83 | // All jobs were bid for 84 | break; 85 | } 86 | 87 | // Make popular items more expensive 88 | cost.add_(high_bids); 89 | torch::sub_out(value, worker_and_job_to_score, cost); 90 | 91 | bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids}); 92 | 93 | if (counter < max_iterations) { 94 | // Make sure that this item will be in the winning worker's top-k next 95 | // time. 96 | value.view(-1).index_put_({bid_indices}, max_value); 97 | } else { 98 | // Suboptimal approximation that converges quickly from current solution 99 | value.view(-1).index_put_( 100 | {bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices})); 101 | } 102 | 103 | counter += 1; 104 | } 105 | 106 | return {top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}) 107 | .reshape(-1), counter}; 108 | } 109 | 110 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 111 | m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment"); 112 | } -------------------------------------------------------------------------------- /faster_moefication/balanced_assignment/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='balanced_assignment', 6 | ext_modules=[ 7 | CUDAExtension('balanced_assignment', [ 8 | 'ba.cpp', 9 | # add other source files if needed 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /faster_moefication/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.distributed as dist 6 | import torch.distributed 7 | import torch.multiprocessing as mp 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from moefication import kmeans_balanced 12 | 13 | dist.init_process_group(backend="nccl") 14 | rank = torch.distributed.get_rank() 15 | torch.cuda.set_device(rank) 16 | 17 | if rank == 0: 18 | w_in_s = [torch.rand(768*8, 768).cuda() for _ in range(8)] 19 | w_out_s = [torch.rand(768, 768*8).cuda() for _ in range(8)] 20 | else: 21 | w_in_s = [torch.zeros(768*8, 768).cuda() for _ in range(8)] 22 | 23 | torch.distributed.barrier() 24 | start_time = time.time() 25 | 26 | for w_in in w_in_s: 27 | dist.broadcast(w_in, src=0) 28 | 29 | centers, labels, obj = kmeans_balanced(w_in_s[rank].cpu().numpy(), 32, 192) 30 | print(labels) 31 | 32 | if rank == 0: 33 | gather_list = [torch.zeros_like(labels) for _ in range(8)] 34 | else: 35 | gather_list = None 36 | 37 | dist.gather(labels, gather_list=gather_list, dst=0) 38 | 39 | torch.distributed.barrier() 40 | end_time = time.time() 41 | 42 | if rank == 0: 43 | print("Clustering time:", (end_time - start_time)) 44 | 45 | for layer, labels in enumerate(gather_list): 46 | labels = torch.tensor(labels).cpu() 47 | w_in_ = w_in_s[layer].cpu() 48 | w_out = w_out_s[layer].cpu() 49 | tmp_in = [] 50 | tmp_in_norm = [] 51 | for i in range(32): 52 | tmp_in.append(w_in[labels == i, :]) 53 | tmp_in_norm.append(w_in_[labels == i, :]) 54 | tmp_in = torch.stack(tmp_in, dim=0) 55 | tmp_in_norm = np.stack(tmp_in_norm, axis=0) 56 | 57 | tmp_out = [] 58 | for i in range(32): 59 | tmp_out.append(w_out[:, labels == i].transpose(0, 1)) 60 | tmp_out = torch.cat(tmp_out, dim=0) 61 | 62 | # wg = tmp_in.mean(1) 63 | wg = tmp_in_norm.mean(1) 64 | wg = torch.tensor(wg, dtype=tmp_in.dtype, device=tmp_in.device) 65 | wg_norm = wg / torch.norm(wg, dim=-1, keepdim=True) * tmp_in[0, 0, :].norm() 66 | 67 | end_time = time.time() 68 | print("Total time:", (end_time - start_time)) -------------------------------------------------------------------------------- /faster_moefication/moefication.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed 3 | import tqdm 4 | import random 5 | import balanced_assignment 6 | from faiss import Kmeans 7 | import time 8 | 9 | def cal_intertia_(data, centroids, labels): 10 | dist_ = torch.cdist(data, centroids) 11 | dist = dist_.gather(1, labels.unsqueeze(1)).squeeze(1) 12 | dist = dist.pow(2) 13 | return dist.sum(), dist_ 14 | 15 | def update_centers_(data, labels, centroids): 16 | for i in range(centroids.shape[0]): 17 | centroids[i] = data[labels == i].mean(0) 18 | 19 | 20 | 21 | def kmeans_single(dat, k, size, niter=100, init=None): 22 | hidden_size = dat.shape[1] 23 | kmeans = Kmeans(hidden_size, k, gpu=1, nredo=1, seed=0) 24 | dat_cuda = torch.tensor(dat).cuda() 25 | if init is not None: 26 | cluster_centers_ = torch.zeros(k, hidden_size).cuda() 27 | update_centers_(dat_cuda, init, cluster_centers_) 28 | cluster_centers_ = cluster_centers_.cpu().numpy() 29 | else: 30 | cluster_centers_ = None 31 | kmeans.train(dat, init_centroids=cluster_centers_) 32 | 33 | dat = dat_cuda 34 | cluster_centers_ = torch.tensor(kmeans.centroids).cuda() 35 | 36 | labels_raw = torch.arange(dat.shape[0]).cuda() 37 | 38 | obj = 0 39 | last_centers = None 40 | labels = None 41 | 42 | dists = torch.cdist(dat, cluster_centers_) * -1 43 | for i in range(niter): 44 | # E step 45 | sorted_id, _ = balanced_assignment.balanced_assignment(dists, size*2) 46 | 47 | tmp_labels = torch.ones_like(sorted_id) 48 | tmp_labels[sorted_id] = torch.div(labels_raw, size, rounding_mode='floor') 49 | 50 | # M step 51 | last_centers = cluster_centers_.clone() 52 | update_centers_(dat, tmp_labels, cluster_centers_) 53 | 54 | tmp_obj, dists = cal_intertia_(dat, cluster_centers_, tmp_labels) 55 | tmp_obj = tmp_obj.item() 56 | dists = dists * -1 57 | if obj == 0 or tmp_obj < obj: 58 | obj = tmp_obj 59 | labels = tmp_labels 60 | 61 | return last_centers, labels, obj 62 | 63 | def kmeans_balanced(dat, k, size, niter=100, nredo=1): 64 | best_obj = 0 65 | cluster_centers = None 66 | labels = None 67 | for i in range(nredo): 68 | tmp_cluster_centers, tmp_labels, tmp_obj = kmeans_single(dat, k, size, niter) 69 | if best_obj == 0 or tmp_obj < best_obj: 70 | best_obj = tmp_obj 71 | cluster_centers = tmp_cluster_centers 72 | labels = tmp_labels 73 | return cluster_centers, labels, best_obj 74 | 75 | source_in_temp_orig = 'layers.{}.ffn.ffn.w_in.w_0.weight' # 'layers.{}.ffn.ffn.w_in.w.weight' 76 | source_out_temp_orig = 'layers.{}.ffn.ffn.w_out.weight' 77 | target_in_temp_orig = 'layers.{}.ffn.ffn.mlp.batched_fc1_w' #'layers.{}.ffn.ffn.experts.batched_fc1_w' 78 | target_out_temp_orig = 'layers.{}.ffn.ffn.mlp.batched_fc2_w' #'layers.{}.ffn.ffn.experts.batched_fc2_w' 79 | target_wg_temp_orig = 'layers.{}.ffn.ffn.router.wg' #'layers.{}.ffn.ffn.gate.wg' 80 | # layers = [i for i in range(12)] 81 | # model_types = ['encoder.', 'decoder.'] 82 | 83 | 84 | def init_ckpt(ckpt, split_num, layers, model_types): 85 | for model_type in model_types: 86 | source_in_temp = model_type + source_in_temp_orig 87 | source_out_temp = model_type + source_out_temp_orig 88 | target_in_temp = model_type + target_in_temp_orig 89 | target_out_temp = model_type + target_out_temp_orig 90 | target_wg_temp = model_type + target_wg_temp_orig 91 | 92 | for layer in tqdm.tqdm(layers): 93 | w_in = ckpt[source_in_temp.format(layer)] 94 | w_out = ckpt[source_out_temp.format(layer)] 95 | 96 | ckpt[target_in_temp.format(layer)] = w_in 97 | ckpt[target_out_temp.format(layer)] = w_out 98 | 99 | wg = w_in.view(split_num, -1, w_in.shape[-1]).mean(dim=1) 100 | wg_norm = wg / torch.norm(wg, dim=-1, keepdim=True) * w_in[0, :].norm() 101 | 102 | ckpt[target_wg_temp.format(layer)] = wg_norm.transpose(0, 1) 103 | 104 | del ckpt[source_in_temp.format(layer)] 105 | del ckpt[source_out_temp.format(layer)] 106 | 107 | return ckpt 108 | 109 | import sklearn 110 | import numpy as np 111 | 112 | def split_ckpt(ckpt, split_num, layers, model_types, additional=False, structures=None): 113 | 114 | permutes = [] 115 | idx = 0 116 | for model_type in model_types: 117 | if not additional: 118 | source_in_temp = model_type + source_in_temp_orig 119 | source_out_temp = model_type + source_out_temp_orig 120 | else: 121 | source_in_temp = model_type + target_in_temp_orig 122 | source_out_temp = model_type + target_out_temp_orig 123 | target_in_temp = model_type + target_in_temp_orig 124 | target_out_temp = model_type + target_out_temp_orig 125 | target_wg_temp = model_type + target_wg_temp_orig 126 | 127 | for layer in tqdm.tqdm(layers): 128 | w_in = ckpt[source_in_temp.format(layer)] 129 | w_out = ckpt[source_out_temp.format(layer)] 130 | hidden_size = w_in.shape[0] 131 | expert_size = hidden_size // split_num 132 | w_in_ = sklearn.preprocessing.normalize(w_in.float().numpy()) 133 | 134 | if structures is None: 135 | centers, labels, obj = kmeans_balanced(w_in_, split_num, expert_size) 136 | # centers, labels, obj = kmeans_balanced(w_in.float().numpy(), split_num, expert_size) 137 | else: 138 | labels = structures[idx] 139 | idx += 1 140 | 141 | labels = torch.tensor(labels).cpu() 142 | tmp_in = [] 143 | tmp_in_norm = [] 144 | for i in range(split_num): 145 | tmp_in.append(w_in[labels == i, :]) 146 | tmp_in_norm.append(w_in_[labels == i, :]) 147 | tmp_in = torch.stack(tmp_in, dim=0) 148 | tmp_in_norm = np.stack(tmp_in_norm, axis=0) 149 | 150 | tmp_out = [] 151 | for i in range(split_num): 152 | tmp_out.append(w_out[:, labels == i].transpose(0, 1)) 153 | tmp_out = torch.cat(tmp_out, dim=0) 154 | 155 | # wg = tmp_in.mean(1) 156 | wg = tmp_in_norm.mean(1) 157 | wg = torch.tensor(wg, dtype=tmp_in.dtype, device=tmp_in.device) 158 | wg_norm = wg / torch.norm(wg, dim=-1, keepdim=True) * tmp_in[0, 0, :].norm() 159 | 160 | ckpt[target_wg_temp.format(layer)] = wg_norm.transpose(0, 1) 161 | 162 | if not additional: 163 | tmp_in = tmp_in.view(-1, tmp_in.shape[-1]) 164 | 165 | ckpt[target_in_temp.format(layer)] = tmp_in 166 | ckpt[target_out_temp.format(layer)] = tmp_out 167 | 168 | del ckpt[source_in_temp.format(layer)] 169 | del ckpt[source_out_temp.format(layer)] 170 | else: 171 | permute = [] 172 | for i in range(split_num): 173 | permute.append((labels == i).nonzero().squeeze(1)) 174 | permute = torch.cat(permute, dim=0) 175 | permutes.append(permute) 176 | 177 | if additional: 178 | return ckpt, permutes 179 | else: 180 | return ckpt 181 | 182 | def get_labels(x, k, init=None): 183 | centers, labels, obj = kmeans_single(x, k, x.shape[0] // k, init=init) 184 | labels = labels.cpu().numpy() 185 | return labels, obj 186 | 187 | def cal_structure(ckpt, split_num, layers, model_types, prev_structure): 188 | labels = [] 189 | for model_type in model_types: 190 | source_in_temp = model_type + source_in_temp_orig 191 | for layer in tqdm.tqdm(layers): 192 | w_in = ckpt[source_in_temp.format(layer)] 193 | # a = w_in.float().numpy() 194 | a = sklearn.preprocessing.normalize(w_in.float().numpy()) 195 | 196 | if prev_structure is None: 197 | l, obj = get_labels(a, split_num) 198 | else: 199 | l, obj = get_labels(a, split_num, init=prev_structure[len(labels)]) 200 | l_random, obj_ = get_labels(a, split_num) 201 | if obj_ < obj: 202 | l = l_random 203 | labels.append(l) 204 | return labels 205 | 206 | def merge_ckpt(ckpt, layers, model_types): 207 | for model_type in model_types: 208 | source_in_temp = model_type + source_in_temp_orig 209 | source_out_temp = model_type + source_out_temp_orig 210 | target_in_temp = model_type + target_in_temp_orig 211 | target_out_temp = model_type + target_out_temp_orig 212 | target_wg_temp = model_type + target_wg_temp_orig 213 | 214 | for layer in tqdm.tqdm(layers): 215 | w_in = ckpt[target_in_temp.format(layer)] 216 | w_out = ckpt[target_out_temp.format(layer)] 217 | # wg = ckpt[target_wg_temp.format(layer)] 218 | 219 | # w_in = w_in.reshape(-1, w_in.shape[-1]) 220 | w_out = w_out.transpose(0, 1) 221 | # w_out = w_out.reshape(w_out.shape[0], -1) 222 | 223 | ckpt[source_in_temp.format(layer)] = w_in 224 | ckpt[source_out_temp.format(layer)] = w_out 225 | 226 | del ckpt[target_in_temp.format(layer)] 227 | del ckpt[target_out_temp.format(layer)] 228 | del ckpt[target_wg_temp.format(layer)] 229 | 230 | return ckpt 231 | 232 | def split_ckpt_random(ckpt, split_num, layers, model_types, additional=False, structures=None): 233 | for model_type in model_types: 234 | source_in_temp = model_type + source_in_temp_orig 235 | source_out_temp = model_type + source_out_temp_orig 236 | target_in_temp = model_type + target_in_temp_orig 237 | target_out_temp = model_type + target_out_temp_orig 238 | target_wg_temp = model_type + target_wg_temp_orig 239 | 240 | for layer in tqdm.tqdm(layers): 241 | w_in = ckpt[source_in_temp.format(layer)] 242 | w_out = ckpt[source_out_temp.format(layer)] 243 | hidden_size = w_in.shape[0] 244 | expert_size = hidden_size // split_num 245 | # equally split 246 | labels = [i for i in range(split_num) for _ in range(expert_size)] 247 | # shuffle 248 | random.shuffle(labels) 249 | labels = torch.tensor(labels) 250 | 251 | tmp_in = [] 252 | for i in range(split_num): 253 | tmp_in.append(w_in[labels == i, :]) 254 | tmp_in = torch.stack(tmp_in, dim=0) 255 | 256 | tmp_out = [] 257 | for i in range(split_num): 258 | tmp_out.append(w_out[:, labels == i]) 259 | tmp_out = torch.stack(tmp_out, dim=0) 260 | 261 | ckpt[target_in_temp.format(layer)] = tmp_in 262 | ckpt[target_out_temp.format(layer)] = tmp_out 263 | 264 | wg = tmp_in.mean(1) 265 | wg = torch.randn_like(wg) 266 | ckpt[target_wg_temp.format(layer)] = wg.transpose(0, 1) 267 | 268 | del ckpt[source_in_temp.format(layer)] 269 | del ckpt[source_out_temp.format(layer)] 270 | 271 | return ckpt 272 | -------------------------------------------------------------------------------- /moefication/adj.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import numpy as np 3 | import torch 4 | import tqdm 5 | import sys 6 | import os 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument('--model_path', type=str, default='results/t5-base/ckpt.bin', help='path to the model checkpoint') 12 | parser.add_argument('--res_path', type=str, default='results/t5-base/', help='path to store the results of moefication') 13 | parser.add_argument('--num-layer', type=int, default=12, help='number of layers') 14 | parser.add_argument('--templates', type=str, default='encoder.blocks.{}.ff.dense_relu_dense.wi.weight,decoder.blocks.{}.ff.dense_relu_dense.wi.weight', help='weight names of the first linear layer in each FFN (use comma to separate multiple templates)') 15 | parser.add_argument('--num-gpu', type=int, default=4, help='number of gpus') 16 | 17 | args = parser.parse_args() 18 | 19 | num_layer = args.num_layer 20 | batch_size = 8 21 | max_instance = 200000 22 | 23 | def run(proc_id): 24 | proc_id, model_path, res_path, template, num_gpu = proc_id 25 | cuda_dev = torch.device('cuda:{}'.format(proc_id)) 26 | for layer in range(num_layer): 27 | layer_id = + layer 28 | if layer_id % num_gpu != proc_id: 29 | continue 30 | 31 | ffn = torch.tensor(utils.load_ffn_weight(model_path, template, layer)) 32 | hidden = utils.load_hidden_states(res_path, layer) 33 | hidden = torch.cat(hidden, 0).transpose(1, 2).reshape(-1, 4096) 34 | 35 | cnt = 0 36 | adj = torch.zeros(ffn.shape[0], ffn.shape[0], device=cuda_dev).float() 37 | ffn = torch.tensor(ffn).to(cuda_dev).transpose(0, 1) 38 | for i in tqdm.tqdm(range(hidden.shape[0]//batch_size)): 39 | with torch.no_grad(): 40 | dat = hidden[i*batch_size:(i+1)*batch_size].to(cuda_dev) 41 | res = torch.nn.functional.relu(torch.matmul(dat, ffn)).unsqueeze(-1) 42 | res = torch.clamp(torch.bmm(res, res.transpose(1, 2)).sum(0), max=1) 43 | adj += res 44 | 45 | cnt += batch_size 46 | if cnt > max_instance: 47 | break 48 | del hidden 49 | 50 | adj = adj.cpu().numpy() 51 | target = os.path.join(res_path, template.format(layer)) 52 | 53 | threshold = 0 54 | pos = 10 55 | while threshold == 0: 56 | assert pos != 110 57 | threshold = np.percentile(adj.reshape(-1), pos) 58 | pos += 10 59 | print("threshold", threshold, layer_id, pos, adj.max()) 60 | threshold = threshold * 0.99 61 | adj /= threshold 62 | 63 | with open(target, "w") as fout: 64 | edges = 0 65 | for i in range(adj.shape[0]): 66 | cnt = 0 67 | for j in range(adj.shape[1]): 68 | if i == j or adj[i, j] < 1: 69 | pass 70 | else: 71 | cnt += 1 72 | edges += cnt 73 | assert edges > 0 74 | fout.write("{} {} {}\n".format(adj.shape[0], edges // 2, "001")) 75 | for i in range(adj.shape[0]): 76 | vec = [] 77 | for j in range(adj.shape[1]): 78 | if i == j or adj[i, j] < 1: 79 | pass 80 | else: 81 | val = int(adj[i, j]) 82 | vec.append([j+1, val]) 83 | fout.write(" ".join(["{} {}".format(x[0], x[1]) for x in vec]) + "\n") 84 | 85 | import multiprocessing 86 | templates = args.templates.split(',') 87 | 88 | for template in templates: 89 | pool = multiprocessing.Pool(processes=args.num_gpu) 90 | pool.map(run, [(i, args.model_path, args.res_path, template, args.num_gpu) for i in range(args.num_gpu)]) 91 | pool.close() 92 | pool.join() 93 | 94 | -------------------------------------------------------------------------------- /moefication/mlp_select_example.py: -------------------------------------------------------------------------------- 1 | from tempfile import template 2 | import utils 3 | import numpy as np 4 | import torch 5 | import sys 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument('--model_path', type=str, default='results/t5-base/ckpt.bin', help='path to the model checkpoint') 11 | parser.add_argument('--res_path', type=str, default='results/t5-base/', help='path to store the results of moefication') 12 | parser.add_argument('--num-layer', type=int, default=12, help='number of layers') 13 | parser.add_argument('--num-expert', type=int, default=96, help='number of experts') 14 | parser.add_argument('--templates', type=str, default='encoder.block.{}.layer.1.DenseReluDense.wi.weight,decoder.block.{}.layer.2.DenseReluDense.wi.weight', help='weight names of the first linear layer in each FFN (use comma to separate multiple templates)') 15 | 16 | args = parser.parse_args() 17 | 18 | config = utils.ModelConfig(args.model_path, args.res_path, split_num=args.num_expert) 19 | 20 | templates = args.templates.split(',') 21 | for template in templates: 22 | for i in range(args.num_layer): 23 | center = utils.MLPCenter(config, template, '{}/param_split/{}'.format(args.res_path, template.format(i)), i) 24 | center.cal_center() 25 | -------------------------------------------------------------------------------- /moefication/param_cluster_example.py: -------------------------------------------------------------------------------- 1 | from ast import arg 2 | from re import template 3 | import os 4 | import utils 5 | import numpy as np 6 | import torch 7 | import argparse 8 | import tqdm 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument('--model_path', type=str, default='results/t5-base/ckpt.bin', help='path to the model checkpoint') 13 | parser.add_argument('--res_path', type=str, default='results/t5-base/', help='path to store the results of moefication') 14 | parser.add_argument('--num-layer', type=int, default=12, help='number of layers') 15 | parser.add_argument('--num-expert', type=int, default=96, help='number of experts') 16 | parser.add_argument('--templates', type=str, default='encoder.block.{}.layer.1.DenseReluDense.wi.weight,decoder.block.{}.layer.2.DenseReluDense.wi.weight', help='weight names of the first linear layer in each FFN (use comma to separate multiple templates)') 17 | 18 | args = parser.parse_args() 19 | if not os.path.exists(args.res_path): 20 | os.makedirs(args.res_path) 21 | 22 | config = utils.ModelConfig(args.model_path, args.res_path, split_num=args.num_expert) 23 | 24 | templates = args.templates.split(',') 25 | 26 | for template in templates: 27 | for i in tqdm.tqdm(range(args.num_layer)): 28 | split = utils.ParamSplit(config, template, i) 29 | split.split() 30 | split.cnt() 31 | split.save() 32 | -------------------------------------------------------------------------------- /moefication/similarity_select_example.py: -------------------------------------------------------------------------------- 1 | from re import template 2 | import utils 3 | import numpy as np 4 | import torch 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument('--model_path', type=str, default='results/t5-base/ckpt.bin', help='path to the model checkpoint') 10 | parser.add_argument('--res_path', type=str, default='results/t5-base/', help='path to store the results of moefication') 11 | parser.add_argument('--num-layer', type=int, default=12, help='number of layers') 12 | parser.add_argument('--num-expert', type=int, default=96, help='number of experts') 13 | parser.add_argument('--templates', type=str, default='encoder.blocks.{}.ff.dense_relu_dense.wi.weight,decoder.blocks.{}.ff.dense_relu_dense.wi.weight', help='weight names of the first linear layer in each FFN (use comma to separate multiple templates)') 14 | 15 | args = parser.parse_args() 16 | 17 | config = utils.ModelConfig(args.model_path, args.res_path, split_num=args.num_expert) 18 | 19 | templates = args.templates.split(',') 20 | for template in templates: 21 | for i in range(args.num_layer): 22 | center = utils.ParamCenter(config, '{}/param_split/{}.model'.format(args.res_path, template.format(i))) 23 | center.cal_center() 24 | center.save() 25 | -------------------------------------------------------------------------------- /moefication/trans_gp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from collections import defaultdict 4 | import random 5 | import os 6 | 7 | filename = sys.argv[1] 8 | 9 | labels = [] 10 | 11 | with open(filename) as fin: 12 | d = defaultdict(list) 13 | for i, line in enumerate(fin): 14 | labels.append(int(line.strip())) 15 | d[labels[-1]].append(i) 16 | 17 | need_move = [] 18 | 19 | for i in range(max(d.keys())+1): 20 | if i not in d: 21 | d[i] = [] 22 | print(len(labels), len(d.keys())) 23 | 24 | num = len(labels) // len(d.keys()) 25 | for k, v in d.items(): 26 | if len(v) > num: 27 | random.shuffle(v) 28 | for i in range(num, len(v)): 29 | need_move.append(v[i]) 30 | d[k] = v[:num] 31 | 32 | print("need_move", need_move) 33 | 34 | random.shuffle(need_move) 35 | for k, v in d.items(): 36 | if len(v) < num: 37 | pos = num-len(v) 38 | v += need_move[:pos] 39 | need_move = need_move[pos:] 40 | for x in v: 41 | labels[x] = k 42 | 43 | vec = os.path.basename(filename).split('.')[:-2] 44 | target = '.'.join(vec) 45 | 46 | save_folder = os.path.join(os.path.dirname(filename), 'gp_split') 47 | 48 | if not os.path.exists(save_folder): 49 | os.makedirs(save_folder) 50 | 51 | torch.save(labels, os.path.join(save_folder, target)) 52 | 53 | from collections import Counter 54 | 55 | print(Counter(labels)) 56 | -------------------------------------------------------------------------------- /moefication/utils.py: -------------------------------------------------------------------------------- 1 | from typing import DefaultDict 2 | import sys 3 | import torch 4 | import os 5 | import tqdm 6 | from collections import Counter 7 | import sklearn 8 | from sklearn.linear_model import LogisticRegression 9 | from sklearn.preprocessing import MultiLabelBinarizer 10 | import numpy as np 11 | from k_means_constrained import KMeansConstrained 12 | 13 | def get_layer_num(filename): 14 | model = torch.load(filename, map_location='cpu')['module'] 15 | enc_keys = [x for x in model.keys() if 'ff.dense_relu_dense.wi.weight' in x and 'encoder' in x] 16 | dec_keys = [x for x in model.keys() if 'ff.dense_relu_dense.wi.weight' in x and 'decoder' in x] 17 | 18 | enc_nums = [int(x.split('.')[2]) for x in enc_keys] 19 | dec_nums = [int(x.split('.')[2]) for x in dec_keys] 20 | 21 | return max(enc_nums)+1, max(dec_nums)+1 22 | 23 | def load_ffn_weight(filename, template, layer): 24 | 25 | model = torch.load(filename, map_location='cpu') 26 | key = template.format(layer) 27 | 28 | return model[key].numpy() 29 | 30 | def load_hidden_states(folder, filename): 31 | target = os.path.join(folder, filename) 32 | vecs = torch.load(target) 33 | return vecs 34 | 35 | class ModelConfig: 36 | 37 | def __init__(self, filename, folder, split_num): 38 | self.filename = filename 39 | self.folder = folder 40 | self.split_num = split_num 41 | 42 | class LayerSplit: 43 | 44 | def __init__(self, config : ModelConfig, template, layer=0): 45 | self.config = config 46 | self.layer = layer 47 | self.template = template 48 | 49 | def split(self): 50 | pass 51 | 52 | def save(self): 53 | save_folder = os.path.join(self.config.folder, self.type) 54 | 55 | if not os.path.exists(save_folder): 56 | os.makedirs(save_folder) 57 | 58 | filename = os.path.join(save_folder, self.template.format(self.layer)) 59 | torch.save(self.labels, filename) 60 | 61 | def cnt(self): 62 | print(Counter(self.labels)) 63 | 64 | def load_param(self): 65 | self.ffn_weight = load_ffn_weight(self.config.filename, self.template, self.layer) 66 | self.neuron_num = self.ffn_weight.shape[0] 67 | self.split_size = self.neuron_num // self.config.split_num 68 | assert self.split_size * self.config.split_num == self.neuron_num 69 | 70 | class RandomSplit(LayerSplit): 71 | 72 | def __init__(self, config: ModelConfig, layer=0, is_encoder=True): 73 | super().__init__(config, layer=layer, is_encoder=is_encoder) 74 | self.type = 'random_split' 75 | 76 | def split(self): 77 | self.load_param() 78 | 79 | self.labels = [i // self.split_size for i in range(self.neuron_num)] 80 | 81 | class ParamSplit(LayerSplit): 82 | 83 | def __init__(self, config: ModelConfig, template, layer=0): 84 | super().__init__(config, template=template, layer=layer) 85 | self.type = 'param_split' 86 | 87 | def split(self): 88 | self.load_param() 89 | ffn_weight_norm = sklearn.preprocessing.normalize(self.ffn_weight) 90 | 91 | kmeans = KMeansConstrained(n_clusters=self.config.split_num, size_min=self.split_size, size_max=self.split_size, random_state=0).fit(ffn_weight_norm, None) 92 | 93 | self.labels = [x for x in kmeans.labels_] 94 | 95 | class BlockCenter: 96 | 97 | def __init__(self, config, template, filename, layer): 98 | self.config = config 99 | self.filename = filename 100 | self.labels = torch.load(filename) 101 | self.template = template 102 | 103 | self.layer = layer 104 | 105 | def cal_center(self): 106 | pass 107 | 108 | def save(self): 109 | print(self.centers.shape) 110 | torch.save(self.centers, "{}_{}".format(self.filename, self.type)) 111 | self.save_acc() 112 | 113 | def save_acc(self): 114 | with open("{}_{}_acc".format(self.filename, self.type), 'w') as fout: 115 | fout.write(str(self.acc)) 116 | 117 | class RandomCenter(BlockCenter): 118 | 119 | def __init__(self, config, filename): 120 | super().__init__(config, filename) 121 | self.type = "random" 122 | 123 | def cal_center(self): 124 | ffn_weight = load_ffn_weight(self.config.filename, self.layer, self.is_encoder) 125 | ffn_weight_norm = ffn_weight 126 | 127 | d = {} 128 | for i, x in enumerate(self.labels): 129 | if x not in d: 130 | d[x] = ffn_weight_norm[i, :] 131 | centers = sorted(list(d.items()), key=lambda x: x[0]) 132 | 133 | self.centers = sklearn.preprocessing.normalize(np.array([x[1] for x in centers])) 134 | self.acc = 0 135 | 136 | class ParamCenter(BlockCenter): 137 | 138 | def __init__(self, config, filename, layer): 139 | super().__init__(config, filename, layer) 140 | self.type = "param" 141 | 142 | def cal_center(self): 143 | ffn_weight = load_ffn_weight(self.config.filename, self.layer, self.is_encoder) 144 | ffn_weight_norm = sklearn.preprocessing.normalize(ffn_weight) 145 | 146 | centers = [] 147 | num_blocks = max(self.labels) + 1 148 | for i in range(num_blocks): 149 | centers.append(ffn_weight_norm[np.array(self.labels) == i, :].mean(0)) 150 | 151 | centers = np.array(centers) 152 | self.centers = centers 153 | 154 | centers = torch.tensor(centers).cuda().unsqueeze(0) 155 | 156 | patterns = [] 157 | for i in range(num_blocks): 158 | patterns.append(np.array(self.labels) == i) 159 | patterns = torch.Tensor(patterns).cuda().float().transpose(0, 1) # 4096, num_blocks 160 | 161 | acc = [] 162 | hiddens = load_hidden_states(self.config.folder, self.template.format(self.layer)) 163 | hiddens = torch.cat(hiddens, 0).float() 164 | hiddens = hiddens.view(-1, hiddens.shape[-1]) 165 | hiddens = hiddens / torch.norm(hiddens, dim=-1).unsqueeze(-1) 166 | num = hiddens.shape[0] 167 | 168 | ffn_weight = torch.tensor(ffn_weight).cuda().transpose(0, 1).float() 169 | for i in range(num // 10 * 9, num, 512): 170 | with torch.no_grad(): 171 | input = hiddens[i:i+512, :].cuda() 172 | acts = torch.relu((torch.matmul(input, ffn_weight))) # 512, 4096 173 | scores = torch.matmul(acts, patterns) # 512, num_blocks, vary from 0 to 1 174 | labels = torch.topk(scores, k=25, dim=-1)[1] 175 | 176 | input = input / torch.norm(input, dim=-1).unsqueeze(-1) 177 | dist = -1 * torch.norm(input.unsqueeze(1).expand(-1, num_blocks, -1) - centers, dim=-1) 178 | pred = torch.topk(dist, k=25, dim=-1)[1] 179 | 180 | for x, y in zip(labels, pred): 181 | x = set(x.cpu().numpy()) 182 | y = set(y.cpu().numpy()) 183 | acc.append(len(x & y) / 25) 184 | print("param acc", np.mean(acc)) 185 | sys.stdout.flush() 186 | self.acc = np.mean(acc) 187 | 188 | class MLPCenter(BlockCenter): 189 | def __init__(self, config, template, filename, layer): 190 | super().__init__(config, template, filename, layer) 191 | self.type = "input_compl" 192 | 193 | def cal_center(self): 194 | ffn_weight = load_ffn_weight(self.config.filename, self.template, self.layer) 195 | ffn_weight_norm_ = sklearn.preprocessing.normalize(ffn_weight) 196 | centers = [] 197 | num_blocks = max(self.labels) + 1 198 | for i in range(num_blocks): 199 | centers.append(ffn_weight_norm_[np.array(self.labels) == i, :].mean(0)) 200 | centers = np.array(centers) # num_blocks, 1024 201 | 202 | ffn_weight = torch.tensor(ffn_weight).cuda().transpose(0, 1).float() 203 | patterns = [] 204 | num_blocks = max(self.labels) + 1 205 | for i in range(num_blocks): 206 | patterns.append(np.array(self.labels) == i) 207 | patterns = torch.Tensor(patterns).cuda().float().transpose(0, 1) 208 | 209 | hiddens = load_hidden_states(self.config.folder, self.template.format(self.layer)) 210 | 211 | hiddens = hiddens / torch.norm(hiddens, dim=-1).unsqueeze(-1) 212 | 213 | model = torch.nn.Sequential(torch.nn.Linear(hiddens.shape[-1], num_blocks, bias=False), 214 | torch.nn.Tanh(), 215 | torch.nn.Linear(num_blocks, num_blocks, bias=False)) 216 | 217 | def weights_init(m): 218 | if isinstance(m, torch.nn.Linear): 219 | if m.weight.shape[-1] == hiddens.shape[-1]: 220 | m.weight.data = torch.from_numpy(centers).float() 221 | else: 222 | m.weight.data = torch.eye(m.weight.data.shape[0]) 223 | #torch.nn.init.normal_(m.weight.data) 224 | #m.bias.data[:] = 0 225 | 226 | model.apply(weights_init) 227 | 228 | model.cuda() 229 | 230 | optim = torch.optim.Adam(model.parameters(), lr=0.01) 231 | 232 | loss_func = torch.nn.BCEWithLogitsLoss() 233 | 234 | save_acc = [0, 0] 235 | save_epoch = [-1, -1] 236 | 237 | self.centers = model 238 | 239 | train_hiddens = hiddens[:hiddens.shape[0] // 10 * 9, :] 240 | #pos_max = None 241 | 242 | last_epoch = -1 243 | 244 | for epoch in range(30): 245 | train_hiddens=train_hiddens[torch.randperm(train_hiddens.size()[0])] 246 | 247 | pbar = tqdm.tqdm(range(0, train_hiddens.shape[0], 512)) 248 | for i in pbar: 249 | model.zero_grad() 250 | 251 | input = train_hiddens[i:i+512, :].float().clone().detach().cuda() 252 | with torch.no_grad(): 253 | acts = torch.relu((torch.matmul(input, ffn_weight))).float() 254 | scores = torch.matmul(acts, patterns) 255 | scores /= scores.max() 256 | pred = model(input) 257 | loss = loss_func(pred.view(-1), scores.view(-1)) 258 | 259 | loss.backward() 260 | optim.step() 261 | 262 | pbar.set_description("loss: {:.4f}".format(loss.item())) 263 | 264 | acc = [] 265 | 266 | for i in range(hiddens.shape[0] // 10 * 9, hiddens.shape[0], 512): 267 | with torch.no_grad(): 268 | input = hiddens[i:i+512, :].float().cuda() 269 | acts = torch.relu((torch.matmul(input, ffn_weight))).float() # 512, 4096 270 | 271 | scores = torch.matmul(acts, patterns) # 512, num_blocks, vary from 0 to 1 272 | mask, labels = torch.topk(scores, k=int(num_blocks*0.2), dim=-1) 273 | mask = mask > 0 274 | 275 | pred = model(input) 276 | pred = torch.topk(pred, k=int(num_blocks*0.2), dim=-1)[1] 277 | 278 | for x, m, s in zip(pred, mask, scores): 279 | if m.sum().item() == 0: 280 | continue 281 | x = sum([s[xx] for xx in x.cpu()]).item() 282 | y = s.sum().item() 283 | acc.append( x / y) 284 | 285 | cur_acc = np.mean(acc) 286 | if cur_acc > save_acc[0]: 287 | self.del_ckpt(save_epoch[1]) 288 | save_acc = [cur_acc, save_acc[0]] 289 | save_epoch = [epoch, save_epoch[0]] 290 | print("input compl center acc", np.mean(acc)) 291 | self.acc = save_acc[1] 292 | sys.stdout.flush() 293 | self.save(epoch) 294 | elif cur_acc > save_acc[1]: 295 | self.del_ckpt(save_epoch[1]) 296 | save_acc = [save_acc[0], cur_acc] 297 | save_epoch = [save_epoch[0], epoch] 298 | print("input compl center acc", np.mean(acc)) 299 | self.acc = save_acc[1] 300 | sys.stdout.flush() 301 | self.save(epoch) 302 | os.system("rm -rf {}_{}_{}".format(self.filename, self.type, save_epoch[0])) 303 | os.system("cp {0}_{1}_{2} {0}_{1}".format(self.filename, self.type, save_epoch[1])) 304 | os.system("rm {0}_{1}_{2}".format(self.filename, self.type, save_epoch[1])) 305 | 306 | def del_ckpt(self, epoch): 307 | os.system("rm -rf {}_{}_{}".format(self.filename, self.type, epoch)) 308 | 309 | def save(self, epoch): 310 | print("input compl center save") 311 | torch.save(self.centers, "{}_{}_{}".format(self.filename, self.type, epoch)) 312 | self.save_acc() 313 | --------------------------------------------------------------------------------