├── .DS_Store ├── configs ├── .DS_Store ├── deepspeed_stage1.json ├── deepspeed_stage2.json └── deepspeed_stage3.json ├── generation ├── .DS_Store ├── register_server.sh ├── merge_data.py ├── gen_hf2.py └── gen_hf.py ├── annotate_data ├── .DS_Store ├── get_bon_data.py └── get_rewards.py ├── sft ├── prepare_model.py ├── gemma-2b-it.yaml ├── gemma-9b-it.yaml └── llama3-8b-it.yaml ├── diffusion-example ├── requirements.txt ├── README.md ├── SD256-RAFT.ipynb └── train_text_to_image_lora.py └── README.md /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RLHFlow/RAFT/HEAD/.DS_Store -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RLHFlow/RAFT/HEAD/configs/.DS_Store -------------------------------------------------------------------------------- /generation/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RLHFlow/RAFT/HEAD/generation/.DS_Store -------------------------------------------------------------------------------- /annotate_data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RLHFlow/RAFT/HEAD/annotate_data/.DS_Store -------------------------------------------------------------------------------- /configs/deepspeed_stage1.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 1, 4 | "overlap_comm": true 5 | }, 6 | "bf16": { 7 | "enabled": "auto" 8 | }, 9 | "fp16": { 10 | "enabled": "auto", 11 | "auto_cast": false, 12 | "loss_scale": 0, 13 | "initial_scale_power": 32, 14 | "loss_scale_window": 1000, 15 | "hysteresis": 2, 16 | "min_loss_scale": 1 17 | }, 18 | "gradient_accumulation_steps": "auto", 19 | "gradient_clipping": "auto", 20 | "train_batch_size": "auto", 21 | "train_micro_batch_size_per_gpu": "auto", 22 | "wall_clock_breakdown": false 23 | } 24 | -------------------------------------------------------------------------------- /sft/prepare_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | 4 | name = 'meta-llama/Meta-Llama-3-8B-Instruct' 5 | tokenizer_name = name 6 | 7 | model = AutoModelForCausalLM.from_pretrained( 8 | name, 9 | torch_dtype=torch.bfloat16, 10 | ) 11 | 12 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 13 | tokenizer.pad_token = tokenizer.eos_token 14 | tokenizer.pad_token_id = tokenizer.eos_token_id 15 | model.config.pad_token_id = tokenizer.pad_token_id 16 | 17 | model.save_pretrained("./models/llama3_it_with_padding_token") 18 | tokenizer.save_pretrained("./models/llama3_it_with_padding_token") 19 | 20 | -------------------------------------------------------------------------------- /generation/register_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # check whether the model is provided 4 | if [ $# -eq 0 ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | # the first parameter from the command as the model name or path 10 | MODEL_PATH=$1 11 | 12 | # we use for loop to create 8 vllm instances, with different GPUs 13 | for i in 0 1 2 3 4 5 6 7 14 | do 15 | CUDA_VISIBLE_DEVICES=$i python -m vllm.entrypoints.api_server \ 16 | --model $MODEL_PATH \ 17 | --gpu-memory-utilization=0.9 \ 18 | --max-num-seqs=200 \ 19 | --host 127.0.0.1 --tensor-parallel-size 1 \ 20 | --port $((8000+i)) \ 21 | & 22 | done 23 | -------------------------------------------------------------------------------- /configs/deepspeed_stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 2, 4 | "offload_optimizer": { 5 | "device": "cpu" 6 | }, 7 | "contiguous_gradients": true, 8 | "overlap_comm": true 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "fp16": { 14 | "enabled": "auto", 15 | "auto_cast": false, 16 | "loss_scale": 0, 17 | "initial_scale_power": 32, 18 | "loss_scale_window": 1000, 19 | "hysteresis": 2, 20 | "min_loss_scale": 1 21 | }, 22 | "gradient_accumulation_steps": "auto", 23 | "gradient_clipping": "auto", 24 | "train_batch_size": "auto", 25 | "train_micro_batch_size_per_gpu": "auto", 26 | "wall_clock_breakdown": false 27 | } 28 | -------------------------------------------------------------------------------- /configs/deepspeed_stage3.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "overlap_comm": true, 5 | "contiguous_gradients": true, 6 | "sub_group_size": 0, 7 | "reduce_bucket_size": "auto", 8 | "stage3_prefetch_bucket_size": "auto", 9 | "stage3_param_persistence_threshold": "auto", 10 | "stage3_max_live_parameters": 0, 11 | "stage3_max_reuse_distance": 0, 12 | "stage3_gather_16bit_weights_on_model_save": true 13 | }, 14 | "bf16": { 15 | "enabled": true 16 | }, 17 | "fp16": { 18 | "enabled": "auto", 19 | "auto_cast": false, 20 | "loss_scale": 0, 21 | "initial_scale_power": 32, 22 | "loss_scale_window": 1000, 23 | "hysteresis": 2, 24 | "min_loss_scale": 1 25 | }, 26 | "gradient_accumulation_steps": "auto", 27 | "gradient_clipping": "auto", 28 | "train_batch_size": "auto", 29 | "train_micro_batch_size_per_gpu": "auto", 30 | "wall_clock_breakdown": false 31 | } 32 | -------------------------------------------------------------------------------- /annotate_data/get_bon_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass, field 4 | from typing import Optional 5 | import numpy as np 6 | from datasets import load_dataset 7 | 8 | 9 | 10 | tqdm.pandas() 11 | 12 | @dataclass 13 | class ScriptArguments: 14 | """ 15 | The arguments for the DPO training script. 16 | """ 17 | 18 | dataset_name_or_path: Optional[str] = field( 19 | default="uf_split0_responses_K8.jsonl", 20 | metadata={"help": "the location of the dataset name or path"}, 21 | ) 22 | output_dir: Optional[str] = field( 23 | default="uf_split0_responses_K8_reward.json", 24 | metadata={"help": "the location of the output file"}, 25 | ) 26 | 27 | 28 | parser = HfArgumentParser(ScriptArguments) 29 | script_args = parser.parse_args_into_dataclasses()[0] 30 | 31 | ds_dir = script_args.dataset_name_or_path 32 | ds = load_dataset("json", data_files=ds_dir, split="train") 33 | 34 | def modify_sample(example): 35 | idx = np.argmax(example['rewards']) 36 | example["messages"] = example['prompt'] + [{"role":"user", "content":example['responses'][idx] }] 37 | 38 | return example 39 | 40 | ds2 = ds.map(modify_sample) 41 | ds3 = ds.remove_columns(["prompt", "responses", "rewards"]) 42 | ds3.push_to_hub(script_args.output_dir) 43 | 44 | ''' 45 | with open(script_args.output_dir, "w", encoding="utf8") as f: 46 | for i in range(len(gathered_data)): 47 | json.dump(gathered_data[i], f, ensure_ascii=False) 48 | f.write('\n') 49 | ''' 50 | -------------------------------------------------------------------------------- /generation/merge_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from dataclasses import dataclass, field 4 | from typing import Optional 5 | from datasets import load_dataset 6 | from transformers import HfArgumentParser 7 | 8 | """ 9 | If we use multiple VLLM processes to accelerate the generation, we need to use this script to merge them. 10 | """ 11 | 12 | 13 | @dataclass 14 | class ScriptArguments: 15 | """ 16 | The arguments for the DPO training script. 17 | """ 18 | 19 | base_path: Optional[str] = field( 20 | default="/home/xiongwei/gshf/data/gen_data", 21 | metadata={"help": "the location of the dataset name or path"}, 22 | ) 23 | output_dir: Optional[str] = field( 24 | default="", 25 | metadata={"help": "the location of the output file"}, 26 | ) 27 | num_datasets: Optional[int] = field( 28 | default=8, 29 | metadata={"help": "the location of the output file"}, 30 | ) 31 | 32 | 33 | parser = HfArgumentParser(ScriptArguments) 34 | script_args = parser.parse_args_into_dataclasses()[0] 35 | 36 | 37 | all_dirs = [script_args.base_path + str(i) + ".json" for i in range(script_args.num_datasets)] 38 | 39 | gathered_data = [] 40 | for my_dir in all_dirs: 41 | ds = ds = load_dataset("json", data_files=my_dir, split="train") 42 | print(len(ds)) 43 | for sample in ds: 44 | gathered_data.append(sample) 45 | 46 | random.shuffle(gathered_data) 47 | 48 | print("I collect ", len(gathered_data), "samples") 49 | 50 | with open(script_args.output_dir, "w", encoding="utf8") as f: 51 | for i in range(len(gathered_data)): 52 | json.dump(gathered_data[i], f, ensure_ascii=False) 53 | f.write('\n') 54 | -------------------------------------------------------------------------------- /sft/gemma-2b-it.yaml: -------------------------------------------------------------------------------- 1 | base_model: google/gemma-2b-it 2 | model_type: AutoModelForCausalLM 3 | tokenizer_type: AutoTokenizer 4 | 5 | load_in_8bit: false 6 | load_in_4bit: false 7 | strict: false 8 | 9 | datasets: 10 | - path: RLHFlow/Llama3-SFT-RAFT-Ultrafeedback-iter1 11 | conversation: gemma 12 | type: sharegpt.load_ultrachat 13 | split: "train" 14 | train_on_split: "train" 15 | 16 | warmup_steps: 40 17 | val_set_size: 0.0 18 | output_dir: ./models/gemma-2b-it_64_lr1e-5 19 | #wandb_project: sft-models 20 | #wandb_entity: raft_train 21 | wandb_watch: 22 | wandb_name: "gemma-2b-it_bs64_lr1e-5" 23 | #_response_only 24 | wandb_log_model: 25 | 26 | train_on_inputs: false 27 | 28 | save_safetensors: true 29 | #noisy_embedding_alpha: 10.0 # default for sharegpt type 30 | dataset_prepared_path: ~/data/preference-models/last_run_prepared 31 | 32 | 33 | dataset_processes: 48 34 | #torch_compile: true 35 | sequence_len: 4096 36 | sample_packing: true 37 | pad_to_sequence_len: true 38 | 39 | trust_remote_code: True 40 | adapter: 41 | lora_model_dir: 42 | 43 | 44 | 45 | 46 | gradient_checkpointing: true 47 | 48 | #warmup_ratio: 0.1 49 | gradient_accumulation_steps: 4 50 | micro_batch_size: 2 51 | num_epochs: 2 52 | optimizer: paged_adamw_32bit 53 | lr_scheduler: cosine 54 | learning_rate: 1.e-5 55 | 56 | weight_decay: 0.0 57 | max_grad_norm: 1.0 58 | 59 | 60 | group_by_length: false 61 | bf16: auto 62 | fp16: false 63 | tf32: true 64 | 65 | early_stopping_patience: 66 | local_rank: 67 | logging_steps: 1 68 | xformers_attention: 69 | flash_attention: true 70 | 71 | 72 | eval_steps: 73 | eval_table_size: 74 | eval_table_max_new_tokens: 75 | #save_steps: 100 76 | save_strategy: "epoch" 77 | save_total_limit: 2 78 | debug: 79 | 80 | 81 | ddp: #true 82 | deepspeed: #deepspeed/zero1.json # multi-gpu only 83 | 84 | fsdp: 85 | fsdp_config: 86 | special_tokens: 87 | -------------------------------------------------------------------------------- /sft/gemma-9b-it.yaml: -------------------------------------------------------------------------------- 1 | base_model: google/gemma-2-9b-it 2 | model_type: AutoModelForCausalLM 3 | tokenizer_type: AutoTokenizer 4 | 5 | load_in_8bit: false 6 | load_in_4bit: false 7 | strict: false 8 | 9 | datasets: 10 | - path: RLHFlow/Llama3-SFT-RAFT-Ultrafeedback-iter1 11 | conversation: gemma 12 | type: sharegpt.load_ultrachat 13 | split: "train" 14 | train_on_split: "train" 15 | 16 | warmup_steps: 100 17 | val_set_size: 0.0 18 | output_dir: ./models/gemma-9b-it_bs64_lr5e-7 19 | #wandb_project: preference-models 20 | #wandb_entity: domain-generalization 21 | wandb_watch: 22 | wandb_name: "gemma-9b-it_bs64_lr5e-7" 23 | #_response_only 24 | wandb_log_model: 25 | 26 | train_on_inputs: false 27 | 28 | save_safetensors: true 29 | #noisy_embedding_alpha: 10.0 # default for sharegpt type 30 | dataset_prepared_path: ~/data/preference-models/last_run_prepared 31 | 32 | 33 | dataset_processes: 48 34 | #torch_compile: true 35 | sequence_len: 4096 36 | sample_packing: true 37 | pad_to_sequence_len: true 38 | 39 | trust_remote_code: True 40 | adapter: 41 | lora_model_dir: 42 | 43 | 44 | 45 | 46 | gradient_checkpointing: True 47 | 48 | #warmup_ratio: 0.1 49 | gradient_accumulation_steps: 8 50 | micro_batch_size: 1 51 | num_epochs: 2 52 | optimizer: paged_adamw_32bit 53 | lr_scheduler: cosine 54 | learning_rate: 5.e-7 55 | 56 | weight_decay: 0.0 57 | max_grad_norm: 1.0 58 | 59 | 60 | group_by_length: false 61 | bf16: auto 62 | fp16: false 63 | tf32: true 64 | 65 | early_stopping_patience: 66 | local_rank: 67 | logging_steps: 1 68 | xformers_attention: 69 | flash_attention: true 70 | 71 | 72 | eval_steps: 73 | eval_table_size: 74 | eval_table_max_new_tokens: 75 | #save_steps: 100 76 | save_strategy: "epoch" 77 | save_total_limit: 2 78 | debug: 79 | 80 | 81 | ddp: #true 82 | deepspeed: #deepspeed/zero1.json # multi-gpu only 83 | 84 | fsdp: 85 | fsdp_config: 86 | special_tokens: 87 | 88 | -------------------------------------------------------------------------------- /sft/llama3-8b-it.yaml: -------------------------------------------------------------------------------- 1 | base_model: ./models/llama3_it_with_padding_token 2 | model_type: AutoModelForCausalLM 3 | tokenizer_type: AutoTokenizer 4 | 5 | load_in_8bit: false 6 | load_in_4bit: false 7 | strict: false 8 | 9 | datasets: 10 | - path: RLHFlow/Llama3-SFT-RAFT-Ultrafeedback-iter1 11 | conversation: llama-3 12 | type: sharegpt.load_ultrachat 13 | split: "train" 14 | train_on_split: "train" 15 | 16 | warmup_steps: 40 17 | val_set_size: 0.0 18 | output_dir: ./models/llama3-8b-it_bs128_lr5e-7 19 | #wandb_project: raft_train 20 | #wandb_entity: raft 21 | wandb_watch: 22 | wandb_name: "llama-8b-it_bs64_lr5e-7" 23 | wandb_log_model: 24 | 25 | train_on_inputs: false 26 | 27 | save_safetensors: true 28 | #noisy_embedding_alpha: 10.0 # default for sharegpt type 29 | dataset_prepared_path: ~/data/preference-models/last_run_prepared 30 | 31 | 32 | dataset_processes: 48 33 | #torch_compile: true 34 | sequence_len: 4096 35 | sample_packing: true 36 | pad_to_sequence_len: true 37 | 38 | trust_remote_code: True 39 | adapter: 40 | lora_model_dir: 41 | #lora_r: 32 42 | #lora_alpha: 16 43 | #lora_dropout: 0.05 44 | #lora_target_linear: true 45 | #lora_fan_in_fan_out: 46 | 47 | 48 | 49 | 50 | gradient_checkpointing: True 51 | 52 | #warmup_ratio: 0.1 53 | gradient_accumulation_steps: 8 54 | micro_batch_size: 1 55 | num_epochs: 1 56 | #max_steps: 10 57 | #optimizer: adamw_torch_fused 58 | optimizer: paged_adamw_32bit 59 | #lr_scheduler: constant_with_warmup 60 | lr_scheduler: cosine 61 | learning_rate: 5.0e-6 62 | 63 | weight_decay: 0.0 64 | max_grad_norm: 1.0 65 | 66 | 67 | group_by_length: false 68 | bf16: auto 69 | fp16: false 70 | tf32: true 71 | 72 | early_stopping_patience: 73 | local_rank: 74 | logging_steps: 2 75 | xformers_attention: 76 | flash_attention: true 77 | 78 | 79 | eval_steps: 80 | eval_table_size: 81 | eval_table_max_new_tokens: 82 | #save_steps: 100 83 | save_strategy: "epoch" 84 | save_total_limit: 1 85 | #save_safetensors: false 86 | debug: 87 | 88 | 89 | ddp: #true 90 | deepspeed: #deepspeed/zero1.json # multi-gpu only 91 | 92 | fsdp: 93 | fsdp_config: 94 | special_tokens: 95 | 96 | -------------------------------------------------------------------------------- /diffusion-example/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.18.0 2 | asttokens==2.2.1 3 | backcall==0.2.0 4 | bitsandbytes==0.37.2 5 | certifi==2022.12.7 6 | charset-normalizer==3.1.0 7 | clip==1.0== 8 | cmake==3.26.1 9 | comm==0.1.3 10 | contourpy==1.0.7 11 | cycler==0.11.0 12 | debugpy==1.6.7 13 | decorator==5.1.1 14 | diffusers==0.14.0 15 | executing==1.2.0 16 | filelock==3.11.0 17 | fonttools==4.39.3 18 | ftfy==6.1.1 19 | huggingface-hub==0.13.4 20 | idna==3.4 21 | importlib-metadata==6.2.0 22 | importlib-resources==5.12.0 23 | ipykernel==6.22.0 24 | ipython==8.12.0 25 | jedi==0.18.2 26 | Jinja2==3.1.2 27 | jupyter_client==8.1.0 28 | jupyter_core==5.3.0 29 | kiwisolver==1.4.4 30 | lit==16.0.0 31 | MarkupSafe==2.1.2 32 | matplotlib==3.7.1 33 | matplotlib-inline==0.1.6 34 | mpmath==1.3.0 35 | mypy-extensions==1.0.0 36 | nest-asyncio==1.5.6 37 | networkx==3.1 38 | numpy==1.24.2 39 | nvidia-cublas-cu11==11.10.3.66 40 | nvidia-cuda-cupti-cu11==11.7.101 41 | nvidia-cuda-nvrtc-cu11==11.7.99 42 | nvidia-cuda-runtime-cu11==11.7.99 43 | nvidia-cudnn-cu11==8.5.0.96 44 | nvidia-cufft-cu11==10.9.0.58 45 | nvidia-curand-cu11==10.2.10.91 46 | nvidia-cusolver-cu11==11.4.0.1 47 | nvidia-cusparse-cu11==11.7.4.91 48 | nvidia-nccl-cu11==2.14.3 49 | nvidia-nvtx-cu11==11.7.91 50 | open-clip-torch==2.16.0 51 | packaging==23.0 52 | pandas==2.0.0 53 | parso==0.8.3 54 | pexpect==4.8.0 55 | pickleshare==0.7.5 56 | Pillow==9.5.0 57 | pip==23.0.1 58 | platformdirs==3.2.0 59 | prompt-toolkit==3.0.38 60 | protobuf==3.20.3 61 | psutil==5.9.4 62 | ptyprocess==0.7.0 63 | pure-eval==0.2.2 64 | Pygments==2.14.0 65 | pyparsing==3.0.9 66 | pyre-extensions==0.0.23 67 | python-dateutil==2.8.2 68 | pytz==2023.3 69 | PyYAML==6.0 70 | pyzmq==25.0.2 71 | regex==2023.3.23 72 | requests==2.28.2 73 | sentencepiece==0.1.97 74 | setuptools==65.6.3 75 | six==1.16.0 76 | stack-data==0.6.2 77 | sympy==1.11.1 78 | timm==0.6.13 79 | tokenizers==0.13.3 80 | torch==2.0.0 81 | torchvision==0.15.1 82 | tornado==6.2 83 | tqdm==4.65.0 84 | traitlets==5.9.0 85 | transformers==4.27.4 86 | triton==2.0.0 87 | typing_extensions==4.5.0 88 | typing-inspect==0.8.0 89 | tzdata==2023.3 90 | urllib3==1.26.15 91 | wcwidth==0.2.6 92 | wheel==0.38.4 93 | xformers==0.0.18 94 | zipp==3.15.0 95 | 96 | -------------------------------------------------------------------------------- /diffusion-example/README.md: -------------------------------------------------------------------------------- 1 | # RAFT-Diffusion 2 | 3 | In this folder, we provide an example to show that how does RAFT work on diffusion models. 4 | 5 | The requirements are shown below. 6 | ``` 7 | accelerate 0.18.0 8 | asttokens 2.2.1 9 | backcall 0.2.0 10 | bitsandbytes 0.37.2 11 | certifi 2022.12.7 12 | charset-normalizer 3.1.0 13 | clip 1.0 14 | cmake 3.26.1 15 | comm 0.1.3 16 | contourpy 1.0.7 17 | cycler 0.11.0 18 | debugpy 1.6.7 19 | decorator 5.1.1 20 | diffusers 0.14.0 21 | executing 1.2.0 22 | filelock 3.11.0 23 | fonttools 4.39.3 24 | ftfy 6.1.1 25 | huggingface-hub 0.13.4 26 | idna 3.4 27 | importlib-metadata 6.2.0 28 | importlib-resources 5.12.0 29 | ipykernel 6.22.0 30 | ipython 8.12.0 31 | jedi 0.18.2 32 | Jinja2 3.1.2 33 | jupyter_client 8.1.0 34 | jupyter_core 5.3.0 35 | kiwisolver 1.4.4 36 | lit 16.0.0 37 | MarkupSafe 2.1.2 38 | matplotlib 3.7.1 39 | matplotlib-inline 0.1.6 40 | mpmath 1.3.0 41 | mypy-extensions 1.0.0 42 | nest-asyncio 1.5.6 43 | networkx 3.1 44 | numpy 1.24.2 45 | nvidia-cublas-cu11 11.10.3.66 46 | nvidia-cuda-cupti-cu11 11.7.101 47 | nvidia-cuda-nvrtc-cu11 11.7.99 48 | nvidia-cuda-runtime-cu11 11.7.99 49 | nvidia-cudnn-cu11 8.5.0.96 50 | nvidia-cufft-cu11 10.9.0.58 51 | nvidia-curand-cu11 10.2.10.91 52 | nvidia-cusolver-cu11 11.4.0.1 53 | nvidia-cusparse-cu11 11.7.4.91 54 | nvidia-nccl-cu11 2.14.3 55 | nvidia-nvtx-cu11 11.7.91 56 | open-clip-torch 2.16.0 57 | packaging 23.0 58 | pandas 2.0.0 59 | parso 0.8.3 60 | pexpect 4.8.0 61 | pickleshare 0.7.5 62 | Pillow 9.5.0 63 | pip 23.0.1 64 | platformdirs 3.2.0 65 | prompt-toolkit 3.0.38 66 | protobuf 3.20.3 67 | psutil 5.9.4 68 | ptyprocess 0.7.0 69 | pure-eval 0.2.2 70 | Pygments 2.14.0 71 | pyparsing 3.0.9 72 | pyre-extensions 0.0.23 73 | python-dateutil 2.8.2 74 | pytz 2023.3 75 | PyYAML 6.0 76 | pyzmq 25.0.2 77 | regex 2023.3.23 78 | requests 2.28.2 79 | sentencepiece 0.1.97 80 | setuptools 65.6.3 81 | six 1.16.0 82 | stack-data 0.6.2 83 | sympy 1.11.1 84 | timm 0.6.13 85 | tokenizers 0.13.3 86 | torch 2.0.0 87 | torchvision 0.15.1 88 | tornado 6.2 89 | tqdm 4.65.0 90 | traitlets 5.9.0 91 | transformers 4.27.4 92 | triton 2.0.0 93 | typing_extensions 4.5.0 94 | typing-inspect 0.8.0 95 | tzdata 2023.3 96 | urllib3 1.26.15 97 | wcwidth 0.2.6 98 | wheel 0.38.4 99 | xformers 0.0.18 100 | zipp 3.15.0 101 | ``` 102 | 103 | -------------------------------------------------------------------------------- /generation/gen_hf2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from dataclasses import dataclass, field 3 | from typing import List, Optional 4 | import numpy as np 5 | import torch 6 | from datasets import load_dataset 7 | from transformers import ( 8 | AutoTokenizer, 9 | HfArgumentParser, 10 | ) 11 | from vllm import LLM, SamplingParams 12 | import json 13 | 14 | 15 | @dataclass 16 | class ScriptArguments: 17 | """ 18 | The arguments for the DPO training script. 19 | """ 20 | 21 | model_name_or_path: Optional[str] = field( 22 | default="your model", 23 | metadata={"help": "the location of the SFT model name or path"}, 24 | ) 25 | dataset_name_or_path: Optional[str] = field( 26 | default="RLHFlow/test_generation_2k", 27 | metadata={"help": "the location of the dataset name or path"}, 28 | ) 29 | local_index: Optional[int] = field( 30 | default=999, 31 | metadata={"help": "the local index of the agent"}, 32 | ) 33 | output_dir: Optional[str] = field( 34 | default="", 35 | metadata={"help": "the location of the output file"}, 36 | ) 37 | my_world_size: Optional[int] = field( 38 | default=4, 39 | metadata={"help": "the total number of the agents"}, 40 | ) 41 | K: Optional[int] = field( 42 | default=8, 43 | metadata={"help": "the number of generations per prompt"}, 44 | ) 45 | max_input_length: Optional[int] = field( 46 | default=10000, 47 | metadata={"help": "the maximum length of the input tokens"}, 48 | ) 49 | max_new_tokens: Optional[int] = field( 50 | default=2048, 51 | metadata={"help": "the maximum length of the new tokens"}, 52 | ) 53 | seed: Optional[int] = field( 54 | default=42, 55 | metadata={"help": "the random seed"}, 56 | ) 57 | temperature: Optional[float] = field( 58 | default=0.7, 59 | metadata={"help": "the temperature"}, 60 | ) 61 | use_beam_search: Optional[bool] = field( 62 | default=False, 63 | metadata={"help": "the beam search"}, 64 | ) 65 | dataset_key: Optional[str] = field( 66 | default="context_messages", 67 | metadata={"help": "the key of the dataset"}, 68 | ) 69 | eos_ids: List[int] = field(default_factory=lambda: [], metadata={"help": "the ids of the end of sentence tokens"}) 70 | 71 | 72 | parser = HfArgumentParser(ScriptArguments) 73 | script_args = parser.parse_args_into_dataclasses()[0] 74 | 75 | model_path = script_args.model_name_or_path 76 | print("model_path", model_path) 77 | seed = script_args.seed 78 | # set seed 79 | torch.manual_seed(seed) 80 | np.random.seed(seed) 81 | 82 | llm = LLM( 83 | model=model_path, 84 | tokenizer=model_path, 85 | dtype="bfloat16", 86 | max_model_len=script_args.max_input_length, 87 | load_format="auto", 88 | seed=42, 89 | ) 90 | tokenizer = AutoTokenizer.from_pretrained(model_path) 91 | 92 | sampling_params = SamplingParams( 93 | temperature=script_args.temperature, 94 | top_p=1.0, 95 | max_tokens=script_args.max_new_tokens, 96 | n=script_args.K, 97 | stop_token_ids=[tokenizer.eos_token_id] + script_args.eos_ids, 98 | #stop=["<|user|>"], 99 | ) 100 | 101 | 102 | ds = load_dataset(script_args.dataset_name_or_path, split="train") 103 | ds = ds.map( 104 | lambda x: { 105 | "prompt": tokenizer.apply_chat_template(x[script_args.dataset_key], tokenize=False, add_generation_prompt=True) 106 | } 107 | ) 108 | 109 | data_size = len(ds["prompt"]) 110 | one_num_share = int(data_size / script_args.my_world_size) 111 | ds = ds.select(np.arange(script_args.local_index * one_num_share, (script_args.local_index + 1) * one_num_share)) 112 | 113 | print([script_args.local_index * one_num_share, (script_args.local_index + 1) * one_num_share]) 114 | print(ds, script_args.dataset_name_or_path) 115 | print(ds[0]) 116 | 117 | 118 | prompts = ds["prompt"] 119 | outputs = llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True) 120 | 121 | 122 | completions = [] 123 | used_prompts = [] 124 | gathered_data = [] 125 | for i, output in enumerate(outputs): 126 | tmp_data = {"prompt": ds[i][script_args.dataset_key], "responses": [out.text for out in output.outputs]} 127 | gathered_data.append(tmp_data) 128 | 129 | 130 | print("I collect ", len(gathered_data), "samples") 131 | 132 | 133 | with open(script_args.output_dir + str(script_args.local_index) + ".json", "w", encoding="utf8") as f: 134 | for i in range(len(gathered_data)): 135 | json.dump(gathered_data[i], f, ensure_ascii=False) 136 | f.write('\n') 137 | -------------------------------------------------------------------------------- /generation/gen_hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass, field 3 | from typing import List, Optional 4 | from datasets import load_dataset 5 | from tqdm import tqdm 6 | from transformers import AutoTokenizer, HfArgumentParser 7 | from concurrent.futures import ThreadPoolExecutor, as_completed 8 | import requests 9 | 10 | tqdm.pandas() 11 | 12 | 13 | @dataclass 14 | class ScriptArguments: 15 | """ 16 | The arguments for the DPO training script. 17 | """ 18 | 19 | url: Optional[str] = field( 20 | default="http://localhost", 21 | metadata={"help": "url of the model response"}, 22 | ) 23 | tokenizer: Optional[str] = field( 24 | default="HuggingFaceH4/mistral-7b-sft-beta", 25 | metadata={"help": "the tokenizer to use"}, 26 | ) 27 | ports: List[str] = field(default_factory=lambda: ["8000"], metadata={"help": "ports of the model response"}) 28 | eos_ids: List[int] = field(default_factory=lambda: [], metadata={"help": "the ids of the end of sentence tokens"}) 29 | dataset_name_or_path: Optional[str] = field( 30 | default="cornfieldrm/iterative-prompt-v1-iter1-2K", 31 | metadata={"help": "the location of the dataset name or path"}, 32 | ) 33 | output_dir: Optional[str] = field( 34 | default="uf_split0_responses_K8.jsonl", 35 | metadata={"help": "the location of the output file"}, 36 | ) 37 | bos_format: Optional[str] = field( 38 | default="", 39 | metadata={"help": "the format of the beginning of the sentence"}, 40 | ) 41 | K: Optional[int] = field( 42 | default=8, 43 | metadata={"help": "the number of generations per prompt"}, 44 | ) 45 | max_input_length: Optional[int] = field( 46 | default=10000, 47 | metadata={"help": "the maximum length of the input tokens"}, 48 | ) 49 | max_new_tokens: Optional[int] = field( 50 | default=2048, 51 | metadata={"help": "the maximum length of the new tokens"}, 52 | ) 53 | seed: Optional[int] = field( 54 | default=42, 55 | metadata={"help": "the random seed"}, 56 | ) 57 | temperature: Optional[float] = field( 58 | default=0.7, 59 | metadata={"help": "the temperature"}, 60 | ) 61 | use_beam_search: Optional[bool] = field( 62 | default=False, 63 | metadata={"help": "the beam search"}, 64 | ) 65 | dataset_key: Optional[str] = field( 66 | default="context_messages", 67 | metadata={"help": "the key of the dataset"}, 68 | ) 69 | max_workers: Optional[int] = field( 70 | default=1024, 71 | metadata={"help": "the number of workers"}, 72 | ) 73 | 74 | 75 | parser = HfArgumentParser(ScriptArguments) 76 | script_args = parser.parse_args_into_dataclasses()[0] 77 | ds_dir = script_args.dataset_name_or_path 78 | output_dir = script_args.output_dir 79 | K = script_args.K 80 | ports = script_args.ports 81 | 82 | tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer) 83 | 84 | 85 | def query_model(prompt, args, port): 86 | json = { 87 | **args, 88 | "prompt": prompt, 89 | } 90 | response = requests.post(url=script_args.url + ":" + str(port) + "/generate", json=json) 91 | response_json = response.json() 92 | return [response_json["text"][i][len(prompt) :] for i in range(len(response_json["text"]))] 93 | 94 | 95 | default_args = { 96 | "use_beam_search": script_args.use_beam_search, 97 | "n": script_args.K, 98 | "temperature": script_args.temperature, 99 | "max_tokens": script_args.max_new_tokens, 100 | "seed": script_args.seed, 101 | "top_p": 1.0, 102 | "top_k": -1, 103 | "stop_token_ids": [tokenizer.eos_token_id] + script_args.eos_ids, 104 | } 105 | 106 | print(default_args) 107 | 108 | ds = load_dataset(ds_dir, split="train") 109 | # load_dataset("json", data_files=ds_dir, split="train", field="instances") 110 | print(ds) 111 | 112 | # use tokenizer.apply_template to apply the template to the prompt 113 | ds = ds.map( 114 | lambda x: { 115 | "prompt": tokenizer.apply_chat_template(x[script_args.dataset_key], tokenize=False, add_generation_prompt=True) 116 | } 117 | ) 118 | 119 | 120 | with ThreadPoolExecutor(max_workers=script_args.max_workers) as executor: 121 | result = [ 122 | executor.submit(query_model, ds[i]["prompt"], default_args, ports[i % len(ports)]) for i in range(len(ds)) 123 | ] 124 | # use tqdm to show progress 125 | for _ in tqdm(as_completed(result), total=len(result)): 126 | pass 127 | 128 | responses = [r.result() for r in result] 129 | 130 | 131 | gathered_data = [] 132 | for i in range(len(ds)): 133 | tmp_data = {"prompt": ds[i][script_args.dataset_key], "responses": responses[i]} 134 | gathered_data.append(tmp_data) 135 | 136 | print("I collect ", len(gathered_data), "samples") 137 | 138 | 139 | with open(output_dir, 'w', encoding='utf8') as f: 140 | for i in range(len(gathered_data)): 141 | json.dump(gathered_data[i], f, ensure_ascii=False) 142 | f.write('\n') 143 | -------------------------------------------------------------------------------- /annotate_data/get_rewards.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass, field 4 | from typing import Optional 5 | import numpy as np 6 | import torch 7 | from datasets import load_dataset 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, HfArgumentParser, pipeline 10 | from accelerate import Accelerator 11 | 12 | tqdm.pandas() 13 | 14 | ##### 15 | # This script takes a dataset as the input, where each sample is {"prompt": "the pormpt", "responses": ["response1", "response2", "response3", ...]} 16 | # The script will compute the reward for each input-output pair, and eventually output a new dataset, where each sample contains {"prompt": "the pormpt", "responses": ["response1", "response2", "response3", ...], "rewards": [reward1, reward2, ...]} 17 | ##### 18 | 19 | 20 | @dataclass 21 | class ScriptArguments: 22 | """ 23 | The arguments for the DPO training script. 24 | """ 25 | 26 | dataset_name_or_path: Optional[str] = field( 27 | default="uf_split0_responses_K8.jsonl", 28 | metadata={"help": "the location of the dataset name or path"}, 29 | ) 30 | output_dir: Optional[str] = field( 31 | default="uf_split0_responses_K8_reward.json", 32 | metadata={"help": "the location of the output file"}, 33 | ) 34 | record_dir: Optional[str] = field( 35 | default=None, 36 | metadata={"help": "the location of the recording file"}, 37 | ) 38 | reward_name_or_path: Optional[str] = field( 39 | default="sfairXC/FsfairX-LLaMA3-RM-v0.1", 40 | metadata={"help": "the name of the reward model"}, 41 | ) 42 | input_output_delimiter: Optional[str] = field( 43 | default="", 44 | metadata={"help": "the delimiter between input and output"}, 45 | ) 46 | K: Optional[int] = field( 47 | default=8, 48 | metadata={"help": "the number of responses per prompt"}, 49 | ) 50 | 51 | 52 | accelerator = Accelerator() 53 | 54 | parser = HfArgumentParser(ScriptArguments) 55 | script_args = parser.parse_args_into_dataclasses()[0] 56 | 57 | device = accelerator.device 58 | pipe_kwargs = { 59 | "return_all_scores": True, 60 | "function_to_apply": "none", 61 | "batch_size": 1, 62 | } 63 | reward_model = script_args.reward_name_or_path 64 | rm_tokenizer = AutoTokenizer.from_pretrained(reward_model) 65 | rm_pipe = pipeline( 66 | "sentiment-analysis", 67 | model=reward_model, 68 | device=device, 69 | tokenizer=rm_tokenizer, 70 | model_kwargs={"torch_dtype": torch.bfloat16}, 71 | truncation=True, 72 | ) 73 | 74 | 75 | ds_dir = script_args.dataset_name_or_path 76 | world_size = int(os.getenv("WORLD_SIZE", "1")) 77 | ds = load_dataset("json", data_files=ds_dir, split="train") 78 | 79 | local_rank = Accelerator().local_process_index 80 | 81 | data_size = len(ds["prompt"]) 82 | 83 | share = int(data_size / world_size) + 1 84 | ds = ds.select(np.arange(local_rank * share, min((local_rank + 1) * share, len(ds)))) 85 | 86 | """ 87 | We process the data format here and query the reward model to get the rewards. 88 | """ 89 | 90 | 91 | def get_reward(test_texts): 92 | pipe_outputs = rm_pipe(test_texts, **pipe_kwargs) 93 | rewards = [output[0]["score"] for output in pipe_outputs] 94 | return rewards 95 | 96 | 97 | def change_of_format(prom, resp): 98 | # To be modified according to the reward model and the LLM you use 99 | # Be careful about multi-turn conversions 100 | """ 101 | prom = prom.replace("GPT4 Correct User: ", "").replace("<|end_of_turn|>GPT4 Correct Assistant:", "") 102 | 103 | final_resp = resp.split("GPT4 Correct User")[0] 104 | """ 105 | message = prom + [{"role": "assistant", "content": resp}] 106 | return rm_tokenizer.apply_chat_template(message, tokenize=False).replace(rm_tokenizer.bos_token, "") 107 | 108 | 109 | data = [] 110 | 111 | # tqdm is used to show the progress bar 112 | with torch.no_grad(): 113 | for sample in tqdm(ds): 114 | # The VLLM may not generate responses for some prompts because it is too long, we skip them 115 | if len(sample["responses"]) < script_args.K: 116 | continue 117 | test_texts = [change_of_format(sample['prompt'], tmp_output) for tmp_output in sample['responses']] 118 | 119 | rewards = get_reward(test_texts) 120 | data.append({"prompt": sample["prompt"], "responses": sample["responses"], "rewards": rewards}) 121 | 122 | 123 | # Send the data to other GPUs 124 | world_size = int(os.getenv("WORLD_SIZE", "1")) 125 | all_process_list = [{}] * world_size 126 | 127 | data_to_send = { 128 | "data": [[data[i]] for i in range(len(data))], 129 | } 130 | 131 | import torch.distributed as dist 132 | 133 | dist.all_gather_object(all_process_list, data_to_send) 134 | gathered_data = [] 135 | 136 | 137 | for i in range(world_size): 138 | tmp_data = [tmp[0] for tmp in all_process_list[i]["data"]] 139 | gathered_data.extend(tmp_data) 140 | 141 | all_rewards = [sample["rewards"] for sample in gathered_data] 142 | top1_scores = np.mean(np.max(all_rewards, axis=1)) 143 | mean_scores = np.mean(all_rewards) 144 | 145 | 146 | if local_rank == 0: 147 | print( 148 | "Collect {} data from {} inputs. mean score {} top1 score: {}".format( 149 | len(gathered_data), data_size, mean_scores, top1_scores 150 | ) 151 | ) 152 | if len(gathered_data) < data_size: 153 | print( 154 | "Some of the prompts are with responses < {}. This can happen because the prompt is too long and is ignored by VLLM".format( 155 | script_args.K 156 | ) 157 | ) 158 | 159 | with open(script_args.output_dir, "w", encoding="utf8") as f: 160 | for i in range(len(gathered_data)): 161 | json.dump(gathered_data[i], f, ensure_ascii=False) 162 | f.write('\n') 163 | 164 | if script_args.record_dir is not None: 165 | with open(script_args.record_dir, "a") as f: 166 | f.write(str(mean_scores) + "\t" + str(top1_scores) + "\n") 167 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RAFT 2 | This is an official implementation of the [Reward rAnked Fine-Tuning Algorithm (RAFT)](https://arxiv.org/pdf/2304.06767), also known as iterative best-of-n fine-tuning or rejection sampling fine-tuning. 3 | 4 | 5 | ## 1 Structure 6 | 7 | The initial release of tis project focus on the Bradley-Terry reward modeling and pairwise preference model. Since then, we have included more advanced techniques to construct preference model. The structure of this project is 8 | 9 | - [`Data generation`](./generation/) to generate n responses per prompt; 10 | - [`Reward Ranking`](./annotate_data/) to compute the rewards of the responses and select the response with highest reward; 11 | - [`Finetuning`](./sft/) to finetune the model on the selected responses. 12 | 13 | 14 | We also provide a small demo for RAFT+diffusion model in [diffusion-example](./diffusion-example). 15 | 16 | You may also refer to our [colab](https://colab.research.google.com/drive/1bQmlSiKnqFjrkijFUJ5ylbYW-zUwObqL) for more information. 17 | 18 | ## 2 Installation instructions 19 | 20 | It is recommended to have two separate environments for inference and training, respectively. 21 | 22 | **Inference Environment** 23 | 24 | ```sh 25 | conda create -n vllm python=3.10.9 26 | conda activate vllm 27 | pip install datasets 28 | 29 | # The following code is tested for CUDA12.0-12.2, and CUDA12.6 30 | # To develop llama-3, mistral, gemma-1, 1.1, 2, deepseek you can consider the following vllm version 31 | pip install vllm==0.5.4 32 | 33 | pip install accelerate==0.33.0 34 | pip install deepspeed==0.14.5 35 | pip install transformers==4.43.4 36 | pip install numpy==1.26.4 #Note that the numpy version should be `numpy<2.0`. `Numpy 2.0` will encounter unexpected issues!!! 37 | ``` 38 | 39 | **Training Environment** 40 | 41 | 42 | ```shell 43 | conda create -n sft_train python=3.10.9 44 | conda activate sft_train 45 | 46 | ## Get axolotl for general model 47 | git clone https://github.com/OpenAccess-AI-Collective/axolotl 48 | cd axolotl 49 | git checkout 55cc214c767741e83ee7b346e5e13e6c03b7b9fa 50 | pip install -e . 51 | 52 | # The test cuda version is 12.1, 12.2. You may need to update the torch version based on your cuda version... 53 | # you may encounter underfined symbol error related to cuda and flash-attn and 2.1.2 can solve it ... 54 | pip3 install torch==2.1.2 torchvision torchaudio 55 | pip install flash-attn==2.6.3 56 | 57 | 58 | ## Get FastChat 59 | git clone https://github.com/lm-sys/FastChat.git 60 | cd FastChat 61 | pip install -e . 62 | 63 | git clone https://github.com/WeiXiongUST/RLHF-Reward-Modeling.git 64 | pip install deepspeed 65 | ``` 66 | 67 | ## Running the code 68 | 69 | ### 3.1 Data generation 70 | 71 | We have prepared some prompt sets on huggingface. 72 | - UltraFeedback RLHFlow/ultrafeedback_iter1, RLHFlow/ultrafeedback_iter2, RLHFlow/ultrafeedback_iter3 73 | - RLHFlow/iterative-prompt-v1-iter1-20K, RLHFlow/iterative-prompt-v1-iter2-20K, RLHFlow/iterative-prompt-v1-iter3-20K ... 74 | 75 | To accelerate data generation, we use the VLLM. We prepare two ways of using VLLM to inference for a more robust implementation, where you can try them out and choose the one that fits with your environment best. We use LLaMA3-8B as an example. 76 | 77 | You may create a test_gen.sh file, and copy the following contents into the file and run ``bash test_gen.sh''. 78 | 79 | ```sh 80 | # First approach: initialize 4 VLLM processes and split the prompt set to the 4 agents 81 | # The generated samples will be stored at output_dir + local_index + ".jsonl 82 | 83 | my_world_size=8 # how many gpu you use 84 | infer_model=RLHFlow/LLaMA3-SFT 85 | prompt_dir=RLHFlow/test_generation_2k 86 | mkdir data 87 | output_dir=./data/gen_data 88 | 89 | conda activate vllm 90 | CUDA_VISIBLE_DEVICES=0 python ./generation/get_hf2.py --model_name_or_path ${infer_model} --dataset_name_or_path ${prompt_dir} --output_dir ${output_dir} --K 4 --temperature 1.0 --local_index 0 --my_world_size ${my_world_size} & 91 | CUDA_VISIBLE_DEVICES=1 python ./generation/get_hf2.py --model_name_or_path ${infer_model} --dataset_name_or_path ${prompt_dir} --output_dir ${output_dir} --K 4 --temperature 1.0 --local_index 1 --my_world_size ${my_world_size} & 92 | CUDA_VISIBLE_DEVICES=2 python ./generation/get_hf2.py --model_name_or_path ${infer_model} --dataset_name_or_path ${prompt_dir} --output_dir ${output_dir} --K 4 --temperature 1.0 --local_index 2 --my_world_size ${my_world_size} & 93 | CUDA_VISIBLE_DEVICES=3 python ./generation/get_hf2.py --model_name_or_path ${infer_model} --dataset_name_or_path ${prompt_dir} --output_dir ${output_dir} --K 4 --temperature 1.0 --local_index 3 --my_world_size ${my_world_size} & 94 | CUDA_VISIBLE_DEVICES=4 python ./generation/get_hf2.py --model_name_or_path ${infer_model} --dataset_name_or_path ${prompt_dir} --output_dir ${output_dir} --K 4 --temperature 1.0 --local_index 4 --my_world_size ${my_world_size} & 95 | CUDA_VISIBLE_DEVICES=5 python ./generation/get_hf2.py --model_name_or_path ${infer_model} --dataset_name_or_path ${prompt_dir} --output_dir ${output_dir} --K 4 --temperature 1.0 --local_index 5 --my_world_size ${my_world_size} & 96 | CUDA_VISIBLE_DEVICES=6 python ./generation/get_hf2.py --model_name_or_path ${infer_model} --dataset_name_or_path ${prompt_dir} --output_dir ${output_dir} --K 4 --temperature 1.0 --local_index 6 --my_world_size ${my_world_size} & 97 | CUDA_VISIBLE_DEVICES=7 python ./generation/get_hf2.py --model_name_or_path ${infer_model} --dataset_name_or_path ${prompt_dir} --output_dir ${output_dir} --K 4 --temperature 1.0 --local_index 7 --my_world_size ${my_world_size} & 98 | 99 | # then, we merge the 8 datasets into one dataset. 100 | wait 101 | python ./generation/merge_data.py --base_path ${output_dir} --output_dir ./data/gen_data.json --num_datasets ${my_world_size} 102 | ``` 103 | 104 | We can also use API server to generate new responses. 105 | 106 | ```sh 107 | mkdir data 108 | conda activate vllm 109 | 110 | # register the api server 111 | bash ./generation/register_server.sh RLHFlow/LLaMA3-SFT 112 | 113 | # start to generate 114 | python ./generation/gen_hf.py --ports 8000 8001 8002 8003 8004 8005 8006 8007 --tokenizer RLHFlow/LLaMA3-SFT --dataset_name_or_path RLHFlow/test_generation_2k --output_dir ./data/gen_data.jsonl --K 4 --temperature 1.0 115 | ``` 116 | 117 | ### 3.2 Data Annotation 118 | Then, we call the reward/preference model trained in step 2 to rank the generated responses. 119 | 120 | ```sh 121 | accelerate launch ./annotate_data/get_rewards.py --dataset_name_or_path ./data/gen_data.jsonl --output_dir ./data/data_with_rewards.jsonl --K 4 122 | 123 | python ./annotate_data/get_bon_data.py --dataset_name_or_path ./data/data_with_rewards.jsonl --output_dir your_huggingface_dataset_dir 124 | ``` 125 | 126 | If you encounter error ``TypeError: Got unsupported ScalarType BFloat16'', considering adjusting your transformer version. 127 | 128 | ### 3.3 Training 129 | 130 | ```sh 131 | conda activate sft_train 132 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" torchrun --nproc_per_node 8 --master_port 20001 -m axolotl.cli.train gemma-2b-it.yaml --deepspeed ./configs/deepspeed_stage2.json 133 | ``` 134 | If you encounter out-of-memory issue. Running the code with Gemma-2b-it with deepspeed stage 3 and gradient checkpoint (set in the config). 135 | 136 | ## Citation 137 | 138 | If you find the content of this repo useful in your work, please consider citing: 139 | 140 | ```bibtex 141 | @article{dong2023raft, 142 | title={{RAFT}: Reward rAnked FineTuning for Generative Foundation Model Alignment}, 143 | author={Hanze Dong and Wei Xiong and Deepanshu Goyal and Yihan Zhang and Winnie Chow and Rui Pan and Shizhe Diao and Jipeng Zhang and KaShun SHUM and Tong Zhang}, 144 | journal={Transactions on Machine Learning Research}, 145 | issn={2835-8856}, 146 | year={2023}, 147 | url={https://openreview.net/forum?id=m7p5O7zblY}, 148 | } 149 | 150 | ``` 151 | -------------------------------------------------------------------------------- /diffusion-example/SD256-RAFT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "YoURrh11fbIc" 7 | }, 8 | "source": [ 9 | "# RAFT: Reward rAnked FineTuning for Generative Foundation Model Alignment\n", 10 | "\n", 11 | "This notebook beautifully showcases how RAFT can be leveraged to fine-tune a model.\n", 12 | "\n", 13 | "\n", 14 | "\n", 15 | "\n", 16 | "\n", 17 | "Curious how this works? Read our [paper](https://arxiv.org/abs/2304.06767) to explore the intricacies of our innovative approach." 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "BzmCovNKkwbi" 24 | }, 25 | "source": [ 26 | "## Initial Setup" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": { 33 | "id": "n7TI5hirlzn8" 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "#@title Install the required libs\n", 38 | "%pip install -q accelerate diffusers transformers ftfy bitsandbytes gradio natsort safetensors xformers datasets\n", 39 | "%pip install -qq \"ipywidgets>=7,<8\"\n", 40 | "!wget -q https://raw.githubusercontent.com/OptimalScale/LMFlow/main/experimental/RAFT-diffusion/train_text_to_image_lora.py" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "cellView": "form", 48 | "id": "fvCBZCnrqcX1" 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "#@title Install CLIP\n", 53 | "\n", 54 | "!pip install git+https://github.com/deepgoyal19/CLIP.git" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": { 61 | "id": "guDgmswnmW-4" 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "#@title Import required libraries\n", 66 | "import argparse\n", 67 | "import itertools\n", 68 | "import math\n", 69 | "import os\n", 70 | "import shutil\n", 71 | "from os.path import expanduser # pylint: disable=import-outside-toplevel\n", 72 | "from urllib.request import urlretrieve # pylint: disable=import-outside-toplevel\n", 73 | "from contextlib import nullcontext\n", 74 | "import random\n", 75 | "import pandas as pd\n", 76 | "import numpy as np\n", 77 | "import torch\n", 78 | "import torch.nn.functional as F\n", 79 | "import torch.utils.checkpoint\n", 80 | "from torch.utils.data import Dataset\n", 81 | "import concurrent\n", 82 | "import PIL\n", 83 | "from accelerate import Accelerator\n", 84 | "from accelerate.logging import get_logger\n", 85 | "from accelerate.utils import set_seed\n", 86 | "from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DPMSolverMultistepScheduler\n", 87 | "from diffusers.optimization import get_scheduler\n", 88 | "from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker\n", 89 | "from PIL import Image\n", 90 | "from torchvision import transforms\n", 91 | "from tqdm.auto import tqdm\n", 92 | "from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n", 93 | "import clip\n", 94 | "import bitsandbytes as bnb\n", 95 | "from torch.utils.data import DataLoader\n", 96 | "def image_grid(imgs, rows, cols):\n", 97 | " assert len(imgs) == rows*cols\n", 98 | "\n", 99 | " w, h = imgs[0].size\n", 100 | " grid = Image.new('RGB', size=(cols*w, rows*h))\n", 101 | " grid_w, grid_h = grid.size\n", 102 | " \n", 103 | " for i, img in enumerate(imgs):\n", 104 | " grid.paste(img, box=(i%cols*w, i//cols*h))\n", 105 | " return grid" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": { 111 | "id": "f4D64FI9pI38" 112 | }, 113 | "source": [ 114 | "## Loading Dataset" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": { 121 | "id": "7IryKE4wq0SZ", 122 | "cellView": "form" 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "#@title Creating Dataloader\n", 127 | "\n", 128 | "prompts=['airplane','automobile','bird','deer','dog','cat','frog','horse','ship','truck'] # CIFAR labels\n", 129 | "prompts = pd.DataFrame({'prompts': prompts}) #converting prompts list into a pandas dataframe\n", 130 | "\n", 131 | "class CIFAR10Dataset():\n", 132 | " def __init__(self):\n", 133 | " global prompts\n", 134 | " self.prompts=prompts.iloc[:,0]\n", 135 | " \n", 136 | " def __len__(self):\n", 137 | " return len(self.prompts)\n", 138 | " \n", 139 | " def __getitem__(self,index):\n", 140 | " return self.prompts.iloc[index]\n", 141 | "\n", 142 | "#@markdown Please mention the batch size.\n", 143 | "batch_size =5 #@param {type:\"integer\"}\n", 144 | "\n", 145 | "\n", 146 | "dataset = CIFAR10Dataset()\n", 147 | "finetune_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": { 153 | "id": "BWH9vc1kvhvC" 154 | }, 155 | "source": [ 156 | "## Loading CLIP" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": { 163 | "id": "lJAguhs1d89L" 164 | }, 165 | "outputs": [], 166 | "source": [ 167 | "def get_aesthetic_model(clip_model=\"vit_l_14\"):\n", 168 | " \"\"\"load the aethetic model\"\"\"\n", 169 | " home = expanduser(\"~\")\n", 170 | " cache_folder = home + \"/.cache/emb_reader\"\n", 171 | " path_to_model = cache_folder + \"/sa_0_4_\"+clip_model+\"_linear.pth\"\n", 172 | " if not os.path.exists(path_to_model):\n", 173 | " os.makedirs(cache_folder, exist_ok=True)\n", 174 | " url_model = (\n", 175 | " \"https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_\"+clip_model+\"_linear.pth?raw=true\"\n", 176 | " )\n", 177 | " urlretrieve(url_model, path_to_model)\n", 178 | " if clip_model == \"vit_l_14\":\n", 179 | " m = torch.nn.Linear(768, 1)\n", 180 | " elif clip_model == \"vit_b_32\":\n", 181 | " m = torch.nn.Linear(512, 1)\n", 182 | " else:\n", 183 | " raise ValueError()\n", 184 | " s = torch.load(path_to_model)\n", 185 | " m.load_state_dict(s)\n", 186 | " m.eval()\n", 187 | " return m\n", 188 | "\n", 189 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 190 | "amodel= get_aesthetic_model(clip_model=\"vit_l_14\").to(device)\n", 191 | "amodel.eval()\n", 192 | "\n", 193 | "model, preprocess = clip.load('ViT-L/14', device=device)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": { 199 | "id": "0RPeQGHUzUZp" 200 | }, 201 | "source": [ 202 | "## Evaluating Aesthetic Score" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": { 209 | "id": "s61Ljr9Sd89M" 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "def get_image_score(image): #Evaluating Scores if images\n", 214 | " images = preprocess(image).unsqueeze(0).to(device)\n", 215 | " with torch.no_grad():\n", 216 | " image_features= model.encode_image(images).to(device)\n", 217 | " image_features /= image_features.norm(dim=-1, keepdim=True)\n", 218 | " image_features=image_features.to(torch.float32)\n", 219 | " prediction = amodel(image_features)\n", 220 | " return(float(prediction))\n", 221 | " \n", 222 | "def get_max_score(image_list,index,epoch=0): #The get_max_score function will return prompt's image with the highest aesthetic score will be chosen for additional fine-tuning.\n", 223 | " score_list=[]\n", 224 | " for image in image_list:\n", 225 | " score_list.append(get_image_score(image))\n", 226 | " torch.cuda.empty_cache()\n", 227 | "\n", 228 | " prompts.loc[index,f'Epoch{epoch} Scores']=max(score_list)\n", 229 | " return [max(score_list),score_list.index(max(score_list))]\n" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": { 235 | "id": "Ak1jArUL0eCi" 236 | }, 237 | "source": [ 238 | "##Parameters" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": { 245 | "id": "jv6WYJos0iT5", 246 | "cellView": "form" 247 | }, 248 | "outputs": [], 249 | "source": [ 250 | "#@title Settings for the model\n", 251 | "\n", 252 | "#@markdown All settings have been configured to achieve optimal output. Changing them is not advisable.\n", 253 | "\n", 254 | "#@markdown Enter value for `resolution`.\n", 255 | "resolution=256 #@param {type:\"integer\"}\n", 256 | "\n", 257 | "#@markdown Enter value for `num_images_per_prompt`.\n", 258 | "num_images_per_prompt=10 #@param {type:\"integer\"} \n", 259 | "\n", 260 | "#@markdown Enter value for `epochs`. \n", 261 | "epochs=10 #@param {type:\"integer\"} |" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "metadata": { 268 | "id": "7gFbnMaLd89N", 269 | "cellView": "form" 270 | }, 271 | "outputs": [], 272 | "source": [ 273 | "# @title Setting Stable Diffusion pipeline\n", 274 | "model_id = \"runwayml/stable-diffusion-v1-5\"\n", 275 | "pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)\n", 276 | "pipe.enable_xformers_memory_efficient_attention()\n", 277 | "torch.cuda.empty_cache()\n", 278 | "\n", 279 | "#@markdown Check the `set_progress_bar_config` option if you would like to hide the progress bar for image generation\n", 280 | "set_progress_bar_config= False #@param {type:\"boolean\"}\n", 281 | "pipe.set_progress_bar_config(disable=set_progress_bar_config) \n", 282 | "\n", 283 | "\n", 284 | "scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n", 285 | "pipe.scheduler = scheduler\n", 286 | "\n", 287 | "torch.cuda.empty_cache()\n" 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": { 293 | "id": "9U2P_PUN-5xX" 294 | }, 295 | "source": [ 296 | "##Finetuning" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": { 303 | "id": "F-m6S9Sg-yS_", 304 | "cellView": "form" 305 | }, 306 | "outputs": [], 307 | "source": [ 308 | "#@title Generating images on the pretrained model\n", 309 | "\n", 310 | "#@markdown Check the box to generate images using the pretrained model.\n", 311 | "generate_pretrained_model_images= True #@param {type:\"boolean\"}\n", 312 | "\n", 313 | "if generate_pretrained_model_images:\n", 314 | " image_list=[]\n", 315 | " for step, prompt_list in enumerate(finetune_dataloader):\n", 316 | " image=pipe(prompt_list,num_images_per_prompt=num_images_per_prompt,width=resolution,height=resolution).images \n", 317 | " image_list+=image\n", 318 | " torch.cuda.empty_cache()\n", 319 | "\n", 320 | " grid = image_grid(image_list, len(prompts),num_images_per_prompt)\n", 321 | " grid.save(\"pretrained.png\") \n", 322 | " grid\n", 323 | "\n" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": { 330 | "id": "kPfHR4HQd89N" 331 | }, 332 | "outputs": [], 333 | "source": [ 334 | "#@title Run training\n", 335 | "\n", 336 | "os.environ['MODEL_NAME'] = model_id\n", 337 | "os.environ['OUTPUT_DIR'] = f\"./CustomModel/\"\n", 338 | "topk=8\n", 339 | "training_steps_per_epoch=topk*10\n", 340 | "os.environ['CHECKPOINTING_STEPS']=str(training_steps_per_epoch)\n", 341 | "os.environ['RESOLUTION']=str(resolution)\n", 342 | "os.environ['LEARNING_RATE']=str(9e-6)\n", 343 | "\n", 344 | "# remove old account directory\n", 345 | "try: \n", 346 | " shutil.rmtree('./CustomModel')\n", 347 | "except:\n", 348 | " pass\n", 349 | "try: \n", 350 | " shutil.rmtree('./trainingdataset/imagefolder/')\n", 351 | "except:\n", 352 | " pass\n", 353 | "\n", 354 | "model_id = \"runwayml/stable-diffusion-v1-5\"\n", 355 | "\n", 356 | "\n", 357 | "for epoch in range(epochs+1):\n", 358 | " print(\"Epoch: \",epoch)\n", 359 | " epoch=epoch\n", 360 | " training_steps=str(training_steps_per_epoch*(epoch+1))\n", 361 | " os.environ['TRAINING_STEPS']=training_steps\n", 362 | " os.environ['TRAINING_DIR'] = f'./trainingdataset/imagefolder/{epoch}'\n", 363 | "\n", 364 | " training_prompts=[]\n", 365 | " prompts[f'Epoch{epoch} Scores']=np.nan\n", 366 | "\n", 367 | " for step, prompt_list in enumerate(finetune_dataloader):\n", 368 | " image=pipe(prompt_list,num_images_per_prompt=num_images_per_prompt,width=resolution,height=resolution).images\n", 369 | " image_list=[]\n", 370 | "\n", 371 | " for i in range(int(len(image)/num_images_per_prompt)):\n", 372 | " image_list.append(image[i*num_images_per_prompt:(i+1)*num_images_per_prompt])\n", 373 | " torch.cuda.empty_cache()\n", 374 | " \n", 375 | " with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:\n", 376 | " step_list=[i for i in range(step*batch_size,(step+1)*batch_size)]\n", 377 | " score_index=executor.map(get_max_score,image_list,step_list,[epoch for i in range(len(step_list))])\n", 378 | "\n", 379 | " iterator=0\n", 380 | " for max_scores in score_index:\n", 381 | " training_prompts.append([max_scores[0],image_list[iterator][max_scores[1]],prompt_list[iterator]])\n", 382 | " iterator+=1\n", 383 | "\n", 384 | " training_prompts=[row[1:3] for row in sorted(training_prompts,key=lambda x: (x[0]),reverse=True)[:topk]]\n", 385 | " training_prompts=pd.DataFrame(training_prompts)\n", 386 | "\n", 387 | " if not os.path.exists(f\"./trainingdataset/imagefolder/{epoch}/train/\"):\n", 388 | " os.makedirs(f\"./trainingdataset/imagefolder/{epoch}/train/\")\n", 389 | " if not os.path.exists(f\"./CustomModel/\"):\n", 390 | " os.makedirs(f\"./CustomModel/\")\n", 391 | " for i in range(len(training_prompts)):\n", 392 | " training_prompts.iloc[i,0].save(f'./trainingdataset/imagefolder/{epoch}/train/{i}.png')\n", 393 | "\n", 394 | " training_prompts['file_name']=[f\"{i}.png\" for i in range(len(training_prompts))]\n", 395 | " training_prompts.columns = ['0','text','file_name']\n", 396 | " training_prompts.drop('0',axis=1,inplace=True)\n", 397 | " training_prompts.to_csv(f'./trainingdataset/imagefolder/{epoch}/train/metadata.csv',index=False)\n", 398 | " torch.cuda.empty_cache()\n", 399 | "\n", 400 | " if epoch=" 291 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 292 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 293 | ), 294 | ) 295 | parser.add_argument( 296 | "--report_to", 297 | type=str, 298 | default="tensorboard", 299 | help=( 300 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 301 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 302 | ), 303 | ) 304 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 305 | parser.add_argument( 306 | "--checkpointing_steps", 307 | type=int, 308 | default=500, 309 | help=( 310 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 311 | " training using `--resume_from_checkpoint`." 312 | ), 313 | ) 314 | parser.add_argument( 315 | "--checkpoints_total_limit", 316 | type=int, 317 | default=None, 318 | help=( 319 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." 320 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" 321 | " for more docs" 322 | ), 323 | ) 324 | parser.add_argument( 325 | "--resume_from_checkpoint", 326 | type=str, 327 | default=None, 328 | help=( 329 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 330 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 331 | ), 332 | ) 333 | parser.add_argument( 334 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 335 | ) 336 | 337 | args = parser.parse_args() 338 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 339 | if env_local_rank != -1 and env_local_rank != args.local_rank: 340 | args.local_rank = env_local_rank 341 | 342 | # Sanity checks 343 | if args.dataset_name is None and args.train_data_dir is None: 344 | raise ValueError("Need either a dataset name or a training folder.") 345 | 346 | return args 347 | 348 | 349 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 350 | if token is None: 351 | token = HfFolder.get_token() 352 | if organization is None: 353 | username = whoami(token)["name"] 354 | return f"{username}/{model_id}" 355 | else: 356 | return f"{organization}/{model_id}" 357 | 358 | 359 | DATASET_NAME_MAPPING = { 360 | "lambdalabs/pokemon-blip-captions": ("image", "text"), 361 | } 362 | 363 | 364 | def main(): 365 | args = parse_args() 366 | project_dir = os.path.join(args.output_dir, args.project_dir) 367 | 368 | accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) 369 | 370 | accelerator = Accelerator( 371 | gradient_accumulation_steps=args.gradient_accumulation_steps, 372 | mixed_precision=args.mixed_precision, 373 | log_with=args.report_to, 374 | project_dir=project_dir, 375 | project_config=accelerator_project_config, 376 | ) 377 | if args.report_to == "wandb": 378 | if not is_wandb_available(): 379 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 380 | import wandb 381 | 382 | # Make one log on every process with the configuration for debugging. 383 | logging.basicConfig( 384 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 385 | datefmt="%m/%d/%Y %H:%M:%S", 386 | level=logging.INFO, 387 | ) 388 | logger.info(accelerator.state, main_process_only=False) 389 | if accelerator.is_local_main_process: 390 | datasets.utils.logging.set_verbosity_warning() 391 | transformers.utils.logging.set_verbosity_warning() 392 | diffusers.utils.logging.set_verbosity_info() 393 | else: 394 | datasets.utils.logging.set_verbosity_error() 395 | transformers.utils.logging.set_verbosity_error() 396 | diffusers.utils.logging.set_verbosity_error() 397 | 398 | # If passed along, set the training seed now. 399 | if args.seed is not None: 400 | set_seed(args.seed) 401 | 402 | # Handle the repository creation 403 | if accelerator.is_main_process: 404 | if args.output_dir is not None: 405 | os.makedirs(args.output_dir, exist_ok=True) 406 | 407 | if args.push_to_hub: 408 | repo_id = create_repo( 409 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 410 | ).repo_id 411 | 412 | # Load scheduler, tokenizer and models. 413 | noise_scheduler = DDPMScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler") 414 | tokenizer = CLIPTokenizer.from_pretrained( 415 | args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision 416 | ) 417 | text_encoder = CLIPTextModel.from_pretrained( 418 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 419 | ) 420 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) 421 | unet = UNet2DConditionModel.from_pretrained( 422 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision 423 | ) 424 | # freeze parameters of models to save more memory 425 | unet.requires_grad_(False) 426 | vae.requires_grad_(False) 427 | 428 | text_encoder.requires_grad_(False) 429 | 430 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 431 | # as these models are only used for inference, keeping weights in full precision is not required. 432 | weight_dtype = torch.float32 433 | if accelerator.mixed_precision == "fp16": 434 | weight_dtype = torch.float16 435 | elif accelerator.mixed_precision == "bf16": 436 | weight_dtype = torch.bfloat16 437 | 438 | # Move unet, vae and text_encoder to device and cast to weight_dtype 439 | unet.to(accelerator.device, dtype=weight_dtype) 440 | vae.to(accelerator.device, dtype=weight_dtype) 441 | text_encoder.to(accelerator.device, dtype=weight_dtype) 442 | 443 | # now we will add new LoRA weights to the attention layers 444 | # It's important to realize here how many attention weights will be added and of which sizes 445 | # The sizes of the attention layers consist only of two different variables: 446 | # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. 447 | # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. 448 | 449 | # Let's first see how many attention processors we will have to set. 450 | # For Stable Diffusion, it should be equal to: 451 | # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 452 | # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 453 | # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 454 | # => 32 layers 455 | 456 | # Set correct lora layers 457 | lora_attn_procs = {} 458 | for name in unet.attn_processors.keys(): 459 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 460 | if name.startswith("mid_block"): 461 | hidden_size = unet.config.block_out_channels[-1] 462 | elif name.startswith("up_blocks"): 463 | block_id = int(name[len("up_blocks.")]) 464 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 465 | elif name.startswith("down_blocks"): 466 | block_id = int(name[len("down_blocks.")]) 467 | hidden_size = unet.config.block_out_channels[block_id] 468 | 469 | lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) 470 | 471 | unet.set_attn_processor(lora_attn_procs) 472 | 473 | if args.enable_xformers_memory_efficient_attention: 474 | if is_xformers_available(): 475 | import xformers 476 | 477 | xformers_version = version.parse(xformers.__version__) 478 | if xformers_version == version.parse("0.0.16"): 479 | logger.warn( 480 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 481 | ) 482 | unet.enable_xformers_memory_efficient_attention() 483 | else: 484 | raise ValueError("xformers is not available. Make sure it is installed correctly") 485 | 486 | lora_layers = AttnProcsLayers(unet.attn_processors) 487 | 488 | # Enable TF32 for faster training on Ampere GPUs, 489 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 490 | if args.allow_tf32: 491 | torch.backends.cuda.matmul.allow_tf32 = True 492 | 493 | if args.scale_lr: 494 | args.learning_rate = ( 495 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 496 | ) 497 | 498 | # Initialize the optimizer 499 | if args.use_8bit_adam: 500 | try: 501 | import bitsandbytes as bnb 502 | except ImportError: 503 | raise ImportError( 504 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 505 | ) 506 | 507 | optimizer_cls = bnb.optim.AdamW8bit 508 | else: 509 | optimizer_cls = torch.optim.AdamW 510 | 511 | optimizer = optimizer_cls( 512 | lora_layers.parameters(), 513 | lr=args.learning_rate, 514 | betas=(args.adam_beta1, args.adam_beta2), 515 | weight_decay=args.adam_weight_decay, 516 | eps=args.adam_epsilon, 517 | ) 518 | 519 | # Get the datasets: you can either provide your own training and evaluation files (see below) 520 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 521 | 522 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 523 | # download the dataset. 524 | if args.dataset_name is not None: 525 | # Downloading and loading a dataset from the hub. 526 | dataset = load_dataset( 527 | args.dataset_name, 528 | args.dataset_config_name, 529 | cache_dir=args.cache_dir, 530 | ) 531 | else: 532 | data_files = {} 533 | if args.train_data_dir is not None: 534 | data_files["train"] = os.path.join(args.train_data_dir, "**") 535 | dataset = load_dataset( 536 | "imagefolder", 537 | data_files=data_files, 538 | cache_dir=args.cache_dir, 539 | ) 540 | # See more about loading custom images at 541 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 542 | 543 | # Preprocessing the datasets. 544 | # We need to tokenize inputs and targets. 545 | column_names = dataset["train"].column_names 546 | 547 | # 6. Get the column names for input/target. 548 | dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) 549 | if args.image_column is None: 550 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 551 | else: 552 | image_column = args.image_column 553 | if image_column not in column_names: 554 | raise ValueError( 555 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" 556 | ) 557 | if args.caption_column is None: 558 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 559 | else: 560 | caption_column = args.caption_column 561 | if caption_column not in column_names: 562 | raise ValueError( 563 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" 564 | ) 565 | 566 | # Preprocessing the datasets. 567 | # We need to tokenize input captions and transform the images. 568 | def tokenize_captions(examples, is_train=True): 569 | captions = [] 570 | for caption in examples[caption_column]: 571 | if isinstance(caption, str): 572 | captions.append(caption) 573 | elif isinstance(caption, (list, np.ndarray)): 574 | # take a random caption if there are multiple 575 | captions.append(random.choice(caption) if is_train else caption[0]) 576 | else: 577 | raise ValueError( 578 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 579 | ) 580 | inputs = tokenizer( 581 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 582 | ) 583 | return inputs.input_ids 584 | 585 | # Preprocessing the datasets. 586 | train_transforms = transforms.Compose( 587 | [ 588 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 589 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 590 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), 591 | transforms.ToTensor(), 592 | transforms.Normalize([0.5], [0.5]), 593 | ] 594 | ) 595 | 596 | def preprocess_train(examples): 597 | images = [image.convert("RGB") for image in examples[image_column]] 598 | examples["pixel_values"] = [train_transforms(image) for image in images] 599 | examples["input_ids"] = tokenize_captions(examples) 600 | return examples 601 | 602 | with accelerator.main_process_first(): 603 | if args.max_train_samples is not None: 604 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 605 | # Set the training transforms 606 | train_dataset = dataset["train"].with_transform(preprocess_train) 607 | 608 | def collate_fn(examples): 609 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 610 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 611 | input_ids = torch.stack([example["input_ids"] for example in examples]) 612 | return {"pixel_values": pixel_values, "input_ids": input_ids} 613 | 614 | # DataLoaders creation: 615 | train_dataloader = torch.utils.data.DataLoader( 616 | train_dataset, 617 | shuffle=False, 618 | collate_fn=collate_fn, 619 | batch_size=args.train_batch_size, 620 | num_workers=args.dataloader_num_workers, 621 | ) 622 | 623 | # Scheduler and math around the number of training steps. 624 | overrode_max_train_steps = False 625 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 626 | if args.max_train_steps is None: 627 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 628 | overrode_max_train_steps = True 629 | 630 | lr_scheduler = get_scheduler( 631 | args.lr_scheduler, 632 | optimizer=optimizer, 633 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 634 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 635 | ) 636 | 637 | # Prepare everything with our `accelerator`. 638 | lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 639 | lora_layers, optimizer, train_dataloader, lr_scheduler 640 | ) 641 | 642 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 643 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 644 | if overrode_max_train_steps: 645 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 646 | # Afterwards we recalculate our number of training epochs 647 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 648 | 649 | # We need to initialize the trackers we use, and also store our configuration. 650 | # The trackers initializes automatically on the main process. 651 | if accelerator.is_main_process: 652 | accelerator.init_trackers("text2image-fine-tune", config=vars(args)) 653 | 654 | # Train! 655 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 656 | 657 | logger.info("***** Running training *****") 658 | logger.info(f" Num examples = {len(train_dataset)}") 659 | logger.info(f" Num Epochs = {args.num_train_epochs}") 660 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 661 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 662 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 663 | logger.info(f" Total optimization steps = {args.max_train_steps}") 664 | global_step = 0 665 | first_epoch = 0 666 | 667 | # Potentially load in the weights and states from a previous save 668 | if args.resume_from_checkpoint: 669 | if args.resume_from_checkpoint != "latest": 670 | path = os.path.basename(args.resume_from_checkpoint) 671 | else: 672 | # Get the most recent checkpoint 673 | dirs = os.listdir(args.output_dir) 674 | dirs = [d for d in dirs if d.startswith("checkpoint")] 675 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 676 | path = dirs[-1] if len(dirs) > 0 else None 677 | 678 | if path is None: 679 | accelerator.print( 680 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 681 | ) 682 | args.resume_from_checkpoint = None 683 | else: 684 | accelerator.print(f"Resuming from checkpoint {path}") 685 | accelerator.load_state(os.path.join(args.output_dir, path)) 686 | global_step = int(path.split("-")[1]) 687 | 688 | resume_global_step = global_step * args.gradient_accumulation_steps 689 | first_epoch = global_step // num_update_steps_per_epoch 690 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) 691 | 692 | # Only show the progress bar once on each machine. 693 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 694 | progress_bar.set_description("Steps") 695 | 696 | for epoch in range(first_epoch, args.num_train_epochs): 697 | unet.train() 698 | train_loss = 0.0 699 | for step, batch in enumerate(train_dataloader): 700 | # Skip steps until we reach the resumed step 701 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 702 | if step % args.gradient_accumulation_steps == 0: 703 | progress_bar.update(1) 704 | continue 705 | with accelerator.accumulate(unet): 706 | # Convert images to latent space 707 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 708 | latents = latents * vae.config.scaling_factor 709 | # Sample noise that we'll add to the latents 710 | noise = torch.randn_like(latents) 711 | bsz = latents.shape[0] 712 | # Sample a random timestep for each image 713 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) 714 | timesteps = timesteps.long() 715 | 716 | # Add noise to the latents according to the noise magnitude at each timestep 717 | # (this is the forward diffusion process) 718 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 719 | 720 | # Get the text embedding for conditioning 721 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 722 | 723 | # Get the target for loss depending on the prediction type 724 | if noise_scheduler.config.prediction_type == "epsilon": 725 | target = noise 726 | elif noise_scheduler.config.prediction_type == "v_prediction": 727 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 728 | else: 729 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 730 | 731 | # Predict the noise residual and compute loss 732 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 733 | 734 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 735 | 736 | # Gather the losses across all processes for logging (if we use distributed training). 737 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 738 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 739 | # logger.info(loss) 740 | # Backpropagate 741 | accelerator.backward(loss) 742 | if accelerator.sync_gradients: 743 | params_to_clip = lora_layers.parameters() 744 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 745 | optimizer.step() 746 | lr_scheduler.step() 747 | optimizer.zero_grad() 748 | 749 | # Checks if the accelerator has performed an optimization step behind the scenes 750 | if accelerator.sync_gradients: 751 | progress_bar.update(1) 752 | global_step += 1 753 | accelerator.log({"train_loss": train_loss}, step=global_step) 754 | train_loss = 0.0 755 | 756 | if global_step % args.checkpointing_steps == 0: 757 | if accelerator.is_main_process: 758 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 759 | accelerator.save_state(save_path) 760 | logger.info(f"Saved state to {save_path}") 761 | 762 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 763 | progress_bar.set_postfix(**logs) 764 | 765 | if global_step >= args.max_train_steps: 766 | break 767 | 768 | if accelerator.is_main_process: 769 | if args.validation_prompt is not None and epoch % args.validation_epochs == 0: 770 | logger.info( 771 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 772 | f" {args.validation_prompt}." 773 | ) 774 | # create pipeline 775 | pipeline = DiffusionPipeline.from_pretrained( 776 | args.pretrained_model_name_or_path, 777 | unet=accelerator.unwrap_model(unet), 778 | revision=args.revision, 779 | torch_dtype=weight_dtype, 780 | ) 781 | pipeline = pipeline.to(accelerator.device) 782 | pipeline.set_progress_bar_config(disable=True) 783 | 784 | # run inference 785 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 786 | images = [] 787 | for _ in range(args.num_validation_images): 788 | images.append( 789 | pipeline(args.validation_prompt, num_inference_steps=50, generator=generator).images[0] 790 | ) 791 | 792 | for tracker in accelerator.trackers: 793 | if tracker.name == "tensorboard": 794 | np_images = np.stack([np.asarray(img) for img in images]) 795 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") 796 | if tracker.name == "wandb": 797 | tracker.log( 798 | { 799 | "validation": [ 800 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") 801 | for i, image in enumerate(images) 802 | ] 803 | } 804 | ) 805 | 806 | del pipeline 807 | torch.cuda.empty_cache() 808 | 809 | # Save the lora layers 810 | accelerator.wait_for_everyone() 811 | if accelerator.is_main_process: 812 | unet = unet.to(torch.float32) 813 | unet.save_attn_procs(args.output_dir) 814 | 815 | if args.push_to_hub: 816 | save_model_card( 817 | repo_id, 818 | base_model=args.pretrained_model_name_or_path, 819 | dataset_name=args.dataset_name, 820 | repo_folder=args.output_dir, 821 | ) 822 | upload_folder( 823 | repo_id=repo_id, 824 | folder_path=args.output_dir, 825 | commit_message="End of training", 826 | ignore_patterns=["step_*", "epoch_*"], 827 | ) 828 | 829 | # # Final inference 830 | # Load previous pipeline 831 | # pipeline = DiffusionPipeline.from_pretrained( 832 | # args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype 833 | # ) 834 | # pipeline = pipeline.to(accelerator.device) 835 | 836 | # # load attention processors 837 | # pipeline.unet.load_attn_procs(args.output_dir) 838 | 839 | # run inference 840 | # generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 841 | # images = [] 842 | # for _ in range(args.num_validation_images): 843 | # images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) 844 | 845 | 846 | # for _ in range(1): 847 | # images.append(pipeline('a photo of cat').images[0]) 848 | # images[0].save('/root/autodl-tmp/deepanshu/output.png') 849 | 850 | # if accelerator.is_main_process: 851 | # for tracker in accelerator.trackers: 852 | # if tracker.name == "tensorboard": 853 | # np_images = np.stack([np.asarray(img) for img in images]) 854 | # tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") 855 | # if tracker.name == "wandb": 856 | # tracker.log( 857 | # { 858 | # "test": [ 859 | # wandb.Image(image, caption=f"{i}: {args.validation_prompt}") 860 | # for i, image in enumerate(images) 861 | # ] 862 | # } 863 | # ) 864 | 865 | accelerator.end_training() 866 | 867 | 868 | if __name__ == "__main__": 869 | main() 870 | --------------------------------------------------------------------------------