├── plots_and_data_for_paper ├── readme.md ├── UN_LM.pdf ├── UN_LM.png ├── Europarl_LM.pdf ├── Europarl_LM.png ├── mt_individual.csv └── lm_individual.csv ├── create_data ├── make_nc_data │ ├── copy_to_data_folder.sh │ ├── download.sh │ ├── README.md │ └── sample_train_test_data.py ├── make_mtnt_data │ ├── download.sh │ ├── copy_to_data_folder.sh │ ├── README.md │ ├── preprocess.sh │ ├── remove_too_much_punc.py │ └── sample_train_test_data.py ├── make_un_data_for_mt │ ├── copy_to_data_folder.sh │ ├── README.md │ ├── unzip_all.sh │ ├── download_files.sh │ └── sample_train_test_data.py ├── make_un_data_for_lm │ ├── copy_to_data_folder.sh │ ├── README.md │ ├── unzip_all.sh │ ├── download_files.sh │ ├── sample_train_test_data.py │ └── gather_exclusive_paths.py └── make_europarl_data_for_lm │ ├── copy_to_data_folder.sh │ ├── README.md │ ├── download_files.sh │ └── clean_and_split_data.py ├── enviroment_setup ├── activate_poetry.sh ├── install_poetry.sh └── readme.md ├── .gitignore ├── bin ├── run_sacrebleu_eval.sh ├── run_fl_tc.sh ├── run_fl_mt.sh └── run_fl_lm.sh ├── pyproject.toml ├── constants.py ├── README.md ├── create_naacl_plots.py ├── dataset_utils.py └── main_lm.py /plots_and_data_for_paper/readme.md: -------------------------------------------------------------------------------- 1 | This folder contains the CSV files and plots of the figures in the paper. -------------------------------------------------------------------------------- /create_data/make_nc_data/copy_to_data_folder.sh: -------------------------------------------------------------------------------- 1 | mkdir ../../data 2 | mkdir ../../data/nc 3 | cp nc/* ../../data/nc/ -------------------------------------------------------------------------------- /enviroment_setup/activate_poetry.sh: -------------------------------------------------------------------------------- 1 | $HOME/.poetry/bin/poetry shell 2 | source $(poetry env info --path)/bin/activate 3 | -------------------------------------------------------------------------------- /enviroment_setup/install_poetry.sh: -------------------------------------------------------------------------------- 1 | curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python - -------------------------------------------------------------------------------- /plots_and_data_for_paper/UN_LM.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orionw/Multilingual-Federated-Learning/HEAD/plots_and_data_for_paper/UN_LM.pdf -------------------------------------------------------------------------------- /plots_and_data_for_paper/UN_LM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orionw/Multilingual-Federated-Learning/HEAD/plots_and_data_for_paper/UN_LM.png -------------------------------------------------------------------------------- /create_data/make_mtnt_data/download.sh: -------------------------------------------------------------------------------- 1 | wget https://github.com/pmichel31415/mtnt/releases/download/v1.1/MTNT.1.1.tar.gz 2 | tar -xvf MTNT.1.1.tar.gz -------------------------------------------------------------------------------- /create_data/make_nc_data/download.sh: -------------------------------------------------------------------------------- 1 | wget https://xglue.blob.core.windows.net/xglue/xglue_full_dataset.tar.gz 2 | tar -xvf xglue_full_dataset.tar.gz -------------------------------------------------------------------------------- /create_data/make_mtnt_data/copy_to_data_folder.sh: -------------------------------------------------------------------------------- 1 | mkdir ../../data 2 | mkdir ../../data/mtnt_mt_corpus 3 | cp splits_data/* ../../data/mtnt_mt_corpus/ -------------------------------------------------------------------------------- /create_data/make_un_data_for_mt/copy_to_data_folder.sh: -------------------------------------------------------------------------------- 1 | mkdir ../../data 2 | mkdir ../../data/un_mt_corpus 3 | cp splits_data/* ../../data/un_mt_corpus/ -------------------------------------------------------------------------------- /create_data/make_un_data_for_lm/copy_to_data_folder.sh: -------------------------------------------------------------------------------- 1 | mkdir ../../data 2 | mkdir ../../data/un_corpus 3 | cp un_corpus/splits_data/* ../../data/un_corpus/ -------------------------------------------------------------------------------- /plots_and_data_for_paper/Europarl_LM.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orionw/Multilingual-Federated-Learning/HEAD/plots_and_data_for_paper/Europarl_LM.pdf -------------------------------------------------------------------------------- /plots_and_data_for_paper/Europarl_LM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orionw/Multilingual-Federated-Learning/HEAD/plots_and_data_for_paper/Europarl_LM.png -------------------------------------------------------------------------------- /create_data/make_europarl_data_for_lm/copy_to_data_folder.sh: -------------------------------------------------------------------------------- 1 | mkdir ../../data 2 | mkdir ../../data/europarl_mt_corpus 3 | cp splits_data/* ../../data/europarl_mt_corpus/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | simulation_pytorch/data/* 2 | **/__pycache__/** 3 | slurm-* 4 | poetry.lock 5 | make_wmt_data/europarl-v9.* 6 | make_wmt_data/split_data/* 7 | fl_models 8 | un_corpus 9 | results/* 10 | saved_objects/* 11 | saved_objects* 12 | *.json -------------------------------------------------------------------------------- /bin/run_sacrebleu_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # How to evaluate sacrebleu on output translations 3 | sacrebleu -i -m bleu -b -w 1 --confidence 4 | # If using ja add `--tokenize ja-mecab` and if using zh add `--tokenize zh` -------------------------------------------------------------------------------- /create_data/make_nc_data/README.md: -------------------------------------------------------------------------------- 1 | # Create PAWS-X FL Data 2 | 0. Run `bash download.sh` to download and extract the data 3 | 1. Run `python sample_train_test_data.py` to split the data 4 | 2. `bash copy_to_data_folder.sh` to move them to the `data` folder -------------------------------------------------------------------------------- /plots_and_data_for_paper/mt_individual.csv: -------------------------------------------------------------------------------- 1 | Method,Setting,En-Fr,En-Ja,Avg_MTNT,En-Fr,Ar-Es,Ru-Zh,Avg_UN 2 | Centralized,Pretrained,22.8,4.9,13.9,37.4,36.2,22.6,32.1 3 | IID FL,Pretrained,24.9,4.0,14.5,38.7,37.2,21.4,32.4 4 | Non-IID FL,Pretrained,16.1,5.9,11.0,38.0,36.9,21.4,32.1 -------------------------------------------------------------------------------- /create_data/make_europarl_data_for_lm/README.md: -------------------------------------------------------------------------------- 1 | # Steps to recreate data 2 | 0. `bash download_files.sh` will download and unzip the tar files 3 | 1. `python clean_and_split_data.py` to sample data for training, dev, and testing 4 | 2. `bash copy_to_data_folder.sh` to move them to the `data` folder -------------------------------------------------------------------------------- /create_data/make_mtnt_data/README.md: -------------------------------------------------------------------------------- 1 | # Steps to recreate data 2 | 0. `bash download.sh` will download the tar files 3 | 1. `bash preprocess.sh` will prepare the data for M2M-100 to be similar to its training data and re-write the original data 4 | 2. `python sample_train_test_data.py` to sample data for training, dev, and testing 5 | 3. `bash copy_to_data_folder.sh` to move them to the `data` folder -------------------------------------------------------------------------------- /create_data/make_un_data_for_mt/README.md: -------------------------------------------------------------------------------- 1 | # Steps to recreate data 2 | 0. `bash download_files.sh` will download the tar files. NOTE: if this provides an error, try downloading them from the link manually 3 | 1. `unzip_all.sh` will combine the tar.gz files and unzip them 4 | 3. `python sample_train_test_data.py` to sample data for training, dev, and testing 5 | 4. `bash copy_to_data_folder.sh` to move them to the `data` folder -------------------------------------------------------------------------------- /create_data/make_un_data_for_mt/unzip_all.sh: -------------------------------------------------------------------------------- 1 | # tar.gz files were split - we need to combine them to read them 2 | cat UNv1.0.ar-es.tar.gz.* > UNv1.0-TEI.ar-es.tar.gz 3 | cat UNv1.0.ru-zh.tar.gz.* > UNv1.0-TEI.ru-zh.tar.gz 4 | cat UNv1.0.en-fr.tar.gz.* > UNv1.0-TEI.en-fr.tar.gz 5 | # now decode them 6 | tar -xvf UNv1.0-TEI.en-fr.tar.gz && 7 | tar -xvf UNv1.0-TEI.ar-es.tar.gz && 8 | tar -xvf UNv1.0-TEI.ru-zh.tar.gz 9 | mkdir splits_data -------------------------------------------------------------------------------- /create_data/make_un_data_for_mt/download_files.sh: -------------------------------------------------------------------------------- 1 | gdown --id 14xQhLb0mJgUF3-UfCNDPFz-Wlub4-q80 # ru-zh 1 2 | gdown --id 1lVsmTDFE_o7XIawEcWXPUN0GXQatjTId # ru-zh 2 3 | gdown --id 1sB5aybnEOqTaHmOBfIPqtn4c-inQDDRe # ar-es 1 4 | gdown --id 1e55Ro28oimBNCTRBoF1MAclTpDgVRu2w # ar-es 2 5 | gdown --id 126paJ81dFHu1wSiXrM47bkFuHDMMHpN_ # fr-en 1 6 | gdown --id 1bs617LoEkn84O_lwOl1NSl4tpF0PvwvE # fr-en 2 7 | gdown --id 11LzF_iQo3-8pwov6Iu7ZeMBl3nOOJECL # fr-en 3 -------------------------------------------------------------------------------- /create_data/make_mtnt_data/preprocess.sh: -------------------------------------------------------------------------------- 1 | for lang in fr ja 2 | do 3 | for split in train valid test 4 | do 5 | src=en 6 | tgt=$lang 7 | input=MTNT/$split/$split.en-$lang.tsv 8 | output=no-punct.en-$lang 9 | python remove_too_much_punc.py --input <(gzip -c $input) --bitext $output --src-lang $src --tgt-lang $tgt 10 | paste $output.$src $output.$tgt | cat -n > $input.corrected 11 | done 12 | done -------------------------------------------------------------------------------- /enviroment_setup/readme.md: -------------------------------------------------------------------------------- 1 | You should be able to setup the enviroment like so from the main directory: 2 | 3 | ## Enviroment Setup 4 | 0. Install poetry (`bash enviroment_setup/install_poetry.sh`) 5 | 1. Activate poetry (`bash enviroment_setup/activate_poetry.sh`) 6 | 2. Install dependecies (`poetry install`) 7 | 8 | If this fails, please consult the [poetry forums](https://python-poetry.org/docs/basic-usage/) or [Github](https://github.com/python-poetry/poetry) for more help. -------------------------------------------------------------------------------- /create_data/make_un_data_for_lm/README.md: -------------------------------------------------------------------------------- 1 | # Steps to recreate data 2 | 0. `bash download_files.sh` will download the tar files. NOTE: if this provides an error, try downloading them from the link manually 3 | 1. `unzip_all.sh` will combine the tar.gz files and unzip them into `UNv1.0-TEI` 4 | 2. `python gather_exclusive_paths.py` to gather only unique sentences from each language and create text files of them 5 | 3. `python sample_train_test_data.py` to sample data for training, dev, and testing 6 | 4. `bash copy_to_data_folder.sh` to move them to the `data` folder -------------------------------------------------------------------------------- /create_data/make_un_data_for_lm/unzip_all.sh: -------------------------------------------------------------------------------- 1 | # tar.gz files were split - we need to combine them to read them 2 | cat UNv1.0-TEI.fr.tar.gz.* > UNv1.0-TEI.fr.tar.gz 3 | cat UNv1.0-TEI.ar.tar.gz.* > UNv1.0-TEI.ar.tar.gz 4 | cat UNv1.0-TEI.es.tar.gz.* > UNv1.0-TEI.es.tar.gz 5 | cat UNv1.0-TEI.en.tar.gz.* > UNv1.0-TEI.en.tar.gz 6 | cat UNv1.0-TEI.ru.tar.gz.* > UNv1.0-TEI.ru.tar.gz 7 | # now decode them 8 | tar -xvf UNv1.0-TEI.fr.tar.gz && 9 | tar -xvf UNv1.0-TEI.en.tar.gz && 10 | tar -xvf UNv1.0-TEI.ar.tar.gz && 11 | tar -xvf UNv1.0-TEI.ru.tar.gz && 12 | tar -xvf UNv1.0-TEI.es.tar.gz && 13 | tar -xvf UNv1.0-TEI.zh.tar.gz.00 -------------------------------------------------------------------------------- /create_data/make_un_data_for_lm/download_files.sh: -------------------------------------------------------------------------------- 1 | gdown --id 1b9uh4MxYmmL3C67Vy0_IkQoRcT8oKPzB # en 1 2 | gdown --id 1qR_bUODtnxz7aIo9hbh_jnyY_04dqL5O # en 2 3 | gdown --id 1mbOxrXg2icSb7RuYxbZQrKofP9J8XnCT # fr 1 4 | gdown --id 14B_iJkO7LyLpXfdTTJ7cQm2ey6Pv7YWJ # fr 2 5 | gdown --id 16skUarJy_ly6DJZnkMB70An0zkX4njgn # es 1 6 | gdown --id 1qWpPY8MrXo4nAA6gSTr8F2Yx796W670G # es 2 7 | gdown --id 1SuqHMYalMIVn4C3mmEPjV39H0G0q9yec # ru 1 8 | gdown --id 1nv2e-YG2jjO3La6I69Owubn7Au8XEmKz # ru 2 9 | gdown --id 1z2jxPanShvOP3c-CgL9V5HWcWDJS3tbG # zh 1 10 | gdown --id 1QlddnLM6h_XnYV9b81FoI013lLoCVJ_q # ar 1 11 | gdown --id 1lF_hCnZ-SYxGyFSFiFR2Ni-e9BMW7SY3 # ar 2 12 | 13 | -------------------------------------------------------------------------------- /plots_and_data_for_paper/lm_individual.csv: -------------------------------------------------------------------------------- 1 | Method,Setting,En_E,Cs_E,Lt_E,Es_E,Pl_E,Fi_E,Pt_E,De_E,Avg_Europarl,En_U,Es_U,Fr_U,Ru_U,Zh_U,Ar_U,Avg_UN 2 | Centralized,Random,19.3,4.4,3.9,8.4,4.7,4.6,6.9,10.8,6.7,8.8,5.2,8.3,3.8,4.2,4.4,5.5 3 | IID FL,Random,27.1,5.3,4.4,11.2,5.8,5.4,8.7,15.1,8.3,8.8,5.2,8.5,3.6,3.8,4.4,5.3 4 | Non-IID FL,Random,50.4,6.8,12.3,16.1,18.3,11.3,34.6,21.8,17.9,12.3,11.4,14.8,9.1,8.0,8.3,10.4 5 | Centralized,Pretrained,12.0,3.5,3.3,13.4,4.7,3.8,4.7,6.8,5.7, 6.7,4.1,4.8,2.9,3.2,3.5,4.1 6 | IID FL,Pretrained,10.5,3.9,4.2,6.1,3.7,4.4,5.4,6.7,5.3,6.4,3.9,5.8,2.8,3.2,3.4,4.0 7 | Non-IID FL,Pretrained,8.4,3.7,4.0,6.0,3.7,4.3,5.5,6.5,5.0,6.9,4.5,6.4,4.2,4.3,4.1,4.7 -------------------------------------------------------------------------------- /bin/run_fl_tc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # centralized 4 | python main_lm.py --data nc --model xlm-roberta-base --n_cpus 1 --n_gpus 2 --batch_size 8 --batch_accum 4 --lang_mix 0.99 --centralized --n_iterations 10 --lr 1e-5 5 | 6 | # IID FL 7 | python main_lm.py --data nc --model xlm-roberta-base --n_cpus 1 --n_gpus 6 --batch_size 8 --batch_accum 4 --lang_mix 0.99 --n_iterations 10 --lr 1e-5 8 | 9 | # Non-IID FL 10 | python main_lm.py --data nc --model xlm-roberta-base --n_cpus 1 --n_gpus 6 --batch_size 8 --batch_accum 4 --lang_mix 0.0 --n_iterations 10 --lr 1e-5 11 | 12 | # For eval add "--n_iterations 0 --load_model " 13 | # For random initialization add "--random_init" to the model and change n_iterations to 50. 14 | -------------------------------------------------------------------------------- /bin/run_fl_mt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # for un_corpus but replace with `mtnt` for mtnt experiments 3 | 4 | # centralized 5 | python main_lm.py --data un_mt_corpus --model facebook/m2m100_418M --n_cpus 1 --n_gpus 2 --batch_size 2 --batch_accum 16 --lang_mix 0.99 --centralized --n_iterations 50 --lr 5e-5 6 | 7 | # IID FL 8 | python main_lm.py --data un_mt_corpus --model facebook/m2m100_418M --n_cpus 1 --n_gpus 3 --batch_size 2 --batch_accum 16 --lang_mix 0.99 --n_iterations 50 --lr 5e-5 9 | 10 | # non-IID FL 11 | python main_lm.py --data un_mt_corpus --model facebook/m2m100_418M --n_cpus 1 --n_gpus 3 --batch_size 2 --batch_accum 16 --lang_mix 0.0 --n_iterations 50 --lr 5e-5 12 | 13 | # For eval add "--n_iterations 0 --load_model " 14 | # For random initialization add "--random_init" to the model (garbage results though, due to the small data) 15 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "poetry==1.1.10", 4 | ] 5 | build-backend = "poetry.masonry.api" 6 | 7 | [tool.poetry] 8 | name = "MultilingualFederatedNLP" 9 | version = "0.1.0" 10 | description = "Federated Learning Multilingual NLP" 11 | authors = ["Orion Weller, Marc Marone, Vladimir Braverman, Dawn Lawrie, and Benjamin Van Durme"] 12 | 13 | [tool.poetry.dependencies] 14 | python = "~3.7.9" 15 | numpy = "1.21.2" 16 | pandas = "1.3.5" 17 | matplotlib = "3.3.4" 18 | flwr = {extras = ["simulation"], version = "^0.17.0"} 19 | # flwr = {extras = ["simulation"], path = "../../", develop = true } # For development 20 | torch = "1.7.1" 21 | torchvision = "0.8.2" 22 | transformers = "4.12.2" 23 | datasets = "1.15.1" 24 | scikit-learn = "1.0.1" 25 | sentencepiece = "0.1.96" 26 | sacrebleu = "2.0.0" 27 | seaborn = "0.9.1" 28 | gdown = "4.2.2" 29 | -------------------------------------------------------------------------------- /bin/run_fl_lm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # for un_corpus but replace with `wmt` for europarl experiments 3 | 4 | # centralized 5 | python main_lm.py --data un_corpus --model distilbert-base-multilingual-cased --n_cpus 1 --n_gpus 2 --batch_size 10 --batch_accum 3 --lang_mix 0.99 --centralized --n_iterations 100 --lr 5e-5 6 | 7 | # IID FL 8 | python main_lm.py --data un_corpus --model distilbert-base-multilingual-cased --n_cpus 1 --n_gpus 5 --batch_size 10 --batch_accum 3 --lang_mix 0.99 --n_iterations 100 --lr 5e-5 9 | 10 | # non-IID FL 11 | python main_lm.py --data un_corpus --model distilbert-base-multilingual-cased --n_cpus 1 --n_gpus 5 --batch_size 10 --batch_accum 3 --lang_mix 0.0 --n_iterations 100 --lr 5e-5 12 | 13 | # For eval add "--n_iterations 0 --load_model " 14 | # For random initialization add "--random_init" to the model and double n_iterations 15 | -------------------------------------------------------------------------------- /create_data/make_europarl_data_for_lm/download_files.sh: -------------------------------------------------------------------------------- 1 | wget --no-check-certificate http://www.statmt.org/europarl/v9/training-monolingual/europarl-v9.pl.gz 2 | wget --no-check-certificate http://www.statmt.org/europarl/v9/training-monolingual/europarl-v9.cs.gz 3 | wget --no-check-certificate http://www.statmt.org/europarl/v9/training-monolingual/europarl-v9.de.gz 4 | wget --no-check-certificate http://www.statmt.org/europarl/v9/training-monolingual/europarl-v9.en.gz 5 | wget --no-check-certificate http://www.statmt.org/europarl/v9/training-monolingual/europarl-v9.es.gz 6 | wget --no-check-certificate http://www.statmt.org/europarl/v9/training-monolingual/europarl-v9.fi.gz 7 | wget --no-check-certificate http://www.statmt.org/europarl/v9/training-monolingual/europarl-v9.lt.gz 8 | wget --no-check-certificate http://www.statmt.org/europarl/v9/training-monolingual/europarl-v9.pt.gz 9 | gunzip europarl-v9.* 10 | -------------------------------------------------------------------------------- /create_data/make_un_data_for_lm/sample_train_test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import glob 4 | import json 5 | import random 6 | import pandas as pd 7 | from sklearn.model_selection import train_test_split 8 | import numpy as np 9 | 10 | np.random.seed(1) 11 | random.seed(1) 12 | 13 | def gather_data(): 14 | all_files = {} 15 | for file_path in glob.glob("*.txt"): 16 | lang = file_path.split("/")[-1].split(".")[0] 17 | data = [] 18 | with open(file_path, "r") as fin: 19 | for line in fin: 20 | data.append(line.strip()) 21 | all_files[lang] = data 22 | print(f"Reading in {file_path} with length {len(data)}") 23 | return all_files 24 | 25 | def split_and_save_data(save_path: str, min_size: int, data_w_text: dict): 26 | if not os.path.isdir(save_path): 27 | os.makedirs(save_path) 28 | for lang, data in data_w_text.items(): 29 | data = pd.DataFrame({"text": data}) 30 | print(f"Saving final data for lang {lang}") 31 | kept_data = data.sample(n=min_size, replace=False) 32 | train, dev_and_test = train_test_split(kept_data, test_size=10000) 33 | dev, test = train_test_split(dev_and_test, test_size=0.5) 34 | 35 | final_save_path = os.path.join(save_path, lang + "_train.csv") 36 | train.to_csv(final_save_path, index=False) 37 | dev.to_csv(final_save_path.replace("train", "dev"), index=False) 38 | test.to_csv(final_save_path.replace("train", "test"), index=False) 39 | print(f"Train shape {train.shape} and dev shape {dev.shape} and test shape {test.shape}") 40 | 41 | if __name__ == "__main__": 42 | data_w_text = gather_data() 43 | split_and_save_data("splits_data/", 60000, data_w_text) -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | LANG_MAP_EUROPARL = { 4 | "en": 0, 5 | "cs": 1, 6 | "lt": 2, 7 | "es": 3, 8 | "pl": 4, 9 | "fi": 5, 10 | "pt": 6, 11 | "de": 7 12 | } 13 | 14 | 15 | LANG_MAP_UN_CORPUS = { 16 | "en": 0, 17 | "es": 1, 18 | "zh": 2, 19 | "ru": 3, 20 | "ar": 4, 21 | "fr": 5, 22 | } 23 | 24 | LANG_MAP_PAWSX = { 25 | "de": 0, 26 | "en": 1, 27 | "es": 2, 28 | "fr": 3, 29 | "ja": 4, 30 | "ko": 5, 31 | "zh": 6, 32 | } 33 | 34 | 35 | LANG_MAP_UN_MT_CORPUS = { 36 | "en-fr": 0, 37 | "ar-es": 1, 38 | "ru-zh": 2, 39 | } 40 | 41 | LANG_MAP_MTNT_MT_CORPUS = { 42 | "en-fr": 0, 43 | "en-ja": 1, 44 | } 45 | 46 | 47 | LANG_MAP_NC = { 48 | "en": 0, 49 | "es": 1, 50 | "fr": 2, 51 | "de": 3, 52 | "ru": 4 53 | } 54 | 55 | 56 | MBART_MAP = { 57 | "en": "en_XX", 58 | "cs": "cs_CZ", 59 | "lt": "lt_LT", 60 | "es": "es_XX", 61 | "pl": "pl_PL", 62 | "fi": "fi_FI", 63 | "pt": "pt_XX", 64 | "de": "de_DE", 65 | } 66 | 67 | 68 | POOL_SIZE = { 69 | "brown": 1, 70 | "wmt": 8, 71 | "un_corpus": 6, 72 | "un_mt_corpus": 3, 73 | "mtnt": 2, 74 | "pawsx": 7, 75 | "nc": 5, 76 | } 77 | 78 | DATA_TO_FILE_PATHS = { 79 | "brown": "data/brown", 80 | "wmt": "data/wmt", 81 | "un_corpus": 'data/un_corpus', 82 | "un_mt_corpus": 'data/un_mt_corpus', 83 | "mtnt": "data/mtnt_mt_corpus", 84 | "pawsx": "data/pawsx", 85 | "nc": "data/nc" 86 | } 87 | 88 | MAP_LANG_MAP = { 89 | "data/wmt": LANG_MAP_EUROPARL, 90 | "data/un_corpus": LANG_MAP_UN_CORPUS, 91 | "data/un_mt_corpus": LANG_MAP_UN_MT_CORPUS, 92 | "data/mtnt_mt_corpus": LANG_MAP_MTNT_MT_CORPUS, 93 | "data/pawsx": LANG_MAP_PAWSX, 94 | "data/nc": LANG_MAP_NC, 95 | } 96 | -------------------------------------------------------------------------------- /create_data/make_mtnt_data/remove_too_much_punc.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import argparse 3 | from string import punctuation 4 | 5 | #Source: https://raw.githubusercontent.com/pytorch/fairseq/main/examples/m2m_100/process_data/remove_too_much_punc.py 6 | 7 | def len_no_punc(s, punc): 8 | return len([ch for ch in s if ch in punc]) 9 | 10 | def filter_overpunc(len_npunc, len_sen): 11 | return len_npunc < 0.5*len_sen 12 | 13 | def main(args): 14 | punc = punctuation + "-|-" 15 | print('Processing file {}'.format(args.input)) 16 | with gzip.open(args.input, 'rt', encoding=args.encoding) as tsv: 17 | with open(args.bitext + '.' + args.src_lang, 'wt', encoding=args.encoding) as fsrc: 18 | with open(args.bitext + '.' + args.tgt_lang, 'wt', encoding=args.encoding) as ftgt: 19 | for line in tsv: 20 | fields = line.split('\t') 21 | 22 | src, tgt = fields[1], fields[2] 23 | 24 | nchar_npunc_src = len_no_punc(src, punc) 25 | nchar_npunc_tgt = len_no_punc(tgt, punc) 26 | 27 | if filter_overpunc(nchar_npunc_src, len(src)) and filter_overpunc(nchar_npunc_tgt, len(tgt)): 28 | fsrc.write(src.strip() + '\n') 29 | ftgt.write(tgt.strip() + '\n') 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--input", required=True, type=str) 34 | parser.add_argument('--encoding', default='utf-8', help='character encoding for input/output') 35 | parser.add_argument('--bitext', type=str, required=True, help='language direction') 36 | parser.add_argument('--src-lang', type=str, required=True, help='Source language') 37 | parser.add_argument('--tgt-lang', type=str, required=True, help='Target language') 38 | main(parser.parse_args()) 39 | -------------------------------------------------------------------------------- /create_data/make_un_data_for_mt/sample_train_test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import glob 4 | import json 5 | import random 6 | import pandas as pd 7 | from sklearn.model_selection import train_test_split 8 | import numpy as np 9 | 10 | np.random.seed(1) 11 | random.seed(1) 12 | 13 | def gather_data(): 14 | all_data = {} 15 | lang_sets = ["en-fr", "ar-es", "ru-zh"] 16 | for lang_set in lang_sets: 17 | first, second = lang_set.split("-") 18 | all_data[lang_set] = {} 19 | for cur_lang in [first, second]: 20 | data = [] 21 | with open(f"{lang_set}/UNv1.0.{lang_set}.{cur_lang}", "r") as fin: 22 | for line in fin: 23 | data.append(line.strip()) 24 | all_data[lang_set][cur_lang] = data 25 | print(f"Reading in {lang_set} {cur_lang} with length {len(data)}") 26 | return all_data 27 | 28 | def split_and_save_data(save_path: str, min_size: int, data_w_text: dict): 29 | if not os.path.isdir(save_path): 30 | os.makedirs(save_path) 31 | for lang_set, data_dict in data_w_text.items(): 32 | first, second = lang_set.split("-") 33 | data = pd.DataFrame({first: data_dict[first], second: data_dict[second]}) 34 | print(f"Saving final data for lang {lang_set}") 35 | kept_data = data.sample(n=min_size, replace=False) 36 | train, dev_and_test = train_test_split(kept_data, test_size=10000) 37 | dev, test = train_test_split(dev_and_test, test_size=0.5) 38 | 39 | final_save_path = os.path.join(save_path, lang_set + "_train.csv") 40 | train.to_csv(final_save_path, index=False) 41 | dev.to_csv(final_save_path.replace("train", "dev"), index=False) 42 | test.to_csv(final_save_path.replace("train", "test"), index=False) 43 | print(f"Train shape {train.shape} and dev shape {dev.shape} and test shape {test.shape}") 44 | 45 | if __name__ == "__main__": 46 | data_w_text = gather_data() 47 | split_and_save_data("splits_data/", 20000, data_w_text) 48 | # size following https://aclanthology.org/2020.wmt-1.4.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pretrained Models for Multilingual Federated Learning 2 | Code and data setup for our paper: [Pretrained Models for Multilingual Federated Learning](https://aclanthology.org/2022.naacl-main.101/) by *Orion Weller, *Marc Marone, Vladimir Braverman, Dawn Lawrie, and Benjamin Van Durme. Many thanks to the great developers at the [flwr](https://flower.dev/) team who have prepared [excellent examples](https://github.com/adap/flower/tree/main/examples/simulation_pytorch). 3 | 4 | ## Enviroment Setup 5 | NOTE: we used poetry following the advice of the flwr framework. 6 | 7 | 0. Install poetry (`bash enviroment_setup/install_poetry.sh`) 8 | 1. Activate poetry (`bash enviroment_setup/activate_poetry.sh`) 9 | 2. Install dependecies (`poetry install`). NOTE: this takes a few minutes. 10 | 11 | ## Data Setup 12 | 0. After deciding which data setup you would like, look for the corresponding dataset in `create_data` For the sake of this readme, we will use the `mtnt` data. 13 | 1. `cd` into the folder (`cd create_data/make_mtnt_data`) 14 | 2. Follow the instructions in the `readme` located in the folder. It will typically have scripts for downloading, preprocessing, splitting, and then moving the data into the final location for the model. 15 | 16 | ## Training/Evaluating Federated Learning Models 17 | 0. Make sure the enviroment and the data have been set up as above. 18 | 1. Depending on the type of model you want to train (classification, LM, or MT) see the corresponding scripts in `bin/run_fl_{mt,tc,lm}.sh`. Each script contains information about how to run centralized, non-IID FL, or IID FL learning, as well as random initialization and/or evaluation. 19 | 2. To evaluate BLEU scores, be sure to install the sacrebleu script and evaluating using the format described in `bin/run_sacrebleu_eval.sh`. 20 | 21 | ## Citation 22 | If you found this code or paper helpful, please consider citing: 23 | ``` 24 | @inproceedings{Weller2022PretrainedMF, 25 | title={Pretrained Models for Multilingual Federated Learning}, 26 | author={Orion Weller and Marc Marone and Vladimir Braverman and Dawn J Lawrie and Benjamin Van Durme}, 27 | booktitle={Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT)}, 28 | year={2022} 29 | } 30 | ``` 31 | 32 | -------------------------------------------------------------------------------- /create_naacl_plots.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import argparse 4 | import pandas as pd 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | 8 | sns.set_context('poster') 9 | sns.set(font_scale=1.5) 10 | 11 | def create_lm_plot(path: str): 12 | data = pd.read_csv(path, header=0, index_col=None) 13 | E_cols = [col for col in data.columns if "_E" == col[-2:]] 14 | U_cols = [col for col in data.columns if "_U" == col[-2:]] 15 | data["Europarl"] = data[E_cols].mean(axis=1) 16 | data["UN"] = data[U_cols].mean(axis=1) 17 | 18 | 19 | for idx, target_data in enumerate(["Europarl", "UN"]): 20 | with sns.axes_style("whitegrid"): 21 | g = sns.catplot( 22 | data=data, kind="bar", 23 | x="Method", y=target_data, hue="Setting", 24 | ci=None, legend_out=False, legend=False, height=4, aspect=8/4 25 | ) 26 | g.set_axis_labels("Method", "PPL") 27 | # plt.xticks(rotation=45) 28 | plt.title(target_data) 29 | if idx == 0: 30 | plt.legend(loc='upper left', title='Model Type') 31 | plt.tight_layout() 32 | plt.savefig(f"plots_and_data_for_paper/{target_data}_LM.png") 33 | plt.savefig(f"plots_and_data_for_paper/{target_data}_LM.pdf") 34 | plt.close() 35 | 36 | print("Done with LM Plots") 37 | 38 | 39 | def create_mt_plot(path: str): 40 | data = pd.read_csv(path, header=0, index_col=None) 41 | for idx, (target_data, target_scale) in enumerate([("MTNT", (8, 15)), ("UN", (30, 33))]): 42 | with sns.axes_style("whitegrid"): 43 | g = sns.catplot( 44 | data=data, kind="bar", 45 | x="Method", y=f"Avg_{target_data}", hue="Setting", 46 | ci=None, legend=False, height=6, aspect=8/6, legend_out=True 47 | ) 48 | g.set_axis_labels("Method", "BLEU") 49 | plt.title(target_data) 50 | g.set(ylim=target_scale) 51 | # plt.legend(loc='upper left', title='Model Type') 52 | plt.tight_layout() 53 | plt.savefig(f"plots_and_data_for_paper/{target_data}_MT.png") 54 | plt.savefig(f"plots_and_data_for_paper/{target_data}_MT.pdf") 55 | plt.close() 56 | print("Done") 57 | 58 | 59 | if __name__ == "__main__": 60 | create_lm_plot("plots_and_data_for_paper/lm_individual.csv") 61 | create_mt_plot("plots_and_data_for_paper/mt_individual.csv") -------------------------------------------------------------------------------- /create_data/make_nc_data/sample_train_test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import glob 4 | import json 5 | import random 6 | import pandas as pd 7 | import numpy as np 8 | 9 | np.random.seed(1) 10 | random.seed(1) 11 | 12 | labels = { 13 | "finance": 0, 14 | "entertainment": 1, 15 | "sports": 2, 16 | "news": 3, 17 | "autos": 4, 18 | "video": 5, 19 | "lifestyle": 6, 20 | "travel": 7, 21 | "health": 8, 22 | "foodanddrink": 9, 23 | } 24 | 25 | def read_tsv(file_path: str, skip_first=False): 26 | data = [] 27 | with open(file_path, "r") as fin: 28 | for idx, line in enumerate(fin): 29 | if skip_first and not idx: 30 | continue 31 | skip_flag = False 32 | segments = line.strip().split("\t") 33 | assert len(segments) == 3, segments 34 | if not skip_flag: 35 | data.append({ 36 | "input": segments[0] + segments[1], # combine news title with news body 37 | "label": labels[segments[2]], 38 | }) 39 | return pd.DataFrame(data) 40 | 41 | 42 | # query \t news title \t news body \t news category 43 | def gather_data(): 44 | num_for_train = 8000 45 | num_for_dev_test = 1000 46 | save_path = "splits_data" 47 | datasets = [] 48 | lang_ids = ["de", "en", "es", "fr", "ru"] 49 | for lang_id in lang_ids: 50 | dev = read_tsv(f"xglue_full_dataset/NC/xglue.nc.{lang_id}.dev") # test doesn't have labels / no train for other langs 51 | datasets.append((lang_id, dev)) 52 | 53 | # downsample training sets to simulate FL scenario 54 | for (lang_id, dev) in datasets: 55 | print(lang_id, "saving to file") 56 | save_path = f"nc/{lang_id}" 57 | if not os.path.isdir("nc"): 58 | os.makedirs("nc") 59 | 60 | all_data = dev.sample(frac=1) 61 | train_sampled = all_data.iloc[:num_for_train] 62 | dev = all_data.iloc[num_for_train:num_for_train+num_for_dev_test] 63 | test = all_data.iloc[num_for_train+num_for_dev_test : ] 64 | dev.to_csv(save_path + "_dev.csv", index=None) 65 | test.to_csv(save_path + "_test.csv", index=None) 66 | train_sampled.to_csv(save_path + "_train.csv", index=None) 67 | 68 | print(f"train_sampled shape {train_sampled.shape}") 69 | print(f"dev shape {dev.shape}") 70 | print(f"test shape {test.shape}") 71 | 72 | return {} 73 | 74 | 75 | if __name__ == "__main__": 76 | data_w_text = gather_data() 77 | -------------------------------------------------------------------------------- /create_data/make_europarl_data_for_lm/clean_and_split_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import glob 4 | import json 5 | import random 6 | import pandas as pd 7 | from sklearn.model_selection import train_test_split 8 | import numpy as np 9 | 10 | np.random.seed(1) 11 | random.seed(1) 12 | 13 | def gather_all_files(folder_path: str) -> dict: 14 | all_files = {} 15 | for file in glob.glob(os.path.join(folder_path, "europarl-v9.*")): 16 | if "gz" in file: 17 | continue 18 | print(f"Loading {file}") 19 | lang = file.split("/")[-1].split(".")[-1] 20 | data = [] 21 | with open(file, "r") as fin: 22 | for line in fin: 23 | data.append(line) 24 | all_files[lang] = data 25 | 26 | return all_files 27 | 28 | def preprocess_str(x: str) -> str: 29 | x = x.strip() 30 | if len(x) < 3 or len(x.split(" ")) < 3: 31 | return "" 32 | return x 33 | 34 | def preprocess_data(data_w_text: dict) -> dict: 35 | new_data = {} 36 | for lang, data in data_w_text.items(): 37 | data_non_nan = [line for line in data if line not in ["", None]] 38 | data_non_nan = [preprocess_str(x) for x in data_non_nan] 39 | data_non_nan = [line for line in data_non_nan if line not in [""]] 40 | new_data[lang] = pd.DataFrame({"text": data_non_nan}) 41 | return new_data 42 | 43 | def split_and_save_data(save_path: str, min_size: int, data_w_text: dict, test_percent: float): 44 | for lang, data in data_w_text.items(): 45 | print(f"Saving final data for lang {lang}") 46 | kept_data = data.sample(n=min_size, replace=False) 47 | train, dev_and_test = train_test_split(kept_data, test_size=test_percent) 48 | dev, test = train_test_split(dev_and_test, test_size=0.5) 49 | 50 | final_save_path = os.path.join(save_path, lang + "_train.csv") 51 | train.to_csv(final_save_path, index=False) 52 | dev.to_csv(final_save_path.replace("train", "dev"), index=False) 53 | test.to_csv(final_save_path.replace("train", "test"), index=False) 54 | print(f"Train shape {train.shape} and dev shape {dev.shape} and test shape {test.shape}") 55 | 56 | def randomly_select_and_split(file_path: str, test_percent: float = 0.333333): 57 | save_path = "splits_data" 58 | if not os.path.isdir(save_path): 59 | os.makedirs(save_path) 60 | 61 | data_w_text = gather_all_files(file_path) 62 | data_w_text = preprocess_data(data_w_text) 63 | info = [(lang, len(data)) for lang, data in data_w_text.items() if data is not None and not data.empty] 64 | min_size = min([30000] + [len(data) for _, data in data_w_text.items() if data is not None and not data.empty]) 65 | print(f"Minimum Data Size is {min_size}") 66 | split_and_save_data(os.path.join(file_path, save_path), min_size, data_w_text, test_percent) 67 | 68 | 69 | if __name__ == "__main__": 70 | randomly_select_and_split("./") 71 | -------------------------------------------------------------------------------- /create_data/make_mtnt_data/sample_train_test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import glob 4 | import json 5 | import random 6 | import pandas as pd 7 | import numpy as np 8 | 9 | np.random.seed(1) 10 | random.seed(1) 11 | 12 | def read_tsv(file_path: str, trg_lang: str): 13 | data = [] 14 | with open(file_path, "r") as fin: 15 | for line in fin: 16 | segments = line.strip().split("\t") 17 | if len(segments) != 3: 18 | if len(segments) == 2 and len(segments[0].split(" ")) > 1: 19 | data.append({ 20 | "id": -1, 21 | "en": segments[0], 22 | trg_lang: segments[1] 23 | }) 24 | else: 25 | print(file_path, "error") 26 | breakpoint() 27 | raise Exception(segments) 28 | else: 29 | data.append({ 30 | "id": segments[0], 31 | "en": segments[1], 32 | trg_lang: segments[2] 33 | }) 34 | return pd.DataFrame(data) 35 | 36 | def gather_data(): 37 | save_path = "splits_data" 38 | if not os.path.isdir(save_path): 39 | os.makedirs(save_path) 40 | 41 | en_ja = read_tsv("MTNT/train/train.en-ja.tsv.corrected", trg_lang="ja") 42 | en_fr = read_tsv("MTNT/train/train.en-fr.tsv.corrected", trg_lang="fr") 43 | 44 | # remove super short sentences that are numbers or emojis 45 | en_ja_short = en_ja.apply(lambda x: len(x['en'].split(" ")) < 3, axis=1) 46 | en_ja = en_ja[~en_ja_short] 47 | 48 | en_ja_only = set(en_ja["en"].to_list()) 49 | en_fr_matched_bool = en_fr.apply(lambda x: x["en"] in en_ja_only, axis=1) 50 | matched_en_fr = en_fr[en_fr_matched_bool] 51 | not_matched_en_fr = en_fr[~en_fr_matched_bool] 52 | 53 | num_to_random_sample = len(en_ja) - len(matched_en_fr) 54 | additional_samples = not_matched_en_fr.sample(n=num_to_random_sample) 55 | full_en_fr = pd.concat([matched_en_fr, additional_samples]) 56 | 57 | full_en_fr = full_en_fr[full_en_fr.columns[1:]] 58 | en_ja = en_ja[en_ja.columns[1:]] 59 | 60 | en_ja.to_csv(os.path.join(save_path, "en-ja_train.csv"), index=None) 61 | full_en_fr.to_csv(os.path.join(save_path, "en-fr_train.csv"), index=None) 62 | print(f"En-Ja Train shape {en_ja.shape} ") 63 | print(f"En-Fr Train shape {full_en_fr.shape} ") 64 | 65 | 66 | testing_files = [ 67 | ("MTNT/valid/valid.en-fr.tsv.corrected", "en-fr_dev.csv"), 68 | ("MTNT/valid/valid.en-ja.tsv.corrected", "en-ja_dev.csv"), 69 | ("MTNT/test/test.en-fr.tsv.corrected", "en-fr_test.csv"), 70 | ("MTNT/test/test.en-ja.tsv.corrected", "en-ja_test.csv"), 71 | ] 72 | for file_path, save_name in testing_files: 73 | trg_lang = save_name.split("-")[1].split("_")[0] 74 | data = read_tsv(file_path, trg_lang) 75 | data = data[data.columns[1:]] 76 | data.to_csv(os.path.join(save_path, save_name), index=None) 77 | print(f"{save_name} shape {data.shape}") 78 | 79 | return {} 80 | 81 | 82 | if __name__ == "__main__": 83 | data_w_text = gather_data() 84 | # size following https://aclanthology.org/2020.wmt-1.4.pdf -------------------------------------------------------------------------------- /create_data/make_un_data_for_lm/gather_exclusive_paths.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import json 4 | import tqdm 5 | from pathlib import Path 6 | from collections import defaultdict, Counter 7 | import xml.etree.ElementTree as ET 8 | import xml 9 | import numpy as np 10 | import random 11 | random.seed(3) 12 | 13 | LANGS = [ "en", "ar", "es", "fr", "ru", "zh"] 14 | 15 | def read_in_text(xml_file: str): 16 | lines = [] 17 | with open(xml_file, "r") as fin: 18 | try: 19 | root = ET.parse(fin).getroot() 20 | for p_tag in root.findall('text')[0].find("body").findall("p"): 21 | for s_tag in p_tag.findall("s"): 22 | if s_tag.text is not None: 23 | lines.append(s_tag.text) 24 | except xml.etree.ElementTree.ParseError: 25 | print(f"Cannot parse {xml_file}") 26 | if "zh" in xml_file[:10]: 27 | return [line for line in lines if len(line) > 5] 28 | else: 29 | return [line for line in lines if len(line.split(' ')) > 5] 30 | 31 | def get_paths(lang_path: str) -> list: 32 | all_paths = [] 33 | for file_path in tqdm.tqdm(Path(lang_path).rglob(os.path.join("*.xml"))): 34 | relative_path = "/".join(str(file_path).split("/")[2:]) 35 | all_paths.append(relative_path) 36 | assert len(all_paths) == len(set(all_paths)), f"{len(all_paths)} {len(set(all_paths))}" 37 | return all_paths 38 | 39 | def gather_exclusive_paths(base_path: str): 40 | lang_map = {} 41 | total_docs = 0 42 | for lang in LANGS: 43 | print(f'Gathering paths for lang {lang}') 44 | lang_map[lang] = get_paths(os.path.join(base_path, lang)) 45 | total_docs += len(lang_map[lang]) 46 | 47 | print(f"The number of total docs is {total_docs}") 48 | assert total_docs == 799276, f"Expected 799,276 documents - got {total_docs}" 49 | 50 | reverse_dict = defaultdict(list) 51 | for lang, paths in lang_map.items(): 52 | for file_path in paths: 53 | reverse_dict[file_path].append(lang) 54 | 55 | langs_num = [len(lang_list) for _, lang_list in reverse_dict.items()] 56 | lang_counter = Counter(langs_num) 57 | print(f"Average langs={np.mean(langs_num)} for {lang_counter} with total unique={len(reverse_dict)}") 58 | 59 | exclusive_map = defaultdict(list) 60 | for file_path, langs_available in reverse_dict.items(): 61 | # if there are multiple options for languages for a document, randomly choose 62 | if len(langs_available) == 6 and False: 63 | lang_to_use = random.choices( 64 | population=["ar", "en", "es", "fr", "ru", "zh"], 65 | weights=[0.23669934, 0.06437802, 0.18635274, 0.06519098, 0.16047007, 0.28690884], 66 | k=1 67 | ) 68 | # np.array([20909,16215,37016,36719,27827,25134]) 69 | # ar en es fr ru zh 70 | # array([0.12763399, 0.09898059, 0.22595532, 0.22414235, 0.16986326, 71 | # 0.15342449]) 72 | # [20909,16215,37016,36719,27827,25134] / 6 73 | # (1/6) + ((1/6) - (arr / 163820)) 74 | else: 75 | lang_to_use = random.sample(langs_available, 1) 76 | assert len(lang_to_use) == 1 77 | # derive the full path back 78 | exclusive_map[lang_to_use[0]].append(f"UNv1.0-TEI/{lang_to_use[0]}/{file_path}") 79 | 80 | # sanity check here that they are unique 81 | total = [] 82 | for lang, paths in exclusive_map.items(): 83 | lang_data = list(exclusive_map[lang]) 84 | exclusive_map[lang] = lang_data 85 | print(f"Lang {lang} has {len(lang_data)} unique items") 86 | total.extend(lang_data) 87 | assert len(total) == len(set(total)), f"{len(total)} vs {len(set(total))}" 88 | 89 | with open("exclusive_dict.json", "w") as fout: 90 | json.dump(exclusive_map, fout, indent=4) 91 | 92 | for lang, paths in exclusive_map.items(): 93 | print(f"Gathering lines for {lang}") 94 | data = [] 95 | for file_path in tqdm.tqdm(paths): 96 | data.extend(read_in_text(file_path)) 97 | print(f"Has {len(data)} lines") 98 | with open(f"{lang}.txt", "w") as fout: 99 | for line in data: 100 | fout.write(line + "\n") 101 | 102 | 103 | if __name__ == "__main__": 104 | gather_exclusive_paths("UNv1.0-TEI") -------------------------------------------------------------------------------- /dataset_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import glob 4 | import random 5 | import pickle 6 | import copy 7 | from collections import OrderedDict 8 | from pathlib import Path 9 | from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union 10 | import shutil 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import torch 15 | from torch.nn.utils.rnn import pad_sequence 16 | from torch.utils.data import DataLoader 17 | 18 | from transformers import DataCollatorForLanguageModeling, set_seed 19 | 20 | set_seed(1) 21 | 22 | from constants import MBART_MAP, MAP_LANG_MAP 23 | 24 | 25 | class LineByLineTextDataset(torch.utils.data.Dataset): 26 | """ 27 | Deprecated Huggingface Dataset 28 | """ 29 | 30 | def __init__(self, tokenizer, file_path: str, block_size: int = 512, test_flag: int = 0, examples = None): 31 | if examples is not None: 32 | self.examples = [torch.tensor(e, dtype=torch.long) for e in examples] 33 | else: 34 | if os.path.isfile(file_path) is False: 35 | raise ValueError(f"Input file path {file_path} not found") 36 | 37 | with open(file_path, encoding="utf-8") as f: 38 | lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] 39 | if test_flag: 40 | lines = lines[:test_flag] 41 | 42 | batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size) 43 | self.examples = batch_encoding["input_ids"] 44 | self.examples = [torch.tensor(e, dtype=torch.long) for e in self.examples] 45 | 46 | def __len__(self): 47 | return len(self.examples) 48 | 49 | def __getitem__(self, i) -> Dict[str, torch.tensor]: 50 | return self.examples[i] 51 | 52 | 53 | class MultilingualDataset(LineByLineTextDataset): 54 | """ 55 | Loads CSV files in a directory where each file is a separate language 56 | Can be loaded by passing in `examples` where examples is a tuple of list of integers representing words and the language string 57 | """ 58 | def __init__(self, tokenizer, file_path: str, split: str = "train", block_size: int = 512, test_flag: int = 0, examples = None, skip_langs: list = []): 59 | LANG_MAP = MAP_LANG_MAP[file_path] 60 | skip_langs = [] 61 | if examples is not None: 62 | self.examples = [(torch.tensor(e, dtype=torch.long), lang) for e, lang in examples if lang not in skip_langs] 63 | else: 64 | if os.path.isdir(file_path) is False: 65 | raise ValueError(f"Input file directory {file_path} not found") 66 | 67 | self.examples = [] 68 | for lang_file_path in glob.glob(os.path.join(file_path, f"*_{split}.csv")): 69 | lang = lang_file_path.split("/")[-1].split("_")[0] 70 | if os.path.isfile(lang_file_path.replace(".csv", ".pkl")): 71 | with open(lang_file_path.replace(".csv", ".pkl"), "rb") as fin: 72 | examples = pickle.load(fin) 73 | else: 74 | lines = pd.read_csv(lang_file_path, header=0, index_col=None)["text"].tolist() 75 | if test_flag: 76 | lines = lines[:test_flag] 77 | 78 | batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size) 79 | examples = batch_encoding["input_ids"] 80 | # cache the tokenization to save time 81 | with open(lang_file_path.replace(".csv", ".pkl"), "wb") as fout: 82 | pickle.dump(examples, fout) 83 | 84 | examples = [(torch.tensor(e, dtype=torch.long), LANG_MAP[lang]) for e in examples] 85 | self.examples.extend(examples) 86 | 87 | 88 | class MTDataset(LineByLineTextDataset): 89 | """ 90 | Loads CSV files in a directory where each file is a separate language 91 | Can be loaded by passing in `examples` where examples is a tuple of two lists of integers representing words and the language string 92 | """ 93 | def __init__(self, tokenizer, file_path: str, split: str = "train", block_size: int = 512, test_flag: int = 0, examples = None, skip_langs: list = []): 94 | LANG_MAP = MAP_LANG_MAP[file_path] 95 | skip_langs = [] 96 | if examples is not None: 97 | self.examples = [([torch.tensor(e, dtype=torch.long), torch.tensor(l, dtype=torch.long)], lang) for ((e, l), lang) in examples if lang not in skip_langs] 98 | else: 99 | if os.path.isdir(file_path) is False: 100 | raise ValueError(f"Input file directory {file_path} not found") 101 | 102 | self.examples = [] 103 | for lang_file_path in glob.glob(os.path.join(file_path, f"*_{split}.csv")): 104 | lang = lang_file_path.split("/")[-1].split("_")[0] 105 | if os.path.isfile(lang_file_path.replace(".csv", ".pkl")): 106 | with open(lang_file_path.replace(".csv", ".pkl"), "rb") as fin: 107 | examples = pickle.load(fin) 108 | else: 109 | data = pd.read_csv(lang_file_path, header=0, index_col=None) 110 | if test_flag: 111 | lines = lines[:test_flag] 112 | 113 | all_examples = [] 114 | 115 | # order them according to direction we want 116 | col_order = data.columns 117 | for col_name in col_order: 118 | tokenizer.src_lang = col_name 119 | batch_encoding = tokenizer(data[col_name].tolist(), add_special_tokens=True, truncation=True, max_length=block_size) 120 | examples = batch_encoding["input_ids"] 121 | all_examples.append(examples) 122 | 123 | assert len(all_examples[0]) == len(all_examples[1]) 124 | examples = list(zip(*all_examples)) # both langs in each instance 125 | # cache the tokenization to save time 126 | with open(lang_file_path.replace(".csv", ".pkl"), "wb") as fout: 127 | pickle.dump(examples, fout) 128 | 129 | examples = [([torch.tensor(e, dtype=torch.long), torch.tensor(l, dtype=torch.long)], LANG_MAP[lang]) for (e, l) in examples] 130 | self.examples.extend(examples) 131 | 132 | class PAWSDataset(LineByLineTextDataset): 133 | """ 134 | Loads CSV files in a directory where each file is a separate language 135 | Can be loaded by passing in `examples` where examples is a tuple of two lists of integers representing words and the language string 136 | """ 137 | def __init__(self, tokenizer, file_path: str, split: str = "train", block_size: int = 512, test_flag: int = 0, examples = None, skip_langs: list = []): 138 | LANG_MAP = MAP_LANG_MAP[file_path] 139 | skip_langs = [] 140 | if examples is not None: 141 | self.examples = [([torch.tensor(e, dtype=torch.long), torch.tensor(l, dtype=torch.long)], lang) for ((e, l), lang) in examples if lang not in skip_langs] 142 | else: 143 | if os.path.isdir(file_path) is False: 144 | raise ValueError(f"Input file directory {file_path} not found") 145 | 146 | self.examples = [] 147 | for lang_file_path in glob.glob(os.path.join(file_path, f"*_{split}.csv")): 148 | lang = lang_file_path.split("/")[-1].split("_")[0] 149 | if os.path.isfile(lang_file_path.replace(".csv", ".pkl")): 150 | with open(lang_file_path.replace(".csv", ".pkl"), "rb") as fin: 151 | examples = pickle.load(fin) 152 | else: 153 | data = pd.read_csv(lang_file_path, header=0, index_col=None) 154 | all_examples = [] 155 | # order them according to direction we want 156 | sents1 = data["sentence1"].tolist() 157 | sents2 = data["sentence2"].tolist() 158 | for idx in range(len(data)): 159 | tokenizer.src_lang = lang 160 | encoding = tokenizer(sents1[idx], sents2[idx], add_special_tokens=True, truncation=True, max_length=block_size) 161 | examples = encoding["input_ids"] 162 | label = data.iloc[idx].label 163 | all_examples.append((examples, int(label))) 164 | 165 | # cache the tokenization to save time 166 | with open(lang_file_path.replace(".csv", ".pkl"), "wb") as fout: 167 | pickle.dump(all_examples, fout) 168 | examples = all_examples 169 | 170 | 171 | 172 | examples = [([torch.tensor(e, dtype=torch.long), torch.tensor(l, dtype=torch.long).unsqueeze(0)], LANG_MAP[lang]) for (e, l) in examples] 173 | self.examples.extend(examples) 174 | 175 | if test_flag: 176 | print(f"Using a debug run of {test_flag} examples") 177 | self.examples = self.examples[:test_flag] 178 | 179 | 180 | 181 | class ClassificationDataset(LineByLineTextDataset): 182 | """ 183 | Loads CSV files in a directory where each file is a separate language 184 | Can be loaded by passing in `examples` where examples is a tuple of two lists of integers representing words and the language string 185 | """ 186 | def __init__(self, tokenizer, file_path: str, split: str = "train", block_size: int = 512, test_flag: int = 0, examples = None, skip_langs: list = []): 187 | LANG_MAP = MAP_LANG_MAP[file_path] 188 | skip_langs = [] 189 | if examples is not None: 190 | self.examples = [([torch.tensor(e, dtype=torch.long), torch.tensor(l, dtype=torch.long)], lang) for ((e, l), lang) in examples if lang not in skip_langs] 191 | else: 192 | if os.path.isdir(file_path) is False: 193 | raise ValueError(f"Input file directory {file_path} not found") 194 | 195 | self.examples = [] 196 | for lang_file_path in glob.glob(os.path.join(file_path, f"*_{split}.csv")): 197 | lang = lang_file_path.split("/")[-1].split("_")[0] 198 | if os.path.isfile(lang_file_path.replace(".csv", ".pkl")): 199 | with open(lang_file_path.replace(".csv", ".pkl"), "rb") as fin: 200 | examples = pickle.load(fin) 201 | else: 202 | data = pd.read_csv(lang_file_path, header=0, index_col=None) 203 | all_examples = [] 204 | 205 | # order them according to direction we want 206 | labels = data["label"].tolist() 207 | sents = data["input"].tolist() 208 | tokenizer.src_lang = lang 209 | encoding = tokenizer(sents, add_special_tokens=True, truncation=True, max_length=block_size) 210 | all_examples = list(zip(encoding["input_ids"], labels)) 211 | 212 | # cache the tokenization to save time 213 | with open(lang_file_path.replace(".csv", ".pkl"), "wb") as fout: 214 | pickle.dump(all_examples, fout) 215 | examples = all_examples 216 | 217 | examples = [([torch.tensor(e, dtype=torch.long), torch.tensor(l, dtype=torch.long).unsqueeze(0)], LANG_MAP[lang]) for (e, l) in examples] 218 | self.examples.extend(examples) 219 | 220 | if test_flag: 221 | print(f"Using a debug run of {test_flag} examples") 222 | self.examples = self.examples[:test_flag] 223 | 224 | def get_dataset_type(path_to_data): 225 | is_multilingual = "wmt" in str(path_to_data) or "un_corpus" in str(path_to_data) 226 | is_mt = "mt_corpus" in str(path_to_data) 227 | is_paws = "pawsx" in str(path_to_data) 228 | if is_paws: 229 | dataset_type = PAWSDataset 230 | elif "nc" in str(path_to_data): 231 | dataset_type = ClassificationDataset 232 | elif is_multilingual: 233 | dataset_type = MultilingualDataset 234 | elif is_mt: 235 | dataset_type = MTDataset 236 | else: 237 | dataset_type = LineByLineTextDataset 238 | return dataset_type 239 | 240 | 241 | def get_dataset(path_to_data: Path, cid: str, partition: str): 242 | # generate path to cid's data 243 | path_to_data = path_to_data / cid / (partition + ".npy") 244 | data = np.load(path_to_data, allow_pickle=True) 245 | dataset_type = get_dataset_type(str(path_to_data)) 246 | return dataset_type(None, "/".join(str(path_to_data).split("/")[:2]), examples=data.tolist()) 247 | 248 | 249 | def get_random_id_splits(total: int, val_ratio: float, shuffle: bool = True): 250 | """splits a list of length `total` into two following a 251 | (1-val_ratio):val_ratio partitioning. 252 | 253 | By default the indices are shuffled before creating the split and 254 | returning. 255 | """ 256 | 257 | if isinstance(total, int): 258 | indices = list(range(total)) 259 | else: 260 | indices = total 261 | 262 | split = int(np.floor(val_ratio * len(indices))) 263 | if not split: 264 | split = 1 # need at least 1 validation instance 265 | if shuffle: 266 | np.random.shuffle(indices) 267 | return indices[split:], indices[:split] 268 | 269 | 270 | def make_collate_fn(tokenizer): 271 | def collate_fn(batch): 272 | tensors = pad_sequence(batch, batch_first=True, padding_value=tokenizer.pad_token_id) 273 | attn_mask = torch.ones_like(tensors) 274 | is_padding = tensors == tokenizer.pad_token_id 275 | attn_mask[is_padding] = 0 # is padding 276 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) 277 | output_batch = data_collator(tuple(tensors)) 278 | output_batch["attention_mask"] = attn_mask 279 | return output_batch 280 | return collate_fn 281 | 282 | def make_collate_fn_wlang(tokenizer): 283 | def collate_fn_wlang(batch): 284 | langs = torch.tensor([lang for (_, lang) in batch]) 285 | batched_tensors = pad_sequence([num for (num, _) in batch], batch_first=True, padding_value=tokenizer.pad_token_id) 286 | attn_mask = torch.ones_like(batched_tensors) 287 | is_padding = batched_tensors == tokenizer.pad_token_id 288 | attn_mask[is_padding] = 0 289 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) 290 | output_batch = data_collator(tuple(batched_tensors)) 291 | output_batch["langs"] = langs 292 | output_batch["attention_mask"] = attn_mask 293 | return output_batch 294 | return collate_fn_wlang 295 | 296 | def make_collate_fn_mt_wlang(tokenizer): 297 | def collate_fn_mt_wlang(batch): 298 | langs = torch.tensor([lang for (_, lang) in batch]) 299 | batched_input_tensors = pad_sequence([e for ((e, l), _) in batch], batch_first=True, padding_value=tokenizer.pad_token_id) 300 | batched_label_tensors = pad_sequence([l for ((e, l), _) in batch], batch_first=True, padding_value=tokenizer.pad_token_id) 301 | attn_mask = torch.ones_like(batched_input_tensors) 302 | is_padding = batched_input_tensors == tokenizer.pad_token_id 303 | attn_mask[is_padding] = 0 304 | output_batch = { 305 | "langs": langs, 306 | "attention_mask": attn_mask, 307 | "labels": batched_label_tensors, 308 | "input_ids": batched_input_tensors 309 | } 310 | return output_batch 311 | return collate_fn_mt_wlang 312 | 313 | def flatten(t): 314 | return [item for sublist in t for item in sublist] 315 | 316 | class EvenClassSampler: 317 | def __init__(self, classes): 318 | self.classes = classes 319 | self.class_idxs = [[] for _ in range(len(set(self.classes)))] 320 | [self.class_idxs[class_num].append(i) for i, class_num in enumerate(self.classes)] 321 | for i in range(len(self.class_idxs)): 322 | random.shuffle(self.class_idxs[i]) 323 | self.new_indexes = flatten(list(zip(*self.class_idxs))) 324 | 325 | def __iter__(self): 326 | return iter(self.new_indexes) 327 | 328 | 329 | def get_collate_fn(data, tokenizer): 330 | if data == "brown": 331 | return make_collate_fn(tokenizer) 332 | elif data in ["un_mt_corpus", "mtnt", "mtnt_mt_corpus", "pawsx", "nc"]: 333 | return make_collate_fn_mt_wlang(tokenizer) 334 | else: 335 | return make_collate_fn_wlang(tokenizer) 336 | 337 | def get_dataloader( 338 | path_to_data: str, cid: str, is_train: bool, batch_size: int, workers: int, data: str, 339 | tokenizer, shuffle: bool = False, lang_mix: int = -1 340 | ): 341 | """Generates trainset/valset object and returns appropiate dataloader.""" 342 | partition = "train" if is_train else "val" 343 | if type(path_to_data) not in [MultilingualDataset, torch.utils.data.Dataset, LineByLineTextDataset, \ 344 | MTDataset, PAWSDataset, ClassificationDataset]: 345 | dataset = get_dataset(Path(path_to_data), cid, partition) 346 | else: 347 | dataset = path_to_data 348 | 349 | # we use as number of workers all the cpu cores assigned to this actor 350 | kwargs = {"num_workers": workers, "pin_memory": True, "drop_last": False} 351 | if lang_mix == 1.0: 352 | kwargs["sampler"] = EvenClassSampler([item[1] for item in dataset]) 353 | elif shuffle: 354 | kwargs["shuffle"] = True 355 | 356 | c_func = get_collate_fn(data, tokenizer) 357 | return DataLoader(dataset, batch_size=batch_size, collate_fn=c_func, **kwargs) 358 | 359 | def split(a, n): 360 | k, m = divmod(len(a), n) 361 | return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n)) 362 | 363 | def do_fl_partitioning_brown(path_to_dataset, dataset, pool_size, val_ratio=0.0): 364 | dataset = [item.numpy() for item in dataset] # need to use numpy to save since PyTorch wants same sized batch 365 | random.shuffle(dataset) 366 | partitions = list(split(dataset, pool_size)) 367 | 368 | # now save partitioned dataset to disk 369 | # first delete dir containing splits (if exists), then create it 370 | splits_dir = Path(path_to_dataset).parent / "federated" 371 | if splits_dir.exists(): 372 | shutil.rmtree(splits_dir) 373 | Path.mkdir(splits_dir, parents=True) 374 | 375 | for p in range(pool_size): 376 | cur_data = np.array(partitions[p], dtype=object) 377 | # create dir 378 | Path.mkdir(splits_dir / str(p)) 379 | 380 | if val_ratio > 0.0: 381 | # split data according to val_ratio 382 | train_idx, val_idx = get_random_id_splits(len(cur_data), val_ratio) 383 | val_cur_data = cur_data[np.array(val_idx)] 384 | np.save(splits_dir / str(p) / "val.npy", val_cur_data) 385 | 386 | # remaining for training 387 | cur_data = cur_data[np.array(train_idx)] 388 | # save train set 389 | np.save(splits_dir / str(p) / "train.npy", cur_data) 390 | 391 | return splits_dir 392 | 393 | 394 | def convert_to_np(item) -> list: 395 | if type(item) == list: 396 | return [i.numpy() for i in item] 397 | else: 398 | return item.numpy() 399 | 400 | 401 | def do_fl_partitioning(path_to_dataset: str, dataset, pool_size: int, cache_str: str, 402 | lang_mix: float = 0.0, val_ratio=0.0): 403 | # NOTE: tensor may be a list of tensors if seq to seq or something 404 | dataset = [(convert_to_np(tensor_item), lang_id) for (tensor_item, lang_id) in dataset] 405 | dataset_df = pd.DataFrame(dataset, columns=["tensor", "lang_id"]) 406 | sample_df = dataset_df.copy() 407 | 408 | if pool_size == 1: # centralized 409 | partitions = [dataset] 410 | else: # distributed 411 | partition_size = len(dataset_df) // pool_size 412 | 413 | same_lang_num = int(partition_size * (1 - lang_mix)) # floored 414 | if lang_mix == 1.0: 415 | same_lang_num = int(partition_size) # get them separated then zip them later 416 | partitions = [[] for x in range(len(dataset_df.lang_id.unique()))] 417 | 418 | # start by making partitions by lang only, sampling lang_mix 419 | for (lang, lang_df) in dataset_df.groupby(["lang_id"]): 420 | sampled_for_lang = lang_df.sample(n=same_lang_num) 421 | sampled_idx = sampled_for_lang.index 422 | sample_df = sample_df.drop(sampled_idx) 423 | for (idx, row) in sampled_for_lang.iterrows(): 424 | partitions[lang].append((row["tensor"], row["lang_id"])) 425 | 426 | # now sample the rest from the great pool of available instances 427 | for partition_idx in range(len(partitions)): 428 | left_over_sampling = partition_size - same_lang_num 429 | sampled_for_lang = sample_df.sample(n=left_over_sampling) 430 | sampled_idx = sampled_for_lang.index 431 | sample_df = sample_df.drop(sampled_idx) 432 | for (idx, row) in sampled_for_lang.iterrows(): 433 | partitions[partition_idx].append((row["tensor"], row["lang_id"])) 434 | 435 | if lang_mix == 1.0: 436 | # we want the batches to be perfectly split with each language 437 | # zip them together - creates a batch of each one 438 | partition_list = list(zip(*partitions)) 439 | num_batches_per_partition = len(partition_list) // len(partitions) 440 | # now divide it into an almost equal number of batches per device 441 | for partition_num in range(len(partitions)): 442 | start_batch = partition_num * num_batches_per_partition 443 | end_batch = (partition_num + 1) * num_batches_per_partition 444 | if partition_num == (len(partitions) - 1): 445 | end_batch = len(partition_list) 446 | batches_for_partition = partition_list[start_batch:end_batch] 447 | all_items_in_batches = [] 448 | [all_items_in_batches.extend(list(batch)) for batch in batches_for_partition] 449 | partitions[partition_num] = all_items_in_batches 450 | 451 | zero_partitions_langs = pd.Series([item[1] for item in partitions[0]]).value_counts(normalize=True) 452 | print(f"The 0th partition has lang id mapping percent:\n{zero_partitions_langs} with {len(partitions[0])} instances") 453 | 454 | # now save partitioned dataset to disk 455 | # first delete dir containing splits (if exists), then create it 456 | splits_dir = Path(path_to_dataset) / ("federated_" + cache_str) 457 | if splits_dir.exists(): 458 | shutil.rmtree(splits_dir) 459 | Path.mkdir(splits_dir, parents=True) 460 | 461 | for p in range(pool_size): 462 | cur_data = np.array(partitions[p], dtype=object) 463 | # create dir 464 | Path.mkdir(splits_dir / str(p)) 465 | 466 | if val_ratio > 0.0: 467 | # split data according to val_ratio 468 | train_idx, val_idx = get_random_id_splits(len(cur_data), val_ratio) 469 | val_cur_data = cur_data[np.array(val_idx)] 470 | np.save(splits_dir / str(p) / "val.npy", val_cur_data) 471 | 472 | # remaining for training 473 | cur_data = cur_data[np.array(train_idx)] 474 | # save train set 475 | np.save(splits_dir / str(p) / "train.npy", cur_data) 476 | 477 | return splits_dir -------------------------------------------------------------------------------- /main_lm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | from collections import OrderedDict, defaultdict 5 | from pathlib import Path 6 | from typing import Dict, Callable, Optional, Tuple 7 | import traceback 8 | import tqdm 9 | 10 | import math 11 | import bisect 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader 17 | 18 | import flwr as fl 19 | from flwr.common.typing import Scalar 20 | import ray 21 | from sacrebleu.metrics import BLEU, CHRF, TER 22 | 23 | import transformers 24 | from transformers import ( 25 | CONFIG_MAPPING, 26 | MODEL_MAPPING, 27 | AdamW, 28 | AutoConfig, 29 | AutoModelForCausalLM, 30 | AutoTokenizer, 31 | SchedulerType, 32 | default_data_collator, 33 | get_scheduler, 34 | set_seed, 35 | MT5Model, 36 | T5Tokenizer, 37 | MBartForCausalLM, 38 | DistilBertForSequenceClassification, 39 | XLMRobertaForMaskedLM, 40 | XLMRobertaForSequenceClassification, 41 | BartForCausalLM, 42 | logging, 43 | DistilBertForMaskedLM, 44 | M2M100ForConditionalGeneration, 45 | M2M100Tokenizer, 46 | set_seed 47 | ) 48 | 49 | seed_val = 1 50 | set_seed(seed_val) 51 | print("Seed is", seed_val) 52 | bleu = BLEU() 53 | 54 | import warnings 55 | warnings.filterwarnings("ignore") 56 | logging.set_verbosity_error() # hides warning about CasualLM not loading encoder 57 | 58 | from dataset_utils import ( 59 | LineByLineTextDataset, 60 | get_dataset, 61 | get_random_id_splits, 62 | make_collate_fn_wlang, 63 | make_collate_fn, 64 | get_dataloader, 65 | do_fl_partitioning_brown, 66 | do_fl_partitioning, 67 | MultilingualDataset, 68 | get_dataset_type, 69 | MTDataset, 70 | ) 71 | 72 | from constants import * 73 | 74 | BIG_FILE_CACHE = "./cache" 75 | 76 | ## Global Vars that are set under `if __name__ == "__main__"` 77 | ACCUM_STEPS = 1 78 | BATCH_SIZE = 2 79 | CUDA_COUNT = 0 # need to keep track for clients, iterative take the next one 80 | RANDOM_INIT = False 81 | MODEL_NAME = "" 82 | DATA = "" 83 | client_resources = {"num_gpus": 0, "num_cpus": 1} # NOTE: can do fractional GPUs, this is per process/client 84 | GPU_MAPPING = {} 85 | EVAL_NUM = 0 86 | NUM_SKIP_EVAL = 1 87 | PREV_LOSS = -1.0 88 | tokenizer = None 89 | LEARNING_RATE = None 90 | LANG_MIX = None 91 | CACHE_STR = None 92 | ALL_OPTIMIZERS = {} 93 | TOP_N_SCORES = [] 94 | GLOBAL_LANG_MAP = None 95 | 96 | # borrowed from Pytorch quickstart example 97 | def train(net, trainloader, epochs, optimizer, device: str, cid: str = "", get_accuracy: bool = False): 98 | """Train the network on the training set.""" 99 | global ACCUM_STEPS 100 | net.train() 101 | net.zero_grad() 102 | losses = [] 103 | total, correct = 0, 0 104 | for _ in range(epochs): 105 | for batch_idx, batch in tqdm.tqdm(enumerate(trainloader), total=len(trainloader)): 106 | input_ids = batch["input_ids"] 107 | label_ids = batch["labels"] 108 | attn_mask = batch["attention_mask"] 109 | 110 | label_ids = label_ids.to(device) 111 | input_ids = input_ids.to(device) 112 | attn_mask = attn_mask.to(device) 113 | 114 | x = {"input_ids": input_ids, "labels": label_ids, "attention_mask": attn_mask} 115 | output = net(**x) 116 | loss = output.loss / ACCUM_STEPS 117 | loss.backward() 118 | losses.append(output.loss.cpu().detach().repeat(len(batch))) 119 | 120 | label_ids = label_ids.cpu() 121 | label_ids = label_ids.cpu() 122 | attn_mask = attn_mask.cpu() 123 | 124 | if get_accuracy: 125 | pred_labels = output.logits.argmax(dim=-1).cpu() 126 | truth_labels = label_ids.squeeze(-1).cpu() 127 | correct += torch.sum(torch.eq(pred_labels, truth_labels)).cpu().detach().item() 128 | total += len(pred_labels) 129 | 130 | if (batch_idx + 1) % ACCUM_STEPS == 0: 131 | optimizer.step() 132 | net.zero_grad() 133 | loss = 0 134 | 135 | net = net.cpu() 136 | loss = 0 137 | net.zero_grad() 138 | label_ids = label_ids.to("cpu") 139 | label_ids = label_ids.to("cpu") 140 | attn_mask = attn_mask.to("cpu") 141 | mean_loss = torch.cat(losses).mean() 142 | if get_accuracy: 143 | print(f"TRAIN Accuracy for is {correct/total}") 144 | print(f"Got a TRAIN PPL value of {mean_loss.detach().item()} and {torch.exp(mean_loss).detach().item()} \ 145 | for cid={cid}, label={batch['langs'][0].item()}") 146 | 147 | def test(net, testloader, device: str, get_accuracy: bool = False): 148 | """Validate the network on the entire test set.""" 149 | net.eval() 150 | losses = [] 151 | correct = 0 152 | total = 0 153 | labels_to_losses = defaultdict(list) 154 | labels_to_accuracies = defaultdict(dict) 155 | with torch.no_grad(): 156 | for idx, batch in tqdm.tqdm(enumerate(testloader), total=len(testloader)): 157 | input_ids = batch["input_ids"] 158 | label_ids = batch["labels"] 159 | attn_mask = batch["attention_mask"] 160 | input_ids = input_ids.to(device) 161 | attn_mask = attn_mask.to(device) 162 | label_ids = label_ids.to(device) 163 | x = {"input_ids": input_ids, "labels": label_ids, "attention_mask": attn_mask} 164 | output = net(**x) 165 | if get_accuracy: 166 | pred_labels = output.logits.argmax(dim=-1).cpu() 167 | truth_labels = label_ids.squeeze(-1).cpu() 168 | correct += torch.sum(torch.eq(pred_labels, truth_labels)).item() 169 | total += len(pred_labels) 170 | loss = output.loss 171 | try: 172 | assert len(set(batch["langs"].numpy().tolist())) == 1, set(batch["langs"].numpy().tolist()) 173 | labels_to_losses[batch["langs"][0].item()].append(output.loss.item()) 174 | if "correct" not in labels_to_accuracies[batch["langs"][0].item()]: 175 | labels_to_accuracies[batch["langs"][0].item()]["correct"] = 0 176 | labels_to_accuracies[batch["langs"][0].item()]["total"] = 0 177 | labels_to_accuracies[batch["langs"][0].item()]["correct"] += torch.sum(torch.eq(pred_labels, truth_labels)).item() 178 | labels_to_accuracies[batch["langs"][0].item()]["total"] += len(pred_labels) 179 | except Exception as e: 180 | print(f"Cant make lang ppls unless entire batch is the same: use a round num for batch size: {e}") 181 | losses.append(loss.repeat(len(batch))) 182 | 183 | # from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py 184 | mean_loss = torch.cat(losses).mean() 185 | max_num = max(list(labels_to_losses.keys())) 186 | for label in range(max_num+1): 187 | print(f"For EVAL Label {label} the average PPL is {np.exp(np.mean(labels_to_losses[label]))}") 188 | print(f"For EVAL Label {label} the average accuracy is {labels_to_accuracies[label]['correct'] / labels_to_accuracies[label]['total']}") 189 | net = net.to("cpu") 190 | if get_accuracy: 191 | print(f"EVAL Accuracy is {correct/total}") 192 | label_ids = label_ids.to("cpu") 193 | label_ids = label_ids.to("cpu") 194 | attn_mask = attn_mask.to("cpu") 195 | return mean_loss.item(), torch.exp(mean_loss).detach().item() 196 | 197 | 198 | def test_mt(net, testloader, device: str, get_accuracy: bool = False): 199 | """Validate the network on the entire test set.""" 200 | net.eval() 201 | generated = [] 202 | targets = [] 203 | is_long_enough = lambda x: len(x) > 3 204 | make_list_long = lambda x: [item for item in x if is_long_enough(item)] 205 | def compute_bleu(generated, targets): 206 | assert len(generated) == len(targets) 207 | final_gen = [] 208 | final_targ = [] 209 | for idx in range(len(generated)): 210 | if is_long_enough(generated[idx]) and is_long_enough(targets[idx]): 211 | final_gen.append(generated[idx]) 212 | final_targ.append(targets[idx]) 213 | 214 | return bleu.corpus_score(final_gen, final_targ).score 215 | 216 | labels_to_generated = defaultdict(list) 217 | labels_to_targets = defaultdict(list) 218 | with torch.no_grad(): 219 | print(len(testloader)) 220 | for idx, batch in enumerate(testloader): 221 | print(idx) 222 | input_ids = batch["input_ids"] 223 | label_ids = batch["labels"] 224 | assert len(set(batch["labels"][:, 0].tolist())) == 1 # handling multiple languages is tricky with the forced BOS 225 | attn_mask = batch["attention_mask"] 226 | input_ids = input_ids.to(device) 227 | attn_mask = attn_mask.to(device) 228 | label_ids = label_ids.to(device) 229 | x = {"input_ids": input_ids, "attention_mask": attn_mask} 230 | output = net.generate(**x, forced_bos_token_id=batch["labels"][:, 0][0].item()) 231 | try: 232 | assert len(set(batch["langs"].numpy().tolist())) == 1 233 | labels_to_generated[batch["langs"][0].item()].extend(tokenizer.batch_decode(output, skip_special_tokens=True)) 234 | labels_to_targets[batch["langs"][0].item()].extend(tokenizer.batch_decode(label_ids, skip_special_tokens=True)) 235 | except Exception as e: 236 | print("Cant make lang ppls unless entire batch is the same: use a round num for batch size") 237 | raise e 238 | 239 | max_num = max(list(labels_to_targets.keys())) 240 | all_gen = [] 241 | all_trg = [] 242 | for label in range(max_num+1): 243 | with open(MODEL_NAME.replace(".pt", "") + f".label.{label}.pred", "w") as fout: 244 | for sent in labels_to_generated[label]: 245 | fout.write(sent + "\n") 246 | with open(MODEL_NAME.replace(".pt", "") + f".label.{label}.trg", "w") as fout: 247 | for sent in labels_to_targets[label]: 248 | fout.write(sent + "\n") 249 | print(f"For Label {label} the average BLEU is {compute_bleu(labels_to_generated[label], labels_to_targets[label])}") 250 | all_gen.extend(labels_to_generated[label]) 251 | all_trg.extend(labels_to_targets[label]) 252 | bleu_score = compute_bleu(all_gen, all_trg) 253 | print(f"Total overall BLEU across langs is {bleu_score}") 254 | exit(1) 255 | 256 | net = net.to("cpu") 257 | label_ids = label_ids.to("cpu") 258 | label_ids = label_ids.to("cpu") 259 | attn_mask = attn_mask.to("cpu") 260 | return bleu_score, bleu_score # keep same format, but uneeded 261 | 262 | 263 | # Flower client that will be spawned by Ray 264 | # Adapted from Pytorch quickstart example 265 | class RayClient(fl.client.NumPyClient): 266 | def __init__(self, cid: str, fed_dir_data: str, optimizer, net): 267 | global CUDA_COUNT 268 | global GPU_MAPPING 269 | global LEARNING_RATE 270 | 271 | self.cid = cid 272 | self.fed_dir = Path(fed_dir_data) 273 | self.properties: Dict[str, Scalar] = {"tensor_type": "numpy.ndarray"} 274 | 275 | # instantiate model 276 | self.net = net 277 | self.optimizer = optimizer 278 | 279 | # determine device 280 | cuda_available = torch.cuda.is_available() 281 | device_str = f"cuda:0" if cuda_available else "cpu" # CUDA zero defaults to CUDA_VISIBLE_DEVICES 282 | self.device = torch.device(device_str) 283 | 284 | 285 | def get_parameters(self): 286 | return [val.cpu().numpy() for _, val in self.net.state_dict().items()] 287 | 288 | def get_properties(self, ins): 289 | return self.properties 290 | 291 | def set_parameters(self, parameters): 292 | params_dict = zip(self.net.state_dict().keys(), parameters) 293 | state_dict = OrderedDict( 294 | {k: torch.from_numpy(np.copy(v)) for k, v in params_dict} 295 | ) 296 | self.net.load_state_dict(state_dict, strict=True) 297 | 298 | def fit(self, parameters, config): 299 | global DATA 300 | global BATCH_SIZE 301 | self.set_parameters(parameters) 302 | global tokenizer 303 | 304 | try: 305 | 306 | # load data for this client and get trainloader 307 | num_workers = len(ray.worker.get_resource_ids()["CPU"]) 308 | trainloader = get_dataloader( 309 | self.fed_dir, 310 | self.cid, 311 | is_train=True, 312 | batch_size=BATCH_SIZE, 313 | workers=num_workers, 314 | data=DATA, 315 | tokenizer=tokenizer, 316 | shuffle=True, 317 | lang_mix=LANG_MIX, 318 | ) 319 | 320 | # send model to device 321 | self.net.to(self.device) 322 | 323 | # train 324 | train(self.net, trainloader, int(config["epochs"]), self.optimizer, device=self.device, 325 | cid=self.cid, get_accuracy=("pawsx" in DATA or "nc" in DATA)) 326 | 327 | except Exception as e: 328 | print(f"Error failed in train was `{e}`") 329 | print(traceback.format_exc()) 330 | raise e 331 | 332 | # return local model and statistics 333 | return self.get_parameters(), len(trainloader.dataset), {} 334 | 335 | def evaluate(self, parameters, config): 336 | global tokenizer 337 | self.set_parameters(parameters) 338 | 339 | # load data for this client and get trainloader 340 | num_workers = len(ray.worker.get_resource_ids()["CPU"]) 341 | valloader = get_dataloader( 342 | self.fed_dir, self.cid, is_train=False, batch_size=BATCH_SIZE, workers=num_workers, 343 | tokenizer=tokenizer, shuffle=False, lang_mix=LANG_MIX 344 | ) 345 | 346 | # send model to device 347 | self.net.to(self.device) 348 | 349 | # evaluate 350 | loss, accuracy = test(self.net, valloader, device=self.device) 351 | self.net.to("cpu") 352 | 353 | # return statistics 354 | return float(loss), len(valloader.dataset), {f"perplexity_{self.cid}": float(accuracy)} 355 | 356 | 357 | def fit_config(rnd: int) -> Dict[str, str]: 358 | """Return a configuration with static batch size and (local) epochs.""" 359 | global BATCH_SIZE 360 | config = { 361 | "epoch_global": str(rnd), 362 | "epochs": str(1), 363 | "batch_size": str(BATCH_SIZE), 364 | } 365 | return config 366 | 367 | 368 | class TopItem: 369 | # to manage saving the top_N items 370 | def __init__(self, score: float, path: str): 371 | self.score = score 372 | self.path = path 373 | 374 | def __lt__(self, other) -> bool: 375 | return self.score < other.score 376 | 377 | def to_str(self) -> str: 378 | return f"Score: {self.score} at Path: {self.path}" 379 | 380 | 381 | def set_weights(model: torch.nn.ModuleList, weights: fl.common.Weights) -> None: 382 | """Set model weights from a list of NumPy ndarrays.""" 383 | state_dict = OrderedDict( 384 | { 385 | k: torch.Tensor(np.atleast_1d(v)) 386 | for k, v in zip(model.state_dict().keys(), weights) 387 | } 388 | ) 389 | model.load_state_dict(state_dict, strict=True) 390 | 391 | 392 | def get_eval_fn( 393 | testset, lang_mix: float 394 | ) -> Callable[[fl.common.Weights], Optional[Tuple[float, float]]]: 395 | """Return an evaluation function for centralized evaluation.""" 396 | 397 | def evaluate(weights: fl.common.Weights) -> Optional[Tuple[float, float]]: 398 | """Use the entire test set for evaluation.""" 399 | global CUDA_COUNT 400 | global BATCH_SIZE 401 | global DATA 402 | global GPU_MAPPING 403 | global EVAL_NUM 404 | global MODEL_NAME 405 | global tokenizer 406 | global PREV_LOSS 407 | global NUM_SKIP_EVAL 408 | global TOP_N_SCORES 409 | KEEP_PILE = 2 410 | 411 | if EVAL_NUM % NUM_SKIP_EVAL == 1: # after every epoch basically 412 | print(f"Skipping with EVAL_NUM={EVAL_NUM} and NUM_SKIP_EVAL={NUM_SKIP_EVAL}") 413 | EVAL_NUM += 1 414 | return PREV_LOSS, {"perplexity": math.exp(PREV_LOSS)} 415 | 416 | model = make_huggingface_model() 417 | 418 | set_weights(model, weights) 419 | 420 | # determine device 421 | if os.environ.get("CUDA_VISIBLE_DEVICES") is None: 422 | cuda_is_available = False 423 | else: 424 | cuda_is_available = torch.cuda.is_available() 425 | device_str = f"cuda:{GPU_MAPPING['server']}" if cuda_is_available else "cpu" 426 | 427 | device = torch.device(device_str) 428 | model.to(device) 429 | 430 | if len(GPU_MAPPING) == 1: 431 | if "_mt_" in DATA or "mtnt" in DATA: 432 | # MT Eval only 433 | print("Running Eval on MT") 434 | test_fn = test_mt 435 | batch_size = 1 436 | else: 437 | test_fn = test 438 | batch_size = BATCH_SIZE 439 | else: 440 | test_fn = test 441 | batch_size = BATCH_SIZE 442 | 443 | testloader = get_dataloader( 444 | testset, -1, is_train=False, batch_size=batch_size, workers=3, 445 | tokenizer=tokenizer, shuffle=False, data=DATA, lang_mix=LANG_MIX 446 | ) 447 | 448 | 449 | loss, accuracy = test_fn(model, testloader, device=device, get_accuracy=("pawsx" in DATA or "nc" in DATA)) 450 | 451 | if len(TOP_N_SCORES) < KEEP_PILE or loss < TOP_N_SCORES[-1].score: 452 | if "_cont" in CACHE_STR: 453 | save_path = f"{BIG_FILE_CACHE}/{MODEL_NAME.split('/')[-1][-3]}/{CACHE_STR}/" 454 | else: 455 | save_path = f"{BIG_FILE_CACHE}/{MODEL_NAME.split('/')[-1]}/{CACHE_STR}/" 456 | 457 | bisect.insort(TOP_N_SCORES, TopItem(loss, save_path + f"{EVAL_NUM}.pt")) 458 | print([item.to_str() for item in TOP_N_SCORES]) 459 | TOP_N_SCORES, to_remove = TOP_N_SCORES[:KEEP_PILE], TOP_N_SCORES[KEEP_PILE:] 460 | 461 | for top_item in to_remove: 462 | print(f"Removing {top_item.to_str()}") 463 | os.remove(top_item.path) 464 | 465 | if not os.path.isdir(save_path): 466 | os.makedirs(save_path) 467 | 468 | 469 | torch.save(model, f"{save_path}/{EVAL_NUM}.pt") 470 | 471 | 472 | EVAL_NUM += 1 473 | 474 | PREV_LOSS = loss 475 | 476 | # return statistics 477 | return loss, {"perplexity": accuracy} 478 | 479 | return evaluate 480 | 481 | 482 | def make_tokenizer(model_name: str): 483 | global tokenizer 484 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 485 | return tokenizer 486 | 487 | 488 | def make_huggingface_model(): 489 | global CACHE_STR 490 | try: 491 | config = AutoConfig.from_pretrained(MODEL_NAME) 492 | except Exception as e: 493 | pass # loading model doesn't need this 494 | 495 | warnings.filterwarnings("ignore") 496 | logging.set_verbosity_error() # hides warning about CasualLM not loading encoder 497 | if ".pt" in MODEL_NAME[-3:]: 498 | print(f"Loading model {MODEL_NAME}") 499 | model = torch.load(MODEL_NAME) 500 | CACHE_STR = MODEL_NAME.split("/")[-2] + "_cont" 501 | elif not RANDOM_INIT: 502 | if "gpt2" in MODEL_NAME: 503 | model = AutoModelForCausalLM.from_pretrained( 504 | MODEL_NAME, 505 | from_tf=False, 506 | cache_dir=f"{BIG_FILE_CACHE}/huggingface_cache/" 507 | ) 508 | elif "xlm" in MODEL_NAME: 509 | try: 510 | model = XLMRobertaForSequenceClassification.from_pretrained( 511 | MODEL_NAME, 512 | cache_dir=f"{BIG_FILE_CACHE}/huggingface_cache/", 513 | num_labels=10 514 | ) 515 | except Exception as e: 516 | breakpoint() 517 | print(e) 518 | elif "bert" in MODEL_NAME: 519 | model = DistilBertForMaskedLM.from_pretrained( 520 | MODEL_NAME, 521 | cache_dir=f"{BIG_FILE_CACHE}/huggingface_cache/" 522 | ) 523 | elif "m2m" in MODEL_NAME: 524 | model = M2M100ForConditionalGeneration.from_pretrained( 525 | MODEL_NAME, 526 | cache_dir=f"{BIG_FILE_CACHE}/huggingface_cache/" 527 | ) 528 | else: 529 | raise NotImplementedError(f"Haven't impleneted model={MODEL_NAME}") 530 | 531 | else: 532 | print("Training new model from scratch") 533 | if "xlm" in MODEL_NAME: 534 | config.num_labels = 10 535 | model = XLMRobertaForSequenceClassification( 536 | config 537 | ) 538 | elif "bert" in MODEL_NAME: 539 | model = DistilBertForMaskedLM(config) 540 | elif "m2m" in MODEL_NAME: 541 | model = M2M100ForConditionalGeneration(config) 542 | 543 | return model 544 | 545 | 546 | # Start Ray simulation (a _default server_ will be created) 547 | # This example does: 548 | # 1. Prepares the data 549 | # 2. Partitions the dataset into N splits, where N is the total number of 550 | # clients. We refere to this as `pool_size`. The partition can be IID or non-IID 551 | # 4. Starts a Ray-based simulation where a % of clients are sample each round. 552 | # 5. After the M rounds end, the global model is evaluated on the entire testset. 553 | # Also, the global model is evaluated on the valset partition residing in each 554 | # client. This is useful to get a sense on how well the global model can generalise 555 | # to each client's data. 556 | if __name__ == "__main__": 557 | parser = argparse.ArgumentParser() 558 | parser.add_argument("--data", type=str, help="dataset path to use", required=True) 559 | parser.add_argument("--model", type=str, help="The model name to use", required=True) 560 | parser.add_argument("--n_cpus", type=int, help="The number of CPUs to use PER MACHINE", default=-1) 561 | parser.add_argument("--n_gpus", type=float, help="The number of GPUs to use TOTAL", default=0.0) 562 | parser.add_argument("--frac_fit", type=float, help="The percent of nodes to sample each time", default=1.0) 563 | parser.add_argument("--batch_size", type=int, default=16, help="The batch size to use") 564 | parser.add_argument("--batch_accum", type=int, default=8, help="The batch accumulation steps to do") 565 | parser.add_argument("--n_iterations", type=int, default=120, help="The number of iterations to do") 566 | parser.add_argument("--lang_mix", type=float, default=0.0, help="The lang mixture to use (0 for separate, 1 for uniform)") 567 | parser.add_argument("--lr", type=float, default=5e-5, help="The learning rate to use for the optimizer") 568 | parser.add_argument("--random_init", action="store_true", help="Whether to load a random intialized model") 569 | parser.add_argument("--load_model", type=str, help="whether to load a saved model path") 570 | parser.add_argument("--centralized", action="store_true", help="Whether to run with a centralized run instead") 571 | parser.add_argument("--test", type=int, default=0, help="Whether to load small data instead") 572 | args_parsed = parser.parse_args() 573 | 574 | BATCH_SIZE = args_parsed.batch_size 575 | RANDOM_INIT = args_parsed.random_init 576 | MODEL_NAME = args_parsed.model if args_parsed.load_model is None else args_parsed.load_model 577 | DATA = args_parsed.data 578 | ACCUM_STEPS = args_parsed.batch_accum 579 | num_rounds = args_parsed.n_iterations 580 | LEARNING_RATE = args_parsed.lr 581 | LANG_MIX = args_parsed.lang_mix 582 | pool_size = POOL_SIZE[args_parsed.data] # number of dataset partions (= number of total clients) 583 | 584 | if args_parsed.centralized: 585 | pool_size = 1 586 | 587 | cache_str = str(args_parsed.lang_mix) + "_" + str(args_parsed.lr) + f"_{DATA}" if pool_size != 1 else f"centralized_{DATA}_" + str(args_parsed.lr) 588 | if args_parsed.random_init: 589 | cache_str += "_random" 590 | CACHE_STR = cache_str 591 | 592 | if args_parsed.n_gpus != 0.0: 593 | N_GPUS = args_parsed.n_gpus 594 | if num_rounds == 0: 595 | GPU_MAPPING["server"] = 0 596 | else: 597 | if N_GPUS < 2 and num_rounds != 0: 598 | print(f"Given N_GPUs={N_GPUS}, need 2+ for client(s) and server to have separate GPUs. Use CPU instead otherwise") 599 | exit(1) 600 | num_iter_for_epoch = pool_size // (N_GPUS-1) if pool_size % (N_GPUS - 1) == 0 else (pool_size // (N_GPUS-1)) + 1 601 | args_parsed.frac_fit = 1 / num_iter_for_epoch 602 | # TODO implement fractional GPU options if desired 603 | num_rounds = int(num_rounds * num_iter_for_epoch) 604 | client_resources["num_gpus"] = 1.0 605 | print(f"Using 1 GPU per client with {args_parsed.frac_fit} clients sampled per round out of {pool_size} clients") 606 | 607 | gpus = os.environ.get("CUDA_VISIBLE_DEVICES").split(",") 608 | num_of_gpus_per_round = int(pool_size // num_iter_for_epoch) 609 | NUM_SKIP_EVAL = num_iter_for_epoch 610 | GPU_MAPPING["server"] = len(gpus) - 1 611 | for cid in range(pool_size): 612 | cid_gpu_idx = int(cid % num_of_gpus_per_round) 613 | GPU_MAPPING[cid] = int(gpus[cid_gpu_idx]) 614 | print(f"GPU mapping is: {GPU_MAPPING}") 615 | 616 | if args_parsed.n_cpus != -1: 617 | client_resources["num_cpus"] = args_parsed.n_cpus 618 | 619 | tokenizer = make_tokenizer(args_parsed.model) 620 | file_path_data = DATA_TO_FILE_PATHS[args_parsed.data] 621 | GLOBAL_LANG_MAP = MAP_LANG_MAP[file_path_data] 622 | 623 | if args_parsed.data == "brown": 624 | trainset = LineByLineTextDataset(tokenizer, f"{file_path_data}/train.txt", test_flag=args_parsed.test) 625 | testset = LineByLineTextDataset(tokenizer, f"{file_path_data}/dev.txt", test_flag=args_parsed.test) 626 | fed_dir = do_fl_partitioning_brown( 627 | f"{file_path_data}/train.txt", trainset.examples, pool_size=pool_size, val_ratio=0.0 628 | ) 629 | else: 630 | dataset_type = get_dataset_type(file_path_data) 631 | trainset = dataset_type(tokenizer, file_path_data, split="train", test_flag=args_parsed.test) 632 | split_name = "dev" if num_rounds != 0 else "test" 633 | print("Eval set is", split_name) 634 | testset = dataset_type(tokenizer, file_path_data, split=split_name, test_flag=args_parsed.test) 635 | fed_dir = do_fl_partitioning( 636 | file_path_data, trainset.examples, pool_size=pool_size, lang_mix=args_parsed.lang_mix, 637 | cache_str=cache_str, val_ratio=0.0 # we manually do test 638 | ) 639 | 640 | # configure the strategy 641 | strategy = fl.server.strategy.FedAvg( 642 | fraction_fit=args_parsed.frac_fit, 643 | min_fit_clients=1, 644 | min_available_clients=pool_size, # All clients should be available 645 | on_fit_config_fn=fit_config, 646 | eval_fn=get_eval_fn(testset, args_parsed.lang_mix if not args_parsed.centralized else "central"), # centralised testset evaluation of global model 647 | ) 648 | 649 | def client_fn(cid: str, optimizers=ALL_OPTIMIZERS): 650 | net = make_huggingface_model() 651 | # create a single client instance 652 | if cid not in optimizers: 653 | # Split weights in two groups, one with weight decay and the other not. 654 | no_decay = ["bias", "LayerNorm.weight"] 655 | weight_decay = 0 656 | learning_rate = LEARNING_RATE 657 | optimizer_grouped_parameters = [ 658 | { 659 | "params": [p for n, p in net.named_parameters() if not any(nd in n for nd in no_decay)], 660 | "weight_decay": weight_decay, 661 | }, 662 | { 663 | "params": [p for n, p in net.named_parameters() if any(nd in n for nd in no_decay)], 664 | "weight_decay": 0.0, 665 | }, 666 | ] 667 | optimizers[cid] = AdamW(optimizer_grouped_parameters, lr=learning_rate) 668 | return RayClient(cid, fed_dir, optimizers[cid], net) 669 | 670 | # (optional) specify ray config 671 | ray_config = {"include_dashboard": False} 672 | 673 | # start simulation 674 | fl.simulation.start_simulation( 675 | client_fn=client_fn, 676 | num_clients=pool_size, 677 | client_resources=client_resources, 678 | num_rounds=num_rounds, 679 | strategy=strategy, 680 | ray_init_args=ray_config, 681 | ) 682 | --------------------------------------------------------------------------------