├── .gitignore ├── __init__.py ├── data-bpe.zip ├── model.py ├── readme.md ├── task.py └── view.py /.gitignore: -------------------------------------------------------------------------------- 1 | data-bpe 2 | data-bin 3 | __pycache__ 4 | checkpoints 5 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from . import task 2 | from . import model 3 | -------------------------------------------------------------------------------- /data-bpe.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/voidmagic/parameter-differentiation/00b8338132db7f356d09f49ab89557f6caf73964/data-bpe.zip -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from fairseq.models import register_model, register_model_architecture 2 | from fairseq.models.multilingual_transformer import MultilingualTransformerModel 3 | from fairseq.models.transformer import base_architecture 4 | 5 | 6 | @register_model('parameter_differentiation_model') 7 | class ParameterDifferentiationModel(MultilingualTransformerModel): 8 | def __init__(self, encoders, decoders): 9 | super().__init__(encoders, decoders) 10 | shared_model = self.models[self.keys[0]] 11 | for key in self.keys[1:]: 12 | # share encoder 13 | for layer_idx in range(len(shared_model.encoder.layers)): 14 | self.models[key].encoder.layers[layer_idx].self_attn.k_proj = shared_model.encoder.layers[layer_idx].self_attn.k_proj 15 | self.models[key].encoder.layers[layer_idx].self_attn.v_proj = shared_model.encoder.layers[layer_idx].self_attn.v_proj 16 | self.models[key].encoder.layers[layer_idx].self_attn.q_proj = shared_model.encoder.layers[layer_idx].self_attn.q_proj 17 | self.models[key].encoder.layers[layer_idx].self_attn.out_proj = shared_model.encoder.layers[layer_idx].self_attn.out_proj 18 | self.models[key].encoder.layers[layer_idx].fc1 = shared_model.encoder.layers[layer_idx].fc1 19 | self.models[key].encoder.layers[layer_idx].fc2 = shared_model.encoder.layers[layer_idx].fc2 20 | self.models[key].encoder.layers[layer_idx].self_attn_layer_norm = shared_model.encoder.layers[layer_idx].self_attn_layer_norm 21 | self.models[key].encoder.layers[layer_idx].final_layer_norm = shared_model.encoder.layers[layer_idx].final_layer_norm 22 | 23 | # share decoder 24 | for layer_idx in range(len(shared_model.decoder.layers)): 25 | self.models[key].decoder.layers[layer_idx].self_attn.k_proj = shared_model.decoder.layers[layer_idx].self_attn.k_proj 26 | self.models[key].decoder.layers[layer_idx].self_attn.v_proj = shared_model.decoder.layers[layer_idx].self_attn.v_proj 27 | self.models[key].decoder.layers[layer_idx].self_attn.q_proj = shared_model.decoder.layers[layer_idx].self_attn.q_proj 28 | self.models[key].decoder.layers[layer_idx].self_attn.out_proj = shared_model.decoder.layers[layer_idx].self_attn.out_proj 29 | self.models[key].decoder.layers[layer_idx].encoder_attn.k_proj = shared_model.decoder.layers[layer_idx].encoder_attn.k_proj 30 | self.models[key].decoder.layers[layer_idx].encoder_attn.v_proj = shared_model.decoder.layers[layer_idx].encoder_attn.v_proj 31 | self.models[key].decoder.layers[layer_idx].encoder_attn.q_proj = shared_model.decoder.layers[layer_idx].encoder_attn.q_proj 32 | self.models[key].decoder.layers[layer_idx].encoder_attn.out_proj = shared_model.decoder.layers[layer_idx].encoder_attn.out_proj 33 | self.models[key].decoder.layers[layer_idx].fc1 = shared_model.decoder.layers[layer_idx].fc1 34 | self.models[key].decoder.layers[layer_idx].fc2 = shared_model.decoder.layers[layer_idx].fc2 35 | self.models[key].decoder.layers[layer_idx].self_attn_layer_norm = shared_model.decoder.layers[layer_idx].self_attn_layer_norm 36 | self.models[key].decoder.layers[layer_idx].encoder_attn_layer_norm = shared_model.decoder.layers[layer_idx].encoder_attn_layer_norm 37 | self.models[key].decoder.layers[layer_idx].final_layer_norm = shared_model.decoder.layers[layer_idx].final_layer_norm 38 | 39 | 40 | @classmethod 41 | def build_model(cls, args, task): 42 | model = super(ParameterDifferentiationModel, cls).build_model(args, task) 43 | encoders = {key: model.models[key].encoder for key in model.keys} 44 | decoders = {key: model.models[key].decoder for key in model.keys} 45 | return cls(encoders, decoders) 46 | 47 | 48 | @register_model_architecture("parameter_differentiation_model", "parameter_differentiation_tiny_model") 49 | def base_parameter_differentiation_architecture(args): 50 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 128) 51 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256) 52 | args.encoder_layers = getattr(args, "encoder_layers", 3) 53 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) 54 | args.decoder_layers = getattr(args, "decoder_layers", 3) 55 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) 56 | base_architecture(args) 57 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | The implementation of [Parameter Differentiation based Multilingual Neural Machine Translation](https://arxiv.org/abs/2112.13619). 2 | 3 | 4 | # Requirements 5 | 6 | ``` 7 | pip install fairseq==0.10.2 8 | conda install scikit-learn 9 | conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.1 -c pytorch 10 | ``` 11 | 12 | # Usage 13 | 14 | 1. Prepare data following [fairseq](https://github.com/pytorch/fairseq/tree/main/examples/translation#multilingual-translation): 15 | 16 | ``` 17 | unzip data-bpe.zip 18 | 19 | mkdir -p data-bin && cut -f1 data-bpe/bpe.vocab | tail -n +4 | sed "s/$/ 100/g" > data-bin/dict.en.txt 20 | 21 | for lang in es pt; do 22 | fairseq-preprocess --source-lang en --target-lang $lang \ 23 | --trainpref data-bpe/train.en-$lang \ 24 | --validpref data-bpe/valid.en-$lang \ 25 | --testpref data-bpe/test.en-$lang \ 26 | --destdir data-bin \ 27 | --srcdict data-bin/dict.en.txt \ 28 | --tgtdict data-bin/dict.en.txt 29 | done 30 | 31 | ``` 32 | 33 | 34 | 2. Training: 35 | 36 | 37 | Multilingual NMT: 38 | ``` 39 | fairseq-train data-bin --user-dir . --max-tokens 4096 --max-update 20000 \ 40 | --task multilingual_translation --lang-pairs es-en,pt-en \ 41 | --arch parameter_differentiation_tiny_model --share-all-embeddings --share-encoders --share-decoders \ 42 | --lr-scheduler inverse_sqrt --optimizer adam --lr 0.0015 --validate-interval 4 43 | ``` 44 | 45 | 46 | Parameter differentiation based MNMT 47 | ``` 48 | fairseq-train data-bin --user-dir . --max-tokens 4096 --max-update 20000 \ 49 | --task parameter_differentiation_task --lang-pairs es-en,pt-en \ 50 | --arch parameter_differentiation_tiny_model --share-all-embeddings \ 51 | --lr-scheduler inverse_sqrt --optimizer adam --lr 0.0015 --validate-interval 4 52 | ``` 53 | 54 | 55 | 3. Decoding 56 | ``` 57 | fairseq-generate data-bin --user-dir . --max-tokens 4096 --quiet \ 58 | --task parameter_differentiation_task --lang-pairs es-en,pt-en \ 59 | --remove-bpe sentencepiece --source-lang es --target-lang en \ 60 | --path checkpoints/checkpoint_last.pt 61 | 62 | fairseq-generate data-bin --user-dir . --max-tokens 4096 --quiet \ 63 | --task parameter_differentiation_task --lang-pairs es-en,pt-en \ 64 | --remove-bpe sentencepiece --source-lang pt --target-lang en \ 65 | --path checkpoints/checkpoint_last.pt 66 | ``` 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /task.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from fairseq import utils 3 | from .view import ModelView 4 | from fairseq.trainer import Trainer 5 | from fairseq.tasks import register_task 6 | from fairseq.tasks.multilingual_translation import MultilingualTranslationTask 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | @register_task('parameter_differentiation_task') 13 | class ParameterDifferentiationTask(MultilingualTranslationTask): 14 | _view: ModelView = None 15 | 16 | @property 17 | def view(self): 18 | if self._view is None: 19 | self._view = ModelView(get_trainer().model) 20 | return self._view 21 | 22 | def record_gradient(self, model): 23 | logger.info("Start accumulating gradient") 24 | criterion = get_trainer().get_criterion() 25 | model.eval() # disable dropout 26 | for lang_pair, dataset in self.dataset(self.args.valid_subset).datasets.items(): 27 | batch_iterator = self.get_batch_iterator( 28 | dataset=dataset, max_tokens=self.args.max_tokens_valid, seed=self.args.seed).next_epoch_itr() 29 | model.zero_grad() 30 | for sample in batch_iterator: 31 | sample = utils.move_to_cuda(sample) 32 | loss, _, _ = criterion(model.models[lang_pair], sample) 33 | loss = loss / len(batch_iterator) 34 | loss.backward() 35 | self.view.accum_gradient(lang_pair) 36 | model.zero_grad() 37 | model.train() # enable dropout 38 | logger.info("End accumulating gradient") 39 | 40 | def begin_valid_epoch(self, epoch, model): 41 | self.record_gradient(model) 42 | logger.info("num. model params before: {}".format(sum(p.numel() for p in model.parameters()))) 43 | _ = list(self.view.auto_split()) 44 | logger.info("num. model params after: {}".format(sum(p.numel() for p in model.parameters()))) 45 | self.view.clear_gradient() 46 | get_trainer()._optimizer = None 47 | 48 | 49 | def get_trainer() -> Trainer: 50 | import gc 51 | for obj in gc.get_objects(): 52 | if isinstance(obj, Trainer): 53 | return obj 54 | -------------------------------------------------------------------------------- /view.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import logging 4 | import torch.nn as nn 5 | from sklearn.cluster import AgglomerativeClustering 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def name2module(module, name): 12 | def _generator(_module): 13 | for part in name.split('.'): 14 | _module = getattr(_module, part) 15 | yield _module 16 | return [module] + list(_generator(module)) 17 | 18 | 19 | def get_module_names(model: nn.Module): 20 | all_param_names = [name for name in dict(model.named_parameters()).keys() if 'layers' in name] 21 | all_param_names = [param.rstrip('.weight') for param in all_param_names if 'weight' in param] 22 | return all_param_names 23 | 24 | 25 | def extract_gradient_from_module(module): 26 | grads = [p.grad.view(-1) for _, p in sorted(module.named_parameters(), key=lambda pair: pair[0])] 27 | return torch.cat(grads).data.cpu() 28 | 29 | 30 | class ModelView: 31 | def __init__(self, model): 32 | self.model = model 33 | self.container = {name: model.keys for name in get_module_names(model)} 34 | self.gradients = {lang_pair: {} for lang_pair in model.keys} 35 | 36 | def clear_gradient(self): 37 | self.gradients = {lang_pair: {} for lang_pair in self.model.keys} 38 | 39 | def accum_gradient(self, lang_pair): 40 | cur_model = self.model.models[lang_pair] 41 | for name in get_module_names(cur_model): 42 | module_tree = name2module(cur_model, name) 43 | grad = extract_gradient_from_module(module_tree[-1]) 44 | self.gradients[lang_pair][name] = grad + self.gradients[lang_pair].get(name, 0) 45 | 46 | def auto_split(self): 47 | logger.info('Detect split parameters by grad') 48 | # 根据梯度,计算每个模块的散度 49 | # calculate distance (or divergence) of each module. 50 | divergences = {} 51 | for name, lang_pairs in self.container.items(): 52 | # name是模块的全名,lang_pairs是这个模块被多少语言对共享。 53 | # name is the full name of a module. lang_pairs is all languages that share this module. 54 | # 如果把name中的lang_pair变为lang_pairs中的lang_pair,那实际指向的是同一个模块 55 | # if we change the `lang_pair` in `name` to `lang_pair` in `lang_pairs`, they actual point to the same module. 56 | short_name = ".".join(name.split('.')[2:]) # name: 'models.en-de.encoder.layers.0' short_name: 'encoder.layers.0' 57 | module_gradients = {lang_pair: self.gradients[lang_pair][short_name] for lang_pair in lang_pairs} 58 | divergences[name] = calculate_div(module_gradients) 59 | 60 | # 按距离排序,从大到小,-1表示距离最小 61 | # sorted by distance from large to small. -1 means the smallest distance. 62 | sorted_divergences = [d for d in sorted(divergences.items(), key=lambda item: -item[1][1]) if d[1][1] > 0] 63 | for best_name, (best_lang_pairs, best_score) in sorted_divergences[:2]: 64 | logger.info('Split shared parameters: {}'.format(best_name)) 65 | logger.info('This parameter is shared by {}'.format(','.join(best_lang_pairs[0] + best_lang_pairs[1]))) 66 | logger.info('After split: {} {}'.format(','.join(best_lang_pairs[0]), ','.join(best_lang_pairs[1]))) 67 | logger.info('Cosine distance is {}'.format(best_score)) 68 | yield self.split_module(best_name, best_lang_pairs) 69 | 70 | def split_module(self, module_to_split, split_lang_pairs): 71 | # 1. 修改container的内容. Change the content in the container. 72 | # 旧的参数以lang_pairs[0][i]为base. Old parameters take lang_pairs[0][i] as base. 73 | if module_to_split.split(".")[1] in split_lang_pairs[1]: 74 | split_lang_pairs[0], split_lang_pairs[1] = split_lang_pairs[1], split_lang_pairs[0] 75 | 76 | self.container[module_to_split] = split_lang_pairs[0] 77 | # 新的参数以lang_pairs[1][0]为base. New parameters take lang_pairs[1][0] as base. 78 | new_name = ".".join([module_to_split.split(".")[0], split_lang_pairs[1][0]] + module_to_split.split(".")[2:]) 79 | self.container[new_name] = split_lang_pairs[1] 80 | 81 | # 2. 新建参数. Create new parameters 82 | module_tree = name2module(self.model, module_to_split) 83 | new_module = copy.deepcopy(module_tree[-1]).cuda() 84 | 85 | # 3. 给第二个聚类中的语言,赋予该模块. assign the new parameter to languages in the second cluster. 86 | # 第一个聚类还是原来的参数。 the languages in the first cluster use the origin parameters. 87 | for lang_pair in split_lang_pairs[1]: 88 | module_name = ".".join([module_to_split.split(".")[0], lang_pair] + module_to_split.split(".")[2:]) 89 | module_tree = name2module(self.model, module_name) 90 | setattr(module_tree[-2], module_name.split(".")[-1], new_module) 91 | return new_name, module_to_split 92 | 93 | 94 | def calculate_div(module_gradients): 95 | """ 96 | 对于一个特定模块,由L种语言共享,module_gradients就是在这个模块上,每个语言对应的梯度。 97 | 本函数对其进行聚类,最后分为两个类别,并返回类间距离。 98 | For a specific module that shared by L languages, `module_gradients` means the gradient of each language on this module. 99 | This function clusters the languages into two clusters, return the two clusters and their inter-cluster distance. 100 | :param module_gradients: dict of {lang_pair: gradient} 101 | :return: [[cluster_1], [cluster_2]], distance 102 | """ 103 | if len(module_gradients) < 2: 104 | return [], -1 105 | 106 | cluster = AgglomerativeClustering(linkage='average', affinity='cosine', n_clusters=2, compute_distances=True) 107 | lang_pairs, gradients = zip(*module_gradients.items()) 108 | labels = cluster.fit_predict(torch.stack(gradients).numpy() + 1e-5) 109 | cluster_0 = [lang_pair for lang_pair, label in zip(lang_pairs, labels) if label == 0] 110 | cluster_1 = [lang_pair for lang_pair, label in zip(lang_pairs, labels) if label == 1] 111 | return [cluster_0, cluster_1], cluster.distances_[-1] 112 | --------------------------------------------------------------------------------