├── evaluation ├── .gitkeep ├── README.md └── eval_lm_eval_harness.sh ├── mixture_config ├── config_1m │ └── .gitkeep ├── config_1b │ ├── pile_cc_only.yaml │ ├── doremi.yaml │ ├── human.yaml │ └── regmix.yaml ├── README.md ├── visualize_mixture.py └── synthesize_mixture.py ├── misc ├── 1m_pile_cc_loss.pdf ├── method_figure.png ├── prior_vs_regmix.pdf └── weight_distributions.png ├── model_training ├── preprocess │ ├── tokenizer │ │ ├── gptneox │ │ │ └── tokenizer_config.json │ │ └── starcoder │ │ │ ├── special_tokens_map.json │ │ │ └── tokenizer_config.json │ ├── download_dataset.py │ ├── run_preprocess.sh │ └── prepare_file_domain.py ├── requirements.txt ├── convert_lit_to_hf.sh ├── pretrain_tinyllama_1b.sh ├── pretrain_tinyllama_1m.sh ├── lit_gpt │ ├── __init__.py │ ├── tokenizer.py │ ├── fused_rotary_embedding.py │ ├── config.py │ ├── fused_cross_entropy.py │ ├── packed_dataset.py │ ├── model.py │ ├── utils.py │ └── speed_monitor.py ├── README.md └── convert_lit_checkpoint.py ├── LICENSE ├── regression_fitting ├── README.md ├── collect_mixture_data.py └── collect_loss_data.py ├── .gitignore ├── data ├── test_mixture_1B.csv ├── test_pile_loss_1B.csv ├── test_mixture_1m.csv └── test_mixture_60m.csv └── README.md /evaluation/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mixture_config/config_1m/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /misc/1m_pile_cc_loss.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/regmix/HEAD/misc/1m_pile_cc_loss.pdf -------------------------------------------------------------------------------- /misc/method_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/regmix/HEAD/misc/method_figure.png -------------------------------------------------------------------------------- /misc/prior_vs_regmix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/regmix/HEAD/misc/prior_vs_regmix.pdf -------------------------------------------------------------------------------- /misc/weight_distributions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/regmix/HEAD/misc/weight_distributions.png -------------------------------------------------------------------------------- /mixture_config/config_1b/pile_cc_only.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | train_the_pile_pile_cc: 1.0 3 | valid: 4 | valid_the_pile_pile_cc: 1.0 5 | model_name: tinyllama_1_1b -------------------------------------------------------------------------------- /model_training/preprocess/tokenizer/gptneox/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": "<|endoftext|>", 4 | "clean_up_tokenization_spaces": true, 5 | "eos_token": "<|endoftext|>", 6 | "model_max_length": 2048, 7 | "tokenizer_class": "GPTNeoXTokenizer", 8 | "unk_token": "<|endoftext|>" 9 | } 10 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # RegMix Evaluation 2 | 3 | ## Using lm-eval-harness 4 | 5 | First you should install the lm-eval package from the github repository, run: 6 | 7 | ``` 8 | git clone https://github.com/EleutherAI/lm-evaluation-harness 9 | cd lm-evaluation-harness 10 | pip install -e . 11 | ``` 12 | 13 | Then you can use the `eval_lm_eval_harness.sh` to evaluate your model. Please remember to modify `model_args` in the script. -------------------------------------------------------------------------------- /model_training/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.1.0dev 2 | lightning==2.1.2 3 | lightning[app] 4 | jsonargparse[signatures] # CLI 5 | pandas 6 | pyarrow 7 | tokenizers 8 | sentencepiece 9 | wandb 10 | zstd 11 | 12 | # for finetuning 13 | bitsandbytes==0.40.0 14 | transformers==4.31.0 15 | peft==0.4.0 16 | accelerate==0.21.0 17 | einops==0.6.1 18 | evaluate==0.4.0 19 | scikit-learn==1.2.2 20 | sentencepiece==0.1.99 21 | wandb==0.15.3 22 | -------------------------------------------------------------------------------- /model_training/preprocess/tokenizer/starcoder/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "additional_special_tokens": [ 3 | "<|endoftext|>", 4 | "", 5 | "", 6 | "", 7 | "", 8 | "", 9 | "", 10 | "", 11 | "", 12 | "", 13 | "", 14 | "", 15 | "", 16 | "", 17 | "", 18 | "", 19 | "", 20 | "", 21 | "" 22 | ], 23 | "bos_token": "<|endoftext|>", 24 | "eos_token": "<|endoftext|>", 25 | "unk_token": "<|endoftext|>" 26 | } 27 | -------------------------------------------------------------------------------- /model_training/convert_lit_to_hf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # tinyllama_1_1b, tinyllama_1M, more can be found in lit_gpt/config.py 4 | export ARCH_NAME="tinyllama_1_1b" 5 | export INP_FOLDER="lit_checkpoint_folder" 6 | export FILE_NAME="iter-025000-ckpt.pth" 7 | export OUT_FOLDER="converted_huggingface_model_folder" 8 | 9 | # convert the model into Huggingface compatible format, and save the config.json 10 | python convert_lit_checkpoint.py --checkpoint_name "$FILE_NAME" --inp_dir "$INP_FOLDER" --out_dir "$OUT_FOLDER" --model_name $ARCH_NAME 11 | 12 | # copy tokenizer and config into the new folder 13 | # WARNING: if you use a different tokenizer, you need to modify the folder 14 | cp -r preprocess/tokenizer/gptneox/* "$OUT_FOLDER" -------------------------------------------------------------------------------- /model_training/preprocess/download_dataset.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import snapshot_download 2 | from argparse import ArgumentParser 3 | 4 | # Use argparse to handle command line arguments 5 | parser = ArgumentParser() 6 | parser.add_argument("--dataset_name", type=str, required=True, 7 | help="Dataset name to download", default="sail/regmix-data-sample", 8 | choices=["sail/regmix-data-sample", "sail/regmix-data-sample"]) 9 | args = parser.parse_args() 10 | 11 | # You can choose to download regmix-data, or regmix-data-sample 12 | snapshot_download(repo_id=args.dataset_name, 13 | repo_type='dataset', 14 | local_dir=args.dataset_name, 15 | local_dir_use_symlinks=False) 16 | -------------------------------------------------------------------------------- /mixture_config/config_1b/doremi.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | train_the_pile_arxiv: 0.0036 3 | train_the_pile_freelaw: 0.0043 4 | train_the_pile_nih_exporter: 0.0063 5 | train_the_pile_pubmed_central: 0.0046 6 | train_the_pile_wikipedia_en: 0.0699 7 | train_the_pile_dm_mathematics: 0.0018 8 | train_the_pile_github: 0.0179 9 | train_the_pile_philpapers: 0.0274 10 | train_the_pile_stackexchange: 0.0153 11 | train_the_pile_enron_emails: 0.0070 12 | train_the_pile_gutenberg_pg_19: 0.0072 13 | train_the_pile_pile_cc: 0.6057 14 | train_the_pile_ubuntu_irc: 0.0093 15 | train_the_pile_europarl: 0.0062 16 | train_the_pile_hackernews: 0.0134 17 | train_the_pile_pubmed_abstracts: 0.0113 18 | train_the_pile_uspto_backgrounds: 0.0036 19 | valid: 20 | valid_the_pile_pile_cc: 1.0 21 | model_name: tinyllama_1_1b 22 | -------------------------------------------------------------------------------- /mixture_config/config_1b/human.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | train_the_pile_arxiv: 0.1052 3 | train_the_pile_freelaw: 0.0386 4 | train_the_pile_nih_exporter: 0.0052 5 | train_the_pile_pubmed_central: 0.1071 6 | train_the_pile_wikipedia_en: 0.0919 7 | train_the_pile_dm_mathematics: 0.0198 8 | train_the_pile_github: 0.0427 9 | train_the_pile_philpapers: 0.0027 10 | train_the_pile_stackexchange: 0.0929 11 | train_the_pile_enron_emails: 0.0030 12 | train_the_pile_gutenberg_pg_19: 0.0199 13 | train_the_pile_pile_cc: 0.1121 14 | train_the_pile_ubuntu_irc: 0.0074 15 | train_the_pile_europarl: 0.0043 16 | train_the_pile_hackernews: 0.0075 17 | train_the_pile_pubmed_abstracts: 0.0845 18 | train_the_pile_uspto_backgrounds: 0.0420 19 | valid: 20 | valid_the_pile_pile_cc: 1.0 21 | model_name: tinyllama_1_1b 22 | -------------------------------------------------------------------------------- /model_training/pretrain_tinyllama_1b.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT=YOUR_PROJECT_NAME 2 | export WANDB_ENTITY=YOUR_WANDB_ENTITY 3 | export WANDB_API_KEY=YOUR_WANDB_API_KEY 4 | 5 | export MODEL_NAME=tinyllama_1B_n$1 6 | export WANDB_NAME=$MODEL_NAME 7 | export NUMBER_OF_GPUS=8 8 | # you can specify the config name here or pass it as an argument $1 9 | export CONFIG_NAME=$1 10 | 11 | lightning run model \ 12 | --node-rank=0 \ 13 | --main-address=127.0.0.1 \ 14 | --accelerator=cuda \ 15 | --num-nodes=1 \ 16 | --devices=$NUMBER_OF_GPUS \ 17 | pretrain/tinyllama.py --devices $NUMBER_OF_GPUS \ 18 | --train_data_dir lit_dataset_regmix \ 19 | --val_data_dir lit_dataset_regmix \ 20 | --data_yaml_file ../mixture_config/config_1b/$CONFIG_NAME.yaml \ 21 | --out_name $MODEL_NAME \ 22 | --resume True 23 | -------------------------------------------------------------------------------- /model_training/pretrain_tinyllama_1m.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT=YOUR_PROJECT_NAME 2 | export WANDB_ENTITY=YOUR_WANDB_ENTITY 3 | export WANDB_API_KEY=YOUR_WANDB_API_KEY 4 | 5 | export MODEL_NAME=tinyllama_1M_n$1 6 | export WANDB_NAME=$MODEL_NAME 7 | export NUMBER_OF_GPUS=1 8 | # you can specify the config index here or pass it as an argument 9 | export CONFIG_INDEX=$1 10 | 11 | lightning run model \ 12 | --node-rank=0 \ 13 | --main-address=127.0.0.1 \ 14 | --accelerator=cuda \ 15 | --num-nodes=1 \ 16 | --devices=$NUMBER_OF_GPUS \ 17 | pretrain/tinyllama.py --devices $NUMBER_OF_GPUS \ 18 | --train_data_dir lit_dataset_regmix \ 19 | --val_data_dir lit_dataset_regmix \ 20 | --data_yaml_file ../mixture_config/config_1m/n$CONFIG_INDEX.yaml \ 21 | --out_name $MODEL_NAME \ 22 | --resume True 23 | -------------------------------------------------------------------------------- /model_training/preprocess/tokenizer/starcoder/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "additional_special_tokens": [ 4 | "<|endoftext|>", 5 | "", 6 | "", 7 | "", 8 | "", 9 | "", 10 | "", 11 | "", 12 | "", 13 | "", 14 | "", 15 | "", 16 | "", 17 | "", 18 | "", 19 | "", 20 | "", 21 | "", 22 | "" 23 | ], 24 | "bos_token": "<|endoftext|>", 25 | "eos_token": "<|endoftext|>", 26 | "model_max_length": 1000000000000000019884624838656, 27 | "tokenizer_class": "GPT2Tokenizer", 28 | "unk_token": "<|endoftext|>", 29 | "vocab_size": 49152 30 | } 31 | -------------------------------------------------------------------------------- /model_training/lit_gpt/__init__.py: -------------------------------------------------------------------------------- 1 | from lit_gpt.model import GPT 2 | from lit_gpt.config import Config 3 | from lit_gpt.tokenizer import Tokenizer 4 | from lit_gpt.fused_cross_entropy import FusedCrossEntropyLoss 5 | from lightning_utilities.core.imports import RequirementCache 6 | 7 | if not bool(RequirementCache("torch>=2.1.0dev")): 8 | raise ImportError( 9 | "Lit-GPT requires torch nightly (future torch 2.1). Please follow the installation instructions in the" 10 | " repository README.md" 11 | ) 12 | _LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.1.0.dev0") 13 | if not bool(_LIGHTNING_AVAILABLE): 14 | raise ImportError( 15 | "Lit-GPT requires Lightning nightly (future lightning 2.1). Please run:\n" 16 | f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}" 17 | ) 18 | 19 | 20 | __all__ = ["GPT", "Config", "Tokenizer"] 21 | -------------------------------------------------------------------------------- /evaluation/eval_lm_eval_harness.sh: -------------------------------------------------------------------------------- 1 | export OUT_FOLDER="eval_out" 2 | # custom name for saving the output 3 | export model_name="data-mixture-regmix-1b-seed-1" 4 | export model_args="pretrained=sail/data-mixture-regmix-1b,revision=seed-1" 5 | # task list 6 | tasks=( 7 | 'social_iqa' 8 | 'hellaswag' 9 | 'piqa' 10 | 'openbookqa' 11 | 'lambada_standard' 12 | 'sciq' 13 | 'arc_easy' 14 | 'copa' 15 | 'race' 16 | 'logiqa' 17 | 'qqp' 18 | 'winogrande' 19 | 'multirc' 20 | ) 21 | 22 | for few_shot in 0 1 2 3 4 5; do 23 | for task in "${tasks[@]}"; do 24 | # print the task name 25 | echo "Evaluating task: $task" 26 | 27 | lm_eval --model hf \ 28 | --model_args $model_args \ 29 | --tasks $task \ 30 | --batch_size auto:4 \ 31 | --num_fewshot $few_shot \ 32 | --output_path $OUT_FOLDER/$few_shot/$model_name/$task 33 | done 34 | done 35 | -------------------------------------------------------------------------------- /model_training/preprocess/run_preprocess.sh: -------------------------------------------------------------------------------- 1 | export DATASET_SHORT_NAME="the_pile" 2 | 3 | python download_dataset.py --dataset_name sail/regmix-data-sample 4 | # WARNING: you can choose to download the full dataset (around 1TB) by running the following command 5 | # python download_data.py --dataset_name sail/regmix-data 6 | 7 | # We use gptneox tokenizer for the dataset to be consistent with the flagship method DoReMi 8 | python prepare_file_domain.py --source_path sail/regmix-data-sample --tokenizer_path tokenizer/gptneox --destination_path ../lit_dataset_regmix --short_name $DATASET_SHORT_NAME --split train 9 | 10 | # 131136 = 2049 * 64, which means the chunk size is relatively smaller due to the size of the validation set, especially for low-resource domains 11 | python prepare_file_domain.py --source_path sail/regmix-data-sample --tokenizer_path tokenizer/gptneox --destination_path ../lit_dataset_regmix --short_name $DATASET_SHORT_NAME --split valid --chunk_size 131136 12 | -------------------------------------------------------------------------------- /mixture_config/config_1b/regmix.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | train_the_pile_arxiv: 0.0012046169821426883 3 | train_the_pile_freelaw: 0.001454510048554701 4 | train_the_pile_nih_exporter: 0.001231640306882902 5 | train_the_pile_pubmed_central: 0.003108561825532002 6 | train_the_pile_wikipedia_en: 0.01593264140324679 7 | train_the_pile_dm_mathematics: 0.00031106907908634156 8 | train_the_pile_github: 0.00022861228152440253 9 | train_the_pile_philpapers: 1.329107360676338e-05 10 | train_the_pile_stackexchange: 0.00029547405933203174 11 | train_the_pile_enron_emails: 0.0016691646199353991 12 | train_the_pile_gutenberg_pg_19: 0.001612531300038395 13 | train_the_pile_pile_cc: 0.8701291419934237 14 | train_the_pile_ubuntu_irc: 0.06417728505869834 15 | train_the_pile_europarl: 2.9166170357771267e-06 16 | train_the_pile_hackernews: 0.011925517591888925 17 | train_the_pile_pubmed_abstracts: 0.02424425081714838 18 | train_the_pile_uspto_backgrounds: 0.0024587749419225434 19 | valid: 20 | valid_the_pile_pile_cc: 1.0 21 | model_name: tinyllama_1_1b 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sea AI Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /regression_fitting/README.md: -------------------------------------------------------------------------------- 1 | # Regression Fitting 2 | 3 | This directory contains the code for fitting the regression model and for predicting the optimal data mixture. 4 | 5 | ## Prepare Data 6 | 7 | Before fitting the regression model, you need to prepare the data obtained after training the proxy models. If you have not trained the proxy models, please refer to the [mixture_config](../mixture_config/README.md) directory and the [model_training](../model_training/README.md) directory for more details. 8 | 9 | In our paper, we use the validation loss on the Pile-CC subset as the **Target**, and the domain weights as the **Features** for regression model fitting. The already prepared data is stored in the [data](../data) directory. You can also prepare your own data by following the instructions in the [mixture_config](../mixture_config/README.md) directory. 10 | 11 | 12 | ## Model Fitting 13 | 14 | You can follow the [notebook](regression.ipynb) to do both: 15 | - Regression fitting with proxy model training logs 16 | - Simulate and choose the optimal data mixture 17 | 18 | With the notebook, you can easily fit the regression model and predict the optimal data mixture for training the large language models. -------------------------------------------------------------------------------- /mixture_config/README.md: -------------------------------------------------------------------------------- 1 | ## Data Mixture Configuration 2 | 3 | This directory contains scripts for synthesizing and visualizing data mixtures used in the experiments described in the paper. 4 | 5 | ### synthesize_mixture.py 6 | 7 | This script is used to synthesize data mixtures for training the proxy models (1M models in the paper). You can use the following command to generate the data mixtures: 8 | 9 | ```bash 10 | python synthesize_mixture.py --num_configs 512 --output_folder /path/to/configs 11 | ``` 12 | 13 | By default, it generates 512 configurations following the settings specified within the script. The configurations are saved in the `config_1m` directory. 14 | 15 | ### visualize_mixture.py 16 | 17 | This script is used to visualize the data mixtures generated for training the proxy models. By default, it visualizes the configurations stored in the `config_1m` directory. The visualizations are saved in `weight_distributions.png`. 18 | 19 | If you want to visualize a different folder, you can use the following command: 20 | 21 | ```shell 22 | python visualize_mixture.py --config_folder 23 | ``` 24 | 25 | Note that the folder must contain several yaml files which starts from `n` and ends with `.yaml`. 26 | 27 | ## Weight Distribution 28 | 29 | The following image illustrates a possible weight distribution for the data mixtures: 30 | 31 | ![Weight Distribution](../misc/weight_distributions.png) 32 | -------------------------------------------------------------------------------- /regression_fitting/collect_mixture_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from copy import copy 3 | import yaml 4 | import os 5 | import argparse 6 | 7 | def read_config(config_file): 8 | # read the yaml config 9 | with open(config_file, "r") as f: 10 | config = yaml.load(f, Loader=yaml.FullLoader) 11 | new_config = {} 12 | train_keys = list(config["train"].keys()) 13 | for key in train_keys: 14 | # remove train_doremi_sample prefix 15 | if key.startswith("train_"): 16 | new_config[key] = config["train"][key] 17 | 18 | flatten_dict = {} 19 | for key, value in new_config.items(): 20 | if type(value) == float: 21 | flatten_dict[key] = round(value, 5) 22 | if type(value) == int: 23 | flatten_dict[key] = value 24 | return flatten_dict 25 | 26 | 27 | def gather_mixture_data(write_file_path, config_folder): 28 | # read all files in the config folder 29 | output_dict = {} 30 | for file_path in os.listdir(config_folder): 31 | # only read yaml files 32 | if not file_path.endswith(".yaml"): 33 | print("skip", file_path) 34 | continue 35 | full_path = os.path.join(config_folder, file_path) 36 | # index name is the file path remove the prefix "n" 37 | index_name = int(file_path.split(".")[0].replace("n", "")) 38 | config = read_config(full_path) 39 | # only the train part is valid 40 | output_dict[index_name] = config 41 | # convert the dict to dataframe 42 | df = pd.DataFrame(output_dict).T 43 | # the index column is the index name 44 | df.index.name = "index" 45 | # order by index name 46 | df = df.sort_index() 47 | df.to_csv(write_file_path) 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("--write_file_path", type=str, default="train_mixture_1m.csv") 52 | parser.add_argument("--config_folder", type=str, default="../mixture_config/config_1m") 53 | 54 | args = parser.parse_args() 55 | write_file_path = args.write_file_path 56 | config_folder = args.config_folder 57 | 58 | gather_mixture_data(write_file_path, config_folder) -------------------------------------------------------------------------------- /mixture_config/visualize_mixture.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import numpy as np 4 | from collections import defaultdict 5 | import matplotlib.pyplot as plt 6 | from argparse import ArgumentParser 7 | 8 | def visualize_yaml_points(folder): 9 | print(f"Processing YAML files in folder: {folder}") 10 | 11 | train_weight_dict = defaultdict(list) 12 | train_group_zero_dict = defaultdict(int) 13 | 14 | for file_path in os.listdir(folder): 15 | if file_path.endswith(".yaml"): 16 | full_path = os.path.join(folder, file_path) 17 | print(f"Processing file: {full_path}") 18 | 19 | with open(full_path, "r", encoding="utf8") as f: 20 | config = yaml.safe_load(f) 21 | train_config = config.get("train", {}) 22 | 23 | for k, v in train_config.items(): 24 | train_weight_dict[k].append(float(v)) 25 | 26 | zero_count = sum(1 for v in train_config.values() if float(v) < 1e-7) 27 | train_group_zero_dict[zero_count] += 1 28 | 29 | print("\n--- Weight Statistics ---\n") 30 | for k, v in train_weight_dict.items(): 31 | print(f"{k}: Max = {np.max(v):.6f}") 32 | 33 | print("\n--- Weight Distributions ---\n") 34 | for k, v in train_weight_dict.items(): 35 | hist, bins = np.histogram(v, bins=10) 36 | print(f"{k}:") 37 | print(f" Bins: {bins}") 38 | print(f" Counts: {hist}") 39 | print() 40 | 41 | print("--- Zero Value Counts ---") 42 | for k, v in train_weight_dict.items(): 43 | zero_count = sum(1 for x in v if x < 1e-7) 44 | print(f"{k}: {zero_count}") 45 | 46 | print("\n--- Groups Sorted by Zero Value Count ---\n") 47 | for k, v in sorted(train_group_zero_dict.items()): 48 | print(f"Groups with {k} zero values: {v}") 49 | 50 | # Visualize weight distributions 51 | plot_weight_distributions(train_weight_dict) 52 | 53 | def plot_weight_distributions(train_weight_dict): 54 | num_keys = len(train_weight_dict) 55 | fig, axs = plt.subplots(num_keys, 1, figsize=(10, 5 * num_keys), tight_layout=True) 56 | 57 | for i, (k, v) in enumerate(train_weight_dict.items()): 58 | axs[i].hist(v, bins=20, edgecolor='black') 59 | axs[i].set_title(f'Distribution of weights for {k}') 60 | axs[i].set_xlabel('Weight value') 61 | axs[i].set_ylabel('Frequency') 62 | # log y axis 63 | axs[i].set_yscale('log') 64 | 65 | plt.savefig('weight_distributions.png') 66 | print("\nWeight distribution plot saved as 'weight_distributions.png'") 67 | 68 | if __name__ == "__main__": 69 | parser = ArgumentParser() 70 | parser.add_argument("--config_folder", type=str, default="config_1m", help="Folder path containing YAML files") 71 | folder_path = parser.parse_args().config_folder 72 | visualize_yaml_points(folder_path) 73 | -------------------------------------------------------------------------------- /model_training/lit_gpt/tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | class Tokenizer: 9 | def __init__(self, checkpoint_dir: Path) -> None: 10 | # some checkpoints have both files, `.model` takes precedence 11 | if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): 12 | from sentencepiece import SentencePieceProcessor 13 | 14 | self.processor = SentencePieceProcessor(model_file=str(vocabulary_path)) 15 | self.backend = "sentencepiece" 16 | self.bos_id = self.processor.bos_id() 17 | self.eos_id = self.processor.eos_id() 18 | elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file(): 19 | from tokenizers import Tokenizer as HFTokenizer 20 | 21 | self.processor = HFTokenizer.from_file(str(vocabulary_path)) 22 | self.backend = "huggingface" 23 | with open(checkpoint_dir / "tokenizer_config.json") as fp: 24 | config = json.load(fp) 25 | bos_token = config.get("bos_token") 26 | self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None 27 | self.eos_id = self.token_to_id(config["eos_token"]) 28 | else: 29 | raise NotImplementedError 30 | 31 | @property 32 | def vocab_size(self) -> int: 33 | if self.backend == "huggingface": 34 | return self.processor.get_vocab_size(with_added_tokens=False) 35 | if self.backend == "sentencepiece": 36 | return self.processor.vocab_size() 37 | raise RuntimeError 38 | 39 | def token_to_id(self, token: str) -> int: 40 | if self.backend == "huggingface": 41 | id_ = self.processor.token_to_id(token) 42 | elif self.backend == "sentencepiece": 43 | id_ = self.processor.piece_to_id(token) 44 | else: 45 | raise RuntimeError 46 | if id_ is None: 47 | raise ValueError(f"token {token!r} not found in the collection.") 48 | return id_ 49 | 50 | def encode( 51 | self, 52 | string: str, 53 | device: Optional[torch.device] = None, 54 | bos: bool = False, 55 | eos: bool = True, 56 | max_length: int = -1, 57 | ) -> torch.Tensor: 58 | if self.backend == "huggingface": 59 | tokens = self.processor.encode(string).ids 60 | elif self.backend == "sentencepiece": 61 | tokens = self.processor.encode(string) 62 | else: 63 | raise RuntimeError 64 | if bos: 65 | bos_id = self.bos_id 66 | if bos_id is None: 67 | raise NotImplementedError("This tokenizer does not defined a bos token") 68 | tokens = [bos_id] + tokens 69 | if eos: 70 | tokens = tokens + [self.eos_id] 71 | if max_length > 0: 72 | tokens = tokens[:max_length] 73 | return torch.tensor(tokens, dtype=torch.int, device=device) 74 | 75 | def decode(self, tensor: torch.Tensor) -> str: 76 | tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() 77 | return self.processor.decode(tokens) 78 | -------------------------------------------------------------------------------- /model_training/lit_gpt/fused_rotary_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import math 4 | from typing import Optional, Tuple 5 | 6 | import rotary_emb 7 | import torch 8 | from einops import rearrange, repeat 9 | 10 | class ApplyRotaryEmb(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, x, cos, sin, interleaved=False, inplace=False): 13 | """ 14 | x: (batch_size, seqlen, nheads, headdim) 15 | cos, sin: (seqlen, rotary_dim / 2) 16 | interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead 17 | of 1st half and 2nd half (GPT-NeoX style). 18 | rotary_dim must be <= headdim 19 | Apply rotary embedding to the first rotary_dim of x. 20 | """ 21 | batch, seqlen, nheads, headdim = x.shape 22 | rotary_seqlen, rotary_dim = cos.shape 23 | rotary_dim *= 2 24 | assert rotary_dim <= headdim 25 | assert seqlen <= rotary_seqlen 26 | assert sin.shape == (rotary_seqlen, rotary_dim // 2) 27 | x_ro = x[..., :rotary_dim] 28 | x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) 29 | out = torch.empty_like(x) if not inplace else x 30 | out_ro = out[..., :rotary_dim] 31 | if inplace: 32 | o1, o2 = x1, x2 33 | else: 34 | o1, o2 = ( 35 | out_ro.chunk(2, dim=-1) 36 | if not interleaved 37 | else (out_ro[..., ::2], out_ro[..., 1::2]) 38 | ) 39 | rotary_emb.apply_rotary( 40 | x1, 41 | x2, 42 | rearrange(cos[:seqlen], "s d -> s 1 d"), 43 | rearrange(sin[:seqlen], "s d -> s 1 d"), 44 | o1, 45 | o2, 46 | False, 47 | ) 48 | if not inplace and rotary_dim < headdim: 49 | out[..., rotary_dim:].copy_(x[..., rotary_dim:]) 50 | ctx.save_for_backward(cos, sin) 51 | ctx.interleaved = interleaved 52 | ctx.inplace = inplace 53 | return out if not inplace else x 54 | 55 | @staticmethod 56 | def backward(ctx, do): 57 | cos, sin = ctx.saved_tensors 58 | _, seqlen, _, headdim = do.shape 59 | rotary_dim = cos.shape[-1] 60 | rotary_dim *= 2 61 | inplace = ctx.inplace 62 | do_ro = do[..., :rotary_dim] 63 | do1, do2 = ( 64 | do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2]) 65 | ) 66 | dx = torch.empty_like(do) if not inplace else do 67 | if inplace: 68 | dx1, dx2 = do1, do2 69 | else: 70 | dx_ro = dx[..., :rotary_dim] 71 | dx1, dx2 = ( 72 | dx_ro.chunk(2, dim=-1) 73 | if not ctx.interleaved 74 | else (dx_ro[..., ::2], dx_ro[..., 1::2]) 75 | ) 76 | rotary_emb.apply_rotary( 77 | do1, 78 | do2, 79 | rearrange(cos[:seqlen], "s d -> s 1 d"), 80 | rearrange(sin[:seqlen], "s d -> s 1 d"), 81 | dx1, 82 | dx2, 83 | True, 84 | ) 85 | if not inplace and rotary_dim < headdim: 86 | dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) 87 | return dx, None, None, None, None 88 | 89 | 90 | apply_rotary_emb_func = ApplyRotaryEmb.apply 91 | 92 | -------------------------------------------------------------------------------- /regression_fitting/collect_loss_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import wandb 3 | from tqdm import tqdm 4 | import yaml 5 | import os 6 | from copy import copy 7 | import argparse 8 | 9 | # find your API key at https://wandb.ai/authorize 10 | WANDB_API_KEY = "YOUR_API_KEY" 11 | # Project is specified by 12 | RPOJECT_NAME = "YOUR_PROJECT_NAME" 13 | 14 | # by default we only take the pile cc val loss, and you can also choose other as the target 15 | KEY_METRICS = ["metric/the_pile_pile_cc_val_loss", 16 | "metric/train_loss"] 17 | 18 | # this is the prefix for the wandb runs 19 | RUN_NAME_PREFIX = "tinyllama_1M_n" 20 | 21 | # 1000 step corresponds to the 1B token 22 | SELECT_STEP = 1000 23 | 24 | # Please fill in your own API key 25 | api = wandb.Api(api_key=WANDB_API_KEY) 26 | runs = api.runs(RPOJECT_NAME, per_page=20) 27 | 28 | 29 | def export_wandb_runs(write_file_path, enable_empty_row=False): 30 | output_data = [] 31 | records = set() 32 | for run in tqdm(runs): 33 | # skip invalid runs 34 | if not run.name.startswith(RUN_NAME_PREFIX): 35 | print("skip", run.name) 36 | continue 37 | 38 | run_index = int(run.name.replace(RUN_NAME_PREFIX, "")) 39 | # .summary contains the output keys/values for metrics like accuracy. 40 | # We call ._json_dict to omit large files 41 | data_frame = run.history(samples=10000000) 42 | keep_columns = [col for col in data_frame.columns if col in KEY_METRICS + ["trainer/global_step"]] 43 | # only keep the pre-defined columns 44 | data_frame = data_frame[keep_columns] 45 | # select the row when global_step = SELECT_STEP 46 | data_frame = data_frame[data_frame["trainer/global_step"] == SELECT_STEP] 47 | # take the first non-nan value for each column 48 | first_non_nan_indices = data_frame.apply(lambda col: col.first_valid_index()) 49 | if len(data_frame) == 0 or "NaN" in str(first_non_nan_indices): 50 | if enable_empty_row: 51 | # add a row of nan if no row is selected 52 | data_frame = pd.DataFrame([[float("nan") for _ in range(len(keep_columns))]], columns=keep_columns) 53 | data_frame["trainer/global_step"] = SELECT_STEP 54 | else: 55 | print("skip", run.name) 56 | continue 57 | else: 58 | new_df = pd.DataFrame({col: [data_frame[col][idx]] for col, idx in first_non_nan_indices.items()}) 59 | data_frame = new_df 60 | 61 | # no duplicated run name 62 | if run.name in records: 63 | continue 64 | 65 | data_frame["index"] = run_index 66 | # set index as the index column 67 | output_data.append(data_frame.to_dict("records")[0]) 68 | records.add(run.name) 69 | 70 | runs_df = pd.DataFrame.from_dict(output_data) 71 | # set the index column as the index 72 | runs_df.set_index("index", inplace=True) 73 | # order by index 74 | runs_df = runs_df.sort_index() 75 | 76 | # delete global_step, model, and train_loss columns 77 | runs_df = runs_df.drop(columns=["trainer/global_step"]) 78 | runs_df.to_csv(write_file_path) 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument("--write_file_path", type=str, default="train_pile_loss_1m.csv") 84 | 85 | args = parser.parse_args() 86 | write_file_path = args.write_file_path 87 | export_wandb_runs(write_file_path, enable_empty_row=False) -------------------------------------------------------------------------------- /model_training/preprocess/prepare_file_domain.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import os 4 | from pathlib import Path 5 | import sys 6 | from typing import List 7 | import numpy as np 8 | from tqdm import tqdm 9 | from multiprocessing import Process, cpu_count, Pool 10 | 11 | # support running without installing as a package 12 | wd = Path(__file__).parent.parent.resolve() 13 | sys.path.append(str(wd)) 14 | 15 | import lit_gpt.packed_dataset as packed_dataset 16 | from lit_gpt import Tokenizer 17 | 18 | # Filename for SlimPajama 19 | slimpajama_sets = { 20 | "train": "train/*", 21 | "valid": "valid/*", 22 | } 23 | 24 | 25 | def prepare_full( 26 | source_path: Path, 27 | tokenizer_path: Path, 28 | destination_path: Path, 29 | chunk_size: int, 30 | shortname: str = "the_pile_unzip", 31 | split: str="train", 32 | filenames_subset: List[str] = None, 33 | process_id: int = 0 34 | ) -> None: 35 | destination_path.mkdir(parents=True, exist_ok=True) 36 | 37 | tokenizer = Tokenizer(tokenizer_path) 38 | # Use the provided filenames_subset or default to all filenames 39 | filenames = filenames_subset 40 | 41 | if not filenames: 42 | raise RuntimeError( 43 | f"No files matching {slimpajama_sets[split]} found at {source_path}. \n" 44 | "Make sure you download the data..." 45 | ) 46 | 47 | builder = packed_dataset.PackedDatasetBuilder( 48 | outdir=destination_path, 49 | prefix=f"{split}_{shortname}_{process_id}", # Use process_id to differentiate builders 50 | chunk_size=chunk_size, 51 | sep_token=tokenizer.bos_id, 52 | dtype="auto", 53 | vocab_size=tokenizer.vocab_size, 54 | ) 55 | 56 | for filepath in filenames: 57 | print(f"Processing {filepath}") 58 | with open(filepath, "r", encoding="utf-8") as f: 59 | for row in tqdm(f): 60 | text = json.loads(row)["text"] 61 | text_ids = tokenizer.encode(text) 62 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 63 | 64 | 65 | def process_file(args): 66 | source_path, tokenizer_path, destination_path, chunk_size, subset_name, split, filename, index = args 67 | prepare_full(source_path, tokenizer_path, destination_path, chunk_size, subset_name, split, [filename], index) 68 | 69 | 70 | def prepare( 71 | source_path: Path = Path("sail/regmix-data"), 72 | tokenizer_path: Path = Path("tokenizer/gptneox"), 73 | destination_path: Path = Path("data/regmix_data"), 74 | short_name: str = "ind", 75 | chunk_size: int = 2049 * 256, 76 | split: str="train", 77 | percentage: float = 1.0, 78 | ) -> None: 79 | import time 80 | 81 | filenames = glob.glob(os.path.join(source_path, slimpajama_sets[split]), recursive=True) 82 | filenames = filenames[:int(len(filenames) * percentage)] 83 | 84 | start_time = time.time() 85 | 86 | tasks = [] 87 | for i, filename in enumerate(filenames): 88 | subset_name = short_name + "_" + filename.split("/")[-1].split(".")[0] 89 | tasks.append((source_path, tokenizer_path, destination_path, chunk_size, subset_name, split, filename, i)) 90 | 91 | with Pool(processes=min(cpu_count(), len(filenames))) as pool: 92 | pool.map(process_file, tasks) 93 | 94 | end_time = time.time() 95 | elapsed_time = end_time - start_time 96 | print(f"Time taken: {elapsed_time:.2f} seconds") 97 | 98 | 99 | if __name__ == "__main__": 100 | from jsonargparse import CLI 101 | CLI(prepare) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | .DS_Store -------------------------------------------------------------------------------- /model_training/README.md: -------------------------------------------------------------------------------- 1 | # Model Training 2 | 3 | This document describes how to train a model with the awesome [TinyLlama](https://github.com/jzhang38/TinyLlama) codebase. And a large part of this README is borrowed from the TinyLlama README. 4 | 5 | ## Setup 6 | 7 | You can choose to install the dependencies manually or use the provided Docker image. The Docker image is recommended for a hassle-free setup. If you choose to install the dependencies manually, please follow the instructions below. 8 | 9 | ### Installation Manually 10 | 11 | You should install PyTorch with CUDA support, along with xformers and Flash-Attention 2. 12 | 13 | #### Prerequisites 14 | 15 | - CUDA 11.8 installed on your system 16 | 17 | #### 1. Install PyTorch Nightly 18 | 19 | First, you should install the nightly build of PyTorch with CUDA support: 20 | 21 | ```bash 22 | pip install --index-url https://download.pytorch.org/whl/nightly/cu118 --pre 'torch>=2.1.0dev' 23 | ``` 24 | 25 | #### 2. Install Xformers 26 | 27 | ```bash 28 | pip install -U xformers --index-url https://download.pytorch.org/whl/cu118 29 | ``` 30 | 31 | #### 3. Install Flash-Attention 2 32 | 33 | Next, we'll install Flash-Attention 2 and its associated operators: 34 | 35 | ```bash 36 | git clone https://github.com/Dao-AILab/flash-attention.git 37 | cd flash-attention && \ 38 | python setup.py install && \ 39 | cd csrc/rotary && pip install . && \ 40 | cd ../layer_norm && pip install . && \ 41 | cd ../xentropy && pip install . && \ 42 | cd ../.. && rm -rf flash-attention 43 | ``` 44 | 45 | #### 4. Install Remaining Dependencies 46 | 47 | Finally, install the remaining required packages: 48 | 49 | ```bash 50 | pip install -r requirements.txt tokenizers sentencepiece transformers wandb datasets huggingface_hub 51 | ``` 52 | 53 | > [!NOTE] 54 | > The build process for Flash-Attention may take 5 minutes or more. 55 | > If you encounter any issues, ensure your CUDA installation is correct and compatible with the PyTorch version you're installing. 56 | > It's recommended to use a virtual environment for this installation to avoid conflicts with other Python packages. 57 | 58 | If you need any further assistance or encounter any issues during the installation, please don't hesitate to ask for help. 59 | 60 | ### Using Docker 61 | 62 | You can also use the following command to pull the already built Docker: 63 | 64 | ```shell 65 | docker pull siviltaramqian/tinyllama:latest 66 | ``` 67 | 68 | And you may run the docker image along with the following command: 69 | 70 | ```shell 71 | docker run --gpus all -it --rm siviltaramqian/tinyllama:latest ./pretrain_tinyllama_1m.sh 1 72 | ``` 73 | 74 | ## Wandb Integration 75 | 76 | By default we use the wandb for collecting the data to avoid saving massive small models and logs on the local machine. If you want to use the wandb, you need to create an account on the [wandb](https://wandb.ai/site) and get the API key. Then you should set the following environment variable in both `pretrain_tinyllama_1m.sh` and `pretrain_tinyllama_1b.sh`: 77 | 78 | ```shell 79 | # wandb project name, entity, and API key 80 | export WANDB_PROJECT=YOUR_PROJECT_NAME 81 | export WANDB_ENTITY=YOUR_WANDB_ENTITY 82 | export WANDB_API_KEY=YOUR_WANDB_API_KEY 83 | ``` 84 | 85 | ## Preprocess 86 | 87 | Before training the model, you need to preprocess the data. We provide the easy-to-use script for preprocessing the data. You can use the following command to preprocess the data: 88 | 89 | ```shell 90 | cd preprocess 91 | bash run_preprocess.sh 92 | ``` 93 | 94 | By default you will first download the `regmix-data-sample` from the HuggingFace and then preprocess the data. The JSONL data will be saved in the `preprocess/sail/regmix-data-sample` directory, and the preprocessed data will be saved in the `lit_dataset_regmix` directory. 95 | 96 | ## Train 97 | 98 | After preprocessing the data, you can train the model using the following command: 99 | 100 | ```shell 101 | ./pretrain_tinyllama_1m.sh 1 102 | ``` 103 | 104 | The passed argument is the configuration index. After the setup described in [mixture_config](../mixture_config/), you should have 512 configurations for training the proxy models. You can change the configuration index to train different configurations. The training of 1M models should take around 20 minutes to finish on a single A100 GPU, which is pre-trained on 1B tokens. 105 | 106 | You can also train a larger model using the following command: 107 | 108 | ```shell 109 | ./pretrain_tinyllama_1b.sh regmix 110 | ``` 111 | 112 | The regmix is the configuration name for the 1B model. The full configuration is stored in the [mixture_config/config_1b](../mixture_config/config_1b/) directory. The training of 1B models should take around 2 days to finish on 8 x A100 GPU, which is pre-trained on 25B tokens. 113 | -------------------------------------------------------------------------------- /data/test_mixture_1B.csv: -------------------------------------------------------------------------------- 1 | index,train_the_pile_arxiv,train_the_pile_freelaw,train_the_pile_nih_exporter,train_the_pile_pubmed_central,train_the_pile_wikipedia_en,train_the_pile_dm_mathematics,train_the_pile_github,train_the_pile_philpapers,train_the_pile_stackexchange,train_the_pile_enron_emails,train_the_pile_gutenberg_pg_19,train_the_pile_pile_cc,train_the_pile_ubuntu_irc,train_the_pile_europarl,train_the_pile_hackernews,train_the_pile_pubmed_abstracts,train_the_pile_uspto_backgrounds 2 | 0,0.123,0.065,0.0,0.126,0.036,0.0,0.034,0.0,0.039,0.0,0.0,0.27,0.0,0.0,0.0,0.0,0.307 3 | 1,0.066,0.071,0.0,0.211,0.013,0.0,0.153,0.033,0.097,0.0,0.0,0.101,0.0,0.0,0.011,0.136,0.106 4 | 2,0.055,0.052,0.004,0.177,0.02,0.011,0.095,0.0,0.18,0.0,0.016,0.381,0.001,0.0,0.005,0.0,0.003 5 | 3,0.059,0.083,0.0,0.174,0.177,0.0,0.194,0.0,0.0,0.0,0.0,0.192,0.005,0.109,0.0,0.005,0.0 6 | 4,0.201,0.004,0.014,0.243,0.01,0.03,0.017,0.0,0.103,0.0,0.002,0.359,0.0,0.0,0.0,0.014,0.002 7 | 5,0.036,0.212,0.0,0.153,0.005,0.047,0.205,0.0,0.075,0.0,0.0,0.209,0.0,0.001,0.0,0.002,0.055 8 | 6,0.042,0.113,0.0,0.089,0.022,0.007,0.028,0.0,0.011,0.0,0.217,0.232,0.08,0.117,0.018,0.011,0.011 9 | 7,0.126,0.21,0.0,0.123,0.055,0.008,0.008,0.0,0.129,0.0,0.035,0.288,0.0,0.0,0.0,0.016,0.0 10 | 8,0.184,0.009,0.0,0.094,0.035,0.007,0.106,0.0,0.142,0.0,0.0,0.341,0.0,0.0,0.002,0.005,0.075 11 | 9,0.226,0.046,0.0,0.261,0.001,0.001,0.189,0.0,0.077,0.0,0.01,0.114,0.003,0.0,0.0,0.039,0.033 12 | 10,0.107,0.276,0.0,0.157,0.009,0.0,0.024,0.0,0.051,0.0,0.001,0.273,0.0,0.003,0.034,0.009,0.056 13 | 11,0.139,0.048,0.0,0.184,0.032,0.001,0.055,0.0,0.109,0.0,0.0,0.354,0.0,0.003,0.0,0.075,0.0 14 | 12,0.101,0.047,0.001,0.119,0.049,0.092,0.078,0.0,0.002,0.0,0.051,0.283,0.057,0.0,0.0,0.061,0.057 15 | 13,0.099,0.002,0.022,0.501,0.003,0.0,0.017,0.043,0.065,0.0,0.091,0.055,0.0,0.006,0.0,0.007,0.088 16 | 14,0.251,0.024,0.0,0.101,0.17,0.0,0.048,0.019,0.007,0.0,0.0,0.339,0.017,0.0,0.0,0.0,0.024 17 | 15,0.147,0.046,0.0,0.196,0.14,0.008,0.237,0.0,0.06,0.0,0.012,0.111,0.0,0.0,0.001,0.01,0.032 18 | 16,0.228,0.016,0.0,0.204,0.02,0.036,0.02,0.004,0.002,0.0,0.0,0.244,0.0,0.004,0.0,0.196,0.026 19 | 17,0.0,0.019,0.0,0.084,0.159,0.009,0.012,0.0,0.052,0.0,0.001,0.361,0.296,0.0,0.0,0.001,0.007 20 | 18,0.501,0.005,0.0,0.156,0.17,0.0,0.022,0.017,0.062,0.0,0.002,0.061,0.002,0.0,0.0,0.0,0.002 21 | 19,0.101,0.03,0.0,0.272,0.021,0.099,0.124,0.0,0.113,0.0,0.054,0.154,0.0,0.001,0.0,0.011,0.02 22 | 20,0.047,0.014,0.0,0.163,0.218,0.0,0.137,0.0,0.173,0.0,0.001,0.19,0.029,0.007,0.011,0.008,0.001 23 | 21,0.031,0.073,0.0,0.053,0.129,0.0,0.066,0.0,0.12,0.0,0.089,0.057,0.001,0.0,0.031,0.351,0.001 24 | 22,0.078,0.024,0.0,0.302,0.027,0.0,0.04,0.0,0.007,0.0,0.002,0.499,0.0,0.0,0.0,0.0,0.021 25 | 23,0.068,0.181,0.0,0.126,0.07,0.001,0.195,0.0,0.24,0.0,0.0,0.023,0.0,0.0,0.0,0.059,0.036 26 | 24,0.074,0.214,0.0,0.135,0.011,0.0,0.121,0.006,0.024,0.0,0.001,0.088,0.001,0.0,0.004,0.132,0.189 27 | 25,0.076,0.085,0.0,0.214,0.005,0.0,0.127,0.0,0.204,0.0,0.147,0.138,0.002,0.0,0.0,0.001,0.001 28 | 26,0.05,0.039,0.0,0.049,0.068,0.019,0.042,0.0,0.146,0.0,0.01,0.302,0.0,0.008,0.0,0.01,0.255 29 | 27,0.067,0.052,0.0,0.221,0.052,0.0,0.101,0.0,0.001,0.0,0.265,0.214,0.026,0.0,0.0,0.002,0.0 30 | 28,0.244,0.023,0.0,0.064,0.151,0.0,0.073,0.0,0.02,0.0,0.017,0.383,0.01,0.0,0.0,0.007,0.007 31 | 29,0.073,0.087,0.026,0.175,0.017,0.101,0.1,0.019,0.054,0.0,0.0,0.12,0.134,0.037,0.0,0.053,0.002 32 | 30,0.234,0.015,0.0,0.086,0.287,0.026,0.04,0.0,0.022,0.0,0.0,0.134,0.0,0.0,0.0,0.022,0.134 33 | 31,0.08,0.134,0.0,0.255,0.058,0.037,0.171,0.0,0.015,0.0,0.045,0.182,0.0,0.0,0.0,0.016,0.008 34 | 32,0.105,0.007,0.0,0.407,0.045,0.054,0.017,0.0,0.126,0.0,0.009,0.167,0.0,0.007,0.0,0.047,0.008 35 | 33,0.295,0.029,0.0,0.061,0.124,0.0,0.006,0.0,0.006,0.0,0.047,0.364,0.0,0.026,0.004,0.0,0.037 36 | 34,0.142,0.122,0.001,0.065,0.0,0.001,0.006,0.003,0.001,0.0,0.014,0.618,0.001,0.0,0.0,0.0,0.025 37 | 35,0.279,0.01,0.0,0.184,0.0,0.0,0.108,0.0,0.097,0.0,0.039,0.198,0.0,0.0,0.0,0.083,0.002 38 | 36,0.052,0.07,0.253,0.4,0.003,0.0,0.033,0.0,0.019,0.0,0.0,0.031,0.0,0.0,0.018,0.002,0.119 39 | 37,0.251,0.007,0.007,0.331,0.107,0.0,0.13,0.0,0.021,0.0,0.001,0.006,0.12,0.0,0.0,0.005,0.014 40 | 38,0.239,0.087,0.0,0.223,0.029,0.0,0.049,0.0,0.202,0.0,0.0,0.156,0.0,0.0,0.0,0.012,0.001 41 | 39,0.157,0.062,0.0,0.039,0.096,0.007,0.057,0.0,0.174,0.0,0.015,0.181,0.0,0.089,0.012,0.016,0.095 42 | 40,0.422,0.213,0.0,0.08,0.019,0.001,0.026,0.0,0.003,0.0,0.01,0.026,0.0,0.0,0.0,0.101,0.099 43 | 41,0.466,0.075,0.0,0.07,0.006,0.0,0.044,0.0,0.078,0.0,0.0,0.2,0.0,0.0,0.0,0.028,0.031 44 | 42,0.027,0.041,0.0,0.116,0.021,0.001,0.067,0.0,0.137,0.0,0.001,0.549,0.002,0.0,0.0,0.002,0.037 45 | 43,0.063,0.089,0.0,0.219,0.001,0.05,0.291,0.0,0.002,0.0,0.0,0.238,0.0,0.001,0.001,0.045,0.0 46 | 44,0.121,0.008,0.0,0.093,0.008,0.016,0.012,0.0,0.408,0.0,0.006,0.156,0.013,0.001,0.0,0.005,0.153 47 | 45,0.041,0.025,0.0,0.111,0.092,0.062,0.121,0.0,0.124,0.0,0.0,0.214,0.129,0.006,0.012,0.012,0.052 48 | 46,0.033,0.048,0.0,0.22,0.027,0.002,0.169,0.0,0.082,0.0,0.057,0.312,0.0,0.0,0.0,0.0,0.05 49 | 47,0.114,0.116,0.0,0.081,0.038,0.031,0.109,0.0,0.001,0.0,0.021,0.428,0.001,0.0,0.0,0.031,0.029 50 | 48,0.082,0.12,0.0,0.051,0.067,0.034,0.205,0.0,0.036,0.0,0.0,0.371,0.0,0.0,0.0,0.029,0.004 51 | 49,0.091,0.084,0.0,0.343,0.0,0.174,0.144,0.0,0.009,0.0,0.019,0.122,0.001,0.003,0.001,0.006,0.004 52 | 50,0.194,0.04,0.022,0.126,0.046,0.028,0.048,0.01,0.099,0.0,0.04,0.229,0.0,0.002,0.0,0.089,0.027 53 | 51,0.011,0.022,0.0,0.37,0.006,0.0,0.14,0.0,0.058,0.0,0.216,0.101,0.033,0.0,0.002,0.026,0.015 54 | 52,0.039,0.063,0.0,0.079,0.0,0.002,0.482,0.0,0.012,0.0,0.0,0.269,0.0,0.0,0.0,0.002,0.052 55 | 53,0.294,0.119,0.0,0.186,0.023,0.005,0.023,0.0,0.001,0.0,0.002,0.213,0.023,0.0,0.0,0.024,0.088 56 | 54,0.012,0.16,0.0,0.311,0.014,0.0,0.117,0.0,0.004,0.0,0.236,0.037,0.007,0.0,0.0,0.007,0.094 57 | 55,0.25,0.058,0.0,0.104,0.044,0.0,0.028,0.0,0.06,0.0,0.0,0.363,0.0,0.0,0.0,0.086,0.007 58 | 56,0.137,0.085,0.0,0.085,0.059,0.0,0.039,0.0,0.017,0.009,0.007,0.435,0.0,0.0,0.001,0.004,0.122 59 | 57,0.176,0.007,0.0,0.05,0.122,0.001,0.088,0.069,0.05,0.0,0.0,0.339,0.006,0.012,0.0,0.004,0.077 60 | 58,0.471,0.038,0.0,0.218,0.005,0.0,0.097,0.0,0.016,0.0,0.018,0.112,0.017,0.0,0.0,0.001,0.006 61 | 59,0.081,0.153,0.0,0.17,0.017,0.033,0.041,0.048,0.077,0.0,0.001,0.268,0.095,0.0,0.0,0.0,0.016 62 | 60,0.107,0.016,0.0,0.218,0.003,0.0,0.238,0.0,0.113,0.0,0.0,0.272,0.001,0.0,0.0,0.02,0.013 63 | 61,0.278,0.141,0.0,0.257,0.099,0.009,0.041,0.0,0.027,0.0,0.0,0.128,0.0,0.0,0.0,0.0,0.02 64 | 62,0.119,0.085,0.027,0.294,0.02,0.073,0.038,0.0,0.046,0.001,0.026,0.232,0.0,0.001,0.0,0.013,0.025 65 | 63,0.131,0.006,0.03,0.075,0.0,0.093,0.369,0.0,0.06,0.0,0.002,0.188,0.001,0.003,0.017,0.016,0.009 66 | -------------------------------------------------------------------------------- /model_training/lit_gpt/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Literal, Optional, Type 3 | 4 | import torch 5 | from typing_extensions import Self 6 | 7 | import lit_gpt.model 8 | from lit_gpt.utils import find_multiple 9 | 10 | 11 | @dataclass 12 | class Config: 13 | org: str = "Lightning-AI" 14 | name: str = "lit-GPT" 15 | block_size: int = 4096 16 | vocab_size: int = 50254 17 | padding_multiple: int = 512 18 | padded_vocab_size: Optional[int] = None 19 | n_layer: int = 16 20 | n_head: int = 32 21 | n_embd: int = 4096 22 | resid_pdrop: float = 0.0 23 | embd_pdrop: float = 0.0 24 | attn_pdrop: float = 0.0 25 | rotary_percentage: float = 0.25 26 | parallel_residual: bool = True 27 | rope_base: int = 10000 28 | bias: bool = True 29 | # to use multi-head attention (MHA), set this to `n_head` (default) 30 | # to use multi-query attention (MQA), set this to 1 31 | # to use grouped-query attention (GQA), set this to a value in between 32 | # Example with `n_head=4` 33 | # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ 34 | # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ 35 | # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ 36 | # │ │ │ │ │ │ │ 37 | # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ 38 | # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ 39 | # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ 40 | # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ 41 | # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ 42 | # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ 43 | # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ 44 | # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ 45 | # MHA GQA MQA 46 | # n_query_groups=4 n_query_groups=2 n_query_groups=1 47 | # 48 | # credit https://arxiv.org/pdf/2305.13245.pdf 49 | n_query_groups: Optional[int] = None 50 | shared_attention_norm: bool = False 51 | _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" 52 | norm_eps: float = 1e-5 53 | _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP" 54 | intermediate_size: Optional[int] = None 55 | condense_ratio: int = 1 56 | 57 | def __post_init__(self): 58 | # error checking 59 | assert self.n_embd % self.n_head == 0 60 | # vocab size should be a power of 2 to be optimal on hardware. compute the closest value 61 | if self.padded_vocab_size is None: 62 | self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple) 63 | # compute the number of query groups 64 | if self.n_query_groups is not None: 65 | assert self.n_head % self.n_query_groups == 0 66 | else: 67 | self.n_query_groups = self.n_head 68 | # compute the intermediate size for MLP if not set 69 | if self.intermediate_size is None: 70 | if self._mlp_class == "LLaMAMLP": 71 | raise ValueError("The config needs to set the `intermediate_size`") 72 | self.intermediate_size = 4 * self.n_embd 73 | 74 | @property 75 | def head_size(self) -> int: 76 | return self.n_embd // self.n_head 77 | 78 | @classmethod 79 | def from_name(cls, name: str, **kwargs: Any) -> Self: 80 | conf_dict = name_to_config[name].copy() 81 | conf_dict.update(kwargs) 82 | return cls(**conf_dict) 83 | 84 | @property 85 | def mlp_class(self) -> Type: 86 | # `self._mlp_class` cannot be the type to keep the config json serializable 87 | return getattr(lit_gpt.model, self._mlp_class) 88 | 89 | @property 90 | def norm_class(self) -> Type: 91 | # `self._norm_class` cannot be the type to keep the config json serializable 92 | if self._norm_class == "RMSNorm": 93 | from lit_gpt.rmsnorm import RMSNorm 94 | 95 | return RMSNorm 96 | elif self._norm_class == "FusedRMSNorm": 97 | from lit_gpt.rmsnorm import FusedRMSNorm 98 | return FusedRMSNorm 99 | return getattr(torch.nn, self._norm_class) 100 | 101 | ############################# 102 | # Sea AI Lab - RegMix Paper 103 | ############################# 104 | regmix_llama = [ 105 | dict( 106 | org="RegMix Paper", 107 | name="tinyllama_1M", 108 | block_size=2048, 109 | vocab_size=50432, 110 | padding_multiple=64, 111 | n_layer=2, 112 | n_head=8, 113 | n_embd=256, 114 | rotary_percentage=1.0, 115 | parallel_residual=False, 116 | bias=False, 117 | _norm_class="FusedRMSNorm", 118 | norm_eps=1e-5, 119 | _mlp_class="LLaMAMLP", 120 | intermediate_size=512 121 | ), 122 | dict( 123 | org="RegMix Paper", 124 | name="tinycoder_1M", 125 | block_size=2048, 126 | vocab_size=49152, 127 | padding_multiple=64, 128 | n_layer=2, 129 | n_head=8, 130 | n_embd=256, 131 | rotary_percentage=1.0, 132 | parallel_residual=False, 133 | bias=False, 134 | _norm_class="FusedRMSNorm", 135 | norm_eps=1e-5, 136 | _mlp_class="LLaMAMLP", 137 | intermediate_size=512 138 | ), 139 | dict( 140 | org="RegMix Paper", 141 | name="tinyllama_60M", 142 | block_size=2048, 143 | vocab_size=50432, 144 | padding_multiple=64, 145 | n_layer=10, 146 | n_head=8, 147 | n_embd=768, 148 | rotary_percentage=1.0, 149 | parallel_residual=False, 150 | bias=False, 151 | _norm_class="FusedRMSNorm", 152 | norm_eps=1e-5, 153 | _mlp_class="LLaMAMLP", 154 | intermediate_size=1536 155 | ), 156 | dict( 157 | org="RegMix Paper", 158 | name="tinyllama_1_1b", 159 | block_size=2048, 160 | vocab_size=50432, 161 | padding_multiple=64, 162 | n_layer=22, 163 | n_head=16, 164 | n_embd=2048, 165 | rotary_percentage=1.0, 166 | parallel_residual=False, 167 | bias=False, 168 | _norm_class="FusedRMSNorm", 169 | norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6 170 | _mlp_class="LLaMAMLP", 171 | intermediate_size=5632 172 | ) 173 | ] 174 | configs = regmix_llama 175 | name_to_config = {config["name"]: config for config in configs} 176 | -------------------------------------------------------------------------------- /model_training/lit_gpt/fused_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import xentropy_cuda_lib 6 | 7 | # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for 8 | # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent 9 | # version of PyTorch. The following 2 lines are for backward compatibility with 10 | # older PyTorch. 11 | if "all_gather_into_tensor" not in dir(torch.distributed): 12 | torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base 13 | 14 | 15 | class SoftmaxCrossEntropyLossFn(torch.autograd.Function): 16 | @staticmethod 17 | def forward( 18 | ctx, 19 | logits, 20 | labels, 21 | smoothing=0.0, 22 | ignored_index=-100, 23 | inplace_backward=False, 24 | process_group=None, 25 | ): 26 | """ 27 | logits: (batch, vocab_size) 28 | labels: (batch,) 29 | If process_group is not None, we're doing Tensor Parallel: each process is responsible for 30 | one part of the vocab. The loss needs to be aggregated across processes. 31 | """ 32 | batch, vocab_size = logits.shape 33 | assert labels.shape == (batch,) 34 | world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) 35 | ctx.total_classes = world_size * vocab_size 36 | 37 | if world_size == 1: 38 | losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) 39 | losses.masked_fill_(labels == ignored_index, 0) 40 | labels_local = labels 41 | else: 42 | rank = torch.distributed.get_rank(process_group) 43 | vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size 44 | 45 | # Create a mask of valid vocab ids (1 means it needs to be masked). 46 | labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) 47 | ignored_mask = labels == ignored_index 48 | labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) 49 | 50 | # For tensor parallel cross entropy with smoothing, we want to pass in the total number 51 | # of classes so that smoothing can be applied correctly. If total_classes=-1, use the 52 | # last dimension of the input tensor. 53 | losses, lse_local = xentropy_cuda_lib.forward( 54 | logits, labels_local, smoothing, world_size * vocab_size 55 | ) 56 | assert lse_local.shape == (batch,) 57 | assert losses.shape == (batch,) 58 | losses.masked_fill_(ignored_mask, 0) 59 | # For labels == ignored_index, the loss is always 0. 60 | # If there's no smoothing, if labels are in the vocab of this partition, losses contains 61 | # lse_local - predicted logit, and 0 otherwise. 62 | # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains 63 | # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) 64 | # For labels not in the vocab of this partition, losses contains 65 | # 0.1 * (lse_local - sum logit / total_classes). 66 | 67 | lse_allgather = torch.empty( 68 | world_size, batch, dtype=lse_local.dtype, device=lse_local.device 69 | ) 70 | torch.distributed.all_gather_into_tensor( 71 | lse_allgather, lse_local.contiguous(), group=process_group 72 | ) 73 | handle_losses = torch.distributed.all_reduce( 74 | losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True 75 | ) 76 | lse = torch.logsumexp(lse_allgather, dim=0) 77 | # If there's no smoothing, the total losses are lse_local - predicted_logit, 78 | # we just have to subtract the lse_local and add the lse (global). 79 | # If there's smoothing=0.1, the total losses are 80 | # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) 81 | # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). 82 | rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor") 83 | lse_local = lse_allgather[ 84 | rank_per_sample, torch.arange(batch, device=lse_allgather.device) 85 | ] 86 | 87 | handle_losses.wait() 88 | if smoothing == 0.0: 89 | losses += lse - lse_local 90 | else: 91 | losses += (1 - smoothing) * (lse - lse_local) + smoothing * ( 92 | lse - lse_allgather.sum(dim=0) 93 | ) 94 | losses.masked_fill_(ignored_mask, 0) 95 | 96 | ctx.save_for_backward(logits, lse, labels_local) 97 | ctx.smoothing = smoothing 98 | ctx.ignored_index = ignored_index 99 | ctx.inplace_backward = inplace_backward 100 | return losses 101 | 102 | @staticmethod 103 | def backward(ctx, grad_loss): 104 | logits, lse, labels = ctx.saved_tensors 105 | grad_loss = grad_loss.contiguous() 106 | grad_loss.masked_fill_(labels == ctx.ignored_index, 0) 107 | grad_logits = xentropy_cuda_lib.backward( 108 | grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes 109 | ) 110 | return grad_logits, None, None, None, None, None, None 111 | 112 | 113 | class FusedCrossEntropyLoss(nn.Module): 114 | def __init__( 115 | self, 116 | ignore_index=-100, 117 | reduction="mean", 118 | label_smoothing=0.0, 119 | inplace_backward=True, 120 | process_group=None, 121 | ): 122 | super().__init__() 123 | if reduction not in ["mean", "none"]: 124 | raise NotImplementedError("Only support reduction = 'mean' or 'none'") 125 | self.ignore_index = ignore_index 126 | self.reduction = reduction 127 | self.label_smoothing = label_smoothing 128 | self.inplace_backward = inplace_backward 129 | self.process_group = process_group 130 | 131 | def forward(self, input, target): 132 | assert input.is_cuda and target.is_cuda 133 | # SoftmaxCrossEntropyLoss implicitly casts to float 134 | if len(input.shape) == 3: 135 | input = input.view(-1, input.size(-1)) 136 | target = target.view(-1) 137 | loss = SoftmaxCrossEntropyLossFn.apply( 138 | input, 139 | target, 140 | self.label_smoothing, 141 | self.ignore_index, 142 | self.inplace_backward, 143 | self.process_group, 144 | ) 145 | if self.reduction == "mean": 146 | return loss.sum() / (target != self.ignore_index).sum() 147 | else: 148 | return loss -------------------------------------------------------------------------------- /model_training/lit_gpt/packed_dataset.py: -------------------------------------------------------------------------------- 1 | # Very loosely inspired by indexed_dataset in Fairseq, Megatron 2 | # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py 3 | 4 | 5 | import os 6 | import random 7 | import struct 8 | 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import IterableDataset, get_worker_info 12 | 13 | dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16} 14 | 15 | 16 | def code(dtype): 17 | for k in dtypes: 18 | if dtypes[k] == dtype: 19 | return k 20 | raise ValueError(dtype) 21 | 22 | 23 | HDR_MAGIC = b"LITPKDS" 24 | HDR_SIZE = 24 # bytes 25 | 26 | 27 | class PackedDataset(IterableDataset): 28 | def __init__( 29 | self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0 30 | ): 31 | self._filenames = filenames 32 | self._n_chunks = n_chunks 33 | self._block_size = block_size 34 | self._seed = seed 35 | self._shuffle = shuffle 36 | self._wrap = wrap 37 | self._num_processes = num_processes 38 | self._process_rank = process_rank 39 | 40 | def __iter__(self): 41 | worker_info = get_worker_info() 42 | num_workers = worker_info.num_workers if worker_info is not None else 1 43 | worker_id = worker_info.id if worker_info is not None else 0 44 | num_shards = num_workers * self._num_processes 45 | shard_id = self._process_rank * num_workers + worker_id 46 | 47 | max_num_files = len(self._filenames) // num_shards * num_shards 48 | filenames = self._filenames[shard_id:max_num_files:num_shards] 49 | return PackedDatasetIterator( 50 | filenames=filenames, 51 | n_chunks=self._n_chunks, 52 | block_size=self._block_size, 53 | seed=self._seed, 54 | shuffle=self._shuffle, 55 | wrap=self._wrap, 56 | ) 57 | 58 | 59 | class PackedDatasetBuilder(object): 60 | def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None, add_sep_token_doc_end=True): 61 | if dtype == "auto": 62 | if vocab_size is None: 63 | raise ValueError("vocab_size cannot be None when dtype='auto'") 64 | if vocab_size is not None and vocab_size < 65500: 65 | self._dtype = np.uint16 66 | else: 67 | self._dtype = np.int32 68 | else: 69 | self._dtype = dtype 70 | self._counter = 0 71 | self._chunk_size = chunk_size 72 | self._outdir = outdir 73 | self._prefix = prefix 74 | self._sep_token = sep_token 75 | self._add_sep_token_doc_end = add_sep_token_doc_end 76 | self._arr = np.zeros(self._chunk_size, dtype=self._dtype) 77 | self._arr.fill(self._sep_token) 78 | self._idx = 0 79 | self._version = 1 80 | self._filenames = [] 81 | 82 | def _write_chunk(self): 83 | filename = f"{self._prefix}_{self._counter:010d}.bin" 84 | filename = os.path.join(self._outdir, filename) 85 | 86 | with open(filename, "wb") as f: 87 | f.write(HDR_MAGIC) 88 | f.write(struct.pack(" self._chunk_size: 108 | part_len = self._chunk_size - self._idx 109 | self._arr[self._idx : self._idx + part_len] = arr[:part_len] 110 | self._write_chunk() 111 | arr = arr[part_len:] 112 | 113 | arr_len = arr.shape[0] 114 | self._arr[self._idx : self._idx + arr_len] = arr 115 | self._idx += arr_len 116 | 117 | def write_reminder(self): 118 | self._write_chunk() 119 | 120 | 121 | class PackedDatasetIterator: 122 | def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): 123 | self._seed = seed 124 | self._shuffle = shuffle 125 | self._rng = np.random.default_rng(seed) if shuffle else None 126 | self._block_idxs = None 127 | 128 | self._wrap = wrap 129 | 130 | # TODO: instead of filenames, we could have a single text stream 131 | # (or text file) with the sequence of all files to be 132 | # fetched/loaded. 133 | self._filenames = filenames 134 | self._file_idx = 0 135 | 136 | self._n_chunks = n_chunks 137 | 138 | self._dtype = None 139 | self._block_size = block_size 140 | self._n_blocks = None 141 | 142 | self._mmaps = [] 143 | self._buffers = [] 144 | 145 | self._block_idxs = [] 146 | self._curr_idx = 0 147 | 148 | self._load_n_chunks() 149 | 150 | def _read_header(self, path): 151 | with open(path, "rb") as f: 152 | magic = f.read(len(HDR_MAGIC)) 153 | assert magic == HDR_MAGIC, "File doesn't match expected format." 154 | version = struct.unpack(" len(self._filenames[self._file_idx :]): 171 | # if not self._wrap: 172 | # raise StopIteration 173 | self._file_idx = 0 174 | # print("Loading chunks from files:") 175 | # print(self._filenames) 176 | for i in range(self._n_chunks): 177 | filename = self._filenames[self._file_idx + i] 178 | if self._dtype is None: 179 | self._dtype, self._chunk_size = self._read_header(filename) 180 | self._n_blocks = self._chunk_size // self._block_size 181 | # TODO: check header matches with previous files 182 | mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) 183 | self._mmaps.append(mmap) 184 | self._buffers.append(memoryview(mmap)) 185 | 186 | self._file_idx += self._n_chunks 187 | n_all_blocks = self._n_chunks * self._n_blocks 188 | 189 | self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks) 190 | 191 | self._curr_idx = 0 192 | 193 | def __del__(self): 194 | self._close_mmaps() 195 | del self._mmaps 196 | del self._buffers 197 | 198 | def __iter__(self): 199 | return self 200 | 201 | def __next__(self): 202 | if self._curr_idx >= len(self._block_idxs): 203 | self._load_n_chunks() 204 | # TODO: trigger fetching next next n_chunks if remote 205 | block_idx = self._block_idxs[self._curr_idx] 206 | chunk_id = block_idx // self._n_blocks 207 | buffer = self._buffers[chunk_id] 208 | elem_id = (block_idx % self._n_blocks) * self._block_size 209 | offset = np.dtype(self._dtype).itemsize * elem_id 210 | # print('buffer', len(buffer), 'offset', offset) 211 | buffer_length = len(buffer) 212 | offset = max(0, min(offset, buffer_length - 1)) 213 | arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) 214 | self._curr_idx += 1 215 | return torch.from_numpy(arr.astype(np.int64)) 216 | 217 | class CombinedDataset(IterableDataset): 218 | def __init__(self, datasets, seed, weights=None): 219 | self._seed = seed 220 | self._datasets = datasets 221 | self._weights = weights 222 | n_datasets = len(datasets) 223 | if weights is None: 224 | self._weights = [1 / n_datasets] * n_datasets 225 | self._rng = random.Random(seed) 226 | 227 | def __iter__(self): 228 | # The iterator state is initialized here 229 | self._dataset_iters = [iter(dataset) for dataset in self._datasets] 230 | return self 231 | 232 | def __next__(self): 233 | idx = self._rng.choices(range(len(self._datasets)), weights=self._weights, k=1)[0] 234 | # print('chose dataset', idx, 'with weight', self._weights[idx]) 235 | return next(self._dataset_iters[idx]) 236 | -------------------------------------------------------------------------------- /model_training/convert_lit_checkpoint.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import gc 3 | import sys 4 | from functools import partial 5 | from pathlib import Path 6 | from typing import Dict, Literal, Optional, Tuple, Union 7 | from safetensors import safe_open 8 | from safetensors.torch import save_file 9 | import os 10 | import torch 11 | import json 12 | 13 | # support running without installing as a package 14 | wd = Path(__file__).parent.parent.resolve() 15 | sys.path.append(str(wd)) 16 | 17 | from lit_gpt import Config 18 | from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load 19 | # from scripts.convert_hf_checkpoint import layer_template, load_param 20 | 21 | 22 | def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: 23 | split = layer_name.split(".") 24 | number = int(split[idx]) 25 | split[idx] = "{}" 26 | from_name = ".".join(split) 27 | return from_name, number 28 | 29 | 30 | def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: 31 | if hasattr(param, "_load_tensor"): 32 | # support tensors loaded via `lazy_load()` 33 | print(f"Loading {name!r} into RAM") 34 | param = param._load_tensor() 35 | if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: 36 | print(f"Converting {name!r} from {param.dtype} to {dtype}") 37 | param = param.to(dtype) 38 | return param 39 | 40 | def copy_weights_llama( 41 | config: Config, 42 | state_dict: Dict[str, torch.Tensor], 43 | lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 44 | saver: Optional[incremental_save] = None, 45 | ): 46 | weight_map = { 47 | "transformer.wte.weight": "model.embed_tokens.weight", 48 | "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", 49 | "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", 50 | "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", 51 | "transformer.h.{}.mlp.swiglu.w1.weight": "model.layers.{}.mlp.gate_proj.weight", 52 | "transformer.h.{}.mlp.swiglu.w2.weight": "model.layers.{}.mlp.up_proj.weight", 53 | "transformer.h.{}.mlp.swiglu.w3.weight": "model.layers.{}.mlp.down_proj.weight", 54 | "transformer.ln_f.weight": "model.norm.weight", 55 | "lm_head.weight": "lm_head.weight", 56 | } 57 | for name, param in lit_weights.items(): 58 | if name.endswith(".attn.attn.weight"): 59 | from_name, number = layer_template(name, 2) 60 | q = "model.layers.{}.self_attn.q_proj.weight".format(number) 61 | k = "model.layers.{}.self_attn.k_proj.weight".format(number) 62 | v = "model.layers.{}.self_attn.v_proj.weight".format(number) 63 | qkv = load_param(param, name, None) 64 | qp, kp, vp = tensor_split(qkv, config) 65 | for to_name, param in zip((q, k, v), (qp, kp, vp)): 66 | if saver is not None: 67 | param = saver.store_early(param) 68 | state_dict[to_name] = param 69 | elif "transformer.h" in name: 70 | from_name, number = layer_template(name, 2) 71 | to_name = weight_map[from_name] 72 | 73 | if to_name is None: 74 | continue 75 | to_name = to_name.format(number) 76 | param = load_param(param, name,None) 77 | if saver is not None: 78 | param = saver.store_early(param) 79 | state_dict[to_name] = param 80 | 81 | else: 82 | to_name = weight_map[name] 83 | param = load_param(param, name, None) 84 | if saver is not None: 85 | param = saver.store_early(param) 86 | state_dict[to_name] = param 87 | 88 | def save_huggingface_config(config: Config, out_dir: Path) -> None: 89 | default_config_str = """{ 90 | "architectures": [ 91 | "LlamaForCausalLM" 92 | ], 93 | "attention_bias": false, 94 | "bos_token_id": 1, 95 | "eos_token_id": 2, 96 | "hidden_act": "silu", 97 | "hidden_size": 2048, 98 | "initializer_range": 0.02, 99 | "intermediate_size": 5632, 100 | "max_position_embeddings": 2048, 101 | "model_type": "llama", 102 | "num_attention_heads": 16, 103 | "num_hidden_layers": 22, 104 | "rms_norm_eps": 1e-05, 105 | "rope_scaling": null, 106 | "rope_theta": 10000.0, 107 | "tie_word_embeddings": false, 108 | "torch_dtype": "bfloat16", 109 | "transformers_version": "4.35.0", 110 | "use_cache": true, 111 | "vocab_size": 50432 112 | } 113 | """ 114 | config_dict = json.loads(default_config_str) 115 | # modify the dict according to the config 116 | config_dict["hidden_size"] = config.n_embd 117 | config_dict["intermediate_size"] = config.intermediate_size 118 | config_dict["vocab_size"] = config.vocab_size 119 | config_dict["num_attention_heads"] = config.n_head 120 | config_dict["num_hidden_layers"] = config.n_layer 121 | config_dict["max_position_embeddings"] = config.block_size 122 | config_dict["rope_theta"] = config.rope_base 123 | # save the config to the output directory 124 | with open(out_dir / "config.json", "w") as f: 125 | json.dump(config_dict, f, indent=4, ensure_ascii=False) 126 | 127 | 128 | def tensor_split( 129 | param: Union[torch.Tensor, NotYetLoadedTensor], config: Config 130 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 131 | def kstart(start, blen, klen) -> int: 132 | """returns start index of keys in batch""" 133 | return start + (blen - (klen * 2)) 134 | 135 | def vstart(start, blen, klen) -> int: 136 | """returns start index of values in batch""" 137 | return start + blen - klen 138 | 139 | def vend(start, blen) -> int: 140 | """returns last index of values in batch""" 141 | return start + blen 142 | 143 | # num observations 144 | nobs = param.shape[0] 145 | # batch length 146 | blen = nobs // config.n_query_groups 147 | # key length in batch 148 | klen = config.head_size 149 | # value length in batch 150 | vlen = config.head_size 151 | # the starting index of each new batch 152 | starts = range(0, nobs, blen) 153 | # the indices to splice on 154 | splices = [(s, kstart(s, blen, klen), vstart(s, blen, vlen), vend(s, blen)) for s in starts] 155 | 156 | qc = () 157 | kc = () 158 | vc = () 159 | 160 | for splice in splices: 161 | qs, ks, vs, ve = splice 162 | qc += (param[qs:ks, :],) 163 | kc += (param[ks:vs, :],) 164 | vc += (param[vs:ve, :],) 165 | 166 | q = torch.cat(qc) 167 | k = torch.cat(kc) 168 | v = torch.cat(vc) 169 | 170 | return q, k, v 171 | 172 | 173 | def maybe_unwrap_state_dict(lit_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 174 | return lit_weights.get("model", lit_weights) 175 | 176 | 177 | def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: 178 | weight_names = {wk.split(".")[-1] for wk in lit_weights} 179 | # LoRA or QLoRA 180 | if any("lora" in wn for wn in weight_names): 181 | raise ValueError("Model weights must be merged using `lora.merge_lora_weights()` before conversion.") 182 | # adapter v2. adapter_bias will only be in adapter_v2 183 | elif "adapter_bias" in weight_names: 184 | raise NotImplementedError("Converting models finetuned with adapter_v2 not yet supported.") 185 | # adapter. gating_factor is in adapter and adapter_v2 186 | elif "gating_factor" in weight_names: 187 | raise NotImplementedError("Converting models finetuned with adapter not yet supported.") 188 | 189 | 190 | @torch.inference_mode() 191 | def convert_lit_checkpoint(*, checkpoint_name: str, inp_dir: Path, out_dir: Path, model_name: str) -> None: 192 | config = Config.from_name(model_name) 193 | copy_fn = partial(copy_weights_llama, config) 194 | 195 | # initialize a new empty state dict to hold our new weights 196 | sd = {} 197 | 198 | # checkpoint_name cannot be hardcoded because there exists different outputs such as 199 | # ("lit_model_finetuned.pth", "lit_model_lora_finetuned.pth", "lit_model_adapter_finetuned.pth"") 200 | pth_file = inp_dir / checkpoint_name 201 | bin_file = out_dir / "model.safetensors" 202 | 203 | with incremental_save(bin_file) as saver: 204 | with contextlib.ExitStack() as stack: 205 | lit_weights = stack.enter_context(lazy_load(pth_file)) 206 | lit_weights = maybe_unwrap_state_dict(lit_weights) 207 | check_conversion_supported(lit_weights) 208 | # Incremental save will trigger error 209 | copy_fn(sd, lit_weights, saver=None) 210 | gc.collect() 211 | # if there is any, remove the original checkpoint 212 | os.remove(bin_file) 213 | print(f"Saving model to {bin_file}") 214 | # use safe tensor to save the model 215 | save_file(sd, bin_file, metadata={"format": "pt"}) 216 | # save the config 217 | save_huggingface_config(config, out_dir) 218 | 219 | if __name__ == "__main__": 220 | from jsonargparse import CLI 221 | 222 | CLI(convert_lit_checkpoint, as_positional=False) 223 | -------------------------------------------------------------------------------- /data/test_pile_loss_1B.csv: -------------------------------------------------------------------------------- 1 | index,metric/the_pile_arxiv_val_loss,metric/the_pile_freelaw_val_loss,metric/the_pile_pubmed_central_val_loss,metric/the_pile_wikipedia_en_val_loss,metric/the_pile_dm_mathematics_val_loss,metric/the_pile_github_val_loss,metric/the_pile_stackexchange_val_loss,metric/the_pile_gutenberg_pg_19_val_loss,metric/the_pile_pile_cc_val_loss,metric/the_pile_ubuntu_irc_val_loss,metric/the_pile_hackernews_val_loss,metric/the_pile_pubmed_abstracts_val_loss,metric/the_pile_uspto_backgrounds_val_loss 2 | 0,1.772475243,1.929630041,1.771565676,2.464135885,2.071599376,1.139918923,1.869868159,3.004983805,2.932116032,2.854564877,3.07279878,2.561388472,2.050786957 3 | 1,1.82132411,1.933466315,1.707535386,2.625478745,1.660228216,0.945311725,1.759435415,3.127051616,3.065447092,2.910561164,2.84590098,2.258346934,2.202809584 4 | 2,1.852889061,1.95453918,1.763566494,2.495107412,1.34842858,0.982568622,1.703691363,2.664145139,2.887698889,2.552369055,2.810529253,2.56329232,2.487578113 5 | 3,1.871118426,1.921996236,1.781103969,2.307990551,2.067314806,0.950841486,2.076438189,3.025715754,2.983541489,2.341011081,3.339028852,2.542783248,2.662802867 6 | 4,1.718078017,2.224662781,1.712182164,2.542777061,1.280074385,1.17312777,1.768723369,2.871865072,2.903541088,3.040912651,3.244347956,2.399535512,2.495547278 7 | 5,1.901320457,1.78795743,1.781135678,2.64007473,1.258559084,0.924151599,1.78302002,3.031377172,2.983458757,3.038995902,3.271326735,2.564927231,2.29115331 8 | 6,1.921761632,1.873723388,1.845577598,2.516471148,1.382343393,1.196356654,1.995227337,2.344593933,2.953521729,1.957720802,2.773050563,2.556264589,2.458662116 9 | 7,1.77666688,1.777289391,1.782250285,2.400994778,1.361461239,1.239314795,1.756182194,2.560659283,2.907395601,2.965600229,3.255798042,2.475817577,2.601054312 10 | 8,1.733498216,2.150779009,1.796490788,2.457215786,1.365791748,0.973023057,1.717704773,2.984162243,2.905272961,2.983678682,2.896047123,2.544957132,2.243457784 11 | 9,1.698770761,1.987907529,1.698378682,2.750323534,1.465964686,0.920929253,1.765942097,2.821070007,3.06698823,2.419606044,3.107153075,2.35053658,2.323238266 12 | 10,1.793516636,1.744507432,1.762725472,2.570724487,2.069928596,1.165017128,1.846324086,2.930554712,2.925321341,2.888525089,2.675413838,2.476934515,2.274876941 13 | 11,1.759275317,1.963539481,1.726604939,2.456609964,1.521095621,1.055192709,1.756072998,2.980760282,2.89629674,3.026083492,3.229554371,2.321108264,2.552013111 14 | 12,1.804684877,1.970746756,1.767432809,2.423353434,1.219493369,1.056476116,2.019491434,2.528429764,2.924890757,2.011688993,3.11616733,2.355455017,2.276971511 15 | 13,1.770414352,2.378318071,1.646798015,2.721501112,1.656821814,1.183585525,1.832362771,2.482357877,3.131352186,3.132860956,3.411323905,2.357311472,2.214604392 16 | 14,1.709278941,2.043212175,1.798157573,2.281183004,1.927816931,1.115984082,1.982399583,2.950737151,2.894127131,2.148647456,3.254193661,2.621867779,2.359158275 17 | 15,1.754067659,1.985329509,1.741275907,2.335630178,1.352117009,0.903721273,1.793732285,2.750845534,3.029270887,3.098874081,3.056008303,2.468921354,2.337581342 18 | 16,1.700834513,2.101693392,1.690371156,2.523807049,1.270254751,1.23278451,2.11239934,3.017164888,2.95811367,3.010311501,3.344694657,2.21014311,2.328827596 19 | 17,2.271117926,2.084325552,1.865561604,2.297910452,1.372556316,1.257833362,1.856941223,2.896000391,2.891341448,1.765116104,3.238645355,2.663124948,2.489389285 20 | 18,1.618142962,2.254678011,1.748286128,2.31829834,1.805385822,1.155425191,1.825154066,2.982484897,3.091902494,2.454433373,3.284224034,2.602948994,2.561028812 21 | 19,1.782048702,2.026126862,1.7100178,2.535834551,1.207641524,0.956748188,1.741969228,2.532587109,3.001332521,3.048768464,3.281247812,2.41874416,2.364264697 22 | 20,1.876127839,2.127156258,1.777980566,2.276982307,1.705852722,0.951963186,1.713380694,2.961901045,2.961784363,2.068814743,2.811564416,2.519042107,2.566841146 23 | 21,1.906871796,1.923280239,1.772812486,2.343648434,1.605948029,1.041365027,1.762238383,2.474646503,3.054457903,2.567356814,2.737299739,2.18347252,2.574668078 24 | 22,1.817482591,2.027499437,1.70693171,2.443053484,1.670460384,1.143073201,1.996132016,2.819352114,2.847448587,3.080217679,3.229791853,2.487786857,2.34910971 25 | 23,1.827135563,1.81673646,1.769877315,2.461703777,1.453074942,0.906314373,1.677154541,3.163745211,3.180683136,3.077856632,3.355683372,2.377372356,2.334127121 26 | 24,1.819347143,1.784926534,1.735267162,2.643065691,2.066571151,0.990238309,1.888345957,3.042556886,3.082452297,2.609074167,2.96527544,2.271061026,2.125906709 27 | 25,1.817542076,1.90008831,1.747023582,2.642925262,2.009400777,0.950198472,1.693772197,2.399357917,3.014286995,2.41114802,3.243299042,2.552621569,2.596184639 28 | 26,1.867398977,1.99461019,1.84950614,2.389622211,1.315774692,1.079162836,1.738507986,2.732801263,2.914672852,3.006989559,3.250029229,2.529228248,2.089175113 29 | 27,1.847033381,1.953591824,1.749462605,2.415326357,1.878203621,1.029386163,2.061135292,2.299435777,2.94673419,2.111459306,3.314567064,2.535692867,2.607327887 30 | 28,1.712209702,2.037532568,1.829063654,2.285975933,1.97360207,1.047486782,1.893887401,2.638953252,2.874346972,2.224412577,3.234814558,2.552971132,2.445108412 31 | 29,1.832629085,1.913353801,1.747409821,2.601226807,1.211900973,0.993856013,1.81689918,3.069857709,3.056817055,1.868226534,3.308858234,2.355970554,2.574367501 32 | 30,1.711629391,2.115322351,1.790459871,2.235873461,1.291792054,1.128371716,1.928955317,3.046414687,2.993082285,2.756694447,3.368788235,2.445066672,2.180521409 33 | 31,1.815322042,1.838795066,1.720708609,2.416088104,1.268845066,0.953660727,1.903256178,2.54843412,2.972913504,3.059209029,3.306315529,2.412181263,2.443442542 34 | 32,1.770895004,2.194992304,1.664762259,2.455331326,1.242625635,1.159087777,1.751788855,2.811875153,2.989976645,2.768311875,3.303146826,2.29782962,2.422744673 35 | 33,1.694593787,2.018740892,1.837630868,2.310950279,2.025455408,1.371160507,2.071642399,2.523039088,2.886336088,3.108337993,2.883659234,2.648759076,2.322922528 36 | 34,1.773963928,1.836924434,1.839129567,2.672827005,1.54620685,1.39827621,2.149958134,2.639478547,2.817120314,2.497406466,3.232984467,2.639764184,2.345454887 37 | 35,1.679174662,2.157022953,1.714757085,2.772733688,1.998600162,0.97308594,1.75348568,2.588482193,2.984440327,2.655125073,3.272802458,2.302544055,2.520912165 38 | 36,1.912855387,2.029915094,1.733725667,2.944387674,2.182944706,1.224697351,2.047427654,3.447351329,3.340331554,3.150656575,2.980066857,2.509507104,2.271252717 39 | 37,1.692226887,2.27069068,1.689585686,2.414070606,1.869588309,0.975606084,1.878979921,3.130741551,3.273200512,1.903542394,3.433000979,2.443930429,2.414507561 40 | 38,1.693622828,1.893389344,1.719364762,2.516539574,1.918037401,1.037318826,1.691018224,3.05889434,2.999653339,3.038144929,3.247869975,2.428412307,2.546700147 41 | 39,1.761396408,1.950439811,1.865528226,2.375950336,1.373941558,1.039267898,1.719983816,2.712838116,2.978061438,3.024896236,2.799376964,2.534143607,2.234953191 42 | 40,1.636999965,1.785862088,1.759807229,2.64171505,1.523516422,1.202476621,2.114892006,2.858855538,3.229916811,3.225806486,3.337870243,2.307553106,2.210491835 43 | 41,1.623373389,1.910992384,1.788724303,2.62486577,1.718988064,1.07633841,1.782186508,3.043124765,2.984581947,3.08239796,3.196723169,2.432397733,2.323502355 44 | 42,1.928222656,1.975226521,1.801280379,2.466973305,1.482576089,1.034046412,1.737078786,2.854251071,2.838392258,2.434076508,3.028771107,2.579322021,2.313450098 45 | 43,1.845522523,1.899145484,1.734078288,2.712808132,1.248003653,0.89366287,1.964516521,3.002593541,2.978356838,3.075605506,3.033838023,2.367844801,2.612222536 46 | 44,1.765939713,2.208936691,1.801998615,2.643113613,1.31825893,1.138341069,1.6276443,2.881691281,3.020575285,2.166395942,3.252965965,2.558464602,2.163741044 47 | 45,1.88476193,2.052605152,1.799130917,2.37244606,1.239317411,0.961004138,1.730748296,3.023983099,2.95250845,1.85507975,2.773547508,2.509233111,2.29265387 48 | 46,1.89928925,1.965020418,1.742106318,2.475737333,1.445367287,0.941100001,1.771437764,2.504724105,2.910966158,2.983030013,3.233900763,2.549721702,2.287726825 49 | 47,1.795859933,1.854085207,1.807201862,2.429303885,1.28402801,1.016085625,2.036367178,2.617231067,2.872348547,2.516399247,3.246430039,2.443977053,2.342180972 50 | 48,1.834916949,1.860092878,1.854009032,2.386642456,1.279370474,0.932895362,1.835905433,2.95723773,2.89528656,3.011356025,3.23593757,2.485421216,2.498710189 51 | 49,1.793205857,1.900776982,1.692210078,2.851662874,1.180990467,0.966184497,1.94066596,2.708598787,3.055478334,2.506101472,3.093063536,2.437020233,2.500860387 52 | 50,1.723504901,1.984735727,1.738213181,2.431320667,1.284084044,1.067343712,1.766240835,2.562374399,2.943610668,3.029079142,3.264163986,2.304242556,2.334802308 53 | 51,1.987421393,2.069823503,1.695324421,2.640327692,1.572589845,0.967883766,1.813545585,2.347271207,3.051230907,2.065844331,2.984993729,2.373526302,2.405406403 54 | 52,1.905190349,1.947031856,1.844300747,2.730347872,1.431928307,0.831067204,1.86265409,3.047344516,2.967854261,2.704577582,3.25806724,2.645423714,2.310086445 55 | 53,1.675474167,1.850112319,1.723280311,2.518544197,1.380700342,1.209388494,2.135176897,2.929589442,2.965378284,2.126141673,3.320211891,2.386400587,2.216182577 56 | 54,2.000249386,1.822113037,1.715187669,2.617889404,1.78605917,1.00979805,2.030132532,2.334677637,3.149806738,2.36357518,3.457481465,2.444207175,2.228284281 57 | 55,1.700823069,1.935306311,1.75629878,2.418775797,2.040744133,1.141554713,1.822455764,2.941230353,2.889985323,2.848643927,3.115848664,2.322724908,2.426582054 58 | 56,1.772464275,1.886731267,1.806898475,2.379571199,2.085706639,1.141046047,1.935325623,2.737732085,2.8610425,2.98950763,2.952221981,2.547611391,2.188945786 59 | 57,1.749443054,2.18036437,1.853727937,2.323466539,1.487762663,1.015905738,1.827180743,2.94223612,2.904024839,2.307071061,3.115009855,2.602976681,2.256023864 60 | 58,1.618191242,2.005319118,1.714857101,2.685310602,1.648767589,1.000048041,1.89790535,2.725076439,3.068170786,2.145767331,3.238878009,2.515766957,2.461793805 61 | 59,1.822707176,1.828627229,1.771233797,2.537379026,1.279437165,1.103004694,1.796471715,2.935739698,2.940165758,1.919956616,3.242210773,2.589107027,2.400718635 62 | 60,1.788510799,2.098105431,1.732894301,2.634582996,1.473131576,0.899435103,1.739066482,2.828013376,2.958841801,2.457704885,2.972730461,2.422380078,2.393242665 63 | 61,1.683397532,1.833518147,1.708714008,2.368860483,1.349857561,1.117126703,1.903519034,3.028511022,3.010350943,3.115333341,3.35389071,2.535585046,2.363933364 64 | 62,1.771267414,1.893473268,1.699028254,2.517937183,1.229804138,1.117470622,1.855293393,2.627224487,2.951977491,3.099171854,3.275360685,2.377167469,2.340537553 65 | 63,1.773458838,2.24131465,1.814669013,2.865906,1.212895713,0.848714828,1.768502593,2.987658588,3.016209126,2.499511378,2.771228405,2.477572256,2.461724748 -------------------------------------------------------------------------------- /mixture_config/synthesize_mixture.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import yaml 4 | import argparse 5 | import os 6 | 7 | SEED = 42 8 | random.seed(SEED) 9 | np.random.seed(SEED) 10 | 11 | # Temperature for the prior distribution, if your distribution is too skewed, you can use a temperature to smooth it 12 | TEMP = 0.5 13 | 14 | # The minimum and maximum strength for the dirichlet distribution. 15 | # With a small value, the distribution will be more concentrated, and with a large value, the distribution will be more uniform. 16 | MIN_STRENGH = 0.1 17 | MAX_STRENGH = 5.0 18 | 19 | # We first sample SAMPLE_MULTIPLIER times more samples than randomly select some of them 20 | SAMPLE_MULTIPLIER = 100 21 | 22 | # How many epochs are allowed for each domain for the large-scale model training. This hyper-parameter 23 | # is used because the natura trade off between the reweighting v.s. the number of avaiable tokens in each domain. 24 | # Usually we think repeating 4 epochs is okay for language model pre-training, and here we set it as 15 25 | # because the avaiable token of The Pile is much larger than the token amount for training Chinchilla-Optimal 1B models (i.e., 25B tokens). 26 | # However, if you want to train the large-scale model with all avaiable tokens, you can use less than 4 epochs also in the proxy 27 | # model training. 28 | MAXIMUM_USAGE = 15 29 | 30 | # Assume that we have 1B (512,000 examples, and 2048 tokens per example) tokens 31 | # for the proxy model training, the minimum sampling rate 2e-4 indicates that 32 | # at least there will be 100 examples for each domain, which is statistically significant. 33 | # 34 | # If you use less tokens for training the proxy models, you may increase the minimum sampling rate 35 | # to ensure the statistical significance of the domain. I personally recommend using at least 1e-5 36 | # if you have 1B tokens for training the proxy models. 37 | MINIMUM = 2e-4 38 | 39 | def get_token_distribution(): 40 | # The prior distribution of the token may be changed due to the tokenizer 41 | # If you want to get the token distribution following the TinyLlama codebase, you can use the 42 | # script 43 | train = { 44 | "train_the_pile_arxiv": 0.113285273, 45 | "train_the_pile_freelaw": 0.079608651, 46 | "train_the_pile_nih_exporter": 0.003913491, 47 | "train_the_pile_pubmed_central": 0.185375901, 48 | "train_the_pile_wikipedia_en": 0.051081359, 49 | "train_the_pile_dm_mathematics": 0.015962925, 50 | "train_the_pile_github": 0.101750772, 51 | "train_the_pile_philpapers": 0.003707518, 52 | "train_the_pile_stackexchange": 0.066529351, 53 | "train_the_pile_enron_emails": 0.001750772, 54 | "train_the_pile_gutenberg_pg_19": 0.027085479, 55 | "train_the_pile_pile_cc": 0.236869207, 56 | "train_the_pile_ubuntu_irc": 0.01184346, 57 | "train_the_pile_europarl": 0.007929969, 58 | "train_the_pile_hackernews": 0.008032956, 59 | "train_the_pile_pubmed_abstracts": 0.038825953, 60 | "train_the_pile_uspto_backgrounds": 0.046446962 61 | } 62 | 63 | # valid cannot be ignored if you want the generated config is evaluated on the target set 64 | valid = { 65 | "valid_the_pile_arxiv": 1.0, 66 | "valid_the_pile_dm_mathematics": 1.0, 67 | "valid_the_pile_enron_emails": 1.0, 68 | "valid_the_pile_europarl": 1.0, 69 | "valid_the_pile_freelaw": 1.0, 70 | "valid_the_pile_github": 1.0, 71 | "valid_the_pile_gutenberg_pg_19": 1.0, 72 | "valid_the_pile_hackernews": 1.0, 73 | "valid_the_pile_nih_exporter": 1.0, 74 | "valid_the_pile_philpapers": 1.0, 75 | "valid_the_pile_pile_cc": 1.0, 76 | "valid_the_pile_pubmed_abstracts": 1.0, 77 | "valid_the_pile_pubmed_central": 1.0, 78 | "valid_the_pile_stackexchange": 1.0, 79 | "valid_the_pile_ubuntu_irc": 1.0, 80 | "valid_the_pile_uspto_backgrounds": 1.0, 81 | "valid_the_pile_wikipedia_en": 1.0 82 | } 83 | return {"train": train, "valid": valid} 84 | 85 | 86 | def generate_train_group(groups, weights, precision=5): 87 | """ 88 | Generate a formatted string of groups and their corresponding weights. 89 | 90 | Args: 91 | groups (list): List of group names. 92 | weights (list): List of corresponding weights. 93 | sample_folder (str, optional): If provided, will be included in the group name. 94 | prefix (str, optional): Prefix to add before each group name. Defaults to 'train'. 95 | precision (int, optional): Number of decimal places for rounding weights. Defaults to 4. 96 | 97 | Returns: 98 | str: Formatted string of groups and weights. 99 | """ 100 | assert len(groups) == len(weights), "Length of groups and weights must be equal" 101 | 102 | def format_weight(weight): 103 | return f"{weight:.{precision}f}".rstrip('0').rstrip('.') 104 | 105 | output_group = [f" {group}: {format_weight(num)}" 106 | for group, num in zip(groups, weights)] 107 | 108 | return "\n".join(output_group) 109 | 110 | def generate_valid_group(groups): 111 | weights = [1.0] * len(groups) 112 | output_group = [f" {group}: {num}" for group, num in zip(groups, weights)] 113 | return "\n".join(output_group) 114 | 115 | 116 | def generate_weights_dirichlet(prior_dist, 117 | train_groups, 118 | minimum_number, 119 | num_samples=128, 120 | enable_bound=True, 121 | temperature=1.0): 122 | 123 | final_samples = [] 124 | 125 | if enable_bound: 126 | # generate the bound for reject sampling 127 | number_bound = [] 128 | for i in range(len(prior_dist)): 129 | # the token cannot be used more than 4 times 130 | number_bound.append([0.0, 131 | min(prior_dist[i] * MAXIMUM_USAGE, 1.0)]) 132 | else: 133 | number_bound = None 134 | 135 | # apply temperature 136 | if temperature < 1.0: 137 | prior_dist = prior_dist ** TEMP 138 | prior_dist = prior_dist / np.sum(prior_dist) 139 | print("\n\nWith temperature: ", prior_dist) 140 | 141 | print("\n\nThe domain usage bound (maximum domain weight): ") 142 | # print the bound for each group 143 | for i in range(len(prior_dist)): 144 | print(f"{train_groups[i]}: {number_bound[i][1]}") 145 | 146 | # combine reject sampling with dirichlet distribution 147 | for i in range(num_samples * SAMPLE_MULTIPLIER): 148 | if MIN_STRENGH == MAX_STRENGH: 149 | samples = np.random.dirichlet(prior_dist * MIN_STRENGH, 1) 150 | else: 151 | samples = [] 152 | min_strength_log = np.log10(MIN_STRENGH) 153 | max_strength_log = np.log10(MAX_STRENGH) 154 | for strength in np.logspace(min_strength_log, max_strength_log, 15): 155 | # add a noise to the strength 156 | samples_per_strength = np.random.dirichlet(prior_dist * strength, 1) 157 | samples.append(samples_per_strength) 158 | # random sample one 159 | samples = random.choice(samples) 160 | # if there is a bound, the bound is a list of tuples indicating the lower and upper bound of each group 161 | ensure_flag = True 162 | if number_bound is not None: 163 | for j in range(len(samples[0])): 164 | if samples[0][j] < number_bound[j][0] or samples[0][j] > number_bound[j][1]: 165 | ensure_flag = False 166 | break 167 | if ensure_flag is False: 168 | continue 169 | # post normalization, set zero for the number less than minimum_number 170 | samples = np.where(samples < minimum_number, 0.0, samples) 171 | # round samples into the same scale of minimum_number 172 | samples = samples / np.sum(samples, axis=1).reshape(-1, 1) 173 | samples = np.round(samples / minimum_number) * minimum_number 174 | # add the samples to the final_samples 175 | final_samples.append(samples[0]) 176 | 177 | # remove the samples with the nearly same values 178 | print("\nThe number of avaiable samples: ", len(final_samples)) 179 | # deduplicate the samples 180 | final_samples = sort_and_deduplicate(np.array(final_samples)) 181 | # remove the samples with the nearly same values 182 | print("The number of deduplicated samples: ", len(final_samples)) 183 | selected_samples = random.sample(final_samples, num_samples) 184 | print("The number of selected samples: ", len(selected_samples)) 185 | selected_samples = np.stack(selected_samples, axis=0) 186 | return selected_samples 187 | 188 | 189 | def generate_config_from_prior(output_paths, prior_config): 190 | number_of_samples = len(output_paths) 191 | # read the yaml file and get the prior distribution 192 | train_config = prior_config["train"] 193 | train_groups, prior_dist = [], [] 194 | for k, v in train_config.items(): 195 | train_groups.append(k) 196 | prior_dist.append(v) 197 | 198 | # renormalize the prior distribution 199 | prior_dist = prior_dist / np.sum(prior_dist) 200 | print("Prior distribution after normalization: ", prior_dist) 201 | 202 | valid_config = prior_config["valid"] 203 | valid_groups = list(valid_config.keys()) 204 | 205 | train_weights = generate_weights_dirichlet(prior_dist, 206 | train_groups, 207 | MINIMUM, 208 | number_of_samples, 209 | temperature=TEMP) 210 | 211 | for output_path, weights in zip(output_paths, train_weights): 212 | # get the train and valid group 213 | train_group = generate_train_group(train_groups, weights) 214 | valid_group = generate_valid_group(valid_groups) 215 | 216 | with open(output_path, "w", encoding="utf8") as f: 217 | f.write("train:\n") 218 | f.write(train_group) 219 | f.write("\n") 220 | f.write("valid:\n") 221 | f.write(valid_group) 222 | f.write("\n") 223 | f.write(f"seed: {SEED}\n") 224 | f.write(f"temperature: {TEMP}\n") 225 | f.write(f"min_strength: {MIN_STRENGH}\n") 226 | f.write(f"max_strength: {MAX_STRENGH}\n") 227 | f.write(f"minimum: {MINIMUM}\n") 228 | f.write(f"sample_multiplier: {SAMPLE_MULTIPLIER}\n") 229 | f.write(f"maximum_usage: {MAXIMUM_USAGE}\n") 230 | 231 | # these are configurations for the model 232 | content = "" 233 | content += "\n" + "model_name: tinyllama_1M" 234 | # content += "\n" + "model_name: tinycoder_1M" 235 | content += "\n" + "total_devices: 1" 236 | content += "\n" + "num_of_devices: 1" 237 | content += "\n" + "global_batch_size: 512" 238 | content += "\n" + "micro_batch_size: 16" 239 | # 1001 instead of 1000 because wandb has the bug of not showing the last step 240 | content += "\n" + "max_step: 1001" 241 | 242 | # never save the model, just using the wandb log for regression fitting 243 | content += "\n" + "save_step_interval: 2000" 244 | content += "\n" + "eval_step_interval: 100" 245 | 246 | # constant learning rate for the small model 247 | content += "\n" + "learning_rate: 0.0004" 248 | content += "\n" + "min_lr: 0.0004" 249 | # the warmup step is 100 250 | content += "\n" + "warmup_steps: 100" 251 | f.write(content) 252 | 253 | def sort_and_deduplicate(data, threshold=1e-5): 254 | """ 255 | Remove identify configs to avoid duplicated training. 256 | """ 257 | arr = np.array(data) 258 | sorted_indices = np.lexsort(arr.T) 259 | sorted_arr = arr[sorted_indices] 260 | result = [sorted_arr[0]] 261 | 262 | for i in range(1, len(sorted_arr)): 263 | diff = np.sum(np.abs(sorted_arr[i] - result[-1])) 264 | if diff > threshold: 265 | result.append(sorted_arr[i]) 266 | 267 | return result 268 | 269 | if __name__ == "__main__": 270 | parser = argparse.ArgumentParser() 271 | parser.add_argument("--output_folder", type=str, default="config_1m") 272 | parser.add_argument("--num_configs", type=int, default=512) 273 | 274 | args = parser.parse_args() 275 | output_folder = args.output_folder 276 | num_samples = args.num_configs 277 | 278 | # if not exist, create the folder 279 | if not os.path.exists(output_folder): 280 | os.makedirs(output_folder) 281 | 282 | output_paths = [] 283 | for i in range(1, num_samples + 1): 284 | output_paths.append(f"{output_folder}/n{i}.yaml") 285 | 286 | generate_config_from_prior(output_paths, 287 | prior_config=get_token_distribution()) 288 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🧬 RegMix: Data Mixture as Regression for Language Model Pre-training 2 | 3 | Welcome to the official repository of [RegMix](https://huggingface.co/papers/2407.01492), a new approach to optimizing data mixtures for large language model (LLM) pre-training! 4 | Join our [Discord](https://discord.gg/y3hpRaXGaU) for more discussions! 5 | 6 | ## 🌟 What is RegMix? 7 | 8 | RegMix is a novel method that treats data mixture selection as a **regression task**. By training small "proxy" models on diverse data mixtures and analyzing their performance, RegMix builds a regression model that can predict the optimal data mixture for training large-scale LLMs. 9 | 10 | ![RegMix Method Visualization](misc/method_figure.png) 11 | 12 | ## 🚀 How RegMix Works 13 | 14 | RegMix follows a four-step process to optimize LLM training: 15 | 16 | 1. **Generate Configs**: Create various different data mixture configurations. 17 | 2. **Train Small Models**: Use these configs to train small "proxy" models. 18 | 3. **Fit Regression Model**: Analyze the performance of these models (e.g., the validation loss on Pile-CC) to build a predictive regression model. 19 | 4. **Train Large Model**: Use the predicted optimal mixture to train a large-scale LLM. 20 | 21 | ## 🧰 What's in This Repo? 22 | 23 | Our repository is organized into four main components: 24 | 25 | 1. [**mixture_config**](mixture_config): Tools for synthesizing and visualizing data mixtures. 26 | - Generate diverse data mixture configurations. 27 | - Visualize the generated mixtures. 28 | 29 | 2. [**regression_fitting**](regression_fitting): The heart of RegMix. 30 | - Fit the regression model using small model performance data, and by default we use the validation loss on the Pile-CC dataset. 31 | - Simulate and predict the optimal data mixture for large-scale training. 32 | 33 | 3. [**model_training**](model_training): Leveraging [TinyLlama](https://github.com/jzhang38/TinyLlama) for model training. 34 | - Train small (1M parameter) proxy models. 35 | - Scale up to large (1B+ parameter) language models using the predicted optimal mixture. 36 | 37 | 4. [**evaluation**](evaluation): [Work in Progress] Reproduce our evaluation results. 38 | 39 | 40 | ## 🛠 Applying RegMix on Your Dataset 41 | 42 | Want to leverage the power of RegMix for your own language model training? Follow these steps to apply RegMix to your unique dataset: 43 | 44 | ### 1. Prepare Your Data 45 | 46 | - **Organize Your Dataset**: Split your data into distinct categories or domains, with different unique prefix. 47 | - **Format Requirements**: Ensure your data is in a compatible format (e.g., JSON lines, where each line is a valid JSON object containing the `text`). 48 | 49 | An example dataset organization folder structure is as follows: 50 | 51 | ``` 52 | ├── train 53 | │ ├── domain1-0.jsonl 54 | │ ├── domain1-1.jsonl 55 | │ ├── ... 56 | │ ├── domain2-0.jsonl 57 | │ ├── domain2-1.jsonl 58 | │ ├── ... 59 | ├── valid 60 | │ ├── domain1-0.jsonl 61 | │ ├── domain2-0.jsonl 62 | ``` 63 | 64 | > [!CAUTION] 65 | > We use prefix based matching to identify files from the same domain. Therefore, please make sure the prefix is unique for each domain. And please make sure to use the `-` after the prefix to avoid one prefix becoming a substring of another prefix. Our [training code](https://github.com/sail-sg/regmix/blob/e0c0357a312dbb3d62e3aa58a13cf42dd3ef42ee/model_training/pretrain/tinyllama.py#L464) will by default use `-` to make sure that we load the correct domain files! 66 | > Please avoid making each jsonl file too large, as it may cause long time during the preprocessing. 67 | 68 | You can also find the [regmix-data-sample](https://huggingface.co/datasets/sail/regmix-data-sample) for your reference. 69 | 70 | ### 2. Generate Mixture Configurations 71 | 72 | Use the tools in the `mixture_config/` directory to: 73 | - Create a range of data mixture configurations. 74 | - Visualize these mixtures to understand their composition. 75 | 76 | Before generating configurations, ensure you have changed the function `get_token_distribution` in `synthesize_mixture.py` to match your dataset's domain names and token distributions. 77 | 78 | For example, let's assume you have three domains: `domain1`, `domain2`, and `domain3`, and your `DATASET_SHORT_NAME` defined in `model_training/preprocess/run_preprocess.sh` is `your_dataset`. 79 | 80 | To accommodate these domains, you should modify the `get_token_distribution` function as follows: 81 | 82 | ```python 83 | def get_token_distribution(): 84 | # This example uses an equal distribution for each domain. 85 | # Adjust these values if certain domains have significant more available tokens. 86 | train = { 87 | "train_your_dataset_domain1": 0.33, 88 | "train_your_dataset_domain2": 0.33, 89 | "train_your_dataset_domain3": 0.34, 90 | } 91 | # The validation set can be omitted if not needed for 1M model training 92 | valid = { 93 | "valid_your_dataset_domain1": 1.0, 94 | "valid_your_dataset_domain2": 1.0, 95 | "valid_your_dataset_domain3": 1.0, 96 | } 97 | return {"train": train, "valid": valid} 98 | ``` 99 | 100 | Next, generate the mixture configurations: 101 | 102 | ```bash 103 | python synthesize_mixture.py --num_configs 512 --output_folder /path/to/configs 104 | ``` 105 | 106 | > [!TIP] 107 | > The number of configurations typically has the most significant impact on the regression model's accuracy. In our experiments, we utilized 512 configurations for 17 domains. If you're working with fewer domains, you may be able to achieve comparable results with a reduced number of configurations. However, if you have access to additional resources, consider training more proxy models to enhance the regression model's accuracy. This approach is particularly beneficial when dealing with a large number of domains. 108 | > You can also try from a small number of configurations to see if the regression model can predict the unseen data mixture well. 109 | 110 | Finally, visualize the generated mixtures to understand their composition: 111 | 112 | ```bash 113 | python visualize_mixture.py --config_folder /path/to/configs 114 | ``` 115 | 116 | ### 3. Train Proxy Models 117 | 118 | Utilize the `model_training/` scripts to train small "proxy" models on your generated mixtures. Remember to modify the config folder path in the `pretrain_tinyllama_1m.sh` script to match your generated configurations. 119 | 120 | ```bash 121 | cd model_training 122 | for i in {1..512}; do 123 | ./pretrain_tinyllama_1m.sh $i 124 | done 125 | ``` 126 | 127 | ### 4. Fit the Regression Model 128 | 129 | With your proxy model results, use the collect scripts to prepare the data for regression fitting. The first step is to organize the mixture configs into a CSV file: 130 | 131 | ```bash 132 | python collect_mixture_data.py --write_file_path train_mixture_1m_your_dataset.csv --config_folder /path/to/configs 133 | ``` 134 | 135 | The second step is to collect the target performance data for the proxy models. By default we use the Pile-CC validation loss as the target, which is collected from wandb using the `wandb` API. 136 | 137 | ```bash 138 | python collect_loss_data.py --write_file_path train_loss_1m_your_dataset.csv 139 | ``` 140 | 141 | Finally, fit the regression model using the collected data and predict the optimal mixture following the instructions in `regression_fitting/regression.ipynb`. 142 | 143 | ### 5. Train Your Large-Scale LLM 144 | 145 | You can save your final predicted optimal mixture into a yaml file `optimal_mixture.yaml` as a similar format as the config under [mixture_config/config_1b](mixture_config/config_1b). An example of the optimal mixture is as follows: 146 | 147 | ```yaml 148 | train: 149 | train_the_pile_arxiv: 0.0012046169821426883 150 | train_the_pile_freelaw: 0.001454510048554701 151 | train_the_pile_nih_exporter: 0.001231640306882902 152 | train_the_pile_pubmed_central: 0.003108561825532002 153 | train_the_pile_wikipedia_en: 0.01593264140324679 154 | train_the_pile_dm_mathematics: 0.00031106907908634156 155 | train_the_pile_github: 0.00022861228152440253 156 | train_the_pile_philpapers: 1.329107360676338e-05 157 | train_the_pile_stackexchange: 0.00029547405933203174 158 | train_the_pile_enron_emails: 0.0016691646199353991 159 | train_the_pile_gutenberg_pg_19: 0.001612531300038395 160 | train_the_pile_pile_cc: 0.8701291419934237 161 | train_the_pile_ubuntu_irc: 0.06417728505869834 162 | train_the_pile_europarl: 2.9166170357771267e-06 163 | train_the_pile_hackernews: 0.011925517591888925 164 | train_the_pile_pubmed_abstracts: 0.02424425081714838 165 | train_the_pile_uspto_backgrounds: 0.0024587749419225434 166 | valid: 167 | valid_the_pile_pile_cc: 1.0 168 | model_name: tinyllama_1_1b 169 | ``` 170 | 171 | Finally, use the predicted optimal mixture to train your model. Put your `optimal_mixture.yaml` under the folder `mixture_config/config_1b` and run the following script: 172 | 173 | ```bash 174 | cd model_training 175 | ./pretrain_tinyllama_1b.sh optimal_mixture 176 | ``` 177 | 178 | You get the final model trained with the optimal mixture! 179 | 180 | ### Tips for Success 181 | 182 | - **Data Diversity**: Ensure your initial dataset covers a wide range of domains. 183 | - **Proxy Model Size**: While we use 1M parameter models, you might need to adjust based on your computational resources and dataset size. 184 | - **Evaluation**: Choosing the correct target is crucial for the generic downstream performance improvement. You may want to use the loss on a high-quality and diverse validation dataset like Pile-CC. We also recommend using the awesome [paloma evaluation suite](https://huggingface.co/datasets/allenai/paloma) from AI2 for evaluation. 185 | 186 | ### Customization Options 187 | 188 | RegMix is flexible and can be adapted to your specific needs: 189 | - Adjust the number and size of proxy models. 190 | - Modify the regression model architecture or features. 191 | - Incorporate domain-specific metrics in your optimization objective. 192 | 193 | Remember, the key to RegMix's success is in capturing the relationship between data mixture and model performance. The more informative your proxy training runs are, the better your final mixture prediction will be! 194 | 195 | ## 📦 Data and Model Release 196 | 197 | We've made our data and trained models available on HuggingFace! 198 | 199 | ### Model 200 | 201 | Below are the full models, you can load each model with the following code: 202 | 203 | ```python 204 | from transformers import AutoModel, AutoTokenizer 205 | 206 | model_name, revision = "sail/data-mixture-random-1b", "model-index-1" 207 | model = AutoModel.from_pretrained(model_name, revision=revision) 208 | tokenizer = AutoTokenizer.from_pretrained(model_name, revision=revision) 209 | ``` 210 | 211 | And the detailed name and revision of each model is as follows: 212 | 213 | | Model Name | Revisions | Description | Link | 214 | |------------|-----------|-------------|------| 215 | | sail/data-mixture-random-1b | `model-index-1` to `model-index-64` | 64 models with random data mixtures to study correlation between data mixture and downstream performance | [🤗 Hugging Face](https://huggingface.co/sail/data-mixture-random-1b) | 216 | | sail/data-mixture-human-1b | `seed-1` to `seed-5` | 5 models with human-selected data mixture (baseline), using different seeds | [🤗 Hugging Face](https://huggingface.co/sail/data-mixture-human-1b) | 217 | | sail/data-mixture-doremi-1b | `seed-1` to `seed-5` | 5 models with DoReMi best-performing data mixture (baseline), using different seeds | [🤗 Hugging Face](https://huggingface.co/sail/data-mixture-doremi-1b) | 218 | | sail/data-mixture-pile-cc-1b | `seed-1` to `seed-5` | 5 models with Pile-CC only data mixture, using different seeds | [🤗 HuggingFace](https://huggingface.co/sail/data-mixture-pile-cc-1b) | 219 | | sail/data-mixture-regmix-1b | `seed-1` to `seed-5` | 5 models with RegMix data mixture, using different seeds | [🤗 HuggingFace](https://huggingface.co/sail/data-mixture-regmix-1b) | 220 | 221 | ### Data 222 | 223 | We also provide both the full data and the sample data for your reference on HuggingFace. You can download them manually or use the following code to download them: 224 | 225 | ```python 226 | from huggingface_hub import snapshot_download 227 | 228 | # You can choose to download regmix-data, or regmix-data-sample 229 | snapshot_download(repo_id="sail/regmix-data-sample", 230 | repo_type='dataset', 231 | local_dir="sail/regmix-data-sample", 232 | local_dir_use_symlinks=False) 233 | ``` 234 | 235 | Some of the details about these two datasets are as follows: 236 | 237 | | Dataset Name | Description | Size | Link | 238 | |--------------|-------------|------|------| 239 | | sail/regmix-data | Full dataset for RegMix, resplitted from [pile-uncopyrighted](https://huggingface.co/datasets/monology/pile-uncopyrighted) | 250B tokens (~1TB disk space) | [🤗 Hugging Face](https://huggingface.co/datasets/sail/regmix-data) | 240 | | sail/regmix-data-sample | Sample dataset from regmix-data, we keep one file for each domain | 5B tokens (~20GB disk space) | [🤗 Hugging Face](https://huggingface.co/datasets/sail/regmix-data-sample) | 241 | 242 | ## 🔍 Evaluation [Work in Progress] 243 | 244 | Stay tuned! We're currently working on providing comprehensive evaluation setup in the `evaluation` directory. 245 | 246 | ## 📚 Citation 247 | 248 | If RegMix helps your research, please cite our paper: 249 | 250 | ```bibtex 251 | @article{liu2024regmix, 252 | title={RegMix: Data Mixture as Regression for Language Model Pre-training}, 253 | author={Liu, Qian and Zheng, Xiaosen and Muennighoff, Niklas and Zeng, Guangtao and Dou, Longxu and Pang, Tianyu and Jiang, Jing and Lin, Min}, 254 | journal={arXiv preprint arXiv:2407.01492}, 255 | year={2024} 256 | } 257 | ``` 258 | 259 | ## 🤝 Get in Touch 260 | 261 | Excited about RegMix? Have questions? We'd love to hear from you! 262 | 263 | Contact us at: 264 | - liuqian@sea.com 265 | - xszheng.2020@phdcs.smu.edu.sg 266 | 267 | Join us in scalable and efficient data mixture with RegMix! 268 | -------------------------------------------------------------------------------- /model_training/lit_gpt/model.py: -------------------------------------------------------------------------------- 1 | """Full definition of a GPT NeoX Language Model, all of it in this single file. 2 | 3 | Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and 4 | https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. 5 | """ 6 | import math 7 | from typing import Any, List, Optional, Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | from lightning_utilities.core.imports import RequirementCache 12 | from typing_extensions import Self 13 | from lit_gpt.config import Config 14 | from xformers.ops import SwiGLU 15 | 16 | RoPECache = Tuple[torch.Tensor, torch.Tensor] 17 | KVCache = Tuple[torch.Tensor, torch.Tensor] 18 | FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1") 19 | 20 | 21 | class GPT(nn.Module): 22 | def __init__(self, config: Config) -> None: 23 | super().__init__() 24 | assert config.padded_vocab_size is not None 25 | self.config = config 26 | 27 | self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) 28 | self.transformer = nn.ModuleDict( 29 | dict( 30 | wte=nn.Embedding(config.padded_vocab_size, config.n_embd), 31 | h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), 32 | ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), 33 | ) 34 | ) 35 | self.rope_cache: Optional[RoPECache] = None 36 | self.mask_cache: Optional[torch.Tensor] = None 37 | self.kv_caches: List[KVCache] = [] 38 | 39 | def _init_weights(self, module: nn.Module, n_layer) -> None: 40 | """Meant to be used with `gpt.apply(gpt._init_weights)`.""" 41 | # GPT-NeoX https://arxiv.org/pdf/2204.06745.pdf 42 | # print module name 43 | if isinstance(module, nn.Embedding): 44 | # RWKV: set it to 1e-4 45 | torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / module.weight.size(1))) 46 | # torch.nn.init.normal_(module.weight, -1e-4, 1e-4) 47 | elif isinstance(module, nn.Linear): 48 | # fan-in variance scaling intializer 49 | torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / module.weight.size(1))) 50 | if module.bias is not None: 51 | torch.nn.init.zeros_(module.bias) 52 | # GPT-NeoX 53 | for name, p in module.named_parameters(): 54 | if (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU)): #if use xformer swiglu, fc2 layer will be renamed to w3 55 | nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(p.shape[-1]) / n_layer) 56 | 57 | 58 | def reset_cache(self) -> None: 59 | self.kv_caches.clear() 60 | if self.mask_cache is not None and self.mask_cache.device.type == "xla": 61 | # https://github.com/Lightning-AI/lit-gpt/pull/83#issuecomment-1558150179 62 | self.rope_cache = None 63 | self.mask_cache = None 64 | 65 | def forward( 66 | self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None 67 | ) -> torch.Tensor: 68 | B, T = idx.size() 69 | use_kv_cache = input_pos is not None 70 | 71 | block_size = self.config.block_size 72 | if max_seq_length is None: 73 | max_seq_length = block_size 74 | if use_kv_cache: # not relevant otherwise 75 | assert ( 76 | max_seq_length >= T 77 | ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" 78 | assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" 79 | assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}" 80 | 81 | if self.rope_cache is None: 82 | self.rope_cache = self.build_rope_cache(idx) 83 | # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask 84 | # for the kv-cache support (only during inference), we only create it in that situation 85 | # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 86 | if use_kv_cache and self.mask_cache is None: 87 | self.mask_cache = self.build_mask_cache(idx) 88 | 89 | cos, sin = self.rope_cache 90 | if use_kv_cache: 91 | 92 | cos = cos.index_select(0, input_pos) 93 | sin = sin.index_select(0, input_pos) 94 | mask = self.mask_cache.index_select(2, input_pos) 95 | mask = mask[:, :, :, :max_seq_length] 96 | else: 97 | cos = cos[:T] 98 | sin = sin[:T] 99 | mask = None 100 | 101 | # forward the model itself 102 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 103 | 104 | if self.config.embd_pdrop > 0: 105 | x = nn.functional.dropout(x, p=self.config.embd_pdrop, 106 | training=self.training) 107 | 108 | if not use_kv_cache: 109 | for block in self.transformer.h: 110 | x, *_ = block(x, (cos, sin), max_seq_length) 111 | else: 112 | self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2) 113 | for i, block in enumerate(self.transformer.h): 114 | x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i]) 115 | 116 | x = self.transformer.ln_f(x) 117 | 118 | return self.lm_head(x) # (b, t, vocab_size) 119 | 120 | @classmethod 121 | def from_name(cls, name: str, **kwargs: Any) -> Self: 122 | return cls(Config.from_name(name, **kwargs)) 123 | 124 | def build_rope_cache(self, idx: torch.Tensor) -> RoPECache: 125 | return build_rope_cache( 126 | seq_len=self.config.block_size, 127 | n_elem=int(self.config.rotary_percentage * self.config.head_size), 128 | base=self.config.rope_base, 129 | dtype=torch.bfloat16, 130 | device=idx.device, 131 | condense_ratio=self.config.condense_ratio, 132 | ) 133 | 134 | def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor: 135 | ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool) 136 | return torch.tril(ones).unsqueeze(0).unsqueeze(0) 137 | 138 | def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]: 139 | B = idx.size(0) 140 | heads = 1 if self.config.n_query_groups == 1 else self.config.n_query_groups 141 | 142 | k_cache_shape = ( 143 | B, 144 | max_seq_length, 145 | heads, 146 | rope_cache_length + self.config.head_size - int(self.config.rotary_percentage * self.config.head_size), 147 | ) 148 | v_cache_shape = (B, max_seq_length, heads, self.config.head_size) 149 | device = idx.device 150 | return [ 151 | (torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device)) 152 | for _ in range(self.config.n_layer) 153 | ] 154 | 155 | 156 | class Block(nn.Module): 157 | def __init__(self, config: Config) -> None: 158 | super().__init__() 159 | self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) 160 | self.attn = CausalSelfAttention(config) 161 | if not config.shared_attention_norm: 162 | self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) 163 | self.mlp = config.mlp_class(config) 164 | self.config = config 165 | def forward( 166 | self, 167 | x: torch.Tensor, 168 | rope: RoPECache, 169 | max_seq_length: int, 170 | mask: Optional[torch.Tensor] = None, 171 | input_pos: Optional[torch.Tensor] = None, 172 | kv_cache: Optional[KVCache] = None, 173 | ) -> Tuple[torch.Tensor, Optional[KVCache]]: 174 | 175 | n_1 = self.norm_1(x) 176 | h, new_kv_cache = self.attn(n_1, rope, max_seq_length, mask, input_pos, kv_cache) 177 | if self.config.parallel_residual: 178 | n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) 179 | ffn = self.mlp(n_2) 180 | x = x + h + ffn 181 | else: 182 | if self.config.shared_attention_norm: 183 | raise NotImplementedError( 184 | "No checkpoint amongst the ones we support uses this configuration" 185 | " (non-parallel residual and shared attention norm)." 186 | ) 187 | 188 | x = x + h 189 | ffn = self.mlp(self.norm_2(x)) 190 | if self.config.resid_pdrop: 191 | ffn = nn.functional.dropout(ffn, p=self.config.resid_pdrop, training=self.training) 192 | x = x + ffn 193 | return x, new_kv_cache 194 | 195 | 196 | class CausalSelfAttention(nn.Module): 197 | def __init__(self, config: Config) -> None: 198 | super().__init__() 199 | shape = (config.n_head + 2 * config.n_query_groups) * config.head_size 200 | # key, query, value projections for all heads, but in a batch 201 | self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) 202 | # output projection 203 | self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 204 | self.config = config 205 | 206 | def forward( 207 | self, 208 | x: torch.Tensor, 209 | rope: RoPECache, 210 | max_seq_length: int, 211 | mask: Optional[torch.Tensor] = None, 212 | input_pos: Optional[torch.Tensor] = None, 213 | kv_cache: Optional[KVCache] = None, 214 | ) -> Tuple[torch.Tensor, Optional[KVCache]]: 215 | from .fused_rotary_embedding import apply_rotary_emb_func 216 | 217 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 218 | 219 | qkv = self.attn(x) 220 | 221 | # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) 222 | q_per_kv = self.config.n_head // self.config.n_query_groups 223 | total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value 224 | qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) # (B, T, n_query_groups, total_qkv, hs) 225 | # qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) 226 | 227 | # split batched computation into three 228 | q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2) 229 | 230 | # repeat k and v if necessary 231 | # Peiyuan: we do not need to do this as flash attention 2 already support GQA 232 | # if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!) 233 | # # for MHA this is a no-op 234 | # k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 235 | # v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 236 | 237 | q = q.reshape(B, T, -1, self.config.head_size) # (B, T, nh_q, hs) 238 | k = k.reshape(B, T, -1, self.config.head_size) 239 | v = v.reshape(B, T, -1, self.config.head_size) 240 | 241 | cos, sin = rope 242 | 243 | # apply rope in fp32 significanly stabalize training 244 | # fused rope expect (batch_size, seqlen, nheads, headdim) 245 | q = apply_rotary_emb_func(q, cos, sin, False, True) 246 | k = apply_rotary_emb_func(k, cos, sin, False, True) 247 | 248 | # n_elem = int(self.config.rotary_percentage * self.config.head_size) 249 | 250 | # q_roped = apply_rope(q[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2)) 251 | # k_roped = apply_rope(k[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2)) 252 | # print( (q_roped - q).sum()) 253 | # q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) 254 | # k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) 255 | 256 | if kv_cache is not None: 257 | cache_k, cache_v = kv_cache 258 | cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) 259 | # check if reached token limit 260 | if input_pos[-1] >= max_seq_length: 261 | input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) 262 | # shift 1 position to the left 263 | cache_k = torch.roll(cache_k, -1, dims=1) 264 | cache_v = torch.roll(cache_v, -1, dims=1) 265 | 266 | k = cache_k.index_copy_(1, input_pos, k) 267 | v = cache_v.index_copy_(1, input_pos, v) 268 | kv_cache = k, v 269 | 270 | y = self.scaled_dot_product_attention(q, k, v, mask=mask) 271 | 272 | y = y.reshape(B, T, C) # re-assemble all head outputs side by side 273 | 274 | # output projection 275 | y = self.proj(y) 276 | 277 | return y, kv_cache 278 | 279 | def scaled_dot_product_attention( 280 | self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None 281 | ): 282 | scale = 1.0 / math.sqrt(self.config.head_size) 283 | 284 | if ( 285 | FlashAttention2Available 286 | and mask is None 287 | and q.device.type == "cuda" 288 | and q.dtype in (torch.float16, torch.bfloat16) 289 | ): 290 | from flash_attn import flash_attn_func 291 | return flash_attn_func(q, k, v, dropout_p=self.config.attn_pdrop, softmax_scale=scale, causal=True) 292 | 293 | q = q.transpose(1, 2) 294 | k = k.transpose(1, 2) 295 | v = v.transpose(1, 2) 296 | if q.size() != k.size(): 297 | k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1) 298 | v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1) 299 | y = torch.nn.functional.scaled_dot_product_attention( 300 | q, k, v, attn_mask=mask, dropout_p=self.config.attn_pdrop, scale=scale, is_causal=mask is None 301 | ) 302 | return y.transpose(1, 2) 303 | 304 | 305 | class GptNeoxMLP(nn.Module): 306 | def __init__(self, config: Config) -> None: 307 | super().__init__() 308 | self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 309 | self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) 310 | 311 | def forward(self, x: torch.Tensor) -> torch.Tensor: 312 | x = self.fc(x) 313 | x = torch.nn.functional.gelu(x) 314 | return self.proj(x) 315 | 316 | 317 | class LLaMAMLP(nn.Module): 318 | def __init__(self, config: Config) -> None: 319 | super().__init__() 320 | # self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 321 | # self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 322 | # self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) 323 | self.swiglu = SwiGLU(config.n_embd,config.intermediate_size, bias=False, _pack_weights=False) 324 | def forward(self, x: torch.Tensor) -> torch.Tensor: 325 | # x_fc_1 = self.fc_1(x) 326 | # x_fc_2 = self.fc_2(x) 327 | # x = torch.nn.functional.silu(x_fc_1) * x_fc_2 328 | # return self.proj(x) 329 | return self.swiglu(x) 330 | 331 | 332 | def build_rope_cache( 333 | seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1 334 | ) -> RoPECache: 335 | """Enhanced Transformer with Rotary Position Embedding. 336 | 337 | Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ 338 | transformers/rope/__init__.py. MIT License: 339 | https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. 340 | """ 341 | # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ 342 | theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem)) 343 | 344 | # Create position indexes `[0, 1, ..., seq_len - 1]` 345 | seq_idx = torch.arange(seq_len, device=device) / condense_ratio 346 | 347 | # Calculate the product of position index and $\theta_i$ 348 | idx_theta = torch.outer(seq_idx, theta) 349 | 350 | cos, sin = torch.cos(idx_theta), torch.sin(idx_theta) 351 | 352 | # added by peiyuan to ensure same data type with q, k, to use fused rotary embedding 353 | if dtype == torch.bfloat16: 354 | return cos.bfloat16(), sin.bfloat16() 355 | # this is to mimic the behaviour of complex32, else we will get different results 356 | if dtype in (torch.float16, torch.bfloat16, torch.int8): 357 | return cos.half(), sin.half() 358 | return cos, sin 359 | 360 | 361 | def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: 362 | head_size = x.size(-1) 363 | x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) 364 | x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) 365 | rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) 366 | roped = (x * cos) + (rotated * sin) 367 | return roped.type_as(x) 368 | -------------------------------------------------------------------------------- /model_training/lit_gpt/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for training and inference.""" 2 | 3 | import pickle 4 | import sys 5 | import warnings 6 | from contextlib import contextmanager 7 | from functools import partial 8 | from io import BytesIO 9 | from pathlib import Path 10 | from types import MethodType 11 | from typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, Union 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.utils._device 16 | from lightning.fabric.loggers import CSVLogger 17 | from torch.serialization import normalize_storage_type 18 | 19 | 20 | def find_multiple(n: int, k: int) -> int: 21 | assert k > 0 22 | if n % k == 0: 23 | return n 24 | return n + k - (n % k) 25 | 26 | 27 | def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: 28 | return sum(p.numel() for p in module.parameters() if requires_grad is None or p.requires_grad == requires_grad) 29 | 30 | 31 | @contextmanager 32 | def quantization(mode: Optional[str] = None): 33 | if mode is None: 34 | yield 35 | return 36 | 37 | if mode == "bnb.int8": 38 | from quantize.bnb import InferenceLinear8bitLt 39 | 40 | quantized_linear_cls = InferenceLinear8bitLt 41 | elif mode == "bnb.fp4": 42 | from quantize.bnb import Linear4bit 43 | 44 | # Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses 45 | class QuantizedLinear(Linear4bit): 46 | def __init__(self, *args, **kwargs): 47 | super().__init__(*args, quant_type="fp4", compress_statistics=False, **kwargs) 48 | 49 | quantized_linear_cls = QuantizedLinear 50 | elif mode == "bnb.fp4-dq": 51 | from quantize.bnb import Linear4bit 52 | 53 | class QuantizedLinear(Linear4bit): 54 | def __init__(self, *args, **kwargs): 55 | super().__init__(*args, quant_type="fp4", compress_statistics=True, **kwargs) 56 | 57 | quantized_linear_cls = QuantizedLinear 58 | elif mode == "bnb.nf4": 59 | from quantize.bnb import Linear4bit 60 | 61 | class QuantizedLinear(Linear4bit): 62 | def __init__(self, *args, **kwargs): 63 | super().__init__(*args, quant_type="nf4", compress_statistics=False, **kwargs) 64 | 65 | quantized_linear_cls = QuantizedLinear 66 | elif mode == "bnb.nf4-dq": 67 | from quantize.bnb import Linear4bit 68 | 69 | class QuantizedLinear(Linear4bit): 70 | def __init__(self, *args, **kwargs): 71 | super().__init__(*args, quant_type="nf4", compress_statistics=True, **kwargs) 72 | 73 | quantized_linear_cls = QuantizedLinear 74 | elif mode == "gptq.int4": 75 | from quantize.gptq import ColBlockQuantizedLinear 76 | 77 | class QuantizedLinear(ColBlockQuantizedLinear): 78 | def __init__(self, *args, **kwargs): 79 | super().__init__(*args, bits=4, tile_cols=-1, **kwargs) 80 | 81 | quantized_linear_cls = QuantizedLinear 82 | else: 83 | raise ValueError(f"Unknown quantization mode: {mode}") 84 | 85 | torch_linear_cls = torch.nn.Linear 86 | torch.nn.Linear = quantized_linear_cls 87 | yield 88 | torch.nn.Linear = torch_linear_cls 89 | 90 | 91 | # this is taken from torchhacks https://github.com/lernapparat/torchhacks 92 | 93 | 94 | class NotYetLoadedTensor: 95 | def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): 96 | self.metatensor = metatensor 97 | self.archiveinfo = archiveinfo 98 | self.storageinfo = storageinfo 99 | self.rebuild_args = rebuild_args 100 | 101 | @classmethod 102 | def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None): 103 | ret = func(*args) 104 | if isinstance(ret, NotYetLoadedTensor): 105 | old_lt = ret._load_tensor 106 | 107 | def _load_tensor(): 108 | t = old_lt() 109 | return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) 110 | 111 | ret._load_tensor = _load_tensor 112 | return ret 113 | return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) 114 | 115 | @classmethod 116 | def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None): 117 | if isinstance(data, NotYetLoadedTensor): 118 | old_lt = data._load_tensor 119 | 120 | def _load_tensor(): 121 | t = old_lt() 122 | return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) 123 | 124 | data._load_tensor = _load_tensor 125 | return data 126 | return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) 127 | 128 | @classmethod 129 | def rebuild_tensor_v2( 130 | cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, *, archiveinfo=None 131 | ): 132 | rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata) 133 | metatensor = torch._utils._rebuild_tensor_v2( 134 | storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata 135 | ) 136 | storageinfo = storage.archiveinfo 137 | return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) 138 | 139 | def _load_tensor(self): 140 | name, storage_cls, fn, device, size = self.storageinfo 141 | dtype = self.metatensor.dtype 142 | 143 | uts = ( 144 | self.archiveinfo.zipfile_context.zf.get_storage_from_record( 145 | f"data/{fn}", size * torch._utils._element_size(dtype), torch.UntypedStorage 146 | ) 147 | ._typed_storage() 148 | ._untyped_storage 149 | ) 150 | with warnings.catch_warnings(): 151 | warnings.simplefilter("ignore") 152 | storage = torch.storage.TypedStorage(wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True) 153 | return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) 154 | 155 | @classmethod 156 | def __torch_function__(cls, func, types, args=(), kwargs=None): 157 | if kwargs is None: 158 | kwargs = {} 159 | loaded_args = [(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args] 160 | return func(*loaded_args, **kwargs) 161 | # gc.collect would be costly here, maybe do it optionally 162 | 163 | def __getattr__(self, name): 164 | # properties 165 | ## TODO: device, is_...?? 166 | ## TODO: mH, mT, H, T, data, imag, real 167 | ## name ??? 168 | if name in { 169 | "dtype", 170 | "grad", 171 | "grad_fn", 172 | "layout", 173 | "names", 174 | "ndim", 175 | "output_nr", 176 | "requires_grad", 177 | "retains_grad", 178 | "shape", 179 | "volatile", 180 | }: 181 | return getattr(self.metatensor, name) 182 | if name in {"size"}: 183 | return getattr(self.metatensor, name) 184 | # materializing with contiguous is needed for quantization 185 | if name in {"contiguous"}: 186 | return getattr(self._load_tensor(), name) 187 | 188 | raise AttributeError(f"{type(self)} does not have {name}") 189 | 190 | def __repr__(self): 191 | return f"NotYetLoadedTensor({repr(self.metatensor)})" 192 | 193 | 194 | class LazyLoadingUnpickler(pickle.Unpickler): 195 | def __init__(self, file, zipfile_context): 196 | super().__init__(file) 197 | self.zipfile_context = zipfile_context 198 | 199 | def find_class(self, module, name): 200 | res = super().find_class(module, name) 201 | if module == "torch._utils" and name == "_rebuild_tensor_v2": 202 | return partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) 203 | if module == "torch._tensor" and name == "_rebuild_from_type_v2": 204 | return partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) 205 | if module == "torch._utils" and name == "_rebuild_parameter": 206 | return partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) 207 | return res 208 | 209 | def persistent_load(self, pid): 210 | name, cls, fn, device, size = pid 211 | with warnings.catch_warnings(): 212 | warnings.simplefilter("ignore") 213 | s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") 214 | s.archiveinfo = pid 215 | return s 216 | 217 | 218 | class lazy_load: 219 | def __init__(self, fn): 220 | self.zf = torch._C.PyTorchFileReader(str(fn)) 221 | with BytesIO(self.zf.get_record("data.pkl")) as pkl: 222 | mup = LazyLoadingUnpickler(pkl, self) 223 | self.sd = mup.load() 224 | 225 | def __enter__(self): 226 | return self.sd 227 | 228 | def __exit__(self, exc_type, exc_val, exc_tb): 229 | del self.zf # I don't think there is a way to force closing... 230 | self.zf = None 231 | 232 | 233 | def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: 234 | files = { 235 | "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), 236 | "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), 237 | "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( 238 | checkpoint_dir / "tokenizer.model" 239 | ).is_file(), 240 | "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), 241 | } 242 | if checkpoint_dir.is_dir(): 243 | if all(files.values()): 244 | # we're good 245 | return 246 | problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" 247 | else: 248 | problem = " is not a checkpoint directory" 249 | 250 | # list locally available checkpoints 251 | available = list(Path("checkpoints").glob("*/*")) 252 | if available: 253 | options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available]) 254 | extra = f"\nYou have downloaded locally:{options}\n" 255 | else: 256 | extra = "" 257 | 258 | error_message = ( 259 | f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." 260 | "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n" 261 | f"{extra}\nSee all download options by running:\n python scripts/download.py" 262 | ) 263 | print(error_message, file=sys.stderr) 264 | raise SystemExit(1) 265 | 266 | 267 | class SavingProxyForStorage: 268 | def __init__(self, obj, saver, protocol_version=5): 269 | self.protocol_version = protocol_version 270 | self.saver = saver 271 | if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): 272 | raise TypeError(f"expected storage, not {type(obj)}") 273 | 274 | # this logic is taken from PyTorch 2.0+ torch/serialization.py 275 | if isinstance(obj, torch.storage.TypedStorage): 276 | # PT upstream wants to deprecate this eventually... 277 | storage = obj._untyped_storage 278 | storage_type_str = obj._pickle_storage_type() 279 | storage_type = getattr(torch, storage_type_str) 280 | storage_numel = obj._size() 281 | else: 282 | storage = obj 283 | storage_type = normalize_storage_type(type(obj)) 284 | storage_numel = storage.nbytes() 285 | 286 | storage_key = saver._write_storage_and_return_key(storage) 287 | location = torch.serialization.location_tag(storage) 288 | 289 | self.storage_info = ("storage", storage_type, storage_key, location, storage_numel) 290 | 291 | def __reduce_ex__(self, protocol_version): 292 | assert False, "this should be handled with out of band" 293 | 294 | 295 | class SavingProxyForTensor: 296 | def __init__(self, tensor, saver, protocol_version=5): 297 | self.protocol_version = protocol_version 298 | self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(protocol_version) 299 | assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" 300 | storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) 301 | self.reduce_args = (storage_proxy, *other_reduce_args) 302 | 303 | def __reduce_ex__(self, protocol_version): 304 | if protocol_version != self.protocol_version: 305 | raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}") 306 | return self.reduce_ret_fn, self.reduce_args 307 | 308 | 309 | class IncrementalPyTorchPickler(pickle.Pickler): 310 | def __init__(self, saver, *args, **kwargs): 311 | super().__init__(*args, **kwargs) 312 | self.storage_dtypes = {} 313 | self.saver = saver 314 | self.id_map = {} 315 | 316 | # this logic is taken from PyTorch 2.0+ torch/serialization.py 317 | def persistent_id(self, obj): 318 | # FIXME: the docs say that persistent_id should only return a string 319 | # but torch store returns tuples. This works only in the binary protocol 320 | # see 321 | # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects 322 | # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 323 | if isinstance(obj, SavingProxyForStorage): 324 | return obj.storage_info 325 | 326 | if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): 327 | if isinstance(obj, torch.storage.TypedStorage): 328 | # TODO: Once we decide to break serialization FC, this case 329 | # can be deleted 330 | storage = obj._untyped_storage 331 | storage_dtype = obj.dtype 332 | storage_type_str = obj._pickle_storage_type() 333 | storage_type = getattr(torch, storage_type_str) 334 | storage_numel = obj._size() 335 | 336 | else: 337 | storage = obj 338 | storage_dtype = torch.uint8 339 | storage_type = normalize_storage_type(type(obj)) 340 | storage_numel = storage.nbytes() 341 | 342 | # If storage is allocated, ensure that any other saved storages 343 | # pointing to the same data all have the same dtype. If storage is 344 | # not allocated, don't perform this check 345 | if storage.data_ptr() != 0: 346 | if storage.data_ptr() in self.storage_dtypes: 347 | if storage_dtype != self.storage_dtypes[storage.data_ptr()]: 348 | raise RuntimeError( 349 | "Cannot save multiple tensors or storages that view the same data as different types" 350 | ) 351 | else: 352 | self.storage_dtypes[storage.data_ptr()] = storage_dtype 353 | 354 | storage_key = self.id_map.get(storage._cdata) 355 | if storage_key is None: 356 | storage_key = self.saver._write_storage_and_return_key(storage) 357 | self.id_map[storage._cdata] = storage_key 358 | location = torch.serialization.location_tag(storage) 359 | 360 | return ("storage", storage_type, storage_key, location, storage_numel) 361 | 362 | return None 363 | 364 | 365 | class incremental_save: 366 | def __init__(self, name): 367 | self.name = name 368 | self.zipfile = torch._C.PyTorchFileWriter(str(name)) 369 | self.has_saved = False 370 | self.next_key = 0 371 | 372 | def __enter__(self): 373 | return self 374 | 375 | def store_early(self, tensor): 376 | if isinstance(tensor, torch.Tensor): 377 | return SavingProxyForTensor(tensor, self) 378 | raise TypeError(f"can only store tensors early, not {type(tensor)}") 379 | 380 | def save(self, obj): 381 | if self.has_saved: 382 | raise RuntimeError("have already saved") 383 | # Write the pickle data for `obj` 384 | data_buf = BytesIO() 385 | pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) 386 | pickler.dump(obj) 387 | data_value = data_buf.getvalue() 388 | self.zipfile.write_record("data.pkl", data_value, len(data_value)) 389 | self.has_saved = True 390 | 391 | def _write_storage_and_return_key(self, storage): 392 | if self.has_saved: 393 | raise RuntimeError("have already saved") 394 | key = self.next_key 395 | self.next_key += 1 396 | name = f"data/{key}" 397 | if storage.device.type != "cpu": 398 | storage = storage.cpu() 399 | num_bytes = storage.nbytes() 400 | self.zipfile.write_record(name, storage.data_ptr(), num_bytes) 401 | return key 402 | 403 | def __exit__(self, type, value, traceback): 404 | self.zipfile.write_end_of_file() 405 | 406 | 407 | T = TypeVar("T") 408 | 409 | 410 | def step_csv_logger(*args: Any, cls: Type[T] = CSVLogger, **kwargs: Any) -> T: 411 | logger = cls(*args, **kwargs) 412 | 413 | def merge_by(dicts, key): 414 | from collections import defaultdict 415 | 416 | out = defaultdict(dict) 417 | for d in dicts: 418 | if key in d: 419 | out[d[key]].update(d) 420 | return [v for _, v in sorted(out.items())] 421 | 422 | def save(self) -> None: 423 | """Overridden to merge CSV by the step number.""" 424 | import csv 425 | 426 | if not self.metrics: 427 | return 428 | metrics = merge_by(self.metrics, "step") 429 | keys = sorted({k for m in metrics for k in m}) 430 | with self._fs.open(self.metrics_file_path, "w", newline="") as f: 431 | writer = csv.DictWriter(f, fieldnames=keys) 432 | writer.writeheader() 433 | writer.writerows(metrics) 434 | 435 | logger.experiment.save = MethodType(save, logger.experiment) 436 | 437 | return logger 438 | 439 | 440 | def chunked_cross_entropy( 441 | logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128 442 | ) -> torch.Tensor: 443 | # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate 444 | # the memory usage in fine-tuning settings with low number of parameters. 445 | # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing 446 | # the memory spike's magnitude 447 | 448 | # lm_head was chunked (we are fine-tuning) 449 | if isinstance(logits, list): 450 | # don't want to chunk cross entropy 451 | if chunk_size == 0: 452 | logits = torch.cat(logits, dim=1) 453 | logits = logits.reshape(-1, logits.size(-1)) 454 | targets = targets.reshape(-1) 455 | return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) 456 | 457 | # chunk cross entropy 458 | logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] 459 | target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)] 460 | loss_chunks = [ 461 | torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") 462 | for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) 463 | ] 464 | return torch.cat(loss_chunks).mean() 465 | 466 | # no chunking at all 467 | logits = logits.reshape(-1, logits.size(-1)) 468 | targets = targets.reshape(-1) 469 | if chunk_size == 0: 470 | return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) 471 | 472 | # lm_head wasn't chunked, chunk cross entropy 473 | logit_chunks = logits.split(chunk_size) 474 | target_chunks = targets.split(chunk_size) 475 | loss_chunks = [ 476 | torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") 477 | for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) 478 | ] 479 | return torch.cat(loss_chunks).mean() 480 | 481 | 482 | def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: 483 | for checkpoint_name, attribute_name in mapping.items(): 484 | full_checkpoint_name = prefix + checkpoint_name 485 | if full_checkpoint_name in state_dict: 486 | full_attribute_name = prefix + attribute_name 487 | state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) 488 | return state_dict 489 | 490 | 491 | def get_default_supported_precision(training: bool, tpu: bool = False) -> str: 492 | """Return default precision that is supported by the hardware. 493 | 494 | Args: 495 | training: `-mixed` or `-true` version of the precision to use 496 | tpu: whether TPU device is used 497 | 498 | Returns: 499 | default precision that is suitable for the task and is supported by the hardware 500 | """ 501 | if tpu: 502 | return "32-true" 503 | if not torch.cuda.is_available() or torch.cuda.is_bf16_supported(): 504 | return "bf16-mixed" if training else "bf16-true" 505 | return "16-mixed" if training else "16-true" 506 | -------------------------------------------------------------------------------- /model_training/lit_gpt/speed_monitor.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import deque 3 | from contextlib import nullcontext 4 | from typing import Any, Callable, Deque, Dict, Optional 5 | 6 | import torch 7 | from lightning import Callback, Fabric, LightningModule, Trainer 8 | from lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only 9 | from lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only 10 | from torch.utils.flop_counter import FlopCounterMode 11 | import math 12 | from lit_gpt import GPT, Config 13 | from lit_gpt.utils import num_parameters 14 | 15 | GPU_AVAILABLE_FLOPS = { 16 | # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet 17 | # nvidia publishes spec sheet with a 2x sparsity factor 18 | "h100-sxm": { 19 | "64-true": 67e12, 20 | "32-true": 67e12, 21 | "16-true": 1.979e15 / 2, 22 | "16-mixed": 1.979e15 / 2, 23 | "bf16-true": 1.979e15 / 2, 24 | "bf16-mixed": 1.979e15 / 2, 25 | "8-true": 3.958e15 / 2, 26 | "8-mixed": 3.958e15 / 2, 27 | }, 28 | "h100-pcie": { 29 | "64-true": 51e12, 30 | "32-true": 51e12, 31 | "16-true": 1.513e15 / 2, 32 | "16-mixed": 1.513e15 / 2, 33 | "bf16-true": 1.513e15 / 2, 34 | "bf16-mixed": 1.513e15 / 2, 35 | "8-true": 3.026e15 / 2, 36 | "8-mixed": 3.026e15 / 2, 37 | }, 38 | # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf 39 | # sxm and pcie have same flop counts 40 | "a100": { 41 | "64-true": 19.5e12, 42 | "32-true": 19.5e12, 43 | "16-true": 312e12, 44 | "16-mixed": 312e12, 45 | "bf16-true": 312e12, 46 | "bf16-mixed": 312e12, 47 | }, 48 | # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf 49 | "a10g": {"32-true": 31.2e12, "16-true": 125e12, "16-mixed": 125e12, "bf16-true": 125e12, "bf16-mixed": 125e12}, 50 | # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf 51 | "v100-sxm": {"64-true": 7.8e12, "32-true": 15.7e12, "16-true": 125e12, "16-mixed": 125e12}, 52 | "v100-pcie": {"64-true": 7e12, "32-true": 14e12, "16-true": 112e12, "16-mixed": 112e12}, 53 | "v100s-pcie": {"64-true": 8.2e12, "32-true": 16.4e12, "16-true": 130e12, "16-mixed": 130e12}, 54 | # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf 55 | # sxm and pcie have same flop counts 56 | "t4": {"32-true": 8.1e12, "16-true": 65e12, "16-mixed": 65e12, "8-true": 130e12, "int4": 260e12}, 57 | # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf 58 | "quadro rtx 5000": {"32-true": 11.2e12, "16-true": 89.2e12, "16-mixed": 89.2e12}, 59 | } 60 | 61 | TPU_AVAILABLE_FLOPS = { 62 | # flop count for each TPU generation is the same for all precisions 63 | # since bfloat16 precision is always used for performing matrix operations 64 | # for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16 65 | # source: https://arxiv.org/pdf/1907.10701.pdf 66 | "v2": 45e12, 67 | # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3 68 | "v3": 123e12, 69 | # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4 70 | "v4": 275e12, 71 | } 72 | 73 | 74 | def get_flops_available(device: torch.device, precision: str) -> Optional[float]: 75 | if device.type == "cuda": 76 | device_name = torch.cuda.get_device_name(device).lower() 77 | if "h100" in device_name and "hbm3" in device_name: 78 | device_name = "h100-sxm" 79 | elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name): 80 | device_name = "h100-pcie" 81 | elif "a100" in device_name: 82 | device_name = "a100" 83 | elif "a10g" in device_name: 84 | device_name = "a10g" 85 | elif "v100-sxm" in device_name: 86 | device_name = "v100-sxm" 87 | elif "v100-pcie" in device_name: 88 | device_name = "v100-pcie" 89 | elif "t4" in device_name: 90 | device_name = "t4" 91 | elif "quadro rtx 5000" in device_name: 92 | device_name = "quadro rtx 5000" 93 | else: 94 | device_name = None 95 | 96 | if device_name is not None: 97 | try: 98 | return int(GPU_AVAILABLE_FLOPS[device_name][precision]) 99 | except KeyError: 100 | raise KeyError( 101 | f"flop count not found for {device_name} with precision: {precision}; " 102 | "MFU cannot be calculated and reported." 103 | ) 104 | elif device.type == "xla": 105 | from torch_xla.experimental import tpu 106 | 107 | device_name = tpu.get_tpu_env()["TYPE"].lower() 108 | try: 109 | return int(TPU_AVAILABLE_FLOPS[device_name]) 110 | except KeyError: 111 | raise KeyError( 112 | f"flop count not found for {device_name} with precision: {precision}; " 113 | "MFU cannot be calculated and reported." 114 | ) 115 | 116 | return None 117 | 118 | 119 | # Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py 120 | 121 | 122 | class SpeedMonitorBase: 123 | """Logs the training throughput and utilization. 124 | 125 | +-------------------------------------+-----------------------------------------------------------+ 126 | | Key | Logged data | 127 | +=====================================+===========================================================+ 128 | | | Rolling average (over `window_size` most recent | 129 | | `throughput/batches_per_sec` | batches) of the number of batches processed per second | 130 | | | | 131 | +-------------------------------------+-----------------------------------------------------------+ 132 | | | Rolling average (over `window_size` most recent | 133 | | `throughput/samples_per_sec` | batches) of the number of samples processed per second | 134 | | | | 135 | +-------------------------------------+-----------------------------------------------------------+ 136 | | | Rolling average (over `window_size` most recent | 137 | | `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. | 138 | | | This may include padding depending on dataset | 139 | +-------------------------------------+-----------------------------------------------------------+ 140 | | | Estimates flops by `flops_per_batch * batches_per_sec` | 141 | | `throughput/flops_per_sec` | | 142 | | | | 143 | +-------------------------------------+-----------------------------------------------------------+ 144 | | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size | 145 | +-------------------------------------+-----------------------------------------------------------+ 146 | | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size | 147 | +-------------------------------------+-----------------------------------------------------------+ 148 | | | `throughput/tokens_per_sec` divided by world size. This | 149 | | `throughput/device/tokens_per_sec` | may include pad tokens depending on dataset | 150 | | | | 151 | +-------------------------------------+-----------------------------------------------------------+ 152 | | | `throughput/flops_per_sec` divided by world size. Only | 153 | | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` | 154 | | | | 155 | +-------------------------------------+-----------------------------------------------------------+ 156 | | | `throughput/device/flops_per_sec` divided by world size. | 157 | | `throughput/device/mfu` | | 158 | | | | 159 | +-------------------------------------+-----------------------------------------------------------+ 160 | | `time/train` | Total elapsed training time | 161 | +-------------------------------------+-----------------------------------------------------------+ 162 | | `time/val` | Total elapsed validation time | 163 | +-------------------------------------+-----------------------------------------------------------+ 164 | | `time/total` | Total elapsed time (time/train + time/val) | 165 | +-------------------------------------+-----------------------------------------------------------+ 166 | 167 | Notes: 168 | - The implementation assumes that devices are homogeneous as it normalizes by the world size. 169 | - Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or 170 | batches/sec to measure throughput under this circumstance. 171 | - Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``. 172 | There is no widespread, realistic, and reliable implementation to compute them. 173 | We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which 174 | will almost always be an overestimate when compared to the true value. 175 | 176 | Args: 177 | window_size (int, optional): Number of batches to use for a rolling average of throughput. 178 | Defaults to 100. 179 | time_unit (str, optional): Time unit to use for `time` logging. Can be one of 180 | 'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'. 181 | """ 182 | 183 | def __init__( 184 | self, 185 | flops_available: float, 186 | log_dict: Callable[[Dict, int], None], 187 | window_size: int = 100, 188 | time_unit: str = "hours", 189 | log_iter_interval: int = 1, 190 | ): 191 | self.flops_available = flops_available 192 | self.log_dict = log_dict 193 | self.log_iter_interval = log_iter_interval 194 | # Track the batch num samples and wct to compute throughput over a window of batches 195 | self.history_samples: Deque[int] = deque(maxlen=window_size + 1) 196 | self.history_training_loss: Deque[int] = deque(maxlen=log_iter_interval) 197 | self.history_wct: Deque[float] = deque(maxlen=window_size + 1) 198 | self.history_lengths: Deque[int] = deque(maxlen=window_size + 1) 199 | self.history_flops: Deque[int] = deque(maxlen=window_size + 1) 200 | 201 | self.divider = 1 202 | if time_unit == "seconds": 203 | self.divider = 1 204 | elif time_unit == "minutes": 205 | self.divider = 60 206 | elif time_unit == "hours": 207 | self.divider = 60 * 60 208 | elif time_unit == "days": 209 | self.divider = 60 * 60 * 24 210 | else: 211 | raise ValueError( 212 | f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".' 213 | ) 214 | 215 | # Keep track of time spent evaluating 216 | self.total_eval_wct = 0.0 217 | self.iter = -1 218 | 219 | def on_train_batch_end( 220 | self, 221 | samples: int, # total samples seen (per device) 222 | train_elapsed: float, # total training time (seconds) 223 | world_size: int, 224 | step_count: int, 225 | flops_per_batch: Optional[int] = None, # (per device) 226 | lengths: Optional[int] = None, # total length of the samples seen (per device) 227 | train_loss: Optional[float] = None, 228 | ): 229 | self.iter += 1 230 | metrics = {} 231 | 232 | self.history_samples.append(samples) 233 | self.history_training_loss.append(train_loss) 234 | if lengths is not None: 235 | self.history_lengths.append(lengths) 236 | # if lengths are passed, there should be as many values as samples 237 | assert len(self.history_samples) == len(self.history_lengths) 238 | self.history_wct.append(train_elapsed) 239 | if len(self.history_wct) == self.history_wct.maxlen: 240 | elapsed_batches = len(self.history_samples) - 1 241 | elapsed_samples = self.history_samples[-1] - self.history_samples[0] 242 | elapsed_wct = self.history_wct[-1] - self.history_wct[0] 243 | samples_per_sec = elapsed_samples * world_size / elapsed_wct 244 | dev_samples_per_sec = elapsed_samples / elapsed_wct 245 | metrics.update( 246 | { 247 | "throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct, 248 | "throughput/samples_per_sec": samples_per_sec, 249 | "throughput/device/batches_per_sec": elapsed_batches / elapsed_wct, 250 | "throughput/device/samples_per_sec": dev_samples_per_sec, 251 | } 252 | ) 253 | if lengths is not None: 254 | elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0]) 255 | avg_length = elapsed_lengths / elapsed_batches 256 | metrics.update( 257 | { 258 | "throughput/tokens_per_sec": samples_per_sec * avg_length, 259 | "throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length, 260 | "total_tokens": avg_length * world_size * samples, 261 | } 262 | ) 263 | if train_loss is not None: 264 | avg_loss = sum(self.history_training_loss) / len(self.history_training_loss) 265 | metrics.update( 266 | { 267 | "metric/train_loss": avg_loss, 268 | "metric/train_ppl": math.exp(avg_loss) 269 | } 270 | ) 271 | 272 | if flops_per_batch is not None: 273 | # sum of flops per batch across ranks 274 | self.history_flops.append(flops_per_batch * world_size) 275 | if len(self.history_flops) == self.history_flops.maxlen: 276 | elapsed_flops = sum(self.history_flops) - self.history_flops[0] 277 | elapsed_wct = self.history_wct[-1] - self.history_wct[0] 278 | flops_per_sec = elapsed_flops / elapsed_wct 279 | device_flops_per_sec = flops_per_sec / world_size 280 | metrics.update( 281 | {"throughput/flops_per_sec": flops_per_sec, "throughput/device/flops_per_sec": device_flops_per_sec} 282 | ) 283 | if self.flops_available: 284 | metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available 285 | 286 | metrics.update( 287 | { 288 | "time/train": train_elapsed / self.divider, 289 | "time/val": self.total_eval_wct / self.divider, 290 | "time/total": (train_elapsed + self.total_eval_wct) / self.divider, 291 | "samples": samples, 292 | } 293 | ) 294 | if self.iter % self.log_iter_interval == 0: 295 | self.log_dict(metrics, step_count) 296 | 297 | def eval_end(self, eval_elapsed: float): 298 | self.total_eval_wct += eval_elapsed # seconds 299 | 300 | 301 | class SpeedMonitorFabric(SpeedMonitorBase): 302 | def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None: 303 | # TODO: this will not work properly if a precision plugin is passed to Fabric 304 | flops_available = get_flops_available(fabric.device, fabric._connector._precision_input) 305 | super().__init__(flops_available, fabric.log_dict, *args, **kwargs) 306 | 307 | @fabric_rank_zero_only 308 | def on_train_batch_end(self, *args: Any, **kwargs: Any): 309 | super().on_train_batch_end(*args, **kwargs) 310 | 311 | 312 | class SpeedMonitorCallback(Callback): 313 | def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None: 314 | super().__init__() 315 | self.speed_monitor: Optional[SpeedMonitorBase] = None 316 | self.speed_monitor_kwargs = kwargs 317 | self.length_fn = length_fn 318 | self.batch_size = batch_size 319 | self.eval_t0: int = 0 320 | self.train_t0: int = 0 321 | self.total_lengths: int = 0 322 | 323 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: 324 | if self.speed_monitor is not None: 325 | return # already setup 326 | # TODO: this will not work properly if a precision plugin is passed to Trainer 327 | flops_available = get_flops_available( 328 | trainer.strategy.root_device, trainer._accelerator_connector._precision_flag 329 | ) 330 | self.speed_monitor = SpeedMonitorBase(flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs) 331 | 332 | @trainer_rank_zero_only 333 | def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 334 | if trainer.fit_loop._should_accumulate(): 335 | return 336 | 337 | self.train_t0 = time.perf_counter() 338 | 339 | @trainer_rank_zero_only 340 | def on_train_batch_end( 341 | self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int 342 | ) -> None: 343 | self.total_lengths += self.length_fn(batch) 344 | if trainer.fit_loop._should_accumulate(): 345 | return 346 | train_elapsed = time.perf_counter() - self.train_t0 347 | assert self.speed_monitor is not None 348 | iter_num = trainer.fit_loop.total_batch_idx 349 | assert (measured_flops := pl_module.measured_flops) is not None 350 | self.speed_monitor.on_train_batch_end( 351 | (iter_num + 1) * self.batch_size, 352 | train_elapsed, 353 | # this assumes that device FLOPs are the same and that all devices have the same batch size 354 | trainer.world_size, 355 | flops_per_batch=measured_flops, 356 | lengths=self.total_lengths, 357 | ) 358 | 359 | @trainer_rank_zero_only 360 | def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 361 | self.eval_t0 = time.perf_counter() 362 | 363 | @trainer_rank_zero_only 364 | def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 365 | eval_elapsed = time.perf_counter() - self.eval_t0 366 | assert self.speed_monitor is not None 367 | self.speed_monitor.eval_end(eval_elapsed) 368 | 369 | 370 | def flops_per_param(config: Config, n_params: int) -> int: 371 | flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation 372 | # this assumes that all samples have a fixed length equal to the block size 373 | # which is most likely false during finetuning 374 | flops_per_seq = flops_per_token * config.block_size 375 | attn_flops_per_seq = config.n_layer * 2 * 2 * (config.n_embd * (config.block_size**2)) 376 | return flops_per_seq + attn_flops_per_seq 377 | 378 | 379 | def estimate_flops(model: GPT) -> int: 380 | """Measures estimated FLOPs for MFU. 381 | 382 | Refs: 383 | * https://ar5iv.labs.arxiv.org/html/2205.05198#A1 384 | * https://ar5iv.labs.arxiv.org/html/2204.02311#A2 385 | """ 386 | # using all parameters for this is a naive over estimation because not all model parameters actually contribute to 387 | # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage 388 | # (~10%) compared to the measured FLOPs, making those lower but more realistic. 389 | # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper. 390 | n_trainable_params = num_parameters(model, requires_grad=True) 391 | trainable_flops = flops_per_param(model.config, n_trainable_params) 392 | # forward + backward + gradients (assumes no gradient accumulation) 393 | ops_per_step = 3 if model.training else 1 394 | n_frozen_params = num_parameters(model, requires_grad=False) 395 | frozen_flops = flops_per_param(model.config, n_frozen_params) 396 | # forward + backward 397 | frozen_ops_per_step = 2 if model.training else 1 398 | return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops 399 | 400 | 401 | def measure_flops(model: GPT, x: torch.Tensor) -> int: 402 | """Measures real FLOPs for HFU""" 403 | flop_counter = FlopCounterMode(model, display=False) 404 | ctx = nullcontext() if model.training else torch.no_grad() 405 | with ctx, flop_counter: 406 | y = model(x) 407 | if model.training: 408 | y.sum().backward() 409 | return flop_counter.get_total_flops() 410 | -------------------------------------------------------------------------------- /data/test_mixture_1m.csv: -------------------------------------------------------------------------------- 1 | index,train_the_pile_arxiv,train_the_pile_freelaw,train_the_pile_nih_exporter,train_the_pile_pubmed_central,train_the_pile_wikipedia_en,train_the_pile_dm_mathematics,train_the_pile_github,train_the_pile_philpapers,train_the_pile_stackexchange,train_the_pile_enron_emails,train_the_pile_gutenberg_pg_19,train_the_pile_pile_cc,train_the_pile_ubuntu_irc,train_the_pile_europarl,train_the_pile_hackernews,train_the_pile_pubmed_abstracts,train_the_pile_uspto_backgrounds 2 | 1,0.08,0.0,0.0,0.464,0.0,0.0,0.0,0.0,0.0,0.0,0.015,0.353,0.001,0.0,0.042,0.039,0.005 3 | 2,0.035,0.016,0.0,0.022,0.0,0.04,0.0,0.001,0.097,0.0,0.0,0.632,0.066,0.0,0.0,0.0,0.092 4 | 3,0.014,0.0,0.004,0.022,0.003,0.084,0.477,0.0,0.0,0.0,0.209,0.052,0.0,0.002,0.013,0.093,0.027 5 | 4,0.215,0.0,0.0,0.238,0.012,0.0,0.05,0.009,0.186,0.0,0.258,0.0,0.0,0.001,0.032,0.0,0.0 6 | 5,0.732,0.0,0.0,0.0,0.004,0.027,0.035,0.001,0.113,0.004,0.0,0.005,0.04,0.037,0.0,0.0,0.003 7 | 6,0.024,0.034,0.0,0.097,0.001,0.015,0.338,0.0,0.089,0.001,0.086,0.117,0.005,0.0,0.105,0.0,0.089 8 | 7,0.007,0.0,0.0,0.0,0.0,0.0,0.988,0.0,0.0,0.002,0.0,0.0,0.0,0.0,0.0,0.002,0.0 9 | 8,0.0,0.0,0.0,0.005,0.0,0.0,0.0,0.0,0.865,0.0,0.111,0.0,0.0,0.0,0.0,0.0,0.019 10 | 9,0.0,0.0,0.021,0.001,0.0,0.013,0.005,0.001,0.0,0.0,0.0,0.489,0.016,0.0,0.0,0.001,0.454 11 | 10,0.189,0.0,0.0,0.752,0.0,0.0,0.0,0.0,0.0,0.0,0.021,0.019,0.003,0.0,0.004,0.0,0.012 12 | 11,0.0,0.011,0.0,0.0,0.02,0.0,0.328,0.0,0.204,0.023,0.023,0.0,0.008,0.0,0.0,0.0,0.383 13 | 12,0.0,0.021,0.0,0.0,0.0,0.008,0.0,0.006,0.026,0.0,0.0,0.778,0.0,0.002,0.012,0.016,0.131 14 | 13,0.0,0.004,0.0,0.331,0.037,0.001,0.562,0.015,0.0,0.0,0.039,0.01,0.0,0.0,0.0,0.001,0.0 15 | 14,0.0,0.0,0.006,0.0,0.507,0.0,0.002,0.0,0.03,0.0,0.001,0.084,0.0,0.0,0.0,0.0,0.371 16 | 15,0.017,0.0,0.004,0.0,0.4,0.001,0.229,0.004,0.182,0.017,0.002,0.0,0.031,0.001,0.001,0.0,0.111 17 | 16,0.174,0.0,0.0,0.504,0.0,0.0,0.0,0.0,0.081,0.0,0.125,0.0,0.0,0.044,0.0,0.0,0.071 18 | 17,0.24,0.001,0.0,0.0,0.0,0.011,0.038,0.0,0.0,0.0,0.319,0.144,0.0,0.001,0.022,0.223,0.0 19 | 18,0.109,0.116,0.004,0.0,0.0,0.16,0.171,0.0,0.0,0.0,0.0,0.0,0.003,0.103,0.035,0.031,0.265 20 | 19,0.007,0.0,0.0,0.001,0.0,0.0,0.679,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.047,0.09,0.176 21 | 20,0.547,0.015,0.0,0.011,0.0,0.0,0.332,0.006,0.036,0.0,0.0,0.0,0.001,0.036,0.01,0.005,0.0 22 | 21,0.0,0.142,0.014,0.509,0.001,0.002,0.148,0.0,0.034,0.0,0.062,0.005,0.0,0.081,0.0,0.0,0.001 23 | 22,0.009,0.42,0.0,0.0,0.0,0.0,0.343,0.005,0.084,0.0,0.005,0.0,0.07,0.061,0.0,0.004,0.0 24 | 23,0.0,0.106,0.0,0.143,0.042,0.024,0.001,0.01,0.0,0.001,0.403,0.001,0.122,0.001,0.003,0.141,0.0 25 | 24,0.0,0.011,0.008,0.006,0.36,0.0,0.533,0.004,0.0,0.0,0.003,0.0,0.02,0.0,0.0,0.055,0.0 26 | 25,0.025,0.064,0.002,0.001,0.007,0.038,0.011,0.051,0.002,0.0,0.015,0.106,0.005,0.007,0.0,0.086,0.58 27 | 26,0.961,0.0,0.0,0.0,0.0,0.001,0.001,0.03,0.007,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 28 | 27,0.0,0.181,0.006,0.0,0.002,0.0,0.009,0.0,0.08,0.0,0.0,0.0,0.158,0.086,0.0,0.017,0.461 29 | 28,0.006,0.207,0.0,0.103,0.087,0.016,0.22,0.001,0.034,0.001,0.203,0.012,0.057,0.0,0.0,0.005,0.047 30 | 29,0.0,0.001,0.0,0.0,0.0,0.024,0.717,0.003,0.0,0.0,0.0,0.249,0.0,0.0,0.002,0.0,0.004 31 | 30,0.003,0.0,0.0,0.0,0.008,0.0,0.02,0.0,0.167,0.0,0.105,0.026,0.0,0.0,0.025,0.449,0.195 32 | 31,0.0,0.0,0.002,0.0,0.0,0.0,0.442,0.0,0.029,0.0,0.0,0.013,0.0,0.0,0.021,0.348,0.145 33 | 32,0.061,0.007,0.0,0.367,0.0,0.106,0.0,0.0,0.18,0.008,0.001,0.0,0.086,0.001,0.0,0.001,0.182 34 | 33,0.148,0.001,0.035,0.037,0.258,0.066,0.039,0.022,0.002,0.0,0.003,0.342,0.018,0.0,0.015,0.002,0.011 35 | 34,0.017,0.012,0.0,0.0,0.43,0.0,0.0,0.0,0.003,0.008,0.275,0.0,0.001,0.0,0.013,0.241,0.0 36 | 35,0.049,0.0,0.009,0.061,0.132,0.146,0.062,0.003,0.154,0.0,0.0,0.198,0.002,0.001,0.02,0.163,0.0 37 | 36,0.138,0.049,0.012,0.324,0.049,0.071,0.194,0.0,0.0,0.017,0.031,0.017,0.0,0.005,0.0,0.093,0.0 38 | 37,0.0,0.0,0.0,0.005,0.0,0.026,0.537,0.0,0.0,0.0,0.139,0.01,0.0,0.012,0.011,0.26,0.0 39 | 38,0.173,0.001,0.0,0.035,0.0,0.0,0.405,0.0,0.0,0.0,0.001,0.004,0.0,0.0,0.0,0.38,0.0 40 | 39,0.234,0.0,0.0,0.313,0.0,0.0,0.0,0.0,0.351,0.0,0.0,0.021,0.007,0.0,0.074,0.0,0.0 41 | 40,0.03,0.091,0.001,0.001,0.094,0.022,0.085,0.0,0.298,0.007,0.105,0.225,0.023,0.003,0.0,0.002,0.012 42 | 41,0.028,0.001,0.001,0.029,0.065,0.021,0.09,0.0,0.066,0.001,0.015,0.185,0.002,0.004,0.001,0.349,0.142 43 | 42,0.0,0.118,0.021,0.003,0.513,0.0,0.004,0.0,0.001,0.025,0.0,0.0,0.008,0.091,0.0,0.217,0.0 44 | 43,0.0,0.0,0.005,0.001,0.762,0.0,0.001,0.0,0.127,0.0,0.0,0.104,0.0,0.0,0.0,0.0,0.0 45 | 44,0.194,0.255,0.0,0.035,0.0,0.003,0.0,0.0,0.025,0.0,0.12,0.0,0.008,0.079,0.0,0.0,0.282 46 | 45,0.22,0.031,0.001,0.381,0.0,0.014,0.001,0.0,0.02,0.004,0.004,0.08,0.152,0.022,0.04,0.03,0.0 47 | 46,0.0,0.0,0.005,0.0,0.342,0.057,0.0,0.001,0.524,0.0,0.001,0.023,0.0,0.001,0.019,0.022,0.004 48 | 47,0.0,0.374,0.007,0.22,0.0,0.034,0.125,0.007,0.013,0.001,0.0,0.0,0.0,0.052,0.053,0.113,0.0 49 | 48,0.001,0.0,0.0,0.158,0.0,0.024,0.553,0.0,0.013,0.0,0.159,0.039,0.0,0.041,0.0,0.0,0.012 50 | 49,0.019,0.062,0.0,0.0,0.004,0.0,0.0,0.0,0.9,0.0,0.0,0.0,0.0,0.0,0.012,0.003,0.0 51 | 50,0.006,0.029,0.0,0.028,0.0,0.001,0.08,0.001,0.17,0.001,0.216,0.0,0.028,0.036,0.022,0.351,0.03 52 | 51,0.0,0.0,0.0,0.273,0.0,0.093,0.031,0.0,0.001,0.0,0.34,0.206,0.0,0.04,0.0,0.016,0.0 53 | 52,0.0,0.0,0.0,0.805,0.001,0.0,0.137,0.0,0.0,0.0,0.0,0.007,0.0,0.0,0.0,0.0,0.05 54 | 53,0.008,0.003,0.0,0.076,0.061,0.126,0.001,0.025,0.004,0.001,0.09,0.154,0.177,0.0,0.0,0.01,0.264 55 | 54,0.669,0.015,0.004,0.013,0.081,0.0,0.0,0.0,0.0,0.015,0.0,0.001,0.152,0.001,0.05,0.0,0.0 56 | 55,0.0,0.003,0.0,0.594,0.021,0.0,0.004,0.0,0.326,0.0,0.0,0.0,0.0,0.0,0.0,0.053,0.0 57 | 56,0.185,0.004,0.002,0.384,0.142,0.108,0.002,0.023,0.096,0.0,0.001,0.006,0.0,0.033,0.011,0.0,0.003 58 | 57,0.0,0.0,0.0,0.901,0.0,0.0,0.001,0.0,0.0,0.0,0.0,0.0,0.0,0.032,0.001,0.0,0.065 59 | 58,0.0,0.0,0.023,0.061,0.068,0.087,0.007,0.0,0.349,0.009,0.015,0.196,0.087,0.023,0.0,0.068,0.007 60 | 59,0.028,0.065,0.023,0.006,0.002,0.239,0.001,0.0,0.136,0.022,0.064,0.09,0.016,0.0,0.0,0.0,0.307 61 | 60,0.368,0.013,0.007,0.0,0.001,0.0,0.0,0.0,0.513,0.0,0.027,0.001,0.0,0.003,0.064,0.002,0.0 62 | 61,0.011,0.0,0.0,0.0,0.236,0.0,0.002,0.0,0.0,0.0,0.001,0.75,0.0,0.0,0.0,0.0,0.0 63 | 62,0.14,0.0,0.0,0.0,0.001,0.0,0.047,0.002,0.197,0.0,0.0,0.583,0.0,0.0,0.0,0.005,0.024 64 | 63,0.001,0.834,0.019,0.007,0.0,0.0,0.003,0.003,0.0,0.0,0.0,0.0,0.116,0.003,0.0,0.0,0.014 65 | 64,0.028,0.138,0.0,0.395,0.001,0.087,0.177,0.001,0.018,0.0,0.0,0.0,0.072,0.008,0.006,0.042,0.027 66 | 65,0.099,0.0,0.004,0.0,0.0,0.0,0.056,0.0,0.0,0.0,0.14,0.7,0.0,0.0,0.0,0.0,0.0 67 | 66,0.0,0.725,0.0,0.0,0.0,0.0,0.0,0.0,0.008,0.0,0.0,0.109,0.144,0.014,0.0,0.0,0.0 68 | 67,0.753,0.0,0.009,0.001,0.001,0.0,0.0,0.004,0.0,0.0,0.198,0.0,0.0,0.0,0.0,0.034,0.0 69 | 68,0.125,0.001,0.001,0.003,0.068,0.009,0.082,0.001,0.324,0.001,0.014,0.251,0.006,0.02,0.037,0.017,0.042 70 | 69,0.001,0.32,0.001,0.188,0.0,0.11,0.08,0.01,0.006,0.025,0.127,0.007,0.008,0.073,0.003,0.0,0.041 71 | 70,0.0,0.0,0.0,0.0,0.002,0.0,0.002,0.003,0.402,0.0,0.003,0.579,0.0,0.0,0.008,0.0,0.001 72 | 71,0.0,0.0,0.0,0.274,0.45,0.113,0.144,0.019,0.0,0.0,0.0,0.0,0.001,0.0,0.0,0.0,0.0 73 | 72,0.251,0.471,0.0,0.0,0.005,0.168,0.002,0.0,0.074,0.003,0.0,0.021,0.0,0.0,0.0,0.002,0.004 74 | 73,0.002,0.0,0.008,0.551,0.002,0.0,0.004,0.0,0.0,0.009,0.0,0.068,0.0,0.007,0.005,0.342,0.0 75 | 74,0.0,0.362,0.0,0.0,0.0,0.002,0.0,0.0,0.0,0.0,0.004,0.575,0.0,0.03,0.001,0.0,0.025 76 | 75,0.045,0.019,0.006,0.089,0.394,0.048,0.0,0.004,0.0,0.014,0.0,0.05,0.0,0.072,0.066,0.192,0.0 77 | 76,0.058,0.019,0.0,0.102,0.025,0.001,0.025,0.019,0.225,0.0,0.079,0.213,0.0,0.02,0.0,0.001,0.214 78 | 77,0.001,0.001,0.0,0.0,0.624,0.043,0.0,0.001,0.0,0.001,0.0,0.331,0.0,0.0,0.0,0.0,0.0 79 | 78,0.021,0.001,0.032,0.575,0.001,0.096,0.192,0.0,0.0,0.024,0.006,0.0,0.031,0.0,0.0,0.02,0.0 80 | 79,0.004,0.0,0.0,0.003,0.0,0.0,0.913,0.0,0.0,0.0,0.08,0.0,0.0,0.0,0.0,0.0,0.0 81 | 80,0.06,0.0,0.0,0.475,0.101,0.002,0.001,0.0,0.0,0.0,0.001,0.005,0.0,0.013,0.0,0.313,0.029 82 | 81,0.0,0.074,0.001,0.065,0.055,0.013,0.005,0.0,0.023,0.0,0.0,0.079,0.006,0.073,0.016,0.561,0.028 83 | 82,0.007,0.003,0.0,0.201,0.001,0.0,0.356,0.011,0.096,0.001,0.087,0.0,0.024,0.0,0.0,0.213,0.001 84 | 83,0.0,0.005,0.0,0.647,0.129,0.0,0.0,0.0,0.053,0.0,0.11,0.0,0.0,0.0,0.056,0.0,0.0 85 | 84,0.002,0.0,0.0,0.16,0.0,0.055,0.0,0.001,0.248,0.0,0.01,0.522,0.0,0.0,0.0,0.003,0.0 86 | 85,0.0,0.0,0.002,0.008,0.0,0.0,0.491,0.0,0.027,0.0,0.0,0.001,0.002,0.026,0.0,0.443,0.0 87 | 86,0.879,0.0,0.0,0.0,0.0,0.008,0.017,0.0,0.0,0.0,0.032,0.035,0.002,0.001,0.0,0.004,0.021 88 | 87,0.0,0.955,0.0,0.0,0.033,0.0,0.0,0.001,0.0,0.0,0.0,0.0,0.0,0.0,0.01,0.0,0.0 89 | 88,0.008,0.002,0.0,0.0,0.03,0.148,0.0,0.0,0.002,0.0,0.0,0.356,0.0,0.0,0.0,0.0,0.453 90 | 89,0.0,0.029,0.001,0.002,0.467,0.002,0.0,0.0,0.108,0.008,0.018,0.322,0.0,0.018,0.022,0.001,0.002 91 | 90,0.294,0.079,0.0,0.016,0.002,0.0,0.257,0.013,0.012,0.012,0.066,0.007,0.011,0.0,0.0,0.23,0.0 92 | 91,0.445,0.008,0.0,0.0,0.0,0.002,0.005,0.007,0.004,0.008,0.122,0.312,0.007,0.0,0.077,0.001,0.0 93 | 92,0.094,0.0,0.0,0.0,0.001,0.004,0.037,0.0,0.309,0.001,0.022,0.001,0.016,0.0,0.001,0.02,0.494 94 | 93,0.067,0.0,0.01,0.036,0.0,0.225,0.518,0.026,0.0,0.0,0.001,0.004,0.104,0.008,0.0,0.0,0.0 95 | 94,0.376,0.008,0.0,0.0,0.367,0.0,0.001,0.009,0.002,0.0,0.0,0.007,0.0,0.0,0.0,0.121,0.11 96 | 95,0.0,0.0,0.0,0.299,0.525,0.142,0.0,0.0,0.019,0.0,0.0,0.001,0.006,0.0,0.0,0.0,0.008 97 | 96,0.0,0.0,0.0,0.322,0.0,0.0,0.006,0.012,0.614,0.0,0.0,0.0,0.0,0.0,0.046,0.0,0.0 98 | 97,0.0,0.366,0.0,0.0,0.0,0.0,0.023,0.005,0.0,0.0,0.031,0.443,0.002,0.114,0.0,0.0,0.017 99 | 98,0.0,0.093,0.004,0.159,0.005,0.0,0.095,0.0,0.0,0.0,0.0,0.0,0.148,0.0,0.003,0.316,0.177 100 | 99,0.0,0.328,0.0,0.017,0.0,0.0,0.015,0.01,0.459,0.002,0.0,0.0,0.129,0.037,0.0,0.0,0.003 101 | 100,0.173,0.2,0.002,0.01,0.0,0.0,0.0,0.055,0.199,0.0,0.0,0.0,0.0,0.018,0.0,0.328,0.015 102 | 101,0.126,0.006,0.011,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.087,0.005,0.0,0.046,0.026,0.0,0.695 103 | 102,0.082,0.031,0.004,0.741,0.0,0.006,0.01,0.0,0.0,0.0,0.116,0.0,0.0,0.0,0.0,0.0,0.01 104 | 103,0.164,0.001,0.0,0.0,0.0,0.047,0.005,0.0,0.0,0.0,0.239,0.0,0.0,0.054,0.0,0.194,0.294 105 | 104,0.577,0.155,0.0,0.005,0.034,0.008,0.046,0.001,0.005,0.005,0.021,0.117,0.005,0.0,0.02,0.0,0.0 106 | 105,0.018,0.093,0.0,0.0,0.352,0.0,0.004,0.0,0.059,0.0,0.0,0.188,0.0,0.0,0.0,0.0,0.286 107 | 106,0.0,0.001,0.0,0.0,0.038,0.0,0.018,0.0,0.82,0.0,0.0,0.077,0.0,0.002,0.045,0.0,0.0 108 | 107,0.04,0.0,0.0,0.125,0.006,0.01,0.017,0.001,0.0,0.003,0.002,0.664,0.0,0.0,0.008,0.064,0.062 109 | 108,0.001,0.003,0.0,0.263,0.001,0.0,0.082,0.0,0.004,0.0,0.295,0.112,0.002,0.0,0.0,0.086,0.151 110 | 109,0.002,0.0,0.0,0.001,0.0,0.079,0.01,0.001,0.0,0.001,0.0,0.901,0.0,0.0,0.0,0.0,0.005 111 | 110,0.001,0.0,0.0,0.0,0.0,0.046,0.016,0.013,0.491,0.002,0.001,0.0,0.0,0.001,0.0,0.429,0.0 112 | 111,0.498,0.0,0.0,0.012,0.0,0.004,0.001,0.022,0.416,0.0,0.0,0.0,0.0,0.004,0.001,0.041,0.0 113 | 112,0.006,0.262,0.021,0.0,0.0,0.002,0.0,0.0,0.122,0.001,0.0,0.0,0.159,0.0,0.0,0.0,0.426 114 | 113,0.051,0.096,0.001,0.01,0.018,0.0,0.169,0.001,0.033,0.0,0.169,0.257,0.002,0.0,0.0,0.002,0.188 115 | 114,0.001,0.006,0.0,0.532,0.178,0.002,0.0,0.023,0.112,0.0,0.001,0.0,0.144,0.0,0.0,0.0,0.001 116 | 115,0.005,0.0,0.0,0.29,0.001,0.0,0.0,0.0,0.108,0.0,0.0,0.596,0.0,0.0,0.0,0.0,0.0 117 | 116,0.294,0.033,0.002,0.001,0.555,0.018,0.012,0.028,0.0,0.0,0.008,0.013,0.0,0.03,0.0,0.004,0.0 118 | 117,0.0,0.001,0.007,0.089,0.0,0.0,0.009,0.005,0.872,0.0,0.0,0.013,0.0,0.0,0.004,0.0,0.0 119 | 118,0.0,0.029,0.0,0.798,0.0,0.0,0.0,0.0,0.0,0.003,0.149,0.006,0.005,0.0,0.0,0.009,0.0 120 | 119,0.046,0.015,0.0,0.027,0.548,0.134,0.111,0.0,0.002,0.0,0.014,0.0,0.03,0.005,0.022,0.033,0.012 121 | 120,0.012,0.497,0.028,0.005,0.396,0.0,0.0,0.0,0.003,0.0,0.004,0.012,0.03,0.006,0.001,0.008,0.0 122 | 121,0.167,0.0,0.05,0.001,0.019,0.0,0.482,0.003,0.0,0.001,0.202,0.01,0.0,0.0,0.007,0.059,0.0 123 | 122,0.012,0.0,0.046,0.023,0.064,0.064,0.27,0.0,0.0,0.002,0.0,0.097,0.049,0.007,0.0,0.365,0.001 124 | 123,0.341,0.0,0.0,0.0,0.0,0.0,0.004,0.024,0.229,0.0,0.0,0.0,0.003,0.0,0.0,0.398,0.0 125 | 124,0.0,0.016,0.0,0.165,0.0,0.157,0.037,0.0,0.0,0.005,0.276,0.003,0.001,0.0,0.0,0.34,0.0 126 | 125,0.002,0.0,0.0,0.275,0.0,0.001,0.0,0.018,0.068,0.003,0.154,0.082,0.014,0.009,0.0,0.024,0.349 127 | 126,0.015,0.515,0.029,0.001,0.002,0.005,0.148,0.006,0.014,0.0,0.002,0.0,0.002,0.0,0.063,0.002,0.195 128 | 127,0.0,0.433,0.007,0.307,0.0,0.006,0.0,0.014,0.011,0.0,0.003,0.0,0.145,0.009,0.0,0.063,0.0 129 | 128,0.302,0.027,0.036,0.082,0.0,0.0,0.061,0.002,0.064,0.0,0.073,0.145,0.084,0.026,0.013,0.085,0.0 130 | 129,0.0,0.0,0.008,0.0,0.031,0.004,0.206,0.0,0.0,0.0,0.001,0.745,0.0,0.002,0.0,0.0,0.004 131 | 130,0.0,0.022,0.04,0.179,0.026,0.053,0.45,0.002,0.002,0.023,0.005,0.054,0.097,0.006,0.0,0.041,0.0 132 | 131,0.047,0.129,0.003,0.001,0.025,0.009,0.006,0.0,0.0,0.004,0.198,0.196,0.001,0.0,0.08,0.0,0.301 133 | 132,0.373,0.0,0.0,0.001,0.005,0.0,0.117,0.0,0.384,0.0,0.0,0.001,0.006,0.001,0.055,0.04,0.017 134 | 133,0.002,0.001,0.0,0.251,0.034,0.032,0.092,0.046,0.026,0.011,0.022,0.0,0.001,0.0,0.06,0.416,0.009 135 | 134,0.0,0.0,0.0,0.005,0.022,0.0,0.0,0.02,0.0,0.0,0.059,0.689,0.006,0.014,0.0,0.078,0.108 136 | 135,0.0,0.418,0.0,0.002,0.0,0.0,0.58,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 137 | 136,0.0,0.019,0.0,0.727,0.101,0.0,0.0,0.0,0.0,0.0,0.123,0.0,0.0,0.0,0.029,0.0,0.0 138 | 137,0.0,0.0,0.0,0.0,0.239,0.0,0.157,0.0,0.203,0.0,0.007,0.003,0.003,0.0,0.0,0.0,0.387 139 | 138,0.0,0.0,0.0,0.0,0.0,0.0,0.725,0.0,0.004,0.0,0.0,0.051,0.0,0.0,0.0,0.221,0.0 140 | 139,0.873,0.0,0.0,0.01,0.0,0.0,0.101,0.0,0.015,0.0,0.0,0.0,0.001,0.0,0.0,0.0,0.0 141 | 140,0.04,0.476,0.027,0.0,0.045,0.006,0.007,0.002,0.0,0.02,0.064,0.019,0.149,0.051,0.0,0.094,0.0 142 | 141,0.249,0.0,0.018,0.001,0.001,0.051,0.572,0.006,0.058,0.0,0.001,0.0,0.0,0.003,0.0,0.0,0.039 143 | 142,0.009,0.015,0.0,0.135,0.683,0.001,0.009,0.0,0.001,0.005,0.134,0.0,0.0,0.0,0.002,0.006,0.0 144 | 143,0.064,0.21,0.041,0.001,0.0,0.018,0.005,0.0,0.174,0.0,0.216,0.03,0.072,0.097,0.045,0.004,0.022 145 | 144,0.049,0.028,0.0,0.601,0.037,0.04,0.0,0.008,0.0,0.006,0.211,0.0,0.0,0.0,0.0,0.021,0.0 146 | 145,0.016,0.0,0.001,0.002,0.572,0.012,0.321,0.0,0.0,0.0,0.001,0.011,0.036,0.012,0.0,0.001,0.016 147 | 146,0.094,0.0,0.0,0.0,0.552,0.175,0.0,0.05,0.0,0.0,0.0,0.0,0.128,0.0,0.0,0.0,0.001 148 | 147,0.0,0.013,0.0,0.0,0.001,0.0,0.304,0.0,0.564,0.024,0.001,0.0,0.03,0.004,0.012,0.021,0.027 149 | 148,0.002,0.0,0.04,0.22,0.0,0.004,0.133,0.0,0.044,0.0,0.0,0.001,0.0,0.049,0.116,0.0,0.39 150 | 149,0.0,0.006,0.0,0.0,0.0,0.001,0.571,0.0,0.137,0.0,0.001,0.0,0.0,0.0,0.0,0.169,0.115 151 | 150,0.007,0.001,0.019,0.447,0.013,0.0,0.001,0.0,0.077,0.004,0.0,0.0,0.0,0.095,0.002,0.005,0.33 152 | 151,0.006,0.061,0.001,0.323,0.016,0.0,0.017,0.005,0.548,0.007,0.016,0.0,0.0,0.0,0.0,0.0,0.001 153 | 152,0.11,0.0,0.05,0.15,0.071,0.118,0.018,0.0,0.208,0.002,0.0,0.0,0.042,0.027,0.002,0.0,0.203 154 | 153,0.112,0.0,0.009,0.002,0.078,0.004,0.039,0.041,0.02,0.003,0.354,0.0,0.029,0.011,0.0,0.247,0.052 155 | 154,0.199,0.004,0.016,0.004,0.167,0.0,0.432,0.001,0.073,0.0,0.025,0.004,0.0,0.001,0.047,0.023,0.004 156 | 155,0.001,0.385,0.006,0.144,0.123,0.037,0.0,0.001,0.034,0.0,0.106,0.004,0.006,0.007,0.014,0.133,0.002 157 | 156,0.0,0.037,0.0,0.776,0.0,0.0,0.044,0.0,0.108,0.0,0.0,0.033,0.001,0.0,0.0,0.001,0.0 158 | 157,0.0,0.045,0.0,0.042,0.185,0.013,0.0,0.003,0.009,0.0,0.353,0.298,0.01,0.0,0.013,0.0,0.028 159 | 158,0.46,0.001,0.0,0.0,0.028,0.008,0.0,0.0,0.007,0.0,0.019,0.052,0.002,0.0,0.0,0.421,0.002 160 | 159,0.0,0.0,0.0,0.01,0.0,0.096,0.742,0.0,0.0,0.001,0.0,0.0,0.01,0.0,0.0,0.0,0.14 161 | 160,0.0,0.005,0.006,0.104,0.0,0.08,0.134,0.0,0.0,0.003,0.052,0.045,0.0,0.065,0.002,0.38,0.124 162 | 161,0.001,0.038,0.001,0.046,0.062,0.0,0.0,0.019,0.0,0.0,0.0,0.689,0.114,0.008,0.022,0.0,0.0 163 | 162,0.01,0.085,0.0,0.0,0.006,0.001,0.65,0.0,0.0,0.005,0.0,0.01,0.0,0.0,0.03,0.0,0.202 164 | 163,0.199,0.001,0.03,0.112,0.0,0.024,0.012,0.003,0.563,0.006,0.0,0.016,0.001,0.007,0.008,0.019,0.0 165 | 164,0.0,0.026,0.018,0.074,0.427,0.025,0.011,0.0,0.0,0.0,0.137,0.001,0.08,0.002,0.003,0.0,0.197 166 | 165,0.003,0.079,0.028,0.059,0.004,0.131,0.002,0.0,0.007,0.001,0.14,0.01,0.087,0.001,0.005,0.05,0.392 167 | 166,0.021,0.276,0.019,0.056,0.033,0.0,0.077,0.0,0.005,0.001,0.034,0.001,0.0,0.0,0.0,0.461,0.017 168 | 167,0.195,0.023,0.002,0.022,0.0,0.238,0.258,0.0,0.002,0.0,0.04,0.002,0.0,0.107,0.0,0.073,0.037 169 | 168,0.0,0.003,0.0,0.0,0.036,0.002,0.003,0.009,0.638,0.0,0.12,0.087,0.001,0.0,0.001,0.0,0.1 170 | 169,0.004,0.0,0.001,0.634,0.001,0.213,0.009,0.0,0.102,0.002,0.025,0.0,0.008,0.0,0.0,0.001,0.0 171 | 170,0.0,0.001,0.001,0.163,0.08,0.0,0.0,0.0,0.03,0.001,0.005,0.513,0.0,0.0,0.1,0.104,0.003 172 | 171,0.016,0.011,0.0,0.006,0.016,0.0,0.018,0.0,0.696,0.0,0.0,0.001,0.0,0.049,0.0,0.187,0.0 173 | 172,0.138,0.024,0.0,0.11,0.001,0.002,0.013,0.0,0.311,0.0,0.0,0.29,0.041,0.01,0.051,0.001,0.007 174 | 173,0.477,0.004,0.0,0.245,0.002,0.001,0.0,0.0,0.021,0.0,0.0,0.219,0.0,0.0,0.001,0.0,0.03 175 | 174,0.0,0.087,0.001,0.008,0.13,0.0,0.557,0.0,0.004,0.001,0.182,0.001,0.0,0.022,0.0,0.0,0.006 176 | 175,0.0,0.044,0.02,0.531,0.0,0.0,0.352,0.0,0.038,0.001,0.0,0.0,0.0,0.014,0.0,0.0,0.0 177 | 176,0.0,0.141,0.0,0.038,0.044,0.181,0.0,0.0,0.029,0.004,0.001,0.19,0.004,0.037,0.0,0.265,0.067 178 | 177,0.0,0.008,0.0,0.253,0.0,0.001,0.0,0.0,0.0,0.003,0.016,0.606,0.0,0.0,0.0,0.0,0.113 179 | 178,0.0,0.264,0.0,0.061,0.171,0.0,0.149,0.004,0.0,0.0,0.09,0.066,0.087,0.052,0.008,0.0,0.048 180 | 179,0.35,0.001,0.0,0.0,0.096,0.034,0.0,0.01,0.0,0.0,0.001,0.226,0.006,0.067,0.004,0.018,0.186 181 | 180,0.065,0.668,0.015,0.006,0.0,0.0,0.005,0.0,0.0,0.003,0.107,0.018,0.0,0.001,0.0,0.113,0.0 182 | 181,0.018,0.22,0.0,0.001,0.518,0.112,0.057,0.053,0.002,0.001,0.0,0.003,0.0,0.001,0.0,0.0,0.014 183 | 182,0.553,0.062,0.005,0.004,0.001,0.0,0.009,0.0,0.078,0.001,0.031,0.065,0.0,0.001,0.103,0.001,0.087 184 | 183,0.055,0.761,0.0,0.0,0.0,0.0,0.003,0.002,0.124,0.0,0.0,0.032,0.012,0.0,0.001,0.01,0.0 185 | 184,0.032,0.0,0.036,0.0,0.0,0.0,0.054,0.0,0.001,0.0,0.089,0.661,0.085,0.0,0.043,0.0,0.0 186 | 185,0.0,0.006,0.0,0.0,0.0,0.0,0.0,0.001,0.034,0.0,0.003,0.956,0.0,0.0,0.0,0.0,0.0 187 | 186,0.742,0.0,0.0,0.0,0.055,0.0,0.0,0.0,0.19,0.0,0.004,0.0,0.0,0.0,0.009,0.0,0.0 188 | 187,0.01,0.06,0.001,0.094,0.733,0.0,0.001,0.008,0.018,0.0,0.055,0.002,0.0,0.004,0.007,0.0,0.007 189 | 188,0.094,0.0,0.0,0.209,0.072,0.0,0.066,0.007,0.009,0.004,0.039,0.175,0.02,0.117,0.007,0.07,0.112 190 | 189,0.07,0.128,0.0,0.0,0.001,0.0,0.0,0.0,0.002,0.003,0.0,0.684,0.0,0.0,0.039,0.072,0.0 191 | 190,0.182,0.102,0.0,0.0,0.018,0.0,0.0,0.0,0.385,0.0,0.067,0.016,0.145,0.069,0.0,0.017,0.0 192 | 191,0.063,0.0,0.017,0.787,0.021,0.003,0.05,0.0,0.0,0.0,0.0,0.005,0.0,0.054,0.0,0.0,0.0 193 | 192,0.008,0.047,0.039,0.028,0.002,0.046,0.025,0.028,0.266,0.001,0.342,0.0,0.064,0.001,0.056,0.0,0.047 194 | 193,0.783,0.0,0.0,0.0,0.006,0.096,0.002,0.001,0.026,0.0,0.0,0.049,0.0,0.0,0.016,0.0,0.02 195 | 194,0.0,0.0,0.002,0.0,0.0,0.0,0.003,0.0,0.223,0.0,0.0,0.771,0.0,0.0,0.0,0.001,0.0 196 | 195,0.026,0.009,0.055,0.077,0.005,0.0,0.001,0.018,0.0,0.0,0.0,0.086,0.115,0.001,0.009,0.338,0.259 197 | 196,0.191,0.0,0.002,0.209,0.0,0.152,0.117,0.0,0.005,0.0,0.0,0.0,0.029,0.017,0.0,0.278,0.0 198 | 197,0.197,0.001,0.0,0.0,0.539,0.0,0.074,0.019,0.0,0.001,0.0,0.02,0.0,0.087,0.0,0.0,0.063 199 | 198,0.0,0.087,0.0,0.016,0.003,0.0,0.6,0.0,0.292,0.0,0.0,0.002,0.0,0.0,0.0,0.0,0.0 200 | 199,0.104,0.023,0.019,0.066,0.009,0.199,0.027,0.0,0.142,0.0,0.224,0.002,0.004,0.005,0.018,0.087,0.071 201 | 200,0.09,0.167,0.0,0.038,0.0,0.0,0.033,0.05,0.088,0.0,0.194,0.002,0.002,0.001,0.012,0.305,0.017 202 | 201,0.259,0.079,0.0,0.509,0.0,0.0,0.114,0.0,0.0,0.0,0.0,0.022,0.0,0.0,0.017,0.0,0.0 203 | 202,0.07,0.0,0.0,0.13,0.0,0.0,0.002,0.0,0.0,0.0,0.0,0.0,0.034,0.0,0.0,0.304,0.46 204 | 203,0.099,0.001,0.013,0.006,0.17,0.18,0.0,0.0,0.179,0.0,0.0,0.045,0.0,0.047,0.005,0.01,0.246 205 | 204,0.048,0.485,0.007,0.008,0.097,0.014,0.034,0.009,0.169,0.009,0.026,0.044,0.029,0.008,0.0,0.0,0.013 206 | 205,0.533,0.01,0.0,0.0,0.022,0.047,0.0,0.0,0.0,0.0,0.118,0.228,0.003,0.0,0.0,0.0,0.039 207 | 206,0.0,0.252,0.0,0.048,0.0,0.079,0.067,0.043,0.0,0.0,0.199,0.0,0.0,0.011,0.0,0.156,0.144 208 | 207,0.02,0.225,0.0,0.167,0.001,0.016,0.523,0.0,0.017,0.006,0.0,0.0,0.0,0.0,0.001,0.0,0.024 209 | 208,0.0,0.025,0.0,0.081,0.0,0.0,0.742,0.0,0.0,0.0,0.0,0.0,0.0,0.115,0.0,0.0,0.037 210 | 209,0.0,0.0,0.0,0.176,0.0,0.0,0.0,0.0,0.03,0.0,0.0,0.088,0.06,0.0,0.0,0.0,0.646 211 | 210,0.022,0.033,0.002,0.033,0.274,0.0,0.0,0.009,0.003,0.0,0.2,0.058,0.005,0.021,0.0,0.001,0.34 212 | 211,0.008,0.0,0.026,0.0,0.108,0.0,0.055,0.0,0.662,0.003,0.0,0.018,0.119,0.0,0.0,0.0,0.0 213 | 212,0.062,0.714,0.008,0.012,0.008,0.0,0.015,0.0,0.088,0.0,0.0,0.0,0.002,0.0,0.068,0.0,0.022 214 | 213,0.532,0.002,0.0,0.038,0.0,0.0,0.001,0.006,0.0,0.0,0.0,0.0,0.0,0.0,0.002,0.419,0.0 215 | 214,0.0,0.0,0.0,0.633,0.0,0.002,0.02,0.007,0.003,0.0,0.0,0.222,0.0,0.0,0.0,0.0,0.112 216 | 215,0.043,0.617,0.0,0.0,0.026,0.0,0.0,0.025,0.0,0.001,0.008,0.2,0.0,0.001,0.005,0.001,0.075 217 | 216,0.0,0.0,0.0,0.0,0.013,0.002,0.827,0.014,0.073,0.001,0.07,0.0,0.0,0.0,0.0,0.0,0.0 218 | 217,0.0,0.007,0.014,0.03,0.0,0.0,0.003,0.0,0.0,0.001,0.157,0.78,0.006,0.0,0.0,0.002,0.0 219 | 218,0.574,0.0,0.0,0.003,0.006,0.0,0.0,0.0,0.001,0.0,0.0,0.002,0.0,0.064,0.0,0.0,0.351 220 | 219,0.001,0.027,0.001,0.034,0.27,0.017,0.0,0.0,0.225,0.0,0.038,0.025,0.004,0.0,0.005,0.043,0.308 221 | 220,0.0,0.014,0.006,0.576,0.001,0.149,0.004,0.0,0.007,0.004,0.171,0.008,0.0,0.0,0.0,0.04,0.02 222 | 221,0.0,0.022,0.0,0.042,0.0,0.008,0.0,0.0,0.004,0.004,0.06,0.0,0.003,0.096,0.051,0.105,0.607 223 | 222,0.0,0.016,0.0,0.0,0.321,0.007,0.131,0.0,0.019,0.002,0.0,0.463,0.0,0.0,0.04,0.0,0.001 224 | 223,0.425,0.0,0.0,0.0,0.0,0.115,0.343,0.0,0.0,0.0,0.084,0.0,0.021,0.0,0.0,0.0,0.013 225 | 224,0.001,0.0,0.054,0.723,0.0,0.155,0.0,0.0,0.002,0.0,0.0,0.019,0.0,0.0,0.047,0.0,0.0 226 | 225,0.286,0.004,0.04,0.001,0.0,0.017,0.087,0.028,0.264,0.0,0.023,0.209,0.0,0.0,0.0,0.04,0.0 227 | 226,0.085,0.006,0.0,0.525,0.026,0.022,0.047,0.015,0.029,0.005,0.006,0.0,0.012,0.0,0.058,0.003,0.161 228 | 227,0.156,0.0,0.004,0.0,0.0,0.0,0.605,0.0,0.181,0.0,0.0,0.027,0.016,0.01,0.0,0.0,0.001 229 | 228,0.604,0.002,0.002,0.0,0.001,0.001,0.0,0.0,0.023,0.0,0.058,0.295,0.0,0.015,0.0,0.0,0.0 230 | 229,0.094,0.0,0.0,0.0,0.292,0.0,0.004,0.0,0.0,0.014,0.007,0.571,0.0,0.01,0.004,0.002,0.0 231 | 230,0.068,0.0,0.029,0.138,0.0,0.099,0.338,0.022,0.0,0.014,0.001,0.058,0.0,0.019,0.0,0.214,0.0 232 | 231,0.253,0.081,0.008,0.0,0.044,0.0,0.116,0.0,0.002,0.005,0.365,0.013,0.0,0.0,0.002,0.054,0.06 233 | 232,0.007,0.024,0.0,0.511,0.084,0.206,0.025,0.0,0.006,0.023,0.001,0.022,0.006,0.006,0.0,0.0,0.079 234 | 233,0.0,0.013,0.002,0.599,0.019,0.0,0.273,0.0,0.0,0.0,0.0,0.001,0.0,0.0,0.0,0.0,0.092 235 | 234,0.82,0.0,0.0,0.114,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.002,0.058,0.0,0.0,0.006,0.0 236 | 235,0.0,0.007,0.0,0.225,0.016,0.038,0.0,0.005,0.0,0.0,0.0,0.276,0.003,0.0,0.0,0.034,0.396 237 | 236,0.04,0.0,0.0,0.009,0.002,0.008,0.18,0.0,0.0,0.0,0.245,0.0,0.0,0.028,0.012,0.472,0.004 238 | 237,0.225,0.0,0.0,0.0,0.0,0.0,0.64,0.048,0.0,0.002,0.0,0.0,0.0,0.0,0.003,0.0,0.083 239 | 238,0.037,0.415,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.031,0.0,0.516 240 | 239,0.096,0.084,0.0,0.019,0.069,0.028,0.03,0.0,0.127,0.0,0.002,0.034,0.132,0.025,0.0,0.349,0.005 241 | 240,0.091,0.022,0.005,0.003,0.0,0.018,0.0,0.03,0.615,0.0,0.0,0.007,0.04,0.036,0.106,0.025,0.002 242 | 241,0.034,0.073,0.0,0.0,0.0,0.0,0.76,0.0,0.0,0.0,0.068,0.0,0.0,0.0,0.062,0.0,0.004 243 | 242,0.001,0.039,0.0,0.205,0.0,0.097,0.15,0.0,0.474,0.0,0.0,0.0,0.0,0.034,0.0,0.0,0.0 244 | 243,0.002,0.377,0.034,0.005,0.001,0.025,0.0,0.006,0.151,0.002,0.076,0.235,0.015,0.005,0.002,0.0,0.064 245 | 244,0.061,0.0,0.051,0.248,0.299,0.01,0.001,0.0,0.238,0.004,0.0,0.008,0.0,0.0,0.001,0.006,0.073 246 | 245,0.322,0.001,0.0,0.136,0.0,0.021,0.0,0.0,0.124,0.003,0.001,0.08,0.033,0.001,0.018,0.245,0.015 247 | 246,0.005,0.014,0.0,0.021,0.0,0.0,0.136,0.0,0.669,0.0,0.074,0.0,0.059,0.021,0.0,0.0,0.001 248 | 247,0.001,0.0,0.0,0.013,0.055,0.0,0.435,0.0,0.0,0.0,0.26,0.233,0.0,0.004,0.001,0.0,0.0 249 | 248,0.002,0.0,0.0,0.0,0.729,0.0,0.004,0.0,0.258,0.0,0.001,0.0,0.0,0.0,0.0,0.006,0.0 250 | 249,0.004,0.093,0.0,0.002,0.008,0.0,0.206,0.013,0.328,0.011,0.131,0.081,0.043,0.053,0.026,0.0,0.001 251 | 250,0.0,0.003,0.0,0.007,0.0,0.205,0.679,0.0,0.008,0.024,0.0,0.057,0.001,0.015,0.0,0.0,0.0 252 | 251,0.078,0.202,0.0,0.456,0.243,0.0,0.0,0.0,0.008,0.0,0.011,0.0,0.0,0.002,0.0,0.0,0.0 253 | 252,0.008,0.0,0.002,0.009,0.346,0.102,0.176,0.04,0.004,0.0,0.0,0.005,0.01,0.024,0.045,0.018,0.21 254 | 253,0.002,0.014,0.0,0.355,0.01,0.074,0.115,0.0,0.038,0.009,0.0,0.107,0.037,0.015,0.0,0.225,0.0 255 | 254,0.004,0.0,0.034,0.001,0.0,0.046,0.0,0.0,0.294,0.001,0.0,0.131,0.011,0.045,0.0,0.433,0.0 256 | 255,0.0,0.104,0.0,0.08,0.578,0.001,0.0,0.001,0.0,0.0,0.13,0.0,0.029,0.0,0.0,0.0,0.076 257 | 256,0.031,0.0,0.0,0.0,0.001,0.0,0.833,0.0,0.0,0.0,0.0,0.134,0.0,0.001,0.0,0.0,0.0 258 | -------------------------------------------------------------------------------- /data/test_mixture_60m.csv: -------------------------------------------------------------------------------- 1 | index,train_the_pile_arxiv,train_the_pile_freelaw,train_the_pile_nih_exporter,train_the_pile_pubmed_central,train_the_pile_wikipedia_en,train_the_pile_dm_mathematics,train_the_pile_github,train_the_pile_philpapers,train_the_pile_stackexchange,train_the_pile_enron_emails,train_the_pile_gutenberg_pg_19,train_the_pile_pile_cc,train_the_pile_ubuntu_irc,train_the_pile_europarl,train_the_pile_hackernews,train_the_pile_pubmed_abstracts,train_the_pile_uspto_backgrounds 2 | 1,0.08,0.0,0.0,0.464,0.0,0.0,0.0,0.0,0.0,0.0,0.015,0.353,0.001,0.0,0.042,0.039,0.005 3 | 2,0.035,0.016,0.0,0.022,0.0,0.04,0.0,0.001,0.097,0.0,0.0,0.632,0.066,0.0,0.0,0.0,0.092 4 | 3,0.014,0.0,0.004,0.022,0.003,0.084,0.477,0.0,0.0,0.0,0.209,0.052,0.0,0.002,0.013,0.093,0.027 5 | 4,0.215,0.0,0.0,0.238,0.012,0.0,0.05,0.009,0.186,0.0,0.258,0.0,0.0,0.001,0.032,0.0,0.0 6 | 5,0.732,0.0,0.0,0.0,0.004,0.027,0.035,0.001,0.113,0.004,0.0,0.005,0.04,0.037,0.0,0.0,0.003 7 | 6,0.024,0.034,0.0,0.097,0.001,0.015,0.338,0.0,0.089,0.001,0.086,0.117,0.005,0.0,0.105,0.0,0.089 8 | 7,0.007,0.0,0.0,0.0,0.0,0.0,0.988,0.0,0.0,0.002,0.0,0.0,0.0,0.0,0.0,0.002,0.0 9 | 8,0.0,0.0,0.0,0.005,0.0,0.0,0.0,0.0,0.865,0.0,0.111,0.0,0.0,0.0,0.0,0.0,0.019 10 | 9,0.0,0.0,0.021,0.001,0.0,0.013,0.005,0.001,0.0,0.0,0.0,0.489,0.016,0.0,0.0,0.001,0.454 11 | 10,0.189,0.0,0.0,0.752,0.0,0.0,0.0,0.0,0.0,0.0,0.021,0.019,0.003,0.0,0.004,0.0,0.012 12 | 11,0.0,0.011,0.0,0.0,0.02,0.0,0.328,0.0,0.204,0.023,0.023,0.0,0.008,0.0,0.0,0.0,0.383 13 | 12,0.0,0.021,0.0,0.0,0.0,0.008,0.0,0.006,0.026,0.0,0.0,0.778,0.0,0.002,0.012,0.016,0.131 14 | 13,0.0,0.004,0.0,0.331,0.037,0.001,0.562,0.015,0.0,0.0,0.039,0.01,0.0,0.0,0.0,0.001,0.0 15 | 14,0.0,0.0,0.006,0.0,0.507,0.0,0.002,0.0,0.03,0.0,0.001,0.084,0.0,0.0,0.0,0.0,0.371 16 | 15,0.017,0.0,0.004,0.0,0.4,0.001,0.229,0.004,0.182,0.017,0.002,0.0,0.031,0.001,0.001,0.0,0.111 17 | 16,0.174,0.0,0.0,0.504,0.0,0.0,0.0,0.0,0.081,0.0,0.125,0.0,0.0,0.044,0.0,0.0,0.071 18 | 17,0.24,0.001,0.0,0.0,0.0,0.011,0.038,0.0,0.0,0.0,0.319,0.144,0.0,0.001,0.022,0.223,0.0 19 | 18,0.109,0.116,0.004,0.0,0.0,0.16,0.171,0.0,0.0,0.0,0.0,0.0,0.003,0.103,0.035,0.031,0.265 20 | 19,0.007,0.0,0.0,0.001,0.0,0.0,0.679,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.047,0.09,0.176 21 | 20,0.547,0.015,0.0,0.011,0.0,0.0,0.332,0.006,0.036,0.0,0.0,0.0,0.001,0.036,0.01,0.005,0.0 22 | 21,0.0,0.142,0.014,0.509,0.001,0.002,0.148,0.0,0.034,0.0,0.062,0.005,0.0,0.081,0.0,0.0,0.001 23 | 22,0.009,0.42,0.0,0.0,0.0,0.0,0.343,0.005,0.084,0.0,0.005,0.0,0.07,0.061,0.0,0.004,0.0 24 | 23,0.0,0.106,0.0,0.143,0.042,0.024,0.001,0.01,0.0,0.001,0.403,0.001,0.122,0.001,0.003,0.141,0.0 25 | 24,0.0,0.011,0.008,0.006,0.36,0.0,0.533,0.004,0.0,0.0,0.003,0.0,0.02,0.0,0.0,0.055,0.0 26 | 25,0.025,0.064,0.002,0.001,0.007,0.038,0.011,0.051,0.002,0.0,0.015,0.106,0.005,0.007,0.0,0.086,0.58 27 | 26,0.961,0.0,0.0,0.0,0.0,0.001,0.001,0.03,0.007,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 28 | 27,0.0,0.181,0.006,0.0,0.002,0.0,0.009,0.0,0.08,0.0,0.0,0.0,0.158,0.086,0.0,0.017,0.461 29 | 28,0.006,0.207,0.0,0.103,0.087,0.016,0.22,0.001,0.034,0.001,0.203,0.012,0.057,0.0,0.0,0.005,0.047 30 | 29,0.0,0.001,0.0,0.0,0.0,0.024,0.717,0.003,0.0,0.0,0.0,0.249,0.0,0.0,0.002,0.0,0.004 31 | 30,0.003,0.0,0.0,0.0,0.008,0.0,0.02,0.0,0.167,0.0,0.105,0.026,0.0,0.0,0.025,0.449,0.195 32 | 31,0.0,0.0,0.002,0.0,0.0,0.0,0.442,0.0,0.029,0.0,0.0,0.013,0.0,0.0,0.021,0.348,0.145 33 | 32,0.061,0.007,0.0,0.367,0.0,0.106,0.0,0.0,0.18,0.008,0.001,0.0,0.086,0.001,0.0,0.001,0.182 34 | 33,0.148,0.001,0.035,0.037,0.258,0.066,0.039,0.022,0.002,0.0,0.003,0.342,0.018,0.0,0.015,0.002,0.011 35 | 34,0.017,0.012,0.0,0.0,0.43,0.0,0.0,0.0,0.003,0.008,0.275,0.0,0.001,0.0,0.013,0.241,0.0 36 | 35,0.049,0.0,0.009,0.061,0.132,0.146,0.062,0.003,0.154,0.0,0.0,0.198,0.002,0.001,0.02,0.163,0.0 37 | 36,0.138,0.049,0.012,0.324,0.049,0.071,0.194,0.0,0.0,0.017,0.031,0.017,0.0,0.005,0.0,0.093,0.0 38 | 37,0.0,0.0,0.0,0.005,0.0,0.026,0.537,0.0,0.0,0.0,0.139,0.01,0.0,0.012,0.011,0.26,0.0 39 | 38,0.173,0.001,0.0,0.035,0.0,0.0,0.405,0.0,0.0,0.0,0.001,0.004,0.0,0.0,0.0,0.38,0.0 40 | 39,0.234,0.0,0.0,0.313,0.0,0.0,0.0,0.0,0.351,0.0,0.0,0.021,0.007,0.0,0.074,0.0,0.0 41 | 40,0.03,0.091,0.001,0.001,0.094,0.022,0.085,0.0,0.298,0.007,0.105,0.225,0.023,0.003,0.0,0.002,0.012 42 | 41,0.028,0.001,0.001,0.029,0.065,0.021,0.09,0.0,0.066,0.001,0.015,0.185,0.002,0.004,0.001,0.349,0.142 43 | 42,0.0,0.118,0.021,0.003,0.513,0.0,0.004,0.0,0.001,0.025,0.0,0.0,0.008,0.091,0.0,0.217,0.0 44 | 43,0.0,0.0,0.005,0.001,0.762,0.0,0.001,0.0,0.127,0.0,0.0,0.104,0.0,0.0,0.0,0.0,0.0 45 | 44,0.194,0.255,0.0,0.035,0.0,0.003,0.0,0.0,0.025,0.0,0.12,0.0,0.008,0.079,0.0,0.0,0.282 46 | 45,0.22,0.031,0.001,0.381,0.0,0.014,0.001,0.0,0.02,0.004,0.004,0.08,0.152,0.022,0.04,0.03,0.0 47 | 46,0.0,0.0,0.005,0.0,0.342,0.057,0.0,0.001,0.524,0.0,0.001,0.023,0.0,0.001,0.019,0.022,0.004 48 | 47,0.0,0.374,0.007,0.22,0.0,0.034,0.125,0.007,0.013,0.001,0.0,0.0,0.0,0.052,0.053,0.113,0.0 49 | 48,0.001,0.0,0.0,0.158,0.0,0.024,0.553,0.0,0.013,0.0,0.159,0.039,0.0,0.041,0.0,0.0,0.012 50 | 49,0.019,0.062,0.0,0.0,0.004,0.0,0.0,0.0,0.9,0.0,0.0,0.0,0.0,0.0,0.012,0.003,0.0 51 | 50,0.006,0.029,0.0,0.028,0.0,0.001,0.08,0.001,0.17,0.001,0.216,0.0,0.028,0.036,0.022,0.351,0.03 52 | 51,0.0,0.0,0.0,0.273,0.0,0.093,0.031,0.0,0.001,0.0,0.34,0.206,0.0,0.04,0.0,0.016,0.0 53 | 52,0.0,0.0,0.0,0.805,0.001,0.0,0.137,0.0,0.0,0.0,0.0,0.007,0.0,0.0,0.0,0.0,0.05 54 | 53,0.008,0.003,0.0,0.076,0.061,0.126,0.001,0.025,0.004,0.001,0.09,0.154,0.177,0.0,0.0,0.01,0.264 55 | 54,0.669,0.015,0.004,0.013,0.081,0.0,0.0,0.0,0.0,0.015,0.0,0.001,0.152,0.001,0.05,0.0,0.0 56 | 55,0.0,0.003,0.0,0.594,0.021,0.0,0.004,0.0,0.326,0.0,0.0,0.0,0.0,0.0,0.0,0.053,0.0 57 | 56,0.185,0.004,0.002,0.384,0.142,0.108,0.002,0.023,0.096,0.0,0.001,0.006,0.0,0.033,0.011,0.0,0.003 58 | 57,0.0,0.0,0.0,0.901,0.0,0.0,0.001,0.0,0.0,0.0,0.0,0.0,0.0,0.032,0.001,0.0,0.065 59 | 58,0.0,0.0,0.023,0.061,0.068,0.087,0.007,0.0,0.349,0.009,0.015,0.196,0.087,0.023,0.0,0.068,0.007 60 | 59,0.028,0.065,0.023,0.006,0.002,0.239,0.001,0.0,0.136,0.022,0.064,0.09,0.016,0.0,0.0,0.0,0.307 61 | 60,0.368,0.013,0.007,0.0,0.001,0.0,0.0,0.0,0.513,0.0,0.027,0.001,0.0,0.003,0.064,0.002,0.0 62 | 61,0.011,0.0,0.0,0.0,0.236,0.0,0.002,0.0,0.0,0.0,0.001,0.75,0.0,0.0,0.0,0.0,0.0 63 | 62,0.14,0.0,0.0,0.0,0.001,0.0,0.047,0.002,0.197,0.0,0.0,0.583,0.0,0.0,0.0,0.005,0.024 64 | 63,0.001,0.834,0.019,0.007,0.0,0.0,0.003,0.003,0.0,0.0,0.0,0.0,0.116,0.003,0.0,0.0,0.014 65 | 64,0.028,0.138,0.0,0.395,0.001,0.087,0.177,0.001,0.018,0.0,0.0,0.0,0.072,0.008,0.006,0.042,0.027 66 | 65,0.099,0.0,0.004,0.0,0.0,0.0,0.056,0.0,0.0,0.0,0.14,0.7,0.0,0.0,0.0,0.0,0.0 67 | 66,0.0,0.725,0.0,0.0,0.0,0.0,0.0,0.0,0.008,0.0,0.0,0.109,0.144,0.014,0.0,0.0,0.0 68 | 67,0.753,0.0,0.009,0.001,0.001,0.0,0.0,0.004,0.0,0.0,0.198,0.0,0.0,0.0,0.0,0.034,0.0 69 | 68,0.125,0.001,0.001,0.003,0.068,0.009,0.082,0.001,0.324,0.001,0.014,0.251,0.006,0.02,0.037,0.017,0.042 70 | 69,0.001,0.32,0.001,0.188,0.0,0.11,0.08,0.01,0.006,0.025,0.127,0.007,0.008,0.073,0.003,0.0,0.041 71 | 70,0.0,0.0,0.0,0.0,0.002,0.0,0.002,0.003,0.402,0.0,0.003,0.579,0.0,0.0,0.008,0.0,0.001 72 | 71,0.0,0.0,0.0,0.274,0.45,0.113,0.144,0.019,0.0,0.0,0.0,0.0,0.001,0.0,0.0,0.0,0.0 73 | 72,0.251,0.471,0.0,0.0,0.005,0.168,0.002,0.0,0.074,0.003,0.0,0.021,0.0,0.0,0.0,0.002,0.004 74 | 73,0.002,0.0,0.008,0.551,0.002,0.0,0.004,0.0,0.0,0.009,0.0,0.068,0.0,0.007,0.005,0.342,0.0 75 | 74,0.0,0.362,0.0,0.0,0.0,0.002,0.0,0.0,0.0,0.0,0.004,0.575,0.0,0.03,0.001,0.0,0.025 76 | 75,0.045,0.019,0.006,0.089,0.394,0.048,0.0,0.004,0.0,0.014,0.0,0.05,0.0,0.072,0.066,0.192,0.0 77 | 76,0.058,0.019,0.0,0.102,0.025,0.001,0.025,0.019,0.225,0.0,0.079,0.213,0.0,0.02,0.0,0.001,0.214 78 | 77,0.001,0.001,0.0,0.0,0.624,0.043,0.0,0.001,0.0,0.001,0.0,0.331,0.0,0.0,0.0,0.0,0.0 79 | 78,0.021,0.001,0.032,0.575,0.001,0.096,0.192,0.0,0.0,0.024,0.006,0.0,0.031,0.0,0.0,0.02,0.0 80 | 79,0.004,0.0,0.0,0.003,0.0,0.0,0.913,0.0,0.0,0.0,0.08,0.0,0.0,0.0,0.0,0.0,0.0 81 | 80,0.06,0.0,0.0,0.475,0.101,0.002,0.001,0.0,0.0,0.0,0.001,0.005,0.0,0.013,0.0,0.313,0.029 82 | 81,0.0,0.074,0.001,0.065,0.055,0.013,0.005,0.0,0.023,0.0,0.0,0.079,0.006,0.073,0.016,0.561,0.028 83 | 82,0.007,0.003,0.0,0.201,0.001,0.0,0.356,0.011,0.096,0.001,0.087,0.0,0.024,0.0,0.0,0.213,0.001 84 | 83,0.0,0.005,0.0,0.647,0.129,0.0,0.0,0.0,0.053,0.0,0.11,0.0,0.0,0.0,0.056,0.0,0.0 85 | 84,0.002,0.0,0.0,0.16,0.0,0.055,0.0,0.001,0.248,0.0,0.01,0.522,0.0,0.0,0.0,0.003,0.0 86 | 85,0.0,0.0,0.002,0.008,0.0,0.0,0.491,0.0,0.027,0.0,0.0,0.001,0.002,0.026,0.0,0.443,0.0 87 | 86,0.879,0.0,0.0,0.0,0.0,0.008,0.017,0.0,0.0,0.0,0.032,0.035,0.002,0.001,0.0,0.004,0.021 88 | 87,0.0,0.955,0.0,0.0,0.033,0.0,0.0,0.001,0.0,0.0,0.0,0.0,0.0,0.0,0.01,0.0,0.0 89 | 88,0.008,0.002,0.0,0.0,0.03,0.148,0.0,0.0,0.002,0.0,0.0,0.356,0.0,0.0,0.0,0.0,0.453 90 | 89,0.0,0.029,0.001,0.002,0.467,0.002,0.0,0.0,0.108,0.008,0.018,0.322,0.0,0.018,0.022,0.001,0.002 91 | 90,0.294,0.079,0.0,0.016,0.002,0.0,0.257,0.013,0.012,0.012,0.066,0.007,0.011,0.0,0.0,0.23,0.0 92 | 91,0.445,0.008,0.0,0.0,0.0,0.002,0.005,0.007,0.004,0.008,0.122,0.312,0.007,0.0,0.077,0.001,0.0 93 | 92,0.094,0.0,0.0,0.0,0.001,0.004,0.037,0.0,0.309,0.001,0.022,0.001,0.016,0.0,0.001,0.02,0.494 94 | 93,0.067,0.0,0.01,0.036,0.0,0.225,0.518,0.026,0.0,0.0,0.001,0.004,0.104,0.008,0.0,0.0,0.0 95 | 94,0.376,0.008,0.0,0.0,0.367,0.0,0.001,0.009,0.002,0.0,0.0,0.007,0.0,0.0,0.0,0.121,0.11 96 | 95,0.0,0.0,0.0,0.299,0.525,0.142,0.0,0.0,0.019,0.0,0.0,0.001,0.006,0.0,0.0,0.0,0.008 97 | 96,0.0,0.0,0.0,0.322,0.0,0.0,0.006,0.012,0.614,0.0,0.0,0.0,0.0,0.0,0.046,0.0,0.0 98 | 97,0.0,0.366,0.0,0.0,0.0,0.0,0.023,0.005,0.0,0.0,0.031,0.443,0.002,0.114,0.0,0.0,0.017 99 | 98,0.0,0.093,0.004,0.159,0.005,0.0,0.095,0.0,0.0,0.0,0.0,0.0,0.148,0.0,0.003,0.316,0.177 100 | 99,0.0,0.328,0.0,0.017,0.0,0.0,0.015,0.01,0.459,0.002,0.0,0.0,0.129,0.037,0.0,0.0,0.003 101 | 100,0.173,0.2,0.002,0.01,0.0,0.0,0.0,0.055,0.199,0.0,0.0,0.0,0.0,0.018,0.0,0.328,0.015 102 | 101,0.126,0.006,0.011,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.087,0.005,0.0,0.046,0.026,0.0,0.695 103 | 102,0.082,0.031,0.004,0.741,0.0,0.006,0.01,0.0,0.0,0.0,0.116,0.0,0.0,0.0,0.0,0.0,0.01 104 | 103,0.164,0.001,0.0,0.0,0.0,0.047,0.005,0.0,0.0,0.0,0.239,0.0,0.0,0.054,0.0,0.194,0.294 105 | 104,0.577,0.155,0.0,0.005,0.034,0.008,0.046,0.001,0.005,0.005,0.021,0.117,0.005,0.0,0.02,0.0,0.0 106 | 105,0.018,0.093,0.0,0.0,0.352,0.0,0.004,0.0,0.059,0.0,0.0,0.188,0.0,0.0,0.0,0.0,0.286 107 | 106,0.0,0.001,0.0,0.0,0.038,0.0,0.018,0.0,0.82,0.0,0.0,0.077,0.0,0.002,0.045,0.0,0.0 108 | 107,0.04,0.0,0.0,0.125,0.006,0.01,0.017,0.001,0.0,0.003,0.002,0.664,0.0,0.0,0.008,0.064,0.062 109 | 108,0.001,0.003,0.0,0.263,0.001,0.0,0.082,0.0,0.004,0.0,0.295,0.112,0.002,0.0,0.0,0.086,0.151 110 | 109,0.002,0.0,0.0,0.001,0.0,0.079,0.01,0.001,0.0,0.001,0.0,0.901,0.0,0.0,0.0,0.0,0.005 111 | 110,0.001,0.0,0.0,0.0,0.0,0.046,0.016,0.013,0.491,0.002,0.001,0.0,0.0,0.001,0.0,0.429,0.0 112 | 111,0.498,0.0,0.0,0.012,0.0,0.004,0.001,0.022,0.416,0.0,0.0,0.0,0.0,0.004,0.001,0.041,0.0 113 | 112,0.006,0.262,0.021,0.0,0.0,0.002,0.0,0.0,0.122,0.001,0.0,0.0,0.159,0.0,0.0,0.0,0.426 114 | 113,0.051,0.096,0.001,0.01,0.018,0.0,0.169,0.001,0.033,0.0,0.169,0.257,0.002,0.0,0.0,0.002,0.188 115 | 114,0.001,0.006,0.0,0.532,0.178,0.002,0.0,0.023,0.112,0.0,0.001,0.0,0.144,0.0,0.0,0.0,0.001 116 | 115,0.005,0.0,0.0,0.29,0.001,0.0,0.0,0.0,0.108,0.0,0.0,0.596,0.0,0.0,0.0,0.0,0.0 117 | 116,0.294,0.033,0.002,0.001,0.555,0.018,0.012,0.028,0.0,0.0,0.008,0.013,0.0,0.03,0.0,0.004,0.0 118 | 117,0.0,0.001,0.007,0.089,0.0,0.0,0.009,0.005,0.872,0.0,0.0,0.013,0.0,0.0,0.004,0.0,0.0 119 | 118,0.0,0.029,0.0,0.798,0.0,0.0,0.0,0.0,0.0,0.003,0.149,0.006,0.005,0.0,0.0,0.009,0.0 120 | 119,0.046,0.015,0.0,0.027,0.548,0.134,0.111,0.0,0.002,0.0,0.014,0.0,0.03,0.005,0.022,0.033,0.012 121 | 120,0.012,0.497,0.028,0.005,0.396,0.0,0.0,0.0,0.003,0.0,0.004,0.012,0.03,0.006,0.001,0.008,0.0 122 | 121,0.167,0.0,0.05,0.001,0.019,0.0,0.482,0.003,0.0,0.001,0.202,0.01,0.0,0.0,0.007,0.059,0.0 123 | 122,0.012,0.0,0.046,0.023,0.064,0.064,0.27,0.0,0.0,0.002,0.0,0.097,0.049,0.007,0.0,0.365,0.001 124 | 123,0.341,0.0,0.0,0.0,0.0,0.0,0.004,0.024,0.229,0.0,0.0,0.0,0.003,0.0,0.0,0.398,0.0 125 | 124,0.0,0.016,0.0,0.165,0.0,0.157,0.037,0.0,0.0,0.005,0.276,0.003,0.001,0.0,0.0,0.34,0.0 126 | 125,0.002,0.0,0.0,0.275,0.0,0.001,0.0,0.018,0.068,0.003,0.154,0.082,0.014,0.009,0.0,0.024,0.349 127 | 126,0.015,0.515,0.029,0.001,0.002,0.005,0.148,0.006,0.014,0.0,0.002,0.0,0.002,0.0,0.063,0.002,0.195 128 | 127,0.0,0.433,0.007,0.307,0.0,0.006,0.0,0.014,0.011,0.0,0.003,0.0,0.145,0.009,0.0,0.063,0.0 129 | 128,0.302,0.027,0.036,0.082,0.0,0.0,0.061,0.002,0.064,0.0,0.073,0.145,0.084,0.026,0.013,0.085,0.0 130 | 129,0.0,0.0,0.008,0.0,0.031,0.004,0.206,0.0,0.0,0.0,0.001,0.745,0.0,0.002,0.0,0.0,0.004 131 | 130,0.0,0.022,0.04,0.179,0.026,0.053,0.45,0.002,0.002,0.023,0.005,0.054,0.097,0.006,0.0,0.041,0.0 132 | 131,0.047,0.129,0.003,0.001,0.025,0.009,0.006,0.0,0.0,0.004,0.198,0.196,0.001,0.0,0.08,0.0,0.301 133 | 132,0.373,0.0,0.0,0.001,0.005,0.0,0.117,0.0,0.384,0.0,0.0,0.001,0.006,0.001,0.055,0.04,0.017 134 | 133,0.002,0.001,0.0,0.251,0.034,0.032,0.092,0.046,0.026,0.011,0.022,0.0,0.001,0.0,0.06,0.416,0.009 135 | 134,0.0,0.0,0.0,0.005,0.022,0.0,0.0,0.02,0.0,0.0,0.059,0.689,0.006,0.014,0.0,0.078,0.108 136 | 135,0.0,0.418,0.0,0.002,0.0,0.0,0.58,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 137 | 136,0.0,0.019,0.0,0.727,0.101,0.0,0.0,0.0,0.0,0.0,0.123,0.0,0.0,0.0,0.029,0.0,0.0 138 | 137,0.0,0.0,0.0,0.0,0.239,0.0,0.157,0.0,0.203,0.0,0.007,0.003,0.003,0.0,0.0,0.0,0.387 139 | 138,0.0,0.0,0.0,0.0,0.0,0.0,0.725,0.0,0.004,0.0,0.0,0.051,0.0,0.0,0.0,0.221,0.0 140 | 139,0.873,0.0,0.0,0.01,0.0,0.0,0.101,0.0,0.015,0.0,0.0,0.0,0.001,0.0,0.0,0.0,0.0 141 | 140,0.04,0.476,0.027,0.0,0.045,0.006,0.007,0.002,0.0,0.02,0.064,0.019,0.149,0.051,0.0,0.094,0.0 142 | 141,0.249,0.0,0.018,0.001,0.001,0.051,0.572,0.006,0.058,0.0,0.001,0.0,0.0,0.003,0.0,0.0,0.039 143 | 142,0.009,0.015,0.0,0.135,0.683,0.001,0.009,0.0,0.001,0.005,0.134,0.0,0.0,0.0,0.002,0.006,0.0 144 | 143,0.064,0.21,0.041,0.001,0.0,0.018,0.005,0.0,0.174,0.0,0.216,0.03,0.072,0.097,0.045,0.004,0.022 145 | 144,0.049,0.028,0.0,0.601,0.037,0.04,0.0,0.008,0.0,0.006,0.211,0.0,0.0,0.0,0.0,0.021,0.0 146 | 145,0.016,0.0,0.001,0.002,0.572,0.012,0.321,0.0,0.0,0.0,0.001,0.011,0.036,0.012,0.0,0.001,0.016 147 | 146,0.094,0.0,0.0,0.0,0.552,0.175,0.0,0.05,0.0,0.0,0.0,0.0,0.128,0.0,0.0,0.0,0.001 148 | 147,0.0,0.013,0.0,0.0,0.001,0.0,0.304,0.0,0.564,0.024,0.001,0.0,0.03,0.004,0.012,0.021,0.027 149 | 148,0.002,0.0,0.04,0.22,0.0,0.004,0.133,0.0,0.044,0.0,0.0,0.001,0.0,0.049,0.116,0.0,0.39 150 | 149,0.0,0.006,0.0,0.0,0.0,0.001,0.571,0.0,0.137,0.0,0.001,0.0,0.0,0.0,0.0,0.169,0.115 151 | 150,0.007,0.001,0.019,0.447,0.013,0.0,0.001,0.0,0.077,0.004,0.0,0.0,0.0,0.095,0.002,0.005,0.33 152 | 151,0.006,0.061,0.001,0.323,0.016,0.0,0.017,0.005,0.548,0.007,0.016,0.0,0.0,0.0,0.0,0.0,0.001 153 | 152,0.11,0.0,0.05,0.15,0.071,0.118,0.018,0.0,0.208,0.002,0.0,0.0,0.042,0.027,0.002,0.0,0.203 154 | 153,0.112,0.0,0.009,0.002,0.078,0.004,0.039,0.041,0.02,0.003,0.354,0.0,0.029,0.011,0.0,0.247,0.052 155 | 154,0.199,0.004,0.016,0.004,0.167,0.0,0.432,0.001,0.073,0.0,0.025,0.004,0.0,0.001,0.047,0.023,0.004 156 | 155,0.001,0.385,0.006,0.144,0.123,0.037,0.0,0.001,0.034,0.0,0.106,0.004,0.006,0.007,0.014,0.133,0.002 157 | 156,0.0,0.037,0.0,0.776,0.0,0.0,0.044,0.0,0.108,0.0,0.0,0.033,0.001,0.0,0.0,0.001,0.0 158 | 157,0.0,0.045,0.0,0.042,0.185,0.013,0.0,0.003,0.009,0.0,0.353,0.298,0.01,0.0,0.013,0.0,0.028 159 | 158,0.46,0.001,0.0,0.0,0.028,0.008,0.0,0.0,0.007,0.0,0.019,0.052,0.002,0.0,0.0,0.421,0.002 160 | 159,0.0,0.0,0.0,0.01,0.0,0.096,0.742,0.0,0.0,0.001,0.0,0.0,0.01,0.0,0.0,0.0,0.14 161 | 160,0.0,0.005,0.006,0.104,0.0,0.08,0.134,0.0,0.0,0.003,0.052,0.045,0.0,0.065,0.002,0.38,0.124 162 | 161,0.001,0.038,0.001,0.046,0.062,0.0,0.0,0.019,0.0,0.0,0.0,0.689,0.114,0.008,0.022,0.0,0.0 163 | 162,0.01,0.085,0.0,0.0,0.006,0.001,0.65,0.0,0.0,0.005,0.0,0.01,0.0,0.0,0.03,0.0,0.202 164 | 163,0.199,0.001,0.03,0.112,0.0,0.024,0.012,0.003,0.563,0.006,0.0,0.016,0.001,0.007,0.008,0.019,0.0 165 | 164,0.0,0.026,0.018,0.074,0.427,0.025,0.011,0.0,0.0,0.0,0.137,0.001,0.08,0.002,0.003,0.0,0.197 166 | 165,0.003,0.079,0.028,0.059,0.004,0.131,0.002,0.0,0.007,0.001,0.14,0.01,0.087,0.001,0.005,0.05,0.392 167 | 166,0.021,0.276,0.019,0.056,0.033,0.0,0.077,0.0,0.005,0.001,0.034,0.001,0.0,0.0,0.0,0.461,0.017 168 | 167,0.195,0.023,0.002,0.022,0.0,0.238,0.258,0.0,0.002,0.0,0.04,0.002,0.0,0.107,0.0,0.073,0.037 169 | 168,0.0,0.003,0.0,0.0,0.036,0.002,0.003,0.009,0.638,0.0,0.12,0.087,0.001,0.0,0.001,0.0,0.1 170 | 169,0.004,0.0,0.001,0.634,0.001,0.213,0.009,0.0,0.102,0.002,0.025,0.0,0.008,0.0,0.0,0.001,0.0 171 | 170,0.0,0.001,0.001,0.163,0.08,0.0,0.0,0.0,0.03,0.001,0.005,0.513,0.0,0.0,0.1,0.104,0.003 172 | 171,0.016,0.011,0.0,0.006,0.016,0.0,0.018,0.0,0.696,0.0,0.0,0.001,0.0,0.049,0.0,0.187,0.0 173 | 172,0.138,0.024,0.0,0.11,0.001,0.002,0.013,0.0,0.311,0.0,0.0,0.29,0.041,0.01,0.051,0.001,0.007 174 | 173,0.477,0.004,0.0,0.245,0.002,0.001,0.0,0.0,0.021,0.0,0.0,0.219,0.0,0.0,0.001,0.0,0.03 175 | 174,0.0,0.087,0.001,0.008,0.13,0.0,0.557,0.0,0.004,0.001,0.182,0.001,0.0,0.022,0.0,0.0,0.006 176 | 175,0.0,0.044,0.02,0.531,0.0,0.0,0.352,0.0,0.038,0.001,0.0,0.0,0.0,0.014,0.0,0.0,0.0 177 | 176,0.0,0.141,0.0,0.038,0.044,0.181,0.0,0.0,0.029,0.004,0.001,0.19,0.004,0.037,0.0,0.265,0.067 178 | 177,0.0,0.008,0.0,0.253,0.0,0.001,0.0,0.0,0.0,0.003,0.016,0.606,0.0,0.0,0.0,0.0,0.113 179 | 178,0.0,0.264,0.0,0.061,0.171,0.0,0.149,0.004,0.0,0.0,0.09,0.066,0.087,0.052,0.008,0.0,0.048 180 | 179,0.35,0.001,0.0,0.0,0.096,0.034,0.0,0.01,0.0,0.0,0.001,0.226,0.006,0.067,0.004,0.018,0.186 181 | 180,0.065,0.668,0.015,0.006,0.0,0.0,0.005,0.0,0.0,0.003,0.107,0.018,0.0,0.001,0.0,0.113,0.0 182 | 181,0.018,0.22,0.0,0.001,0.518,0.112,0.057,0.053,0.002,0.001,0.0,0.003,0.0,0.001,0.0,0.0,0.014 183 | 182,0.553,0.062,0.005,0.004,0.001,0.0,0.009,0.0,0.078,0.001,0.031,0.065,0.0,0.001,0.103,0.001,0.087 184 | 183,0.055,0.761,0.0,0.0,0.0,0.0,0.003,0.002,0.124,0.0,0.0,0.032,0.012,0.0,0.001,0.01,0.0 185 | 184,0.032,0.0,0.036,0.0,0.0,0.0,0.054,0.0,0.001,0.0,0.089,0.661,0.085,0.0,0.043,0.0,0.0 186 | 185,0.0,0.006,0.0,0.0,0.0,0.0,0.0,0.001,0.034,0.0,0.003,0.956,0.0,0.0,0.0,0.0,0.0 187 | 186,0.742,0.0,0.0,0.0,0.055,0.0,0.0,0.0,0.19,0.0,0.004,0.0,0.0,0.0,0.009,0.0,0.0 188 | 187,0.01,0.06,0.001,0.094,0.733,0.0,0.001,0.008,0.018,0.0,0.055,0.002,0.0,0.004,0.007,0.0,0.007 189 | 188,0.094,0.0,0.0,0.209,0.072,0.0,0.066,0.007,0.009,0.004,0.039,0.175,0.02,0.117,0.007,0.07,0.112 190 | 189,0.07,0.128,0.0,0.0,0.001,0.0,0.0,0.0,0.002,0.003,0.0,0.684,0.0,0.0,0.039,0.072,0.0 191 | 190,0.182,0.102,0.0,0.0,0.018,0.0,0.0,0.0,0.385,0.0,0.067,0.016,0.145,0.069,0.0,0.017,0.0 192 | 191,0.063,0.0,0.017,0.787,0.021,0.003,0.05,0.0,0.0,0.0,0.0,0.005,0.0,0.054,0.0,0.0,0.0 193 | 192,0.008,0.047,0.039,0.028,0.002,0.046,0.025,0.028,0.266,0.001,0.342,0.0,0.064,0.001,0.056,0.0,0.047 194 | 193,0.783,0.0,0.0,0.0,0.006,0.096,0.002,0.001,0.026,0.0,0.0,0.049,0.0,0.0,0.016,0.0,0.02 195 | 194,0.0,0.0,0.002,0.0,0.0,0.0,0.003,0.0,0.223,0.0,0.0,0.771,0.0,0.0,0.0,0.001,0.0 196 | 195,0.026,0.009,0.055,0.077,0.005,0.0,0.001,0.018,0.0,0.0,0.0,0.086,0.115,0.001,0.009,0.338,0.259 197 | 196,0.191,0.0,0.002,0.209,0.0,0.152,0.117,0.0,0.005,0.0,0.0,0.0,0.029,0.017,0.0,0.278,0.0 198 | 197,0.197,0.001,0.0,0.0,0.539,0.0,0.074,0.019,0.0,0.001,0.0,0.02,0.0,0.087,0.0,0.0,0.063 199 | 198,0.0,0.087,0.0,0.016,0.003,0.0,0.6,0.0,0.292,0.0,0.0,0.002,0.0,0.0,0.0,0.0,0.0 200 | 199,0.104,0.023,0.019,0.066,0.009,0.199,0.027,0.0,0.142,0.0,0.224,0.002,0.004,0.005,0.018,0.087,0.071 201 | 200,0.09,0.167,0.0,0.038,0.0,0.0,0.033,0.05,0.088,0.0,0.194,0.002,0.002,0.001,0.012,0.305,0.017 202 | 201,0.259,0.079,0.0,0.509,0.0,0.0,0.114,0.0,0.0,0.0,0.0,0.022,0.0,0.0,0.017,0.0,0.0 203 | 202,0.07,0.0,0.0,0.13,0.0,0.0,0.002,0.0,0.0,0.0,0.0,0.0,0.034,0.0,0.0,0.304,0.46 204 | 203,0.099,0.001,0.013,0.006,0.17,0.18,0.0,0.0,0.179,0.0,0.0,0.045,0.0,0.047,0.005,0.01,0.246 205 | 204,0.048,0.485,0.007,0.008,0.097,0.014,0.034,0.009,0.169,0.009,0.026,0.044,0.029,0.008,0.0,0.0,0.013 206 | 205,0.533,0.01,0.0,0.0,0.022,0.047,0.0,0.0,0.0,0.0,0.118,0.228,0.003,0.0,0.0,0.0,0.039 207 | 206,0.0,0.252,0.0,0.048,0.0,0.079,0.067,0.043,0.0,0.0,0.199,0.0,0.0,0.011,0.0,0.156,0.144 208 | 207,0.02,0.225,0.0,0.167,0.001,0.016,0.523,0.0,0.017,0.006,0.0,0.0,0.0,0.0,0.001,0.0,0.024 209 | 208,0.0,0.025,0.0,0.081,0.0,0.0,0.742,0.0,0.0,0.0,0.0,0.0,0.0,0.115,0.0,0.0,0.037 210 | 209,0.0,0.0,0.0,0.176,0.0,0.0,0.0,0.0,0.03,0.0,0.0,0.088,0.06,0.0,0.0,0.0,0.646 211 | 210,0.022,0.033,0.002,0.033,0.274,0.0,0.0,0.009,0.003,0.0,0.2,0.058,0.005,0.021,0.0,0.001,0.34 212 | 211,0.008,0.0,0.026,0.0,0.108,0.0,0.055,0.0,0.662,0.003,0.0,0.018,0.119,0.0,0.0,0.0,0.0 213 | 212,0.062,0.714,0.008,0.012,0.008,0.0,0.015,0.0,0.088,0.0,0.0,0.0,0.002,0.0,0.068,0.0,0.022 214 | 213,0.532,0.002,0.0,0.038,0.0,0.0,0.001,0.006,0.0,0.0,0.0,0.0,0.0,0.0,0.002,0.419,0.0 215 | 214,0.0,0.0,0.0,0.633,0.0,0.002,0.02,0.007,0.003,0.0,0.0,0.222,0.0,0.0,0.0,0.0,0.112 216 | 215,0.043,0.617,0.0,0.0,0.026,0.0,0.0,0.025,0.0,0.001,0.008,0.2,0.0,0.001,0.005,0.001,0.075 217 | 216,0.0,0.0,0.0,0.0,0.013,0.002,0.827,0.014,0.073,0.001,0.07,0.0,0.0,0.0,0.0,0.0,0.0 218 | 217,0.0,0.007,0.014,0.03,0.0,0.0,0.003,0.0,0.0,0.001,0.157,0.78,0.006,0.0,0.0,0.002,0.0 219 | 218,0.574,0.0,0.0,0.003,0.006,0.0,0.0,0.0,0.001,0.0,0.0,0.002,0.0,0.064,0.0,0.0,0.351 220 | 219,0.001,0.027,0.001,0.034,0.27,0.017,0.0,0.0,0.225,0.0,0.038,0.025,0.004,0.0,0.005,0.043,0.308 221 | 220,0.0,0.014,0.006,0.576,0.001,0.149,0.004,0.0,0.007,0.004,0.171,0.008,0.0,0.0,0.0,0.04,0.02 222 | 221,0.0,0.022,0.0,0.042,0.0,0.008,0.0,0.0,0.004,0.004,0.06,0.0,0.003,0.096,0.051,0.105,0.607 223 | 222,0.0,0.016,0.0,0.0,0.321,0.007,0.131,0.0,0.019,0.002,0.0,0.463,0.0,0.0,0.04,0.0,0.001 224 | 223,0.425,0.0,0.0,0.0,0.0,0.115,0.343,0.0,0.0,0.0,0.084,0.0,0.021,0.0,0.0,0.0,0.013 225 | 224,0.001,0.0,0.054,0.723,0.0,0.155,0.0,0.0,0.002,0.0,0.0,0.019,0.0,0.0,0.047,0.0,0.0 226 | 225,0.286,0.004,0.04,0.001,0.0,0.017,0.087,0.028,0.264,0.0,0.023,0.209,0.0,0.0,0.0,0.04,0.0 227 | 226,0.085,0.006,0.0,0.525,0.026,0.022,0.047,0.015,0.029,0.005,0.006,0.0,0.012,0.0,0.058,0.003,0.161 228 | 227,0.156,0.0,0.004,0.0,0.0,0.0,0.605,0.0,0.181,0.0,0.0,0.027,0.016,0.01,0.0,0.0,0.001 229 | 228,0.604,0.002,0.002,0.0,0.001,0.001,0.0,0.0,0.023,0.0,0.058,0.295,0.0,0.015,0.0,0.0,0.0 230 | 229,0.094,0.0,0.0,0.0,0.292,0.0,0.004,0.0,0.0,0.014,0.007,0.571,0.0,0.01,0.004,0.002,0.0 231 | 230,0.068,0.0,0.029,0.138,0.0,0.099,0.338,0.022,0.0,0.014,0.001,0.058,0.0,0.019,0.0,0.214,0.0 232 | 231,0.253,0.081,0.008,0.0,0.044,0.0,0.116,0.0,0.002,0.005,0.365,0.013,0.0,0.0,0.002,0.054,0.06 233 | 232,0.007,0.024,0.0,0.511,0.084,0.206,0.025,0.0,0.006,0.023,0.001,0.022,0.006,0.006,0.0,0.0,0.079 234 | 233,0.0,0.013,0.002,0.599,0.019,0.0,0.273,0.0,0.0,0.0,0.0,0.001,0.0,0.0,0.0,0.0,0.092 235 | 234,0.82,0.0,0.0,0.114,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.002,0.058,0.0,0.0,0.006,0.0 236 | 235,0.0,0.007,0.0,0.225,0.016,0.038,0.0,0.005,0.0,0.0,0.0,0.276,0.003,0.0,0.0,0.034,0.396 237 | 236,0.04,0.0,0.0,0.009,0.002,0.008,0.18,0.0,0.0,0.0,0.245,0.0,0.0,0.028,0.012,0.472,0.004 238 | 237,0.225,0.0,0.0,0.0,0.0,0.0,0.64,0.048,0.0,0.002,0.0,0.0,0.0,0.0,0.003,0.0,0.083 239 | 238,0.037,0.415,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.031,0.0,0.516 240 | 239,0.096,0.084,0.0,0.019,0.069,0.028,0.03,0.0,0.127,0.0,0.002,0.034,0.132,0.025,0.0,0.349,0.005 241 | 240,0.091,0.022,0.005,0.003,0.0,0.018,0.0,0.03,0.615,0.0,0.0,0.007,0.04,0.036,0.106,0.025,0.002 242 | 241,0.034,0.073,0.0,0.0,0.0,0.0,0.76,0.0,0.0,0.0,0.068,0.0,0.0,0.0,0.062,0.0,0.004 243 | 242,0.001,0.039,0.0,0.205,0.0,0.097,0.15,0.0,0.474,0.0,0.0,0.0,0.0,0.034,0.0,0.0,0.0 244 | 243,0.002,0.377,0.034,0.005,0.001,0.025,0.0,0.006,0.151,0.002,0.076,0.235,0.015,0.005,0.002,0.0,0.064 245 | 244,0.061,0.0,0.051,0.248,0.299,0.01,0.001,0.0,0.238,0.004,0.0,0.008,0.0,0.0,0.001,0.006,0.073 246 | 245,0.322,0.001,0.0,0.136,0.0,0.021,0.0,0.0,0.124,0.003,0.001,0.08,0.033,0.001,0.018,0.245,0.015 247 | 246,0.005,0.014,0.0,0.021,0.0,0.0,0.136,0.0,0.669,0.0,0.074,0.0,0.059,0.021,0.0,0.0,0.001 248 | 247,0.001,0.0,0.0,0.013,0.055,0.0,0.435,0.0,0.0,0.0,0.26,0.233,0.0,0.004,0.001,0.0,0.0 249 | 248,0.002,0.0,0.0,0.0,0.729,0.0,0.004,0.0,0.258,0.0,0.001,0.0,0.0,0.0,0.0,0.006,0.0 250 | 249,0.004,0.093,0.0,0.002,0.008,0.0,0.206,0.013,0.328,0.011,0.131,0.081,0.043,0.053,0.026,0.0,0.001 251 | 250,0.0,0.003,0.0,0.007,0.0,0.205,0.679,0.0,0.008,0.024,0.0,0.057,0.001,0.015,0.0,0.0,0.0 252 | 251,0.078,0.202,0.0,0.456,0.243,0.0,0.0,0.0,0.008,0.0,0.011,0.0,0.0,0.002,0.0,0.0,0.0 253 | 252,0.008,0.0,0.002,0.009,0.346,0.102,0.176,0.04,0.004,0.0,0.0,0.005,0.01,0.024,0.045,0.018,0.21 254 | 253,0.002,0.014,0.0,0.355,0.01,0.074,0.115,0.0,0.038,0.009,0.0,0.107,0.037,0.015,0.0,0.225,0.0 255 | 254,0.004,0.0,0.034,0.001,0.0,0.046,0.0,0.0,0.294,0.001,0.0,0.131,0.011,0.045,0.0,0.433,0.0 256 | 255,0.0,0.104,0.0,0.08,0.578,0.001,0.0,0.001,0.0,0.0,0.13,0.0,0.029,0.0,0.0,0.0,0.076 257 | 256,0.031,0.0,0.0,0.0,0.001,0.0,0.833,0.0,0.0,0.0,0.0,0.134,0.0,0.001,0.0,0.0,0.0 258 | --------------------------------------------------------------------------------