├── requirements.txt ├── scripts ├── data_utils.py ├── convert2dialog_mode_data_format.py ├── convert2hf.py ├── convert2ckpt.py ├── pretokenize.py └── check_data.py ├── setup.py ├── examples ├── train_llama_deepspeed.sh ├── train_slurm.sh └── train.py ├── configs ├── ds_config.json └── ds_config_zero1.json ├── README.md ├── src └── transpeeder │ ├── utils.py │ ├── models │ ├── modeling_llama.py │ ├── llama_pipeline_model.py │ └── patching.py │ └── feeder.py ├── LICENSE └── data └── alpaca_data_sample_oneline_format.json /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | sentencepiece 3 | transformers >= 4.31.0 4 | deepspeed @ git+https://github.com/HuangLK/DeepSpeed.git@dev 5 | flash_attn >= 2.0 6 | -------------------------------------------------------------------------------- /scripts/data_utils.py: -------------------------------------------------------------------------------- 1 | from lingua import Language, LanguageDetectorBuilder 2 | 3 | 4 | class LinguaLid: 5 | """ 6 | https://github.com/pemistahl/lingua-py 7 | """ 8 | def __init__(self): 9 | super().__init__() 10 | self.detector = LanguageDetectorBuilder.from_all_languages().with_preloaded_language_models().build() 11 | 12 | def detect(self, text: str): 13 | result = self.detector.detect_language_of(text) 14 | if result is None: 15 | return None 16 | else: 17 | return result.iso_code_639_1.name.lower() 18 | 19 | 20 | if __name__ == '__main__': 21 | lid = LinguaLid() 22 | print(lid.detect('中の文.')) 23 | print(lid.detect('123')) 24 | print(lid.detect('NHANES III数据库中的期望死亡率可以通过使用Cox比例风险回归模型来确定。')) 25 | print(lid.detect('HistAuGAN的代码和模型可以在https://github.com/sophiajw/HistAuGAN上公开获取。')) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | def fetch_requirements(path): 5 | with open(path, 'r') as fd: 6 | return [r.strip() for r in fd.readlines()] 7 | 8 | 9 | setup( 10 | name='transpeeder', 11 | version='1.0.0', 12 | package_dir={"": "src"}, 13 | packages=find_packages("src"), 14 | description='🤗 transformers and 🚀deepspeed', 15 | long_description='', 16 | long_description_content_type='text/markdown', 17 | license='Apache Software License 2.0', 18 | install_requires=fetch_requirements('requirements.txt'), 19 | python_requires='>=3.10', 20 | classifiers=[ 21 | 'Programming Language :: Python :: 3', 22 | 'License :: OSI Approved :: Apache Software License', 23 | 'Environment :: GPU :: NVIDIA CUDA', 24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 25 | 'Topic :: System :: Distributed Computing', 26 | ], 27 | ) -------------------------------------------------------------------------------- /examples/train_llama_deepspeed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Train script. 3 | set -eux 4 | 5 | name=$1 6 | gpus=$2 7 | 8 | export MASTER_PORT=23857 9 | export WORK_DIR=`pwd` 10 | 11 | #ds_report 12 | 13 | OUTPUT=${WORK_DIR}/output/${name} 14 | if [ -d $OUTPUT ]; then 15 | # rm 16 | echo "${OUTPUT} exist." 17 | else 18 | mkdir -p ${OUTPUT} 19 | fi 20 | 21 | echo "conda env: $CONDA_PREFIX" 22 | deepspeed --include localhost:$2 --master_port ${MASTER_PORT} ${WORK_DIR}/train.py \ 23 | --output_dir ${OUTPUT} \ 24 | --init_ckpt /path/to/models/llama-30b-init-ckpt/ \ 25 | --data_path /path/to/alpaca_en_zh_oneline_format.json \ 26 | --max_seq_len 8192 \ 27 | --train_steps 1000 \ 28 | --eval_steps 10 \ 29 | --save_steps 200 \ 30 | --log_steps 1 \ 31 | --pipe_parallel_size 8 \ 32 | --model_parallel_size 1 \ 33 | --use_flash_attn true \ 34 | --ntk true \ 35 | --deepspeed_config ${WORK_DIR}/../configs/ds_config_zero1.json 36 | -------------------------------------------------------------------------------- /configs/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 16, 3 | "train_batch_size": 128, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 1.0, 6 | "bf16": { 7 | "enabled": false 8 | }, 9 | "fp16": { 10 | "enabled": true, 11 | "loss_scale": 0, 12 | "loss_scale_window": 1000, 13 | "initial_scale_power": 12, 14 | "hysteresis": 2, 15 | "min_loss_scale": 1 16 | }, 17 | "optimizer": { 18 | "type": "Adam", 19 | "params": { 20 | "lr": 1e-5, 21 | "betas": [0.9, 0.95], 22 | "eps": 1.0e-8 23 | } 24 | }, 25 | "scheduler": { 26 | "type": "WarmupDecayLR", 27 | "params": { 28 | "warmup_min_lr": 1e-6, 29 | "warmup_max_lr": 1e-5, 30 | "warmup_num_steps": 100, 31 | "total_num_steps": 1000 32 | } 33 | }, 34 | "activation_checkpointing": { 35 | "partition_activations": false, 36 | "cpu_checkpointing": false, 37 | "contiguous_memory_optimization": false, 38 | "number_checkpoints": null, 39 | "synchronize_checkpoint_boundary": false, 40 | "profile": false 41 | }, 42 | "wandb": { 43 | "enabled": true, 44 | "team": null, 45 | "group": null, 46 | "project": "llama-7B-test" 47 | }, 48 | "wall_clock_breakdown": true 49 | } 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # transpeeder 2 | This is a project under development, which aims to fine-tune the llama (7-70B) model based on the 🤗transformers and 🚀deepspeed, and provide simple and convenient training scripts. 3 | 4 | ## installation 5 | ``` 6 | pip install -e . 7 | ``` 8 | 9 | ## data 10 | Each line is a **JSON string**, as the JSON object must have `prompt` and `output` fields. 11 | ``` 12 | { 13 | "prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWhat is the capital of France?\n\n### Response:", 14 | "output": "The capital of France is Paris." 15 | } 16 | ``` 17 | 18 | ## convert hf model to ckpt 19 | ```bash 20 | # llama-7B 21 | python -m scripts.convert2ckpt --mp_world_size 4 \ 22 | --model_name_or_path /path/to/llama-7b-hf \ 23 | --output_dir /path/to/llama-7b-init-ckpt 24 | 25 | # llama-30B 26 | python -m scripts.convert2ckpt --mp_world_size 8 \ 27 | --model_name_or_path /path/to/llama-30b-hf \ 28 | --output_dir /path/to/llama-30b-init-ckpt 29 | ``` 30 | 31 | ## finetune 32 | See `examples/train_llama_deepspeed.sh`. 33 | 34 | 35 | ## convert ckpt to hf model 36 | ```bash 37 | python -m scripts.convert2hf --model_size 7B \ 38 | --input_dir ./output/llama-7B-ckpt/global_step1000/ \ 39 | --output_dir ./output/llama_hf_7B \ 40 | --tokenizer_size 32001 41 | cp /path/to/llama-7b-hf/*.json ./output/llama_hf_7B 42 | cp /path/to/llama-7b-hf/tokenizer.model ./output/llama_hf_7B 43 | ``` 44 | -------------------------------------------------------------------------------- /configs/ds_config_zero1.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 1, 3 | "train_batch_size": 128, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 1.0, 6 | "bf16": { 7 | "enabled": false 8 | }, 9 | "fp16": { 10 | "enabled": true, 11 | "loss_scale": 0, 12 | "loss_scale_window": 1000, 13 | "initial_scale_power": 12, 14 | "hysteresis": 2, 15 | "min_loss_scale": 1 16 | }, 17 | "optimizer": { 18 | "type": "Adam", 19 | "params": { 20 | "lr": 1e-5, 21 | "betas": [0.9, 0.95], 22 | "eps": 1.0e-8 23 | } 24 | }, 25 | "scheduler": { 26 | "type": "WarmupDecayLR", 27 | "params": { 28 | "warmup_min_lr": 1e-6, 29 | "warmup_max_lr": 1e-5, 30 | "warmup_num_steps": 100, 31 | "total_num_steps": 1000 32 | } 33 | }, 34 | "zero_optimization": { 35 | "stage": 1, 36 | "offload_optimizer": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "allgather_partitions": true, 41 | "allgather_bucket_size": 5e8, 42 | "overlap_comm": false, 43 | "reduce_scatter": true, 44 | "reduce_bucket_size": 5e8, 45 | "contiguous_gradients": true 46 | }, 47 | "activation_checkpointing": { 48 | "partition_activations": false, 49 | "cpu_checkpointing": false, 50 | "contiguous_memory_optimization": false, 51 | "number_checkpoints": null, 52 | "synchronize_checkpoint_boundary": false, 53 | "profile": false 54 | }, 55 | "wandb": { 56 | "enabled": true, 57 | "team": null, 58 | "group": null, 59 | "project": "llama-65B-test-ckpt" 60 | }, 61 | "wall_clock_breakdown": true 62 | } 63 | -------------------------------------------------------------------------------- /scripts/convert2dialog_mode_data_format.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import os 4 | import argparse 5 | 6 | GENERAL_PROMPT = "Add-Your-System-Prompt-Here" 7 | UNIDIED_PROMPT = "<|prefix_begin|>{system_prompt}<|prefix_end|><|prompter|>{user_prompt}<|endoftext|><|assistant|>" 8 | PREFIX_PROMPT = f"<|prefix_begin|>{GENERAL_PROMPT}<|prefix_end|>" 9 | 10 | def read_data(path): 11 | data = [] 12 | with open(path, "r") as f: 13 | for line in f.readlines(): 14 | data.append(json.loads(line.strip())) 15 | return data 16 | 17 | def convert(input_path, output_path): 18 | ''' 19 | Convert the data format from {"prompt":[p1,p2,p3...], "output":[o1,o2,o3...]} 20 | to multi-round dialogue data format for transpeeder's dialog mode 21 | ''' 22 | if os.path.exists(output_path): 23 | print("output dir exists!") 24 | return 25 | 26 | all_data = read_data(input_path) 27 | train = [] 28 | for data in tqdm(all_data): 29 | prompt = data["prompt"] 30 | output = data["output"] 31 | q_prompt = PREFIX_PROMPT 32 | for idx in range(len(prompt) - 1): 33 | q_prompt += f"<|prompter|>{prompt[idx]}<|endoftext|><|assistant|>{output[idx]}<|endoftext|>" 34 | q_prompt += f"<|prompter|>{prompt[-1]}<|endoftext|><|assistant|>" 35 | out = {} 36 | out["prompt"] = q_prompt 37 | out["output"] = output[-1] 38 | train.append(out) 39 | 40 | with open(output_path, "w") as f: 41 | for line in train: 42 | f.write(json.dumps(line, ensure_ascii=False) + "\n") 43 | 44 | 45 | def main(): 46 | 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument( 49 | "--input_path", 50 | help="Location of original data", 51 | ) 52 | 53 | parser.add_argument( 54 | "--output_path", 55 | help="Location of data translated to the format of multi-round dialogue", 56 | ) 57 | args = parser.parse_args() 58 | 59 | convert( 60 | input_path=args.input_path, 61 | output_path=args.output_path 62 | ) 63 | 64 | if __name__ == "__main__": 65 | main() -------------------------------------------------------------------------------- /examples/train_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Usage: srun train_slurm.sh 3 | 4 | set -eux 5 | 6 | # 任务名 7 | name=llama-30B-slurm-test 8 | 9 | export MASTER_ADDR=`perl -le '$_=$ENV{"SLURM_JOB_NODELIST"}; s/,.*//; s/-.*//; s/\[//; print'` 10 | export MASTER_PORT=23857 11 | export OMP_NUM_THREADS=8 12 | #export CUDA_LAUNCH_BLOCKING=1 13 | export WORK_DIR=`pwd` 14 | 15 | # 日志路径 16 | LOG_PATH=./logs/slurm_log_$(date '+%m%d%H%M').txt 17 | GPUS_PER_NODE=8 18 | # 节点数 19 | NNODES=4 20 | # 总gpu数 21 | N_GPUS=32 22 | 23 | # testing for potential faulty nodes 24 | # srun --jobid $SLURM_JOB_ID bash -c 'python -c "import torch, socket; print(socket.gethostname(), torch.cuda.is_available())"' 25 | # exit 0 26 | 27 | # 模型保存路径 28 | OUTPUT=${WORK_DIR}/output/${name} 29 | if [ -d $OUTPUT ]; then 30 | # rm 31 | echo "${OUTPUT} exist." 32 | else 33 | mkdir -p ${OUTPUT} 34 | fi 35 | 36 | echo "conda env: $CONDA_PREFIX" 37 | 38 | export LAUNCHER="python -u -m torch.distributed.run \ 39 | --nproc_per_node $GPUS_PER_NODE \ 40 | --nnodes $NNODES \ 41 | --master_addr $MASTER_ADDR \ 42 | --master_port $MASTER_PORT \ 43 | --max_restarts 0 \ 44 | --tee 3 \ 45 | " 46 | 47 | # 训练任务 48 | export CMD=" \ 49 | ${WORK_DIR}/train.py \ 50 | --output_dir ${OUTPUT} \ 51 | --init_ckpt /path/to/llama-30b-init-ckpt/ \ 52 | --data_path /path/to/alpaca_en_zh_oneline_format.json \ 53 | --max_seq_len 8192 \ 54 | --train_steps 1000 \ 55 | --eval_steps 10 \ 56 | --save_steps 200 \ 57 | --log_steps 1 \ 58 | --pipe_parallel_size 8 \ 59 | --model_parallel_size 1 \ 60 | --use_flash_attn true \ 61 | --ntk true \ 62 | --deepspeed_config ${WORK_DIR}/../configs/ds_config_zero1.json 63 | " 64 | 65 | # do not remove or the training will hang and nodes will be lost w/o this workaround 66 | export CUDA_LAUNCH_BLOCKING=1 67 | 68 | # hide duplicated errors using this hack - will be properly fixed in pt-1.12 69 | export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json 70 | 71 | # force crashing on nccl issues like hanging broadcast 72 | export NCCL_ASYNC_ERROR_HANDLING=1 73 | 74 | echo "START TIME: $(date)" 75 | 76 | bash -c "$LAUNCHER --node_rank $SLURM_PROCID $CMD" 2>&1 | tee -a $LOG_PATH 77 | 78 | echo "END TIME: $(date)" 79 | -------------------------------------------------------------------------------- /src/transpeeder/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import json 4 | 5 | import torch.distributed as dist 6 | from loguru import logger as logger 7 | 8 | 9 | logger.add(f'ds_training.log') 10 | 11 | 12 | def is_rank_0() -> bool: 13 | return not dist.is_initialized() or dist.get_rank() == 0 14 | 15 | 16 | class LoggerRank0: 17 | def trace(self, *args, **kwargs): 18 | if not is_rank_0(): 19 | return 20 | logger.trace(*args, **kwargs) 21 | 22 | def debug(self, *args, **kwargs): 23 | if not is_rank_0(): 24 | return 25 | logger.debug(*args, **kwargs) 26 | 27 | def info(self, *args, **kwargs): 28 | if not is_rank_0(): 29 | return 30 | logger.info(*args, **kwargs) 31 | 32 | def warning(self, *args, **kwargs): 33 | if not is_rank_0(): 34 | return 35 | logger.warning(*args, **kwargs) 36 | 37 | def error(self, *args, **kwargs): 38 | if not is_rank_0(): 39 | return 40 | logger.error(*args, **kwargs) 41 | 42 | logger_rank0 = LoggerRank0() 43 | 44 | 45 | def _make_w_io_base(f, mode: str): 46 | if not isinstance(f, io.IOBase): 47 | f_dirname = os.path.dirname(f) 48 | if f_dirname != "": 49 | os.makedirs(f_dirname, exist_ok=True) 50 | f = open(f, mode=mode) 51 | return f 52 | 53 | 54 | def _make_r_io_base(f, mode: str): 55 | if not isinstance(f, io.IOBase): 56 | f = open(f, mode=mode) 57 | return f 58 | 59 | 60 | def jdump(obj, f, mode="w", indent=4, default=str): 61 | """Dump a str or dictionary to a file in json format. 62 | 63 | Args: 64 | obj: An object to be written. 65 | f: A string path to the location on disk. 66 | mode: Mode for opening the file. 67 | indent: Indent for storing json dictionaries. 68 | default: A function to handle non-serializable entries; defaults to `str`. 69 | """ 70 | f = _make_w_io_base(f, mode) 71 | if isinstance(obj, (dict, list)): 72 | json.dump(obj, f, indent=indent, default=default) 73 | elif isinstance(obj, str): 74 | f.write(obj) 75 | else: 76 | raise ValueError(f"Unexpected type: {type(obj)}") 77 | f.close() 78 | 79 | 80 | def jload(f, mode="r"): 81 | """Load a .json file into a dictionary.""" 82 | f = _make_r_io_base(f, mode) 83 | jdict = json.load(f) 84 | f.close() 85 | return jdict 86 | -------------------------------------------------------------------------------- /scripts/convert2hf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from pathlib import Path 5 | 6 | import torch 7 | 8 | 9 | PARAM_MAP = { 10 | "7B": { 11 | "n_layers": 32, 12 | }, 13 | "13B": { 14 | "n_layers": 40, 15 | }, 16 | "30B": { 17 | "n_layers": 60, 18 | }, 19 | "65B": { 20 | "n_layers": 80, 21 | }, 22 | "70B": { 23 | "n_layers": 80, 24 | }, 25 | } 26 | 27 | 28 | def read_json(path): 29 | with open(path, "r") as f: 30 | return json.load(f) 31 | 32 | 33 | def write_json(text, path): 34 | with open(path, "w") as f: 35 | json.dump(text, f) 36 | 37 | 38 | def write_model(model_path, input_base_path, model_size, tokenizer_size): 39 | assert model_size in PARAM_MAP 40 | os.makedirs(model_path, exist_ok=True) 41 | 42 | params = PARAM_MAP[model_size] 43 | n_layers = params["n_layers"] 44 | 45 | loaded = {} 46 | ORIGINAL_TOKENIZER_SIZE = tokenizer_size 47 | for pt in Path(input_base_path).iterdir(): 48 | # assert tp/mp == 1 49 | if not pt.name.startswith('layer_'): 50 | continue 51 | sd = torch.load(pt, map_location="cpu") 52 | if pt.name == 'layer_00-model_00-model_states.pt': 53 | loaded['model.embed_tokens.weight'] = sd['weight'][: ORIGINAL_TOKENIZER_SIZE, :] 54 | continue 55 | if pt.name == f'layer_{n_layers + 1}-model_00-model_states.pt': 56 | loaded['model.norm.weight'] = sd['weight'] 57 | continue 58 | if pt.name == f'layer_{n_layers + 2}-model_00-model_states.pt': 59 | loaded['lm_head.weight'] = sd['weight'][: ORIGINAL_TOKENIZER_SIZE, :] 60 | continue 61 | 62 | layer_i = int(pt.name.split('-')[0].replace('layer_', '')) - 1 63 | layer_sd = { f"model.layers.{layer_i}.{nm}": weight for nm, weight in sd.items() } 64 | loaded.update(layer_sd) 65 | 66 | 67 | torch.save(loaded, os.path.join(model_path, "pytorch_model.bin")) 68 | 69 | 70 | def main(): 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument( 73 | "--input_dir", 74 | help="Location of LLaMA weights, which contains tokenizer.model and model folders", 75 | ) 76 | parser.add_argument( 77 | "--model_size", 78 | choices=["7B", "13B", "30B", "65B", "70B"], 79 | ) 80 | parser.add_argument( 81 | "--output_dir", 82 | help="Location to write HF model and tokenizer", 83 | ) 84 | parser.add_argument( 85 | "--tokenizer_size", 86 | help="Size of tokenizer", 87 | type=int, 88 | ) 89 | args = parser.parse_args() 90 | write_model( 91 | model_path=args.output_dir, 92 | input_base_path=args.input_dir, 93 | model_size=args.model_size, 94 | tokenizer_size=args.tokenizer_size, 95 | ) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /src/transpeeder/models/modeling_llama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LlamaRotaryEmbedding(torch.nn.Module): 4 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 5 | super().__init__() 6 | self.dim = dim 7 | self.max_position_embeddings = max_position_embeddings 8 | self.base = base 9 | self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 10 | 11 | # Build here to make `torch.jit.trace` work. 12 | self._set_cos_sin_cache( 13 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 14 | ) 15 | 16 | def _set_cos_sin_cache(self, seq_len, device, dtype): 17 | self.max_seq_len_cached = seq_len 18 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) 19 | 20 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 21 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 22 | emb = torch.cat((freqs, freqs), dim=-1) 23 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) 24 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) 25 | 26 | def forward(self, x, seq_len=None): 27 | # x: [bs, num_attention_heads, seq_len, head_size] 28 | if seq_len > self.max_seq_len_cached: 29 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 30 | 31 | return ( 32 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 33 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 34 | ) 35 | 36 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): 37 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 38 | 39 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 40 | self.scaling_factor = scaling_factor 41 | super().__init__(dim, max_position_embeddings, base, device) 42 | 43 | def _set_cos_sin_cache(self, seq_len, device, dtype): 44 | self.max_seq_len_cached = seq_len 45 | 46 | if seq_len > self.max_position_embeddings: 47 | base = self.base * ( 48 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) 49 | ) ** (self.dim / (self.dim - 2)) 50 | self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 51 | 52 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) 53 | 54 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 55 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 56 | emb = torch.cat((freqs, freqs), dim=-1) 57 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) 58 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) -------------------------------------------------------------------------------- /scripts/convert2ckpt.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Literal 3 | from dataclasses import dataclass, field 4 | 5 | import torch 6 | import transformers 7 | 8 | from transpeeder.models.patching import ( 9 | smart_tokenizer_and_embedding_resize, 10 | ) 11 | from transpeeder.feeder import ( 12 | DEFAULT_BOS_TOKEN, 13 | DEFAULT_PAD_TOKEN, 14 | DEFAULT_EOS_TOKEN, 15 | DEFAULT_UNK_TOKEN, 16 | ) 17 | 18 | 19 | @dataclass 20 | class Arguments: 21 | model_name_or_path: Optional[str] = field(default="/path/to/llama-7b-hf") 22 | output_dir: str = field(default="./llama-7B-init-ckpt") 23 | mp_world_size: int = field(default=8) 24 | 25 | 26 | def write_ckpt(outpath: Path, model: torch.nn.Module, model_config: transformers.AutoConfig, mp: int): 27 | loaded = model.state_dict() 28 | 29 | n_layers = model_config.num_hidden_layers 30 | # embedding 31 | sd = {"weight": loaded['model.embed_tokens.weight']} 32 | torch.save(sd, outpath / "layer_00-model_00-model_states.pt") 33 | # norm 34 | sd = {f"weight": loaded['model.norm.weight']} 35 | torch.save(sd, outpath / f"layer_{n_layers + 1}-model_00-model_states.pt") 36 | # lm head 37 | sd = {f"weight": loaded['lm_head.weight']} 38 | torch.save(sd, outpath / f"layer_{n_layers + 2}-model_00-model_states.pt") 39 | # decoder layers 40 | for layer_i in range(n_layers): 41 | sd = {nm.replace(f"model.layers.{layer_i}.", f""): weight for nm, weight in loaded.items() if nm.startswith(f"model.layers.{layer_i}.") and not nm.endswith("inv_freq")} 42 | torch.save(sd, outpath / f"layer_{layer_i + 1:02d}-model_00-model_states.pt") 43 | 44 | model_state = { 45 | "dp_world_size": 1, 46 | "mp_world_size": mp, 47 | "module": None, 48 | "optimizer": None, 49 | "global_steps": 1, 50 | "skipped_steps": 1, 51 | "iteration": 1, 52 | } 53 | for rank in range(mp): 54 | torch.save(model_state, outpath / f"mp_rank_{rank:02d}_model_states.pt") 55 | 56 | 57 | def main(): 58 | parser = transformers.HfArgumentParser((Arguments,)) 59 | args, = parser.parse_args_into_dataclasses() 60 | 61 | tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name_or_path) 62 | model_config = transformers.AutoConfig.from_pretrained(args.model_name_or_path) 63 | model = transformers.AutoModelForCausalLM.from_pretrained(args.model_name_or_path) 64 | 65 | if tokenizer.pad_token is None: 66 | smart_tokenizer_and_embedding_resize( 67 | special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), 68 | tokenizer=tokenizer, 69 | model=model, 70 | ) 71 | 72 | special_tokens_map = { 73 | "additional_special_tokens": [ 74 | "<|prefix_begin|>", 75 | "<|prefix_end|>", 76 | "<|prompter|>", 77 | "<|endoftext|>", 78 | "<|assistant|>" 79 | ], 80 | } 81 | if tokenizer.pad_token is None: 82 | special_tokens_map.update(dict(pad_token=DEFAULT_PAD_TOKEN)) 83 | 84 | smart_tokenizer_and_embedding_resize( 85 | special_tokens_dict=special_tokens_map, 86 | tokenizer=tokenizer, 87 | model=model, 88 | ) 89 | 90 | if "llama" in args.model_name_or_path: 91 | tokenizer.add_special_tokens( 92 | { 93 | "eos_token": DEFAULT_EOS_TOKEN, 94 | "bos_token": DEFAULT_BOS_TOKEN, 95 | "unk_token": DEFAULT_UNK_TOKEN, 96 | } 97 | ) 98 | 99 | model_config.vocab_size = len(tokenizer) 100 | 101 | outpath = Path(args.output_dir) 102 | if outpath.exists(): 103 | print(f"{outpath} exists. Do nothing.") 104 | exit(0) 105 | 106 | print(f"create {outpath}") 107 | outpath.mkdir() 108 | steppath = outpath / "global_step001" 109 | steppath.mkdir() 110 | 111 | write_ckpt(steppath, model, model_config, args.mp_world_size) 112 | with open(outpath / "latest", "w") as fout: 113 | fout.write("global_step001") 114 | 115 | tokenizer.save_pretrained(outpath) 116 | model_config.save_pretrained(outpath) 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /scripts/pretokenize.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import multiprocessing 4 | from pprint import pprint 5 | from pathlib import Path 6 | from typing import Optional, Literal 7 | from dataclasses import dataclass, field 8 | 9 | import torch 10 | import numpy as np 11 | import pandas as pd 12 | import transformers 13 | from tqdm import tqdm 14 | 15 | from transpeeder.feeder import ( 16 | preprocess, PROMPT_FIELD, OUTPUT_FIELD, IGNORE_INDEX 17 | ) 18 | 19 | def _chunk(lst, n): 20 | for i in range(0, len(lst), n): 21 | yield lst[i: i + n] 22 | 23 | 24 | @dataclass 25 | class Arguments: 26 | seed: int = field(default=42) 27 | tokenizer_name_or_path: str = field(default="/path/to/llama-7b-hf") 28 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 29 | output_path: str = field(default="/path/to/output.pt") 30 | mode: Literal['sft', 'pretrain', 'dialog'] = 'sft' 31 | max_seq_len: int = field(default=8192) 32 | batch_size: int = field(default=16) 33 | workers: int = field(default=64) 34 | 35 | 36 | def load_samples(args, eos=''): 37 | samples = [] 38 | data_path = Path(args.data_path) 39 | all_files = list(data_path.glob('**/*.json') if data_path.is_dir() else [data_path]) 40 | 41 | for single_file in tqdm(all_files): 42 | with (single_file).open(encoding='utf-8') as f: 43 | for lnum, ln in enumerate(f): 44 | sample = json.loads(ln) 45 | prompt, output = sample[PROMPT_FIELD], sample[OUTPUT_FIELD] 46 | if not isinstance(prompt, str) or not isinstance(output, str): 47 | raise ValueError() 48 | samples.append(dict( 49 | prompt=prompt, 50 | output=output + eos, 51 | )) 52 | 53 | print(f'total samples num: {len(samples)}') 54 | return samples 55 | 56 | 57 | class Encoder(object): 58 | def __init__(self, args): 59 | self.args = args 60 | 61 | def initializer(self): 62 | # Use Encoder class as a container for global data 63 | Encoder.tokenizer = transformers.AutoTokenizer.from_pretrained( 64 | self.args.tokenizer_name_or_path, 65 | model_max_length=self.args.max_seq_len, 66 | padding_side="right", 67 | use_fast=True, 68 | ) 69 | 70 | def batch_encode(self, batch): 71 | sources = [sample[PROMPT_FIELD] for sample in batch] 72 | targets = [sample[OUTPUT_FIELD] for sample in batch] 73 | 74 | data_dict = preprocess(sources, targets, Encoder.tokenizer, self.args.mode) 75 | input_ids = data_dict["input_ids"] 76 | labels = data_dict["labels"] 77 | 78 | input_ids = torch.stack(input_ids) 79 | labels = torch.stack(labels) 80 | labels = torch.where(labels == Encoder.tokenizer.pad_token_id, IGNORE_INDEX, labels) 81 | 82 | return [ 83 | dict( 84 | input_ids=iid, 85 | labels=lbl, 86 | ) for iid, lbl in zip(input_ids, labels) 87 | ] 88 | 89 | 90 | def main(): 91 | parser = transformers.HfArgumentParser((Arguments,)) 92 | args, = parser.parse_args_into_dataclasses() 93 | 94 | random.seed(args.seed) 95 | np.random.seed(args.seed) 96 | torch.manual_seed(args.seed) 97 | 98 | tokenizer = transformers.AutoTokenizer.from_pretrained( 99 | args.tokenizer_name_or_path, 100 | model_max_length=args.max_seq_len, 101 | padding_side="right", 102 | use_fast=True, 103 | ) 104 | 105 | samples = load_samples(args, tokenizer.eos_token) 106 | batches = _chunk(samples, args.batch_size) 107 | 108 | encoder = Encoder(args) 109 | if args.workers > 1: 110 | pool = multiprocessing.Pool(args.workers, encoder.initializer) 111 | encoded_rlt = pool.imap(encoder.batch_encode, batches) 112 | else: 113 | encoder.initializer() 114 | encoded_rlt = (encoder.batch_encode(batch) for batch in batches) 115 | 116 | data = [] 117 | for encoded_batch in tqdm(encoded_rlt, total=len(samples) // args.batch_size + 1): 118 | data.extend(encoded_batch) 119 | torch.save(data, args.output_path) 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /src/transpeeder/models/llama_pipeline_model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaConfig 5 | import deepspeed 6 | from deepspeed.pipe import PipelineModule, LayerSpec 7 | 8 | 9 | class EmbeddingPipe(torch.nn.Embedding): 10 | def forward(self, args): 11 | input_ids, position_ids, attention_mask = args 12 | inputs_embeds = super().forward(input_ids) 13 | return (inputs_embeds, position_ids, attention_mask) 14 | 15 | 16 | class ParallelTransformerLayerPipe(LlamaDecoderLayer): 17 | def __init__(self, config: LlamaConfig): 18 | super().__init__(config) 19 | 20 | def forward(self, args): 21 | hidden_states, position_ids, mask = args 22 | attention_mask = torch.where(mask == True, float("-inf"), 0).long() 23 | 24 | outputs = LlamaDecoderLayer.forward(self, 25 | hidden_states, 26 | attention_mask, 27 | position_ids, 28 | ) 29 | return (outputs[0], position_ids, mask) 30 | 31 | 32 | class LayerNormPipe(LlamaRMSNorm): 33 | def forward(self, args): 34 | hidden_states, *_ = args 35 | last_hidden_states = super().forward(hidden_states) 36 | return (last_hidden_states,) 37 | 38 | 39 | class LMLayerPipe(torch.nn.Linear): 40 | def forward(self, args): 41 | hidden_states, = args 42 | logits = super().forward(hidden_states) 43 | return (logits,) 44 | 45 | 46 | def loss_fn(outputs, labels): 47 | # unpack 48 | logits, = outputs 49 | # all labels are `ignore_index` will cause nan 50 | return F.cross_entropy( 51 | logits.view(-1, logits.shape[-1]), 52 | labels.view(-1), 53 | ) 54 | 55 | 56 | def get_model(model_config: LlamaConfig, args, activation_checkpointing_config=None, **kwargs): 57 | class GPT2ModelPipe(PipelineModule): 58 | def __init__(self, model_config, **kwargs): 59 | if activation_checkpointing_config: 60 | deepspeed.checkpointing.configure( 61 | None, 62 | partition_activations=activation_checkpointing_config.get("partition_activations", False), 63 | contiguous_checkpointing=activation_checkpointing_config.get("contiguous_memory_optimization", False), 64 | checkpoint_in_cpu=activation_checkpointing_config.get("cpu_checkpointing", False), 65 | num_checkpoints=activation_checkpointing_config.get("number_checkpoints", None), 66 | synchronize=activation_checkpointing_config.get("synchronize_checkpoint_boundary", False), 67 | profile=activation_checkpointing_config.get("profile", False), 68 | ) 69 | super().__init__( 70 | layers=[ 71 | LayerSpec(EmbeddingPipe, model_config.vocab_size, model_config.hidden_size), 72 | *[LayerSpec(ParallelTransformerLayerPipe, model_config) 73 | for _ in range(model_config.num_hidden_layers)], 74 | LayerSpec(LayerNormPipe, model_config.hidden_size, model_config.rms_norm_eps), 75 | LayerSpec(LMLayerPipe, model_config.hidden_size, model_config.vocab_size, bias=False), 76 | ], 77 | activation_checkpoint_interval=(1 if activation_checkpointing_config else 0), 78 | checkpointable_layers=["ParallelTransformerLayerPipe"], 79 | **kwargs 80 | ) 81 | 82 | pp = args.pipe_parallel_size 83 | mp = args.model_parallel_size 84 | assert args.world_size % (pp * mp) == 0 85 | dp = args.world_size // (pp * mp) 86 | 87 | from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology 88 | topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp) 89 | # Offset base seeds for the interior pipeline stages. 90 | stage_id = topo.get_coord(rank=torch.distributed.get_rank()).pipe 91 | if 0 < stage_id < topo.get_dim('pipe') - 1: 92 | args.seed = args.seed + (stage_id * mp) 93 | 94 | return GPT2ModelPipe(model_config, 95 | loss_fn=loss_fn, 96 | topology=topo, 97 | base_seed=args.seed, 98 | **kwargs) 99 | -------------------------------------------------------------------------------- /examples/train.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import random 4 | import warnings 5 | from dataclasses import dataclass, field 6 | from typing import Optional, Literal 7 | 8 | import torch 9 | import transformers 10 | import numpy as np 11 | import deepspeed 12 | 13 | from transpeeder.models.llama_pipeline_model import get_model 14 | from transpeeder.models.patching import ( 15 | replace_llama_attn_with_flash_attn, 16 | refine_rope, 17 | ) 18 | from transpeeder.feeder import ( 19 | make_prompt_dataloader, 20 | make_tokenized_dataloader, 21 | ) 22 | from transpeeder.utils import jload 23 | from transpeeder.utils import logger_rank0 as logger 24 | 25 | warnings.filterwarnings("ignore") 26 | 27 | @dataclass 28 | class TrainerArguments: 29 | init_ckpt: str = field(default="llama-7B-init-test-ckpt") 30 | use_flash_attn: Optional[bool] = field(default=False) 31 | 32 | rank: int = field(default=None) 33 | local_rank: int = field(default=None) 34 | pipe_parallel_size: int = field(default=1) 35 | model_parallel_size: int = field(default=1) 36 | world_size: int = field(default=None) 37 | seed: int = field(default=42) 38 | deepspeed_config: Optional[str] = field(default=None) 39 | 40 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 41 | input_format: Literal['raw', 'tokenized'] = 'raw' 42 | mode: Literal['sft', 'pretrain', 'dialog'] = 'sft' 43 | num_workers: int = field(default=1) 44 | 45 | cache_dir: Optional[str] = field(default=None) 46 | output_dir: str = field(default="./output") 47 | max_seq_len: int = field(default=128) 48 | train_steps: int = field(default=100) 49 | eval_steps: int = field(default=100) 50 | save_steps: int = field(default=100) 51 | log_steps: int = field(default=1) 52 | 53 | resume_step: int = field(default=-1) 54 | resume_ckpt: str = field(default="llama-7B-init-test-ckpt") 55 | ntk : Optional[bool] = field(default=False) 56 | 57 | def read_ds_config(config_path): 58 | config = jload(config_path) 59 | return config 60 | 61 | 62 | def main(): 63 | parser = transformers.HfArgumentParser(TrainerArguments) 64 | args, = parser.parse_args_into_dataclasses() 65 | 66 | # setup deepspeed and other stuff 67 | deepspeed.init_distributed(dist_backend="nccl") 68 | args.world_size = torch.distributed.get_world_size() 69 | 70 | ds_config = read_ds_config(args.deepspeed_config) 71 | args.num_workers = 2 * args.world_size // args.pipe_parallel_size // args.model_parallel_size 72 | args.batch_size = ds_config.get("train_micro_batch_size_per_gpu", 1) 73 | activation_checkpointing_config = ds_config.pop("activation_checkpointing", None) 74 | 75 | random.seed(args.seed) 76 | np.random.seed(args.seed) 77 | torch.manual_seed(args.seed) 78 | deepspeed.runtime.utils.set_random_seed(args.seed) 79 | 80 | if args.use_flash_attn: 81 | logger.info("⚡⚡⚡ enable flash attention.") 82 | replace_llama_attn_with_flash_attn() 83 | refine_rope() 84 | 85 | tokenizer = transformers.AutoTokenizer.from_pretrained( 86 | args.init_ckpt, 87 | model_max_length=args.max_seq_len, 88 | padding_side="right", 89 | use_fast=False, 90 | ) 91 | model_config = transformers.AutoConfig.from_pretrained(args.init_ckpt) 92 | 93 | if args.ntk: 94 | rope_scaling = { 95 | "type": "dynamic", 96 | "factor": 2, 97 | } 98 | model_config.rope_scaling = rope_scaling 99 | logger.info(f"Turn on dynamic rope for llama2") 100 | 101 | # pipeline model 102 | model = get_model(model_config, args, activation_checkpointing_config, partition_method="type:ParallelTransformerLayerPipe") 103 | 104 | engine, _, _, _ = deepspeed.initialize( 105 | args, 106 | model=model, 107 | model_parameters=[p for p in model.parameters() if p.requires_grad], 108 | ) 109 | 110 | # dataset 111 | dataloader_maker = make_tokenized_dataloader if args.input_format == 'tokenized' else make_prompt_dataloader 112 | train_dataloader = dataloader_maker(tokenizer=tokenizer, data_args=args, engine=engine) 113 | 114 | # use `convert2ckpt.py` 115 | if args.resume_step < 0: 116 | engine.load_checkpoint(args.init_ckpt, 117 | load_module_only=True, 118 | load_optimizer_states=False, 119 | load_lr_scheduler_states=False, 120 | ) 121 | else: 122 | engine.load_checkpoint(args.resume_ckpt) 123 | 124 | start = time.time() 125 | for step in range(1, args.train_steps + 1): 126 | if step <= args.resume_step: 127 | micro_batch_num = ds_config['train_batch_size'] // ds_config['train_micro_batch_size_per_gpu'] 128 | [next(train_dataloader) for _ in range(micro_batch_num)] 129 | logger.info(f"Step={step:>6}, skipped.") 130 | continue 131 | 132 | loss = engine.train_batch(data_iter=train_dataloader) 133 | if args.local_rank == 0: 134 | if step % args.log_steps == 0: 135 | now = time.time() 136 | avg_time = (now-start) / args.log_steps 137 | logger.info(f"Step={step:>6}, loss={loss.item():.4f}, {avg_time:.2f} it/s") 138 | start = now 139 | 140 | if step % args.eval_steps == 0: 141 | # TODO 142 | pass 143 | 144 | if step % args.save_steps == 0: 145 | logger.info(f"Saving at step {step}") 146 | engine.save_checkpoint(args.output_dir) 147 | 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /scripts/check_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | from pprint import pprint, pformat 4 | from pathlib import Path 5 | from typing import Optional, Literal 6 | from collections import defaultdict 7 | from dataclasses import dataclass, field 8 | 9 | import torch 10 | import transformers 11 | import pandas as pd 12 | from tqdm import tqdm 13 | from loguru import logger 14 | 15 | from .pretokenize import load_samples, Encoder, _chunk 16 | from .data_utils import LinguaLid 17 | 18 | logger.add('./check_report.log') 19 | 20 | DEFAULT_PAD_TOKEN = "[PAD]" 21 | DEFAULT_EOS_TOKEN = "" 22 | DEFAULT_BOS_TOKEN = "" 23 | DEFAULT_UNK_TOKEN = "" 24 | 25 | PREFIX_BEGIN_TOKEN = "<|prefix_begin|>" 26 | PREFIX_END_TOKEN = "<|prefix_end|>" 27 | PROMPTER_TOKEN = "<|prompter|>" 28 | ASSISTANT_TOKEN = "<|assistant|>" 29 | ENDOFTEXT_TOKEN = "<|endoftext|>" 30 | 31 | 32 | @dataclass 33 | class Arguments: 34 | seed: int = field(default=42) 35 | tokenizer_name_or_path: str = field(default="/path/to/llama-7b-hf") 36 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 37 | output_path: str = field(default="/path/to/output.pt") 38 | mode: Literal['sft', 'pretrain', 'dialog'] = 'sft' 39 | max_seq_len: int = field(default=8192) 40 | batch_size: int = field(default=16) 41 | workers: int = field(default=64) 42 | 43 | 44 | class Summarizer: 45 | def __init__(self): 46 | self.lines = defaultdict(list) 47 | 48 | def update(self, info): 49 | for lnum, msg in info.items(): 50 | self.lines[lnum].append(msg) 51 | 52 | def summary(self): 53 | msgs = [] 54 | for li in range(999999): 55 | if li not in self.lines: 56 | continue 57 | msg = "; ".join(self.lines[li]) 58 | msgs.append(f'line {li}: {msg}') 59 | 60 | return '\n'.join(msgs) 61 | 62 | 63 | def _count_tokens(input_ids): 64 | return [len(x) for x in input_ids] 65 | 66 | 67 | def _check_oasst_format(tokenizeds, tokenizer): 68 | BOS_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_BOS_TOKEN) 69 | EOS_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_EOS_TOKEN) 70 | PREFIX_BEGIN_TOKEN_ID = tokenizer.convert_tokens_to_ids(PREFIX_BEGIN_TOKEN) 71 | PREFIX_END_TOKEN_ID = tokenizer.convert_tokens_to_ids(PREFIX_END_TOKEN) 72 | PROMPTER_TOKEN_ID = tokenizer.convert_tokens_to_ids(PROMPTER_TOKEN) 73 | ASSISTANT_TOKEN_ID = tokenizer.convert_tokens_to_ids(ASSISTANT_TOKEN) 74 | ENDOFTEXT_TOKEN_ID = tokenizer.convert_tokens_to_ids(ENDOFTEXT_TOKEN) 75 | 76 | special_tokens = ( 77 | BOS_TOKEN_ID, EOS_TOKEN_ID, PREFIX_BEGIN_TOKEN_ID, PREFIX_END_TOKEN_ID, 78 | PROMPTER_TOKEN_ID, ASSISTANT_TOKEN_ID, ENDOFTEXT_TOKEN_ID, 79 | ) 80 | def _find_next_special_token(input_ids, st): 81 | while st < len(input_ids) - 1 and input_ids[st] not in special_tokens: 82 | st += 1 83 | if input_ids[st] in special_tokens: 84 | return st 85 | else: 86 | return -1 87 | 88 | msgs = {} 89 | for lnum, item in tqdm(enumerate(tokenizeds), desc='checking oasst format'): 90 | input_ids, labels = item['input_ids'], item['labels'] 91 | if input_ids[0] != BOS_TOKEN_ID: 92 | msgs[lnum + 1] = 'bos error' 93 | continue 94 | if input_ids[1] != PREFIX_BEGIN_TOKEN_ID: 95 | msgs[lnum + 1] = 'prefix begin error' 96 | continue 97 | tidx = _find_next_special_token(input_ids, 2) 98 | if tidx < 0 or input_ids[tidx] != PREFIX_END_TOKEN_ID: 99 | msgs[lnum + 1] = 'prefix end error' 100 | continue 101 | # 检查多轮对话 102 | for _ in range(100): 103 | tidx += 1 104 | if tidx < 0 or input_ids[tidx] != PROMPTER_TOKEN_ID: 105 | msgs[lnum + 1] = 'prompter error' 106 | break 107 | tidx += 1 108 | tidx = _find_next_special_token(input_ids, tidx) 109 | if tidx < 0 or input_ids[tidx] != ENDOFTEXT_TOKEN_ID: 110 | msgs[lnum + 1] = 'endoftext error' 111 | break 112 | tidx += 1 113 | if input_ids[tidx] != ASSISTANT_TOKEN_ID: 114 | msgs[lnum + 1] = 'assistant error' 115 | break 116 | tidx += 1 117 | tidx = _find_next_special_token(input_ids, tidx) 118 | if tidx > 0 and input_ids[tidx] not in (ENDOFTEXT_TOKEN_ID, EOS_TOKEN_ID,): 119 | msgs[lnum + 1] = 'endoftext error' 120 | break 121 | if tidx < 0 or tidx >= len(input_ids) - 1 or input_ids[tidx] == EOS_TOKEN_ID: 122 | break 123 | 124 | return msgs 125 | 126 | 127 | def _check_language(samples): 128 | allowed_langs = ('en', 'zh',) 129 | lid = LinguaLid() 130 | lang_lst = [] 131 | for s in tqdm(samples, desc="checking language"): 132 | text = s['output'] 133 | lang = lid.detect(text) 134 | lang_lst.append(lang) 135 | 136 | df = pd.DataFrame({'lang_stat': lang_lst}) 137 | logger.info(f'lang stat:\n' + pformat(df.value_counts())) 138 | unvalid = {lnum + 1: f'unallowed language [{lang}] detected.' \ 139 | for lnum, lang in enumerate(lang_lst) if lang not in allowed_langs} 140 | return unvalid 141 | 142 | 143 | def _check_length(): 144 | pass 145 | 146 | 147 | def _check_encoding(): 148 | pass 149 | 150 | 151 | def _tokenize(samples, args): 152 | batches = _chunk(samples, args.batch_size) 153 | encoder = Encoder(args) 154 | if args.workers > 1: 155 | pool = multiprocessing.Pool(args.workers, encoder.initializer) 156 | encoded_rlt = pool.imap(encoder.batch_encode, batches) 157 | else: 158 | encoder.initializer() 159 | encoded_rlt = (encoder.batch_encode(batch) for batch in batches) 160 | 161 | data = [] 162 | for encoded_batch in tqdm(encoded_rlt, total=len(samples) // args.batch_size + 1, desc='tokenizing'): 163 | data.extend(encoded_batch) 164 | return data 165 | 166 | 167 | def main(): 168 | parser = transformers.HfArgumentParser((Arguments,)) 169 | args, = parser.parse_args_into_dataclasses() 170 | 171 | tokenizer = transformers.AutoTokenizer.from_pretrained( 172 | args.tokenizer_name_or_path, 173 | model_max_length=args.max_seq_len, 174 | padding_side="right", 175 | use_fast=True, 176 | ) 177 | 178 | samples = load_samples(args, tokenizer.eos_token) 179 | tokenizeds = _tokenize(samples, args) 180 | 181 | summarizer = Summarizer() 182 | summarizer.update(_check_language(samples)) 183 | summarizer.update(_check_oasst_format(tokenizeds, tokenizer)) 184 | 185 | logger.info(f"summary: \n" + summarizer.summary()) 186 | 187 | 188 | 189 | token_nums = [] 190 | df = pd.DataFrame({'token_num': token_nums}) 191 | desc = df.describe( 192 | percentiles=[.5, .75, .85, .90, .95], 193 | ) 194 | pprint(desc) 195 | 196 | 197 | if __name__ == "__main__": 198 | main() 199 | -------------------------------------------------------------------------------- /src/transpeeder/models/patching.py: -------------------------------------------------------------------------------- 1 | """ https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py 2 | """ 3 | 4 | from typing import List, Optional, Tuple, Dict 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import transformers 9 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 10 | 11 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func 12 | from flash_attn.bert_padding import unpad_input, pad_input 13 | 14 | 15 | def smart_tokenizer_and_embedding_resize( 16 | special_tokens_dict: Dict, 17 | tokenizer: transformers.PreTrainedTokenizer, 18 | model: transformers.PreTrainedModel, 19 | ): 20 | """Resize tokenizer and embedding. 21 | 22 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 23 | """ 24 | # TODO: padding embedding size for being divisible by 64. 25 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 26 | model.resize_token_embeddings(len(tokenizer)) 27 | 28 | # NOTE: 多个special tokens用该方式初始化会导致模型初始loss特别大 29 | # if num_new_tokens > 0: 30 | # input_embeddings = model.get_input_embeddings().weight.data 31 | # output_embeddings = model.get_output_embeddings().weight.data 32 | 33 | # input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 34 | # output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 35 | 36 | # input_embeddings[-num_new_tokens:] = input_embeddings_avg 37 | # output_embeddings[-num_new_tokens:] = output_embeddings_avg 38 | 39 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 40 | """ 41 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 42 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 43 | """ 44 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 45 | if n_rep == 1: 46 | return hidden_states 47 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 48 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 49 | 50 | def llama_flash_attn_forward( 51 | self, 52 | hidden_states: torch.Tensor, 53 | attention_mask: Optional[torch.Tensor] = None, 54 | position_ids: Optional[torch.Tensor] = None, 55 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 56 | output_attentions: bool = False, 57 | use_cache: bool = False, 58 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], 59 | Optional[Tuple[torch.Tensor]]]: 60 | """Input shape: Batch x Time x Channel 61 | 62 | attention_mask: [bsz, q_len] 63 | """ 64 | bsz, q_len, _ = hidden_states.size() 65 | 66 | if self.config.pretraining_tp > 1: 67 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp 68 | query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) 69 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 70 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 71 | 72 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] 73 | query_states = torch.cat(query_states, dim=-1) 74 | 75 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] 76 | key_states = torch.cat(key_states, dim=-1) 77 | 78 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] 79 | value_states = torch.cat(value_states, dim=-1) 80 | 81 | else: 82 | query_states = self.q_proj(hidden_states) 83 | key_states = self.k_proj(hidden_states) 84 | value_states = self.v_proj(hidden_states) 85 | 86 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 87 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 88 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 89 | 90 | kv_seq_len = key_states.shape[-2] 91 | if past_key_value is not None: 92 | kv_seq_len += past_key_value[0].shape[-2] 93 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 94 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 95 | 96 | if past_key_value is not None: 97 | # reuse k, v, self_attention 98 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 99 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 100 | 101 | past_key_value = (key_states, value_states) if use_cache else None 102 | 103 | # repeat k/v heads if n_kv_heads < n_heads 104 | key_states = repeat_kv(key_states, self.num_key_value_groups) 105 | value_states = repeat_kv(value_states, self.num_key_value_groups) 106 | 107 | 108 | # Flash attention codes from 109 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 110 | 111 | # transform the data into the format required by flash attention 112 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 113 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 114 | 115 | attention_mask = torch.ones((bsz, q_len), device=qkv.device) 116 | key_padding_mask = attention_mask 117 | 118 | if key_padding_mask is None: 119 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 120 | cu_q_lens = torch.arange( 121 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 122 | ) 123 | max_s = q_len 124 | output = flash_attn_varlen_qkvpacked_func( 125 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 126 | ) 127 | output = output.view(bsz, q_len, -1) 128 | else: 129 | qkv = qkv.reshape(bsz, q_len, -1) 130 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 131 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 132 | output_unpad = flash_attn_varlen_qkvpacked_func( 133 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 134 | ) 135 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 136 | output = pad_input(output_unpad, indices, bsz, q_len) 137 | 138 | return self.o_proj(output), None, past_key_value 139 | 140 | 141 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 142 | # requires the attention mask to be the same as the key_padding_mask 143 | def _prepare_decoder_attention_mask(self, 144 | attention_mask, 145 | input_shape, 146 | inputs_embeds, 147 | past_key_values_length): 148 | # [bsz, seq_len] 149 | return attention_mask 150 | 151 | 152 | def replace_llama_attn_with_flash_attn(): 153 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 154 | transformers.models.llama.modeling_llama.LlamaAttention.forward = llama_flash_attn_forward 155 | 156 | def refine_rope(): 157 | from .modeling_llama import LlamaRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding 158 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding 159 | transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding = LlamaDynamicNTKScalingRotaryEmbedding -------------------------------------------------------------------------------- /src/transpeeder/feeder.py: -------------------------------------------------------------------------------- 1 | """ feader.py """ 2 | 3 | import copy 4 | import json 5 | from pathlib import Path 6 | from functools import cache 7 | from dataclasses import dataclass 8 | from typing import Dict, Sequence, Union, List 9 | from collections import defaultdict 10 | 11 | import torch 12 | import deepspeed 13 | import transformers 14 | from tqdm import tqdm 15 | from torch.utils.data import Dataset, Subset, DataLoader, DistributedSampler 16 | from sklearn.model_selection import train_test_split 17 | 18 | from .utils import is_rank_0 19 | from .utils import logger_rank0 as logger 20 | 21 | 22 | IGNORE_INDEX = -100 23 | DEFAULT_PAD_TOKEN = "[PAD]" 24 | DEFAULT_EOS_TOKEN = "" 25 | DEFAULT_BOS_TOKEN = "" 26 | DEFAULT_UNK_TOKEN = "" 27 | 28 | PREFIX_BEGIN_TOKEN = "<|prefix_begin|>" 29 | PREFIX_END_TOKEN = "<|prefix_end|>" 30 | PROMPTER_TOKEN = "<|prompter|>" 31 | ASSISTANT_TOKEN = "<|assistant|>" 32 | ENDOFTEXT_TOKEN = "<|endoftext|>" 33 | 34 | PROMPT_FIELD = 'prompt' 35 | OUTPUT_FIELD = 'output' 36 | 37 | 38 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 39 | """Tokenize a list of strings.""" 40 | batch_tokenized = tokenizer( 41 | strings, 42 | return_tensors="pt", 43 | padding="max_length", 44 | max_length=tokenizer.model_max_length, 45 | truncation=True, 46 | ).input_ids 47 | 48 | input_ids = labels = batch_tokenized 49 | input_ids_lens = labels_lens = [ 50 | tokenized.ne(tokenizer.pad_token_id).sum().item() for tokenized in batch_tokenized 51 | ] 52 | return dict( 53 | input_ids=input_ids, 54 | labels=labels, 55 | input_ids_lens=input_ids_lens, 56 | labels_lens=labels_lens, 57 | ) 58 | 59 | def _make_labels(input_ids, tokenizer: transformers.PreTrainedTokenizer, mode: str = "sft", **kwargs): 60 | if mode == "sft": 61 | assert "source_lens" in kwargs, f"miss parameter: source_lens" 62 | labels = copy.deepcopy(input_ids) 63 | for label, source_len in zip(labels, kwargs["source_lens"]): 64 | label[: source_len] = IGNORE_INDEX 65 | return labels 66 | elif mode == "pretrain": 67 | return copy.deepcopy(input_ids) 68 | elif mode == "dialog": 69 | labels = torch.full_like(input_ids, IGNORE_INDEX, dtype=input_ids.dtype) 70 | # <|assistant|> ... <|endoftext|> 71 | ASSISTANT_TOKEN_ID = tokenizer.convert_tokens_to_ids(ASSISTANT_TOKEN) 72 | ENDOFTEXT_TOKEN_ID = tokenizer.convert_tokens_to_ids(ENDOFTEXT_TOKEN) 73 | PROMPTER_TOKEN_ID = tokenizer.convert_tokens_to_ids(PROMPTER_TOKEN) 74 | for input_row, label_row in zip(input_ids, labels): 75 | begin_indices = torch.nonzero(input_row == ASSISTANT_TOKEN_ID) 76 | for idx in begin_indices: 77 | edi = idx + 1 78 | while edi < len(input_row) and input_row[edi] != ENDOFTEXT_TOKEN_ID: 79 | edi += 1 80 | if edi < len(input_row) and \ 81 | input_row[edi + 1] != PROMPTER_TOKEN_ID: 82 | logger.warning(f'expect {PROMPTER_TOKEN} after {ENDOFTEXT_TOKEN}, get {input_row[edi + 1]}.') 83 | label_row[idx + 1: edi + 1] = input_row[idx + 1: edi + 1] 84 | 85 | return labels 86 | else: 87 | raise ValueError('Unvalid training mode.') 88 | 89 | 90 | def preprocess( 91 | sources: Sequence[str], 92 | targets: Sequence[str], 93 | tokenizer: transformers.PreTrainedTokenizer, 94 | mode: str 95 | ) -> Dict: 96 | """Preprocess the data by tokenizing.""" 97 | samples = [s + t for s, t in zip(sources, targets)] 98 | samples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (samples, sources)] 99 | input_ids = samples_tokenized["input_ids"] 100 | labels = _make_labels(input_ids, tokenizer, mode, 101 | source_lens=sources_tokenized["input_ids_lens"]) 102 | 103 | # shift 104 | return dict( 105 | input_ids=[ids[: -1] for ids in input_ids], 106 | labels=[lbs[1: ]for lbs in labels] 107 | ) 108 | 109 | 110 | class PromptDataset(Dataset): 111 | """ Dataset for prompt-tuning. """ 112 | 113 | def __init__(self, data_path: Union[str, Path], eos: str = ""): 114 | super().__init__() 115 | if isinstance(data_path, str): 116 | data_path = Path(data_path) 117 | assert data_path.exists(), f'{data_path} does not exists.' 118 | 119 | self.samples = [] 120 | all_files = list(data_path.glob('**/*.json') if data_path.is_dir() else [data_path]) 121 | 122 | error_count = defaultdict(int) 123 | ERROR_THRESHOLD = 10 124 | for single_file in tqdm(all_files, disable=not is_rank_0()): 125 | with (single_file).open(encoding='utf-8') as f: 126 | for lnum, ln in enumerate(f): 127 | try: 128 | sample = json.loads(ln) 129 | prompt, output = sample[PROMPT_FIELD], sample[OUTPUT_FIELD] 130 | if not isinstance(prompt, str) or not isinstance(output, str): 131 | raise ValueError() 132 | self.samples.append(dict( 133 | prompt=prompt, 134 | output=output + eos, 135 | )) 136 | except: 137 | logger.warning(f'{single_file}: {lnum} unvalid.') 138 | error_count[str(single_file)] += 1 139 | 140 | if error_count[str(single_file)] > ERROR_THRESHOLD: 141 | logger.warning(f'{single_file} exceeds max error number. skipped.') 142 | break 143 | 144 | logger.info(f'total samples num: {len(self.samples)}') 145 | 146 | def __len__(self): 147 | return len(self.samples) 148 | 149 | def __getitem__(self, index) -> Dict[str, str]: 150 | # TODO: preprocess here and caching on the fly. 151 | return self.samples[index] 152 | 153 | 154 | @dataclass 155 | class DataCollatorForPromptDataset(object): 156 | """Collate for supervised fine-tuning.""" 157 | 158 | tokenizer: transformers.PreTrainedTokenizer 159 | mode: str 160 | 161 | @cache 162 | @staticmethod 163 | def get_attn_mask(bs, seq_length): 164 | """ 165 | Get triangular attention mask. 166 | """ 167 | # lower triangular attention mask 168 | mask = torch.tril(torch.ones((bs, seq_length, seq_length))).view( 169 | bs, 1, seq_length, seq_length 170 | ) 171 | # convert to binary 172 | return mask < 0.5 173 | 174 | @staticmethod 175 | def get_position_ids(input_ids): 176 | seq_length = input_ids.shape[1] 177 | # Position ids. 178 | position_ids = torch.arange(seq_length, dtype=torch.long) 179 | return position_ids.unsqueeze(0).expand_as(input_ids) 180 | 181 | def __call__(self, samples: Sequence[Dict]) -> Dict[str, torch.Tensor]: 182 | sources = [sample[PROMPT_FIELD] for sample in samples] 183 | targets = [sample[OUTPUT_FIELD] for sample in samples] 184 | 185 | data_dict = preprocess(sources, targets, self.tokenizer, self.mode) 186 | input_ids = data_dict["input_ids"] 187 | labels = data_dict["labels"] 188 | 189 | input_ids = torch.stack(input_ids) 190 | labels = torch.stack(labels) 191 | labels = torch.where(labels == self.tokenizer.pad_token_id, IGNORE_INDEX, labels) 192 | 193 | return ( 194 | ( 195 | input_ids, 196 | DataCollatorForPromptDataset.get_position_ids(input_ids), 197 | DataCollatorForPromptDataset.get_attn_mask(input_ids.shape[0], input_ids.shape[1]), 198 | ), 199 | labels 200 | ) 201 | 202 | 203 | class TokenizedDataset(Dataset): 204 | def __init__(self, data_path: Union[str, Path]): 205 | super().__init__() 206 | if isinstance(data_path, str): 207 | data_path = Path(data_path) 208 | assert data_path.exists(), f'{data_path} does not exists.' 209 | 210 | self.samples = [] 211 | all_files = list(data_path.glob('**/*.pt') if data_path.is_dir() else [data_path]) 212 | 213 | for single_file in tqdm(all_files, disable=not is_rank_0()): 214 | self.samples.extend(torch.load(single_file)) 215 | 216 | logger.info(f'total samples num: {len(self.samples)}') 217 | 218 | def __len__(self): 219 | return len(self.samples) 220 | 221 | def __getitem__(self, index) -> Dict[str, str]: 222 | return self.samples[index] 223 | 224 | 225 | @dataclass 226 | class DataCollatorForTokenizedDataset(DataCollatorForPromptDataset): 227 | 228 | def __call__(self, samples: Sequence[Dict]) -> Dict[str, torch.Tensor]: 229 | input_ids = torch.stack([s['input_ids'] for s in samples]) 230 | labels = torch.stack([s['labels'] for s in samples]) 231 | return ( 232 | ( 233 | input_ids, 234 | self.get_position_ids(input_ids.shape[0], input_ids.shape[1]), 235 | self.get_attn_mask(input_ids), 236 | ), 237 | labels 238 | ) 239 | 240 | 241 | def train_val_dataset(dataset, val_split=0.2): 242 | train_idx, val_idx = train_test_split( 243 | list(range(len(dataset))), test_size=val_split, random_state=42, shuffle=True 244 | ) 245 | return Subset(dataset, train_idx), Subset(dataset, val_idx) 246 | 247 | 248 | def make_prompt_dataloader(tokenizer: transformers.PreTrainedTokenizer, data_args, engine, val_split=None) -> Dict: 249 | # TODO add eval dataloader 250 | assert val_split is None 251 | dataset = PromptDataset(data_path=data_args.data_path, eos=tokenizer.eos_token) 252 | data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, mode=data_args.mode) 253 | g = torch.Generator() 254 | train_sampler = DistributedSampler(dataset, 255 | num_replicas=engine.dp_world_size, 256 | rank=engine.mpu.get_data_parallel_rank(), 257 | shuffle=True) 258 | dataloader = DataLoader(dataset, 259 | collate_fn=data_collator, 260 | num_workers=data_args.num_workers, 261 | batch_size=data_args.batch_size, 262 | sampler=train_sampler, 263 | drop_last=True, 264 | generator=g,) 265 | return iter(deepspeed.utils.RepeatingLoader(dataloader)) 266 | 267 | 268 | def make_tokenized_dataloader(tokenizer: transformers.PreTrainedTokenizer, data_args, engine, val_split=None) -> Dict: 269 | dataset = TokenizedDataset(data_path=data_args.data_path) 270 | data_collator = DataCollatorForTokenizedDataset(tokenizer=tokenizer, mode=data_args.mode) 271 | g = torch.Generator() 272 | train_sampler = DistributedSampler(dataset, 273 | num_replicas=engine.dp_world_size, 274 | rank=engine.mpu.get_data_parallel_rank(), 275 | shuffle=True) 276 | dataloader = DataLoader(dataset, 277 | collate_fn=data_collator, 278 | num_workers=data_args.num_workers, 279 | batch_size=data_args.batch_size, 280 | sampler=train_sampler, 281 | drop_last=True, 282 | generator=g,) 283 | return iter(deepspeed.utils.RepeatingLoader(dataloader)) 284 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /data/alpaca_data_sample_oneline_format.json: -------------------------------------------------------------------------------- 1 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGive three tips for staying healthy.\n\n### Response:", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."} 2 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWhat are the three primary colors?\n\n### Response:", "output": "The three primary colors are red, blue, and yellow."} 3 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDescribe the structure of an atom.\n\n### Response:", "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom."} 4 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nHow can we reduce air pollution?\n\n### Response:", "output": "There are a number of ways to reduce air pollution, such as shifting to renewable energy sources, encouraging the use of public transportation, prohibiting the burning of fossil fuels, implementing policies to reduce emissions from industrial sources, and implementing vehicle emissions standards. Additionally, individuals can do their part to reduce air pollution by reducing car use, avoiding burning materials such as wood, and changing to energy efficient appliances."} 5 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDescribe a time when you had to make a difficult decision.\n\n### Response:", "output": "I had to make a difficult decision when I was working as a project manager at a construction company. I was in charge of a project that needed to be completed by a certain date in order to meet the client’s expectations. However, due to unexpected delays, we were not able to meet the deadline and so I had to make a difficult decision. I decided to extend the deadline, but I had to stretch the team’s resources even further and increase the budget. Although it was a risky decision, I ultimately decided to go ahead with it to ensure that the project was completed on time and that the client’s expectations were met. The project was eventually successfully completed and this was seen as a testament to my leadership and decision-making abilities."} 6 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nIdentify the odd one out.\n\n### Input:\nTwitter, Instagram, Telegram\n\n### Response:", "output": "Telegram"} 7 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain why the following fraction is equivalent to 1/4\n\n### Input:\n4/16\n\n### Response:", "output": "The fraction 4/16 is equivalent to 1/4 because both numerators and denominators are divisible by 4. Dividing both the top and bottom numbers by 4 yields the fraction 1/4."} 8 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWrite a short story in third person narration about a protagonist who has to make an important career decision.\n\n### Response:", "output": "John was at a crossroads in his life. He had just graduated college and was now facing the big decision of what career to pursue. After much deliberation, he decided that he wanted to be an accountant and help the financially disadvantaged. He had always been good with numbers and enjoyed seeing the tangible results of his work. \n\nJohn enrolled in accounting courses and initially found it quite challenging. He had to learn multiple systems and regulations quickly, but he worked hard and eventually excelled in his studies. After a few years, John started working at an accounting firm in his city. He was eager to put his knowledge of taxes and accounting to use in a real-world setting.\n\nJohn loved his job, as it let him express his creativity in finding strategies to save his clients money. After a few years at the firm, he became a senior accountant and was asked to manage bigger and more challenging cases. He was now a respected figure in the financial industry, but he still remembers when he was just a recent college graduate, unsure of the direction in which his life would take him."} 9 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nRender a 3D model of a house\n\n### Response:", "output": " This type of instruction cannot be fulfilled by a GPT model."} 10 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nEvaluate this sentence for spelling and grammar mistakes\n\n### Input:\nHe finnished his meal and left the resturant\n\n### Response:", "output": "He finished his meal and left the restaurant."} 11 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nHow did Julius Caesar die?\n\n### Response:", "output": "Julius Caesar was assassinated by a group of up to 60 conspirators, led by Gaius Cassius Longinus and Marcus Junius Brutus, in the Senate House on the Ides of March (15 March) of 44 BC."} 12 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWhat is the capital of France?\n\n### Response:", "output": "The capital of France is Paris."} 13 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a list of ten items a person might need for a camping trip\n\n### Response:", "output": "1. Tent\n2. Sleeping bags\n3. Flashlight\n4. Matches/lighter\n5. Insect repellent\n6. Extra clothes\n7. First aid kit\n8. Camp stove\n9. Cooking utensils\n10. Biodegradable soaps/shampoos"} 14 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDiscuss the causes of the Great Depression\n\n### Response:", "output": "The Great Depression was a severe global economic downturn that lasted from 1929 to 1939. It was caused by several factors, including the stock market crash of 1929, structural weaknesses in the economy, low consumer spending, high unemployment, and a lack of confidence in the banking system. Other contributing factors included overproduction of commodities, reliance on borrowing to finance investments, agricultural overproduction, and a decrease in international trade."} 15 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nClassify the following into animals, plants, and minerals\n\n### Input:\nOak tree, copper ore, elephant\n\n### Response:", "output": "Oak tree: Plant\n Copper ore: Mineral\n Elephant: Animal"} 16 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain the use of word embeddings in Natural Language Processing\n\n### Response:", "output": "Word embeddings are one of the most powerful tools available for Natural Language Processing (NLP). They are mathematical representations of words or phrases in a vector space, allowing similarities between words and the context in which they are used to be measured. Word embeddings are useful for tasks such as sentiment analysis, text classification, predicting the next word in a sequence, and understanding synonyms and analogies. They allow for words to be processed as numerical values, giving machines an easier way to perform NLP tasks."} 17 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDescribe the function of a computer motherboard\n\n### Response:", "output": "A computer motherboard is the main circuit board in a computer. It is the backbone of the system, connecting all other components such as the processor, RAM, storage, graphics card and peripheral ports. It provides communication pathways between all of these components, allowing them to interact and work together. The motherboard is responsible for regulating power, and also houses slots for expansion cards, as well as ports for USB, audio, Ethernet, and other connection types."} 18 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nReverse engineer this code to create a new version\n\n### Input:\ndef factorialize(num):\n factorial = 1\n for i in range(1, num):\n factorial *= i\n \n return factorial\n\n### Response:", "output": "def factorialize(num):\n factorial = 1\n for i in range(num, 0, -1):\n factorial *= i\n \n return factorial"} 19 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nPropose an ethical solution to the problem of data privacy\n\n### Response:", "output": "One ethical solution to the problem of data privacy is to implement an opt-in data collection policy. This means that companies cannot collect and use personal data without the explicit permission of the data owners. This approach gives users the power to determine which data companies can use and for what purpose. Companies should also be transparent about how they use the collected data, so that users can make informed decisions and have control over their own data."} 20 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate three verbs that mean the same as \"to apologize\"\n\n### Response:", "output": "Acknowledge, express regret, make amends."} 21 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWhat does DNA stand for?\n\n### Response:", "output": "DNA stands for deoxyribonucleic acid."} 22 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nCompare and contrast the Cuban Missile Crisis and the Vietnam War.\n\n### Response:", "output": "The Cuban Missile Crisis and the Vietnam War were two of the largest and deadliest conflicts of the 20th century. The Cuban Missile Crisis involved a standoff between the United States and the Soviet Union over Soviet nuclear missile bases in Cuba, while the Vietnam War was a prolonged military conflict in Indochina between the North, which was backed by the Soviet Union and the South, which was backed by the United States. The Cuban Missile Crisis unfolded over a period of 13 days in October 1962, with the resolution coming in the form of a diplomatic agreement. The Vietnam War lasted for 19 years, beginning in 1954 and ending in 1973, with the signing of the Paris Peace Accords. While both of these conflicts were Cold War related, they were also very different in terms of causes, objectives, strategic objectives and outcomes. The Cuban Missile Crisis did not result in any direct military conflict and was resolved primarily through diplomacy and the decision to withdraw nuclear weapons from Cuba. On the other hand, the Vietnam War was a prolonged series of battles resulting in significant losses for both sides. The Vietnam War also had a significant impact on the United States, as it highlighted the realities of a prolonged conflict and the limits of U.S. military intervention. Additionally, the Vietnam War also drastically changed the way public opinion viewed the military and the use of force."} 23 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a list of random words.\n\n### Response:", "output": "sneeze, conflict, ancestor, thunder, companion, amulet, forge, serpent, magnify, jester, cascade, universe, surge, nausea, oasis, anomaly, pursue, rapture, pagoda, mosaic."} 24 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nTranscribe the recording into text.\n\n### Input:\nRecording about the history of Waterloo\n\n### Response:", "output": "The Battle of Waterloo, fought on June 18, 1815, was the last major battle of the Napoleonic Wars. It marked the end of Napoleon Bonaparte's rule. The battle was fought between Napoleon Bonaparte's French forces and an allied force of British, Dutch, Belgian and Prussian troops, led by the Duke of Wellington. It was fought near the town of Waterloo in the former Austrian Netherlands, now present-day Belgium. The battle was a defeat for Napoleon's forces and ended Napoleon's rule over France. Napoleon himself was exiled to St. Helena, where he would die in 1821. The battle marked the end of French domination of Europe and the beginning of a period of peace in the region. It was also the last major battle of the Napoleonic Wars and is remembered as one of the most important battles in history. The victory of the allied forces at Waterloo helped to bring an end to French imperial ambitions in Europe. It also established the British as one of the most powerful nations in Europe. The battle was fought with a combination of infantry, cavalry, and artillery tactics, and showed the beginning of the changing nature of warfare. Although the French forces greatly outnumbered the allies, their strategic and tactical mistake led to the loss of the battle. This defeat signaled the end of French imperial power in Europe. The legacy of Waterloo still stands today, and it cemented the British Empire's position for the next one hundred years."} 25 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWho is the world's most famous painter?\n\n### Response:", "output": "The world's most famous painter is generally considered to be the Italian renaissance artist Leonardo da Vinci. His work includes the Mona Lisa and The Last Supper, both of which are considered masterpieces. Da Vinci was also a master of science, inventing the helicopter and the tank, among other inventions. He was also a prolific musician, sculptor, and architect. His work has inspired generations of artists and continues to be revered today."} 26 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain the concept of cogging torque.\n\n### Response:", "output": "Cogging torque is the torque that is applied to a rotating device when the device experiences resistance from the teeth of the gears that are connected to the device. Cogging torque occurs due to the reaction force generated when the teeth of the gear drive mesh. This force pushes against each other, resulting in a drag force that is transmitted from the gears to the shaft. The magnitude of the cogging torque can vary depending on the size and number of teeth, as well as the gear's material and design. Cogging torque is undesirable in most motorized applications, as it can lead to loss of efficiency. To reduce cogging torque, gear designs are often symmetrical or have certain tooth profiles which reduce the force of contact between the teeth when they come into contact with each other."} 27 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nLook up the boiling point of water.\n\n### Response:", "output": "The boiling point of water is 100°C (212°F), at standard atmospheric pressure. When heated at higher pressures, water's boiling point increases. For example, water will boil at a temperature of 115°C (239°F) at an atmospheric pressure of 1 bar. At higher pressures, such as those experienced at the bottom of the ocean, the boiling point of water can be as high as 374°C (705°F)."} 28 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDescribe the life and reign of King Charles II.\n\n### Response:", "output": "King Charles II ascended to the throne of England following the English Civil War and the death of his father, Charles I. He was crowned in 1661 and his reign lasted until his death in 1685. Charles II was also a patron of the arts, working with some of the greatest painters and sculptors of his day. He also encouraged theater and sports and had an interest in sciences such as astronomy and natural history. Charles II was a moderate ruler who respected the civil and religious liberties of his citizens. He was tolerant of other religions and gave Jews and Catholics more rights in the kingdom. Charles II was also a strong supporter of the Royal Navy, which saw England increase its power in the world. He also promoted the development of industry and commerce, which helped to modernize the country. Charles II's reign was an important step in restoring the monarchy of England after the turmoil of the English Civil War."} 29 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nFind the area of a circle given its radius.\n\n### Input:\nRadius = 4\n\n### Response:", "output": "The area of a circle with a radius of 4 is equal to 12.5664 square units. This is calculated by using the formula A = πr2, where A is the area, π is roughly equal to 3.1416 and r is the radius of the circle."} 30 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nIdentify the lines of longitude that divides North and South America.\n\n### Response:", "output": "The lines of longitude that divides North and South America is the line of the Panama Canal. This line of longitude cuts through the Isthmus of Panama and separates South America from Central America. It is approximately 800 km (500 mi) in total length and was completed in 1914. The Panama Canal is an important shipping route and serves as an international waterway connecting the Caribbean Sea with the Pacific Ocean."} 31 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain the process of cellular respiration in plants.\n\n### Response:", "output": "Cellular respiration in plants is a complex process in which the plant converts glucose into usable energy. The process begins when the plant takes in carbon dioxide from the environment and combines it with water to form glucose molecules using the energy from sunlight. This glucose is then broken down into simpler molecules of carbon dioxide and water. During this process, the energy released is used to produce adenosine triphosphate (ATP) molecules. ATP is used by the plant to power other processes, such as growth and reproduction. Additionally, the ATP can also be broken down further to produce heat energy, which helps to sustain the plant's temperature. The carbon dioxide and water molecules produced in the process are then released back into the environment to complete the cycle."} 32 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nRewrite the following sentence using active voice.\n\n### Input:\nThe news report was read by the captain.\n\n### Response:", "output": "The captain read the news report."} 33 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a list of adjectives that describe a person as brave.\n\n### Response:", "output": "Courageous, heroic, audacious, vigorous, valorous, resolute, intrepid, bold, brave, daring, tenacious, plucky, stout-hearted, doughty, gallant, fearless, undaunted, spirited, daredevil."} 34 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nOutline the consequences of deforestation.\n\n### Response:", "output": "Deforestation is the process of clearing and removing forests. The consequences of deforestation can include soil erosion, reduced water quality, decreased biodiversity, increased carbon dioxide in the atmosphere and loss of habitat for wildlife. Soil erosion occurs when the trees that help retain water and soil are removed, leading to soil particles being washed away. This can cause siltation in rivers and streams, affecting water quality and leading to flooding. Additionally, when trees are removed, wildlife lose their habitat. This can cause species to become endangered, leading to further loss of biodiversity. Deforestation can also lead to increased carbon dioxide in the atmosphere as there are fewer trees to take up the gas during photosynthesis, which can contribute to climate change."} 35 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDevelop a plan to reduce electricity usage in a home.\n\n### Response:", "output": "A plan to reduce electricity usage in a home can include a number of different strategies. These include replacing inefficient lightbulbs with energy-efficient LEDs, unplugging appliances and electronics when not in use, using natural light instead of electric lights where possible, using a touch-free faucet to reduce water heater use, using a smart thermostat to monitor and adjust temperatures automatically, installing a solar hot water heater, and turning off electronics when not in use. Additionally, making sure to install and seal windows and doors correctly to reduce air leakage can also help reduce electricity usage."} 36 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nArrange the words in the given sentence to form a grammatically correct sentence.\n\n### Input:\nquickly the brown fox jumped\n\n### Response:", "output": "The quick brown fox jumped quickly."} 37 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nAnalyze the given text for its tone.\n\n### Input:\nThe world has been greatly impacted by the COVID-19 pandemic and it has drastically changed our lives.\n\n### Response:", "output": "The tone of the text is one of concern and recognition of the dramatic effects of the COVID-19 pandemic on the world. It is an acknowledgement of the upheaval that the virus has caused and that our lives have been drastically changed."} 38 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nUse the given data to calculate the median.\n\n### Input:\n[2, 3, 7, 8, 10]\n\n### Response:", "output": "The median of the given data is 7."} 39 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nRewrite the given paragraph in a shorter, easier to understand form.\n\n### Input:\nAlthough it is generally accepted that the internet has allowed us to connect with people all over the world, there are still those people who are not familiar with its basic functions, who don’t understand why it has become so commonplace, or what its true capabilities are.\n\n### Response:", "output": "The internet has allowed us to connect globally, even though some people are unfamiliar with its functions, don't comprehend its widespread use, and don't know its full potential."} 40 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDesign a logo for a website about recycling.\n\n### Response:", "output": "The logo should feature a green motif that is recognizable and memorable, such as a recycle symbol or a stylized representation of a recycling bin. The logo should also include text that communicates the website's purpose, such as \"Recycling\" or \"Go Green\"."} 41 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a poem that expresses joy.\n\n### Response:", "output": "Joy flows through me like a river clear,\nBringing laughter and bliss to my soul to hear.\n\nThe world is bright and wonderful, I can see,\nMaking every moment something happy and free.\n\nMy heart feels light, my cares have flown,\nAs I discover what happiness gives me alone.\n\nI open my arms and accept the joy,\nBringing life to my song, so sweet like a toy.\n\nEvery moment is a precious gift,\nThat I savor and never let drift."} 42 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nConvert the given equation into an algebraic expression.\n\n### Input:\n3x+5y=9\n\n### Response:", "output": "3x + 5y - 9 = 0"} 43 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nBrainstorm possible solutions to reduce water pollution.\n\n### Response:", "output": "- Implement stricter regulations on industrial waste disposal\n- Increase the use of green technology in wastewater treatment\n- Provide incentives to businesses that reduce their water consumption and waste generation\n- Educate citizens about water conservation and pollution prevention\n- Offer recycling programs to reduce plastic waste\n- Replace disposable plastics with biodegradable alternatives\n- Improve sewer systems\n- Re-introduce native plant species\n- Foster artificial wetlands and green infrastructure\n- Provide filtration systems at water treatment plants\n- Implement best practices for agricultural runoff\n- Increase research and monitoring of water quality\n- Establish buffer zones between aquatic environments and land-based activities."} 44 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain why the given definition is wrong.\n\n### Input:\nA mole is an animal that lives underground.\n\n### Response:", "output": "The definition is wrong because a mole is actually a small mammal that lives above ground and feeds on insects. It is most distinguishable by its long snout and wide, paddle-like feet. It does, however, often dig burrows in which it lives and shelters."} 45 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nVariable x is defined as “4x + 2y = 10”. Find the value of x.\n\n### Response:", "output": "The value of x is 2. To find the value, simplify the equation by subtracting 2y from both sides, giving 4x = 10; dividing both sides by 4, giving x = 2/4, which is equal to 2."} 46 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nWrite a short paragraph about the given topic.\n\n### Input:\nThe importance of using renewable energy\n\n### Response:", "output": "The use of renewable energy is growing rapidly in relevance and importance as the world looks towards solutions to combat climate change. Renewable energy sources, such as solar, wind, and hydropower, are sustainable and have a much lower environmental impact than traditional energy sources like coal and oil. Switching to renewable energy can reduce carbon emissions, improve air quality, lessen the dependence on non-renewable resources, and provide a stable and secure energy future. Renewable energy is essential for achieving a sustainable future and should be an important part of any country’s energy portfolio."} 47 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nExplain the concept of artificial intelligence in simple terms.\n\n### Response:", "output": "Artificial Intelligence (AI) is the ability of machines and computers to make decisions and complete tasks similar to what humans can do. AI can learn from data, recognize patterns, and make predictions. AI technology can automate tedious and complex processes, helping to reduce human effort and errors, and enable more accurate and efficient decisions. AI can be used in a wide range of applications, from robotics and autonomous vehicles to healthcare and finance. It is increasingly becoming an integral part of everyday life."} 48 | {"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDesign an app for a delivery company.\n\n### Response:", "output": "The app should allow customers to order items from local businesses and nearby store, place delivery requests with detailed delivery requirements, and track their orders. It should also allow customers to easily find and pay for services. The app should allow businesses to manage orders, track deliveries and drivers, set delivery boundaries, and set prices. It should also have features to help businesses promote their products and services, as well as options to provide customer feedback and support."} 49 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nSummarize the given passage.\n\n### Input:\nA recent study showed that global climate change is one of the most important challenges facing the world today. The consequences of global warming include rising sea levels, extreme weather events and biodiversity loss. Climate change is caused by human activities such as burning fossil fuels, and if unchecked will have long-term and devastating impacts on the planet.\n\n### Response:", "output": "A recent study revealed that global climate change is one of the world’s most pressing issues. In addition to rising sea levels, extreme weather events and biodiversity loss, it is caused largely by human activities like burning fossil fuels. If left unchecked, climate change will have far-reaching and devastating consequences."} 50 | {"prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nExtract the facts from the paragraph.\n\n### Input:\nOnline education continues to become more popular for schools and students alike. Its advantages are generally lower costs, less commitment and the ability to study at a time, place and pace that suits the student.\n\n### Response:", "output": "- Online education is becoming increasingly popular.\n- It has several advantages such as lower costs, less commitment and the ability to study at one’s own time and pace."} 51 | --------------------------------------------------------------------------------