├── .gitignore ├── scripts ├── extract_fp_32.sh ├── eval.multi.sh ├── eval.mono.sh ├── train.sh ├── mono.sh └── mono.train.sh ├── requirements.txt ├── config └── ds_config_zero_2.json ├── upload_to_hub.py ├── xla_spawn.py ├── README.md ├── dataloader.py ├── train.py └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | output 2 | __pycache__ -------------------------------------------------------------------------------- /scripts/extract_fp_32.sh: -------------------------------------------------------------------------------- 1 | cd $1 && python zero_to_fp32.py . pytorch_model.bin -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.10.2 2 | datasets==1.12.0 3 | torch==1.9.0 4 | -------------------------------------------------------------------------------- /scripts/eval.multi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | for language in en de es fr it nl pt tr ru pl id no sv da vi fi ro cs he hu hr 3 | do 4 | echo "Evaluation for ${language}" 5 | output_dir="output/eval/${language}" 6 | python train.py \ 7 | --languages $language \ 8 | --model_name_or_path $MODEL \ 9 | --output_dir $output_dir \ 10 | --do_predict \ 11 | --per_device_eval_batch_size $BATCH_SIZE \ 12 | --max_seq_len 128 \ 13 | --label_names page_id \ 14 | --dataloader_num_workers 1 15 | done 16 | -------------------------------------------------------------------------------- /scripts/eval.mono.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | for language in de es fr it pt tr ru pl 3 | do 4 | echo "Evaluation for ${language}" 5 | output_dir="output/eval/mono/${language}" 6 | model="output/mono/${language}/checkpoint-1500" 7 | python train.py \ 8 | --languages $language \ 9 | --model_name_or_path $model \ 10 | --output_dir $output_dir \ 11 | --do_prPedict \ 12 | --per_device_eval_batch_size $BATCH_SIZE \ 13 | --max_seq_len 128 \ 14 | --label_names page_id \ 15 | --dataloader_num_workers 1 16 | done 17 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch --nproc_per_node=4 train.py \ 2 | --languages en de es fr it nl pt tr ru pl id no sv da vi fi ro cs he hu hr \ 3 | --output_dir output/1024 \ 4 | --do_train \ 5 | --per_device_train_batch_size 256 \ 6 | --per_device_eval_batch_size 256 \ 7 | --distributed_softmax \ 8 | --max_steps 3000 \ 9 | --evaluation_strategy steps \ 10 | --eval_steps 250 \ 11 | --max_seq_len 128 \ 12 | --warmup_steps 1000 \ 13 | --label_names page_id \ 14 | --logging_steps 5 \ 15 | --fp16 \ 16 | --metric_for_best_model eval_global_mrr \ 17 | --load_best_model_at_end \ 18 | --save_total_limit 3 \ 19 | --report_to tensorboard \ 20 | --dataloader_num_workers 1 \ 21 | --single_domain \ 22 | --hidden_dropout_prob 0.25 \ 23 | --learning_rate 0.00005 \ 24 | --weight_decay 0.01 \ 25 | --alpha 1 \ 26 | --gradient_checkpointing \ 27 | --deepspeed config/ds_config_zero_2.json 28 | 29 | -------------------------------------------------------------------------------- /config/ds_config_zero_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "optimizer": { 12 | "type": "AdamW", 13 | "params": { 14 | "lr": "auto", 15 | "betas": "auto", 16 | "eps": "auto", 17 | "weight_decay": "auto" 18 | } 19 | }, 20 | 21 | "zero_optimization": { 22 | "stage": 2, 23 | "allgather_partitions": true, 24 | "allgather_bucket_size": 2e8, 25 | "overlap_comm": true, 26 | "reduce_scatter": true, 27 | "reduce_bucket_size": 2e8, 28 | "contiguous_gradients": true 29 | }, 30 | 31 | "gradient_accumulation_steps": "auto", 32 | "gradient_clipping": "auto", 33 | "steps_per_print": 2000, 34 | "train_batch_size": "auto", 35 | "train_micro_batch_size_per_gpu": "auto", 36 | "wall_clock_breakdown": false 37 | } -------------------------------------------------------------------------------- /scripts/mono.sh: -------------------------------------------------------------------------------- 1 | LANGUAGE=en OUTPUT_DIR=output/mono/en MODEL_NAME=roberta-base sh scripts/mono.train.sh 2 | LANGUAGE=de OUTPUT_DIR=output/mono/de MODEL_NAME=deepset/gbert-base sh scripts/mono.train.sh 3 | LANGUAGE=es OUTPUT_DIR=output/mono/es MODEL_NAME=bertin-project/bertin-roberta-base-spanish sh scripts/mono.train.sh 4 | LANGUAGE=fr OUTPUT_DIR=output/mono/fr MODEL_NAME=camembert-base sh scripts/mono.train.sh 5 | LANGUAGE=it OUTPUT_DIR=output/mono/it MODEL_NAME=dbmdz/bert-base-italian-cased sh scripts/mono.train.sh 6 | LANGUAGE=nl OUTPUT_DIR=output/mono/nl MODEL_NAME=DTAI-KULeuven/robbertje-1-gb-shuffled sh scripts/mono.train.sh 7 | LANGUAGE=pt OUTPUT_DIR=output/mono/pt MODEL_NAME=neuralmind/bert-base-portuguese-cased sh scripts/mono.train.sh 8 | LANGUAGE=tr OUTPUT_DIR=output/mono/tr MODEL_NAME=dbmdz/bert-base-turkish-cased sh scripts/mono.train.sh 9 | LANGUAGE=ru OUTPUT_DIR=output/mono/ru MODEL_NAME=DeepPavlov/rubert-base-cased sh scripts/mono.train.sh 10 | LANGUAGE=pl OUTPUT_DIR=output/mono/pl MODEL_NAME=dkleczek/bert-base-polish-uncased-v1 sh scripts/mono.train.sh -------------------------------------------------------------------------------- /scripts/mono.train.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch --nproc_per_node=4 train.py \ 3 | --model_name_or_path $MODEL_NAME \ 4 | --languages $LANGUAGE \ 5 | --output_dir $OUTPUT_DIR \ 6 | --do_train \ 7 | --per_device_train_batch_size 256 \ 8 | --per_device_eval_batch_size 256 \ 9 | --distributed_softmax \ 10 | --max_steps 1500 \ 11 | --evaluation_strategy steps \ 12 | --eval_steps 125 \ 13 | --max_seq_len 128 \ 14 | --warmup_steps 1000 \ 15 | --label_names page_id \ 16 | --logging_steps 5 \ 17 | --fp16 \ 18 | --metric_for_best_model eval_global_mrr \ 19 | --load_best_model_at_end \ 20 | --save_total_limit 3 \ 21 | --report_to tensorboard \ 22 | --dataloader_num_workers 1 \ 23 | --single_domain \ 24 | --hidden_dropout_prob 0.25 \ 25 | --learning_rate 0.00005 \ 26 | --weight_decay 0.01 \ 27 | --alpha 1 \ 28 | --gradient_checkpointing \ 29 | --deepspeed config/ds_config_zero_2.json \ 30 | --limit_valid_size 25 31 | 32 | -------------------------------------------------------------------------------- /upload_to_hub.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from transformers import AutoModel, AutoTokenizer 3 | from sentence_transformers.models import Pooling, Transformer 4 | from sentence_transformers import SentenceTransformer 5 | 6 | 7 | parser = ArgumentParser() 8 | parser.add_argument("--checkpoint") 9 | parser.add_argument("--model_name", default="mfaq") 10 | parser.add_argument("--organization", default="clips") 11 | parser.add_argument("--exist_ok", action="store_true") 12 | parser.add_argument("--replace_model_card", action="store_true") 13 | args = parser.parse_args() 14 | 15 | 16 | # model = AutoModel.from_pretrained(args.checkpoint, add_pooling_layer=False) 17 | # tokenizer = AutoTokenizer.from_pretrained(args.checkpoint) 18 | 19 | model = Transformer( 20 | args.checkpoint, 21 | max_seq_length=128, 22 | model_args={"add_pooling_layer": False}, 23 | tokenizer_name_or_path=args.checkpoint 24 | ) 25 | pooling = Pooling(model.auto_model.config.hidden_size, pooling_mode="mean") 26 | st = SentenceTransformer(modules=[model, pooling]) 27 | st.save_to_hub( 28 | args.model_name, 29 | organization=args.organization, 30 | exist_ok=args.exist_ok, 31 | replace_model_card=args.replace_model_card 32 | ) 33 | 34 | -------------------------------------------------------------------------------- /xla_spawn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A simple launcher script for TPU training 16 | 17 | Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py 18 | 19 | :: 20 | >>> python xla_spawn.py --num_cores=NUM_CORES_YOU_HAVE 21 | YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other 22 | arguments of your training script) 23 | 24 | """ 25 | 26 | 27 | import importlib 28 | import sys 29 | from argparse import REMAINDER, ArgumentParser 30 | from pathlib import Path 31 | 32 | import torch_xla.distributed.xla_multiprocessing as xmp 33 | 34 | 35 | def parse_args(): 36 | """ 37 | Helper function parsing the command line options 38 | @retval ArgumentParser 39 | """ 40 | parser = ArgumentParser( 41 | description=( 42 | "PyTorch TPU distributed training launch " 43 | "helper utility that will spawn up " 44 | "multiple distributed processes" 45 | ) 46 | ) 47 | 48 | # Optional arguments for the launch helper 49 | parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).") 50 | 51 | # positional 52 | parser.add_argument( 53 | "training_script", 54 | type=str, 55 | help=( 56 | "The full path to the single TPU training " 57 | "program/script to be launched in parallel, " 58 | "followed by all the arguments for the " 59 | "training script" 60 | ), 61 | ) 62 | 63 | # rest from the training program 64 | parser.add_argument("training_script_args", nargs=REMAINDER) 65 | 66 | return parser.parse_args() 67 | 68 | 69 | def main(): 70 | args = parse_args() 71 | 72 | # Import training_script as a module. 73 | script_fpath = Path(args.training_script) 74 | sys.path.append(str(script_fpath.parent.resolve())) 75 | mod_name = script_fpath.stem 76 | mod = importlib.import_module(mod_name) 77 | 78 | # Patch sys.argv 79 | sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)] 80 | 81 | xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MFAQ: a Multilingual FAQ Dataset 2 | 3 | MFAQ is a multilingual FAQ retrieval dataset. We also release a multilingual FAQ retrieval model trained on this dataset. 4 | 5 | ## Dataset 6 | The dataset is hosted on the HuggingFace hub. You can find it [here](https://huggingface.co/datasets/clips/mfaq). 7 | 8 | Start by installing the dataset package: 9 | ``` 10 | pip install datasets 11 | ``` 12 | 13 | Then import the dataset: 14 | ```python 15 | from datasets import load_dataset 16 | en_dataset = load_dataset("clips/mfaq", "en") 17 | ``` 18 | You can find more information about the dataset and the available configurations on the [description page](https://huggingface.co/datasets/clips/mfaq). 19 | 20 | ## Model 21 | The pre-trained FAQ retrieval model is also hosted on the HuggingFace Hub. You can find it [here](https://huggingface.co/clips/mfaq). 22 | 23 | Start by installing sentence-transformers: 24 | ``` 25 | pip install sentence-transformers 26 | ``` 27 | 28 | Load the model: 29 | ```python 30 | from sentence_transformers import SentenceTransformer 31 | ``` 32 | 33 | Each question must be pre-pended with a ``, answers with a ``. 34 | ```python 35 | question = "How many models can I host on HuggingFace?" 36 | answers = [ 37 | "All plans come with unlimited private models and datasets.", 38 | "AutoNLP is an automatic way to train and deploy state-of-the-art NLP models, seamlessly integrated with the Hugging Face ecosystem.", 39 | "Based on how much training data and model variants are created, we send you a compute cost and payment link - as low as $10 per job." 40 | ] 41 | model = SentenceTransformer('clips/mfaq') 42 | q_embedding, *a_embeddings = model.encode([question] + answers) 43 | best_answer_idx = sorted(enumerate(a_embeddings), key=lambda x: q_embedding.dot(x[1]), reverse=True)[0][0] 44 | print(answers[best_answer_idx]) 45 | ``` 46 | 47 | ## Training 48 | `train.py` uses the [HuggingFace Trainer](https://huggingface.co/transformers/main_classes/trainer.html) to train a FAQ retrieval model on MFAQ. 49 | The following configuration reaches an MRR of 89% on the English subset: 50 | ``` 51 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py \ 52 | --languages en de es fr it nl pt tr ru pl id no sv da vi fi ro cs he hu hr \ 53 | --output_dir output/1024 \ 54 | --do_train \ 55 | --per_device_train_batch_size 256 \ 56 | --per_device_eval_batch_size 256 \ 57 | --distributed_softmax \ 58 | --max_steps 3000 \ 59 | --evaluation_strategy steps \ 60 | --eval_steps 250 \ 61 | --max_seq_len 128 \ 62 | --warmup_steps 1000 \ 63 | --label_names page_id \ 64 | --logging_steps 5 \ 65 | --fp16 \ 66 | --metric_for_best_model eval_global_mrr \ 67 | --load_best_model_at_end \ 68 | --save_total_limit 3 \ 69 | --report_to tensorboard \ 70 | --dataloader_num_workers 1 \ 71 | --single_domain \ 72 | --hidden_dropout_prob 0.25 \ 73 | --learning_rate 0.00005 \ 74 | --weight_decay 0.01 \ 75 | --alpha 1 \ 76 | --gradient_checkpointing \ 77 | --deepspeed config/ds_config_zero_2.json 78 | ``` 79 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | # for the training set, convert everything into a history, replies 2 | # for the validation set, convert everything into a history, replies 3 | import time 4 | import torch 5 | from datasets import load_dataset 6 | import torch.distributed as dist 7 | 8 | 9 | class ValidationDataset(torch.utils.data.Dataset): 10 | def __init__(self, dataset): 11 | self.dataset = dataset 12 | self.index = self._prepare(dataset) 13 | 14 | def _prepare(self, dataset): 15 | index = [] 16 | for page_i, page in enumerate(self.dataset): 17 | for pair_i, pair in enumerate(page["qa_pairs"]): 18 | index.append((page_i, pair_i)) 19 | return index 20 | 21 | def __len__(self): 22 | return sum([e["num_pairs"] for e in self.dataset]) 23 | 24 | def __getitem__(self, idx): 25 | page_i, pair_i = self.index[idx] 26 | page = self.dataset[page_i] 27 | pair = page["qa_pairs"][pair_i] 28 | pair["page_id"] = page["id"] 29 | return pair 30 | 31 | 32 | class MonolingualDataset(torch.utils.data.IterableDataset): 33 | def __init__(self, dataset, *, batch_size, seed = 42, epoch = 0, single_domain = False): 34 | self.dataset = dataset 35 | self.batch_size = batch_size 36 | self.seed = seed 37 | self.epoch = epoch 38 | self.single_domain = single_domain 39 | 40 | def set_epoch(self, epoch): 41 | self.epoch = epoch 42 | 43 | def __len__(self): 44 | return len(self.dataset) 45 | # return sum([e["num_pairs"] for e in self.dataset]) 46 | 47 | def __iter__(self): 48 | while True: 49 | g = torch.Generator() 50 | g.manual_seed(self.seed + self.epoch) 51 | seen_domains = set() 52 | for i in torch.randperm(len(self.dataset), generator=g).tolist(): 53 | page = self.dataset[i] 54 | if page["domain"] in seen_domains: 55 | continue 56 | for pair in page["qa_pairs"]: 57 | pair["page_id"] = page["id"] 58 | yield pair 59 | if self.single_domain: 60 | seen_domains.add(page["domain"]) 61 | self.epoch += 1 62 | 63 | 64 | class IterableDataset(torch.utils.data.IterableDataset): 65 | def __init__(self, datasets, languages, *, probabilities = None, batch_size = 10, seed = 42, single_domain = False, alpha = 0.3): 66 | self.datasets = [MonolingualDataset(e, batch_size=batch_size, single_domain=single_domain) for e in datasets] 67 | self.languages = languages 68 | self.probabilities = self._get_default_probs(alpha) if probabilities is None else torch.Tensor(probabilities) 69 | self.batch_size = batch_size 70 | self.seed = seed 71 | self.alpha = alpha 72 | 73 | def _get_default_probs(self, alpha): 74 | ds_length = [len(e) for e in self.datasets] 75 | total_length = sum(ds_length) 76 | probs = [(e/total_length)**alpha for e in ds_length] 77 | return torch.Tensor(probs) 78 | 79 | def set_epoch(self, epoch): 80 | self.epoch = epoch 81 | 82 | def __iter__(self): 83 | g = torch.Generator() 84 | g.manual_seed(self.seed) 85 | self.datasets = [iter(e) for e in self.datasets] 86 | while True: 87 | idx = torch.multinomial(self.probabilities, 1)[0].item() 88 | for _ in range(self.batch_size): 89 | item = next(self.datasets[idx]) 90 | yield item 91 | 92 | 93 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # adapter from https://github.com/huggingface/transformers/blob/master/examples/pytorch/token-classification/run_ner.py 2 | import os 3 | import sys 4 | import torch 5 | import logging 6 | import pandas as pd 7 | import torch.distributed as dist 8 | from typing import Optional, List 9 | from dataclasses import dataclass, field 10 | from datasets import load_dataset, interleave_datasets 11 | from transformers import set_seed, EarlyStoppingCallback 12 | from transformers import Trainer, TrainingArguments, HfArgumentParser 13 | from transformers import AutoModel, AutoTokenizer 14 | from collections import OrderedDict 15 | from transformers.trainer_utils import get_last_checkpoint 16 | from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint 17 | 18 | 19 | from dataloader import IterableDataset, ValidationDataset 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | @dataclass 26 | class ModelArguments: 27 | gradient_checkpointing: bool = field(default=False) 28 | hidden_dropout_prob: float = field(default=0.1) 29 | attention_probs_dropout_prob: float = field(default=0.1) 30 | model_name_or_path: str = field(default="xlm-roberta-base") 31 | config_name: Optional[str] = field(default=None) 32 | tokenizer_name: Optional[str] = field(default=None) 33 | cache_dir: Optional[str] = field(default=None) 34 | model_revision: str = field(default="main") 35 | 36 | 37 | @dataclass 38 | class DataTrainingArguments: 39 | preprocessing_num_workers: Optional[int] = field(default=None) 40 | max_seq_length: int = field(default=None) 41 | languages: Optional[List[str]] = field(default=None) 42 | probabilities: Optional[List[float]] = field(default=None) 43 | overwrite_cache: bool = field(default=False) 44 | pad_to_max_length: bool = field(default=False) 45 | single_domain: bool = field(default=False) 46 | alpha: float = field(default=0.3) 47 | no_special_token: bool = field(default=False) 48 | limit_valid_size: Optional[int] = field(default=None) 49 | 50 | 51 | @dataclass 52 | class CustomTrainingArgument(TrainingArguments): 53 | distributed_softmax: bool = field(default=False) 54 | 55 | 56 | def distributed_softmax(q_output, a_output, rank, world_size): 57 | q_list = [torch.zeros_like(q_output) for _ in range(world_size)] 58 | a_list = [torch.zeros_like(a_output) for _ in range(world_size)] 59 | dist.all_gather(tensor_list=q_list, tensor=q_output.contiguous()) 60 | dist.all_gather(tensor_list=a_list, tensor=a_output.contiguous()) 61 | q_list[rank] = q_output 62 | a_list[rank] = a_output 63 | q_output = torch.cat(q_list, 0) 64 | a_output = torch.cat(a_list, 0) 65 | return q_output, a_output 66 | 67 | 68 | def mean_pooling(model_output, attention_mask): 69 | token_embeddings = model_output["last_hidden_state"] #First element of model_output contains all token embeddings 70 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 71 | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) 72 | sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) 73 | return sum_embeddings / sum_mask 74 | 75 | 76 | class CustomTrainer(Trainer): 77 | def compute_loss(self, model, inputs, return_outputs=False): 78 | page_id = inputs.pop("page_id", None) 79 | outputs = model(**inputs) 80 | sentence_embeddings = mean_pooling(outputs, inputs['attention_mask']) 81 | q_logits, a_logits = torch.chunk(sentence_embeddings, 2) 82 | if self.args.distributed_softmax and self.args.local_rank != -1 and return_outputs is False: 83 | q_logits, a_logits = distributed_softmax( 84 | q_logits, a_logits, self.args.local_rank, self.args.world_size 85 | ) 86 | labels = torch.arange(q_logits.size(0), device=a_logits.device) 87 | cross_entropy = torch.nn.CrossEntropyLoss() 88 | dp = q_logits.mm(a_logits.transpose(0, 1)) 89 | labels = torch.arange(dp.size(0), device=dp.device) 90 | loss = cross_entropy(dp, labels) 91 | if return_outputs: 92 | outputs = OrderedDict({"q_logits": q_logits, "a_logits": a_logits, "page_id": page_id}) 93 | return (loss, outputs) if return_outputs else loss 94 | 95 | 96 | def get_acc_rr(q_logits, a_logits): 97 | q_logits = torch.from_numpy(q_logits) 98 | a_logits = torch.from_numpy(a_logits) 99 | dp = q_logits.mm(a_logits.transpose(0, 1)) 100 | indices = torch.argsort(dp, dim=-1, descending=True) 101 | targets = torch.arange(indices.size(0), device=indices.device).view(-1, 1) 102 | targets = targets.expand_as(indices) 103 | hits = (targets == indices).nonzero() 104 | ranks = hits[:, -1] + 1 105 | ranks = ranks.float() 106 | acc = ranks.eq(1).float().squeeze() 107 | rr = torch.reciprocal(ranks).squeeze() 108 | return rr, acc 109 | 110 | 111 | def main(): 112 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArgument)) 113 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 114 | training_args.remove_unused_columns = False 115 | 116 | logging.basicConfig( 117 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 118 | datefmt="%m/%d/%Y %H:%M:%S", 119 | handlers=[logging.StreamHandler(sys.stdout)], 120 | ) 121 | 122 | set_seed(training_args.seed) 123 | 124 | model_kwargs = dict( 125 | cache_dir=model_args.cache_dir, 126 | revision=model_args.model_revision, 127 | hidden_dropout_prob=model_args.hidden_dropout_prob, 128 | attention_probs_dropout_prob=model_args.attention_probs_dropout_prob, 129 | add_pooling_layer=False 130 | ) 131 | 132 | if model_args.gradient_checkpointing: 133 | # CANINE does not supporte 134 | model_kwargs["gradient_checkpointing"] = True 135 | 136 | model = AutoModel.from_pretrained( 137 | model_args.model_name_or_path, 138 | **model_kwargs 139 | ) 140 | 141 | tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path 142 | tokenizer = AutoTokenizer.from_pretrained( 143 | tokenizer_name_or_path, 144 | cache_dir=model_args.cache_dir, 145 | use_fast=True, 146 | revision=model_args.model_revision, 147 | additional_special_tokens=None if data_args.no_special_token else ["", "", ""] 148 | ) 149 | if not data_args.no_special_token: 150 | model.resize_token_embeddings(len(tokenizer)) 151 | 152 | datasets = [load_dataset("clips/mfaq", l) for l in data_args.languages] 153 | train_datasets = [e["train"] for e in datasets] 154 | eval_datasets = [e["validation"] for e in datasets] 155 | if data_args.limit_valid_size: 156 | raise 157 | eval_datasets = [e.select(range(data_args.limit_valid_size)) for e in eval_datasets] 158 | eval_dataset = ValidationDataset(interleave_datasets(eval_datasets)) 159 | 160 | if training_args.do_train: 161 | world_size = 1 if training_args.world_size is None else training_args.world_size 162 | train_dataset = IterableDataset( 163 | train_datasets, 164 | data_args.languages, 165 | probabilities=data_args.probabilities, 166 | batch_size=training_args.per_device_train_batch_size*world_size, 167 | seed=training_args.seed, 168 | single_domain=data_args.single_domain, 169 | alpha=data_args.alpha 170 | ) 171 | 172 | padding = "max_length" if data_args.pad_to_max_length else True 173 | def collate_fn(batch): 174 | questions, answers, page_ids = [], [], [] 175 | for item in batch: 176 | questions.append(item['question'] if data_args.no_special_token else f"{item['question']}") 177 | answers.append(item['answer'] if data_args.no_special_token else f"{item['answer']}") 178 | page_ids.append(item["page_id"]) 179 | output = tokenizer( 180 | questions + answers, 181 | padding=padding, 182 | truncation=True, 183 | max_length=data_args.max_seq_length, 184 | return_tensors="pt", 185 | pad_to_multiple_of=8 186 | ) 187 | output["page_id"] = torch.Tensor(page_ids) 188 | return output 189 | 190 | def compute_metrics(predictions): 191 | q_output, a_output, page_id = predictions.predictions 192 | unique_page_ids = set(page_id.tolist()) 193 | global_rr, global_acc, pp_mrr, pp_acc = [], [], [], [] 194 | for unique_page_id in unique_page_ids: 195 | selector = page_id == unique_page_id 196 | s_q_output = q_output[selector, :] 197 | s_a_output = a_output[selector, :] 198 | rr, acc = get_acc_rr(s_q_output, s_a_output) 199 | global_rr.append(rr) 200 | global_acc.append(acc) 201 | pp_mrr.append(rr.mean()) 202 | pp_acc.append(acc.mean()) 203 | global_mrr = torch.cat(global_rr).mean() 204 | global_acc = torch.cat(global_acc).mean() 205 | per_page_mrr = torch.stack(pp_mrr).mean() 206 | per_page_acc = torch.stack(pp_acc).mean() 207 | return {"global_mrr": global_mrr, "global_acc": global_acc, "per_page_mrr": per_page_mrr, "per_page_acc": per_page_acc} 208 | 209 | trainer = CustomTrainer( 210 | model=model, 211 | args=training_args, 212 | train_dataset=train_dataset if training_args.do_train else None, 213 | eval_dataset=eval_dataset if training_args.do_eval else None, 214 | tokenizer=tokenizer, 215 | data_collator=collate_fn, 216 | compute_metrics=compute_metrics 217 | ) 218 | 219 | if training_args.do_train: 220 | train_result = trainer.train() 221 | metrics = train_result.metrics 222 | trainer.log_metrics("train", metrics) 223 | trainer.save_metrics("train", metrics) 224 | trainer.save_state() 225 | 226 | 227 | if training_args.do_predict: 228 | logger.info("*** Predict ***") 229 | _, _, metrics = trainer.predict(eval_dataset, metric_key_prefix="predict") 230 | trainer.log_metrics("eval", metrics) 231 | trainer.save_metrics("eval", metrics) 232 | 233 | 234 | def _mp_fn(index): 235 | # For xla_spawn (TPUs) 236 | main() 237 | 238 | 239 | if __name__ == "__main__": 240 | main() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------