├── src ├── data │ ├── tknzr │ │ ├── tokenizer-ca,de,en,es,fr,ru.tknzr │ │ │ ├── special_tokens_map.json │ │ │ └── tokenizer_config.json │ │ └── tokenizer-de,en,es,fr,nl,ru.tknzr │ │ │ ├── special_tokens_map.json │ │ │ └── tokenizer_config.json │ ├── wiki40b.py │ ├── process_wiki40b.py │ ├── train_tokenizer_wiki40b.py │ ├── utils.py │ └── slimpajama.py ├── eval_datasets.py ├── run.py ├── models.py └── dataloader.py ├── script ├── doge.sh ├── doge_ood_wiki40b.sh ├── base.sh └── doge_ood.sh ├── requirements.txt ├── README.md ├── LICENSE ├── config ├── doge_ood │ ├── wiki40b_catalan.json │ ├── book_ood.json │ ├── c4_ood.json │ ├── cc_ood.json │ ├── arxiv_ood.json │ ├── github_ood.json │ ├── wiki_ood.json │ └── stack_ood.json ├── doge.json └── gpt_base │ ├── baseline.json │ ├── reweight_doge.json │ ├── reweight_doremi50k.json │ └── reweight_doremi10k.json └── .gitignore /src/data/tknzr/tokenizer-ca,de,en,es,fr,ru.tknzr/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": "<|endoftext|>", 3 | "eos_token": "<|endoftext|>", 4 | "unk_token": "<|endoftext|>" 5 | } 6 | -------------------------------------------------------------------------------- /src/data/tknzr/tokenizer-de,en,es,fr,nl,ru.tknzr/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": "<|endoftext|>", 3 | "eos_token": "<|endoftext|>", 4 | "unk_token": "<|endoftext|>" 5 | } 6 | -------------------------------------------------------------------------------- /script/doge.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | pip install -r requirements.txt 3 | export WANDB_API_KEY="put your authorize key here, to find it: https://wandb.ai/authorize" 4 | 5 | python src/run.py --config_json config/doge.json --wandb_proj doge --wandb_run DOGE-proxy-82M --total_iterations 10000 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tiktoken 2 | --find-links https://download.pytorch.org/whl/torch_stable.html 3 | torch==2.0.0+cu118 4 | torchaudio==2.0.0+cu118 5 | torchvision==0.15.0+cu118 6 | tqdm==4.65.0 7 | transformers==4.36.2 8 | accelerate 9 | wandb 10 | datasets 11 | zstandard 12 | scikit-learn 13 | sacremoses 14 | tokenizers 15 | tfds-nightly 16 | tensorflow -------------------------------------------------------------------------------- /script/doge_ood_wiki40b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | pip install -r requirements.txt 3 | export WANDB_API_KEY="put your authorize key here, to find it: https://wandb.ai/authorize" 4 | 5 | # process dataset 6 | python src/data/process_wiki40b.py 7 | python src/run.py --config_json config/doge_ood/wiki40b_catalan.json --wandb_proj doge --wandb_run DOGEood-Catalan --total_iterations 10000 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DoGE 2 | 3 | ## Requirements 4 | > pip install -r requirements.txt 5 | 6 | ## Run DoGE-proxy for Domain Reweighting 7 | - Universal Generalization on SlimPajama 8 | replace `WANDB_API_KEY` by your own authorize key. 9 | > bash script/doge.sh 10 | 11 | - Out-of-Domain Generalization 12 | 13 | **SlimPajama** 14 | > bash script/doge_ood.sh 15 | 16 | **Wiki40b-Catalan** 17 | > bash script/doge_ood_wiki40b.sh 18 | 19 | ## Train Base Model (DoGE-base) 20 | > bash script/base.sh 21 | -------------------------------------------------------------------------------- /src/data/tknzr/tokenizer-ca,de,en,es,fr,ru.tknzr/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "added_tokens_decoder": { 4 | "0": { 5 | "content": "<|endoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false, 10 | "special": true 11 | } 12 | }, 13 | "bos_token": "<|endoftext|>", 14 | "clean_up_tokenization_spaces": true, 15 | "eos_token": "<|endoftext|>", 16 | "model_max_length": 1024, 17 | "tokenizer_class": "GPT2Tokenizer", 18 | "unk_token": "<|endoftext|>" 19 | } 20 | -------------------------------------------------------------------------------- /src/data/tknzr/tokenizer-de,en,es,fr,nl,ru.tknzr/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "added_tokens_decoder": { 4 | "0": { 5 | "content": "<|endoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false, 10 | "special": true 11 | } 12 | }, 13 | "bos_token": "<|endoftext|>", 14 | "clean_up_tokenization_spaces": true, 15 | "eos_token": "<|endoftext|>", 16 | "model_max_length": 1024, 17 | "tokenizer_class": "GPT2Tokenizer", 18 | "unk_token": "<|endoftext|>" 19 | } 20 | -------------------------------------------------------------------------------- /script/base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | pip install -r requirements.txt 3 | export WANDB_API_KEY="put your authorize key here, to find it: https://wandb.ai/authorize" 4 | 5 | python src/run.py --config_json config/gpt_base/baseline.json --wandb_proj doge --wandb_run BASE-82M --total_iterations 20000 6 | python src/run.py --config_json config/gpt_base/reweight_doge.json --wandb_proj doge --wandb_run DOGE-base-82M --total_iterations 20000 7 | python src/run.py --config_json config/gpt_base/reweight_doremi50k.json --wandb_proj doge --wandb_run DOREMI50k-82M --total_iterations 20000 8 | python src/run.py --config_json config/gpt_base/reweight_doremi10k.json --wandb_proj doge --wandb_run DOREMI10k-82M --total_iterations 20000 9 | -------------------------------------------------------------------------------- /script/doge_ood.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | pip install -r requirements.txt 3 | export WANDB_API_KEY="put your authorize key here, to find it: https://wandb.ai/authorize" 4 | 5 | python src/run.py --config_json config/doge_ood/arxiv_ood.json --wandb_proj doge --wandb_run DOGEood-Arxiv --total_iterations 10000 6 | python src/run.py --config_json config/doge_ood/book_ood.json --wandb_proj doge --wandb_run DOGEood-Book --total_iterations 10000 7 | python src/run.py --config_json config/doge_ood/c4_ood.json --wandb_proj doge --wandb_run DOGEood-C4 --total_iterations 10000 8 | python src/run.py --config_json config/doge_ood/cc_ood.json --wandb_proj doge --wandb_run DOGEood-CC --total_iterations 10000 9 | python src/run.py --config_json config/doge_ood/github_ood.json --wandb_proj doge --wandb_run DOGEood-Github --total_iterations 10000 10 | python src/run.py --config_json config/doge_ood/stack_ood.json --wandb_proj doge --wandb_run DOGEood-Stack --total_iterations 10000 11 | python src/run.py --config_json config/doge_ood/wiki_ood.json --wandb_proj doge --wandb_run DOGEood-Wiki --total_iterations 10000 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Olivia-fsm 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/data/wiki40b.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import XLMTokenizer, AutoTokenizer 3 | import numpy as np 4 | import os 5 | import torch 6 | from datasets import load_dataset, Dataset 7 | import tensorflow_datasets as tfds 8 | 9 | 10 | tknzr = AutoTokenizer.from_pretrained('/scratch/homes/sfan/models/doge_codebase/src/data/tknzr/tokenizer-ca,de,en,es,fr,ru.tknzr') 11 | end_of_doc_token = '' 12 | languages = ['en', 'ar', 'zh-cn', 'zh-tw', 'nl', 'fr', 'de', 'it', 'ja', 'ko', 'pl', 'pt', 'ru', 'es', 'th', 'tr', 'bg', 'ca', 'cs', 'da', 'el', 'et', 'fa', 'fi', 'he', 'hi', 'hr', 'hu', 'id', 'lt', 'lv', 'ms', 'no', 'ro', 'sk', 'sl', 'sr', 'sv', 'tl', 'uk', 'vi'] 13 | 14 | 15 | def get_wiki40b(subset='en', num_proc=40, 16 | return_torch=True): 17 | """ https://huggingface.co/datasets/wiki40b 18 | """ 19 | WIKI_40B_PATH = os.path.join(os.path.dirname(__file__), "wiki40b") 20 | SUBSET_PATH = os.path.join(WIKI_40B_PATH, subset) 21 | train_path = os.path.join(SUBSET_PATH, f"{subset}_train.bin") 22 | test_path = os.path.join(SUBSET_PATH, f"{subset}_test.bin") 23 | 24 | train_data = np.memmap(train_path, dtype=np.uint16, mode='r') 25 | test_data = np.memmap(test_path, dtype=np.uint16, mode='r') 26 | print(f'Subset {subset}: train[{len(train_data)}] | val[{len(test_data)}]') 27 | if return_torch: 28 | train_data = torch.tensor(np.array(train_data, dtype=np.int32)) 29 | test_data = torch.tensor(np.array(test_data, dtype=np.int32)) 30 | return {'train': train_data, 'val': test_data, 'test': test_data} 31 | 32 | get_wiki40b(subset='ca', num_proc=10, 33 | return_torch=False) -------------------------------------------------------------------------------- /config/doge_ood/wiki40b_catalan.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "wiki40b-en-de-fr-es-ru-ca", 3 | "train_domains":"en,de,fr,es,ru", 4 | "tgt_domains":"ca", 5 | "train_dw": null, 6 | "val_dw": null, 7 | "max_train_samples": null, 8 | "max_eval_samples": null, 9 | "max_downstream_samples": null, 10 | "max_token_length": 512, 11 | "seed": 42, 12 | "preprocessing_num_workers": 1, 13 | "model_name_or_path": null, 14 | "model_type": "gpt2", 15 | "config_overrides": "vocab_size=52000,bos_token_id=0,eos_token_id=0,n_positions=512,n_embd=768,n_layer=6,n_head=12", 16 | "run_name": "doge_wiki40b_6l", 17 | "output_dir": "exp/", 18 | "do_train": true, 19 | "do_eval": true, 20 | "do_predict": false, 21 | "learning_rate": 5e-4, 22 | "weight_decay": 1e-2, 23 | "reweight_domains": true, 24 | "lr_scheduler_name": "linear_warmup_cosine", 25 | "lr_end": 1e-4, 26 | "reweight_eps": 0.0, 27 | "mu": 0.001, 28 | "dw_max": 5.0, 29 | "dw_min": 0.0, 30 | "max_grad_norm": 5.0, 31 | "per_device_train_batch_size": 32, 32 | "warmup_ratio": 0.05, 33 | "warmup_steps": 500, 34 | "max_steps": 10000, 35 | "save_steps": 1000, 36 | "eval_steps": 1000, 37 | "gradient_accumulation_steps": 1, 38 | "save_strategy": "steps", 39 | "evaluation_strategy": "steps", 40 | "logging_steps": 1, 41 | "save_total_limit": 10, 42 | "ddp_find_unused_parameters": false, 43 | "downstream_num_shots": 5, 44 | "downstream_datasets": "", 45 | "eval_all_checkpoints": true, 46 | "skip_perplexity_eval": true, 47 | "use_cpu": false, 48 | "ddp_backend": "nccl", 49 | "compute_pertoken_losses": false, 50 | "overwrite_output_dir": false, 51 | "local_rank": -1, 52 | "domain_update_per_iter": null 53 | } -------------------------------------------------------------------------------- /config/doge.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "slim_full-mix", 3 | "train_domains":"arxiv,book,c4,cc,github,stackexchange,wikipedia", 4 | "tgt_domains":"mix", 5 | "train_dw": null, 6 | "val_dw": null, 7 | "max_train_samples": null, 8 | "max_eval_samples": null, 9 | "max_downstream_samples": null, 10 | "max_token_length": 512, 11 | "seed": 16, 12 | "preprocessing_num_workers": 1, 13 | "model_name_or_path": null, 14 | "model_type": "gpt2", 15 | "config_overrides": "n_positions=512,n_embd=768,n_layer=6,n_head=12", 16 | "run_name": "doge_82M", 17 | "output_dir": "path to save checkpoints", 18 | "do_train": true, 19 | "do_eval": true, 20 | "do_predict": false, 21 | "learning_rate": 5e-4, 22 | "weight_decay": 1e-2, 23 | "reweight_domains": true, 24 | "doremi": false, 25 | "ref_model": null, 26 | "lr_scheduler_name": "linear_warmup_cosine", 27 | "lr_end": 1e-4, 28 | "reweight_eps": 0.0, 29 | "mu": 0.01, 30 | "dw_max": 5.0, 31 | "dw_min": 0.0, 32 | "max_grad_norm": 5.0, 33 | "per_device_train_batch_size": 32, 34 | "warmup_ratio": 0.05, 35 | "warmup_steps": 500, 36 | "max_steps": 10000, 37 | "save_steps": 1000, 38 | "eval_steps": 1000, 39 | "gradient_accumulation_steps": 1, 40 | "save_strategy": "steps", 41 | "evaluation_strategy": "steps", 42 | "logging_steps": 10, 43 | "save_total_limit": 10, 44 | "ddp_find_unused_parameters": false, 45 | "downstream_num_shots": 5, 46 | "downstream_datasets": null, 47 | "eval_all_checkpoints": true, 48 | "skip_perplexity_eval": true, 49 | "use_cpu": false, 50 | "ddp_backend": "nccl", 51 | "compute_pertoken_losses": false, 52 | "overwrite_output_dir": false, 53 | "local_rank": -1, 54 | "domain_update_per_iter": null 55 | } -------------------------------------------------------------------------------- /config/doge_ood/book_ood.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "slim_ood-book", 3 | "train_domains":"arxiv,cc,c4,github,stackexchange,wikipedia", 4 | "tgt_domains":"book", 5 | "train_dw": null, 6 | "val_dw": null, 7 | "max_train_samples": null, 8 | "max_eval_samples": null, 9 | "max_downstream_samples": null, 10 | "max_token_length": 512, 11 | "seed": 16, 12 | "preprocessing_num_workers": 1, 13 | "model_name_or_path": null, 14 | "model_type": "gpt2", 15 | "config_overrides": "n_positions=512,n_embd=768,n_layer=6,n_head=12", 16 | "run_name": "doge_6l", 17 | "output_dir": "path to save checkpoints", 18 | "do_train": true, 19 | "do_eval": true, 20 | "do_predict": false, 21 | "learning_rate": 5e-4, 22 | "weight_decay": 1e-2, 23 | "reweight_domains": true, 24 | "doremi": false, 25 | "ref_model": null, 26 | "lr_scheduler_name": "linear_warmup_cosine", 27 | "lr_end": 1e-4, 28 | "reweight_eps": 0.0, 29 | "mu": 0.05, 30 | "dw_max": 5.0, 31 | "dw_min": 0.0, 32 | "max_grad_norm": 1.0, 33 | "per_device_train_batch_size": 32, 34 | "warmup_ratio": 0.05, 35 | "warmup_steps": 500, 36 | "max_steps": 10000, 37 | "save_steps": 5000, 38 | "eval_steps": 1000, 39 | "gradient_accumulation_steps": 1, 40 | "save_strategy": "steps", 41 | "evaluation_strategy": "steps", 42 | "logging_steps": 10, 43 | "save_total_limit": 10, 44 | "ddp_find_unused_parameters": false, 45 | "downstream_num_shots": 5, 46 | "downstream_datasets": null, 47 | "eval_all_checkpoints": true, 48 | "skip_perplexity_eval": true, 49 | "use_cpu": false, 50 | "ddp_backend": "nccl", 51 | "compute_pertoken_losses": false, 52 | "overwrite_output_dir": false, 53 | "local_rank": -1, 54 | "domain_update_per_iter": null 55 | } -------------------------------------------------------------------------------- /config/doge_ood/c4_ood.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "slim_ood-c4", 3 | "train_domains":"arxiv,book,cc,github,stackexchange,wikipedia", 4 | "tgt_domains":"c4", 5 | "train_dw": null, 6 | "val_dw": null, 7 | "max_train_samples": null, 8 | "max_eval_samples": null, 9 | "max_downstream_samples": null, 10 | "max_token_length": 512, 11 | "seed": 16, 12 | "preprocessing_num_workers": 1, 13 | "model_name_or_path": null, 14 | "model_type": "gpt2", 15 | "config_overrides": "n_positions=512,n_embd=768,n_layer=6,n_head=12", 16 | "run_name": "doge_6l", 17 | "output_dir": "path to save checkpoints", 18 | "do_train": true, 19 | "do_eval": true, 20 | "do_predict": false, 21 | "learning_rate": 5e-4, 22 | "weight_decay": 1e-2, 23 | "reweight_domains": true, 24 | "doremi": false, 25 | "ref_model": null, 26 | "lr_scheduler_name": "linear_warmup_cosine", 27 | "lr_end": 1e-4, 28 | "reweight_eps": 0.0, 29 | "mu": 0.05, 30 | "dw_max": 5.0, 31 | "dw_min": 0.0, 32 | "max_grad_norm": 1.0, 33 | "per_device_train_batch_size": 32, 34 | "warmup_ratio": 0.05, 35 | "warmup_steps": 500, 36 | "max_steps": 10000, 37 | "save_steps": 5000, 38 | "eval_steps": 1000, 39 | "gradient_accumulation_steps": 1, 40 | "save_strategy": "steps", 41 | "evaluation_strategy": "steps", 42 | "logging_steps": 10, 43 | "save_total_limit": 10, 44 | "ddp_find_unused_parameters": false, 45 | "downstream_num_shots": 5, 46 | "downstream_datasets": null, 47 | "eval_all_checkpoints": true, 48 | "skip_perplexity_eval": true, 49 | "use_cpu": false, 50 | "ddp_backend": "nccl", 51 | "compute_pertoken_losses": false, 52 | "overwrite_output_dir": false, 53 | "local_rank": -1, 54 | "domain_update_per_iter": null 55 | } -------------------------------------------------------------------------------- /config/doge_ood/cc_ood.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "slim_ood-cc", 3 | "train_domains":"arxiv,book,c4,github,stackexchange,wikipedia", 4 | "tgt_domains":"cc", 5 | "train_dw": null, 6 | "val_dw": null, 7 | "max_train_samples": null, 8 | "max_eval_samples": null, 9 | "max_downstream_samples": null, 10 | "max_token_length": 512, 11 | "seed": 16, 12 | "preprocessing_num_workers": 1, 13 | "model_name_or_path": null, 14 | "model_type": "gpt2", 15 | "config_overrides": "n_positions=512,n_embd=768,n_layer=6,n_head=12", 16 | "run_name": "doge_6l", 17 | "output_dir": "path to save checkpoints", 18 | "do_train": true, 19 | "do_eval": true, 20 | "do_predict": false, 21 | "learning_rate": 5e-4, 22 | "weight_decay": 1e-2, 23 | "reweight_domains": true, 24 | "doremi": false, 25 | "ref_model": null, 26 | "lr_scheduler_name": "linear_warmup_cosine", 27 | "lr_end": 1e-4, 28 | "reweight_eps": 0.0, 29 | "mu": 0.05, 30 | "dw_max": 5.0, 31 | "dw_min": 0.0, 32 | "max_grad_norm": 1.0, 33 | "per_device_train_batch_size": 32, 34 | "warmup_ratio": 0.05, 35 | "warmup_steps": 500, 36 | "max_steps": 10000, 37 | "save_steps": 5000, 38 | "eval_steps": 1000, 39 | "gradient_accumulation_steps": 1, 40 | "save_strategy": "steps", 41 | "evaluation_strategy": "steps", 42 | "logging_steps": 10, 43 | "save_total_limit": 10, 44 | "ddp_find_unused_parameters": false, 45 | "downstream_num_shots": 5, 46 | "downstream_datasets": null, 47 | "eval_all_checkpoints": true, 48 | "skip_perplexity_eval": true, 49 | "use_cpu": false, 50 | "ddp_backend": "nccl", 51 | "compute_pertoken_losses": false, 52 | "overwrite_output_dir": false, 53 | "local_rank": -1, 54 | "domain_update_per_iter": null 55 | } -------------------------------------------------------------------------------- /config/doge_ood/arxiv_ood.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "slim_ood-arxiv", 3 | "train_domains":"book,cc,c4,github,stackexchange,wikipedia", 4 | "tgt_domains":"arxiv", 5 | "train_dw": null, 6 | "val_dw": null, 7 | "max_train_samples": null, 8 | "max_eval_samples": null, 9 | "max_downstream_samples": null, 10 | "max_token_length": 512, 11 | "seed": 16, 12 | "preprocessing_num_workers": 1, 13 | "model_name_or_path": null, 14 | "model_type": "gpt2", 15 | "config_overrides": "n_positions=512,n_embd=768,n_layer=6,n_head=12", 16 | "run_name": "doge_6l", 17 | "output_dir": "path to save checkpoints", 18 | "do_train": true, 19 | "do_eval": true, 20 | "do_predict": false, 21 | "learning_rate": 5e-4, 22 | "weight_decay": 1e-2, 23 | "reweight_domains": true, 24 | "doremi": false, 25 | "ref_model": null, 26 | "lr_scheduler_name": "linear_warmup_cosine", 27 | "lr_end": 1e-4, 28 | "reweight_eps": 0.0, 29 | "mu": 0.05, 30 | "dw_max": 5.0, 31 | "dw_min": 0.0, 32 | "max_grad_norm": 1.0, 33 | "per_device_train_batch_size": 32, 34 | "warmup_ratio": 0.05, 35 | "warmup_steps": 500, 36 | "max_steps": 10000, 37 | "save_steps": 5000, 38 | "eval_steps": 1000, 39 | "gradient_accumulation_steps": 1, 40 | "save_strategy": "steps", 41 | "evaluation_strategy": "steps", 42 | "logging_steps": 10, 43 | "save_total_limit": 10, 44 | "ddp_find_unused_parameters": false, 45 | "downstream_num_shots": 5, 46 | "downstream_datasets": null, 47 | "eval_all_checkpoints": true, 48 | "skip_perplexity_eval": true, 49 | "use_cpu": false, 50 | "ddp_backend": "nccl", 51 | "compute_pertoken_losses": false, 52 | "overwrite_output_dir": false, 53 | "local_rank": -1, 54 | "domain_update_per_iter": null 55 | } -------------------------------------------------------------------------------- /config/doge_ood/github_ood.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "slim_ood-github", 3 | "train_domains":"arxiv,book,cc,c4,stackexchange,wikipedia", 4 | "tgt_domains":"github", 5 | "train_dw": null, 6 | "val_dw": null, 7 | "max_train_samples": null, 8 | "max_eval_samples": null, 9 | "max_downstream_samples": null, 10 | "max_token_length": 512, 11 | "seed": 16, 12 | "preprocessing_num_workers": 1, 13 | "model_name_or_path": null, 14 | "model_type": "gpt2", 15 | "config_overrides": "n_positions=512,n_embd=768,n_layer=6,n_head=12", 16 | "run_name": "doge_6l", 17 | "output_dir": "path to save checkpoints", 18 | "do_train": true, 19 | "do_eval": true, 20 | "do_predict": false, 21 | "learning_rate": 5e-4, 22 | "weight_decay": 1e-2, 23 | "reweight_domains": true, 24 | "doremi": false, 25 | "ref_model": null, 26 | "lr_scheduler_name": "linear_warmup_cosine", 27 | "lr_end": 1e-4, 28 | "reweight_eps": 0.0, 29 | "mu": 0.05, 30 | "dw_max": 5.0, 31 | "dw_min": 0.0, 32 | "max_grad_norm": 1.0, 33 | "per_device_train_batch_size": 32, 34 | "warmup_ratio": 0.05, 35 | "warmup_steps": 500, 36 | "max_steps": 10000, 37 | "save_steps": 5000, 38 | "eval_steps": 1000, 39 | "gradient_accumulation_steps": 1, 40 | "save_strategy": "steps", 41 | "evaluation_strategy": "steps", 42 | "logging_steps": 10, 43 | "save_total_limit": 10, 44 | "ddp_find_unused_parameters": false, 45 | "downstream_num_shots": 5, 46 | "downstream_datasets": null, 47 | "eval_all_checkpoints": true, 48 | "skip_perplexity_eval": true, 49 | "use_cpu": false, 50 | "ddp_backend": "nccl", 51 | "compute_pertoken_losses": false, 52 | "overwrite_output_dir": false, 53 | "local_rank": -1, 54 | "domain_update_per_iter": null 55 | } -------------------------------------------------------------------------------- /config/doge_ood/wiki_ood.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "slim_ood-wikipedia", 3 | "train_domains":"arxiv,book,cc,c4,github,stackexchange", 4 | "tgt_domains":"wikipedia", 5 | "train_dw": null, 6 | "val_dw": null, 7 | "max_train_samples": null, 8 | "max_eval_samples": null, 9 | "max_downstream_samples": null, 10 | "max_token_length": 512, 11 | "seed": 16, 12 | "preprocessing_num_workers": 1, 13 | "model_name_or_path": null, 14 | "model_type": "gpt2", 15 | "config_overrides": "n_positions=512,n_embd=768,n_layer=6,n_head=12", 16 | "run_name": "doge_6l", 17 | "output_dir": "path to save checkpoints", 18 | "do_train": true, 19 | "do_eval": true, 20 | "do_predict": false, 21 | "learning_rate": 5e-4, 22 | "weight_decay": 1e-2, 23 | "reweight_domains": true, 24 | "doremi": false, 25 | "ref_model": null, 26 | "lr_scheduler_name": "linear_warmup_cosine", 27 | "lr_end": 1e-4, 28 | "reweight_eps": 0.0, 29 | "mu": 0.05, 30 | "dw_max": 5.0, 31 | "dw_min": 0.0, 32 | "max_grad_norm": 1.0, 33 | "per_device_train_batch_size": 32, 34 | "warmup_ratio": 0.05, 35 | "warmup_steps": 500, 36 | "max_steps": 10000, 37 | "save_steps": 5000, 38 | "eval_steps": 1000, 39 | "gradient_accumulation_steps": 1, 40 | "save_strategy": "steps", 41 | "evaluation_strategy": "steps", 42 | "logging_steps": 10, 43 | "save_total_limit": 10, 44 | "ddp_find_unused_parameters": false, 45 | "downstream_num_shots": 5, 46 | "downstream_datasets": null, 47 | "eval_all_checkpoints": true, 48 | "skip_perplexity_eval": true, 49 | "use_cpu": false, 50 | "ddp_backend": "nccl", 51 | "compute_pertoken_losses": false, 52 | "overwrite_output_dir": false, 53 | "local_rank": -1, 54 | "domain_update_per_iter": null 55 | } -------------------------------------------------------------------------------- /config/doge_ood/stack_ood.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "slim_ood-stackexchange", 3 | "train_domains":"arxiv,book,cc,c4,github,wikipedia", 4 | "tgt_domains":"stackexchange", 5 | "train_dw": null, 6 | "val_dw": null, 7 | "max_train_samples": null, 8 | "max_eval_samples": null, 9 | "max_downstream_samples": null, 10 | "max_token_length": 512, 11 | "seed": 16, 12 | "preprocessing_num_workers": 1, 13 | "model_name_or_path": null, 14 | "model_type": "gpt2", 15 | "config_overrides": "n_positions=512,n_embd=768,n_layer=6,n_head=12", 16 | "run_name": "doge_6l", 17 | "output_dir": "path to save checkpoints", 18 | "do_train": true, 19 | "do_eval": true, 20 | "do_predict": false, 21 | "learning_rate": 5e-4, 22 | "weight_decay": 1e-2, 23 | "reweight_domains": true, 24 | "doremi": false, 25 | "ref_model": null, 26 | "lr_scheduler_name": "linear_warmup_cosine", 27 | "lr_end": 1e-4, 28 | "reweight_eps": 0.0, 29 | "mu": 0.05, 30 | "dw_max": 5.0, 31 | "dw_min": 0.0, 32 | "max_grad_norm": 1.0, 33 | "per_device_train_batch_size": 32, 34 | "warmup_ratio": 0.05, 35 | "warmup_steps": 500, 36 | "max_steps": 10000, 37 | "save_steps": 5000, 38 | "eval_steps": 1000, 39 | "gradient_accumulation_steps": 1, 40 | "save_strategy": "steps", 41 | "evaluation_strategy": "steps", 42 | "logging_steps": 10, 43 | "save_total_limit": 10, 44 | "ddp_find_unused_parameters": false, 45 | "downstream_num_shots": 5, 46 | "downstream_datasets": null, 47 | "eval_all_checkpoints": true, 48 | "skip_perplexity_eval": true, 49 | "use_cpu": false, 50 | "ddp_backend": "nccl", 51 | "compute_pertoken_losses": false, 52 | "overwrite_output_dir": false, 53 | "local_rank": -1, 54 | "domain_update_per_iter": null 55 | } -------------------------------------------------------------------------------- /config/gpt_base/baseline.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "slim_full-mix", 3 | "train_domains":"arxiv,book,cc,c4,github,stackexchange,wikipedia", 4 | "tgt_domains":"mix", 5 | "train_dw": "0.1428,0.1428,0.1428,0.1428,0.1428,0.1428,0.1428,0", 6 | "val_dw": "0.1428,0.1428,0.1428,0.1428,0.1428,0.1428,0.1428,0", 7 | "max_train_samples": null, 8 | "max_eval_samples": 20000, 9 | "max_downstream_samples": null, 10 | "max_token_length": 512, 11 | "seed": 42, 12 | "preprocessing_num_workers": 1, 13 | "model_name_or_path": null, 14 | "model_type": "gpt2", 15 | "config_overrides": "n_positions=512,n_embd=768,n_layer=12,n_head=12", 16 | "run_name": "BASE_82M", 17 | "output_dir": "path to save checkpoints", 18 | "do_train": true, 19 | "do_eval": true, 20 | "do_predict": false, 21 | "learning_rate": 5e-4, 22 | "weight_decay": 1e-2, 23 | "reweight_domains": false, 24 | "doremi": false, 25 | "ref_model": null, 26 | "lr_scheduler_name": "linear_warmup_cosine", 27 | "lr_end": 1e-4, 28 | "reweight_eps": 0.0, 29 | "mu": 0.01, 30 | "dw_max": 5.0, 31 | "dw_min": 0.0, 32 | "max_grad_norm": 1.0, 33 | "per_device_train_batch_size": 32, 34 | "warmup_ratio": 0.05, 35 | "warmup_steps": 500, 36 | "max_steps": 10000, 37 | "save_steps": 1000, 38 | "eval_steps": 1000, 39 | "gradient_accumulation_steps": 1, 40 | "save_strategy": "steps", 41 | "evaluation_strategy": "steps", 42 | "logging_steps": 10, 43 | "save_total_limit": 10, 44 | "ddp_find_unused_parameters": false, 45 | "downstream_num_shots": 5, 46 | "downstream_datasets": "trivia_qa,web_questions,lambada,natural_questions", 47 | "eval_all_checkpoints": true, 48 | "skip_perplexity_eval": true, 49 | "use_cpu": false, 50 | "ddp_backend": "nccl", 51 | "compute_pertoken_losses": false, 52 | "overwrite_output_dir": true, 53 | "local_rank": -1, 54 | "domain_update_per_iter": null 55 | } -------------------------------------------------------------------------------- /config/gpt_base/reweight_doge.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "slim_full-mix", 3 | "train_domains":"arxiv,book,cc,c4,github,stackexchange,wikipedia", 4 | "tgt_domains":"mix", 5 | "train_dw": "0.088,0.045,0.269,0.214,0.070,0.166,0.148,0", 6 | "val_dw": "0.1428,0.1428,0.1428,0.1428,0.1428,0.1428,0.1428,0", 7 | "max_train_samples": null, 8 | "max_eval_samples": 20000, 9 | "max_downstream_samples": null, 10 | "max_token_length": 512, 11 | "seed": 42, 12 | "preprocessing_num_workers": 1, 13 | "model_name_or_path": null, 14 | "model_type": "gpt2", 15 | "config_overrides": "n_positions=512,n_embd=768,n_layer=12,n_head=12", 16 | "run_name": "doge_82M", 17 | "output_dir": "path to save checkpoints", 18 | "do_train": true, 19 | "do_eval": true, 20 | "do_predict": false, 21 | "learning_rate": 5e-4, 22 | "weight_decay": 1e-2, 23 | "reweight_domains": false, 24 | "doremi": false, 25 | "ref_model": null, 26 | "lr_scheduler_name": "linear_warmup_cosine", 27 | "lr_end": 1e-4, 28 | "reweight_eps": 0.0, 29 | "mu": 0.01, 30 | "dw_max": 5.0, 31 | "dw_min": 0.0, 32 | "max_grad_norm": 1.0, 33 | "per_device_train_batch_size": 32, 34 | "warmup_ratio": 0.05, 35 | "warmup_steps": 500, 36 | "max_steps": 10000, 37 | "save_steps": 1000, 38 | "eval_steps": 1000, 39 | "gradient_accumulation_steps": 1, 40 | "save_strategy": "steps", 41 | "evaluation_strategy": "steps", 42 | "logging_steps": 10, 43 | "save_total_limit": 10, 44 | "ddp_find_unused_parameters": false, 45 | "downstream_num_shots": 5, 46 | "downstream_datasets": "trivia_qa,web_questions,lambada,natural_questions", 47 | "eval_all_checkpoints": true, 48 | "skip_perplexity_eval": true, 49 | "use_cpu": false, 50 | "ddp_backend": "nccl", 51 | "compute_pertoken_losses": false, 52 | "overwrite_output_dir": true, 53 | "local_rank": -1, 54 | "domain_update_per_iter": null 55 | } -------------------------------------------------------------------------------- /config/gpt_base/reweight_doremi50k.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "slim_full-mix", 3 | "train_domains":"arxiv,book,cc,c4,github,stackexchange,wikipedia", 4 | "tgt_domains":"arxiv,book,cc,c4,github,stackexchange,wikipedia", 5 | "train_dw": "0.04235,0.08201,0.381,0.1141,0.0654,0.0847,0.2305,0", 6 | "val_dw": "0.1428,0.1428,0.1428,0.1428,0.1428,0.1428,0.1428,0", 7 | "max_train_samples": null, 8 | "max_eval_samples": 20000, 9 | "max_downstream_samples": null, 10 | "max_token_length": 512, 11 | "seed": 42, 12 | "preprocessing_num_workers": 1, 13 | "model_name_or_path": null, 14 | "model_type": "gpt2", 15 | "config_overrides": "n_positions=512,n_embd=768,n_layer=12,n_head=12", 16 | "run_name": "DOREMI-50k_82M", 17 | "output_dir": "path to save checkpoints", 18 | "do_train": true, 19 | "do_eval": true, 20 | "do_predict": false, 21 | "learning_rate": 5e-4, 22 | "weight_decay": 1e-2, 23 | "reweight_domains": false, 24 | "doremi": false, 25 | "ref_model": null, 26 | "lr_scheduler_name": "linear_warmup_cosine", 27 | "lr_end": 1e-4, 28 | "reweight_eps": 0.0, 29 | "mu": 0.01, 30 | "dw_max": 5.0, 31 | "dw_min": 0.0, 32 | "max_grad_norm": 1.0, 33 | "per_device_train_batch_size": 32, 34 | "warmup_ratio": 0.05, 35 | "warmup_steps": 500, 36 | "max_steps": 10000, 37 | "save_steps": 1000, 38 | "eval_steps": 1000, 39 | "gradient_accumulation_steps": 1, 40 | "save_strategy": "steps", 41 | "evaluation_strategy": "steps", 42 | "logging_steps": 10, 43 | "save_total_limit": 10, 44 | "ddp_find_unused_parameters": false, 45 | "downstream_num_shots": 5, 46 | "downstream_datasets": null, 47 | "eval_all_checkpoints": true, 48 | "skip_perplexity_eval": true, 49 | "use_cpu": false, 50 | "ddp_backend": "nccl", 51 | "compute_pertoken_losses": false, 52 | "overwrite_output_dir": true, 53 | "local_rank": -1, 54 | "domain_update_per_iter": null 55 | } -------------------------------------------------------------------------------- /config/gpt_base/reweight_doremi10k.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "slim_full-mix", 3 | "train_domains":"arxiv,book,cc,c4,github,stackexchange,wikipedia", 4 | "tgt_domains":"arxiv,book,cc,c4,github,stackexchange,wikipedia", 5 | "train_dw": "0.08205,0.02007,0.01395,0.2961,0.294,0.09507,0.1987,0", 6 | "val_dw": "0.1428,0.1428,0.1428,0.1428,0.1428,0.1428,0.1428,0", 7 | "max_train_samples": null, 8 | "max_eval_samples": 20000, 9 | "max_downstream_samples": null, 10 | "max_token_length": 512, 11 | "seed": 42, 12 | "preprocessing_num_workers": 1, 13 | "model_name_or_path": null, 14 | "model_type": "gpt2", 15 | "config_overrides": "n_positions=512,n_embd=768,n_layer=12,n_head=12", 16 | "run_name": "DOREMI-10k_82M", 17 | "output_dir": "path to save checkpoints", 18 | "do_train": true, 19 | "do_eval": true, 20 | "do_predict": false, 21 | "learning_rate": 5e-4, 22 | "weight_decay": 1e-2, 23 | "reweight_domains": false, 24 | "doremi": false, 25 | "ref_model": null, 26 | "lr_scheduler_name": "linear_warmup_cosine", 27 | "lr_end": 1e-4, 28 | "reweight_eps": 0.0, 29 | "mu": 0.01, 30 | "dw_max": 5.0, 31 | "dw_min": 0.0, 32 | "max_grad_norm": 1.0, 33 | "per_device_train_batch_size": 32, 34 | "warmup_ratio": 0.05, 35 | "warmup_steps": 500, 36 | "max_steps": 10000, 37 | "save_steps": 1000, 38 | "eval_steps": 1000, 39 | "gradient_accumulation_steps": 1, 40 | "save_strategy": "steps", 41 | "evaluation_strategy": "steps", 42 | "logging_steps": 10, 43 | "save_total_limit": 10, 44 | "ddp_find_unused_parameters": false, 45 | "downstream_num_shots": 5, 46 | "downstream_datasets": "lambada", 47 | "eval_all_checkpoints": true, 48 | "skip_perplexity_eval": true, 49 | "use_cpu": false, 50 | "ddp_backend": "nccl", 51 | "compute_pertoken_losses": false, 52 | "overwrite_output_dir": true, 53 | "local_rank": -1, 54 | "domain_update_per_iter": null 55 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Dataset folder 2 | src/data/datasets/ 3 | wandb/ 4 | exps/ 5 | exp/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # ipynb 16 | *.ipynb 17 | 18 | # images 19 | *.png 20 | *.jpg 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | pip-wheel-metadata/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # Datasets 145 | datasets/* 146 | -------------------------------------------------------------------------------- /src/data/process_wiki40b.py: -------------------------------------------------------------------------------- 1 | import tensorflow_datasets as tfds 2 | from transformers import XLMTokenizer, AutoTokenizer 3 | import numpy as np 4 | import concurrent.futures 5 | import os 6 | from tqdm import tqdm 7 | 8 | tokenizer = AutoTokenizer.from_pretrained('src/data/tknzr/tokenizer-ca,de,en,es,fr,ru.tknzr') 9 | end_of_doc_token = '' 10 | 11 | # languages = ['en', 'ar', 'zh-cn', 'zh-tw', 'nl', 'fr', 'de', 'it', 'ja', 'ko', 'pl', 'pt', 'ru', 'es', 'th', 'tr', 'bg', 'ca', 'cs', 'da', 'el', 'et', 'fa', 'fi', 'he', 'hi', 'hr', 'hu', 'id', 'lt', 'lv', 'ms', 'no', 'ro', 'sk', 'sl', 'sr', 'sv', 'tl', 'uk', 'vi'] 12 | languages = ['en', 'fr', 'de', 'ru', 'es', 'nl', 'ca'] 13 | 14 | 15 | def process_document(doc): 16 | text = doc['text'].numpy().decode('utf-8') 17 | tokens = tokenizer.tokenize(text) + [end_of_doc_token] 18 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 19 | return token_ids 20 | 21 | WIKI40B_PATH = os.path.join(os.path.dirname(__file__), "datasets", "wiki40b") 22 | for lang in languages: 23 | dataset, dataset_info = tfds.load(f"wiki40b/{lang}", with_info=True) 24 | os.makedirs(os.path.join(WIKI40B_PATH, f"{lang}"), exist_ok=True) 25 | for split in ['test', 'train']: 26 | 27 | filename = os.path.join(WIKI40B_PATH, f"{lang}", f'{lang}_{split}.bin') 28 | if os.path.exists(filename): 29 | print(f"{filename} already exist, skipping ...") 30 | continue 31 | 32 | print(f"Size of the {split} set for {lang}: {len(dataset[split])} documents") 33 | 34 | dataset_list = list(dataset[split]) 35 | 36 | results = [] 37 | for doc in tqdm(dataset_list, desc=f"tokenizing {split} for {lang}"): 38 | results.append(process_document(doc)) 39 | 40 | # with concurrent.futures.ProcessPoolExecutor(max_workers=20) as executor: 41 | # # Submit all tasks to the executor 42 | # futures = [executor.submit(process_document, doc) for doc in dataset_list] 43 | # # Create a tqdm progress bar 44 | # results = [] 45 | # with tqdm(total=len(futures), desc=f"tokenizing {split} for {lang}") as progress: 46 | # for future in concurrent.futures.as_completed(futures): 47 | # result = future.result() 48 | # results.append(result) 49 | # # Update progress bar for each completed task 50 | # progress.update(1) 51 | 52 | print(f"Processed {len(results)} documents in {split} set for {lang}") 53 | 54 | arr_len = np.sum([len(seq) for seq in results]) 55 | dtype = np.uint32 56 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 57 | idx = 0 58 | for start in tqdm(range(0, len(results), 1024), desc=f'writing {filename}'): 59 | arr_batch = np.array([x for y in results[start:start+1024] for x in y], dtype=np.uint32) 60 | arr[idx : idx + len(arr_batch)] = arr_batch 61 | idx += len(arr_batch) 62 | arr.flush() 63 | -------------------------------------------------------------------------------- /src/data/train_tokenizer_wiki40b.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | """ 3 | Train a tokenizer on a subset of languages of wiki40B 4 | 5 | pip install -q tfds-nightly tensorflow tqdm transformers 6 | """ 7 | 8 | import argparse 9 | import tensorflow_datasets as tfds 10 | from transformers import AutoTokenizer 11 | from tqdm import tqdm 12 | import random 13 | import numpy as np 14 | import os 15 | 16 | 17 | end_of_doc_token = '' 18 | 19 | 20 | args_parser = argparse.ArgumentParser() 21 | # DomainConfigArguments 22 | args_parser.add_argument('--langs', default='en,de,fr,ru,es,nl', type=str) 23 | args_parser.add_argument('--weights', default='0.1,0.1,0.1,0.1,0.1,1', type=str) # subsampling large language dataset 24 | args_parser.add_argument('--data_path', default='path to wiki40b data', type=str) 25 | args_parser.add_argument('--vocab_size', default=52000, type=int) 26 | 27 | 28 | def get_text(doc): 29 | text = doc['text'].numpy().decode('utf-8') 30 | tokens = text + ' ' + end_of_doc_token 31 | return tokens 32 | 33 | 34 | def get_training_corpus(languages, weights): 35 | for lang, weight in zip(languages, weights): 36 | dataset, dataset_info = tfds.load(f"wiki40b/{lang}", with_info=True, data_dir='./tensorflow_datasets') 37 | for split in ['test', 'validation', 'train']: 38 | 39 | print(f"Size of the {split} set for {lang}: {len(dataset[split])} documents") 40 | dataset_list = list(dataset[split]) 41 | 42 | for doc in tqdm(dataset_list, desc=f"tokenizing {split} for {lang}"): 43 | if random.random() < weight: 44 | yield get_text(doc) 45 | 46 | 47 | def process_document(doc, tokenizer): 48 | text = doc['text'].numpy().decode('utf-8') 49 | tokens = tokenizer.tokenize(text) 50 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 51 | return token_ids 52 | 53 | 54 | def tokenize_dataset(languages, tokenizer, data_path, tokenizer_name): 55 | 56 | for lang in languages: 57 | dataset, dataset_info = tfds.load(f"wiki40b/{lang}", with_info=True) 58 | os.makedirs(f"{data_path}/{tokenizer_name}_{lang}", exist_ok=True) 59 | for split in ['test', 'validation', 'train']: 60 | 61 | filename = os.path.join(f"{data_path}/{tokenizer_name}_{lang}/", f'{lang}_{split}.bin') 62 | if os.path.exists(filename): 63 | print(f"{filename} already exist, skipping ...") 64 | continue 65 | 66 | print(f"Size of the {split} set for {lang}: {len(dataset[split])} documents") 67 | 68 | dataset_list = list(dataset[split]) 69 | 70 | results = [] 71 | for doc in tqdm(dataset_list, desc=f"tokenizing {split} for {lang}"): 72 | results.append(process_document(doc, tokenizer)) 73 | 74 | print(f"Processed {len(results)} documents in {split} set for {lang}") 75 | 76 | arr_len = np.sum([len(seq) for seq in results]) 77 | dtype = np.uint32 78 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 79 | idx = 0 80 | for start in tqdm(range(0, len(results), 1024), desc=f'writing {filename}'): 81 | arr_batch = np.array([x for y in results[start:start+1024] for x in y], dtype=np.uint32) 82 | arr[idx : idx + len(arr_batch)] = arr_batch 83 | idx += len(arr_batch) 84 | arr.flush() 85 | 86 | 87 | 88 | if __name__ == "__main__": 89 | 90 | args = args_parser.parse_args() 91 | 92 | languages = args.langs.split(',') 93 | weights = [float(x) for x in args.weights.split(',')] 94 | if len(weights) == 1: 95 | weights = [weights[0]] * len(languages) 96 | 97 | assert len(languages) == len(weights) 98 | 99 | print(languages) 100 | print(weights) 101 | 102 | lang_str = ','.join(sorted(languages)) 103 | tokenizer_name = f"tokenizer-{lang_str}" 104 | 105 | if not os.path.exists(f"{tokenizer_name}.tknzr"): 106 | 107 | old_tokenizer = AutoTokenizer.from_pretrained("gpt2") 108 | tokenizer = old_tokenizer.train_new_from_iterator(get_training_corpus(languages, weights), args.vocab_size) 109 | tokenizer.save_pretrained(f"{tokenizer_name}.tknzr") 110 | 111 | else: 112 | 113 | tokenizer = AutoTokenizer.from_pretrained(f"./tokenizer-{lang_str}.tknzr") 114 | 115 | tokenize_dataset(languages, tokenizer, args.data_path, tokenizer_name) 116 | -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict 3 | import random 4 | import itertools 5 | 6 | from .slimpajama import get_slimpajama, get_slimpajama_6b, SUBSET2META 7 | from .wiki40b import get_wiki40b 8 | 9 | SLIMPAJAMA_DOMAINS = ['arxiv', 'book', 'c4', 'cc', 'github', 'stackexchange', 'wikipedia'] 10 | 11 | def get_dataset(args, dataset=None) -> Dict[str, np.ndarray]: 12 | """ Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is 13 | contained in its own python file. The expected format at the moment is a dictionary of np.memmap 14 | containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. """ 15 | if dataset is not None: 16 | trg_dataset = dataset 17 | else: 18 | trg_dataset = args.dataset 19 | print(f"Loading train dataset '{trg_dataset}'") 20 | 21 | if 'wiki40b' in trg_dataset: 22 | lang_list = ['en', 'ar', 'zh-cn', 'zh-tw', 'nl', 'fr', 'de', 'it', 'ja', 'ko', 'pl', 'pt', 'ru', 'es', 'th', 'tr', 'bg', 'ca', 'cs', 'da', 'el', 'et', 'fa', 'fi', 'he', 'hi', 'hr', 'hu', 'id', 'lt', 'lv', 'ms', 'no', 'ro', 'sk', 'sl', 'sr', 'sv', 'tl', 'uk', 'vi'] 23 | subset_list = trg_dataset.split('-')[1:] 24 | rst_dict = {} 25 | rst_dict['train'] = {} 26 | rst_dict['val'] = {} 27 | rst_dict['test'] = {} 28 | for subset in subset_list: 29 | if subset not in lang_list: 30 | continue 31 | subset_data = get_wiki40b(subset=subset, num_proc=10) 32 | rst_dict['train'][subset] = subset_data['train'] 33 | rst_dict['val'][subset] = subset_data['val'] 34 | rst_dict['test'][subset] = subset_data['test'] 35 | print(f"Subset {subset}: train[{len(subset_data['train'])}]|val[{subset_data['val']}]") 36 | return rst_dict 37 | 38 | if 'slim_6b' in trg_dataset: 39 | subset = trg_dataset.split('-')[1] 40 | if subset == 'all' or args.eval_all_domains: 41 | all_train_list, all_val_list = [], [] 42 | rst_dict = {} 43 | rst_dict['train'] = {} 44 | rst_dict['val'] = {} 45 | for k in SUBSET2META.keys(): 46 | subset_data = get_slim_redpajama_6b(subset=k, num_proc=10) 47 | rst_dict['train'][k] = subset_data['train'] 48 | rst_dict['val'][k] = subset_data['val'] 49 | all_train_list.append(subset_data['train']) 50 | all_val_list.append(subset_data['val']) 51 | train_data = np.concatenate(all_train_list) 52 | val_data = np.concatenate(all_val_list) 53 | rst_dict['train']['all'] = train_data 54 | rst_dict['val']['all'] = val_data 55 | 56 | if subset != 'all': 57 | rst_dict['train'] = rst_dict['train'][subset] 58 | if 'all' in rst_dict['val'].keys(): 59 | rst_dict['val'].pop('all') 60 | return rst_dict 61 | return get_slim_redpajama_6b(subset=subset, num_proc=10) 62 | elif 'slim_full' in trg_dataset: 63 | subset = trg_dataset.split('-')[1] 64 | if subset =='all': 65 | rst_dict = {} 66 | rst_dict['train'] = {} 67 | rst_dict['val'] = {} 68 | n_items_val = 5000 69 | 70 | for k in SLIMPAJAMA_DOMAINS: 71 | subset_data = get_slim_redpajama(subset=k, num_proc=10) 72 | rst_dict['train'][k] = subset_data['train'] 73 | rst_dict['val'][k] = subset_data['val'][:n_items_val*args.max_token_length] 74 | return rst_dict 75 | elif subset == 'mix': 76 | rst_dict = {} 77 | rst_dict['train'] = {} 78 | rst_dict['val'] = {} 79 | mix_data_train = [] 80 | mix_data_val = [] 81 | n_items_mix_train = (2000000000//args.max_token_length)//7 82 | n_items_mix_val = 2000 83 | 84 | for k in SLIMPAJAMA_DOMAINS: 85 | subset_data = get_slim_redpajama(subset=k, num_proc=10) 86 | rst_dict['train'][k] = subset_data['train'][:-n_items_mix_train*args.max_token_length] 87 | rst_dict['val'][k] = subset_data['val'][:n_items_mix_val*args.max_token_length] 88 | mix_data_train.append(subset_data['train'][-n_items_mix_train*args.max_token_length:]) 89 | mix_data_val.append(subset_data['val'][-n_items_mix_val*args.max_token_length:]) 90 | 91 | mix_train_data = np.concatenate(mix_data_train) 92 | mix_val_data = np.concatenate(mix_data_val) 93 | # shuffle 94 | A = np.arange(0, len(mix_train_data), args.max_token_length) 95 | np.random.shuffle(A) 96 | mix_train_data = np.concatenate([mix_train_data[i:i+args.max_token_length] for i in A]) 97 | 98 | B = np.arange(0, len(mix_val_data), args.max_token_length) 99 | np.random.shuffle(B) 100 | mix_val_data = np.concatenate([mix_val_data[i:i+args.max_token_length] for i in B]) 101 | 102 | rst_dict['train']['mix'] = mix_train_data 103 | rst_dict['val']['mix'] = mix_val_data 104 | return rst_dict 105 | return get_slim_redpajama(subset=subset, num_proc=10) 106 | elif 'slim_ood' in trg_dataset: 107 | subset_ood = trg_dataset.split('-')[1] 108 | rst_dict = {} 109 | rst_dict['train'] = {} 110 | rst_dict['val'] = {} 111 | n_items_train = 2000000000 112 | n_items_train_ood = 50000000 113 | n_items_val = 5000 114 | 115 | for k in SLIMPAJAMA_DOMAINS: 116 | subset_data = get_slim_redpajama(subset=k, num_proc=10) 117 | if k==subset_ood: 118 | rst_dict['train'][k] = subset_data['train'][:n_items_train_ood] 119 | rst_dict['val'][k] = subset_data['val'][:n_items_val*args.max_token_length] 120 | else: 121 | rst_dict['train'][k] = subset_data['train'][:n_items_train] 122 | rst_dict['val'][k] = subset_data['val'][:n_items_val*args.max_token_length] 123 | return rst_dict 124 | else: 125 | raise NotImplementedError(f"Unknow dataset key '{trg_dataset}'") 126 | -------------------------------------------------------------------------------- /src/eval_datasets.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, concatenate_datasets 2 | import string 3 | 4 | 5 | def substring_until(s, split_strs): 6 | idx = len(s) 7 | for split_str in split_strs: 8 | try: 9 | new_idx = s.index(split_str) 10 | if new_idx < idx: 11 | idx = new_idx 12 | except Exception: 13 | pass 14 | return s[:idx] 15 | 16 | 17 | def pred_postprocess_default(pred): 18 | pred = pred.strip().lower() 19 | return substring_until(pred, ['\n']).strip().lower().translate(str.maketrans('', '', string.punctuation)) 20 | 21 | 22 | def eval_func_default(answer, pred, prompt, model=None, tokenizer=None, inputs=None, trainer=None): 23 | if not isinstance(answer, list): 24 | answer = [answer.strip().lower().translate(str.maketrans('', '', string.punctuation))] 25 | else: 26 | answer = [a.strip().lower().translate(str.maketrans('', '', string.punctuation)) for a in answer] 27 | return pred in answer 28 | 29 | 30 | def get_eval_dataset(dataset_name, num_shots, seed=42): 31 | 32 | # defaults 33 | top_k = 1 34 | top_p = 0 35 | temperature = 1 36 | num_shots = num_shots 37 | max_new_tokens = 20 38 | shuffle_train = True 39 | 40 | eval_func = eval_func_default 41 | pred_postprocess_func = pred_postprocess_default 42 | 43 | # load fewshot dataset 44 | if dataset_name == 'trivia_qa': 45 | dataset = load_dataset(dataset_name, name='rc.nocontext') 46 | dataset_train = dataset['train'] 47 | dataset_val = dataset['validation'] 48 | input_key = 'question' 49 | output_key = 'answer' 50 | 51 | def prompt_transform(ex, context_exs): 52 | prompt = '\n\n'.join([f"Question: {c_ex[input_key]}\nAnswer: {c_ex[output_key]['aliases'][0]}" for c_ex in context_exs]) 53 | prompt += f"\n\nQuestion: {ex[input_key]}\nAnswer:" 54 | 55 | answer_list = ex[output_key]['aliases'] 56 | return {'prompt': prompt, 'answer': answer_list} 57 | 58 | elif dataset_name == 'natural_questions': 59 | dataset = load_dataset("lucadiliello/naturalquestionsshortqa") 60 | dataset_train = dataset['train'] 61 | dataset_val = dataset['validation'] 62 | 63 | def prompt_transform(ex, context_exs): 64 | prompt = '\n\n'.join([f"Q: {c_ex['question']}?\n\nA: {c_ex['answers'][0]}" 65 | for c_ex in context_exs]) 66 | prompt += f"\n\nQ: {ex['question']}?\n\nA:" 67 | 68 | answer_list = ex['answers'] 69 | return {'prompt': prompt, 'answer': answer_list} 70 | 71 | elif dataset_name == 'web_questions': 72 | dataset = load_dataset(dataset_name) 73 | dataset_train = dataset['train'] 74 | dataset_val = dataset['test'] 75 | 76 | def prompt_transform(ex, context_exs): 77 | prompt = '\n\n'.join([f"Question: {c_ex['question']}\nAnswer: {c_ex['answers'][0]}" 78 | for c_ex in context_exs]) 79 | prompt += f"\n\nQuestion: {ex['question']}\nAnswer:" 80 | 81 | answer_list = ex['answers'] 82 | return {'prompt': prompt, 'answer': answer_list} 83 | 84 | elif dataset_name == 'lambada': 85 | dataset = load_dataset(dataset_name) 86 | dataset_train = dataset['validation'] 87 | dataset_val = dataset['test'] 88 | 89 | def prompt_transform(ex, context_exs): 90 | words = ex['text'].split(' ') 91 | ex_input = ' '.join(words[:-1]) 92 | ex_answer = words[-1] 93 | 94 | context_ex_toks = [c_ex['text'].split(' ') for c_ex in context_exs] 95 | prompt = '\n\n'.join([f"Input: {' '.join(c_ex_toks[:-1])}\nOutput: {c_ex_toks[-1]}" 96 | for c_ex_toks in context_ex_toks]) 97 | prompt += f"\n\nInput: {ex_input}\nOutput:" 98 | prompt = "Complete the following sentences.\n\n" + prompt 99 | 100 | answer_list = [ex_answer] 101 | return {'prompt': prompt, 'answer': answer_list} 102 | 103 | elif dataset_name == 'squad_v2': 104 | dataset = load_dataset(dataset_name) 105 | # dataset_train = dataset['train'] 106 | shuffle_train = False 107 | 108 | dataset_val = dataset['validation'] 109 | 110 | # get indices for each title 111 | dataset_val_chunks = [] 112 | dataset_train_chunks = [] 113 | all_titles = set([ex['title'] for ex in dataset_val]) 114 | for i, title in enumerate(all_titles): 115 | title_dataset_val = dataset_val.filter(lambda x: x['title'] == title).shuffle(seed + i) 116 | title_dataset_train = title_dataset_val.select(list(reversed(range(len(title_dataset_val))))) 117 | assert(len(title_dataset_train) == len(title_dataset_val)) 118 | dataset_train_chunks.append(title_dataset_train) 119 | dataset_val_chunks.append(title_dataset_val) 120 | 121 | dataset_train = concatenate_datasets(dataset_train_chunks) 122 | dataset_val = concatenate_datasets(dataset_val_chunks) 123 | 124 | def prompt_transform(ex, context_exs): 125 | for c_ex in [ex] + context_exs: 126 | if len(c_ex['answers']['text']) == 0: 127 | c_ex['answers']['text'] = ['unanswerable'] 128 | assert(c_ex['title'] == ex['title']) 129 | 130 | prompt = f"Title: {ex['title']}\n\nBackground: {ex['context']}\n\n" 131 | prompt += '\n\n'.join([f"Question: {c_ex['question']}\n\nAnswer (use Background or answer \"unanswerable\"): {c_ex['answers']['text'][0]}"]) 132 | prompt += f"\n\nQuestion: {ex['question']}\n\nAnswer (use Background or answer \"unanswerable\"):" 133 | 134 | answer_list = ex['answers']['text'] 135 | return {'prompt': prompt, 'answer': answer_list} 136 | 137 | def eval_func(answer, pred, prompt, model, tokenizer, inputs, trainer): 138 | if not isinstance(answer, list): 139 | answer = [answer.strip().lower().translate(str.maketrans('', '', string.punctuation))] 140 | else: 141 | answer = [a.strip().lower().translate(str.maketrans('', '', string.punctuation)) for a in answer] 142 | return pred in answer 143 | 144 | else: 145 | raise ValueError(f"Dataset {dataset_name} not supported") 146 | 147 | return { 148 | 'top_k': top_k, 149 | 'top_p': top_p, 150 | 'temperature': temperature, 151 | 'num_shots': num_shots, 152 | 'max_new_tokens': max_new_tokens, 153 | 'prompt_transform': prompt_transform, 154 | 'dataset_train': dataset_train, 155 | 'shuffle_train': shuffle_train, 156 | 'dataset_val': dataset_val, 157 | 'eval_func': eval_func, 158 | 'pred_postprocess_func': pred_postprocess_func, } -------------------------------------------------------------------------------- /src/data/slimpajama.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | import tiktoken 5 | from datasets import load_dataset 6 | import os 7 | import torch 8 | tknzr = tiktoken.get_encoding("gpt2") 9 | 10 | SUBSET2META = { 11 | 'arxiv': 'RedPajamaArXiv', 12 | 'book': 'RedPajamaBook', 13 | 'cc': 'RedPajamaCommonCrawl', 14 | 'c4': 'RedPajamaC4', 15 | 'github': 'RedPajamaGithub', 16 | 'stackexchange': 'RedPajamaStackExchange', 17 | 'wikipedia': 'RedPajamaWikipedia', 18 | 19 | } 20 | 21 | def get_slimpajama(subset='arxiv', num_proc=40, 22 | return_torch=False,): 23 | """ Full: https://huggingface.co/datasets/cerebras/SlimPajama-627B 24 | 6B-subset: DKYoon/SlimPajama-6B 25 | """ 26 | # { 27 | # "text": ..., 28 | # "meta": {"url": "...", "timestamp": "...", "source": "...", "language": "...", ...}, 29 | # "red_pajama_subset": "common_crawl" | "c4" | "github" | "books" | "arxiv" | "wikipedia" | "stackexchange" 30 | # } 31 | SLIM_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/slim_redpajama/") 32 | SUBSET_PATH = os.path.join(SLIM_DATA_PATH, subset) 33 | subset_name = SUBSET2META[subset] 34 | if not os.path.exists(os.path.join(SUBSET_PATH, 'val.bin')): 35 | os.makedirs(SUBSET_PATH, exist_ok=True) 36 | dataset = load_dataset("cerebras/SlimPajama-627B", split=['train', 'test']) 37 | data_dict = {} 38 | data_dict['train'] = dataset[0].filter(lambda example: example["meta"]['redpajama_set_name']==subset_name) 39 | data_dict['val'] = dataset[1].filter(lambda example: example["meta"]['redpajama_set_name']==subset_name) 40 | 41 | def process(example): 42 | ids = tknzr.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 43 | ids.append(tknzr.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 44 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 45 | out = {'ids': ids, 'len': len(ids)} 46 | return out 47 | 48 | # tokenize the dataset 49 | tokenized = {} 50 | tokenized['train'] = data_dict['train'].map( 51 | process, 52 | remove_columns=['text', 'meta'], 53 | desc="tokenizing the splits", 54 | num_proc=num_proc, 55 | ) 56 | tokenized['val'] = data_dict['val'].map( 57 | process, 58 | remove_columns=['text', 'meta'], 59 | desc="tokenizing the splits", 60 | num_proc=num_proc, 61 | ) 62 | 63 | # concatenate all the ids in each dataset into one large file we can use for training 64 | for split, dset in tokenized.items(): 65 | arr_len = np.sum(dset['len']) 66 | filename = os.path.join(SUBSET_PATH, f'{split}.bin') 67 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 68 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 69 | total_batches = 100 70 | 71 | idx = 0 72 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): 73 | # Batch together samples for faster write 74 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') 75 | arr_batch = np.concatenate(batch['ids']) 76 | # Write into mmap 77 | arr[idx : idx + len(arr_batch)] = arr_batch 78 | idx += len(arr_batch) 79 | arr.flush() 80 | 81 | train_data = np.memmap(os.path.join(SUBSET_PATH, 'train.bin'), dtype=np.uint16, mode='r') 82 | val_data = np.memmap(os.path.join(SUBSET_PATH, 'val.bin'), dtype=np.uint16, mode='r') 83 | print(f'Subset {subset}: train[{len(train_data)}] | val[{len(val_data)}]') 84 | if return_torch: 85 | train_data = torch.tensor(np.array(train_data, dtype=np.int32)) 86 | val_data = torch.tensor(np.array(val_data, dtype=np.int32)) 87 | return {'train': train_data, 'val': val_data} 88 | 89 | 90 | def get_slimpajama_6b(subset='arxiv', num_proc=40, 91 | return_torch=False): 92 | """ Full: https://huggingface.co/datasets/cerebras/SlimPajama-627B 93 | 6B-subset: DKYoon/SlimPajama-6B 94 | """ 95 | # { 96 | # "text": ..., 97 | # "meta": {"url": "...", "timestamp": "...", "source": "...", "language": "...", ...}, 98 | # "red_pajama_subset": "common_crawl" | "c4" | "github" | "books" | "arxiv" | "wikipedia" | "stackexchange" 99 | # } 100 | REDPAJIMA_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/slim_6b/") 101 | SUBSET_PATH = os.path.join(REDPAJIMA_DATA_PATH, subset) 102 | subset_name = SUBSET2META[subset] 103 | print('Load subset_name: ', subset_name) 104 | if not os.path.exists(os.path.join(SUBSET_PATH, 'val.bin')): 105 | os.makedirs(SUBSET_PATH, exist_ok=True) 106 | dataset = load_dataset("DKYoon/SlimPajama-6B", split=['train', 'test']) 107 | print(dataset) 108 | data_dict = {} 109 | data_dict['train'] = dataset[0].filter(lambda example: example["meta"]['redpajama_set_name']==subset_name) 110 | data_dict['val'] = dataset[1].filter(lambda example: example["meta"]['redpajama_set_name']==subset_name) 111 | 112 | print(data_dict['train']) 113 | def process(example): 114 | 'Processing dataset...' 115 | ids = tknzr.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 116 | ids.append(tknzr.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 117 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 118 | out = {'ids': ids, 'len': len(ids)} 119 | return out 120 | 121 | # tokenize the dataset 122 | tokenized = {} 123 | tokenized['train'] = data_dict['train'].map( 124 | process, 125 | remove_columns=['text', 'meta', '__index_level_0__'], 126 | desc="tokenizing the splits", 127 | num_proc=num_proc, 128 | ) 129 | tokenized['val'] = data_dict['val'].map( 130 | process, 131 | remove_columns=['text', 'meta', '__index_level_0__'], 132 | desc="tokenizing the splits", 133 | num_proc=num_proc, 134 | ) 135 | 136 | # concatenate all the ids in each dataset into one large file we can use for training 137 | for split, dset in tokenized.items(): 138 | print('Columns: ', dset.features) 139 | arr_len = np.sum(dset['len']) 140 | filename = os.path.join(SUBSET_PATH, f'{split}.bin') 141 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 142 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 143 | total_batches = 10 144 | 145 | idx = 0 146 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): 147 | # Batch together samples for faster write 148 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') 149 | arr_batch = np.concatenate(batch['ids']) 150 | # Write into mmap 151 | arr[idx : idx + len(arr_batch)] = arr_batch 152 | idx += len(arr_batch) 153 | arr.flush() 154 | 155 | train_data = np.memmap(os.path.join(SUBSET_PATH, 'train.bin'), dtype=np.uint16, mode='r') 156 | val_data = np.memmap(os.path.join(SUBSET_PATH, 'val.bin'), dtype=np.uint16, mode='r') 157 | if return_torch: 158 | train_data = torch.tensor(np.array(train_data, dtype=np.int32)) 159 | val_data = torch.tensor(np.array(val_data, dtype=np.int32)) 160 | return {'train': train_data, 'val': val_data} 161 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | from trainer import * 2 | import logging 3 | from pathlib import Path 4 | import os 5 | import sys 6 | import json 7 | import numpy as np 8 | import argparse 9 | import datasets 10 | import torch 11 | import pickle 12 | 13 | import transformers 14 | from transformers import ( 15 | CONFIG_MAPPING, 16 | AutoConfig, 17 | AutoModelForCausalLM, 18 | AutoTokenizer, 19 | HfArgumentParser, 20 | set_seed, 21 | ) 22 | from eval_datasets import get_eval_dataset 23 | from dataloader import DataTrainingArguments, DomainConfigArguments, get_data_collator, get_train_eval_datasets 24 | from models import CausalLMOutputWithDomainIDs, ModelArguments, get_model_from_config, GPT2DoGE 25 | from trainer import FullTrainingArguments 26 | 27 | args_parser = argparse.ArgumentParser() 28 | # DomainConfigArguments 29 | args_parser.add_argument('--config_json', default='path to json config file', type=str) 30 | args_parser.add_argument('--wandb_proj', default='doge_universal', type=str) 31 | args_parser.add_argument('--wandb_run', default=None, type=str) 32 | args_parser.add_argument('--curriculum_path', default=None, type=str) 33 | args_parser.add_argument('--cc_selection', action='store_true') 34 | args_parser.add_argument('--cc_ns', default=10, type=int) 35 | args_parser.add_argument('--cc_steps', default=1000, type=int) 36 | args_parser.add_argument('--total_iterations', default=10000, type=int) 37 | 38 | 39 | def main(): 40 | args = args_parser.parse_args() 41 | os.environ["WANDB_PROJECT"] = args.wandb_proj # name your W&B project 42 | config_parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FullTrainingArguments)) 43 | if args.config_json is not None: 44 | model_args, data_args, training_args = config_parser.parse_json_file(json_file=args.config_json) 45 | else: 46 | model_args, data_args, training_args = config_parser.parse_args_into_dataclasses() 47 | 48 | if args.wandb_run is None: 49 | wandb_run_name = training_args.run_name 50 | else: 51 | wandb_run_name = args.wandb_run 52 | 53 | training_args.local_rank = -1 54 | print("training local_rank: ", training_args.local_rank) 55 | 56 | # Setup logging 57 | logging.basicConfig( 58 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 59 | datefmt="%m/%d/%Y %H:%M:%S", 60 | handlers=[logging.StreamHandler(sys.stdout)], 61 | ) 62 | train_ds, val_ds, domain_config, tokenizer, train_dataset_ls = get_train_eval_datasets(data_config=data_args, 63 | verbose=True, 64 | doremi=training_args.doremi, 65 | ) 66 | 67 | data_collator=get_data_collator(tokenizer, do_padding=data_args.do_padding, max_length=data_args.max_token_length) 68 | grad_acc_steps = training_args.gradient_accumulation_steps 69 | if training_args.doremi: 70 | # TODO: train reference model for doremi 71 | if training_args.ref_model is not None: 72 | logger.info("*** Load Reference Model (DoReMi) ***") 73 | ref_model, _ = get_model_from_config(model_args, doge=True, ref_model_path=training_args.ref_model) 74 | else: 75 | ref_model, _ = get_model_from_config(model_args, doge=True, ref_model_path=None) 76 | set_seed(training_args.seed) 77 | training_args.ddp_find_unused_parameters = False 78 | torch.cuda.empty_cache() 79 | # Initialize our Trainer 80 | ref_trainer = DoGETrainer( 81 | model=ref_model, 82 | args=training_args, 83 | domain_args=domain_config, 84 | train_dataset=train_ds if training_args.do_train else None, 85 | eval_dataset=val_ds if training_args.do_eval else None, 86 | tokenizer=tokenizer, 87 | data_collator=data_collator, 88 | selected_modules_ls=None, 89 | wandb_run_name="ref_doremi_"+wandb_run_name, 90 | output_dir=os.path.join(training_args.output_dir, "ref_doremi_"+wandb_run_name), 91 | grad_acc=1, 92 | ) 93 | if training_args.do_train: 94 | logger.info("*** Train Reference Model (DoReMi) ***") 95 | checkpoint = None 96 | ref_trainer.train(resume_from_checkpoint=None) 97 | ref_model.to("cuda") 98 | else: 99 | ref_model = None 100 | if args.cc_selection: 101 | cc_model, _ = get_model_from_config(model_args, doge=True) 102 | # Set seed before initializing model. 103 | set_seed(training_args.seed) 104 | # turn off find unused parameters 105 | training_args.ddp_find_unused_parameters = False 106 | 107 | torch.cuda.empty_cache() 108 | # Initialize our Trainer 109 | trainer = DoGETrainer( 110 | model=cc_model, 111 | args=training_args, 112 | domain_args=domain_config, 113 | train_dataset=train_ds if training_args.do_train else None, 114 | eval_dataset=val_ds if training_args.do_eval else None, 115 | tokenizer=tokenizer, 116 | data_collator=data_collator, 117 | cc_selection=args.cc_selection, 118 | cc_ns=args.cc_ns, 119 | cc_steps=args.cc_steps, 120 | selected_modules_ls=None, 121 | wandb_run_name="cc_"+wandb_run_name, 122 | output_dir=os.path.join(training_args.output_dir, "cc_"+wandb_run_name), 123 | grad_acc=1, 124 | ) 125 | if training_args.do_train: 126 | logger.info("*** Assessing Cancellation Effect ***") 127 | checkpoint = None 128 | trainer.train(resume_from_checkpoint=None) 129 | selected_modules_ls = trainer.selected_modules 130 | weight_dict = trainer.prev_w 131 | selected_params_num = 0 132 | for k in selected_modules_ls: 133 | selected_params_num += weight_dict[k].flatten().shape[0] 134 | print('Selected Modules: ') 135 | for m in selected_modules_ls: 136 | print('| ', m) 137 | print('Total parameters to compute W: ', selected_params_num, f'({selected_params_num*100/cc_model.num_parameters()}%)') 138 | else: 139 | selected_modules_ls = None 140 | selected_params_num = None 141 | 142 | ## Start Training ## 143 | # Detecting last checkpoint. 144 | doge_model, doge_config = get_model_from_config(model_args, doge=True) 145 | print("DoGE model parameters: ", doge_model.num_parameters()) 146 | print("Num. GPU used: ", training_args.n_gpu) 147 | print("Gradient accumulate steps: ", training_args.gradient_accumulation_steps) 148 | last_checkpoint = None 149 | num_skip_examples = 0 150 | output_dir = os.path.join(training_args.output_dir, wandb_run_name) 151 | if os.path.isdir(output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 152 | last_checkpoint = get_last_checkpoint(output_dir) 153 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: 154 | logger.info( 155 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 156 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 157 | ) 158 | state = TrainerState.load_from_json(str(Path(last_checkpoint) / TRAINER_STATE_NAME)) 159 | global_batch_size = training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size 160 | num_skip_examples = state.global_step * global_batch_size 161 | logger.info(f"Skipping {num_skip_examples} examples") 162 | else: 163 | os.makedirs(output_dir, exist_ok=True) 164 | # Set seed before initializing model. 165 | set_seed(training_args.seed) 166 | # turn off find unused parameters 167 | training_args.ddp_find_unused_parameters = False 168 | 169 | torch.cuda.empty_cache() 170 | # Initialize our Trainer 171 | trainer = DoGETrainer( 172 | model=doge_model, 173 | args=training_args, 174 | domain_args=domain_config, 175 | train_dataset=train_ds if training_args.do_train else None, 176 | eval_dataset=val_ds if training_args.do_eval else None, 177 | tokenizer=tokenizer, 178 | data_collator=data_collator, 179 | selected_modules_ls=selected_modules_ls, 180 | selected_params_num=selected_params_num, 181 | cc_selection=None, 182 | cc_ns=None, 183 | cc_steps=None, 184 | wandb_run_name=wandb_run_name, 185 | output_dir=output_dir, 186 | total_iterations=args.total_iterations, 187 | grad_acc=grad_acc_steps, 188 | ref_model=ref_model, 189 | train_dataset_ls=train_dataset_ls, 190 | ) 191 | 192 | if training_args.do_train: 193 | logger.info("*** Train ***") 194 | checkpoint = None 195 | if training_args.resume_from_checkpoint is not None: 196 | checkpoint = training_args.resume_from_checkpoint 197 | elif last_checkpoint is not None: 198 | checkpoint = last_checkpoint 199 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 200 | trainer.save_model() # Saves the tokenizer too for easy upload 201 | 202 | # Evaluation 203 | if training_args.do_eval: 204 | logger.info("*** Evaluate ***") 205 | 206 | if training_args.eval_all_checkpoints: 207 | checkpoint_dir_list = trainer.get_all_checkpoints(training_args.output_dir) 208 | else: 209 | checkpoint_dir_list = [get_last_checkpoint(training_args.output_dir)] 210 | 211 | for checkpoint_dir in checkpoint_dir_list: 212 | trainer.load_checkpoint(checkpoint_dir) 213 | state = TrainerState.load_from_json(str(Path(checkpoint_dir) / TRAINER_STATE_NAME)) 214 | trainer.state.global_step = state.global_step 215 | 216 | if not training_args.skip_perplexity_eval: 217 | metrics = trainer.evaluate() 218 | trainer.log_metrics("eval", metrics) 219 | trainer.save_metrics("eval", metrics) 220 | 221 | if training_args.downstream_datasets is not None: 222 | dataset_names = training_args.downstream_datasets.split(',') 223 | downstream_metrics = trainer.evaluate_fewshot( 224 | dataset_names, 225 | max_samples=data_args.max_downstream_samples, 226 | num_shots=training_args.downstream_num_shots) 227 | trainer.log_metrics("eval", downstream_metrics) 228 | trainer.save_metrics("eval", downstream_metrics) 229 | 230 | print('DoGE launched! ❤️‍🔥❤️‍🔥') 231 | 232 | if __name__ == "__main__": 233 | main() -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | import os 4 | import sys 5 | import json 6 | import numpy as np 7 | 8 | import datasets 9 | from dataclasses import dataclass, field 10 | # from typing import Optional 11 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 12 | 13 | import transformers 14 | from transformers import ( 15 | TrainingArguments, 16 | MODEL_FOR_CAUSAL_LM_MAPPING, 17 | CONFIG_MAPPING, 18 | AutoConfig, 19 | GPT2LMHeadModel, 20 | AutoModelForCausalLM, 21 | AutoTokenizer, 22 | HfArgumentParser, 23 | set_seed, 24 | ) 25 | from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions 26 | from transformers.trainer import TRAINER_STATE_NAME 27 | 28 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 29 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 30 | 31 | import torch 32 | from torch.nn import CrossEntropyLoss 33 | from collections import namedtuple 34 | from contextlib import nullcontext 35 | 36 | @dataclass 37 | class CausalLMOutputWithDomainIDs(CausalLMOutputWithCrossAttentions): 38 | domain_ids: Optional[torch.LongTensor] = None 39 | reference_pertoken_loss: Optional[torch.FloatTensor] = None # corresponds to uniq_domain_ids 40 | pertoken_loss: Optional[torch.FloatTensor] = None # corresponds to uniq_domain_ids 41 | token_mask: Optional[torch.FloatTensor] = None # 1.0 for tokens that are not padding 42 | hidden_states: Optional[torch.FloatTensor] = None # embeddings before linear + softmax 43 | 44 | 45 | @dataclass 46 | class ModelArguments: 47 | """ 48 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 49 | """ 50 | 51 | model_name_or_path: Optional[str] = field( 52 | default=None, 53 | metadata={ 54 | "help": ( 55 | "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." 56 | ) 57 | }, 58 | ) 59 | model_type: Optional[str] = field( 60 | default='gpt2', 61 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 62 | ) 63 | config_overrides: Optional[str] = field( 64 | default="n_positions=512,n_embd=768,n_layer=12,n_head=12", 65 | metadata={ 66 | "help": ( 67 | "Override some existing default config settings when a model is trained from scratch. Example: " 68 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 69 | ) 70 | }, 71 | ) 72 | config_name: Optional[str] = field( 73 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 74 | ) 75 | tokenizer_name: Optional[str] = field( 76 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 77 | ) 78 | cache_dir: Optional[str] = field( 79 | default=None, 80 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 81 | ) 82 | use_fast_tokenizer: bool = field( 83 | default=True, 84 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 85 | ) 86 | model_revision: str = field( 87 | default="main", 88 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 89 | ) 90 | use_auth_token: bool = field( 91 | default=False, 92 | metadata={ 93 | "help": ( 94 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 95 | "with private models)." 96 | ) 97 | }, 98 | ) 99 | torch_dtype: Optional[str] = field( 100 | default=None, 101 | metadata={ 102 | "help": ( 103 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " 104 | "dtype will be automatically derived from the model's weights." 105 | ), 106 | "choices": ["auto", "bfloat16", "float16", "float32"], 107 | }, 108 | ) 109 | 110 | def __post_init__(self): 111 | if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): 112 | raise ValueError( 113 | "--config_overrides can't be used in combination with --config_name or --model_name_or_path" 114 | ) 115 | 116 | def get_model_from_config(model_args:ModelArguments, 117 | doge=False, 118 | ref_model_path=None,): 119 | config_kwargs = { 120 | "cache_dir": model_args.cache_dir, 121 | "revision": model_args.model_revision, 122 | "use_auth_token": True if model_args.use_auth_token else None, 123 | } 124 | if model_args.config_name: 125 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 126 | elif model_args.model_name_or_path: 127 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 128 | else: 129 | config = CONFIG_MAPPING[model_args.model_type]() 130 | print("You are instantiating a new config instance from scratch.") 131 | if model_args.config_overrides is not None: 132 | print(f"Overriding config: {model_args.config_overrides}") 133 | config.update_from_string(model_args.config_overrides) 134 | print(f"New config: {config}") 135 | if doge: 136 | if ref_model_path is not None: 137 | return GPT2DoGE(config).from_pretrained(ref_model_path), config 138 | return GPT2DoGE(config), config 139 | 140 | return GPT2LMHeadModel(config), config 141 | 142 | 143 | class GPT2DoGE(GPT2LMHeadModel): 144 | 145 | def __init__(self, config): 146 | super().__init__(config) 147 | self.ignore_index = -100 148 | # self.loss_fct: compute mean token loss for standard training 149 | self.loss_fct = CrossEntropyLoss(reduction='mean', ignore_index=self.ignore_index) 150 | # self.pertoken_loss_fct: compute token loss for proxy model 151 | self.pertoken_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=self.ignore_index) 152 | 153 | def _forward(self, 154 | input_ids: Optional[torch.LongTensor] = None, 155 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 156 | attention_mask: Optional[torch.FloatTensor] = None, 157 | token_type_ids: Optional[torch.LongTensor] = None, 158 | position_ids: Optional[torch.LongTensor] = None, 159 | head_mask: Optional[torch.FloatTensor] = None, 160 | inputs_embeds: Optional[torch.FloatTensor] = None, 161 | encoder_hidden_states: Optional[torch.Tensor] = None, 162 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 163 | use_cache: Optional[bool] = None, 164 | output_attentions: Optional[bool] = None, 165 | output_hidden_states: Optional[bool] = None, 166 | return_dict: Optional[bool] = None, 167 | last_token_only: Optional[bool] = False,): 168 | """ 169 | last_token_only: whether to return the logit for the last token only, 170 | of shape (batch_size, vocab_size) 171 | """ 172 | 173 | transformer_outputs = self.transformer( 174 | input_ids, 175 | past_key_values=past_key_values, 176 | attention_mask=attention_mask, 177 | token_type_ids=token_type_ids, 178 | position_ids=position_ids, 179 | head_mask=head_mask, 180 | inputs_embeds=inputs_embeds, 181 | encoder_hidden_states=encoder_hidden_states, 182 | encoder_attention_mask=encoder_attention_mask, 183 | use_cache=use_cache, 184 | output_attentions=output_attentions, 185 | output_hidden_states=output_hidden_states, 186 | return_dict=return_dict, 187 | ) 188 | hidden_states = transformer_outputs[0] 189 | 190 | if last_token_only: 191 | hidden_states = hidden_states[:, -1] 192 | 193 | # Set device for model parallelism 194 | if self.model_parallel: 195 | torch.cuda.set_device(self.transformer.first_device) 196 | hidden_states = hidden_states.to(self.lm_head.weight.device) 197 | 198 | lm_logits = self.lm_head(hidden_states) 199 | 200 | CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "hidden_states"]) 201 | if output_hidden_states: 202 | return CausalLMOutput(logits=lm_logits, hidden_states=hidden_states) 203 | else: 204 | return CausalLMOutput(logits=lm_logits, hidden_states=None) 205 | 206 | 207 | def forward(self, 208 | input_ids: Optional[torch.LongTensor] = None, 209 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 210 | attention_mask: Optional[torch.FloatTensor] = None, 211 | token_type_ids: Optional[torch.LongTensor] = None, 212 | position_ids: Optional[torch.LongTensor] = None, 213 | head_mask: Optional[torch.FloatTensor] = None, 214 | inputs_embeds: Optional[torch.FloatTensor] = None, 215 | encoder_hidden_states: Optional[torch.Tensor] = None, 216 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 217 | labels: Optional[torch.LongTensor] = None, 218 | use_cache: Optional[bool] = None, 219 | output_attentions: Optional[bool] = None, 220 | output_hidden_states: Optional[bool] = None, 221 | return_dict: Optional[bool] = None, 222 | # new params 223 | domain_ids: Optional[torch.LongTensor] = None, 224 | return_pertoken_losses: Optional[bool] = False, 225 | last_token_only: Optional[bool] = False, 226 | ) -> Union[Tuple, CausalLMOutputWithCrossAttentions, CausalLMOutputWithDomainIDs]: 227 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 228 | if not return_pertoken_losses: 229 | # perform standard training 230 | fwd_output = self._forward( 231 | input_ids=input_ids, 232 | past_key_values=past_key_values, 233 | attention_mask=attention_mask, 234 | token_type_ids=token_type_ids, 235 | position_ids=position_ids, 236 | head_mask=head_mask, 237 | inputs_embeds=inputs_embeds, 238 | encoder_hidden_states=encoder_hidden_states, 239 | encoder_attention_mask=encoder_attention_mask, 240 | use_cache=use_cache, 241 | output_attentions=output_attentions, 242 | output_hidden_states=output_hidden_states, 243 | return_dict=return_dict, 244 | last_token_only=last_token_only) 245 | lm_logits = fwd_output.logits 246 | 247 | loss = None 248 | pertoken_loss = None 249 | token_mask = None 250 | if labels is not None: 251 | # move labels to correct device to enable model parallelism 252 | labels = labels.to(lm_logits.device) 253 | # Shift so that tokens < n predict n 254 | shift_logits = lm_logits[..., :-1, :].contiguous() 255 | shift_labels = labels[..., 1:].contiguous() 256 | # Flatten the tokens 257 | loss = self.loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 258 | else: 259 | # train proxy model 260 | fwd_output = self._forward( 261 | input_ids=input_ids, 262 | past_key_values=past_key_values, 263 | attention_mask=attention_mask, 264 | token_type_ids=token_type_ids, 265 | position_ids=position_ids, 266 | head_mask=head_mask, 267 | inputs_embeds=inputs_embeds, 268 | encoder_hidden_states=encoder_hidden_states, 269 | encoder_attention_mask=encoder_attention_mask, 270 | use_cache=use_cache, 271 | output_attentions=output_attentions, 272 | output_hidden_states=output_hidden_states, 273 | return_dict=return_dict, 274 | last_token_only=last_token_only) 275 | lm_logits = fwd_output.logits 276 | 277 | loss = None 278 | pertoken_loss = None 279 | token_mask = None 280 | if labels is not None: 281 | # move labels to correct device to enable model parallelism 282 | labels = labels.to(lm_logits.device) 283 | ignore_index = -100 284 | shift_logits = lm_logits[..., :-1, :].contiguous() 285 | shift_labels = labels[..., 1:].contiguous() 286 | # Flatten the tokens 287 | pertoken_loss = self.pertoken_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 288 | 289 | pertoken_loss = pertoken_loss.view(shift_labels.size(0), shift_labels.size(1)) 290 | token_mask = shift_labels.ne(ignore_index).float() # not equal to PAD 291 | 292 | loss = pertoken_loss.sum() / token_mask.sum() 293 | if not return_dict: 294 | output = (lm_logits, None, fwd_output.hidden_states, None, domain_ids, pertoken_loss, token_mask) 295 | return ((loss,) + output) if loss is not None else output 296 | 297 | out_hidden_states = fwd_output.hidden_states 298 | return CausalLMOutputWithDomainIDs( 299 | loss=loss, 300 | logits=lm_logits, 301 | past_key_values=None, 302 | hidden_states=out_hidden_states, 303 | attentions=None, 304 | domain_ids=domain_ids, 305 | pertoken_loss=pertoken_loss, 306 | token_mask=token_mask) 307 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | ## Data Arguments ## 2 | from dataclasses import dataclass, field 3 | import pickle 4 | # from typing import Optional 5 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 6 | from transformers import TrainingArguments, MODEL_FOR_CAUSAL_LM_MAPPING 7 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 8 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 9 | import transformers 10 | from data.utils import get_dataset 11 | import random 12 | import numpy as np 13 | from datasets import Dataset, IterableDataset, load_from_disk 14 | from datasets.iterable_dataset import RandomlyCyclingMultiSourcesExamplesIterable 15 | from pathlib import Path 16 | from collections import Counter 17 | from copy import deepcopy 18 | from transformers import AutoTokenizer 19 | import torch 20 | from torch.utils.data import WeightedRandomSampler 21 | from cfg_tokenizer import CFGTokenizer 22 | # transformers.utils.move_cache('/mloraw1/sfan/huggingface_cache') 23 | 24 | RANDOM_BATCH_SIZE = 512 25 | DEFAULT_SEED=111 26 | 27 | 28 | @dataclass 29 | class DataTrainingArguments: 30 | """ 31 | Arguments pertaining to what data we are going to input our model for training and eval. 32 | """ 33 | 34 | dataset_dir: str = field( 35 | default='.', metadata={"help": "Path to the dataset directory."} 36 | ) 37 | dataset: str = field( 38 | default='redpajama-all', metadata={"help": "Name of the dataset."} 39 | ) 40 | curriculum_path: str = field( 41 | default=None, metadata={"help": "Path to stage-wise curriculum domain weights (.pkl)."} 42 | ) 43 | train_domains: str = field( 44 | default='arxiv,book,cc,c4,github,wikipedia', metadata={"help": "domain names for training."} 45 | ) 46 | tgt_domains: str = field( 47 | default='stackexchange', metadata={"help": "target domain name(s) for generalization."} 48 | ) 49 | train_dw: str = field( 50 | default=None, metadata={"help": "training domain weights."} 51 | ) 52 | val_dw: str = field( 53 | default=None, metadata={"help": "validation domain weights."} 54 | ) 55 | eval_dataset_dir: str = field( 56 | default=None, metadata={"help": "Path to the eval dataset directory. Defaults to dataset_dir"} 57 | ) 58 | eval_dataset_name: str = field( 59 | default=None, metadata={"help": "Name of the eval dataset. Defaults to dataset_name."} 60 | ) 61 | max_train_samples: Optional[int] = field( 62 | default=None, 63 | metadata={ 64 | "help": ( 65 | "For debugging purposes or quicker training, truncate the number of training examples to this " 66 | "value if set." 67 | ) 68 | }, 69 | ) 70 | max_eval_samples: Optional[int] = field( 71 | default=None, 72 | metadata={ 73 | "help": ( 74 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 75 | "value if set." 76 | ) 77 | }, 78 | ) 79 | max_downstream_samples: Optional[int] = field( 80 | default=None, 81 | metadata={ 82 | "help": ( 83 | "For quicker downstream evaluation, limit the number of examples if set." 84 | ) 85 | }, 86 | ) 87 | max_token_length: int = field( 88 | default=512, 89 | metadata={ 90 | "help": ( 91 | "Input sequence length after tokenization. " 92 | ) 93 | }, 94 | ) 95 | block_size: Optional[int] = field( 96 | default=None, 97 | metadata={ 98 | "help": ( 99 | "Optional input sequence length after tokenization. " 100 | "The training dataset will be truncated in block of this size for training. " 101 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 102 | ) 103 | }, 104 | ) 105 | overwrite_cache: bool = field( 106 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 107 | ) 108 | do_padding: bool = field( 109 | default=False, metadata={"help": "Pad the inputs."} 110 | ) 111 | preprocessing_num_workers: Optional[int] = field( 112 | default=None, 113 | metadata={"help": "The number of processes to use for the preprocessing."}, 114 | ) 115 | shuffle: bool = field( 116 | default=True, metadata={"help": "Shuffle the training data on the fly"} 117 | ) 118 | keep_in_memory: bool = field( 119 | default=False, metadata={"help": "keep data in memory"} 120 | ) 121 | 122 | 123 | 124 | @dataclass 125 | class DomainConfigArguments: 126 | """ 127 | Domain config settings. """ 128 | 129 | domain_list: list = field( 130 | default_factory=list, 131 | # default=['arxiv', 'book', 'c4', 'cc', 'github', 'stackexchange', 'wikipedia'], 132 | metadata={"help": "List of domain names."} 133 | ) 134 | train_dw: torch.Tensor = field( 135 | default=None, 136 | metadata={"help": "Training domain weights."} 137 | ) 138 | val_dw: torch.Tensor = field( 139 | default=None, 140 | metadata={"help": "Validation domain weights."} 141 | ) 142 | idx2domain: dict = field( 143 | default_factory=dict, 144 | metadata={"help": "index mapping to domain names."} 145 | ) 146 | domain2idx: dict = field( 147 | default_factory=dict, 148 | metadata={"help": "domain names mapping to indices."} 149 | ) 150 | train_ids: torch.Tensor = field( 151 | default=None, 152 | metadata={"help": "Training domain indices."} 153 | ) 154 | tgt_ids: torch.Tensor = field( 155 | default=None, 156 | metadata={"help": "Target domain indices."} 157 | ) 158 | curriculum_path: str = field( 159 | default=None, metadata={"help": "Path to stage-wise curriculum domain weights (.pkl)."} 160 | ) 161 | 162 | def domain_gen(data, seq_len, domain_id=None): 163 | if domain_id is None: 164 | for i in range(len(data)//seq_len): 165 | yield {"input_ids": data[i*seq_len:(i+1)*seq_len]} 166 | else: 167 | for i in range(len(data)//seq_len): 168 | yield {"domain_ids": torch.tensor([domain_id], dtype=torch.long), "input_ids": data[i*seq_len:(i+1)*seq_len]} 169 | 170 | class UpdatableRandomlyCyclingMultiSourcesExamplesIterable( 171 | RandomlyCyclingMultiSourcesExamplesIterable): 172 | 173 | def __init__(self, ex_iterables, generator, probabilities=None, probabilities_handle=None, stopping_strategy="all_exhausted", 174 | curriculum_dict=None): 175 | ''' 176 | probabilities: vector of static probabilities over training 177 | probabilities_handle: handle to domain weights buffer in model params 178 | ''' 179 | super().__init__(ex_iterables, generator, stopping_strategy=stopping_strategy) 180 | self.probabilities_handle = probabilities_handle 181 | self.probabilities = probabilities 182 | self.curriculum_dict = curriculum_dict 183 | if curriculum_dict is not None: 184 | self.step = 0 185 | 186 | @staticmethod 187 | def _iter_random_indices(rng, num_sources, probabilities_handle=None, p=None, random_batch_size=RANDOM_BATCH_SIZE): 188 | while True: 189 | # read domain weights 190 | if probabilities_handle is not None: 191 | p = probabilities_handle.detach().cpu().numpy() 192 | yield from WeightedRandomSampler(weights=p, num_samples=random_batch_size, replacement=True) 193 | 194 | def _give_indice_iterator(self): 195 | rng = deepcopy(self.generator) 196 | return self._iter_random_indices(rng, len(self.ex_iterables), probabilities_handle=self.probabilities_handle, probabilities=self.probabilities) 197 | 198 | def shard_data_sources(self, shard_indices): 199 | return self 200 | 201 | @property 202 | def n_shards(self): 203 | return 1 204 | 205 | def shuffle_data_sources(self, seed): 206 | self.ex_iterables = [ex_iterable.shuffle_data_sources(seed) for ex_iterable in self.ex_iterables] 207 | return self 208 | 209 | 210 | def interleave_datasets(datasets, probabilities=None, probabilities_handle=None, seed=None, stopping_strategy='all_exhausted'): 211 | iterable_datasets = [] 212 | for dataset in datasets: 213 | if not isinstance(dataset, IterableDataset): 214 | iterable_datasets.append(dataset.to_iterable_dataset()) 215 | else: 216 | iterable_datasets.append(dataset) 217 | 218 | ex_iterables = [d._ex_iterable for d in iterable_datasets] 219 | 220 | generator = np.random.default_rng(seed) 221 | ex_iterable = UpdatableRandomlyCyclingMultiSourcesExamplesIterable( 222 | ex_iterables, generator=generator, 223 | probabilities=probabilities, probabilities_handle=probabilities_handle, 224 | stopping_strategy=stopping_strategy) 225 | 226 | return IterableDataset(ex_iterable=ex_iterable) 227 | 228 | def get_train_eval_datasets(data_config:DataTrainingArguments, 229 | verbose:bool=False, 230 | doremi:bool=False, 231 | **kwargs): 232 | data_dict = get_dataset(data_config) 233 | if 'all' in data_dict['train'].keys(): 234 | del data_dict['train']['all'] 235 | if 'all' in data_dict['val'].keys(): 236 | del data_dict['val']['all'] 237 | if doremi and ('mix' in data_dict['train'].keys()): 238 | del data_dict['train']['mix'] 239 | del data_dict['val']['mix'] 240 | 241 | seed = 42 242 | sequence_length = data_config.max_token_length 243 | max_train_samples = data_config.max_train_samples 244 | max_eval_samples = data_config.max_eval_samples 245 | 246 | domain_list = list(data_dict['train'].keys()) 247 | idx2domain = {i:dom for i,dom in enumerate(domain_list)} 248 | domain2idx = {dom:i for i,dom in idx2domain.items()} 249 | train_ids = torch.tensor([domain2idx[name] for name in data_config.train_domains.split(',')]) 250 | tgt_ids = torch.tensor([domain2idx[name] for name in data_config.tgt_domains.split(',')]) 251 | 252 | all_domain_ids = torch.concat([train_ids, tgt_ids]).numpy() 253 | curriculum_dict = None 254 | if data_config.curriculum_path is not None: 255 | with open(data_config.curriculum_path, "rb") as trg: 256 | curriculum_dict = pickle.load(trg) 257 | 258 | if curriculum_dict is not None: 259 | train_dw = torch.tensor(curriculum_dict[0], dtype=torch.float) 260 | elif data_config.train_dw is None: 261 | train_dw = torch.ones(len(domain_list), dtype=torch.float)/len(set(all_domain_ids)) 262 | if len(domain_list)>len(set(all_domain_ids)): 263 | exclude_ids = torch.tensor([i for i in range(len(domain_list)) if torch.tensor(i) not in all_domain_ids]) 264 | train_dw[exclude_ids] = 0.0 265 | else: 266 | train_dw = torch.tensor([float(i) for i in data_config.train_dw.split(",")]) 267 | 268 | if data_config.val_dw is None: 269 | if 'mix' not in data_config.tgt_domains: 270 | val_dw = torch.zeros(len(domain_list), dtype=torch.float) 271 | val_dw[tgt_ids] = 1/len(tgt_ids) 272 | else: 273 | val_dw = torch.ones(len(domain_list), dtype=torch.float)/len(domain_list) 274 | else: 275 | val_dw = torch.tensor([float(i) for i in data_config.val_dw.split(",")]) 276 | 277 | domain_config = DomainConfigArguments( 278 | domain_list=domain_list, 279 | idx2domain=idx2domain, 280 | domain2idx=domain2idx, 281 | train_ids=train_ids, 282 | tgt_ids=tgt_ids, 283 | train_dw=train_dw, 284 | val_dw=val_dw, 285 | curriculum_path=data_config.curriculum_path, 286 | **kwargs) 287 | train_dict = {domain2idx[dom]:v for dom,v in data_dict['train'].items()} 288 | val_dict = {domain2idx[dom]:v for dom,v in data_dict['val'].items() if val_dw[domain2idx[dom]]>0} 289 | 290 | train_dataset_ls, val_dataset_ls = [], [] 291 | for k in train_dict.keys(): 292 | train_domain_dataset = IterableDataset.from_generator(domain_gen, 293 | gen_kwargs={'data': train_dict[k], 294 | 'seq_len': sequence_length, 295 | 'domain_id': k, 296 | } 297 | ) 298 | train_dataset_ls.append(train_domain_dataset) 299 | if verbose: 300 | print(f'{idx2domain[k]} loaded!') 301 | 302 | val_dw_gen = [] 303 | for k in val_dict.keys(): 304 | val_domain_dataset = IterableDataset.from_generator(domain_gen, 305 | gen_kwargs={'data': val_dict[k], 306 | 'seq_len': sequence_length, 307 | 'domain_id': k, 308 | } 309 | ) 310 | val_dataset_ls.append(val_domain_dataset) 311 | val_dw_gen.append(val_dw[k]) 312 | 313 | train_ds = interleave_datasets( 314 | train_dataset_ls, 315 | probabilities=train_dw, 316 | probabilities_handle=train_dw, 317 | seed=seed) 318 | val_ds = interleave_datasets( 319 | val_dataset_ls, 320 | probabilities=torch.tensor(val_dw_gen), 321 | probabilities_handle=torch.tensor(val_dw_gen), 322 | seed=seed) 323 | 324 | 325 | def take_data_generator(ds, max_samples): 326 | idx = 0 327 | for ex in ds: 328 | yield ex 329 | idx += 1 330 | if max_samples is not None and idx >= max_samples: 331 | return 332 | if max_train_samples is not None: 333 | train_ds = IterableDataset.from_generator(take_data_generator, gen_kwargs={'ds': train_ds, 'max_samples': max_train_samples}) 334 | if max_eval_samples is not None: 335 | val_ds = IterableDataset.from_generator(take_data_generator, gen_kwargs={'ds': val_ds, 'max_samples': max_eval_samples}) 336 | if 'wiki40b' in data_config.dataset: 337 | tokenizer = AutoTokenizer.from_pretrained('/scratch/pagliard/doge/exp/doge_frozen_weights_12l_catalan-mu0001_seed42_10k/checkpoint-10000') 338 | elif 'cfg' in data_config.dataset: 339 | tokenizer = CFGTokenizer 340 | else: 341 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 342 | if 'cfg' not in data_config.dataset: 343 | tokenizer.model_max_length=data_config.max_token_length 344 | if data_config.curriculum_path is not None: 345 | return train_ds, val_ds, domain_config, tokenizer, train_dataset_ls 346 | else: 347 | return train_ds, val_ds, domain_config, tokenizer, None 348 | 349 | def get_data_collator(tokenizer, return_tensors='pt', do_padding=False, max_length=1024): 350 | def data_collator(features): 351 | if not do_padding: 352 | try: 353 | batch = { 354 | k: torch.tensor([f[k] for f in features]) 355 | for k in features[0].keys() if k!='input_ids' 356 | } 357 | if not torch.is_tensor(batch['input_ids']): 358 | batch['input_ids'] = torch.tensor([np.array(f['input_ids'], dtype=np.int32) for f in features]) 359 | except Exception: 360 | batch = { 361 | k: torch.tensor([np.array(f[k], dtype=np.int32) for f in features]) 362 | for k in features[0].keys() 363 | } 364 | else: 365 | try: 366 | batch = tokenizer.pad(features, return_tensors=return_tensors, pad_to_multiple_of=max_length) 367 | except: 368 | raise Exception 369 | batch['input_ids'] = batch['input_ids'].long() 370 | if 'attention_mask' not in batch: 371 | batch['attention_mask'] = torch.ones_like(batch['input_ids']).long() 372 | else: 373 | batch['attention_mask'] = batch['attention_mask'].long() 374 | 375 | batch.pop("special_tokens_mask", None) 376 | if 'labels' not in batch: 377 | labels = batch['input_ids'].clone() 378 | batch["labels"] = labels 379 | 380 | try: 381 | if tokenizer.pad_token_id is not None: 382 | batch['labels'][batch['labels'] == tokenizer.pad_token_id] = -100 383 | except: 384 | pass 385 | 386 | if 'domain_ids' not in batch and 'domain_id' in batch: 387 | batch['domain_ids'] = batch['domain_id'] # compat 388 | batch.pop('domain_id') 389 | # print(batch) 390 | return batch 391 | return data_collator 392 | --------------------------------------------------------------------------------