├── README.md ├── configs └── demo │ ├── baseline.yml │ ├── fine-tune.yml │ └── multilingual.yml ├── fairseq_code ├── __init__.py ├── datasets │ ├── __init__.py │ ├── batched_sampled_multi_epoch_dataset.py │ └── multilingual_data_manager.py ├── models │ ├── __init__.py │ ├── transformer_mask_base_model.py │ ├── transformer_with_mask.py │ └── utils.py ├── tasks │ ├── __init__.py │ └── mask_translation_multi_simple_epoch.py ├── toolbox │ ├── __init__.py │ ├── calculate_sad.py │ ├── calculate_sad_from_config.py │ ├── count_flops_utils │ │ ├── __init__.py │ │ ├── attach_new_forward.py │ │ ├── module_wrapper.py │ │ └── registry.py │ ├── generate_from_config.py │ ├── generate_mask_from_config.py │ ├── generate_mask_from_softthreshold.py │ ├── generate_to_count_flops.py │ └── util.py └── utils │ ├── common.py │ ├── file_operation.py │ └── logging.py ├── requirements.txt ├── scripts ├── data processing │ ├── deduplicate.sh │ ├── learn_and_encode_spm.sh │ └── preprocessing.sh ├── evaluate.sh ├── install.sh └── train.sh └── toolbox ├── __init__.py ├── add_new_mask.py ├── cal_similarity.py ├── generate_mask.py ├── generate_random_mask.py ├── merge_mask.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # LaSS: Learning Language Specific Sub-network for Multilingual Machine Translation 2 | 3 | This is the repo for ACL2021 paper Learning Language Specific Sub-network for Multilingual Machine Translation. 4 | 5 | [paper](https://arxiv.org/abs/2105.09259) 6 | 7 | 8 | ## Introduction 9 | 10 | LaSS, representing **La**nguage **S**pecific **S**ub-network, is a single unified multilingual MT model. LaSS aims at alleviating the well-known parameter interference issue in multilingual MT by accommodating one sub-network for each language pair. Extensive experiments demonstrate the efficacy of LaSS and its strong generalization performance in different scenarios. 11 | 12 | 13 | ## Pre-requisite 14 | 15 | 16 | ``` 17 | pip3 install -r requirements.txt 18 | ``` 19 | 20 | ## Pipeline 21 | 22 | The pipeline contains 4 steps: 23 | 1. Train a vanilla multilingual baseline 24 | 2. Fine-tune the baseline for each language pair 25 | 3. Obtain the masks from the fine-tuned model 26 | 4. Continue training the vanilla multilingual baseline with the obtained masks 27 | 28 | ### Data Processing 29 | 30 | Before the training phase, you need to prepare the data. In general, data processing contains the following steps: 31 | * Data filtering 32 | * Data deduplication 33 | * Learning/Applying joint BPE vocabulary 34 | * Data Cleaning 35 | 36 | For IWSLT we used in the paper, we directly use [this scripts](https://github.com/RayeRen/multilingual-kd-pytorch/blob/master/data/iwslt/raw/prepare-iwslt14.sh). 37 | 38 | For WMT, we collect data from the official WMT website. For details please refer to the appendix of our paper. 39 | 40 | We provide some [data preprocessing scripts](scripts/data%20processing) for reference. 41 | 42 | ### Multilingual baseline 43 | 44 | We first train a vanilla multilingual baseline. 45 | 46 | ``` 47 | bash scripts/train.sh —config baseline.yml 48 | ``` 49 | 50 | ### Fine-tune the baseline 51 | 52 | After obtaining the vanilla multilingual baseline, we need to fine-tune the baseline for each language pair. 53 | 54 | ``` 55 | bash scripts/train.sh —config finetune.yml 56 | ``` 57 | 58 | After fine-tuning, we obtain n models, where n represents the number of language pairs we use. 59 | 60 | ### Obtain the masks 61 | 62 | For each language pair, we need to prune the α percent lowest weights to obtain the sub-networks. 63 | 64 | ``` 65 | python3 toolbox/generate_mask.py —checkpoint-path xx —mask-path /path/to/destination —gen-mask-with-prob —mask-prob α —gen-part all —exclude-output-proj 66 | ``` 67 | 68 | 69 | ### Training with masks 70 | 71 | The last step is to continue training the vanilla multilingual model with the obtained masks. 72 | 73 | ``` 74 | bash scripts/train.sh —config multilingual.yml 75 | ``` 76 | 77 | The yaml config mentioned above can be found in [here](configs/demo). 78 | 79 | 80 | ### Evaluation 81 | 82 | You can evaluate the trained model with the following script: 83 | ``` 84 | bash scripts/evaluate.sh --config config.yml --checkpoint-name xxx --lang-pairs x-y --evaluate-bin /path/to/your/data 85 | ``` 86 | 87 | * `--config` is the training config. 88 | * `--lang-pairs` is not necessary. If not available, the script will evaluate all the language pair in the config. 89 | * `--evaluate-bin` is also not necessary. If not available, the script will load the data from `data_bin` in the config. 90 | 91 | 92 | -------------------------------------------------------------------------------- /configs/demo/baseline.yml: -------------------------------------------------------------------------------- 1 | 2 | data_bin: data-bin/multilingual-kd_iwslt_cased/multilingual 3 | save_dir: checkpoints/iwslt/baseline/multilingual/multilingual_transformer_iwslt_arch_dp0.1_nodecay_bigbatch_normal_lr 4 | tensorboard_logdir: tensorboard_logdir/iwslt/baseline/multilingual/multilingual_transformer_iwslt_arch_dp0.1_nodecay_bigbatch_normal_lr 5 | #hdfs_yml_config: yml_config/iwslt/baseline/multilingual/multilingual_transformer_iwslt_arch_dp0.1_nodecay_bigbatch_normal_lr.yml 6 | 7 | langs: fa,he,pl,it,ar,es,de,nl,en 8 | lang_pairs: ar-en,de-en,en-ar,en-de,en-es,en-fa,en-he,en-it,en-nl,en-pl,es-en,fa-en,he-en,it-en,nl-en,pl-en 9 | 10 | arch: transformer_iwslt_de_en 11 | share_decoder_input_output_embed: true 12 | dropout: 0.1 13 | 14 | task: translation_multi_simple_epoch 15 | sampling_method: temperature 16 | sampling_temperature: 2 17 | encoder_langtok: src 18 | decoder_langtok: true 19 | 20 | criterion: label_smoothed_cross_entropy 21 | label_smoothing: 0.1 22 | optimizer: adam 23 | adam_betas: "'(0.9, 0.98)'" 24 | clip_norm: 0.0 25 | lr: 5e-4 26 | lr_scheduler: inverse_sqrt 27 | warmup_updates: 4000 28 | weight_decay: 0.0 29 | 30 | # 4 gpus 31 | max_tokens: 16384 32 | update_freq: 4 33 | 34 | max_update: 160000 35 | fp16: true 36 | 37 | 38 | patience: 20 39 | save_interval_updates: 1000 40 | keep_interval_updates: 20 41 | no_epoch_checkpoints: true 42 | seed: 22 43 | log_format: simple 44 | log_interval: 20 45 | -------------------------------------------------------------------------------- /configs/demo/fine-tune.yml: -------------------------------------------------------------------------------- 1 | 2 | data_bin: data-bin/multilingual-kd_iwslt_cased/de2en 3 | save_dir: checkpoints/iwslt/baseline/multilingual_finetuning/de2en_fine-tuning 4 | tensorboard_logdir: tensorboard_logdir/iwslt/baseline/multilingual_finetuning/de2en_fine-tuning 5 | #hdfs_yml_config: yml_config/iwslt/baseline/multilingual_finetuning/de2en_fine-tuning.yml 6 | restore-file: checkpoints/iwslt/baseline/multilingual/multilingual_transformer_iwslt_arch_dp0.1_nodecay_bigbatch_normal_lr/checkpoint_best.pt 7 | 8 | 9 | reset_dataloader: true 10 | reset_meters: true 11 | reset_optimizer: true 12 | reset_lr_scheduler: true 13 | 14 | langs: fa,he,pl,it,ar,es,de,nl,en 15 | lang_pairs: de-en 16 | 17 | arch: transformer_iwslt_de_en 18 | share_decoder_input_output_embed: true 19 | dropout: 0.3 20 | 21 | 22 | task: translation_multi_simple_epoch 23 | sampling_method: temperature 24 | sampling_temperature: 2 25 | encoder_langtok: src 26 | decoder_langtok: true 27 | 28 | criterion: label_smoothed_cross_entropy 29 | label_smoothing: 0.1 30 | optimizer: adam 31 | adam_betas: "'(0.9, 0.98)'" 32 | clip_norm: 0.0 33 | lr: 5e-4 34 | lr_scheduler: inverse_sqrt 35 | warmup_updates: 4000 36 | weight_decay: 0.0 37 | 38 | max_tokens: 16384 39 | update_freq: 1 40 | 41 | max_update: 160000 42 | fp16: true 43 | 44 | patience: 30 45 | save_interval: 1000 46 | save_interval_updates: 500 47 | keep_interval_updates: 15 48 | no_epoch_checkpoints: true 49 | seed: 22 50 | log_format: simple 51 | log_interval: 20 52 | 53 | #seed: 1234 54 | #log_format: simple 55 | #log_interval: 50 56 | #patience: 15 57 | #keep_last_epochs: 20 58 | # 59 | #eval-bleu: true 60 | #eval-tokenized-bleu: true 61 | #eval-bleu-args: "'{\"beam\": 5, \"max_len_a\": 1.2, \"max_len_b\": 10}'" 62 | #eval-bleu-remove-bpe: true 63 | #eval-bleu-print-samples: true 64 | #best-checkpoint-metric: bleu 65 | #maximize-best-checkpoint-metric: true 66 | 67 | 68 | -------------------------------------------------------------------------------- /configs/demo/multilingual.yml: -------------------------------------------------------------------------------- 1 | 2 | data_bin: data-bin/multilingual-kd_iwslt_cased/multilingual 3 | save_dir: checkpoints/iwslt/baseline/multilingual_static_mask/multilingual_static_mask_prob${prob} 4 | tensorboard_logdir: tensorboard_logdir/iwslt/baseline/multilingual_static_mask/multilingual_static_mask_prob${prob} 5 | #hdfs_yml_config: yml_config/iwslt/baseline/multilingual_static_mask/multilingual_static_mask_prob${prob}.yml 6 | restore-file: checkpoints/iwslt/baseline/multilingual/multilingual_transformer_iwslt_arch_dp0.1_nodecay_bigbatch_normal_lr/checkpoint_best.pt 7 | 8 | reset_dataloader: true 9 | reset_meters: true 10 | reset_optimizer: true 11 | reset_lr_scheduler: true 12 | 13 | langs: fa,he,pl,it,ar,es,de,nl,en 14 | lang_pairs: ar-en,de-en,en-ar,en-de,en-es,en-fa,en-he,en-it,en-nl,en-pl,es-en,fa-en,he-en,it-en,nl-en,pl-en 15 | 16 | arch: transformer_iwslt_arch_with_mask 17 | share_decoder_input_output_embed: true 18 | dropout: 0.1 19 | 20 | 21 | task: mask_translation_multi_simple_epoch 22 | no-mask-output-project: true 23 | no-save-static-mask: true 24 | mask_path: "'{\"de-en\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/de2en_ft_mask_prob_${prob}.pt\", 25 | \"en-de\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/en2de_ft_mask_prob_${prob}.pt\", 26 | \"ar-en\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/ar2en_ft_mask_prob_${prob}.pt\", 27 | \"en-ar\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/en2ar_ft_mask_prob_${prob}.pt\", 28 | \"en-es\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/en2es_ft_mask_prob_${prob}.pt\", 29 | \"es-en\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/es2en_ft_mask_prob_${prob}.pt\", 30 | \"en-fa\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/en2fa_ft_mask_prob_${prob}.pt\", 31 | \"fa-en\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/fa2en_ft_mask_prob_${prob}.pt\", 32 | \"en-he\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/en2he_ft_mask_prob_${prob}.pt\", 33 | \"he-en\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/he2en_ft_mask_prob_${prob}.pt\", 34 | \"en-it\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/en2it_ft_mask_prob_${prob}.pt\", 35 | \"it-en\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/it2en_ft_mask_prob_${prob}.pt\", 36 | \"en-nl\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/en2nl_ft_mask_prob_${prob}.pt\", 37 | \"nl-en\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/nl2en_ft_mask_prob_${prob}.pt\", 38 | \"en-pl\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/en2pl_ft_mask_prob_${prob}.pt\", 39 | \"pl-en\": \"checkpoints/iwslt/baseline/multilingual_finetuning/abs_mask_path/pl2en_ft_mask_prob_${prob}.pt\" 40 | }'" 41 | sampling_method: temperature 42 | sampling_temperature: 2 43 | encoder_langtok: src 44 | decoder_langtok: true 45 | 46 | criterion: label_smoothed_cross_entropy 47 | label_smoothing: 0.1 48 | optimizer: adam 49 | adam_betas: "'(0.9, 0.98)'" 50 | clip_norm: 0.0 51 | lr: 5e-4 52 | lr_scheduler: inverse_sqrt 53 | warmup_updates: 4000 54 | weight_decay: 0.0 55 | 56 | # 4 gpu 57 | max_tokens: 16384 58 | update_freq: 4 59 | 60 | max_update: 160000 61 | fp16: true 62 | ddp_backend: no_c10d 63 | 64 | patience: 20 65 | save_interval: 500 66 | save_interval_updates: 500 67 | validate_interval: 500 68 | validate_interval_updates: 500 69 | keep_interval_updates: 15 70 | no_epoch_checkpoints: true 71 | seed: 23 72 | log_format: simple 73 | log_interval: 20 -------------------------------------------------------------------------------- /fairseq_code/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import * 2 | from .tasks import * 3 | from .models import * 4 | -------------------------------------------------------------------------------- /fairseq_code/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .batched_sampled_multi_epoch_dataset import * 2 | from .multilingual_data_manager import * -------------------------------------------------------------------------------- /fairseq_code/datasets/batched_sampled_multi_epoch_dataset.py: -------------------------------------------------------------------------------- 1 | from fairseq.data import SampledMultiEpochDataset 2 | import cytoolz as toolz 3 | import more_itertools 4 | 5 | 6 | class BatchedSampledMultiEpochDataset(SampledMultiEpochDataset): 7 | """ 8 | The only difference compared with SampledMultiEpochDataset is 9 | batch size. This dataset will only group data from one dataset 10 | to one batch. 11 | """ 12 | 13 | def _group_indices_by_dataset_index(self, indices): 14 | return toolz.groupby(lambda x: self._get_dataset_and_index(x)[0], indices) 15 | 16 | def batch_by_size(self, indices, max_tokens=None, max_sentences=None, required_batch_size_multiple=1): 17 | batches = [] 18 | for _, grouped_indices in self._group_indices_by_dataset_index(indices).items(): 19 | # Group indices by the dataset. 20 | batches.append( 21 | super().batch_by_size(grouped_indices, max_tokens, max_sentences, required_batch_size_multiple) 22 | ) 23 | return list(more_itertools.flatten(batches)) 24 | 25 | def collater(self, samples, **extra_args): 26 | if len(samples) == 0: 27 | return None 28 | # Add language to the batch 29 | batch = super().collater(samples, **extra_args) 30 | assert len(set(sample[0] for sample in samples)) == 1 31 | key = self.keys[samples[0][0]] 32 | # The format of key {data_category}:{src}-{tgt} 33 | src_lang, tgt_lang = key.split(":")[1].strip().split("-") 34 | batch["src_lang"] = src_lang 35 | batch["tgt_lang"] = tgt_lang 36 | return batch 37 | -------------------------------------------------------------------------------- /fairseq_code/datasets/multilingual_data_manager.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import cytoolz as toolz 4 | import more_itertools 5 | 6 | from fairseq.data import SampledMultiEpochDataset 7 | from fairseq.data.multilingual.multilingual_data_manager import MultilingualDatasetManager as OldMultilingualDatasetManager 8 | from fairseq.data.multilingual.sampled_multi_dataset import CollateFormat 9 | from fairseq.data.multilingual.sampled_multi_dataset import SampledMultiDataset as OldSampledMultiDataset 10 | 11 | from .batched_sampled_multi_epoch_dataset import BatchedSampledMultiEpochDataset 12 | 13 | 14 | class SampledMultiDataset(OldSampledMultiDataset): 15 | 16 | def _group_indices_by_dataset_index(self, indices): 17 | return toolz.groupby(lambda x: self._get_dataset_and_index(x)[0], indices) 18 | 19 | def batch_by_size(self, indices, max_tokens=None, max_sentences=None, required_batch_size_multiple=1): 20 | batches_list = [] 21 | self.batched_size_dict = {} 22 | for k, grouped_indices in self._group_indices_by_dataset_index(indices).items(): 23 | # Group indices by the dataset. 24 | batches = super().batch_by_size(grouped_indices, max_tokens, max_sentences, required_batch_size_multiple) 25 | batches_list.append( 26 | batches 27 | ) 28 | self.batched_size_dict[k] = len(batches) 29 | return list(more_itertools.flatten(batches_list)) 30 | 31 | def collater(self, samples, **extra_args): 32 | if len(samples) == 0: 33 | return None 34 | # Add language to the batch 35 | batch = super().collater(samples, **extra_args) 36 | assert len(batch) == 1 37 | key = list(batch.keys())[0] 38 | # The format of key {data_category}:{src}-{tgt} 39 | src_lang, tgt_lang = key.split(":")[1].strip().split("-") 40 | batch = batch[key] 41 | batch["src_lang"] = src_lang 42 | batch["tgt_lang"] = tgt_lang 43 | return batch 44 | 45 | 46 | class MultilingualDatasetManager(OldMultilingualDatasetManager): 47 | @classmethod 48 | def setup_data_manager(cls, args, lang_pairs, langs, dicts, sampling_method): 49 | return cls(args, lang_pairs, langs, dicts, sampling_method) 50 | 51 | def load_into_concat_dataset(self, split, datasets, data_param_list): 52 | return SampledMultiDataset( 53 | OrderedDict(datasets), 54 | sampling_ratios=None, 55 | eval_key=None, 56 | collate_format="ordered_dict", 57 | virtual_size=None, 58 | split=split, 59 | ) 60 | 61 | def load_sampled_multi_epoch_dataset(self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs): 62 | # Datasets is a list of tuple with type Tuple[str, FairseqDataset], the string is 63 | # a key in data_param_list attribute. 64 | # Data_param_list is a list of dict, the dict contains {"key" ,...} 65 | datasets, data_param_list = self.load_split_datasets( 66 | split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs 67 | ) 68 | if training and split == getattr(self.args, "train_subset", None): 69 | sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) 70 | return BatchedSampledMultiEpochDataset( 71 | OrderedDict(datasets), 72 | epoch=epoch, 73 | shard_epoch=shard_epoch, 74 | # valid and test datasets will be degenerate to concating datasets: 75 | sampling_ratios=sample_ratios, 76 | eval_key=None, 77 | collate_format=CollateFormat.single, 78 | virtual_size=self.args.virtual_data_size, 79 | split=split, 80 | virtual_epoch_size=self.args.virtual_epoch_size, 81 | # if not using lang_tok altering, simplified to use the same collater 82 | shared_collater=self._shared_collater(), 83 | ) 84 | else: 85 | return self.load_into_concat_dataset(split, datasets, data_param_list) 86 | 87 | -------------------------------------------------------------------------------- /fairseq_code/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer_with_mask import * 2 | -------------------------------------------------------------------------------- /fairseq_code/models/transformer_mask_base_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | from fairseq.models.transformer import TransformerModel 4 | 5 | 6 | def _catalog_shared_params(module, memo=None, prefix=""): 7 | if memo is None: 8 | first_call = True 9 | memo = {} 10 | else: 11 | first_call = False 12 | for name, param in module._parameters.items(): 13 | param_prefix = prefix + ("." if prefix else "") + name 14 | if param not in memo: 15 | memo[param] = [] 16 | memo[param].append(param_prefix) 17 | for name, m in module._modules.items(): 18 | if m is None: 19 | continue 20 | if not prefix.endswith("."): 21 | submodule_prefix = prefix + ("." if prefix else "") + name 22 | else: 23 | submodule_prefix = prefix + name 24 | _catalog_shared_params(m, memo, submodule_prefix) 25 | if first_call: 26 | return [x for x in memo.values() if len(x) > 1] 27 | 28 | 29 | class TransformerMaskBaseModel(TransformerModel, metaclass=ABCMeta): 30 | 31 | @staticmethod 32 | def add_args(parser): 33 | super(TransformerMaskBaseModel, TransformerMaskBaseModel).add_args(parser) 34 | 35 | @staticmethod 36 | def lang_pair(src: str, tgt: str): 37 | return f"{src}-{tgt}" 38 | 39 | def _format_name(self, k): 40 | return k.replace(".", "|") 41 | 42 | def upgrade_state_dict_named(self, state_dict, name): 43 | # Add the share parameter to the different name to match the pytorch api 44 | shared_names = [sorted(t) for t in _catalog_shared_params(self, prefix=name)] 45 | for names in shared_names: 46 | for _name in names[1:]: 47 | if names[0] not in state_dict: 48 | continue 49 | if _name not in state_dict: 50 | state_dict[_name] = state_dict[names[0]] 51 | 52 | def state_dict(self, destination=None, prefix='', keep_vars=False): 53 | state_dict = super().state_dict(destination, prefix, keep_vars) 54 | shared_names = [sorted(t) for t in _catalog_shared_params(self, prefix=prefix)] 55 | for names in shared_names: 56 | for name in names[1:]: 57 | # Remove the shared parameter's other name 58 | del state_dict[name] 59 | return state_dict 60 | 61 | @abstractmethod 62 | def patch_all_mask(self, src_lang, tgt_lang): 63 | pass 64 | -------------------------------------------------------------------------------- /fairseq_code/models/transformer_with_mask.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from collections import defaultdict 4 | 5 | import typing 6 | 7 | from fairseq.file_io import PathManager 8 | from fairseq.models import register_model, register_model_architecture 9 | 10 | import torch.nn as nn 11 | import torch 12 | 13 | import cytoolz as toolz 14 | 15 | from .transformer_mask_base_model import TransformerMaskBaseModel 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | @register_model("transformer_with_mask") 21 | class TransformerWithMaskModel(TransformerMaskBaseModel): 22 | """ 23 | This is a transformer model weight mask. 24 | The mask is described with a dict, the key is like the key in the state_dict. 25 | The mask will mask the weight in the state_dict with the same key. 26 | """ 27 | 28 | @staticmethod 29 | def add_args(parser): 30 | super(TransformerWithMaskModel, TransformerWithMaskModel).add_args(parser) 31 | parser.add_argument("--mask-path", help="A json dict of the path to all language direction") 32 | parser.add_argument("--no-save-static-mask", action="store_true") 33 | parser.add_argument("--mask-embedding", action="store_true", 34 | help="Mask the embedding module") 35 | parser.add_argument("--no-mask-output-project", action="store_true", help="No mask the output project") 36 | 37 | @classmethod 38 | def build_model(cls, args, task): 39 | model = super().build_model(args, task) 40 | mask_dict = getattr(args, "mask_path", None) 41 | if mask_dict is not None: 42 | mask_dict = json.loads(mask_dict) 43 | for k, mask_path in mask_dict.items(): 44 | with PathManager.open(mask_path, "rb") as f: 45 | mask = torch.load( 46 | f, 47 | map_location=( 48 | lambda s, _: torch.serialization.default_restore_location(s, "cpu") 49 | ), 50 | ) 51 | model.add_language_mask(k, mask) 52 | # model.language_mask.to(model.device) 53 | 54 | return model 55 | 56 | def __init__(self, args, encoder, decoder): 57 | super().__init__(args, encoder, decoder) 58 | self.no_save_static_mask = getattr(args, "no_save_static_mask", False) 59 | self.language_mask = nn.ModuleDict({}) 60 | 61 | def add_language_mask(self, k, v): 62 | if k in self.language_mask: 63 | return 64 | self.language_mask[k] = nn.ParameterDict({self._format_name(a): nn.Parameter(b, requires_grad=False) 65 | for a, b in v.items()}) 66 | 67 | def upgrade_state_dict_named(self, state_dict, name): 68 | # expand the self.language_mask to match the language mask in state_dict 69 | 70 | loaded_mask_name = [] 71 | for k in state_dict.keys(): 72 | if k.startswith("language_mask"): 73 | loaded_mask_name.append(k) 74 | grouped_loaded_mask_name = toolz.groupby(lambda x: x.strip().split(".")[1], loaded_mask_name) 75 | for language, keys in grouped_loaded_mask_name.items(): 76 | # self.language_mask[language] = nn.ParameterDict({k.strip().split(".")[2]: state_dict[k] for k in keys}) 77 | self.add_language_mask(language, {k.strip().split(".")[2]: state_dict[k] for k in keys}) 78 | # self.language_mask.to(self.device) 79 | 80 | # expand the self.language_mask to match the language mask in state_dict 81 | # Also need to add the language_mask to the state_dict if not in state_dict 82 | for k, v in self.language_mask.items(): 83 | for i_k, d in v.items(): 84 | key = f"language_mask.{k}.{i_k}" 85 | if key in self.language_mask: 86 | pass 87 | else: 88 | state_dict[key] = d 89 | 90 | # move global weight to weight, all weight will be moved to global weight when it need used 91 | for k in filter(lambda x: "global_weight" in x, list(state_dict.keys())): 92 | new_k = k.replace("global_weight", "weight") 93 | state_dict[new_k] = state_dict[k] 94 | del state_dict[k] 95 | 96 | super().upgrade_state_dict_named(state_dict, name) 97 | 98 | def patch_all_mask(self, src_lang: str, tgt_lang: str): 99 | threshold_percent = self._patch_static_mask(src_lang, tgt_lang) 100 | return self, threshold_percent 101 | 102 | def _patch_static_mask(self, src_lang, tgt_lang): 103 | lang_pair = self.lang_pair(src=src_lang, tgt=tgt_lang) 104 | assert lang_pair in self.language_mask 105 | 106 | total_number = 0 107 | un_mask_number = 0 108 | 109 | for name, c in self.named_modules(): 110 | if isinstance(c, nn.Linear) or (self.args.mask_embedding and isinstance(c, nn.Embedding)): 111 | if getattr(self.args, "no_mask_output_project", False) and "output_projection" == name: 112 | continue 113 | if not hasattr(c, "global_weight"): 114 | c.global_weight = c.weight 115 | del c.weight 116 | mask_key = name + ".weight" 117 | mask_key = self._format_name(mask_key) 118 | if mask_key in self.language_mask[lang_pair]: 119 | mask = self.language_mask[lang_pair][mask_key] 120 | total_number += mask.numel() 121 | un_mask_number += mask.sum().item() 122 | c.weight = mask * c.global_weight 123 | # def repopulate_weight(mod, _): 124 | # nonlocal mask 125 | # mod.weight = mask * mod.global_weight 126 | # c.register_forward_pre_hook(repopulate_weight) 127 | else: 128 | c.weight = c.global_weight 129 | 130 | return un_mask_number / total_number 131 | 132 | def state_dict(self, destination=None, prefix='', keep_vars=False): 133 | state_dict = super().state_dict(destination, prefix, keep_vars) 134 | 135 | # Remove the static dict from checkpoint 136 | if self.no_save_static_mask: 137 | for k in list(state_dict.keys()): 138 | if k.startswith("language_mask"): 139 | del state_dict[k] 140 | return state_dict 141 | 142 | 143 | @register_model_architecture("transformer_with_mask", "transformer_with_mask") 144 | def base_architecture(args): 145 | from fairseq.models.transformer import base_architecture as transformer_base_architecture 146 | transformer_base_architecture(args) 147 | args.mask_path = getattr(args, "mask_path", None) 148 | args.mask_embedding = getattr(args, "mask_embedding", False) 149 | args.no_mask_output_project = getattr(args, "no_mask_output_project", False) 150 | 151 | 152 | @register_model_architecture("transformer_with_mask", "transformer_iwslt_arch_with_mask") 153 | def transformer_iwslt_arch_with_mask(args): 154 | from fairseq.models.transformer import transformer_iwslt_de_en 155 | transformer_iwslt_de_en(args) 156 | base_architecture(args) 157 | 158 | 159 | @register_model_architecture("transformer_with_mask", "transformer_vaswani_wmt_en_fr_big_with_mask") 160 | def transformer_vaswani_wmt_en_fr_big_with_mask(args): 161 | from fairseq.models.transformer import transformer_vaswani_wmt_en_fr_big 162 | transformer_vaswani_wmt_en_fr_big(args) 163 | base_architecture(args) 164 | 165 | 166 | @register_model_architecture("transformer_with_mask", "transformer_wmt_en_de_big_t2t_with_mask") 167 | def transformer_wmt_en_de_big_t2t_with_mask(args): 168 | from fairseq.models.transformer import transformer_wmt_en_de_big_t2t 169 | transformer_wmt_en_de_big_t2t(args) 170 | base_architecture(args) 171 | 172 | 173 | @register_model_architecture("transformer_with_mask", "architecture_for_test") 174 | def architecture_for_test(args): 175 | args.encoder_embed_dim = 256 176 | args.encoder_ffn_embed_dim = 1024 177 | args.decoder_ffn_embed_dim = 1024 178 | args.encoder_layers = 3 179 | args.decoder_layers = 3 180 | args.encoder_attention_heads = 4 181 | args.decoder_attention_heads = 4 182 | args.encoder_normalize_before = True 183 | args.decoder_normalize_before = True 184 | args.share_all_embeddings = True 185 | 186 | args.soft_threshold = True 187 | args.soft_threshold_level = "vector_shared_g" 188 | args.soft_threshold_init_bias = -12800 189 | 190 | # args.mask_path = """{ 191 | # "ar-en":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 192 | # "de-en":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 193 | # "en-ar":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 194 | # "en-de":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 195 | # "en-es":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 196 | # "en-fa":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 197 | # "en-he":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 198 | # "en-it":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 199 | # "en-nl":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 200 | # "en-pl":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 201 | # "es-en":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 202 | # "fa-en":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 203 | # "he-en":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 204 | # "it-en":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 205 | # "nl-en":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt", 206 | # "pl-en":"/data00/home/wuliwei.000/test_data_sparse_sharing/mask.pt" 207 | # }""" 208 | base_architecture(args) 209 | -------------------------------------------------------------------------------- /fairseq_code/models/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch 4 | from torch import nn as nn 5 | 6 | 7 | def loop_linear_module_for_model(m, no_mask_output_project=False, with_name=False, prefix="", 8 | skip_pattern=None): 9 | 10 | if skip_pattern is not None: 11 | pattern = re.compile(skip_pattern) 12 | else: 13 | pattern = None 14 | for n, c in m.named_children(): 15 | 16 | if no_mask_output_project and n == "output_projection": 17 | continue 18 | 19 | now_name = n if len(prefix) == 0 else prefix + "." + n 20 | 21 | if skip_pattern is not None: 22 | if now_name == skip_pattern: 23 | continue 24 | if pattern.match(now_name) is not None: 25 | continue 26 | 27 | if isinstance(c, nn.Linear): 28 | if with_name: 29 | yield now_name, c 30 | else: 31 | yield c 32 | yield from loop_linear_module_for_model(c, no_mask_output_project=no_mask_output_project, 33 | with_name=with_name, prefix=now_name, skip_pattern=skip_pattern) 34 | 35 | 36 | def get_row_mask(weight): 37 | return torch.all(weight == 0, dim=1) 38 | 39 | 40 | -------------------------------------------------------------------------------- /fairseq_code/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .mask_translation_multi_simple_epoch import MaskTranslationMultiSimpleEpochTask 2 | -------------------------------------------------------------------------------- /fairseq_code/tasks/mask_translation_multi_simple_epoch.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | from fairseq.logging import metrics 5 | from fairseq.tasks import register_task 6 | from fairseq.tasks.translation_multi_simple_epoch import TranslationMultiSimpleEpochTask 7 | 8 | from ..datasets import MultilingualDatasetManager 9 | 10 | 11 | @register_task("mask_translation_multi_simple_epoch") 12 | class MaskTranslationMultiSimpleEpochTask(TranslationMultiSimpleEpochTask): 13 | 14 | def __init__(self, args, langs, dicts, training): 15 | super().__init__(args, langs, dicts, training) 16 | self.data_manager = MultilingualDatasetManager.setup_data_manager( 17 | args, self.lang_pairs, langs, dicts, self.sampling_method 18 | ) 19 | 20 | def valid_step(self, sample, model, criterion): 21 | with torch.no_grad(): 22 | model.eval() 23 | model.patch_all_mask(src_lang=sample['src_lang'], tgt_lang=sample['tgt_lang']) 24 | return super().valid_step(sample, model, criterion) 25 | 26 | def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False): 27 | src_lang = sample['src_lang'] 28 | tgt_lang = sample['tgt_lang'] 29 | model.train() 30 | model.set_num_updates(update_num) 31 | model.patch_all_mask(src_lang=src_lang, tgt_lang=tgt_lang) 32 | res = super().train_step(sample, model, criterion, optimizer, update_num, ignore_grad) 33 | return res 34 | 35 | def build_generator(self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None): 36 | for m in models: 37 | m.eval() 38 | models = [m.patch_all_mask( 39 | src_lang=self.args.source_lang, 40 | tgt_lang=self.args.target_lang 41 | )[0] for m in models] 42 | return super().build_generator(models, args, seq_gen_cls, extra_gen_cls_kwargs) 43 | -------------------------------------------------------------------------------- /fairseq_code/toolbox/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Playground/LaSS/83be1881e9c25e26b598401d73867170af5adeba/fairseq_code/toolbox/__init__.py -------------------------------------------------------------------------------- /fairseq_code/toolbox/calculate_sad.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Translate pre-processed data with a trained model. 8 | """ 9 | 10 | import ast 11 | import json 12 | import logging 13 | import os 14 | import re 15 | import sys 16 | from argparse import Namespace 17 | from collections import defaultdict 18 | 19 | import numpy as np 20 | import torch 21 | import typing 22 | from fairseq import checkpoint_utils, options, tasks, utils 23 | from fairseq.dataclass.utils import convert_namespace_to_omegaconf 24 | from fairseq.file_io import PathManager 25 | from omegaconf import DictConfig 26 | 27 | from ..models.utils import get_row_mask 28 | 29 | 30 | def main(cfg: DictConfig): 31 | 32 | if isinstance(cfg, Namespace): 33 | cfg = convert_namespace_to_omegaconf(cfg) 34 | 35 | assert cfg.common_eval.path is not None, "--path required for generation!" 36 | assert ( 37 | not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam 38 | ), "--sampling requires --nbest to be equal to --beam" 39 | assert ( 40 | cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" 41 | ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" 42 | 43 | if cfg.common_eval.results_path is not None: 44 | os.makedirs(cfg.common_eval.results_path, exist_ok=True) 45 | output_path = os.path.join( 46 | cfg.common_eval.results_path, 47 | "generate-{}.txt".format(cfg.dataset.gen_subset), 48 | ) 49 | with open(output_path, "w", buffering=1, encoding="utf-8") as h: 50 | return _main(cfg, h) 51 | else: 52 | return _main(cfg, sys.stdout) 53 | 54 | 55 | def get_symbols_to_strip_from_output(generator): 56 | if hasattr(generator, "symbols_to_strip_from_output"): 57 | return generator.symbols_to_strip_from_output 58 | else: 59 | return {generator.eos} 60 | 61 | 62 | def loop_directory(directory: str) -> typing.List[str]: 63 | names = PathManager.ls(directory) 64 | pattern = re.compile("checkpoint_.*_(.*).pt") 65 | names = map(lambda x: (x, pattern.match(x)), names) 66 | names = filter(lambda x: x[1] is not None, names) 67 | names = sorted(names, key=lambda x: int(x[1].group(1))) 68 | return [f"{directory}/{name[0]}" for name in names] 69 | 70 | 71 | def _main(cfg: DictConfig, output_file): 72 | logging.basicConfig( 73 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 74 | datefmt="%Y-%m-%d %H:%M:%S", 75 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 76 | stream=output_file, 77 | ) 78 | logger = logging.getLogger("fairseq_cli.generate") 79 | 80 | utils.import_user_module(cfg.common) 81 | 82 | if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: 83 | cfg.dataset.max_tokens = 12000 84 | logger.info(cfg) 85 | 86 | # Fix seed for stochastic decoding 87 | if cfg.common.seed is not None and not cfg.generation.no_seed_provided: 88 | np.random.seed(cfg.common.seed) 89 | utils.set_torch_seed(cfg.common.seed) 90 | 91 | use_cuda = torch.cuda.is_available() and not cfg.common.cpu 92 | 93 | # Load dataset splits 94 | task = tasks.setup_task(cfg.task) 95 | task.load_dataset(cfg.dataset.gen_subset) 96 | 97 | # Set dictionaries 98 | try: 99 | src_dict = getattr(task, "source_dictionary", None) 100 | except NotImplementedError: 101 | src_dict = None 102 | tgt_dict = task.target_dictionary 103 | 104 | overrides = ast.literal_eval(cfg.common_eval.model_overrides) 105 | 106 | lang_pairs = task.args.lang_pairs 107 | # Load ensemble 108 | cached_weights = None 109 | sad_list = defaultdict(list) 110 | for checkpoint_path in loop_directory(cfg.common_eval.path): 111 | logger.info("loading model(s) from {}".format(checkpoint_path)) 112 | models, _model_args = checkpoint_utils.load_model_ensemble( 113 | [checkpoint_path], 114 | arg_overrides=overrides, 115 | task=task, 116 | suffix=cfg.checkpoint.checkpoint_suffix, 117 | strict=(cfg.checkpoint.checkpoint_shard_count == 1), 118 | num_shards=cfg.checkpoint.checkpoint_shard_count, 119 | ) 120 | model = models[0] 121 | if use_cuda: 122 | model = model.cuda() 123 | model.eval() 124 | 125 | if cached_weights is None: 126 | prev_none = True 127 | cached_weights = defaultdict(list) 128 | else: 129 | prev_none = False 130 | for lang_pair in lang_pairs: 131 | new_weights = [] 132 | src_lang, tgt_lang = lang_pair.split("-") 133 | model.patch_all_mask(src_lang=src_lang, tgt_lang=tgt_lang) 134 | for linear_module in model.model_loop_iter(): 135 | new_weights.append(linear_module.weight) 136 | if not prev_none: 137 | sad = 0 138 | for old_weight, now_weight in zip(cached_weights[lang_pair], new_weights): 139 | old_mask = get_row_mask(old_weight) 140 | now_mask = get_row_mask(now_weight) 141 | sad += torch.abs(old_mask.int()-now_mask.int()).sum() 142 | sad_list[lang_pair].append(sad.item()) 143 | 144 | cached_weights[lang_pair] = new_weights 145 | 146 | with PathManager.open(cfg.task.target_path, "w") as f: 147 | json.dump(sad_list, f) 148 | 149 | 150 | def cli_main(): 151 | parser = get_calculate_sad_parser() 152 | args = options.parse_args_and_arch(parser) 153 | main(args) 154 | 155 | 156 | def get_calculate_sad_parser(): 157 | parser = options.get_generation_parser() 158 | parser.add_argument("--target-path", type=str, help="The path to final sad list ") 159 | return parser 160 | 161 | 162 | if __name__ == "__main__": 163 | cli_main() 164 | -------------------------------------------------------------------------------- /fairseq_code/toolbox/calculate_sad_from_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Translate pre-processed data with a trained model. 8 | """ 9 | 10 | import argparse 11 | 12 | from fairseq import options 13 | 14 | from .calculate_sad import get_calculate_sad_parser 15 | from .util import get_all_task_options, load_config 16 | 17 | 18 | def main_from_config(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--config", type=str, required=True) 21 | parser.add_argument("--target-path", type=str, required=True) 22 | args = parser.parse_args() 23 | config = args.config 24 | config = load_config(config) 25 | param_args = get_all_task_options(config) 26 | param_args.extend( 27 | ["--target-path", 28 | args.target_path, 29 | "--path", 30 | config["save_dir"]] 31 | ) 32 | sad_parser = get_calculate_sad_parser() 33 | args = options.parse_args_and_arch(sad_parser, input_args=param_args) 34 | from .calculate_sad import main as sad_main 35 | sad_main(args) 36 | 37 | 38 | if __name__ == "__main__": 39 | main_from_config() 40 | -------------------------------------------------------------------------------- /fairseq_code/toolbox/count_flops_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Playground/LaSS/83be1881e9c25e26b598401d73867170af5adeba/fairseq_code/toolbox/count_flops_utils/__init__.py -------------------------------------------------------------------------------- /fairseq_code/toolbox/count_flops_utils/attach_new_forward.py: -------------------------------------------------------------------------------- 1 | import types 2 | from typing import Dict 3 | 4 | # from fvcore.nn import FlopCountAnalysis 5 | # import fvcore 6 | from fvcore import nn as fv_nn 7 | from torch import Tensor 8 | import torch 9 | 10 | from .module_wrapper import TracingAdapter 11 | 12 | 13 | _IGNORED_OPS = { 14 | "aten::add", 15 | "aten::add_", 16 | "aten::argmax", 17 | "aten::argsort", 18 | "aten::batch_norm", 19 | "aten::constant_pad_nd", 20 | "aten::div", 21 | "aten::div_", 22 | "aten::exp", 23 | "aten::log2", 24 | "aten::max_pool2d", 25 | "aten::meshgrid", 26 | "aten::mul", 27 | "aten::mul_", 28 | "aten::neg", 29 | "aten::nonzero_numpy", 30 | "aten::reciprocal", 31 | "aten::rsub", 32 | "aten::sigmoid", 33 | "aten::sigmoid_", 34 | "aten::softmax", 35 | "aten::sort", 36 | "aten::sqrt", 37 | "aten::sub", 38 | "torchvision::nms", 39 | } 40 | 41 | 42 | class FlopCountAnalysis(fv_nn.FlopCountAnalysis): 43 | """ 44 | Same as :class:`fvcore.nn.FlopCountAnalysis`, but supports detectron2 models. 45 | """ 46 | 47 | def __init__(self, model, inputs): 48 | """ 49 | Args: 50 | model (nn.Module): 51 | inputs (Any): inputs of the given model. Does not have to be tuple of tensors. 52 | """ 53 | wrapper = TracingAdapter(model, inputs, allow_non_tensor=True) 54 | super().__init__(wrapper, wrapper.flattened_inputs) 55 | self.set_op_handle(**{k: None for k in _IGNORED_OPS}) 56 | 57 | 58 | @torch.no_grad() 59 | def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): 60 | inputs = [sample] 61 | for k in ["prefix_tokens", "constraints", "bos_token"]: 62 | if k in kwargs: 63 | inputs.append(kwargs[k]) 64 | else: 65 | inputs.append(None) 66 | inputs = tuple(inputs) 67 | flops = FlopCountAnalysis(self, inputs) 68 | return flops 69 | 70 | 71 | def attach_new_generate_method_to_generator(generator): 72 | generator.generate = types.MethodType( 73 | generate, 74 | generator 75 | ) 76 | 77 | def count_flops(model, **kwargs): 78 | inputs = [] 79 | for k in ["src_tokens", "src_lengths", "prev_output_tokens", 80 | "return_all_hiddens", "features_only", "alignment_layer", "alignment_heads"]: 81 | if k in kwargs: 82 | inputs.append(kwargs[k]) 83 | else: 84 | inputs.append(None) 85 | inputs = tuple(inputs) 86 | flops = FlopCountAnalysis(model, inputs) 87 | return flops -------------------------------------------------------------------------------- /fairseq_code/toolbox/count_flops_utils/module_wrapper.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from contextlib import contextmanager, ExitStack 3 | from dataclasses import dataclass 4 | from typing import Callable, List, Optional, Tuple 5 | from unittest import mock 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from .registry import _convert_target_to_string, locate 11 | 12 | 13 | @contextmanager 14 | def patch_builtin_len(modules=()): 15 | """ 16 | Patch the builtin len() function of a few detectron2 modules 17 | to use __len__ instead, because __len__ does not convert values to 18 | integers and therefore is friendly to tracing. 19 | Args: 20 | modules (list[stsr]): names of extra modules to patch len(), in 21 | addition to those in detectron2. 22 | """ 23 | 24 | def _new_len(obj): 25 | return obj.__len__() 26 | 27 | with ExitStack() as stack: 28 | MODULES = list(modules) 29 | ctxs = [stack.enter_context(mock.patch(mod + ".len")) for mod in MODULES] 30 | for m in ctxs: 31 | m.side_effect = _new_len 32 | yield 33 | 34 | 35 | @dataclass 36 | class Schema: 37 | """ 38 | A Schema defines how to flatten a possibly hierarchical object into tuple of 39 | primitive objects, so it can be used as inputs/outputs of PyTorch's tracing. 40 | 41 | PyTorch does not support tracing a function that produces rich output 42 | structures (e.g. dict, Instances, Boxes). To trace such a function, we 43 | flatten the rich object into tuple of tensors, and return this tuple of tensors 44 | instead. Meanwhile, we also need to know how to "rebuild" the original object 45 | from the flattened results, so we can evaluate the flattened results. 46 | A Schema defines how to flatten an object, and while flattening it, it records 47 | necessary schemas so that the object can be rebuilt using the flattened outputs. 48 | 49 | The flattened object and the schema object is returned by ``.flatten`` classmethod. 50 | Then the original object can be rebuilt with the ``__call__`` method of schema. 51 | 52 | A Schema is a dataclass that can be serialized easily. 53 | """ 54 | 55 | # inspired by FetchMapper in tensorflow/python/client/session.py 56 | 57 | @classmethod 58 | def flatten(cls, obj): 59 | raise NotImplementedError 60 | 61 | def __call__(self, values): 62 | raise NotImplementedError 63 | 64 | @staticmethod 65 | def _concat(values): 66 | ret = () 67 | sizes = [] 68 | for v in values: 69 | assert isinstance(v, tuple), "Flattened results must be a tuple" 70 | ret = ret + v 71 | sizes.append(len(v)) 72 | return ret, sizes 73 | 74 | @staticmethod 75 | def _split(values, sizes): 76 | if len(sizes): 77 | expected_len = sum(sizes) 78 | assert ( 79 | len(values) == expected_len 80 | ), f"Values has length {len(values)} but expect length {expected_len}." 81 | ret = [] 82 | for k in range(len(sizes)): 83 | begin, end = sum(sizes[:k]), sum(sizes[: k + 1]) 84 | ret.append(values[begin:end]) 85 | return ret 86 | 87 | 88 | @dataclass 89 | class ListSchema(Schema): 90 | schemas: List[Schema] # the schemas that define how to flatten each element in the list 91 | sizes: List[int] # the flattened length of each element 92 | 93 | def __call__(self, values): 94 | values = self._split(values, self.sizes) 95 | if len(values) != len(self.schemas): 96 | raise ValueError( 97 | f"Values has length {len(values)} but schemas " f"has length {len(self.schemas)}!" 98 | ) 99 | values = [m(v) for m, v in zip(self.schemas, values)] 100 | return list(values) 101 | 102 | @classmethod 103 | def flatten(cls, obj): 104 | res = [flatten_to_tuple(k) for k in obj] 105 | values, sizes = cls._concat([k[0] for k in res]) 106 | return values, cls([k[1] for k in res], sizes) 107 | 108 | 109 | @dataclass 110 | class TupleSchema(ListSchema): 111 | def __call__(self, values): 112 | return tuple(super().__call__(values)) 113 | 114 | 115 | @dataclass 116 | class IdentitySchema(Schema): 117 | def __call__(self, values): 118 | return values[0] 119 | 120 | @classmethod 121 | def flatten(cls, obj): 122 | return (obj,), cls() 123 | 124 | 125 | @dataclass 126 | class DictSchema(ListSchema): 127 | keys: List[str] 128 | 129 | def __call__(self, values): 130 | values = super().__call__(values) 131 | return dict(zip(self.keys, values)) 132 | 133 | @classmethod 134 | def flatten(cls, obj): 135 | for k in obj.keys(): 136 | if not isinstance(k, str): 137 | raise KeyError("Only support flattening dictionaries if keys are str.") 138 | keys = sorted(obj.keys()) 139 | values = [obj[k] for k in keys] 140 | ret, schema = ListSchema.flatten(values) 141 | return ret, cls(schema.schemas, schema.sizes, keys) 142 | 143 | 144 | @dataclass 145 | # class InstancesSchema(DictSchema): 146 | # def __call__(self, values): 147 | # image_size, fields = values[-1], values[:-1] 148 | # fields = super().__call__(fields) 149 | # return Instances(image_size, **fields) 150 | # 151 | # @classmethod 152 | # def flatten(cls, obj): 153 | # ret, schema = super().flatten(obj.get_fields()) 154 | # size = obj.image_size 155 | # if not isinstance(size, torch.Tensor): 156 | # size = torch.tensor(size) 157 | # return ret + (size,), schema 158 | 159 | 160 | @dataclass 161 | class TensorWrapSchema(Schema): 162 | """ 163 | For classes that are simple wrapper of tensors, e.g. 164 | Boxes, RotatedBoxes, BitMasks 165 | """ 166 | 167 | class_name: str 168 | 169 | def __call__(self, values): 170 | return locate(self.class_name)(values[0]) 171 | 172 | @classmethod 173 | def flatten(cls, obj): 174 | return (obj.tensor,), cls(_convert_target_to_string(type(obj))) 175 | 176 | 177 | # if more custom structures needed in the future, can allow 178 | # passing in extra schemas for custom types 179 | def flatten_to_tuple(obj): 180 | """ 181 | Flatten an object so it can be used for PyTorch tracing. 182 | Also returns how to rebuild the original object from the flattened outputs. 183 | 184 | Returns: 185 | res (tuple): the flattened results that can be used as tracing outputs 186 | schema: an object with a ``__call__`` method such that ``schema(res) == obj``. 187 | It is a pure dataclass that can be serialized. 188 | """ 189 | schemas = [ 190 | ((str, bytes), IdentitySchema), 191 | (list, ListSchema), 192 | (tuple, TupleSchema), 193 | (collections.abc.Mapping, DictSchema), 194 | ] 195 | for klass, schema in schemas: 196 | if isinstance(obj, klass): 197 | F = schema 198 | break 199 | else: 200 | F = IdentitySchema 201 | 202 | return F.flatten(obj) 203 | 204 | 205 | class TracingAdapter(nn.Module): 206 | """ 207 | A model may take rich input/output format (e.g. dict or custom classes), 208 | but `torch.jit.trace` requires tuple of tensors as input/output. 209 | This adapter flattens input/output format of a model so it becomes traceable. 210 | 211 | It also records the necessary schema to rebuild model's inputs/outputs from flattened 212 | inputs/outputs. 213 | 214 | Example: 215 | :: 216 | outputs = model(inputs) # inputs/outputs may be rich structure 217 | adapter = TracingAdapter(model, inputs) 218 | 219 | # can now trace the model, with adapter.flattened_inputs, or another 220 | # tuple of tensors with the same length and meaning 221 | traced = torch.jit.trace(adapter, adapter.flattened_inputs) 222 | 223 | # traced model can only produce flattened outputs (tuple of tensors) 224 | flattened_outputs = traced(*adapter.flattened_inputs) 225 | # adapter knows the schema to convert it back (new_outputs == outputs) 226 | new_outputs = adapter.outputs_schema(flattened_outputs) 227 | """ 228 | 229 | flattened_inputs: Tuple[torch.Tensor] = None 230 | """ 231 | Flattened version of inputs given to this class's constructor. 232 | """ 233 | 234 | inputs_schema: Schema = None 235 | """ 236 | Schema of the inputs given to this class's constructor. 237 | """ 238 | 239 | outputs_schema: Schema = None 240 | """ 241 | Schema of the output produced by calling the given model with inputs. 242 | """ 243 | 244 | def __init__( 245 | self, 246 | model: nn.Module, 247 | inputs, 248 | inference_func: Optional[Callable] = None, 249 | allow_non_tensor: bool = False, 250 | ): 251 | """ 252 | Args: 253 | model: an nn.Module 254 | inputs: An input argument or a tuple of input arguments used to call model. 255 | After flattening, it has to only consist of tensors. 256 | inference_func: a callable that takes (model, *inputs), calls the 257 | model with inputs, and return outputs. By default it 258 | is ``lambda model, *inputs: model(*inputs)``. Can be override 259 | if you need to call the model differently. 260 | allow_non_tensor: allow inputs/outputs to contain non-tensor objects. 261 | This option will filter out non-tensor objects to make the 262 | model traceable, but ``inputs_schema``/``outputs_schema`` cannot be 263 | used anymore because inputs/outputs cannot be rebuilt from pure tensors. 264 | This is useful when you're only interested in the single trace of 265 | execution (e.g. for flop count), but not interested in 266 | generalizing the traced graph to new inputs. 267 | """ 268 | super().__init__() 269 | if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)): 270 | model = model.module 271 | self.model = model 272 | if not isinstance(inputs, tuple): 273 | inputs = (inputs,) 274 | self.inputs = inputs 275 | self.allow_non_tensor = allow_non_tensor 276 | 277 | if inference_func is None: 278 | inference_func = lambda model, *inputs: model(*inputs) # noqa 279 | self.inference_func = inference_func 280 | 281 | self.flattened_inputs, self.inputs_schema = flatten_to_tuple(inputs) 282 | 283 | if all(isinstance(x, torch.Tensor) for x in self.flattened_inputs): 284 | return 285 | if self.allow_non_tensor: 286 | self.flattened_inputs = tuple( 287 | [x for x in self.flattened_inputs if isinstance(x, torch.Tensor)] 288 | ) 289 | self.inputs_schema = None 290 | else: 291 | for input in self.flattened_inputs: 292 | if not isinstance(input, torch.Tensor): 293 | raise ValueError( 294 | "Inputs for tracing must only contain tensors. " 295 | f"Got a {type(input)} instead." 296 | ) 297 | 298 | def forward(self, *args: torch.Tensor): 299 | with torch.no_grad(), patch_builtin_len(): 300 | if self.inputs_schema is not None: 301 | inputs_orig_format = self.inputs_schema(args) 302 | else: 303 | if args != self.flattened_inputs: 304 | raise ValueError( 305 | "TracingAdapter does not contain valid inputs_schema." 306 | " So it cannot generalize to other inputs and must be" 307 | " traced with `.flattened_inputs`." 308 | ) 309 | inputs_orig_format = self.inputs 310 | 311 | outputs = self.inference_func(self.model, *inputs_orig_format) 312 | flattened_outputs, schema = flatten_to_tuple(outputs) 313 | 314 | flattened_output_tensors = tuple( 315 | [x for x in flattened_outputs if isinstance(x, torch.Tensor)] 316 | ) 317 | if len(flattened_output_tensors) < len(flattened_outputs): 318 | if self.allow_non_tensor: 319 | flattened_outputs = flattened_output_tensors 320 | self.outputs_schema = None 321 | else: 322 | raise ValueError( 323 | "Model cannot be traced because some model outputs " 324 | "cannot flatten to tensors." 325 | ) 326 | else: # schema is valid 327 | if self.outputs_schema is None: 328 | self.outputs_schema = schema 329 | else: 330 | assert self.outputs_schema == schema, ( 331 | "Model should always return outputs with the same " 332 | "structure so it can be traced!" 333 | ) 334 | return flattened_outputs 335 | 336 | def _create_wrapper(self, traced_model): 337 | """ 338 | Return a function that has an input/output interface the same as the 339 | original model, but it calls the given traced model under the hood. 340 | """ 341 | 342 | def forward(*args): 343 | flattened_inputs, _ = flatten_to_tuple(args) 344 | flattened_outputs = traced_model(*flattened_inputs) 345 | return self.outputs_schema(flattened_outputs) 346 | 347 | return forward -------------------------------------------------------------------------------- /fairseq_code/toolbox/count_flops_utils/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from typing import Any 4 | import pydoc 5 | from fvcore.common.registry import Registry # for backward compatibility. 6 | 7 | """ 8 | ``Registry`` and `locate` provide ways to map a string (typically found 9 | in config files) to callable objects. 10 | """ 11 | 12 | __all__ = ["Registry", "locate"] 13 | 14 | 15 | def _convert_target_to_string(t: Any) -> str: 16 | """ 17 | Inverse of ``locate()``. 18 | Args: 19 | t: any object with ``__module__`` and ``__qualname__`` 20 | """ 21 | module, qualname = t.__module__, t.__qualname__ 22 | 23 | # Compress the path to this object, e.g. ``module.submodule._impl.class`` 24 | # may become ``module.submodule.class``, if the later also resolves to the same 25 | # object. This simplifies the string, and also is less affected by moving the 26 | # class implementation. 27 | module_parts = module.split(".") 28 | for k in range(1, len(module_parts)): 29 | prefix = ".".join(module_parts[:k]) 30 | candidate = f"{prefix}.{qualname}" 31 | try: 32 | if locate(candidate) is t: 33 | return candidate 34 | except ImportError: 35 | pass 36 | return f"{module}.{qualname}" 37 | 38 | 39 | def locate(name: str) -> Any: 40 | """ 41 | Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``, 42 | such as "module.submodule.class_name". 43 | Raise Exception if it cannot be found. 44 | """ 45 | obj = pydoc.locate(name) 46 | 47 | # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly 48 | # by pydoc.locate. Try a private function from hydra. 49 | if obj is None: 50 | try: 51 | # from hydra.utils import get_method - will print many errors 52 | from hydra.utils import _locate 53 | except ImportError as e: 54 | raise ImportError(f"Cannot dynamically locate object {name}!") from e 55 | else: 56 | obj = _locate(name) # it raises if fails 57 | 58 | return obj -------------------------------------------------------------------------------- /fairseq_code/toolbox/generate_from_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script will evaluate the model. 3 | 4 | The generate.log will be saved at the ${save_dir}/results/${checkpoint_name}/${src}-${tgt}/generate.log 5 | 6 | """ 7 | 8 | 9 | import argparse 10 | import io 11 | from contextlib import redirect_stderr, redirect_stdout 12 | from copy import deepcopy 13 | 14 | from fairseq import options 15 | from tqdm import tqdm 16 | 17 | from fairseq_code.utils import file_operation 18 | from .util import load_config, get_all_task_options 19 | 20 | 21 | def get_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--config", type=str, required=True) 24 | parser.add_argument("--checkpoint-name", type=str, default="checkpoint_last.pt") 25 | parser.add_argument("--lang-pairs", type=str, default=None, 26 | help="THe format is the same as fairseq multilingual translaiton task") 27 | parser.add_argument("--evaluate-bin", type=str, default=None, 28 | help="The bin to evaluate") 29 | parser.add_argument("--count-inference-flops", action="store_true", 30 | help="Count the inference flops") 31 | return parser.parse_known_args() 32 | 33 | 34 | def main(): 35 | args, generate_args = get_args() 36 | config = args.config 37 | checkpoint_name = args.checkpoint_name 38 | lang_pairs = args.lang_pairs 39 | count_inference_flops = args.count_inference_flops 40 | 41 | config = load_config(config) 42 | 43 | if args.evaluate_bin is not None: 44 | config['data_bin'] = args.evaluate_bin 45 | 46 | param_args = get_all_task_options(config) 47 | checkpoint_path = file_operation.join_paths( 48 | config["save_dir"], 49 | checkpoint_name 50 | ) 51 | 52 | if lang_pairs is None: 53 | lang_pairs = config['lang_pairs'] 54 | lang_pairs = lang_pairs.strip().split(",") 55 | 56 | for lang_pair in tqdm(lang_pairs): 57 | src, tgt = lang_pair.strip().split("-") 58 | 59 | generate_log = io.StringIO() 60 | 61 | with redirect_stdout(generate_log), redirect_stderr(generate_log): 62 | generate(checkpoint_path, param_args, src=src, tgt=tgt, count_inference_flops=count_inference_flops, 63 | other_args=generate_args) 64 | 65 | generate_log.seek(0) 66 | generate_log = generate_log.read() 67 | print(generate_log) 68 | if count_inference_flops: 69 | lang_pair += "-flops" 70 | with file_operation.open_file( 71 | file_operation.join_paths( 72 | config["save_dir"], 73 | "results", 74 | checkpoint_name, 75 | lang_pair, 76 | "generate.log" 77 | ), 78 | "w" 79 | ) as f: 80 | f.write(generate_log) 81 | 82 | 83 | def generate(checkpoint_path, param_args, src, tgt, count_inference_flops=False, other_args=None): 84 | param_args = deepcopy(param_args) 85 | param_args.extend( 86 | [ 87 | "--path", 88 | checkpoint_path, 89 | "-s", src, 90 | "-t", tgt, 91 | ] 92 | ) 93 | if other_args is not None: 94 | param_args.extend(other_args) 95 | if other_args is not None and "--max-tokens" not in other_args: 96 | param_args.extend(["--max-tokens", "12000"]) 97 | parser = options.get_generation_parser() 98 | args = options.parse_args_and_arch(parser, input_args=param_args) 99 | if count_inference_flops: 100 | from .generate_to_count_flops import main as generate_main 101 | else: 102 | from fairseq_cli.generate import main as generate_main 103 | generate_main(args) 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /fairseq_code/toolbox/generate_mask_from_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script will evaluate the model. 3 | 4 | The generate.log will be saved at the ${save_dir}/results/${checkpoint_name}/${src}-${tgt}/generate.log 5 | 6 | """ 7 | 8 | 9 | import argparse 10 | from copy import deepcopy 11 | 12 | from fairseq import options 13 | 14 | from fairseq_code.utils import file_operation 15 | from .util import load_config, get_all_task_options 16 | from .generate_mask_from_softthreshold import get_parser 17 | from .generate_mask_from_softthreshold import main as generate_mask_main 18 | 19 | 20 | def get_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--config", type=str, required=True) 23 | parser.add_argument("--checkpoint-name", type=str, default="checkpoint_last.pt") 24 | parser.add_argument("--lang-pair", type=str, default=None, 25 | help="THe format is the same as fairseq multilingual translaiton task") 26 | parser.add_argument("--target-path", type=str, required=True, 27 | help="The target path of the mask") 28 | return parser.parse_args() 29 | 30 | 31 | def main(): 32 | args = get_args() 33 | config = args.config 34 | checkpoint_name = args.checkpoint_name 35 | lang_pair = args.lang_pair 36 | 37 | config = load_config(config) 38 | param_args = get_all_task_options(config) 39 | checkpoint_path = file_operation.join_paths( 40 | config["save_dir"], 41 | checkpoint_name 42 | ) 43 | 44 | src, tgt = lang_pair.strip().split("-") 45 | generate(checkpoint_path, param_args, src=src, tgt=tgt, 46 | target_path=args.target_path) 47 | 48 | 49 | def generate(checkpoint_path, param_args, src, tgt, target_path): 50 | param_args = deepcopy(param_args) 51 | param_args.extend( 52 | [ 53 | "--path", 54 | checkpoint_path, 55 | "-s", src, 56 | "-t", tgt, 57 | "--dest", target_path, 58 | ] 59 | ) 60 | parser = get_parser() 61 | args = options.parse_args_and_arch(parser, input_args=param_args) 62 | generate_mask_main(args) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /fairseq_code/toolbox/generate_mask_from_softthreshold.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Translate pre-processed data with a trained model. 8 | """ 9 | 10 | import ast 11 | import logging 12 | import math 13 | import os 14 | import sys 15 | from argparse import Namespace 16 | from itertools import chain 17 | 18 | import numpy as np 19 | import torch 20 | from fairseq import checkpoint_utils, options, scoring, tasks, utils 21 | from fairseq.checkpoint_utils import torch_persistent_save 22 | from fairseq.data import encoders 23 | from fairseq.dataclass.utils import convert_namespace_to_omegaconf 24 | from fairseq.logging import progress_bar 25 | from fairseq.logging.meters import StopwatchMeter, TimeMeter 26 | from omegaconf import DictConfig 27 | 28 | from torch import nn 29 | 30 | def loop_linear_module_for_model(m, no_mask_output_project=False, with_name=False, prefix=""): 31 | for n, c in m.named_children(): 32 | 33 | if no_mask_output_project and n == "output_projection": 34 | continue 35 | now_name = n if len(prefix) == 0 else prefix + "." + n 36 | 37 | if isinstance(c, nn.Linear): 38 | if with_name: 39 | yield now_name, c 40 | else: 41 | yield c 42 | yield from loop_linear_module_for_model(c, no_mask_output_project=no_mask_output_project, 43 | with_name=with_name, prefix=now_name) 44 | 45 | 46 | def main(cfg: DictConfig): 47 | 48 | if isinstance(cfg, Namespace): 49 | cfg = convert_namespace_to_omegaconf(cfg) 50 | 51 | assert cfg.common_eval.path is not None, "--path required for generation!" 52 | assert ( 53 | not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam 54 | ), "--sampling requires --nbest to be equal to --beam" 55 | assert ( 56 | cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" 57 | ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" 58 | 59 | if cfg.common_eval.results_path is not None: 60 | os.makedirs(cfg.common_eval.results_path, exist_ok=True) 61 | output_path = os.path.join( 62 | cfg.common_eval.results_path, 63 | "generate-{}.txt".format(cfg.dataset.gen_subset), 64 | ) 65 | with open(output_path, "w", buffering=1, encoding="utf-8") as h: 66 | return _main(cfg, h) 67 | else: 68 | return _main(cfg, sys.stdout) 69 | 70 | 71 | def get_symbols_to_strip_from_output(generator): 72 | if hasattr(generator, "symbols_to_strip_from_output"): 73 | return generator.symbols_to_strip_from_output 74 | else: 75 | return {generator.eos} 76 | 77 | 78 | def _main(cfg: DictConfig, output_file): 79 | logging.basicConfig( 80 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 81 | datefmt="%Y-%m-%d %H:%M:%S", 82 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 83 | stream=output_file, 84 | ) 85 | logger = logging.getLogger("fairseq_cli.generate") 86 | 87 | utils.import_user_module(cfg.common) 88 | 89 | if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: 90 | cfg.dataset.max_tokens = 12000 91 | logger.info(cfg) 92 | 93 | # Fix seed for stochastic decoding 94 | if cfg.common.seed is not None and not cfg.generation.no_seed_provided: 95 | np.random.seed(cfg.common.seed) 96 | utils.set_torch_seed(cfg.common.seed) 97 | 98 | use_cuda = torch.cuda.is_available() and not cfg.common.cpu 99 | 100 | # Load dataset splits 101 | task = tasks.setup_task(cfg.task) 102 | task.load_dataset(cfg.dataset.gen_subset) 103 | 104 | # Set dictionaries 105 | try: 106 | src_dict = getattr(task, "source_dictionary", None) 107 | except NotImplementedError: 108 | src_dict = None 109 | tgt_dict = task.target_dictionary 110 | 111 | overrides = ast.literal_eval(cfg.common_eval.model_overrides) 112 | 113 | # Load ensemble 114 | logger.info("loading model(s) from {}".format(cfg.common_eval.path)) 115 | models, _model_args = checkpoint_utils.load_model_ensemble( 116 | utils.split_paths(cfg.common_eval.path), 117 | arg_overrides=overrides, 118 | task=task, 119 | suffix=cfg.checkpoint.checkpoint_suffix, 120 | strict=(cfg.checkpoint.checkpoint_shard_count == 1), 121 | num_shards=cfg.checkpoint.checkpoint_shard_count, 122 | ) 123 | 124 | model = models[0] 125 | source_lang = cfg['task'].source_lang 126 | target_lang = cfg['task'].target_lang 127 | model.eval() 128 | model.patch_all_mask(source_lang, target_lang) 129 | mask_dict = {} 130 | if hasattr(model, "model_loop_iter"): 131 | itr = model.model_loop_iter(with_name=True) 132 | else: 133 | itr = loop_linear_module_for_model(model, no_mask_output_project=True, with_name=True) 134 | for name, m in itr: 135 | if name.startswith("soft_threshold"): 136 | continue 137 | mask_dict[name] = (m.weight != 0) 138 | torch_persistent_save(mask_dict, cfg['task'].dest) 139 | 140 | 141 | def cli_main(): 142 | parser = get_parser() 143 | args = options.parse_args_and_arch(parser) 144 | main(args) 145 | 146 | 147 | def get_parser(): 148 | parser = options.get_generation_parser() 149 | group = parser.add_argument_group("generate_mask") 150 | group.add_argument("--dest", type=str) 151 | return parser 152 | 153 | 154 | if __name__ == "__main__": 155 | cli_main() 156 | -------------------------------------------------------------------------------- /fairseq_code/toolbox/generate_to_count_flops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Translate pre-processed data with a trained model. 8 | """ 9 | 10 | import ast 11 | import logging 12 | import math 13 | import os 14 | import sys 15 | from argparse import Namespace 16 | from itertools import chain 17 | import types 18 | 19 | import numpy as np 20 | import torch 21 | from fairseq import checkpoint_utils, options, scoring, tasks, utils 22 | from fairseq.data import encoders 23 | from fairseq.dataclass.utils import convert_namespace_to_omegaconf 24 | from fairseq.logging import progress_bar, metrics 25 | from fairseq.logging.meters import StopwatchMeter, TimeMeter 26 | from omegaconf import DictConfig 27 | 28 | from .count_flops_utils.attach_new_forward import count_flops 29 | 30 | 31 | def main(cfg: DictConfig): 32 | 33 | if isinstance(cfg, Namespace): 34 | cfg = convert_namespace_to_omegaconf(cfg) 35 | 36 | assert cfg.common_eval.path is not None, "--path required for generation!" 37 | assert ( 38 | not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam 39 | ), "--sampling requires --nbest to be equal to --beam" 40 | assert ( 41 | cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" 42 | ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" 43 | 44 | if cfg.common_eval.results_path is not None: 45 | os.makedirs(cfg.common_eval.results_path, exist_ok=True) 46 | output_path = os.path.join( 47 | cfg.common_eval.results_path, 48 | "generate-{}.txt".format(cfg.dataset.gen_subset), 49 | ) 50 | with open(output_path, "w", buffering=1, encoding="utf-8") as h: 51 | return _main(cfg, h) 52 | else: 53 | return _main(cfg, sys.stdout) 54 | 55 | 56 | def get_symbols_to_strip_from_output(generator): 57 | if hasattr(generator, "symbols_to_strip_from_output"): 58 | return generator.symbols_to_strip_from_output 59 | else: 60 | return {generator.eos} 61 | 62 | 63 | def get_new_set_masked_weight_fn(model): 64 | old_fn = model.set_masked_weight 65 | def set_masked_weight(self, linear_module, masked_weight): 66 | masked_weight = masked_weight.detach() 67 | old_fn(linear_module, masked_weight) 68 | return set_masked_weight 69 | 70 | 71 | def _main(cfg: DictConfig, output_file): 72 | logging.basicConfig( 73 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 74 | datefmt="%Y-%m-%d %H:%M:%S", 75 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 76 | stream=output_file, 77 | ) 78 | logger = logging.getLogger("fairseq_cli.generate") 79 | 80 | utils.import_user_module(cfg.common) 81 | 82 | if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: 83 | cfg.dataset.max_tokens = 12000 84 | logger.info(cfg) 85 | 86 | # Fix seed for stochastic decoding 87 | if cfg.common.seed is not None and not cfg.generation.no_seed_provided: 88 | np.random.seed(cfg.common.seed) 89 | utils.set_torch_seed(cfg.common.seed) 90 | 91 | use_cuda = torch.cuda.is_available() and not cfg.common.cpu 92 | 93 | # Load dataset splits 94 | task = tasks.setup_task(cfg.task) 95 | task.load_dataset(cfg.dataset.gen_subset) 96 | 97 | # Set dictionaries 98 | try: 99 | src_dict = getattr(task, "source_dictionary", None) 100 | except NotImplementedError: 101 | src_dict = None 102 | tgt_dict = task.target_dictionary 103 | 104 | overrides = ast.literal_eval(cfg.common_eval.model_overrides) 105 | 106 | # Load ensemble 107 | logger.info("loading model(s) from {}".format(cfg.common_eval.path)) 108 | models, _model_args = checkpoint_utils.load_model_ensemble( 109 | utils.split_paths(cfg.common_eval.path), 110 | arg_overrides=overrides, 111 | task=task, 112 | suffix=cfg.checkpoint.checkpoint_suffix, 113 | strict=(cfg.checkpoint.checkpoint_shard_count == 1), 114 | num_shards=cfg.checkpoint.checkpoint_shard_count, 115 | ) 116 | 117 | if cfg.generation.lm_path is not None: 118 | overrides["data"] = cfg.task.data 119 | 120 | try: 121 | lms, _ = checkpoint_utils.load_model_ensemble( 122 | [cfg.generation.lm_path], arg_overrides=overrides, task=None 123 | ) 124 | except: 125 | logger.warning( 126 | f"Failed to load language model! Please make sure that the language model dict is the same " 127 | f"as target dict and is located in the data dir ({cfg.task.data})" 128 | ) 129 | raise 130 | 131 | assert len(lms) == 1 132 | else: 133 | lms = [None] 134 | 135 | # Optimize ensemble for generation 136 | for model in chain(models, lms): 137 | if model is None: 138 | continue 139 | if cfg.common.fp16: 140 | model.half() 141 | if use_cuda and not cfg.distributed_training.pipeline_model_parallel: 142 | model.cuda() 143 | model.prepare_for_inference_(cfg) 144 | 145 | # Load alignment dictionary for unknown word replacement 146 | # (None if no unknown word replacement, empty if no path to align dictionary) 147 | align_dict = utils.load_align_dict(cfg.generation.replace_unk) 148 | 149 | # Load dataset (possibly sharded) 150 | itr = task.get_batch_iterator( 151 | dataset=task.dataset(cfg.dataset.gen_subset), 152 | max_tokens=cfg.dataset.max_tokens, 153 | max_sentences=cfg.dataset.batch_size, 154 | max_positions=utils.resolve_max_positions( 155 | task.max_positions(), *[m.max_positions() for m in models] 156 | ), 157 | ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, 158 | required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, 159 | seed=cfg.common.seed, 160 | num_shards=cfg.distributed_training.distributed_world_size, 161 | shard_id=cfg.distributed_training.distributed_rank, 162 | num_workers=cfg.dataset.num_workers, 163 | data_buffer_size=cfg.dataset.data_buffer_size, 164 | ).next_epoch_itr(shuffle=False) 165 | progress = progress_bar.progress_bar( 166 | itr, 167 | log_format=cfg.common.log_format, 168 | log_interval=cfg.common.log_interval, 169 | default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), 170 | ) 171 | 172 | # Initialize generator 173 | gen_timer = StopwatchMeter() 174 | 175 | extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} 176 | for model in models: 177 | model.set_masked_weight = types.MethodType( 178 | get_new_set_masked_weight_fn(model), 179 | model 180 | ) 181 | 182 | generator = task.build_generator( 183 | models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs 184 | ) 185 | # attach_new_generate_method_to_generator(generator) 186 | 187 | # Handle tokenization and BPE 188 | tokenizer = encoders.build_tokenizer(cfg.tokenizer) 189 | bpe = encoders.build_bpe(cfg.bpe) 190 | 191 | def decode_fn(x): 192 | if bpe is not None: 193 | x = bpe.decode(x) 194 | if tokenizer is not None: 195 | x = tokenizer.decode(x) 196 | return x 197 | 198 | model = models[0] 199 | 200 | flops_key_set = {"total"} 201 | with metrics.aggregate(new_root=True) as agg: 202 | for sample in progress: 203 | sample = utils.move_to_cuda(sample) if use_cuda else sample 204 | if "net_input" not in sample: 205 | continue 206 | 207 | with torch.no_grad(): 208 | flops = count_flops( 209 | model, 210 | **sample['net_input'], 211 | ) 212 | metrics.log_scalar("total", float(flops.total())) 213 | for k, v in flops.by_operator().items(): 214 | flops_key_set.add(k) 215 | metrics.log_scalar(k, float(v)) 216 | 217 | # for k in flops_key_set: 218 | # v = metrics.get_meter("default", k) 219 | # print(f"Flops of {k}: {v}") 220 | for k, v in agg.get_smoothed_values().items(): 221 | print(f"Flops of {k}: {v}") 222 | 223 | 224 | def cli_main(): 225 | parser = options.get_generation_parser() 226 | args = options.parse_args_and_arch(parser) 227 | main(args) 228 | 229 | 230 | if __name__ == "__main__": 231 | cli_main() 232 | -------------------------------------------------------------------------------- /fairseq_code/toolbox/util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import yaml 4 | from fairseq import tasks 5 | from fairseq.file_io import PathManager 6 | 7 | 8 | def get_all_task_options(config): 9 | task = tasks.get_task(config['task']) 10 | parser = argparse.ArgumentParser() 11 | task.add_args(parser) 12 | option_names = list(parser._option_string_actions.keys()) 13 | option_names = filter(lambda x: x.startswith("--"), option_names) 14 | option_names = map(lambda x: x.lstrip("-"), option_names) 15 | option_names = list(option_names) 16 | option_names += ["bpe", "tokenizer"] 17 | param_args = [config["data_bin"], "--task", config['task']] 18 | for option_name in option_names: 19 | key_in_config = option_name.replace("-", "_") 20 | if key_in_config not in config: 21 | continue 22 | option_value = config[key_in_config] 23 | if isinstance(option_value, bool): 24 | if option_value: 25 | param_args.append( 26 | f"--{option_name}" 27 | ) 28 | else: 29 | param_args.append( 30 | f"--{option_name}" 31 | ) 32 | param_args.append( 33 | str(option_value) 34 | ) 35 | return param_args 36 | 37 | 38 | def load_config(config_path): 39 | with PathManager.open(config_path, "r") as f: 40 | config = yaml.safe_load(f) 41 | return config -------------------------------------------------------------------------------- /fairseq_code/utils/common.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import argparse 4 | import os 5 | import subprocess 6 | import typing 7 | import traceback 8 | 9 | 10 | from .logging import get_logger 11 | 12 | 13 | logger = get_logger() 14 | 15 | 16 | def python_interpreter(): 17 | res = os.environ.get('PYTHON_INTERPRETER', 'python3') 18 | logger.info("python interpreter:{}".format(res)) 19 | return res 20 | 21 | 22 | def command(cmd, *args, raise_exception=True, print_error=True, 23 | use_system=True): 24 | if len(args) != 0: 25 | cmd = cmd.format(*args) 26 | logger.info(cmd) 27 | if not use_system: 28 | popen = subprocess.Popen( 29 | cmd, 30 | shell=True, 31 | stderr=subprocess.PIPE, 32 | encoding="utf-8", 33 | ) 34 | _, err = popen.communicate() 35 | if popen.returncode != 0: 36 | if print_error: 37 | sys.stderr.write(err) 38 | if raise_exception: 39 | raise RuntimeError(f"The cmd '{cmd}' run failed") 40 | else: 41 | popen.stderr.close() 42 | return_code = popen.returncode 43 | else: 44 | return_code = os.system(cmd) 45 | if return_code != 0 and raise_exception: 46 | raise RuntimeError(f"The cmd '{cmd}' run failed") 47 | 48 | return return_code 49 | 50 | 51 | def popen_command(cmd: typing.Union[typing.List[str], str], shell: bool=False, 52 | stdin=None): 53 | logger.info(cmd) 54 | pipe = subprocess.Popen(cmd, 55 | stdin=stdin, 56 | stdout=subprocess.PIPE, 57 | stderr=subprocess.PIPE, 58 | encoding="utf-8", 59 | shell=shell) 60 | std, err = pipe.communicate(stdin) 61 | if pipe.returncode != 0: 62 | sys.stderr.write(err) 63 | raise RuntimeError(f"The cmd {cmd} run failed") 64 | else: 65 | res = std.strip().split("\n") 66 | return res 67 | 68 | 69 | _cpu_number = None 70 | 71 | 72 | def get_cpu_num(): 73 | return _cpu_number 74 | 75 | 76 | def pipe(*cmd, need_output=False): 77 | p = None 78 | for i, c in enumerate(cmd): 79 | in_pipe = None 80 | if p is not None: 81 | in_pipe = p.stdout 82 | 83 | out_pipe = subprocess.PIPE 84 | if i == len(cmd) - 1 and not need_output: 85 | out_pipe = subprocess.DEVNULL 86 | _p = subprocess.Popen(c, shell=True, 87 | stdin=in_pipe, 88 | stdout=out_pipe) 89 | if in_pipe is not None: 90 | p.stdout.close() 91 | p = _p 92 | if need_output: 93 | return p.communicate()[0] 94 | else: 95 | p.wait() 96 | 97 | 98 | def print_exception(e: Exception): 99 | traceback.print_tb(e.__traceback__) 100 | 101 | -------------------------------------------------------------------------------- /fairseq_code/utils/file_operation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shlex 3 | from abc import abstractmethod 4 | 5 | import tempfile 6 | from io import StringIO 7 | from datetime import datetime 8 | 9 | from pathlib import Path 10 | 11 | import subprocess 12 | import warnings 13 | from contextlib import contextmanager 14 | from pathlib import Path 15 | 16 | import typing 17 | 18 | from .common import command, popen_command 19 | from .logging import get_logger 20 | import glob as _glob 21 | import portalocker 22 | 23 | 24 | def file_lock(path: str): # type: ignore 25 | """ 26 | A file lock. Once entered, it is guaranteed that no one else holds the 27 | same lock. Others trying to enter the lock will block for 30 minutes and 28 | raise an exception. 29 | 30 | This is useful to make sure workers don't cache files to the same location. 31 | 32 | Args: 33 | path (str): a path to be locked. This function will create a lock named 34 | `path + ".lock"` 35 | 36 | Examples: 37 | 38 | filename = "/path/to/file" 39 | with file_lock(filename): 40 | if not os.path.isfile(filename): 41 | do_create_file() 42 | """ 43 | dirname = os.path.dirname(path) 44 | try: 45 | os.makedirs(dirname, exist_ok=True) 46 | except OSError: 47 | # makedir is not atomic. Exceptions can happen when multiple workers try 48 | # to create the same dir, despite exist_ok=True. 49 | # When this happens, we assume the dir is created and proceed to creating 50 | # the lock. If failed to create the directory, the next line will raise 51 | # exceptions. 52 | pass 53 | return portalocker.Lock(path + ".lock", timeout=1800) # type: ignore 54 | 55 | 56 | logger = get_logger() 57 | 58 | 59 | class PathManagerRegister(object): 60 | def __init__(self): 61 | self.d = {} 62 | self.local_path_manager = LocalPathManager 63 | 64 | def register(self, c): 65 | assert issubclass(c, PathManager) 66 | self.d[c.prefix] = c 67 | 68 | def get_path_manager(self, path: str): 69 | for k, c in self.d.items(): 70 | if path.startswith(k): 71 | return c 72 | return self.local_path_manager 73 | 74 | 75 | hdfs_prefix = "hdfs://" 76 | 77 | 78 | def is_hdfs_path(p: str): 79 | if p.startswith(hdfs_prefix): 80 | return True 81 | else: 82 | return False 83 | 84 | 85 | class PathManager(object): 86 | def __init__(self, *path: str): 87 | self.path = self._join_paths(*path) 88 | 89 | @staticmethod 90 | def build(*path: str): 91 | assert len(path) > 0, \ 92 | f"The path input can not be empty" 93 | return path_manager_register.get_path_manager(path[0])(*path) 94 | 95 | @abstractmethod 96 | def _join_paths(self, *path: str): 97 | pass 98 | 99 | @abstractmethod 100 | def make_dir(self): 101 | pass 102 | 103 | @abstractmethod 104 | def dir_exists(self): 105 | pass 106 | 107 | @abstractmethod 108 | def file_exists(self): 109 | pass 110 | 111 | def check_file_exist(self): 112 | path = self.path 113 | if not self.file_exists(): 114 | raise ValueError(f"The file {path} is not existed") 115 | 116 | def check_dir_exist(self): 117 | path = self.path 118 | if not self.dir_exists(): 119 | raise ValueError(f"directory {path} does not exist") 120 | 121 | @abstractmethod 122 | def ls_dir(self): 123 | pass 124 | 125 | @abstractmethod 126 | def file_lines_number(self): 127 | pass 128 | 129 | @abstractmethod 130 | def remove_file(self): 131 | pass 132 | 133 | @abstractmethod 134 | def sync_file_to(self, target_path: str, source_no_existed_ignore=False, 135 | overwrite_target=True): 136 | pass 137 | 138 | @abstractmethod 139 | def symlink_to(self, target, force_empty=False): 140 | pass 141 | 142 | def join_paths(self): 143 | return self.path 144 | 145 | @abstractmethod 146 | def split(self): 147 | pass 148 | 149 | def build_contain_dir(self): 150 | """ 151 | create the directory contains the file self.path 152 | """ 153 | container_dir = self.split()[0] 154 | if len(container_dir.strip()) == 0: 155 | return 156 | PathManager.build(container_dir).make_dir() 157 | 158 | @abstractmethod 159 | def open(self, mode): 160 | pass 161 | 162 | @abstractmethod 163 | def cat(self): 164 | """ 165 | When use this function, please ensure read all strings 166 | from the stream. 167 | :return: A stream of string 168 | """ 169 | pass 170 | 171 | @abstractmethod 172 | def glob(self): 173 | pass 174 | 175 | @abstractmethod 176 | def modified_time(self): 177 | pass 178 | 179 | @abstractmethod 180 | def append_to_file(self, s: str): 181 | pass 182 | 183 | 184 | class LocalPathManager(PathManager): 185 | 186 | def append_to_file(self, s: str): 187 | s = shlex.quote(s) 188 | command(f"echo {s} >> {self.path}") 189 | 190 | def modified_time(self): 191 | return datetime.fromtimestamp( 192 | os.path.getmtime(self.path) 193 | ) 194 | 195 | def glob(self): 196 | return _glob.glob(self.path) 197 | 198 | @contextmanager 199 | def cat(self): 200 | logger.info(f"To cat the path {self.path}") 201 | p = subprocess.Popen( 202 | f"cat {self.path}", 203 | stdout=subprocess.PIPE, 204 | shell=True, 205 | stderr=subprocess.PIPE, 206 | encoding="utf-8", 207 | ) 208 | try: 209 | yield p.stdout 210 | finally: 211 | for _ in p.stdout: 212 | pass 213 | error_msg = p.stderr.readlines() 214 | p.wait() 215 | if p.returncode != 0: 216 | raise ValueError("".join(error_msg)) 217 | else: 218 | p.stderr.close() 219 | 220 | @contextmanager 221 | def open(self, mode): 222 | if self.file_exists(): 223 | pass 224 | elif self.dir_exists(): 225 | raise ValueError(f"Can not open the directory {self.path}") 226 | elif "r" in mode: 227 | raise ValueError(f"Can not read a non-existed file! {self.path}") 228 | else: 229 | self.build_contain_dir() 230 | with open(self.path, mode) as f: 231 | yield f 232 | 233 | def _join_paths(self, *path: str): 234 | return os.path.join(*path) 235 | 236 | def make_dir(self): 237 | command(f"mkdir -p {self.path}") 238 | 239 | def dir_exists(self): 240 | path = self.path 241 | if os.path.exists(path): 242 | if os.path.isdir(path): 243 | return True 244 | return False 245 | 246 | def file_exists(self): 247 | path = self.path 248 | if os.path.exists(path): 249 | if os.path.isfile(path): 250 | return True 251 | return False 252 | 253 | def ls_dir(self): 254 | self.check_dir_exist() 255 | return os.listdir(self.path) 256 | 257 | def file_lines_number(self): 258 | file_path = self.path 259 | check_file_exist(file_path) 260 | p = subprocess.Popen(f"wc -l {file_path}", shell=True, stdout=subprocess.PIPE) 261 | file_length = p.stdout.readlines()[0] 262 | return int(file_length.strip().split()[0]) 263 | 264 | def remove_file(self): 265 | # self.check_file_exist() 266 | assert self.dir_exists() or self.file_exists() 267 | command(f"rm -r {self.path}") 268 | 269 | def sync_file_to(self, target_path: str, source_no_existed_ignore=False, 270 | overwrite_target=True): 271 | target_path = PathManager.build(target_path) 272 | target_path.build_contain_dir() 273 | t_path = target_path.path 274 | if isinstance(target_path, LocalPathManager): 275 | command(f"cp {self.path} {t_path}") 276 | elif isinstance(target_path, HdfsPathManager): 277 | overwrite = "-f" 278 | command(f"hadoop fs -put {overwrite} {self.path} {t_path}", use_system=False) 279 | 280 | def symlink_to(self, target, force_empty=False): 281 | target_path = self.build(target) 282 | assert isinstance(target_path, self.__class__) 283 | source = self.path 284 | if os.path.exists(target): 285 | if force_empty: 286 | raise ValueError(f"The target {target} existed") 287 | assert Path(target).resolve() == Path(source).resolve() 288 | return 289 | make_dir(os.path.split(target)[0]) 290 | os.symlink( 291 | source, 292 | target, 293 | ) 294 | 295 | def split(self): 296 | return os.path.split(self.path) 297 | 298 | def sync_to_local(self): 299 | return self.path 300 | 301 | 302 | class HdfsPathManager(PathManager): 303 | prefix: str = hdfs_prefix 304 | 305 | def append_to_file(self, s: str): 306 | s = shlex.quote(s) 307 | command(f"echo {s} | hadoop fs -appendToFile - {self.path}") 308 | 309 | def modified_time(self): 310 | res = popen_command(["hadoop", "fs", "-stat", "%y", self.path]) 311 | return datetime.strptime(res[0].strip(), "%Y-%m-%d %H:%M:%S") 312 | 313 | def glob(self): 314 | res = popen_command(["hadoop", "fs", "-ls", self.path]) 315 | # res = res[1:] 316 | # print(res) 317 | res = map(lambda x: x.strip().split(), res) 318 | res = filter(lambda x: len(x) == 8, res) 319 | res = [t[-1] for t in res] 320 | # res = [t.split("/")[-1] for t in res] 321 | return res 322 | 323 | @contextmanager 324 | def cat(self): 325 | logger.info(f"To cat the path {self.path}") 326 | p = subprocess.Popen( 327 | f"hadoop fs -text {self.path}", 328 | shell=True, 329 | stdout=subprocess.PIPE, 330 | stderr=subprocess.PIPE, 331 | encoding="utf-8", 332 | ) 333 | try: 334 | yield p.stdout 335 | finally: 336 | # read all lines in stdout and stderr 337 | # Avoid the deadlock for wait 338 | for l in p.stdout: 339 | pass 340 | p.stdout.close() 341 | error_msg = "".join(p.stderr.readlines()) 342 | p.stderr.close() 343 | p.wait() 344 | if p.returncode == 0: 345 | pass 346 | else: 347 | raise ValueError(error_msg) 348 | 349 | 350 | @staticmethod 351 | def _cache_path(p): 352 | home = str(Path.home()) 353 | return LocalPathManager(home, "__cached_dir", p[len(hdfs_prefix):]) 354 | 355 | @staticmethod 356 | def _cache_persist_path(p): 357 | home = str(Path.home()) 358 | return LocalPathManager(home, "__cached_persist_dir", p[len(hdfs_prefix):]) 359 | 360 | @contextmanager 361 | def open(self, mode): 362 | cache_path = self._cache_path(self.path) 363 | if self.file_exists(): 364 | sync_file(cache_path.path, self.path) 365 | elif self.dir_exists(): 366 | raise ValueError("Can not open a directory") 367 | elif "r" in mode: 368 | raise ValueError(f"Can not open a non-existed file {self.path}") 369 | else: 370 | cache_path.build_contain_dir() 371 | try: 372 | with cache_path.open(mode) as f: 373 | yield f 374 | if "a" in mode or "w" in mode: 375 | sync_file(self.path, cache_path.path) 376 | finally: 377 | cache_path.remove_file() 378 | 379 | def _join_paths(self, *path: str): 380 | return "/".join(path) 381 | 382 | def make_dir(self): 383 | command(f"hadoop fs -mkdir -p {self.path}", use_system=False) 384 | 385 | def dir_exists(self): 386 | res = command(f"hadoop fs -test -d {self.path}", raise_exception=False, print_error=False, use_system=False) 387 | if res != 0: 388 | return False 389 | else: 390 | return True 391 | 392 | def file_exists(self): 393 | res = command(f"hadoop fs -test -f {self.path}", raise_exception=False, print_error=False, use_system=False) 394 | if res != 0: 395 | return False 396 | else: 397 | return True 398 | 399 | def ls_dir(self): 400 | try: 401 | res = popen_command(["hadoop", "fs", "-ls", self.path]) 402 | except RuntimeError as e: 403 | return [] 404 | # res = res[1:] 405 | # print(res) 406 | res = map(lambda x: x.strip().split(), res) 407 | res = filter(lambda x: len(x) == 8, res) 408 | res = [t[-1] for t in res] 409 | res = [t.split("/")[-1] for t in res] 410 | return res 411 | 412 | def file_lines_number(self): 413 | res = popen_command([f"hadoop fs -cat {self.path} | wc -l"], shell=True) 414 | return int(res[0]) 415 | 416 | def remove_file(self): 417 | command(f"hadoop fs -rm -f {self.path}", use_system=False) 418 | 419 | def sync_file_to(self, target_path: str, source_no_existed_ignore=False, 420 | overwrite_target=True): 421 | target_path = PathManager.build(target_path) 422 | target_path.build_contain_dir() 423 | 424 | if overwrite_target and target_path.file_exists(): 425 | target_path.remove_file() 426 | 427 | t_path = target_path.path 428 | 429 | if isinstance(target_path, HdfsPathManager): 430 | command(f"hadoop fs -cp {self.path} {t_path}", use_system=False) 431 | elif isinstance(target_path, LocalPathManager): 432 | command(f"hadoop fs -copyToLocal {self.path} {t_path}", use_system=False) 433 | 434 | def symlink_to(self, target, force_empty=False): 435 | raise ValueError("The hdfs path can not create symbol link") 436 | 437 | def split(self): 438 | paths = self.path.split("/") 439 | return "/".join(paths[:-1]), paths[-1] 440 | 441 | def sync_to_local(self): 442 | local_path = self._cache_persist_path(self.path) 443 | with file_lock(local_path.path): 444 | sync_file(local_path.path, self.path) 445 | return local_path.path 446 | 447 | 448 | path_manager_register = PathManagerRegister() 449 | path_manager_register.register(HdfsPathManager) 450 | 451 | 452 | def make_dir(*path: str): 453 | """ 454 | This method will recursively create the directory 455 | :param path: A variable length parameter. The function will use the os.path.join to the path list 456 | :return: 457 | """ 458 | PathManager.build(*path).make_dir() 459 | 460 | 461 | def dir_exists(path: str): 462 | return PathManager.build(path).dir_exists() 463 | 464 | 465 | def file_exists(path: str): 466 | return PathManager.build(path).file_exists() 467 | 468 | 469 | def check_file_exist(path: str): 470 | PathManager.build(path).check_file_exist() 471 | 472 | 473 | def check_dir_exist(path: str): 474 | PathManager.build(path).check_dir_exist() 475 | 476 | 477 | def ls_dir(path: str): 478 | return PathManager.build(path).ls_dir() 479 | 480 | 481 | def file_lines_number(file_path: str): 482 | return PathManager.build(file_path).file_lines_number() 483 | 484 | 485 | def split_file(file_name, splited_dir, buckets): 486 | # with open(file_name, "r") as f: 487 | # lines = f.readlines() 488 | # lines = more_itertools.distribute(buckets, lines) 489 | # file_name = os.path.split(file_name)[1] 490 | # res = [] 491 | # logger.info("TO split file:{}".format(file_name)) 492 | # for idx, l in enumerate(lines): 493 | # split_path = os.path.join(splited_dir, "{}-{}".format(file_name, idx)) 494 | # with open(split_path, "w") as f: 495 | # l = map(lambda x: x.strip(), l) 496 | # f.write("\n".join(l)) 497 | # logger.info("split to {}".format(split_path)) 498 | # res.append(split_path) 499 | # return res 500 | buckets = int(buckets) 501 | name = os.path.split(file_name)[1] 502 | target_file_pattern = os.path.join(splited_dir, f"{name}.") 503 | command(f"split -n l/{buckets} {file_name} {target_file_pattern}") 504 | return _glob.glob(target_file_pattern+"*") 505 | 506 | 507 | def merge_files(target_file, source_files): 508 | logger.info("Begin merge {} to {}".format(" ".join(source_files), target_file)) 509 | target_dir = os.path.split(target_file)[0] 510 | if not dir_exists(target_dir): 511 | warnings.warn("The target dir {} is not existed, it will be created".format(target_file)) 512 | make_dir(target_dir) 513 | 514 | os.system("cat {} > {}".format(" ".join(source_files), target_file)) 515 | logger.info("End merge {} to {}".format(" ".join(source_files), target_file)) 516 | 517 | 518 | def remove_files(*args: str): 519 | for path in args: 520 | PathManager.build(path).remove_file() 521 | 522 | 523 | def temp_dir(d: str=None): 524 | if d is not None: 525 | assert isinstance(PathManager.build(d), LocalPathManager) 526 | if not dir_exists(d): 527 | make_dir(d) 528 | return tempfile.TemporaryDirectory(dir=d) 529 | 530 | 531 | @contextmanager 532 | def temp_file(d=None): 533 | with temp_dir(d) as t_d: 534 | yield os.path.join(t_d, "tmp") 535 | 536 | 537 | def _sync_file_ignore_existed(target_path: str, source_path: str, source_no_existed_ignore: bool = False, 538 | overwrite_old_file: bool = False, overwrite_target: bool = False): 539 | target_path = PathManager.build(target_path) 540 | source_path = PathManager.build(source_path) 541 | if source_path.file_exists(): 542 | if target_path.file_exists(): 543 | if overwrite_target: 544 | source_path.sync_file_to(target_path.path, ) 545 | elif overwrite_old_file and source_path.modified_time() > target_path.modified_time(): 546 | source_path.sync_file_to(target_path.path, ) 547 | else: 548 | logger.info(f"File {target_path.path} is existed. Skip it.") 549 | return 550 | elif target_path.dir_exists(): 551 | raise ValueError(f"The source path {source_path.path} is a file," 552 | f"but the synced target path {target_path.path} is a directory") 553 | else: 554 | source_path.sync_file_to(target_path.path,) 555 | elif source_path.dir_exists(): 556 | for sub_name in source_path.ls_dir(): 557 | _sync_file_ignore_existed( 558 | target_path=join_paths(target_path.path, sub_name), 559 | source_path=join_paths(source_path.path, sub_name), 560 | source_no_existed_ignore=source_no_existed_ignore, 561 | overwrite_old_file=overwrite_old_file, 562 | overwrite_target=overwrite_target, 563 | ) 564 | else: 565 | if not source_no_existed_ignore: 566 | raise ValueError(f"The source path {source_path.path} does not exist!") 567 | 568 | 569 | def sync_file(target_path: str, source_path: str, source_no_existed_ignore: bool = False, 570 | overwrite_target: bool = True, overwrite_old_file: bool = False): 571 | """ 572 | Copy the content in the source path to the target path. 573 | The target content will just at the target path, even if target path is a directory. 574 | It is different from cp command action. 575 | :param target_path: 576 | :param source_path: 577 | :param source_no_existed_ignore: If set and the source path does not exit, do nothing. 578 | :param overwrite_target: If set, overwrite the target. 579 | :param overwrite_old_file: If set, only overwrite the old file by the time stamp. 580 | This is not always successful, if two nodes use different time zone. 581 | :return: 582 | """ 583 | _sync_file_ignore_existed(target_path=target_path, source_path=source_path, 584 | source_no_existed_ignore=source_no_existed_ignore, 585 | overwrite_old_file=overwrite_old_file, 586 | overwrite_target=overwrite_target) 587 | 588 | 589 | def symlink(source, target, force_empty=False): 590 | """ 591 | :param target: The symbolic path 592 | :param source: The real path 593 | :param force_empty: If true, the target source should be empty. 594 | If False and target exists, source should point to the same as the target 595 | :return: None 596 | """ 597 | PathManager.build(source).symlink_to(target, force_empty=force_empty) 598 | 599 | 600 | def join_paths(*paths): 601 | return PathManager.build(*paths).path 602 | 603 | 604 | def path_split(path: str): 605 | return PathManager.build(path).split() 606 | 607 | 608 | @contextmanager 609 | def open_file(path: str, mode: str): 610 | with PathManager.build(path).open(mode) as f: 611 | yield f 612 | 613 | 614 | @contextmanager 615 | def cat(path): 616 | with PathManager.build(path).cat() as f: 617 | yield f 618 | 619 | 620 | def glob(path): 621 | return PathManager.build(path).glob() 622 | 623 | 624 | def build_contain_dir(path): 625 | PathManager.build(path).build_contain_dir() 626 | 627 | 628 | def home(): 629 | return str(Path.home()) 630 | 631 | 632 | def modified_time(path): 633 | return PathManager.build(path).modified_time() 634 | 635 | 636 | def append_to_file(path: str, s: str): 637 | PathManager.build(path).append_to_file(s) 638 | 639 | 640 | def sync_to_local(path: str) -> str: 641 | return PathManager.build(path).sync_to_local() 642 | -------------------------------------------------------------------------------- /fairseq_code/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def get_logger(): 5 | name = "logger" 6 | if hasattr(get_logger, name): 7 | return getattr(get_logger, name) 8 | logger = logging.getLogger(__name__) 9 | logger.setLevel(level=logging.INFO) 10 | formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", 11 | "%Y-%m-%d %H:%M:%S") 12 | console = logging.StreamHandler() 13 | console.setFormatter(formatter) 14 | logger.addHandler(console) 15 | 16 | setattr(get_logger, name, logger) 17 | return logger -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sacremoses 2 | psutil 3 | more_itertools 4 | cytoolz 5 | cffi 6 | cython 7 | editdistance 8 | numpy 9 | regex 10 | sacrebleu 11 | tqdm 12 | portalocker 13 | hydra-core==1.0.3 14 | torch==1.8 15 | -------------------------------------------------------------------------------- /scripts/data processing/deduplicate.sh: -------------------------------------------------------------------------------- 1 | paste ${prefix}/normalize/train/en.norm ${prefix}/normalize/train/${lang}.norm | awk '!x[$0]++' > ${prefix}/deduplicate/train/train.norm.dedup 2 | echo "keeping $(wc -l ${prefix}/deduplicate/train/train.norm.dedup) bitext out of $(wc -l ${prefix}/normalize/train/en.norm)" 3 | cut -f1 ${prefix}/deduplicate/train/train.norm.dedup > ${prefix}/deduplicate/train/en.norm.dedup 4 | cut -f2 ${prefix}/deduplicate/train/train.norm.dedup > ${prefix}/deduplicate/train/${lang}.norm.dedup 5 | rm ${prefix}/deduplicate/train/train.norm.dedup -------------------------------------------------------------------------------- /scripts/data processing/learn_and_encode_spm.sh: -------------------------------------------------------------------------------- 1 | # learn bpe 2 | 3 | python -u fairseq/scripts/spm_train.py --input=/path/to/your/file --model_prefix=spm.bpe --vocab_size=64000 --character_coverage=1.0 --model_type=bpe --num_threads=45 --shuffle_input_sentence --train_extremely_large_corpus 4 | 5 | # apply bpe 6 | python fairseq/scripts/spm_encode.py --model spm.bpe.model \ 7 | --output_format=piece --inputs en-${lang}/clean/train/train.full.${lang} en-${lang}/clean/train/train.full.en \ 8 | --outputs ${train_path}/train.full.bpe.${lang} ${train_path}/train.full.bpe.en \ 9 | --min-len 1 --max-len 256 10 | 11 | # restrict length ratio 12 | SCRIPTS=mosesdecoder/scripts 13 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 14 | perl $CLEAN -ratio 3.0 ${prefix}/norm.dedup.spm en ${lang} ${prefix}/norm.dedep.spm.clean 1 256 -------------------------------------------------------------------------------- /scripts/data processing/preprocessing.sh: -------------------------------------------------------------------------------- 1 | echo 'Cloning Moses github repository (for tokenization scripts)...' 2 | git clone https://github.com/moses-smt/mosesdecoder.git 3 | 4 | echo 'Cloning WMT16 scripts...' 5 | git clone https://github.com/rsennrich/wmt16-scripts.git 6 | 7 | SCRIPTS=mosesdecoder/scripts 8 | RO_SCRIPTS=wmt16-scripts 9 | NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl 10 | REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl 11 | REPLACE_UNICODE_PUNCT=$SCRIPTS/tokenizer/replace-unicode-punctuation.perl 12 | 13 | NORMALIZE_IU_SPELLING=normalize-iu-spelling.pl 14 | 15 | REMOVE_DIACRITICS=$RO_SCRIPTS/remove-diacritics.py 16 | NORMALIZE_ROMANIAN=$RO_SCRIPTS/normalise-romanian.py 17 | 18 | if [ ${lang} == "zh" ]; then 19 | cat ${infile} \ 20 | | ${REPLACE_UNICODE_PUNCT} \ 21 | | ${NORM_PUNC} -l ${lang} \ 22 | | ${REM_NON_PRINT_CHAR} \ 23 | | hanzi-convert - -s \ 24 | > ${outfile} 25 | elif [ ${lang} == "ro" ]; then 26 | cat ${infile} \ 27 | | ${REPLACE_UNICODE_PUNCT} \ 28 | | ${NORM_PUNC} -l ${lang} \ 29 | | ${REM_NON_PRINT_CHAR} \ 30 | | ${NORMALIZE_ROMANIAN} \ 31 | | ${REMOVE_DIACRITICS} \ 32 | > ${outfile} 33 | elif [ ${lang} == "iu" ]; then 34 | cat ${infile} \ 35 | | ${REPLACE_UNICODE_PUNCT} \ 36 | | ${NORM_PUNC} -l ${lang} \ 37 | | ${REM_NON_PRINT_CHAR} \ 38 | | perl ${NORMALIZE_IU_SPELLING} \ 39 | > ${outfile} 40 | else 41 | cat ${infile} \ 42 | | ${REPLACE_UNICODE_PUNCT} \ 43 | | ${NORM_PUNC} -l ${lang} \ 44 | | ${REM_NON_PRINT_CHAR} \ 45 | > ${outfile} 46 | fi 47 | -------------------------------------------------------------------------------- /scripts/evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # repo_dir: root directory of the project 4 | repo_dir="$( cd "$( dirname "$0" )" && pwd )"/.. 5 | cd "${repo_dir}" 6 | echo "==== Working directory: ====" >&2 7 | echo "${repo_dir}" >&2 8 | echo "============================" >&2 9 | 10 | bash scripts/install.sh 11 | 12 | pip uninstall numpy 13 | pip install numpy 14 | 15 | python3 -m fairseq_code.toolbox.generate_from_config $@ 16 | 17 | -------------------------------------------------------------------------------- /scripts/install.sh: -------------------------------------------------------------------------------- 1 | pip3 install numpy 2 | 3 | pip3 install -U portalocker 4 | 5 | git clone https://github.com/pytorch/fairseq.git 6 | cd fairseq 7 | git checkout 6f847c8654d56b4d1b1fbacec027f47419426ddb 8 | pip3 install -e . 9 | cd .. 10 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # repo_dir: root directory of the project 4 | repo_dir="$( cd "$( dirname "$0" )" && pwd )"/.. 5 | cd "${repo_dir}" 6 | echo "==== Working directory: ====" >&2 7 | echo "${repo_dir}" >&2 8 | echo "============================" >&2 9 | 10 | bash scripts/install.sh 11 | 12 | # pip uninstall numpy 13 | # pip install numpy 14 | 15 | export MKL_THREADING_LAYER=GNU 16 | export PYTHONPATH="." 17 | 18 | export NCCL_IB_DISABLE=0 19 | export NCCL_IB_GID_INDEX=3 20 | export NCCL_SOCKET_IFNAME=eth0 21 | 22 | python3 toolbox/train.py $@ 23 | -------------------------------------------------------------------------------- /toolbox/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NLP-Playground/LaSS/83be1881e9c25e26b598401d73867170af5adeba/toolbox/__init__.py -------------------------------------------------------------------------------- /toolbox/add_new_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import argparse 4 | 5 | import torch 6 | from fairseq.checkpoint_utils import torch_persistent_save 7 | from fairseq.file_io import PathManager 8 | from more_itertools import flatten, collapse 9 | from pprint import pprint 10 | import json 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--new-mask-name", type=str, help="like en-zh") 16 | parser.add_argument("--new-mask-path", type=str) 17 | parser.add_argument("--input-ckp", type=str) 18 | parser.add_argument("--output-ckp", type=str) 19 | 20 | return parser.parse_args() 21 | 22 | 23 | def main(): 24 | args = get_args() 25 | pprint(args) 26 | new_mask_name = args.new_mask_name 27 | new_mask_path = args.new_mask_path 28 | model = torch.load(args.input_ckp) 29 | 30 | mask_path_dict = json.loads(model['cfg']['model'].mask_path) 31 | if new_mask_name in mask_path_dict: 32 | raise ValueError 33 | 34 | mask_path_dict[new_mask_name] = new_mask_path 35 | 36 | mask_str = json.dumps(mask_path_dict) 37 | 38 | model['cfg']['model'].mask_path = mask_str 39 | 40 | torch_persistent_save(model, args.output_ckp) 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /toolbox/cal_similarity.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from fairseq.checkpoint_utils import torch_persistent_save 5 | from fairseq.file_io import PathManager 6 | from more_itertools import flatten, collapse 7 | from pprint import pprint 8 | 9 | 10 | # dict_keys(['encoder.layers.0.self_attn.k_proj.weight', 'encoder.layers.0.self_attn.v_proj.weight', 'encoder.layers.0.self_attn.q_proj.weight', 'encoder.layers.0.self_attn.out_proj.weight', 'encoder.layers.0.self_attn_layer_norm.weight', 'encoder.layers.0.fc1.weight', 'encoder.layers.0.fc2.weight', 'encoder.layers.0.final_layer_norm.weight', 'encoder.layers.1.self_attn.k_proj.weight', 'encoder.layers.1.self_attn.v_proj.weight', 'encoder.layers.1.self_attn.q_proj.weight', 'encoder.layers.1.self_attn.out_proj.weight', 'encoder.layers.1.self_attn_layer_norm.weight', 'encoder.layers.1.fc1.weight', 'encoder.layers.1.fc2.weight', 'encoder.layers.1.final_layer_norm.weight', 'encoder.layers.2.self_attn.k_proj.weight', 'encoder.layers.2.self_attn.v_proj.weight', 'encoder.layers.2.self_attn.q_proj.weight', 'encoder.layers.2.self_attn.out_proj.weight', 'encoder.layers.2.self_attn_layer_norm.weight', 'encoder.layers.2.fc1.weight', 'encoder.layers.2.fc2.weight', 'encoder.layers.2.final_layer_norm.weight', 'encoder.layers.3.self_attn.k_proj.weight', 'encoder.layers.3.self_attn.v_proj.weight', 'encoder.layers.3.self_attn.q_proj.weight', 'encoder.layers.3.self_attn.out_proj.weight', 'encoder.layers.3.self_attn_layer_norm.weight', 'encoder.layers.3.fc1.weight', 'encoder.layers.3.fc2.weight', 'encoder.layers.3.final_layer_norm.weight', 'encoder.layers.4.self_attn.k_proj.weight', 'encoder.layers.4.self_attn.v_proj.weight', 'encoder.layers.4.self_attn.q_proj.weight', 'encoder.layers.4.self_attn.out_proj.weight', 'encoder.layers.4.self_attn_layer_norm.weight', 'encoder.layers.4.fc1.weight', 'encoder.layers.4.fc2.weight', 'encoder.layers.4.final_layer_norm.weight', 'encoder.layers.5.self_attn.k_proj.weight', 'encoder.layers.5.self_attn.v_proj.weight', 'encoder.layers.5.self_attn.q_proj.weight', 'encoder.layers.5.self_attn.out_proj.weight', 'encoder.layers.5.self_attn_layer_norm.weight', 'encoder.layers.5.fc1.weight', 'encoder.layers.5.fc2.weight', 'encoder.layers.5.final_layer_norm.weight', 'decoder.layers.0.self_attn.k_proj.weight', 'decoder.layers.0.self_attn.v_proj.weight', 'decoder.layers.0.self_attn.q_proj.weight', 'decoder.layers.0.self_attn.out_proj.weight', 'decoder.layers.0.self_attn_layer_norm.weight', 'decoder.layers.0.encoder_attn.k_proj.weight', 'decoder.layers.0.encoder_attn.v_proj.weight', 'decoder.layers.0.encoder_attn.q_proj.weight', 'decoder.layers.0.encoder_attn.out_proj.weight', 'decoder.layers.0.encoder_attn_layer_norm.weight', 'decoder.layers.0.fc1.weight', 'decoder.layers.0.fc2.weight', 'decoder.layers.0.final_layer_norm.weight', 'decoder.layers.1.self_attn.k_proj.weight', 'decoder.layers.1.self_attn.v_proj.weight', 'decoder.layers.1.self_attn.q_proj.weight', 'decoder.layers.1.self_attn.out_proj.weight', 'decoder.layers.1.self_attn_layer_norm.weight', 'decoder.layers.1.encoder_attn.k_proj.weight', 'decoder.layers.1.encoder_attn.v_proj.weight', 'decoder.layers.1.encoder_attn.q_proj.weight', 'decoder.layers.1.encoder_attn.out_proj.weight', 'decoder.layers.1.encoder_attn_layer_norm.weight', 'decoder.layers.1.fc1.weight', 'decoder.layers.1.fc2.weight', 'decoder.layers.1.final_layer_norm.weight', 'decoder.layers.2.self_attn.k_proj.weight', 'decoder.layers.2.self_attn.v_proj.weight', 'decoder.layers.2.self_attn.q_proj.weight', 'decoder.layers.2.self_attn.out_proj.weight', 'decoder.layers.2.self_attn_layer_norm.weight', 'decoder.layers.2.encoder_attn.k_proj.weight', 'decoder.layers.2.encoder_attn.v_proj.weight', 'decoder.layers.2.encoder_attn.q_proj.weight', 'decoder.layers.2.encoder_attn.out_proj.weight', 'decoder.layers.2.encoder_attn_layer_norm.weight', 'decoder.layers.2.fc1.weight', 'decoder.layers.2.fc2.weight', 'decoder.layers.2.final_layer_norm.weight', 'decoder.layers.3.self_attn.k_proj.weight', 'decoder.layers.3.self_attn.v_proj.weight', 'decoder.layers.3.self_attn.q_proj.weight', 'decoder.layers.3.self_attn.out_proj.weight', 'decoder.layers.3.self_attn_layer_norm.weight', 'decoder.layers.3.encoder_attn.k_proj.weight', 'decoder.layers.3.encoder_attn.v_proj.weight', 'decoder.layers.3.encoder_attn.q_proj.weight', 'decoder.layers.3.encoder_attn.out_proj.weight', 'decoder.layers.3.encoder_attn_layer_norm.weight', 'decoder.layers.3.fc1.weight', 'decoder.layers.3.fc2.weight', 'decoder.layers.3.final_layer_norm.weight', 'decoder.layers.4.self_attn.k_proj.weight', 'decoder.layers.4.self_attn.v_proj.weight', 'decoder.layers.4.self_attn.q_proj.weight', 'decoder.layers.4.self_attn.out_proj.weight', 'decoder.layers.4.self_attn_layer_norm.weight', 'decoder.layers.4.encoder_attn.k_proj.weight', 'decoder.layers.4.encoder_attn.v_proj.weight', 'decoder.layers.4.encoder_attn.q_proj.weight', 'decoder.layers.4.encoder_attn.out_proj.weight', 'decoder.layers.4.encoder_attn_layer_norm.weight', 'decoder.layers.4.fc1.weight', 'decoder.layers.4.fc2.weight', 'decoder.layers.4.final_layer_norm.weight', 'decoder.layers.5.self_attn.k_proj.weight', 'decoder.layers.5.self_attn.v_proj.weight', 'decoder.layers.5.self_attn.q_proj.weight', 'decoder.layers.5.self_attn.out_proj.weight', 'decoder.layers.5.self_attn_layer_norm.weight', 'decoder.layers.5.encoder_attn.k_proj.weight', 'decoder.layers.5.encoder_attn.v_proj.weight', 'decoder.layers.5.encoder_attn.q_proj.weight', 'decoder.layers.5.encoder_attn.out_proj.weight', 'decoder.layers.5.encoder_attn_layer_norm.weight', 'decoder.layers.5.fc1.weight', 'decoder.layers.5.fc2.weight', 'decoder.layers.5.final_layer_norm.weight', 'decoder.output_projection.weight']) 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--mask1-path", type=str) 16 | parser.add_argument("--mask2-path", type=str) 17 | parser.add_argument("--which-part", choices=['encoder', 'decoder', 'all']) 18 | parser.add_argument("--which-layer", default=None, choices=["0", "1", "2", "3", "4", "5"], help="start from 0") 19 | parser.add_argument("--query", nargs='+', 20 | choices=["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.out_proj", 21 | "encoder_attn.q", "encoder_attn.k", "encoder_attn.v", "encoder_attn.out_proj", 22 | "fc"], help="input query for specific component") 23 | parser.add_argument("--include-output-project", action="store_true") 24 | return parser.parse_args() 25 | 26 | 27 | def _cal_similarity(mask1_weight, mask2_weight): 28 | w1_size = mask1_weight.numel() 29 | w2_size = mask2_weight.numel() 30 | 31 | assert w1_size == w2_size, "Size mismatch" 32 | 33 | total_one_num = (mask1_weight == 1).sum().item() 34 | overlap_one = (mask1_weight & mask2_weight).sum().item() 35 | 36 | return {"total": total_one_num, "overlap": overlap_one} 37 | 38 | 39 | def load_mask(mask_path): 40 | with PathManager.open(mask_path, "rb") as f: 41 | mask_state = torch.load( 42 | f, 43 | map_location=( 44 | lambda s, _: torch.serialization.default_restore_location(s, "cpu") 45 | ), 46 | ) 47 | return mask_state 48 | 49 | 50 | def main(): 51 | args = get_args() 52 | pprint(args) 53 | query = args.query 54 | part = args.which_part 55 | layer = args.which_layer 56 | mask1_path = args.mask1_path 57 | mask2_path = args.mask2_path 58 | print("calculate similarity with {part} in layer {layer}, the query {query}".format(part=part, layer=layer, 59 | query=query)) 60 | 61 | mask1_state = load_mask(mask1_path) 62 | mask2_state = load_mask(mask2_path) 63 | 64 | # calculate 65 | total_cnt = 0 66 | cnt = 0 67 | layer_cond = "" if layer is None else "layers." + layer 68 | part_cond = "" if part == "all" else part + "." 69 | print(layer_cond, part_cond) 70 | for k, v in mask1_state.items(): 71 | if k not in mask2_state: 72 | raise ValueError 73 | if "layer_norm" in k: 74 | continue 75 | if "projection" in k and not args.include_output_project: 76 | continue 77 | 78 | result = None 79 | if layer_cond in k and part_cond in k: 80 | if query is not None: 81 | for q in query: 82 | if q in k: 83 | print(k) 84 | result = _cal_similarity(v, mask2_state[k]) 85 | break 86 | else: 87 | print(k) 88 | result = _cal_similarity(v, mask2_state[k]) 89 | 90 | if result is not None: 91 | total_cnt += result['total'] 92 | cnt += result['overlap'] 93 | 94 | print("Total Cnt is {total}, Overlap Cnt is {overlap}, the similarity of {path1} and {path2} is {p:.2f}%".format( 95 | total=total_cnt, overlap=cnt, path1=mask1_path, path2=mask2_path, p=cnt / total_cnt * 100)) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | 101 | -------------------------------------------------------------------------------- /toolbox/generate_mask.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from fairseq.checkpoint_utils import torch_persistent_save 5 | from fairseq.file_io import PathManager 6 | from more_itertools import flatten, collapse 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--checkpoint-path", type=str, 12 | help="The checkpoint path to generate the mask") 13 | parser.add_argument("--mask-path", type=str, 14 | help="The output mask path") 15 | parser.add_argument("--gen-mask-with-prob", action="store_true") 16 | parser.add_argument("--gen-random-mask", action="store_true") 17 | parser.add_argument("--mask-prob", type=float) 18 | parser.add_argument('--gen-part', type=str, choices=["encoder", "decoder", "all"], required=True) 19 | parser.add_argument("--include-embedding", action="store_true") 20 | parser.add_argument("--exclude-output-proj",action="store_true") 21 | return parser.parse_args() 22 | 23 | 24 | 25 | def gen_each_random_mask_with_prob(weight, p): 26 | """ 27 | 28 | :param weight: a tensor 29 | :param p: probability 30 | :return: 31 | """ 32 | mask = torch.rand_like(weight.float()) > p 33 | 34 | return mask 35 | 36 | 37 | def main(): 38 | args = get_args() 39 | checkpoint_path = args.checkpoint_path 40 | mask_path = args.mask_path 41 | 42 | with PathManager.open(checkpoint_path, "rb") as f: 43 | state = torch.load( 44 | f, 45 | map_location=( 46 | lambda s, _: torch.serialization.default_restore_location(s, "cpu") 47 | ), 48 | ) 49 | gen_part = args.gen_part 50 | p = args.mask_prob 51 | if args.gen_mask_with_prob or args.gen_random_mask: 52 | mask_dict = gen_mask_with_prob(state, p, gen_part, random_gen=args.gen_random_mask, 53 | include_embedding=args.include_embedding,exclude_output_proj=args.exclude_output_proj) 54 | else: 55 | raise NotImplementedError 56 | 57 | torch_persistent_save(mask_dict, mask_path) 58 | 59 | 60 | def gen_each_mask_with_prob(weight, p): 61 | """ 62 | 63 | :param weight: a tensor 64 | :param p: probability 65 | :return: 66 | """ 67 | total_size = weight.nelement() 68 | kth = int(p * total_size) 69 | weight = torch.abs(weight.float()) 70 | kth_element, _ = torch.kthvalue(weight.view(-1), k=kth, dim=0) 71 | kth_element = kth_element.tolist() # float 72 | mask = weight > kth_element 73 | 74 | return mask 75 | 76 | 77 | def gen_embedding_mask_with_prob(weight, p): 78 | # mask embedding is a little different, we treat each vector in embedding weight as a linear. 79 | weight = torch.abs(weight.float()) 80 | dim = weight.size(1) 81 | kth = int(p * dim) 82 | kth_element, _ = torch.kthvalue(weight, k=kth, dim=1, keepdim=True) # kth_element: (num_vec,1) 83 | mask = weight > kth_element 84 | 85 | return mask 86 | 87 | 88 | def gen_mask_with_prob(state, p, gen_part="all", random_gen=False, include_embedding=False,exclude_output_proj=False): 89 | """ 90 | generate mask with probability 91 | :return: mask_dict 92 | """ 93 | mask_dict = {} 94 | gen_func = gen_each_random_mask_with_prob if random_gen else gen_each_mask_with_prob 95 | for k, v in state['model'].items(): 96 | if "weight" in k and "embed" not in k.lower() and "layer_norm" not in k: 97 | if gen_part == "all": 98 | mask_dict[k] = gen_func(v, p) 99 | elif gen_part == "encoder": 100 | if "encoder" in k: 101 | mask_dict[k] = gen_func(v, p) 102 | elif gen_part == "decoder": 103 | if "decoder" in k: 104 | mask_dict[k] = gen_func(v, p) 105 | else: 106 | raise NotImplementedError 107 | if include_embedding and "weight" in k and "embed" in k.lower(): 108 | mask_dict[k] = gen_embedding_mask_with_prob(v, p) 109 | 110 | if exclude_output_proj: 111 | for k in list(state['model'].keys()): 112 | if 'output_projection' in k: 113 | print("Delete output projection") 114 | del state['model'][k] 115 | return mask_dict 116 | 117 | 118 | if __name__ == '__main__': 119 | main() 120 | -------------------------------------------------------------------------------- /toolbox/generate_random_mask.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from fairseq.checkpoint_utils import torch_persistent_save 5 | from fairseq.file_io import PathManager 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--checkpoint-path", type=str, 11 | help="The checkpoint path to generate the mask") 12 | parser.add_argument("--mask-path", type=str, 13 | help="The output mask path") 14 | return parser.parse_args() 15 | 16 | 17 | def random_like(x): 18 | x = x.float().new_ones(x.size()) * 0.5 19 | return torch.bernoulli(x).bool() 20 | 21 | 22 | def main(): 23 | args = get_args() 24 | checkpoint_path = args.checkpoint_path 25 | mask_path = args.mask_path 26 | 27 | with PathManager.open(checkpoint_path, "rb") as f: 28 | state = torch.load( 29 | f, 30 | map_location=( 31 | lambda s, _: torch.serialization.default_restore_location(s, "cpu") 32 | ), 33 | ) 34 | 35 | mask_dict = {} 36 | for k, v in state['model'].items(): 37 | if "weight" in k: 38 | mask_dict[k] = random_like(v) 39 | 40 | torch_persistent_save(mask_dict, mask_path) 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /toolbox/merge_mask.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from fairseq.checkpoint_utils import torch_persistent_save 5 | from fairseq.file_io import PathManager 6 | from more_itertools import flatten, collapse 7 | from pprint import pprint 8 | 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--encoder-mask", type=str) 13 | parser.add_argument("--decoder-mask", type=str) 14 | parser.add_argument("--output-mask", type=str) 15 | 16 | return parser.parse_args() 17 | 18 | 19 | def load_mask(mask_path): 20 | with PathManager.open(mask_path, "rb") as f: 21 | mask_state = torch.load( 22 | f, 23 | map_location=( 24 | lambda s, _: torch.serialization.default_restore_location(s, "cpu") 25 | ), 26 | ) 27 | return mask_state 28 | 29 | 30 | def main(): 31 | args = get_args() 32 | pprint(args) 33 | encoder_mask_state = load_mask(args.encoder_mask) 34 | decoder_mask_state = load_mask(args.decoder_mask) 35 | 36 | output_mask_dict = {} 37 | for k, v in encoder_mask_state.items(): 38 | if "encoder." in k: 39 | print(k) 40 | output_mask_dict[k] = v 41 | 42 | for k, v in decoder_mask_state.items(): 43 | if "decoder." in k: 44 | print(k) 45 | output_mask_dict[k] = v 46 | 47 | torch_persistent_save(output_mask_dict, args.output_mask) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /toolbox/train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import re 4 | 5 | import yaml 6 | 7 | import argparse 8 | 9 | from fairseq_code.utils import file_operation 10 | from fairseq_code.utils.common import command 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--config", nargs="+", type=str) 16 | parser.add_argument("--force-reset", action="store_true") 17 | return parser.parse_args() 18 | 19 | 20 | def get_user_dir_path(): 21 | d = file_operation.path_split(__file__)[0] 22 | d = file_operation.path_split(d)[0] 23 | return file_operation.join_paths(d, "fairseq_code") 24 | 25 | 26 | def pop_from_config(config, key): 27 | if key in config: 28 | return config.pop(key) 29 | key = key.replace("_", "-") 30 | if key in config: 31 | return config.pop(key) 32 | key = key.replace("-", "_") 33 | if key in config: 34 | return config.pop(key) 35 | return None 36 | 37 | 38 | def _get_environment_variable(name: str): 39 | return os.environ.get(name) 40 | 41 | 42 | def no_restore_if_checkpoint_last_exits(config, force_reset): 43 | save_dir = config['save_dir'] if 'save_dir' in config else config['save-dir'] 44 | if file_operation.file_exists( 45 | file_operation.join_paths( 46 | save_dir, "checkpoint_last.pt" 47 | ) 48 | ): 49 | if "restore-file" in config: 50 | del config["restore-file"] 51 | if "restore_file" in config: 52 | del config["restore_file"] 53 | if not force_reset: 54 | for key in list(config.keys()): 55 | if key.startswith("reset"): 56 | del config[key] 57 | 58 | 59 | def train(args): 60 | config_paths = args.config 61 | force_reset = args.force_reset 62 | config = {} 63 | for config_path in config_paths: 64 | with file_operation.open_file(config_path, "r") as f: 65 | config = {**yaml.safe_load(f), **config} 66 | 67 | no_restore_if_checkpoint_last_exits(config, force_reset) 68 | args = [] 69 | for k, v in config.items(): 70 | if k == "data_bin": 71 | args.append(v) 72 | continue 73 | k = k.replace("_", "-") 74 | if isinstance(v, bool): 75 | if v: 76 | k = f"--{k}" 77 | else: 78 | k = "" 79 | else: 80 | k = f"--{k} {v}" 81 | args.append(k) 82 | args = " ".join(args) 83 | 84 | command( 85 | f"fairseq-train --user-dir {get_user_dir_path()} {args}" 86 | ) 87 | 88 | 89 | if __name__ == "__main__": 90 | train(get_args()) 91 | --------------------------------------------------------------------------------