├── assets ├── trajectory.jpeg ├── ar_example_demo.gif ├── decoding_comparison.gif ├── jacobi_forcing_logo.jpeg ├── jacobi_forcing_example_demo.gif ├── multiblock_rejection_recycling.gif ├── noisy_context_attention_mask.jpeg ├── noise_schedule_and_sequence_packing.gif ├── baseline_comparison.sh └── baseline_comparison.py ├── .gitignore ├── JacobiForcing ├── scripts │ ├── train │ │ ├── ds_config.json │ │ ├── ds_config_cpu_offloading.json │ │ ├── baseline_sft_train.sh │ │ ├── train_clean_context_conditioned_cllm_openthinker2_n64.sh │ │ ├── train_jacobi_forcing_coder_n32.sh │ │ └── train_jacobi_forcing_coder_n64.sh │ ├── inference │ │ └── scanning_hyperparameter_jacobi_decoding_mr.sh │ └── tool │ │ ├── plot_inference_configuration_search.py │ │ ├── extract_inference_profiling_datapoints_from_log.py │ │ └── 3d_plot_inference_configuration_search_with_quadratic_fit.py ├── train │ ├── baseline_sft_train.py │ ├── deprecated │ │ ├── flexattn_cllm_trainer.py │ │ ├── soft_cllm_loss_trainer.py │ │ ├── soft_flexattn_cllm_trainer_legacy.py │ │ └── vanilla_efficient_cllm_trainer.py │ └── cllm_trainer.py ├── ar_inference_baseline.py └── jacobi_forcing_inference_MATH500.py ├── generate_trajectory ├── data │ ├── 3_str_seq_length_filtering_dataset.py │ ├── tool_merge_standalone_jsonl_data.py │ ├── tool_debug_complete_training_seq_data.py │ ├── -1_opencodeinstruct_data_filtering.py │ ├── tool_profile_trajectory_dataset.py │ ├── 2_prepare_baseline_training_data_sft_reverse_engineering.py │ ├── 3_downsample_dataset.py │ ├── 3_postprocessing_data_length_filtering.py │ ├── tool_merge_ds_ckpts.py │ ├── 1_masking_based_prepare_trajectory.py │ ├── 0_bucketing_opencodeinstruct.py │ ├── 0_bucketing_openthought2.py │ ├── 2_prepare_efficient_cllm_training_data_new.py │ ├── 1_progressive_masking_based_prepare_trajectory.py │ ├── 2_prepare_baseline_training_data_sft.py │ ├── 2_prepare_efficient_cllm_training_data.py │ └── 2_prepare_efficient_cllm_training_data_new_progressive_noise.py └── generation │ ├── generate_trajectory_opencodeinstruct_nongreedy.sh │ └── generate_trajectory_opencodeinstruct_greedy.sh ├── requirements.txt └── applications └── jacobi_streaming_driver.py /assets/trajectory.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hao-ai-lab/JacobiForcing/HEAD/assets/trajectory.jpeg -------------------------------------------------------------------------------- /assets/ar_example_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hao-ai-lab/JacobiForcing/HEAD/assets/ar_example_demo.gif -------------------------------------------------------------------------------- /assets/decoding_comparison.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hao-ai-lab/JacobiForcing/HEAD/assets/decoding_comparison.gif -------------------------------------------------------------------------------- /assets/jacobi_forcing_logo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hao-ai-lab/JacobiForcing/HEAD/assets/jacobi_forcing_logo.jpeg -------------------------------------------------------------------------------- /assets/jacobi_forcing_example_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hao-ai-lab/JacobiForcing/HEAD/assets/jacobi_forcing_example_demo.gif -------------------------------------------------------------------------------- /assets/multiblock_rejection_recycling.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hao-ai-lab/JacobiForcing/HEAD/assets/multiblock_rejection_recycling.gif -------------------------------------------------------------------------------- /assets/noisy_context_attention_mask.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hao-ai-lab/JacobiForcing/HEAD/assets/noisy_context_attention_mask.jpeg -------------------------------------------------------------------------------- /assets/noise_schedule_and_sequence_packing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hao-ai-lab/JacobiForcing/HEAD/assets/noise_schedule_and_sequence_packing.gif -------------------------------------------------------------------------------- /assets/baseline_comparison.sh: -------------------------------------------------------------------------------- 1 | python3 assets/baseline_comparison.py \ 2 | --csv assets/baselines.csv \ 3 | --out assets/baselines_comparison.png 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/collected_diffusion_trajectory/* 2 | */__pycache__ 3 | cllm2 4 | cllm2_venv 5 | wandb 6 | ckpts 7 | *.png 8 | *.csv 9 | *.txt 10 | *.jsonl 11 | profiling_results/* 12 | logs/* 13 | 14 | -------------------------------------------------------------------------------- /JacobiForcing/scripts/train/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { "enabled": true }, 3 | "zero_optimization": { 4 | "stage": 3, 5 | "allgather_partitions": true, 6 | "overlap_comm": false, 7 | "reduce_scatter": true, 8 | "contiguous_gradients": true, 9 | "stage3_param_persistence_threshold": 0 10 | }, 11 | "gradient_accumulation_steps": 1, 12 | "train_micro_batch_size_per_gpu": 1, 13 | "wall_clock_breakdown": false 14 | } 15 | -------------------------------------------------------------------------------- /JacobiForcing/scripts/train/ds_config_cpu_offloading.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { "enabled": true }, 3 | "zero_optimization": { 4 | "stage": 3, 5 | "offload_param": { "device": "cpu", "pin_memory": true }, 6 | "offload_optimizer": { "device": "cpu", "pin_memory": true }, 7 | 8 | "allgather_partitions": true, 9 | "overlap_comm": false, 10 | "reduce_scatter": true, 11 | "contiguous_gradients": true, 12 | "stage3_param_persistence_threshold": 0 13 | }, 14 | "activation_checkpointing": { 15 | "partition_activations": true, 16 | "cpu_checkpointing": true, 17 | "contiguous_memory_optimization": true 18 | }, 19 | "gradient_accumulation_steps": 1, 20 | "train_micro_batch_size_per_gpu": 1, 21 | "zero_allow_untested_optimizer": true, 22 | "zero_force_ds_cpu_optimizer": true 23 | } 24 | -------------------------------------------------------------------------------- /generate_trajectory/data/3_str_seq_length_filtering_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | def main(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--input_path", required=True) 7 | parser.add_argument("--output_path", required=True, help="Output for kept entries.") 8 | parser.add_argument("--max-seq-len", type=int, required=True) 9 | args = parser.parse_args() 10 | 11 | with open(args.input_path, "r", encoding="utf-8") as fin, \ 12 | open(args.output_path, "w", encoding="utf-8") as fout_keep: 13 | for line in fin: 14 | item = json.loads(line) 15 | seq = item.get("complete_training_sequence_ids", []) 16 | if len(seq) <= args.max_seq_len: 17 | print(f"Keeping item {item['data_id']} with sequence length {len(seq)}") 18 | fout_keep.write(json.dumps(item) + "\n") 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /JacobiForcing/scripts/train/baseline_sft_train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 2 | export WANDB_PROJECT=cllm2_training 3 | export WANDB_RUN_NAME="sft_baseline_data_v1" 4 | 5 | export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,max_split_size_mb:256" 6 | 7 | model_path="/checkpoint/lhu/models/OpenThinker2-7B" 8 | trajectory_file="/checkpoint/lhu/data/postprocessed_trajectory_collection_merged/openhought2_sft_bs_k8s_postprocessed_merged_all.jsonl" 9 | output_path="/checkpoint/lhu/train_ckpts/cllm/sft_baseline_data_v1" 10 | qlora=False 11 | 12 | torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=101 \ 13 | --rdzv_endpoint='localhost:5666' \ 14 | --master_port 10000 \ 15 | train/baseline_sft_train.py \ 16 | --target_model_path ${model_path} \ 17 | --data_path ${trajectory_file} \ 18 | --output_dir ${output_path} \ 19 | --bf16 True \ 20 | --report_to wandb \ 21 | --do_train \ 22 | --num_train_epochs 1 \ 23 | --per_device_train_batch_size 1 \ 24 | --gradient_accumulation_steps 1 \ 25 | --gradient_checkpointing True \ 26 | --save_strategy "steps" \ 27 | --save_steps 500 \ 28 | --save_total_limit 2 \ 29 | --learning_rate 2e-5 \ 30 | --weight_decay 0. \ 31 | --warmup_ratio 0.03 \ 32 | --lr_scheduler_type "cosine" \ 33 | --logging_steps 10 \ 34 | --model_max_length 16384 \ 35 | --qlora ${qlora} 36 | -------------------------------------------------------------------------------- /JacobiForcing/scripts/train/train_clean_context_conditioned_cllm_openthinker2_n64.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 2 | export WANDB_PROJECT=cllm2_training 3 | 4 | export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,max_split_size_mb:256" 5 | 6 | model_path="/checkpoint/lhu/models/OpenThinker2-7B" 7 | trajectory_file="/checkpoint/lhu/data/postprocessed_trajectory_collection_merged/14k_filtered_output_postprocessed_merged_all.jsonl" 8 | output_path="/checkpoint/lhu/train_ckpts/cllm/one-pass-efficient-train-cllm-openthinker2-7B-ntok64_cllm_soft_loss_length_capped_16k_flexattn_siqi_data_1_ar_10" 9 | n_token_seq_size=64 10 | qlora=False 11 | 12 | torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=101 \ 13 | --rdzv_endpoint='localhost:5666' \ 14 | --master_port 10000 \ 15 | train/soft_flexattn_train_cllm.py \ 16 | --target_model_path ${model_path} \ 17 | --data_path ${trajectory_file} \ 18 | --output_dir ${output_path} \ 19 | --max_new_tokens ${n_token_seq_size} \ 20 | --bf16 True \ 21 | --report_to wandb \ 22 | --do_train \ 23 | --num_train_epochs 1 \ 24 | --per_device_train_batch_size 1 \ 25 | --gradient_accumulation_steps 1 \ 26 | --gradient_checkpointing True \ 27 | --save_strategy "steps" \ 28 | --save_steps 500 \ 29 | --save_total_limit 2 \ 30 | --learning_rate 5e-6 \ 31 | --weight_decay 0. \ 32 | --warmup_ratio 0.03 \ 33 | --lr_scheduler_type "cosine" \ 34 | --logging_steps 10 \ 35 | --model_max_length 16384 \ 36 | --lazy_preprocess True \ 37 | --qlora ${qlora} 38 | -------------------------------------------------------------------------------- /JacobiForcing/scripts/train/train_jacobi_forcing_coder_n32.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | export WANDB_PROJECT=cllm2_training 3 | 4 | export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,max_split_size_mb:256" 5 | 6 | model_path="/data/numa0/train-tests/models/progressive_noise_cllm2_mask_1m_steps" 7 | #trajectory_file="/checkpoint/lhu/data/CLLM2_openthought/merged/merged_data_v2_8_27_opencodeinstruct.jsonl" 8 | trajectory_file="/data/numa0/train-tests/data/cllm2_packed_data/merged-traj-data-oct-16-n32w16.jsonl" 9 | output_path="/data/numa0/train-tests/ckpts/v2/shiftedattn-11-19-cllm-qwen2p5-coder-7B-n16w16-n32w16_distill_data_v2_ar_1_only_cyclic_progressive_noise_all_lr1e-5" 10 | n_token_seq_size=32 11 | qlora=False 12 | 13 | torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=101 \ 14 | --rdzv_endpoint='localhost:5666' \ 15 | --master_port 10000 \ 16 | train/soft_flexattn_train_cllm_multiblock.py \ 17 | --target_model_path ${model_path} \ 18 | --data_path ${trajectory_file} \ 19 | --output_dir ${output_path} \ 20 | --max_new_tokens ${n_token_seq_size} \ 21 | --bf16 True \ 22 | --report_to wandb \ 23 | --do_train \ 24 | --num_train_epochs 1 \ 25 | --per_device_train_batch_size 1 \ 26 | --gradient_accumulation_steps 1 \ 27 | --gradient_checkpointing True \ 28 | --save_strategy "steps" \ 29 | --save_steps 5000 \ 30 | --save_total_limit 8 \ 31 | --learning_rate 1e-5 \ 32 | --weight_decay 0. \ 33 | --warmup_ratio 0.03 \ 34 | --lr_scheduler_type "cosine" \ 35 | --logging_steps 10 \ 36 | --model_max_length 16384 \ 37 | --qlora ${qlora} 38 | -------------------------------------------------------------------------------- /JacobiForcing/scripts/train/train_jacobi_forcing_coder_n64.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | export WANDB_PROJECT=cllm2_training 3 | 4 | export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,max_split_size_mb:256" 5 | 6 | model_path="/data/numa0/train-tests/models/progressive_noise_cllm2_mask_1m_steps" 7 | #trajectory_file="/checkpoint/lhu/data/CLLM2_openthought/merged/merged_data_v2_8_27_opencodeinstruct.jsonl" 8 | trajectory_file="/data/numa0/train-tests/data/merged_oct_22_opencoderinstruct_qwen2.5coder_7b_n64/merged_packed_sequence_10_22_distill_n64w32_data.jsonl" 9 | output_path="/data/numa0/train-tests/ckpts/v2/shiftedattn-11-19-cllm-qwen2p5-coder-7B-n16w16-n64w32_distill_data_n64v2_ar_1_only_cyclic_progressive_noise_all_lr5e-7" 10 | n_token_seq_size=64 11 | qlora=False 12 | 13 | torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=101 \ 14 | --rdzv_endpoint='localhost:5666' \ 15 | --master_port 10000 \ 16 | train/soft_flexattn_train_cllm_multiblock.py \ 17 | --target_model_path ${model_path} \ 18 | --data_path ${trajectory_file} \ 19 | --output_dir ${output_path} \ 20 | --max_new_tokens ${n_token_seq_size} \ 21 | --bf16 True \ 22 | --report_to wandb \ 23 | --do_train \ 24 | --num_train_epochs 1 \ 25 | --per_device_train_batch_size 1 \ 26 | --gradient_accumulation_steps 1 \ 27 | --gradient_checkpointing True \ 28 | --save_strategy "steps" \ 29 | --save_steps 5000 \ 30 | --save_total_limit 8 \ 31 | --learning_rate 1e-6 \ 32 | --weight_decay 0. \ 33 | --warmup_ratio 0.03 \ 34 | --lr_scheduler_type "cosine" \ 35 | --logging_steps 10 \ 36 | --model_max_length 16384 \ 37 | --qlora ${qlora} 38 | -------------------------------------------------------------------------------- /generate_trajectory/generation/generate_trajectory_opencodeinstruct_nongreedy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ===== Config ===== 4 | json_files=( 5 | "/data/nfs01/lanxiang/data/OpenCodeInstruct_bucketed/bucket_0016_avg328_min325_max330.json" 6 | "/data/nfs01/lanxiang/data/OpenCodeInstruct_bucketed/bucket_0017_avg332_min330_max334.json" 7 | "/data/nfs01/lanxiang/data/OpenCodeInstruct_bucketed/bucket_0018_avg336_min334_max338.json" 8 | "/data/nfs01/lanxiang/data/OpenCodeInstruct_bucketed/bucket_0019_avg341_min338_max343.json" 9 | ) 10 | 11 | save_path="/data/nfs01/lanxiang/data/CLLM2_data_prep/trajectory_bs_k8s_opencoderinstruct" 12 | log_file="/data/nfs01/lanxiang/data/cllm_logs/generate_trajectory_bs_k8s_batch_1_data16_to_data19.log" 13 | # ===== Config ===== 14 | 15 | model_path="/home/ubuntu/Qwen2.5-Coder-7B-Instruct" 16 | n_token_seq_len=32 17 | max_new_seq_len=2048 18 | data_start_id=0 19 | data_eos_id=25000 20 | batch_size=128 21 | 22 | # ===== Launch jobs ===== 23 | for i in "${!json_files[@]}"; do 24 | cuda_device=$i 25 | echo "Device CUDA: ${i}" 26 | filename="${json_files[$i]}" 27 | 28 | # Each GPU gets one file 29 | echo "Launching process on CUDA:${cuda_device} for file ${filename}" 30 | 31 | CUDA_VISIBLE_DEVICES=${cuda_device} python3 generate_trajectory/v2/generate_trajectory_opencodeinstruct_nongreedy.py \ 32 | --filename "${filename}" \ 33 | --model "${model_path}" \ 34 | --n_token_seq_len "${n_token_seq_len}" \ 35 | --max_new_seq_len "${max_new_seq_len}" \ 36 | --data_bos_id ${data_start_id} \ 37 | --data_eos_id ${data_eos_id} \ 38 | --batch_size "${batch_size}" \ 39 | --save_path "${save_path}" \ 40 | >"${log_file}" 2>&1 & 41 | done 42 | 43 | wait 44 | echo "All trajectory generation processes completed." 45 | -------------------------------------------------------------------------------- /generate_trajectory/generation/generate_trajectory_opencodeinstruct_greedy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ===== Config ===== 4 | json_files=( 5 | "/data/numa0/train-tests/data/OpenCodeInstruct_bucketed/bucket_0030_avg390_min387_max392.json" 6 | "/data/numa0/train-tests/data/OpenCodeInstruct_bucketed/bucket_0031_avg394_min392_max397.json" 7 | "/data/numa0/train-tests/data/OpenCodeInstruct_bucketed/bucket_0032_avg399_min397_max402.json" 8 | "/data/numa0/train-tests/data/OpenCodeInstruct_bucketed/bucket_0033_avg404_min402_max407.json" 9 | ) 10 | 11 | save_path="/data/numa0/train-tests/data/opencodeinstruct_generated_trajectory_blk32" 12 | log_file="/data/numa0/train-tests/data/cllm_logs/generate_trajectory_greedy_k8s_batch_0_data30_to_data33.log" 13 | # ===== Config ===== 14 | 15 | model_path="/data/numa0/train-tests/models/progressive_noise_cllm2_mask_1m_steps" 16 | n_token_seq_len=32 17 | max_new_seq_len=1024 18 | data_start_id=0 19 | data_eos_id=5000 20 | batch_size=1 21 | 22 | # ===== Launch jobs ===== 23 | for i in "${!json_files[@]}"; do 24 | cuda_device=$i 25 | echo "Device CUDA: ${i}" 26 | filename="${json_files[$i]}" 27 | 28 | # Each GPU gets one file 29 | echo "Launching process on CUDA:${cuda_device} for file ${filename}" 30 | 31 | CUDA_VISIBLE_DEVICES=${cuda_device} python3 generate_trajectory/v2/generate_trajectory_opencodeinstruct_greedy.py \ 32 | --filename "${filename}" \ 33 | --model "${model_path}" \ 34 | --n_token_seq_len "${n_token_seq_len}" \ 35 | --max_new_seq_len "${max_new_seq_len}" \ 36 | --data_bos_id ${data_start_id} \ 37 | --data_eos_id ${data_eos_id} \ 38 | --batch_size "${batch_size}" \ 39 | --save_path "${save_path}" \ 40 | >"${log_file}" 2>&1 & 41 | done 42 | 43 | wait 44 | echo "All trajectory generation processes completed." 45 | -------------------------------------------------------------------------------- /generate_trajectory/data/tool_merge_standalone_jsonl_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import os 4 | 5 | 6 | def merge_jsonl_files(input_files, output_file): 7 | """ 8 | Merge multiple JSONL files into a single JSONL file with a global progress bar. 9 | 10 | Args: 11 | input_files (list of str): List of input file paths. 12 | output_file (str): Path to the merged output file. 13 | """ 14 | # Count total lines across all files 15 | total_lines = 0 16 | for file in input_files: 17 | with open(file, 'r', encoding='utf-8') as f: 18 | total_lines += sum(1 for _ in f) 19 | 20 | with open(output_file, 'w', encoding='utf-8') as outfile, tqdm(total=total_lines, desc="Merging JSONL files") as pbar: 21 | for file in input_files: 22 | with open(file, 'r', encoding='utf-8') as infile: 23 | for line in infile: 24 | line = line.strip() 25 | if line: # skip empty lines 26 | try: 27 | json_obj = json.loads(line) 28 | outfile.write(json.dumps(json_obj, ensure_ascii=False) + "\n") 29 | except json.JSONDecodeError as e: 30 | print(f"Skipping invalid JSON in {file}: {e}") 31 | pbar.update(1) 32 | 33 | if __name__ == "__main__": 34 | # Example usage 35 | input_files = [ 36 | "/checkpoint/lhu/data/CLLM2_data_prep/trajectory_bs_k8s_08_27_merged/merged_all_08_27_lanxiang.jsonl", 37 | "/checkpoint/lhu/data/TRACE-08-27/merged_all_08_27_yichao.jsonl", 38 | "/checkpoint/lhu/data/opencoderinstruct_trajectory/merged/merged_all_8_27_siqi.jsonl" 39 | ] 40 | output_file = "/checkpoint/lhu/data/CLLM2_openthought/merged/merged_data_v2_8_28_raw.jsonl" 41 | merge_jsonl_files(input_files, output_file) 42 | print(f"Merged {len(input_files)} files into {output_file}") 43 | -------------------------------------------------------------------------------- /generate_trajectory/data/tool_debug_complete_training_seq_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from transformers import AutoTokenizer 3 | 4 | # Set your .jsonl file path and tokenizer name/path 5 | #json_path = "/checkpoint/lhu/data/CLLM2_openthought/merged/4k_samples_sft_length_20k_filtered_output_merged_data_v1_8_18.jsonl" 6 | jsonl_path = "/checkpoint/lhu/data/CLLM2_openthought/merged/40k_samples_merged_data_v2_8_27_opencodeinstruct_progressive_noise_cyclic_cap_idx_0p5.jsonl" 7 | 8 | tokenizer_name = "/checkpoint/lhu/models/Qwen2.5-Coder-7B-Instruct" 9 | 10 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 11 | 12 | with open(jsonl_path, "r", encoding="utf-8") as f: 13 | for i, line in enumerate(f): 14 | if i > 3: 15 | break 16 | 17 | data = json.loads(line) 18 | print(f"\ndata id: {data['data_id']}") 19 | 20 | if "complete_training_sequence_ids" in data: 21 | ids = data["complete_training_sequence_ids"] 22 | 23 | print(f"\ncomplete training sequence length: {len(ids)}") 24 | 25 | #decoded = tokenizer.decode(ids) 26 | #print(f"\ndecoded length: {len(decoded)}") 27 | 28 | print(f"\n[Line {i}] clean tokens:\n\n{ids[-16:]}") 29 | print(f"\n[Line {i}] clean decoded:\n\n{tokenizer.decode(ids[-16:])}") 30 | 31 | print(f"\n[Line {i}] noisy tokens:\n\n{ids[-32:-16]}") 32 | print(f"\n[Line {i}] noisy decoded:\n\n{tokenizer.decode(ids[-32:-16])}") 33 | 34 | elif "labels_ids" in data: 35 | 36 | ids = data["labels_ids"] 37 | print(f"\nlabels length: {len(ids)}") 38 | decoded = tokenizer.decode(ids) 39 | print(f"\n[Line {i}] Decoded:\n\n{decoded}") 40 | 41 | #print(f"\ncomplete training sequence length: {len(data['complete_training_sequence_ids'])}") 42 | #print(f"\nlabel length: {len(data['labels_ids'])}") 43 | #print(f"prompt id: {data['prompt_ids']}") 44 | #print(f"\nprompt id length: {data['prompt_ids_len'][0]}") 45 | -------------------------------------------------------------------------------- /generate_trajectory/data/-1_opencodeinstruct_data_filtering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pyarrow.parquet as pq 5 | 6 | def filter_and_rank(input_dir: str, output_path: str): 7 | """ 8 | Filters OpenCodeInstruct data entries with average_test_score == 1.0, 9 | ranks them by llm_judgement['score'], and writes sorted entries to a JSONL file. 10 | """ 11 | records = [] 12 | 13 | parquet_files = sorted( 14 | os.path.join(input_dir, fname) 15 | for fname in os.listdir(input_dir) 16 | if fname.endswith(".parquet") 17 | ) 18 | 19 | for pfile in parquet_files: 20 | print(f"Reading {pfile}...") 21 | table = pq.read_table(pfile) 22 | df = table.to_pandas() 23 | 24 | perfect = df[df['average_test_score'].astype(float) == 1.0] 25 | print(f" --> Found {len(perfect)} perfect-score records in this file.") 26 | 27 | for _, row in perfect.iterrows(): 28 | lj = json.loads(row['llm_judgement']) 29 | score = float(lj.get("score", 0)) 30 | 31 | rec = row.to_dict() 32 | rec['llm_score'] = score 33 | records.append(rec) 34 | 35 | print(f"Sorting {len(records)} records by LLM judgement score (descending)...") 36 | records.sort(key=lambda x: x['llm_score'], reverse=True) 37 | 38 | print(f"Writing sorted records to {output_path}...") 39 | with open(output_path, 'w', encoding='utf-8') as f_out: 40 | for rec in records: 41 | rec.pop('llm_score', None) 42 | f_out.write(json.dumps(rec) + "\n") 43 | print("Done.") 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser( 47 | description="Filter & rank OpenCodeInstruct entries by llm_judgement score" 48 | ) 49 | parser.add_argument( 50 | "--input_dir", 51 | required=True, 52 | help="Directory containing .parquet input files" 53 | ) 54 | parser.add_argument( 55 | "--output_path", 56 | required=True, 57 | help="Path to output .jsonl file" 58 | ) 59 | args = parser.parse_args() 60 | 61 | filter_and_rank(args.input_dir, args.output_path) 62 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.8.1 2 | aiohappyeyeballs==2.6.1 3 | aiohttp==3.12.13 4 | aiosignal==1.3.2 5 | altair==6.0.0 6 | annotated-types==0.7.0 7 | attrs==25.3.0 8 | blessed==1.21.0 9 | blinker==1.9.0 10 | cachetools==6.2.2 11 | certifi==2025.6.15 12 | charset-normalizer==3.4.2 13 | click==8.2.1 14 | contourpy==1.3.3 15 | cycler==0.12.1 16 | datasets==3.6.0 17 | deepspeed==0.17.1 18 | dill==0.3.8 19 | einops==0.8.1 20 | filelock==3.18.0 21 | flash_attn==2.8.3 22 | fonttools==4.60.1 23 | frozenlist==1.7.0 24 | fsspec==2025.3.0 25 | gitdb==4.0.12 26 | GitPython==3.1.44 27 | gpustat==1.1.1 28 | hf-xet==1.1.5 29 | hjson==3.1.0 30 | huggingface-hub==0.33.2 31 | idna==3.10 32 | Jinja2==3.1.6 33 | jsonschema==4.25.1 34 | jsonschema-specifications==2025.9.1 35 | kiwisolver==1.4.9 36 | MarkupSafe==3.0.2 37 | matplotlib==3.10.7 38 | mpmath==1.3.0 39 | msgpack==1.1.1 40 | multidict==6.6.2 41 | multiprocess==0.70.16 42 | narwhals==2.13.0 43 | networkx==3.5 44 | ninja==1.11.1.4 45 | numpy==2.3.3 46 | nvidia-cublas-cu12==12.8.3.14 47 | nvidia-cuda-cupti-cu12==12.8.57 48 | nvidia-cuda-nvrtc-cu12==12.8.61 49 | nvidia-cuda-runtime-cu12==12.8.57 50 | nvidia-cudnn-cu12==9.7.1.26 51 | nvidia-cufft-cu12==11.3.3.41 52 | nvidia-cufile-cu12==1.13.0.11 53 | nvidia-curand-cu12==10.3.9.55 54 | nvidia-cusolver-cu12==11.7.2.55 55 | nvidia-cusparse-cu12==12.5.7.53 56 | nvidia-cusparselt-cu12==0.6.3 57 | nvidia-ml-py==12.575.51 58 | nvidia-nccl-cu12==2.26.2 59 | nvidia-nvjitlink-cu12==12.8.61 60 | nvidia-nvtx-cu12==12.8.55 61 | orjson==3.11.3 62 | packaging==25.0 63 | pandas==2.3.0 64 | peft==0.15.2 65 | pillow==11.3.0 66 | platformdirs==4.3.8 67 | propcache==0.3.2 68 | protobuf==6.31.1 69 | psutil==7.0.0 70 | py-cpuinfo==9.0.0 71 | pyarrow==20.0.0 72 | pydantic==2.11.7 73 | pydantic_core==2.33.2 74 | pydeck==0.9.1 75 | pyparsing==3.2.5 76 | python-dateutil==2.9.0.post0 77 | pytz==2025.2 78 | PyYAML==6.0.2 79 | referencing==0.37.0 80 | regex==2024.11.6 81 | requests==2.32.4 82 | rpds-py==0.30.0 83 | safetensors==0.5.3 84 | scipy==1.16.3 85 | sentry-sdk==2.32.0 86 | setuptools==80.9.0 87 | six==1.17.0 88 | smmap==5.0.2 89 | streamlit==1.52.1 90 | sympy==1.14.0 91 | tenacity==9.1.2 92 | tokenizers==0.21.2 93 | toml==0.10.2 94 | torch==2.7.1+cu128 95 | tornado==6.5.2 96 | tqdm==4.67.1 97 | transformers==4.53.0 98 | triton==3.3.1 99 | typing-inspection==0.4.1 100 | typing_extensions==4.14.0 101 | tzdata==2025.2 102 | urllib3==2.5.0 103 | wandb==0.21.0 104 | watchdog==6.0.0 105 | wcwidth==0.2.13 106 | xxhash==3.5.0 107 | yarl==1.20.1 108 | -------------------------------------------------------------------------------- /generate_trajectory/data/tool_profile_trajectory_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import statistics 4 | 5 | def smart_open(path): 6 | with open(path, "r", encoding="utf-8") as f: 7 | first = f.read(1) 8 | f.seek(0) 9 | if first == "[": 10 | return json.load(f) 11 | else: 12 | return [json.loads(line) for line in f] 13 | 14 | def describe_list(lst): 15 | if not lst: 16 | return "[]" 17 | if all(isinstance(x, int) for x in lst): 18 | return f"[int] len={len(lst)}" 19 | if all(isinstance(x, list) and all(isinstance(xx, int) for xx in x) for x in lst): 20 | lens = [len(x) for x in lst] 21 | return f"[[int,…],…] {len(lst)} chunks (chunk_lens={lens})" 22 | return f"{type(lst).__name__} of {len(lst)}" 23 | 24 | def profile(file, max_rows=100): 25 | data = smart_open(file) 26 | n = min(max_rows, len(data)) 27 | print(f"Profiling {n} / {len(data)} rows from: {file}\n") 28 | 29 | # Gather per-field values 30 | fields = [ 31 | "data_id", 32 | "diffusion_itr_id", 33 | "prompt_ids_len", 34 | "prompt_ids", 35 | "answer_trajectory_ids", 36 | "teacher_output_ids", 37 | "labels_ids" 38 | ] 39 | stats = {k: [] for k in fields} 40 | for row in data[:n]: 41 | for k in fields: 42 | stats[k].append(row.get(k)) 43 | 44 | # Print the schema-aligned summary 45 | def ex(x): # get an example value 46 | for v in x: 47 | if v is not None: 48 | return v 49 | 50 | for k in fields: 51 | v = stats[k] 52 | sample = ex(v) 53 | # Types 54 | if k in ("data_id", "diffusion_itr_id"): 55 | print(f'{k:<22} str\t\te.g. "{sample}"') 56 | elif k == "prompt_ids_len": 57 | print(f'{k:<22} [int]\t\te.g. {sample}') 58 | elif k == "prompt_ids": 59 | lens = [len(x) for x in v if isinstance(x, list)] 60 | print(f'{k:<22} [int,…]\tlen={statistics.mean(lens):.1f} (min={min(lens)}, max={max(lens)})') 61 | elif k == "answer_trajectory_ids": 62 | if isinstance(sample, list) and sample and isinstance(sample[0], list): 63 | num_chunks = [len(x) for x in v if isinstance(x, list)] 64 | flat_lens = [[len(xx) for xx in x] for x in v if isinstance(x, list)] 65 | print(f'{k:<22} [[int,…],…]\tnum_chunks={statistics.mean(num_chunks):.1f} per row; chunk_lens (first row)={flat_lens[0] if flat_lens else "[]"}') 66 | else: 67 | print(f'{k:<22} [[int,…],…]\tempty or missing') 68 | elif k in ("teacher_output_ids", "labels_ids"): 69 | lens = [len(x) for x in v if isinstance(x, list)] 70 | print(f'{k:<22} [int,…]\tlen={statistics.mean(lens):.1f} (min={min(lens)}, max={max(lens)})') 71 | print("\nDone.") 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument("--file") 76 | parser.add_argument("--max_rows", type=int, default=100) 77 | args = parser.parse_args() 78 | profile(args.file, args.max_rows) 79 | -------------------------------------------------------------------------------- /generate_trajectory/data/2_prepare_baseline_training_data_sft_reverse_engineering.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse, json, sys 3 | from typing import Any, Dict, List 4 | from tqdm import tqdm 5 | 6 | def parse_args(): 7 | p = argparse.ArgumentParser( 8 | description="Reconstruct labels_ids = prompt_ids + concat_of_all(last_j) from preprocessed JSONL." 9 | ) 10 | p.add_argument("--input_path", required=True, help="Input JSONL from the previous stage.") 11 | p.add_argument("--output_path", required=True, help="Output JSONL with added labels_ids.") 12 | p.add_argument("--n-token-seq-length", type=int, required=True, 13 | help="Same n used to build the pairs: each pair = sampled(n) + last(n).") 14 | p.add_argument("--best-effort", action="store_true", 15 | help="If set, tolerate length mismatches by slicing whatever is available.") 16 | return p.parse_args() 17 | 18 | def reconstruct_labels(entry: Dict[str, Any], n: int, best_effort: bool) -> List[int]: 19 | """ 20 | Given one line's dict with: 21 | - prompt_ids: List[int] 22 | - complete_training_sequence_ids: List[int] = prompt_ids + concat over pairs of [sampled_kj | last_j] 23 | - traj_position_indices: List[int] (length == number of pairs) 24 | - prompt_ids_len: int (optional; will verify if present) 25 | return labels_ids = prompt_ids + concat_of_all(last_j). 26 | """ 27 | prompt: List[int] = entry["prompt_ids"] 28 | prompt_len_reported = entry.get("prompt_ids_len", len(prompt)) 29 | 30 | full_seq: List[int] = entry["complete_training_sequence_ids"] 31 | tail_after_prompt = full_seq[len(prompt):] 32 | 33 | num_pairs = len(entry.get("traj_position_indices", [])) 34 | expected_tail_len = 2 * n * num_pairs 35 | 36 | labels_tail: List[int] = [] 37 | pos = 0 38 | for i in range(num_pairs): 39 | block = tail_after_prompt[pos:pos + 2 * n] if pos < len(tail_after_prompt) else [] 40 | last_j = block[-n:] if len(block) >= n else block 41 | labels_tail.extend(last_j) 42 | pos += 2 * n 43 | 44 | labels_ids = list(prompt) + labels_tail 45 | return labels_ids 46 | 47 | def main(): 48 | args = parse_args() 49 | n = args.n_token_seq_length 50 | 51 | count_in, count_out = 0, 0 52 | with open(args.input_path, "r", encoding="utf-8") as fin, \ 53 | open(args.output_path, "w", encoding="utf-8") as fout: 54 | for line in tqdm(fin, desc="Reconstructing labels_ids"): 55 | line = line.strip() 56 | if not line: 57 | continue 58 | count_in += 1 59 | obj = json.loads(line) 60 | 61 | labels_ids = reconstruct_labels(obj, n=n, best_effort=args.best_effort) 62 | obj.pop("complete_training_sequence_ids") 63 | obj.pop("traj_position_indices") 64 | obj["labels_ids"] = labels_ids 65 | 66 | # obj["labels_ids_len"] = len(labels_ids) 67 | 68 | fout.write(json.dumps(obj, ensure_ascii=False)) 69 | fout.write("\n") 70 | count_out += 1 71 | 72 | print(f"Read {count_in} lines, wrote {count_out} with labels_ids --> {args.output_path}") 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /JacobiForcing/scripts/inference/scanning_hyperparameter_jacobi_decoding_mr.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -euo pipefail 3 | 4 | PROFILE_PY="jacobi_forcing_inference_MR_humaneval_config_grid_search.py" 5 | 6 | DF_FILE="/home/lah003/data/openai_humaneval/openai_humaneval/test-00000-of-00001.parquet" 7 | MODEL_NAME="/raid/lah003/shiftedattn-10-16-7b-qwen2p5-coder-n32w16-n16distill-data-v2-ar-1-cyclic-noise-all-1e-6/ckpt-344092" 8 | TOKENIZER_NAME="/home/lah003/models/Qwen2.5-Coder-7B-Instruct" 9 | 10 | CSV_DIR="profiling_results" 11 | EVAL_DIR="/home/lah003/data/CLLM2_eval_generations/multiblock_testing_prompt/scanning_generation" 12 | 13 | MAX_CALLS=64 14 | MAX_NEW_TOKENS=1024 15 | 16 | LOG_DIR="logs/hparam_sweep_$(date +%Y%m%d_%H%M%S)" 17 | mkdir -p "$LOG_DIR" "$CSV_DIR" 18 | 19 | block_sizes=(256 128 64 32 16 8) 20 | Ks=(1 2 3 4) 21 | pool_sizes=(1 2 4 8 12) 22 | 23 | # r sweep: from 0.50 to 0.95, in increments of 0.5 24 | r_values=() 25 | while read -r val; do r_values+=("$val"); done < <(seq 0.50 0.05 0.95) 26 | 27 | # --------------------------- 28 | # GPU semaphore 29 | # --------------------------- 30 | GPUS=(0 1 2 3 4 5) 31 | NUM_GPUS=${#GPUS[@]} 32 | 33 | FIFO="$(mktemp -u)" 34 | mkfifo "$FIFO" 35 | exec 9<>"$FIFO" 36 | rm "$FIFO" 37 | 38 | # seed tokens 39 | for g in "${GPUS[@]}"; do 40 | echo "$g" >&9 41 | done 42 | 43 | pids=() 44 | fails=0 45 | 46 | launch_one() { 47 | local n="$1" K="$2" r_fmt="$3" ng="$4" 48 | 49 | local run_id="ntok${n}_K${K}_r${r_fmt}_ng${ng}" 50 | local csv_path="${CSV_DIR}/${run_id}_diffusion_profile_humaneval.csv" 51 | local log_path="${LOG_DIR}/${run_id}.log" 52 | 53 | # acquire a GPU token 54 | local gpu 55 | read -r gpu <&9 56 | 57 | echo "========= LAUNCH $run_id on GPU $gpu =========" 58 | 59 | ( 60 | set +e 61 | CUDA_VISIBLE_DEVICES="$gpu" python3 "$PROFILE_PY" \ 62 | "$DF_FILE" \ 63 | "$MODEL_NAME" \ 64 | "$TOKENIZER_NAME" \ 65 | "$csv_path" \ 66 | "$MAX_CALLS" \ 67 | "$MAX_NEW_TOKENS" \ 68 | --n_token_seq_len "$n" \ 69 | --K "$K" \ 70 | --r "$r_fmt" \ 71 | --n_gram_pool_size "$ng" \ 72 | --eval_dir "$EVAL_DIR" \ 73 | --out_prefix "$run_id" \ 74 | > "$log_path" 2>&1 75 | rc=$? 76 | 77 | if [[ $rc -ne 0 ]]; then 78 | echo "FAILED $run_id on GPU $gpu (see $log_path)" >&2 79 | fi 80 | 81 | # release GPU token 82 | echo "$gpu" >&9 83 | exit $rc 84 | ) & 85 | 86 | pids+=("$!") 87 | } 88 | 89 | # --------------------------- 90 | # Sweep 91 | # --------------------------- 92 | for n in "${block_sizes[@]}"; do 93 | for K in "${Ks[@]}"; do 94 | for r in "${r_values[@]}"; do 95 | r_fmt=$(printf "%.2f" "$r") 96 | for ng in "${pool_sizes[@]}"; do 97 | launch_one "$n" "$K" "$r_fmt" "$ng" 98 | done 99 | done 100 | done 101 | done 102 | 103 | # wait all 104 | for pid in "${pids[@]}"; do 105 | if wait "$pid"; then 106 | : 107 | else 108 | fails=$((fails + 1)) 109 | fi 110 | done 111 | 112 | exec 9>&- 113 | 114 | echo "Sweep complete. Logs in $LOG_DIR, CSVs in $CSV_DIR" 115 | if ((fails > 0)); then 116 | echo "WARNING: $fails runs failed. Check logs." 117 | exit 1 118 | fi 119 | echo "All runs succeeded." 120 | -------------------------------------------------------------------------------- /generate_trajectory/data/3_downsample_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse, json, os, random, sys 3 | 4 | def detect_input_kind(path): 5 | # Return "jsonl" if not starting with '[', otherwise "json" 6 | with open(path, "r", encoding="utf-8") as f: 7 | while True: 8 | ch = f.read(1) 9 | if not ch: 10 | return "jsonl" # empty -> treat as jsonl 11 | if ch.isspace(): 12 | continue 13 | return "json" if ch == "[" else "jsonl" 14 | 15 | def load_json_array(path): 16 | with open(path, "r", encoding="utf-8") as f: 17 | return json.load(f) 18 | 19 | def read_jsonl(path): 20 | with open(path, "r", encoding="utf-8") as f: 21 | for lineno, line in enumerate(f, start=1): 22 | line = line.strip() 23 | if not line: 24 | continue 25 | try: 26 | yield json.loads(line) 27 | except json.JSONDecodeError as e: 28 | sys.stderr.write(f"[warn] Skipping invalid JSON on line {lineno}: {e}\n") 29 | 30 | print(f"total read size: {lineno}") 31 | 32 | def write_jsonl(path, items): 33 | with open(path, "w", encoding="utf-8") as f: 34 | for obj in items: 35 | f.write(json.dumps(obj, ensure_ascii=False)) 36 | f.write("\n") 37 | 38 | def write_json_array(path, items): 39 | with open(path, "w", encoding="utf-8") as f: 40 | json.dump(items, f, ensure_ascii=False, indent=2) 41 | 42 | def reservoir_sample(iterable, k, rng): 43 | """Return up to k items sampled uniformly without knowing the total size.""" 44 | reservoir = [] 45 | for i, item in enumerate(iterable, start=1): 46 | if i <= k: 47 | reservoir.append(item) 48 | else: 49 | j = rng.randrange(i) 50 | if j < k: 51 | reservoir[j] = item 52 | return reservoir 53 | 54 | def first_k(iterable, k): 55 | out = [] 56 | for item in iterable: 57 | if len(out) >= k: 58 | break 59 | out.append(item) 60 | return out 61 | 62 | def main(): 63 | p = argparse.ArgumentParser(description="Downsample a JSONL or JSON array file to a specified size.") 64 | p.add_argument("--input", help="Input file path (JSONL or a JSON array).") 65 | p.add_argument("--output", help="Output file path.") 66 | p.add_argument("-n", "--size", type=int, required=True, help="Target number of records.") 67 | p.add_argument("--method", choices=["random", "first"], default="random", help="Sampling method.") 68 | p.add_argument("--seed", type=int, default=None, help="Random seed (for --method random).") 69 | p.add_argument("--output-format", 70 | choices=["jsonl", "json"], 71 | default="jsonl", 72 | help="Output format (default: jsonl).") 73 | args = p.parse_args() 74 | 75 | if args.size <= 0: 76 | sys.stderr.write("[error] --size must be a positive integer.\n") 77 | sys.exit(2) 78 | 79 | rng = random.Random(args.seed) 80 | 81 | kind = detect_input_kind(args.input) 82 | 83 | # Read input as an iterator of dicts 84 | if kind == "json": 85 | data = load_json_array(args.input) 86 | if not isinstance(data, list): 87 | sys.stderr.write("[error] JSON input is not an array.\n") 88 | sys.exit(2) 89 | iterable = data 90 | else: 91 | iterable = read_jsonl(args.input) 92 | 93 | # Sample 94 | if args.method == "first": 95 | sampled = first_k(iterable, args.size) 96 | else: 97 | sampled = reservoir_sample(iterable, args.size, rng) 98 | 99 | # Write output 100 | if args.output_format == "json": 101 | write_json_array(args.output, sampled) 102 | else: 103 | write_jsonl(args.output, sampled) 104 | 105 | # Small summary to stderr 106 | sys.stderr.write(f"[ok] Wrote {len(sampled)} records to {args.output} " 107 | f"({args.output_format}), method={args.method}, seed={args.seed}\n") 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /generate_trajectory/data/3_postprocessing_data_length_filtering.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse, json, os, sys, gzip 3 | from typing import Iterable, Tuple, Union 4 | from transformers import AutoTokenizer 5 | 6 | Field = "complete_training_sequence_ids" 7 | 8 | def open_maybe_gz(path: str, mode: str): 9 | if path.endswith(".gz"): 10 | return gzip.open(path, mode + "t", encoding="utf-8") 11 | return open(path, mode, encoding="utf-8") 12 | 13 | def is_jsonl(path: str) -> bool: 14 | return path.endswith(".jsonl") or path.endswith(".jsonl.gz") 15 | 16 | def load_json(path: str) -> list: 17 | with open_maybe_gz(path, "r") as f: 18 | data = json.load(f) 19 | if not isinstance(data, list): 20 | raise ValueError(f"{path} must contain a JSON array of records.") 21 | return data 22 | 23 | def iter_jsonl(path: str) -> Iterable[str]: 24 | with open_maybe_gz(path, "r") as f: 25 | for line in f: 26 | line = line.strip() 27 | if line: 28 | yield line 29 | 30 | def count_tokens(tokenizer, value: Union[list, str]) -> int: 31 | if isinstance(value, list): 32 | return len(value) 33 | if isinstance(value, str): 34 | # count tokens without adding special tokens 35 | return len(tokenizer(value, add_special_tokens=False).input_ids) 36 | # Unknown type → treat as zero to keep the record 37 | return 0 38 | 39 | def filter_stream_jsonl(in_path: str, out_path: str, tokenizer, field: str, threshold: int) -> Tuple[int, int]: 40 | kept = dropped = 0 41 | with open_maybe_gz(out_path, "w") as out_f: 42 | for line in iter_jsonl(in_path): 43 | obj = json.loads(line) 44 | tok_val = obj.get(field, None) 45 | n = count_tokens(tokenizer, tok_val) if tok_val is not None else 0 46 | if n > threshold: 47 | dropped += 1 48 | else: 49 | out_f.write(json.dumps(obj, ensure_ascii=False) + "\n") 50 | kept += 1 51 | return kept, dropped 52 | 53 | def filter_json_array(in_path: str, out_path: str, tokenizer, field: str, threshold: int) -> Tuple[int, int]: 54 | data = load_json(in_path) 55 | kept_data = [] 56 | kept = dropped = 0 57 | for obj in data: 58 | tok_val = obj.get(field, None) 59 | n = count_tokens(tokenizer, tok_val) if tok_val is not None else 0 60 | if n > threshold: 61 | dropped += 1 62 | else: 63 | kept_data.append(obj) 64 | kept += 1 65 | with open_maybe_gz(out_path, "w") as f: 66 | json.dump(kept_data, f, ensure_ascii=False) 67 | return kept, dropped 68 | 69 | def main(): 70 | ap = argparse.ArgumentParser(description="Filter records by token length.") 71 | ap.add_argument("--input", help="Path to input .json / .jsonl (optionally .gz)") 72 | ap.add_argument("-o", "--output", help="Path to output (defaults to *_filtered.json/.jsonl)") 73 | ap.add_argument("--model", default="/checkpoint/lhu/models/OpenThinker2-7B", help="Tokenizer path") 74 | ap.add_argument("--field", default=Field, help="Field containing IDs or text") 75 | ap.add_argument("--threshold", type=int, default=20480, help="Max allowed tokens (strictly > is dropped)") 76 | args = ap.parse_args() 77 | 78 | in_path = args.input 79 | if not os.path.exists(in_path): 80 | print(f"Input not found: {in_path}", file=sys.stderr) 81 | sys.exit(1) 82 | if args.output: 83 | out_path = args.output 84 | else: 85 | base, ext = os.path.splitext(in_path) 86 | if ext == ".gz": 87 | base2, ext2 = os.path.splitext(base) 88 | ext = ext2 + ext 89 | base = base2 90 | out_path = f"{base}_filtered{ext or '.json'}" 91 | 92 | print(f"Loading tokenizer from {args.model} ...", file=sys.stderr) 93 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True) 94 | 95 | print(f"Filtering '{in_path}' --> '{out_path}' using field '{args.field}' with threshold {args.threshold} tokens", file=sys.stderr) 96 | 97 | if is_jsonl(in_path): 98 | kept, dropped = filter_stream_jsonl(in_path, out_path, tokenizer, args.field, args.threshold) 99 | else: 100 | kept, dropped = filter_json_array(in_path, out_path, tokenizer, args.field, args.threshold) 101 | 102 | print(f"Done. kept={kept}, dropped={dropped}, total={kept + dropped}", file=sys.stderr) 103 | 104 | if __name__ == "__main__": 105 | main() 106 | -------------------------------------------------------------------------------- /generate_trajectory/data/tool_merge_ds_ckpts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from transformers import AutoConfig, Qwen2ForCausalLM 4 | from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint 5 | 6 | def normalize_ckpt_keys(state): 7 | new_state = {} 8 | changed = 0 9 | for k, v in state.items(): 10 | new_k = k 11 | # strip repeatedly in case of 'module.module.' etc. 12 | while True: 13 | if new_k.startswith("module."): 14 | new_k = new_k[len("module."):] 15 | changed += 1 16 | continue 17 | if new_k.startswith("_fsdp_wrapped_module."): 18 | new_k = new_k[len("_fsdp_wrapped_module."):] 19 | changed += 1 20 | continue 21 | break 22 | new_state[new_k] = v 23 | print(f"Stripped prefixes on {changed} keys.") 24 | # sanity check 25 | if "module.lm_head.weight" in state and "lm_head.weight" not in new_state: 26 | raise RuntimeError("Expected lm_head.weight after prefix strip, but didn't find it.") 27 | return new_state 28 | 29 | ckpt_parent = "/data/numa0/train-tests/ckpts/v2/shiftedattn-10-16-cllm-qwen2p5-coder-7B-ntok16-distill32_distill_data_v2_ar_1_only_cyclic_progressive_noise_all_lr1e-6/checkpoint-212000" 30 | ckpt_dir = os.path.join(ckpt_parent, "") 31 | tag = "global_step212000" 32 | output_dtype = torch.bfloat16 33 | output_dir = os.path.join(ckpt_dir, "hf_merged_step_212000") 34 | os.makedirs(output_dir, exist_ok=True) 35 | 36 | # 1) Reconstruct FP32 state_dict 37 | fp32_state = get_fp32_state_dict_from_zero_checkpoint(ckpt_parent, tag=tag, lazy_mode=False) 38 | print(f"Successfully loaded {len(fp32_state)} tensors from DeepSpeed ZeRO checkpoint.") 39 | 40 | # 2) Convert to bfloat16 41 | bf16_state = {k: v.to(dtype=output_dtype) for k, v in fp32_state.items()} 42 | bf16_state = normalize_ckpt_keys(bf16_state) 43 | 44 | # 3) init HF model 45 | config = AutoConfig.from_pretrained(ckpt_parent) 46 | model = Qwen2ForCausalLM(config).to(dtype=output_dtype) 47 | 48 | model_sd = model.state_dict() 49 | model_keys = set(model_sd.keys()) 50 | ckpt_keys = set(bf16_state.keys()) 51 | 52 | common = sorted(model_keys & ckpt_keys) 53 | missing_before = sorted(model_keys - ckpt_keys) 54 | unexpected_before = sorted(ckpt_keys - model_keys) 55 | 56 | shape_mismatches = [] 57 | for k in common: 58 | if model_sd[k].shape != bf16_state[k].shape: 59 | shape_mismatches.append((k, tuple(model_sd[k].shape), tuple(bf16_state[k].shape))) 60 | 61 | def head(items, n=20): 62 | return items[:n] if len(items) > n else items 63 | 64 | print(f"\n=== Key comparison (pre-load) ===") 65 | print(f"Model keys: {len(model_keys)}") 66 | print(f"Checkpoint keys: {len(ckpt_keys)}") 67 | print(f"Common keys: {len(common)}") 68 | print(f"Missing keys: {len(missing_before)}") 69 | print(f"Unexpected keys: {len(unexpected_before)}") 70 | print(f"Shape mismatches:{len(shape_mismatches)}\n") 71 | 72 | if missing_before: 73 | print("• Missing (first few):") 74 | for k in head(missing_before): print(" -", k) 75 | if unexpected_before: 76 | print("\n• Unexpected (first few):") 77 | for k in head(unexpected_before): print(" -", k) 78 | if shape_mismatches: 79 | print("\n• Shape mismatches (first few):") 80 | for k, mshape, cshape in head(shape_mismatches): 81 | print(f" - {k}: model {mshape} vs ckpt {cshape}") 82 | 83 | # Save reports 84 | with open(os.path.join(output_dir, "missing_keys.txt"), "w") as f: 85 | f.write("\n".join(missing_before)) 86 | with open(os.path.join(output_dir, "unexpected_keys.txt"), "w") as f: 87 | f.write("\n".join(unexpected_before)) 88 | with open(os.path.join(output_dir, "shape_mismatches.tsv"), "w") as f: 89 | for k, mshape, cshape in shape_mismatches: 90 | f.write(f"{k}\t{mshape}\t{cshape}\n") 91 | 92 | # 4) Load with strict=False so we can proceed while inspecting diffs 93 | missing_after, unexpected_after = model.load_state_dict(bf16_state, strict=False) 94 | 95 | print(f"\n=== load_state_dict results (PyTorch report) ===") 96 | print(f"Missing: {len(missing_after)}") 97 | print(f"Unexpected:{len(unexpected_after)}") 98 | if missing_after: 99 | print("• Missing after load (first few):") 100 | for k in head(sorted(missing_after)): print(" -", k) 101 | if unexpected_after: 102 | print("\n• Unexpected after load (first few):") 103 | for k in head(sorted(unexpected_after)): print(" -", k) 104 | 105 | # 5) Save merged model 106 | model.save_pretrained(output_dir, safe_serialization=True) 107 | print(f"\n Merged model saved to: {output_dir}") 108 | print(f" Load with: Qwen2ForCausalLM.from_pretrained('{output_dir}')") 109 | print(f" Full reports: missing_keys.txt, unexpected_keys.txt, shape_mismatches.tsv in {output_dir}") 110 | -------------------------------------------------------------------------------- /JacobiForcing/scripts/tool/plot_inference_configuration_search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import math 4 | from pathlib import Path 5 | 6 | import matplotlib.pyplot as plt 7 | import pandas as pd 8 | 9 | 10 | def load_data(jsonl_path: Path) -> pd.DataFrame: 11 | df = pd.read_json(jsonl_path, lines=True) 12 | 13 | # basic sanity checks 14 | required_cols = ["K", "r", "block_size", "ngram_size", "avg_toks_per_sec"] 15 | missing = [c for c in required_cols if c not in df.columns] 16 | if missing: 17 | raise ValueError(f"Missing required columns in JSONL: {missing}") 18 | 19 | return df 20 | 21 | 22 | def filter_for_hparams(df: pd.DataFrame, K: int, r: float, r_tol: float = 1e-3): 23 | """Filter dataframe for given K and r (with tolerance on r).""" 24 | mask = (df["K"] == K) & (df["r"].sub(r).abs() <= r_tol) 25 | filtered = df[mask].copy() 26 | if filtered.empty: 27 | raise ValueError(f"No rows found with K={K} and r≈{r}") 28 | return filtered 29 | 30 | 31 | def aggregate_tokens_per_sec(df: pd.DataFrame) -> pd.DataFrame: 32 | """Average tokens/s in case of duplicates for same config.""" 33 | return ( 34 | df.groupby(["block_size", "ngram_size"], as_index=False)["avg_toks_per_sec"] 35 | .mean() 36 | .rename(columns={"avg_toks_per_sec": "toks_per_sec"}) 37 | ) 38 | 39 | 40 | def plot_tokens_vs_block_size(df_agg: pd.DataFrame, out_path: Path): 41 | """ 42 | First set: 43 | x-axis: block_size 44 | y-axis: tokens/s 45 | one curve per ngram_size 46 | 47 | Color scheme: darker for smaller ngram_size, lighter for larger ngram_size. 48 | """ 49 | plt.figure() 50 | 51 | ngram_values = sorted(df_agg["ngram_size"].unique()) 52 | cmap = plt.get_cmap("Blues_r") # reversed: small -> dark, large -> light 53 | n = len(ngram_values) 54 | denom = max(1, n - 1) 55 | 56 | for i, ngram_size in enumerate(ngram_values): 57 | sub = df_agg[df_agg["ngram_size"] == ngram_size] 58 | sub_sorted = sub.sort_values("block_size") 59 | 60 | # normalized position in [0,1] 61 | t = i / denom 62 | color = cmap(t) 63 | 64 | plt.plot( 65 | sub_sorted["block_size"], 66 | sub_sorted["toks_per_sec"], 67 | marker="o", 68 | label=f"ngram={ngram_size}", 69 | color=color, 70 | ) 71 | 72 | plt.xlabel("Block size (n_token_seq_len)") 73 | plt.ylabel("Tokens / second") 74 | plt.title("Tokens/s vs Block Size (fixed K=2, r=0.85)") 75 | plt.grid(True, alpha=0.3) 76 | plt.legend(title="N-gram size") 77 | plt.tight_layout() 78 | plt.savefig(out_path) 79 | print(f"[PLOT] Saved tokens_vs_block_size to {out_path}") 80 | 81 | 82 | def plot_tokens_vs_ngram_size(df_agg: pd.DataFrame, out_path: Path): 83 | """ 84 | Second set: 85 | x-axis: ngram_size 86 | y-axis: tokens/s 87 | one curve per block_size 88 | 89 | Color scheme: darker for smaller block_size, lighter for larger block_size. 90 | """ 91 | plt.figure() 92 | 93 | block_values = sorted(df_agg["block_size"].unique()) 94 | cmap = plt.get_cmap("Greens_r") # reversed: small -> dark, large -> light 95 | n = len(block_values) 96 | denom = max(1, n - 1) 97 | 98 | for i, block_size in enumerate(block_values): 99 | sub = df_agg[df_agg["block_size"] == block_size] 100 | sub_sorted = sub.sort_values("ngram_size") 101 | 102 | # normalized position in [0,1] 103 | t = i / denom 104 | color = cmap(t) 105 | 106 | plt.plot( 107 | sub_sorted["ngram_size"], 108 | sub_sorted["toks_per_sec"], 109 | marker="o", 110 | label=f"block={block_size}", 111 | color=color, 112 | ) 113 | 114 | plt.xlabel("N-gram size") 115 | plt.ylabel("Tokens / second") 116 | plt.title("Tokens/s vs N-gram Size (fixed K=2, r=0.85)") 117 | plt.grid(True, alpha=0.3) 118 | plt.legend(title="Block size") 119 | plt.tight_layout() 120 | plt.savefig(out_path) 121 | print(f"[PLOT] Saved tokens_vs_ngram_size to {out_path}") 122 | 123 | 124 | def main(): 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument( 127 | "--input_jsonl", 128 | default="/home/lah003/workspace/CLLM2/profiling_results/summary/shiftedattn-10-16-7b-qwen2p5-coder-n32w16-n16distill-data-v2-ar-1-cyclic-noise-all-1e-6-summary.jsonl", 129 | help="Path to summary JSONL generated from logs", 130 | ) 131 | parser.add_argument( 132 | "--out_dir", 133 | type=str, 134 | default=None, 135 | help="Directory to save plots (default: same dir as JSONL)", 136 | ) 137 | parser.add_argument("--K", type=int, default=2, help="Fixed K value") 138 | parser.add_argument("--r", type=float, default=0.85, help="Fixed r value") 139 | parser.add_argument( 140 | "--r_tol", 141 | type=float, 142 | default=1e-3, 143 | help="Tolerance when matching r (default: 1e-3)", 144 | ) 145 | args = parser.parse_args() 146 | 147 | jsonl_path = Path(args.input_jsonl) 148 | if not jsonl_path.is_file(): 149 | raise SystemExit(f"JSONL file not found: {jsonl_path}") 150 | 151 | out_dir = Path(args.out_dir) if args.out_dir else jsonl_path.parent 152 | out_dir.mkdir(parents=True, exist_ok=True) 153 | 154 | df = load_data(jsonl_path) 155 | df_filtered = filter_for_hparams(df, K=args.K, r=args.r, r_tol=args.r_tol) 156 | 157 | df_agg = aggregate_tokens_per_sec(df_filtered) 158 | 159 | # First plot: tokens/s vs block size (curves by ngram_size) 160 | plot_tokens_vs_block_size( 161 | df_agg, out_dir / "tokens_vs_block_size_K2_r0.85.png" 162 | ) 163 | 164 | # Second plot: tokens/s vs ngram size (curves by block_size) 165 | plot_tokens_vs_ngram_size( 166 | df_agg, out_dir / "tokens_vs_ngram_size_K2_r0.85.png" 167 | ) 168 | 169 | print("[DONE]") 170 | 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /JacobiForcing/scripts/tool/extract_inference_profiling_datapoints_from_log.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import os 5 | import re 6 | from pathlib import Path 7 | 8 | 9 | RUN_ID_RE = re.compile( 10 | r"ntok(?P\d+)_K(?P\d+)_r(?P[0-9.]+)_ng(?P\d+)" 11 | ) 12 | 13 | TOKS_PER_SEC_RE = re.compile(r"Avg\s+toks/sec:\s*([0-9.eE+-]+)") 14 | ITERS_PER_TOKEN_RE = re.compile(r"Avg\s+iterations\s*/\s*token:\s*([0-9.eE+-]+)") 15 | 16 | 17 | def parse_run_id(text: str, fallback_name: str = ""): 18 | """ 19 | Extract K, r, block_size, ngram_size from either the [RUN] line or the filename. 20 | """ 21 | m_run = re.search(r"\[RUN\]\s+([^\s]+)", text) 22 | run_id_str = None 23 | run_id_str = m_run.group(1).strip() 24 | 25 | assert run_id_str is not None, f"file name {text} is not in valid format." 26 | 27 | 28 | m = RUN_ID_RE.search(run_id_str) 29 | if not m: 30 | return None 31 | 32 | return { 33 | "block_size": int(m.group("block_size")), 34 | "K": int(m.group("K")), 35 | "r": float(m.group("r")), 36 | "ngram_size": int(m.group("ngram_size")), 37 | } 38 | 39 | 40 | def parse_metrics(text: str): 41 | """ 42 | Extract Avg toks/sec and Avg iterations / token from the EOS-only summary. 43 | """ 44 | m_tps = TOKS_PER_SEC_RE.search(text) 45 | m_ipt = ITERS_PER_TOKEN_RE.search(text) 46 | 47 | if not (m_tps and m_ipt): 48 | return None 49 | 50 | avg_toks_per_sec = float(m_tps.group(1)) 51 | avg_iters_per_token = float(m_ipt.group(1)) 52 | 53 | # Compute Avg tokens / iteration; guard against zero 54 | if avg_iters_per_token > 0: 55 | avg_tokens_per_iter = 1.0 / avg_iters_per_token 56 | else: 57 | avg_tokens_per_iter = math.nan 58 | 59 | return { 60 | "avg_toks_per_sec": avg_toks_per_sec, 61 | "avg_iters_per_token": avg_iters_per_token, 62 | "avg_tokens_per_iter": avg_tokens_per_iter, 63 | } 64 | 65 | 66 | def collect_from_logs(log_dir: Path): 67 | entries = [] 68 | 69 | for log_path in sorted(log_dir.rglob("*.log")): 70 | text = log_path.read_text(encoding="utf-8", errors="ignore") 71 | 72 | params = parse_run_id(text, fallback_name=log_path.stem) 73 | if params is None: 74 | print(f"[WARNING!!!] Could not parse run id from {log_path}") 75 | continue 76 | 77 | metrics = parse_metrics(text) 78 | if metrics is None: 79 | print(f"[WARNING!!!] Could not parse metrics from {log_path}") 80 | continue 81 | 82 | entry = { 83 | "log_path": str(log_path), 84 | } 85 | entry.update(params) 86 | entry.update(metrics) 87 | entries.append(entry) 88 | 89 | return entries 90 | 91 | 92 | def sanitize_for_json_value(v): 93 | """ 94 | Replace NaN/inf with None so json is valid. 95 | """ 96 | if isinstance(v, float) and (math.isnan(v) or math.isinf(v)): 97 | return None 98 | return v 99 | 100 | 101 | def main(): 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument( 104 | "--log_dir", 105 | default="/home/lah003/workspace/CLLM2/logs/hparam_sweep_20251125_171121", 106 | help="Directory containing *.log files", 107 | ) 108 | parser.add_argument( 109 | "--out_jsonl", 110 | type=str, 111 | default=None, 112 | help="Output JSONL path (default: /summary.jsonl)", 113 | ) 114 | args = parser.parse_args() 115 | 116 | log_dir = Path(args.log_dir) 117 | if not log_dir.is_dir(): 118 | raise SystemExit(f"Log directory does not exist or is not a directory: {log_dir}") 119 | 120 | out_jsonl = ( 121 | Path(args.out_jsonl) 122 | if args.out_jsonl is not None 123 | else log_dir / "summary.jsonl" 124 | ) 125 | 126 | entries = collect_from_logs(log_dir) 127 | if not entries: 128 | print(f"[INFO] No valid entries found under {log_dir}") 129 | return 130 | 131 | # ---- Find best runs ---- 132 | def valid_float(e, key): 133 | v = e.get(key) 134 | return isinstance(v, (int, float)) and not math.isnan(v) and not math.isinf(v) 135 | 136 | # best avg_toks_per_sec 137 | valid_tps_entries = [e for e in entries if valid_float(e, "avg_toks_per_sec")] 138 | best_tps = max(valid_tps_entries, key=lambda e: e["avg_toks_per_sec"]) if valid_tps_entries else None 139 | 140 | # best avg_tokens_per_iter 141 | valid_tpi_entries = [e for e in entries if valid_float(e, "avg_tokens_per_iter")] 142 | best_tpi = max(valid_tpi_entries, key=lambda e: e["avg_tokens_per_iter"]) if valid_tpi_entries else None 143 | 144 | if best_tps: 145 | print("\n=== Best Avg toks/sec ===") 146 | print(f"value : {best_tps['avg_toks_per_sec']:.4f}") 147 | print(f"log_path : {best_tps['log_path']}") 148 | print(f"K : {best_tps['K']}") 149 | print(f"r : {best_tps['r']}") 150 | print(f"block_size : {best_tps['block_size']}") 151 | print(f"ngram_size : {best_tps['ngram_size']}") 152 | 153 | if best_tpi: 154 | print("\n=== Best Avg tokens / iteration ===") 155 | print(f"value : {best_tpi['avg_tokens_per_iter']:.6f}") 156 | print(f"log_path : {best_tpi['log_path']}") 157 | print(f"K : {best_tpi['K']}") 158 | print(f"r : {best_tpi['r']}") 159 | print(f"block_size : {best_tpi['block_size']}") 160 | print(f"ngram_size : {best_tpi['ngram_size']}") 161 | 162 | # ---- Write JSONL ---- 163 | os.makedirs(out_jsonl.parent, exist_ok=True) 164 | with out_jsonl.open("w", encoding="utf-8") as f: 165 | for item in entries: 166 | sanitized = {k: sanitize_for_json_value(v) for k, v in item.items()} 167 | f.write(json.dumps(sanitized) + "\n") 168 | 169 | print(f"\n[DONE] Wrote {len(entries)} entries to {out_jsonl}") 170 | 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /applications/jacobi_streaming_driver.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Callable, List, Dict, Optional 3 | 4 | import torch 5 | 6 | 7 | @torch.inference_mode() 8 | def jacobi_stream_chat( 9 | model, 10 | tokenizer, 11 | messages: List[Dict[str, str]], 12 | n_token_seq_len: int = 64, 13 | max_new_tokens: int = 512, 14 | K: int = 2, 15 | r: float = 0.8, 16 | n_gram_pool_size: int = 4, 17 | on_text: Optional[Callable[[str], None]] = None, 18 | stream_per_token: bool = True, # NEW 19 | ): 20 | eos_id = tokenizer.eos_token_id 21 | pad_id = tokenizer.pad_token_id 22 | 23 | prompt = tokenizer.apply_chat_template( 24 | messages, 25 | tokenize=False, 26 | add_generation_prompt=True, 27 | ) 28 | 29 | model_inputs = tokenizer([prompt], return_tensors="pt") 30 | model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()} 31 | input_ids = model_inputs["input_ids"] 32 | prompt_len = input_ids.shape[1] 33 | 34 | attention_mask = torch.ones_like(input_ids, device=input_ids.device) 35 | 36 | past_key_values = None 37 | prefill_phase = True 38 | generated_ids = input_ids.clone() 39 | 40 | total_new_tokens_est = 0 41 | calls = 0 42 | generated_text = "" 43 | 44 | jacobi_time = 0.0 45 | MAX_CALLS = 128 46 | 47 | first_correct_token = None 48 | prefill_drafted_n_gram = None 49 | 50 | while True: 51 | generated_part = generated_ids[:, prompt_len:] 52 | if eos_id is not None and generated_part.numel() > 0: 53 | if (generated_part == eos_id).any().item(): 54 | break 55 | 56 | if total_new_tokens_est >= max_new_tokens: 57 | break 58 | if calls >= MAX_CALLS: 59 | break 60 | 61 | if prefill_phase: 62 | seq_len = generated_ids.shape[1] 63 | idxs = torch.randint( 64 | low=0, 65 | high=seq_len, 66 | size=(n_token_seq_len,), 67 | device=generated_ids.device, 68 | ) 69 | prefill_draft_token_ids = generated_ids[0, idxs].unsqueeze(0) 70 | prefill_input_ids = torch.cat((input_ids, prefill_draft_token_ids), dim=-1) 71 | 72 | past_key_values, first_correct_token, prefill_drafted_n_gram, _ = ( 73 | model.jacobi_forward_greedy_multiblock( 74 | input_ids=prefill_input_ids, 75 | attention_mask=attention_mask, 76 | past_key_values=None, 77 | use_cache=True, 78 | prefill_phase=True, 79 | n_token_seq_len=n_token_seq_len, 80 | K=K, 81 | r=r, 82 | n_gram_pool_size=n_gram_pool_size, 83 | tokenizer=tokenizer, 84 | eos_token_id=eos_id, 85 | pad_token_id=pad_id, 86 | ) 87 | ) 88 | 89 | prefill_phase = False 90 | generated_ids = input_ids 91 | calls += 1 92 | continue 93 | 94 | if calls == 1: 95 | draft_input_ids = prefill_drafted_n_gram 96 | else: 97 | seq_len = generated_ids.shape[1] 98 | tail_len = max(n_token_seq_len - 1, 1) 99 | idxs = torch.randint( 100 | low=0, 101 | high=seq_len, 102 | size=(tail_len,), 103 | device=generated_ids.device, 104 | ) 105 | tail = generated_ids[0, idxs].unsqueeze(0) 106 | draft_input_ids = torch.cat( 107 | (first_correct_token.view(1, -1), tail), dim=-1 108 | ) 109 | 110 | t0 = time.perf_counter() 111 | past_key_values, first_correct_token, accepted_n_gram, _ = ( 112 | model.jacobi_forward_greedy_multiblock( 113 | input_ids=draft_input_ids, 114 | attention_mask=None, 115 | past_key_values=past_key_values, 116 | use_cache=True, 117 | prefill_phase=False, 118 | n_token_seq_len=n_token_seq_len, 119 | K=K, 120 | r=r, 121 | n_gram_pool_size=n_gram_pool_size, 122 | tokenizer=tokenizer, 123 | eos_token_id=eos_id, 124 | pad_token_id=pad_id, 125 | ) 126 | ) 127 | jacobi_time += (time.perf_counter() - t0) 128 | calls += 1 129 | 130 | if accepted_n_gram is None or accepted_n_gram.numel() == 0: 131 | continue 132 | 133 | generated_ids = torch.cat((generated_ids, accepted_n_gram), dim=-1) 134 | 135 | token_ids = accepted_n_gram[0].tolist() 136 | 137 | if pad_id is not None: 138 | token_ids = [t for t in token_ids if t != pad_id] 139 | 140 | eos_hit = False 141 | usable_ids = [] 142 | for t in token_ids: 143 | if eos_id is not None and t == eos_id: 144 | eos_hit = True 145 | break 146 | usable_ids.append(t) 147 | 148 | if usable_ids: 149 | total_new_tokens_est += len(usable_ids) 150 | 151 | if stream_per_token: 152 | # STREAM EVERY TOKEN (UI update per token) 153 | for t in usable_ids: 154 | delta = tokenizer.decode( 155 | [t], 156 | skip_special_tokens=True, 157 | clean_up_tokenization_spaces=False, 158 | ) 159 | if not delta: 160 | continue 161 | generated_text += delta 162 | if on_text is not None: 163 | on_text(generated_text) 164 | else: 165 | # STREAM PER CHUNK (your old behavior) 166 | chunk_text = tokenizer.decode( 167 | usable_ids, 168 | skip_special_tokens=True, 169 | clean_up_tokenization_spaces=False, 170 | ) 171 | if chunk_text: 172 | generated_text += chunk_text 173 | if on_text is not None: 174 | on_text(generated_text) 175 | 176 | if eos_hit: 177 | break 178 | 179 | gen_time = jacobi_time 180 | assistant_text = generated_text.strip() 181 | 182 | final_token_ids = torch.empty((1, 0), dtype=torch.long) 183 | if assistant_text: 184 | final_token_ids = tokenizer( 185 | assistant_text, 186 | add_special_tokens=False, 187 | return_tensors="pt", 188 | ).input_ids 189 | new_tokens = int(final_token_ids.shape[1]) - 1 # keep your "-1" behavior 190 | else: 191 | new_tokens = 0 192 | 193 | return assistant_text, final_token_ids, total_new_tokens_est, gen_time 194 | -------------------------------------------------------------------------------- /generate_trajectory/data/1_masking_based_prepare_trajectory.py: -------------------------------------------------------------------------------- 1 | """ 2 | –––––––––––––––––––––––––––––––––– 3 | final JSON format: 4 | 5 | { 6 | "data_id": str, # "data_" 7 | "diffusion_itr_id": str, # always "itr_0" here 8 | "prompt_ids_len": [int], # list(len(prompt_ids)) 9 | "prompt_ids": [int, …], # full encoded prompt 10 | "answer_trajectory_ids":[[int,…], …], # 1 element per chunk 11 | "teacher_output_ids": [int, …], # = labels_ids 12 | "labels_ids": [int, …] # same as above (optional) 13 | } 14 | """ 15 | from datasets import load_dataset 16 | from transformers import AutoTokenizer 17 | from tqdm import tqdm 18 | import random, copy, json, math, re, os 19 | from multiprocessing import Pool 20 | from functools import partial 21 | 22 | THOUGHT_RE = re.compile( 23 | r"<\|begin_of_thought\|>\n\n(.*?)\n\n<\|end_of_thought\|>", re.DOTALL 24 | ) 25 | SOLUTION_RE = re.compile( 26 | r"<\|begin_of_solution\|>\n\n(.*?)\n\n<\|end_of_solution\|>", re.DOTALL 27 | ) 28 | 29 | def process_response(resp: str) -> str: 30 | tm, sm = THOUGHT_RE.search(resp), SOLUTION_RE.search(resp) 31 | if not (tm and sm): 32 | return resp.strip() 33 | return f"\n{tm.group(1).strip()}\n\n\n{sm.group(1).strip()}" 34 | 35 | def build_messages(sample, use_think_format=False, use_system_prompt=False): 36 | 37 | if use_system_prompt: 38 | system_msg = sample["system"] 39 | msgs = [{"role": "system", "content": system_msg}] 40 | else: 41 | msgs = [] 42 | 43 | for turn in sample["conversations"]: 44 | role = "user" if turn["from"] == "user" else "assistant" 45 | if use_think_format: 46 | content = turn["value"] if role == "user" else process_response(turn["value"]) 47 | else: 48 | content = turn["value"] 49 | msgs.append({"role": role, "content": content}) 50 | return msgs 51 | 52 | def build_user_prompt(sample, use_system_prompt=False): 53 | if use_system_prompt: 54 | system_msg = sample["system"] 55 | msgs = [{"role": "system", "content": system_msg}] 56 | else: 57 | msgs = [] 58 | for turn in sample["conversations"]: 59 | if turn["from"] == "user": 60 | role = "user" 61 | content = turn["value"] 62 | else: 63 | continue 64 | msgs.append({"role": role, "content": content}) 65 | return msgs 66 | 67 | def convert_sample(sample, row_id: int, tokenizer, chunk_size=32, use_think_format=False, use_system_prompt=False): 68 | """Return a list[dict] ready for the final JSON dump.""" 69 | prompt_msgs = build_user_prompt(sample, use_system_prompt=use_system_prompt) 70 | msgs = build_messages(sample, use_think_format=use_think_format, use_system_prompt=use_system_prompt) 71 | 72 | # prompt (generation-prompt=True ➜ no assistant answer) 73 | prompt_ids = tokenizer.apply_chat_template( 74 | prompt_msgs, 75 | tokenize=True, 76 | add_generation_prompt=True, 77 | return_tensors="pt" 78 | ).squeeze(0).tolist() 79 | 80 | # target (= teacher output) 81 | full_ids = tokenizer.apply_chat_template( 82 | msgs, 83 | tokenize=True, 84 | add_generation_prompt=False, 85 | return_tensors="pt" 86 | ).squeeze(0).tolist() 87 | 88 | print("prompt_ids:", len(prompt_ids), "tokens") 89 | print("full_ids:", len(full_ids), "tokens") 90 | 91 | if len(full_ids) > 16_384: # length guard 92 | return [] 93 | 94 | # build *answer_trajectory_ids* 95 | # keep only the ground-truth chunk for each 32-token block 96 | pad_id = tokenizer.eos_token_id 97 | 98 | # Number of tokens after the prompt 99 | response_length = len(full_ids) - len(prompt_ids) 100 | 101 | # If not divisible by chunk_size (32), pad with pad_id 102 | if response_length % chunk_size != 0: 103 | pad_amt = chunk_size - (response_length % chunk_size) 104 | full_ids = full_ids + [pad_id] * pad_amt 105 | 106 | answer_trajectory = [] 107 | for i in range(len(prompt_ids), len(full_ids), chunk_size): 108 | chunk = full_ids[i : i + chunk_size] 109 | answer_trajectory.append(chunk) 110 | 111 | record = dict( 112 | data_id = f"data_{row_id}", 113 | diffusion_itr_id = "itr_0", 114 | prompt_ids_len = [len(prompt_ids)], 115 | prompt_ids = prompt_ids, 116 | answer_trajectory_ids = answer_trajectory, 117 | teacher_output_ids = full_ids, 118 | labels_ids = full_ids, # set if --use_labels 119 | ) 120 | assert ("answer_trajectory_ids" in record) and (record["answer_trajectory_ids"] is not None), f"Missing key in row {row_id}" 121 | 122 | return [record] 123 | 124 | def preprocess_parallel(data, tokenizer, chunk_size=32, n_workers=16, start_idx=0, use_think_format=False, use_system_prompt=False): 125 | func = partial(convert_sample, tokenizer=tokenizer, chunk_size=chunk_size, use_think_format=use_think_format, use_system_prompt=use_system_prompt) 126 | 127 | # build (sample, row_id) pairs 128 | jobs = [(s, start_idx + i) for i, s in enumerate(data)] 129 | with Pool(n_workers) as pool: 130 | out = [] 131 | for recs in tqdm(pool.starmap(func, jobs), total=len(jobs)): 132 | out.extend(recs) 133 | 134 | return out 135 | 136 | if __name__ == "__main__": 137 | random.seed(42) 138 | 139 | tokenizer = AutoTokenizer.from_pretrained( 140 | "/checkpoint/lhu/models/OpenThinker2-7B", trust_remote_code=True 141 | ) 142 | ds = load_dataset( 143 | "parquet", 144 | data_files="/checkpoint/lhu/data/OpenThoughts-114k/data/train-*.parquet", 145 | split="train" 146 | ) 147 | 148 | print("Loaded", len(ds), "rows") 149 | 150 | # 2. Randomly select a split of the data 151 | SPLIT_RATIO = 1 152 | subset_size = len(ds) / SPLIT_RATIO 153 | indices = list(range(len(ds))) 154 | random.shuffle(indices) 155 | selected_indices = indices[:subset_size] 156 | ds_subset = ds.select(selected_indices) 157 | 158 | print("Subset size:", len(ds_subset)) 159 | 160 | CHUNK = 512 # rows per shard you want on disk 161 | outfile = f"/checkpoint/lhu/data/CLLM2_openthought/train_openthoughts__split_ratio_{SPLIT_RATIO}_size_{len(ds_subset)}_ntok64_formatted_with_eos_tokens_with_think_format_without_sysmsg.json" 162 | os.makedirs(os.path.dirname(outfile), exist_ok=True) 163 | 164 | all_records = [] 165 | for shard in range(math.ceil(len(ds_subset) / CHUNK)): 166 | a, b = shard * CHUNK, min((shard + 1) * CHUNK, len(ds_subset)) 167 | print(f"Processing rows {a}…{b-1}") 168 | sub = ds_subset.select(range(a, b)) 169 | all_records.extend( 170 | preprocess_parallel(sub, tokenizer, chunk_size=64, start_idx=a, use_think_format=True, use_system_prompt=False) 171 | ) 172 | 173 | with open(outfile, "w", encoding="utf-8") as f: 174 | json.dump(all_records, f) 175 | 176 | print(f"Wrote {len(all_records):,} aligned records ➜ {outfile}") 177 | -------------------------------------------------------------------------------- /generate_trajectory/data/0_bucketing_opencodeinstruct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Bucket **OpenCodeInstruct**-style data into fixed-count shards (default: 5k examples), 5 | using a chat template to tokenize the full (user,input ↔ assistant,output) pair and 6 | sorting by total token count. 7 | 8 | Strict assumptions (per dataset card): 9 | - Each JSONL row has fields: `input` (question/instruction) and `output` (LLM response). 10 | - We emit ONLY `user` and `assistant` roles; no system role. 11 | - Input format: ONLY `*.jsonl` under --input_path (one JSON object per line). 12 | - Output: JSON files named `bucket_XXXX_avgA_minB_maxC.json`, each a JSON array of the 13 | FIRST user prompts (i.e., the `input` text) for the 5k examples in that bucket. 14 | 15 | Usage: 16 | python bucket_opencodeinstruct.py \ 17 | --input_path /path/to/jsonl_dir \ 18 | --output_path /path/to/out \ 19 | --tokenizer_path Qwen/Qwen2.5-7B-Instruct \ 20 | --bucket_size 5000 \ 21 | --n_workers 8 22 | """ 23 | 24 | import os, glob, json, argparse, multiprocessing as mp 25 | from functools import partial 26 | from typing import List, Dict, Any, Optional 27 | from tqdm import tqdm 28 | from transformers import AutoTokenizer 29 | 30 | TOKENIZER_PATH: Optional[str] = "/checkpoint/lhu/models/Qwen2.5-Coder-7B-Instruct" # set from CLI in main() 31 | TOKENIZER = None # global per worker 32 | 33 | 34 | def init_worker(): 35 | """Initialise the global tokenizer once per worker.""" 36 | global TOKENIZER 37 | if TOKENIZER is None: 38 | if not TOKENIZER_PATH: 39 | raise RuntimeError("TOKENIZER_PATH not set in worker.") 40 | TOKENIZER = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) 41 | 42 | 43 | def tokenize_pair(user_text: str, assistant_text: str) -> Optional[int]: 44 | """Return total token count for a (user, assistant) pair via apply_chat_template.""" 45 | global TOKENIZER 46 | msgs = [ 47 | {"role": "user", "content": user_text}, 48 | {"role": "assistant", "content": assistant_text}, 49 | ] 50 | try: 51 | ids = ( 52 | TOKENIZER.apply_chat_template( 53 | msgs, 54 | tokenize=True, 55 | add_generation_prompt=False, 56 | return_tensors="pt", 57 | ) 58 | .squeeze(0) 59 | .tolist() 60 | ) 61 | return len(ids) 62 | except Exception as e: 63 | print("⚠️ Tokenization error:", e) 64 | return None 65 | 66 | 67 | def process_sample(sample: Dict[str, Any]) -> Optional[Dict[str, Any]]: 68 | """Build messages from `input`/`output`, tokenise, and return stats for bucketing.""" 69 | inp = sample.get("input") 70 | out = sample.get("output") 71 | if not isinstance(inp, str) or not isinstance(out, str): 72 | return None 73 | n_tok = tokenize_pair(inp, out) 74 | 75 | # tidy whitespace of prompt for output file 76 | prompt_clean = " ".join(inp.split()) 77 | return {"prompt": prompt_clean, "n_tokens": n_tok} 78 | 79 | 80 | def load_all_records(input_path: str) -> List[Dict[str, Any]]: 81 | """Load only *.jsonl files under a directory into a list of dicts.""" 82 | paths = sorted(glob.glob(os.path.join(input_path, "*.jsonl"))) 83 | if not paths: 84 | raise FileNotFoundError(f"No *.jsonl found under {input_path}") 85 | print(f"STEP 0: Found {len(paths)} jsonl file(s). Reading...") 86 | 87 | rows: List[Dict[str, Any]] = [] 88 | for p in paths: 89 | with open(p, "r", encoding="utf-8") as fin: 90 | for i, line in enumerate(fin, start=1): 91 | line = line.strip() 92 | if not line: 93 | continue 94 | try: 95 | obj = json.loads(line) 96 | except Exception as e: 97 | print(f"⚠️ Skipping bad JSON line {i} in {os.path.basename(p)}: {e}") 98 | continue 99 | rows.append(obj) 100 | print(f"STEP 1: Loaded {len(rows):,} json-rows from {len(paths)} file(s)") 101 | return rows 102 | 103 | 104 | def main(input_path: str, 105 | output_path: str, 106 | *, 107 | tokenizer_path: str, 108 | bucket_size: int = 5_000, 109 | n_workers: int = 8): 110 | 111 | global TOKENIZER_PATH 112 | TOKENIZER_PATH = tokenizer_path 113 | 114 | os.makedirs(output_path, exist_ok=True) 115 | 116 | # 1) Load ------------------------------------------------------------------------------------ 117 | samples = load_all_records(input_path) 118 | 119 | # 2) Token-count each sample in parallel ----------------------------------------------------- 120 | with mp.Pool(n_workers, initializer=init_worker) as pool: 121 | processed = list( 122 | tqdm( 123 | pool.imap(partial(process_sample), samples), 124 | total=len(samples), 125 | desc="Tokenising", 126 | ) 127 | ) 128 | 129 | processed = [p for p in processed if p is not None] 130 | print(f"STEP 2: Tokenised {len(processed):,} samples") 131 | 132 | # 3) Sort by token length (ascending) 133 | processed.sort(key=lambda x: x["n_tokens"]) 134 | 135 | # 4) Slice into buckets of N prompts 136 | print(f"STEP 3: Bucketing {len(processed):,} prompts --> {bucket_size} prompts per file") 137 | for i in range(0, len(processed), bucket_size): 138 | bucket_idx = i // bucket_size 139 | bucket = processed[i : i + bucket_size] 140 | if not bucket: 141 | continue 142 | 143 | tok_counts = [b["n_tokens"] for b in bucket] 144 | min_tok = min(tok_counts) 145 | max_tok = max(tok_counts) 146 | avg_tok = int(round(sum(tok_counts) / len(tok_counts))) 147 | 148 | out_fname = ( 149 | f"bucket_{bucket_idx:04d}" 150 | f"_avg{avg_tok}_min{min_tok}_max{max_tok}.json" 151 | ) 152 | out_path = os.path.join(output_path, out_fname) 153 | 154 | prompts = [item["prompt"] for item in bucket] 155 | with open(out_path, "w", encoding="utf-8") as fout: 156 | json.dump(prompts, fout, ensure_ascii=False, indent=2) 157 | 158 | print(f"-- {out_fname} ({len(prompts)} prompts, {sum(tok_counts):,} tokens)") 159 | 160 | 161 | if __name__ == "__main__": 162 | parser = argparse.ArgumentParser( 163 | description=( 164 | "Bucket OpenCodeInstruct prompts (input→user) by total token length. " 165 | "Each output file contains a fixed NUMBER of prompts." 166 | ) 167 | ) 168 | parser.add_argument("--input_path", required=True, 169 | help="Directory containing *.jsonl files") 170 | parser.add_argument("--output_path", required=True, 171 | help="Directory for bucketed prompt files") 172 | parser.add_argument("--tokenizer_path", required=True, 173 | help="HF repo or local path for tokenizer with a chat template") 174 | parser.add_argument("--bucket_size", type=int, default=25_000, 175 | help="Number of prompts per output file (default: 25 000)") 176 | parser.add_argument("--n_workers", type=int, default=8, 177 | help="Tokenisation workers (default: 8)") 178 | 179 | args = parser.parse_args() 180 | mp.set_start_method("spawn", force=True) 181 | 182 | main( 183 | args.input_path, 184 | args.output_path, 185 | tokenizer_path=args.tokenizer_path, 186 | bucket_size=args.bucket_size, 187 | n_workers=args.n_workers, 188 | ) 189 | -------------------------------------------------------------------------------- /JacobiForcing/train/baseline_sft_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from dataclasses import dataclass, field 5 | import math 6 | import pathlib 7 | from typing import Dict, Optional, List, Any 8 | 9 | import os 10 | import sys 11 | import torch 12 | from torch.utils.data import Dataset, DataLoader 13 | from torch.utils.data import default_collate 14 | import transformers 15 | from transformers.trainer_pt_utils import LabelSmoother 16 | import logging 17 | from datasets import load_dataset 18 | 19 | from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training 20 | 21 | logger = logging.getLogger(__name__) 22 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index # -100 23 | 24 | @dataclass 25 | class ModelArguments: 26 | target_model_path: Optional[str] = field( 27 | default="models/vicuna-7b-v1.5", 28 | metadata={"help": "Path or HF id for the base model"}, 29 | ) 30 | qlora: bool = field(default=False, metadata={"help": "Enable QLoRA"}) 31 | 32 | @dataclass 33 | class DataArguments: 34 | data_path: str = field(default=None, metadata={"help": "Path to the training data JSONL"}) 35 | 36 | @dataclass 37 | class TrainingArguments(transformers.TrainingArguments): 38 | cache_dir: Optional[str] = field(default=None) 39 | optim: str = field(default="adamw_torch") 40 | model_max_length: int = field( 41 | default=1024, metadata={"help": "Max seq length for tokenizer/model"} 42 | ) 43 | report_to: str = field(default="wandb") 44 | remove_unused_columns: bool = field(default=False) 45 | gradient_checkpointing: bool = field( 46 | default=True, 47 | metadata={"help": "Enable gradient checkpointing to reduce activation memory"}, 48 | ) 49 | bf16: bool = field( 50 | default=True, metadata={"help": "Train in bfloat16 for efficiency."} 51 | ) 52 | fp16: bool = field(default=False) 53 | 54 | class LabelsDataset(Dataset): 55 | def __init__(self, hf_dataset): 56 | super().__init__() 57 | self.ds = hf_dataset 58 | 59 | def __len__(self): 60 | return len(self.ds) 61 | 62 | def __getitem__(self, idx) -> Dict[str, List[int]]: 63 | ex = self.ds[idx] 64 | if "labels_ids" not in ex: 65 | raise KeyError("Each JSONL row must contain 'labels_ids'.") 66 | seq = ex["labels_ids"] 67 | if not isinstance(seq, list) or len(seq) == 0: 68 | raise ValueError("'labels_ids' must be a non-empty list of ints.") 69 | return {"labels_ids": seq} 70 | 71 | def make_collate_fn(tokenizer: transformers.PreTrainedTokenizer, max_len: int): 72 | if tokenizer.pad_token_id is None: 73 | if tokenizer.eos_token_id is not None: 74 | tokenizer.pad_token = tokenizer.eos_token 75 | else: 76 | tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) 77 | pad_id = tokenizer.pad_token_id 78 | eos_id = tokenizer.eos_token_id 79 | 80 | def collate(batch: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: 81 | # Trim each sequence at the FIRST eos 82 | trimmed = [] 83 | for ex in batch: 84 | seq = ex["labels_ids"] 85 | if eos_id is not None and eos_id in seq: 86 | first_eos = seq.index(eos_id) 87 | seq = seq[: first_eos + 1] # keep eos, drop after 88 | trimmed.append(seq) 89 | 90 | # target length 91 | seq_lens = [len(s) for s in trimmed] 92 | target_len = min(max(seq_lens), max_len) 93 | 94 | input_ids_list, labels_list, attn_list = [], [], [] 95 | for seq in trimmed: 96 | seq = seq[:target_len] 97 | pad_needed = target_len - len(seq) 98 | 99 | # after-eos positions are removed; pad to target_len. 100 | ids = seq + [pad_id] * pad_needed 101 | lbl = seq + [IGNORE_TOKEN_ID] * pad_needed 102 | attn = [1] * len(seq) + [0] * pad_needed 103 | 104 | input_ids_list.append(torch.tensor(ids, dtype=torch.long)) 105 | labels_list.append(torch.tensor(lbl, dtype=torch.long)) 106 | attn_list.append(torch.tensor(attn, dtype=torch.long)) 107 | 108 | input_ids = torch.stack(input_ids_list, dim=0) 109 | labels = torch.stack(labels_list, dim=0) 110 | attention_mask = torch.stack(attn_list, dim=0) 111 | return { 112 | "input_ids": input_ids, 113 | "labels": labels, 114 | "attention_mask": attention_mask 115 | } 116 | 117 | return collate 118 | 119 | def safe_save_model(model, tokenizer, output_dir): 120 | os.makedirs(output_dir, exist_ok=True) 121 | to_save = model 122 | if hasattr(to_save, "module"): 123 | to_save = to_save.module 124 | to_save.save_pretrained(output_dir, safe_serialization=False) 125 | if tokenizer is not None: 126 | tokenizer.save_pretrained(output_dir) 127 | 128 | def train(): 129 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 130 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 131 | 132 | logging.basicConfig( 133 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 134 | datefmt="%m/%d/%Y %H:%M:%S", 135 | handlers=[logging.StreamHandler(sys.stdout)], 136 | level=logging.INFO, 137 | ) 138 | transformers.utils.logging.set_verbosity_info() 139 | 140 | # Config 141 | config = transformers.AutoConfig.from_pretrained( 142 | model_args.target_model_path, cache_dir=training_args.cache_dir 143 | ) 144 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 145 | if orig_ctx_len and training_args.model_max_length > orig_ctx_len: 146 | raise ValueError( 147 | f"model_max_length ({training_args.model_max_length}) exceeds model context ({orig_ctx_len})." 148 | ) 149 | config.use_cache = False 150 | 151 | # Model + Tokenizer 152 | model = transformers.AutoModelForCausalLM.from_pretrained( 153 | model_args.target_model_path, 154 | config=config, 155 | cache_dir=training_args.cache_dir, 156 | torch_dtype=torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None), 157 | low_cpu_mem_usage=True, 158 | ) 159 | tokenizer = transformers.AutoTokenizer.from_pretrained( 160 | model_args.target_model_path, padding_side="right", use_fast=False 161 | ) 162 | # If we added a new pad token above, resize embeddings 163 | if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: 164 | tokenizer.pad_token = tokenizer.eos_token 165 | if len(tokenizer) > model.get_input_embeddings().weight.size(0): 166 | model.resize_token_embeddings(len(tokenizer)) 167 | 168 | if getattr(training_args, "gradient_checkpointing", False): 169 | model.gradient_checkpointing_enable() 170 | 171 | if model_args.qlora: 172 | model = prepare_model_for_kbit_training(model) 173 | lora_cfg = LoraConfig(task_type=TaskType.CAUSAL_LM, r=32, lora_alpha=16, lora_dropout=0.05) 174 | model = get_peft_model(model, lora_cfg) 175 | 176 | hf_ds = load_dataset("json", data_files={"train": data_args.data_path}, 177 | split="train", cache_dir=training_args.cache_dir) 178 | train_dataset = LabelsDataset(hf_ds) 179 | 180 | data_collator = make_collate_fn(tokenizer, training_args.model_max_length) 181 | 182 | # --- Regular HF Trainer --- 183 | trainer = transformers.Trainer( 184 | model=model, 185 | tokenizer=tokenizer, 186 | args=training_args, 187 | train_dataset=train_dataset, 188 | data_collator=data_collator, 189 | ) 190 | 191 | # Train 192 | trainer.train(resume_from_checkpoint=True if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")) else None) 193 | 194 | # Save 195 | safe_save_model(model, tokenizer, training_args.output_dir) 196 | 197 | if __name__ == "__main__": 198 | train() 199 | -------------------------------------------------------------------------------- /generate_trajectory/data/0_bucketing_openthought2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os, glob, json, random, re, argparse, multiprocessing as mp 5 | from functools import partial 6 | from tqdm import tqdm 7 | import pandas as pd 8 | import numpy as np 9 | from transformers import AutoTokenizer 10 | 11 | # ---------- Prompt-template helpers ---------- 12 | THOUGHT_RE = re.compile(r"<\|begin_of_thought\|>\n\n(.*?)\n\n<\|end_of_thought\|>", re.DOTALL) 13 | SOLUTION_RE = re.compile(r"<\|begin_of_solution\|>\n\n(.*?)\n\n<\|end_of_solution\|>", re.DOTALL) 14 | 15 | TOKENIZER_PATH = "/checkpoint/lhu/models/OpenThinker2-7B" 16 | tokenizer = None 17 | 18 | def init_worker(): 19 | """Initialise the global tokenizer once per worker.""" 20 | global tokenizer 21 | tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) 22 | 23 | def process_response(resp: str) -> str: 24 | """Replace / blocks with … for assistant turns (if wanted).""" 25 | tm, sm = THOUGHT_RE.search(resp), SOLUTION_RE.search(resp) 26 | if not (tm and sm): 27 | return resp.strip() 28 | return f"\n{tm.group(1).strip()}\n\n\n{sm.group(1).strip()}" 29 | 30 | # ---------- JSON-safe utility ---------- 31 | def to_json_safe(obj): 32 | """Recursively convert NumPy arrays to lists so json.dumps doesn’t choke (not used now, but handy).""" 33 | if isinstance(obj, dict): 34 | return {k: to_json_safe(v) for k, v in obj.items()} 35 | if isinstance(obj, list): 36 | return [to_json_safe(v) for v in obj] 37 | if isinstance(obj, np.ndarray): 38 | return obj.tolist() 39 | return obj 40 | 41 | # ---------- Build chat messages ---------- 42 | def build_messages(sample, *, use_think_format=False, use_system_prompt=False): 43 | msgs = [] 44 | if use_system_prompt and "system" in sample: 45 | msgs.append({"role": "system", "content": sample["system"]}) 46 | for turn in sample["conversations"]: 47 | role = "user" if turn["from"] == "user" else "assistant" 48 | if role == "assistant" and use_think_format: 49 | content = process_response(turn["value"]) 50 | else: 51 | content = turn["value"] 52 | msgs.append({"role": role, "content": content}) 53 | return msgs 54 | 55 | def tokenize_full(sample, *, use_think_format=False, use_system_prompt=False): 56 | """Return list of token IDs for the *whole* sample conversation.""" 57 | global tokenizer 58 | msgs = build_messages(sample, 59 | use_think_format=use_think_format, 60 | use_system_prompt=use_system_prompt) 61 | try: 62 | ids = tokenizer.apply_chat_template( 63 | msgs, 64 | tokenize=True, 65 | add_generation_prompt=False, 66 | return_tensors="pt" 67 | ).squeeze(0).tolist() 68 | return ids 69 | except Exception as e: 70 | print("⚠️ Tokenization error:", e) 71 | return None 72 | 73 | # ---------- Worker wrapper ---------- 74 | def process_sample(sample, 75 | *, use_think_format=False, use_system_prompt=False): 76 | ids = tokenize_full(sample, 77 | use_think_format=use_think_format, 78 | use_system_prompt=use_system_prompt) 79 | if ids is None: 80 | return None 81 | # Extract *first* user prompt string 82 | prompt_text = next( 83 | (turn["value"] for turn in sample["conversations"] 84 | if turn.get("from") == "user"), None) 85 | if prompt_text is None: 86 | return None 87 | return dict(prompt=prompt_text, n_tokens=len(ids)) 88 | 89 | # ---------- Main pipeline ---------- 90 | def main(input_path, output_path, 91 | *, bucket_size=50_000, 92 | use_think_format=True, 93 | use_system_prompt=False, 94 | n_workers=8): 95 | 96 | os.makedirs(output_path, exist_ok=True) 97 | 98 | # 1. Load data -------------------------------------------------------------------------------- 99 | parquet_files = sorted(glob.glob(os.path.join(input_path, "*.parquet"))) 100 | if not parquet_files: 101 | raise FileNotFoundError(f"No *.parquet found under {input_path}") 102 | print(f"🗄️ Found {len(parquet_files)} parquet file(s). Reading...") 103 | 104 | dfs = [pd.read_parquet(p) for p in parquet_files] 105 | df = pd.concat(dfs, ignore_index=True) 106 | samples = df.to_dict("records") 107 | print(f"✅ Loaded {len(samples):,} json-rows from parquets") 108 | 109 | # 2. Token-count each sample in parallel ------------------------------------------------------ 110 | with mp.Pool(n_workers, initializer=init_worker) as pool: 111 | processed = list( 112 | tqdm(pool.imap( 113 | partial(process_sample, 114 | use_think_format=use_think_format, 115 | use_system_prompt=use_system_prompt), 116 | samples), 117 | total=len(samples), 118 | desc="Tokenising")) 119 | # Remove failures 120 | processed = [p for p in processed if p is not None] 121 | print(f"✅ Tokenised {len(processed):,} samples") 122 | 123 | # 3. Sort by token length --------------------------------------------------------------------- 124 | processed.sort(key=lambda x: x["n_tokens"]) 125 | 126 | # 4. Slice into buckets of N prompts ---------------------------------------------------------- 127 | print(f"📦 Bucketing {len(processed):,} prompts --> {bucket_size} prompts per file") 128 | for i in range(0, len(processed), bucket_size): 129 | bucket_idx = i // bucket_size 130 | bucket = processed[i : i + bucket_size] 131 | 132 | # ----- compute stats ----- 133 | tok_counts = [b["n_tokens"] for b in bucket] 134 | min_tok = min(tok_counts) 135 | max_tok = max(tok_counts) 136 | avg_tok = int(round(sum(tok_counts) / len(tok_counts))) 137 | 138 | # ----- build filename ----- 139 | out_fname = ( 140 | f"bucket_{bucket_idx:04d}" 141 | f"_avg{avg_tok}_min{min_tok}_max{max_tok}.json" 142 | ) 143 | out_path = os.path.join(output_path, out_fname) 144 | 145 | # ----- write JSON array of prompt strings ----- 146 | prompts = [" ".join(item["prompt"].split()) for item in bucket] # tidy whitespace 147 | with open(out_path, "w", encoding="utf-8") as fout: 148 | json.dump(prompts, fout, ensure_ascii=False, indent=2) 149 | 150 | print(f"- {out_fname} ({len(prompts)} prompts, {sum(tok_counts):,} tokens)") 151 | 152 | 153 | # ---------- Entry-point ---------- 154 | if __name__ == "__main__": 155 | parser = argparse.ArgumentParser( 156 | description="Bucket user prompts by token length. " 157 | "Each output file contains a fixed NUMBER of prompts.") 158 | parser.add_argument("--input_path", required=True, 159 | help="Directory containing *.parquet files") 160 | parser.add_argument("--output_path", required=True, 161 | help="Directory for bucketed prompt files") 162 | parser.add_argument("--bucket_size", type=int, default=5_000, 163 | help="Number of prompts per output file (default: 5 000)") 164 | parser.add_argument("--n_workers", type=int, default=8, 165 | help="Tokenisation workers (default: 8)") 166 | parser.add_argument("--think_format", action="store_true", 167 | help="Apply replacement to assistant messages") 168 | parser.add_argument("--system_prompt", action="store_true", 169 | help="Include system field as first chat message") 170 | 171 | args = parser.parse_args() 172 | mp.set_start_method("spawn", force=True) 173 | 174 | main(args.input_path, 175 | args.output_path, 176 | bucket_size=args.bucket_size, 177 | use_think_format=args.think_format, 178 | use_system_prompt=args.system_prompt, 179 | n_workers=args.n_workers) 180 | -------------------------------------------------------------------------------- /JacobiForcing/train/deprecated/flexattn_cllm_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | from torch.cuda.amp import autocast 4 | from transformers import Trainer 5 | from transformers.trainer_pt_utils import LabelSmoother 6 | 7 | import torch.nn.functional as F 8 | 9 | from torch.nn.attention.flex_attention import create_block_mask 10 | 11 | from functools import lru_cache 12 | 13 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 14 | 15 | class CllmTrainer(Trainer): 16 | def __init__(self, *args, accelerator=None, optimizer=None, lr_scheduler=None, train_dataloader=None, **kwargs): 17 | super().__init__(*args, **kwargs) 18 | args = kwargs["args"] 19 | self.accelerator = accelerator 20 | self.optimizer = optimizer 21 | self.lr_scheduler = lr_scheduler 22 | self.train_dataloader = train_dataloader 23 | 24 | self.train_step_cnt = 0 25 | self.max_new_tokens = args.max_new_tokens 26 | self.use_gt_labels = args.use_gt_labels 27 | 28 | # ---------------- Utilities ---------------- # 29 | @staticmethod 30 | def _to_int(x): 31 | return x.item() if isinstance(x, torch.Tensor) else int(x) 32 | 33 | def _unpack_sample(self, inputs): 34 | """ 35 | Extract a single sample. (Assumes per_device_train_batch_size == 1.) 36 | Required keys: 37 | - input_ids: [1, L] 38 | - prompt_ids_len: scalar or [1] 39 | - T: length of traj_position_indices (last uncorrupted token positions) in [1, T] 40 | """ 41 | input_ids = inputs["input_ids"][0] 42 | 43 | prompt_len = inputs["prompt_ids_len"] 44 | if isinstance(prompt_len, torch.Tensor): 45 | if prompt_len.dim() > 0: 46 | prompt_len = prompt_len[0] 47 | prompt_len = self._to_int(prompt_len) 48 | 49 | traj_position_indices = inputs["traj_position_indices"][0][0] 50 | traj_position_indices = [int(u) for u in traj_position_indices] 51 | T = len(traj_position_indices) 52 | 53 | return ( 54 | input_ids.to(self.args.device), 55 | prompt_len, 56 | T, 57 | ) 58 | 59 | @staticmethod 60 | def _index_layout(prompt_len: int, T: int, N): 61 | """Return lists of start indices for all k_j and last_j blocks in flattened sequence.""" 62 | k_starts = [prompt_len + 2 * j * N for j in range(T)] 63 | l_starts = [prompt_len + (2 * j + 1) * N for j in range(T)] 64 | return k_starts, l_starts 65 | 66 | # FlexAttention BlockMask 67 | def _build_block_mask(self, L: int, prompt_len: int, T: int, heads: int): 68 | N = self.max_new_tokens 69 | k_starts, l_starts = self._index_layout(prompt_len, T, N) 70 | ks = torch.tensor(k_starts, device=self.args.device) # [T] 71 | ls = torch.tensor(l_starts, device=self.args.device) # [T] 72 | num_traj = T 73 | 74 | def mask_mod(b, h, q, k): 75 | # q, k: any shape, torch tensors (can be batched) 76 | rel_q = q - prompt_len 77 | rel_k = k - prompt_len 78 | block_idx_q = torch.div(rel_q, N, rounding_mode="floor") 79 | block_idx_k = torch.div(rel_k, N, rounding_mode="floor") 80 | 81 | is_prompt_q = q < prompt_len 82 | is_prompt_k = k < prompt_len 83 | is_kj_q = (q >= prompt_len) & (block_idx_q % 2 == 0) 84 | is_lastj_q = (q >= prompt_len) & (block_idx_q % 2 == 1) 85 | is_kj_k = (k >= prompt_len) & (block_idx_k % 2 == 0) 86 | is_lastj_k = (k >= prompt_len) & (block_idx_k % 2 == 1) 87 | j_q = torch.clamp(block_idx_q // 2, min=0) 88 | j_k = torch.clamp(block_idx_k // 2, min=0) 89 | 90 | # k_in_prev_last: is_lastj_k & (block_idx_k < 2 * j_q) 91 | k_in_prev_last = is_lastj_k & (block_idx_k < 2 * j_q) 92 | 93 | same_kj_block = is_kj_q & is_kj_k & (block_idx_q == block_idx_k) 94 | same_lastj_block = is_lastj_q & is_lastj_k & (block_idx_q == block_idx_k) 95 | 96 | ks_per_q = ks[torch.clamp(j_q, max=len(ks) - 1)] 97 | ls_per_q = ls[torch.clamp(j_q, max=len(ls) - 1)] 98 | 99 | mask_prompt = is_prompt_q & (k <= q) 100 | 101 | mask_kj = is_kj_q & ( 102 | is_prompt_k | 103 | k_in_prev_last | 104 | (same_kj_block & (k >= ks_per_q) & (k <= q)) 105 | ) 106 | 107 | mask_lastj = is_lastj_q & ( 108 | is_prompt_k | 109 | k_in_prev_last | 110 | (same_lastj_block & (k >= ls_per_q) & (k <= q)) 111 | ) 112 | 113 | mask = mask_prompt | mask_kj | mask_lastj 114 | return mask 115 | 116 | block_mask = create_block_mask( 117 | mask_mod, B=1, H=heads, Q_LEN=L, KV_LEN=L, device=self.args.device 118 | ) 119 | 120 | return block_mask 121 | 122 | # ---------------- Core Training Step ---------------- 123 | def training_step(self, model, inputs, num_items_in_batch=None): 124 | self.train_step_cnt += 1 125 | return self._one_pass_losses_step(model, inputs) 126 | 127 | def _one_pass_losses_step(self, model, inputs): 128 | """ 129 | Single forward pass to compute: 130 | - AR loss: first u_j tokens of each last_j (shifted LM) 131 | - Consistency loss: corrupted tail of each k_j vs teacher last_j at same offsets 132 | """ 133 | input_ids, prompt_len, T = self._unpack_sample(inputs) 134 | 135 | # Basic layout 136 | L = input_ids.size(0) 137 | expected_len = prompt_len + 2 * T * self.max_new_tokens 138 | if L != expected_len: 139 | raise ValueError( 140 | f"Length mismatch: L={L}, expected {expected_len} (prompt_len={prompt_len}, T={T}, n_token_sequence_size={self.max_new_tokens})" 141 | ) 142 | 143 | # ---- FlexAttention: build BlockMask & forward once ---- 144 | num_heads = getattr(getattr(model, "config", None), "num_attention_heads", 1) 145 | blk_mask = self._build_block_mask(L, prompt_len, T, num_heads) 146 | 147 | outputs = model( 148 | input_ids=input_ids.unsqueeze(0), 149 | block_mask=blk_mask, 150 | attn_implementation="flex_attention", 151 | ) 152 | # [1, L, V] 153 | logits = outputs.logits 154 | 155 | # ========== AR loss ========== 156 | ar_labels = torch.full((L,), IGNORE_TOKEN_ID, device=self.args.device) 157 | k_starts, l_starts = self._index_layout(prompt_len, T, self.max_new_tokens) 158 | for j in range(T): 159 | ls = l_starts[j] 160 | ar_labels[ls : ls + self.max_new_tokens] = input_ids[ls : ls + self.max_new_tokens] 161 | 162 | label_smoother = LabelSmoother(epsilon=0.1, ignore_index=IGNORE_TOKEN_ID) 163 | loss_ar = label_smoother( 164 | outputs, ar_labels.unsqueeze(0), shift_labels=True 165 | ) 166 | loss_ar = loss_ar * 10.0 # scale (as in your previous code) 167 | 168 | # ========== Consistency loss (hard) ========== 169 | student_positions, teacher_positions = [], [] 170 | for j in range(T): 171 | ks, ls = k_starts[j], l_starts[j] 172 | offs = range(self.max_new_tokens) 173 | student_positions.extend(ks + off for off in offs) 174 | teacher_positions.extend(ls + off for off in offs) 175 | 176 | if len(student_positions) == 0: 177 | loss_consistency = torch.zeros((), device=self.args.device) 178 | else: 179 | student_logits_sel = logits[0, student_positions, :] 180 | 181 | # Hard targets: token ids from the aligned last_j positions 182 | target_ids = input_ids[teacher_positions].to(device=self.args.device, dtype=torch.long) 183 | loss_consistency = F.cross_entropy(student_logits_sel, target_ids, reduction="mean") 184 | 185 | total_loss = loss_ar + loss_consistency 186 | 187 | if self.args.qlora: 188 | total_loss.requires_grad = True 189 | 190 | if self.args.local_rank == 0: 191 | wandb.log( 192 | { 193 | "ar loss": float(loss_ar.detach().cpu()), 194 | "consistency loss": float(loss_consistency.detach().cpu()), 195 | } 196 | ) 197 | 198 | with self.accelerator.accumulate(model): 199 | self.accelerator.backward(total_loss) 200 | 201 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 202 | torch.distributed.barrier() 203 | return total_loss.detach() 204 | -------------------------------------------------------------------------------- /JacobiForcing/train/cllm_trainer.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | from transformers import Trainer 4 | from transformers.trainer_pt_utils import LabelSmoother 5 | import wandb 6 | import random 7 | from torch.utils.data import DataLoader 8 | from torch.cuda.amp import autocast 9 | 10 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 11 | 12 | class CllmTrainer(Trainer): 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | args = kwargs["args"] 16 | self.train_step_cnt = 0 17 | self.max_new_tokens = args.max_new_tokens 18 | self.use_gt_labels = args.use_gt_labels 19 | 20 | def training_step(self, model, inputs, num_items_in_batch): 21 | self.train_step_cnt += 1 22 | return self.consistency_training_step(model, inputs) 23 | 24 | def consistency_training_step(self, model, inputs): 25 | 26 | max_new_tokens = self.max_new_tokens 27 | 28 | jacobian_trajectory = inputs["jacobian_trajectory"] 29 | input_masks = inputs["attention_mask"] 30 | bsz = jacobian_trajectory[0].shape[0] 31 | eos_reached = torch.tensor([False] * bsz).to(model.device) 32 | 33 | ### tokens generated after are set to 34 | for i in range(len(jacobian_trajectory)): 35 | for j in range(bsz): 36 | trajectory_len = torch.sum(input_masks, dim=-1) 37 | # find the first accurate 38 | eos_positions = torch.where(jacobian_trajectory[i][j, :(trajectory_len[j]-max_new_tokens)]==self.processing_class.eos_token_id)[0] 39 | if len(eos_positions)==0: 40 | continue 41 | # otherwise, set tokens coming after the accurate as pad 42 | eos_reached[j] = True 43 | trajectory_copy = jacobian_trajectory[i].clone().detach() 44 | eos_pos = eos_positions[0] 45 | trajectory_copy[j, int(eos_pos)+1:] = self.processing_class.pad_token_id 46 | jacobian_trajectory[i] = trajectory_copy 47 | 48 | ### compute AutoRegression loss ### 49 | # use labels to avoid pattern collapse 50 | if self.use_gt_labels: 51 | labels = inputs['labels_ids'] 52 | else: 53 | labels = inputs['teacher_output_ids'] 54 | # TODO: check if it's right when batch size > 1 55 | labels = torch.tensor(labels).to(model.device) 56 | attention_mask = torch.full_like(labels, 1).to(model.device) 57 | # FlashAttention only support fp16 and bf16 data type 58 | with autocast(dtype=torch.bfloat16): 59 | label_student_model_output = model(labels, attention_mask) 60 | 61 | label_smoother = LabelSmoother(epsilon=0.1, ignore_index= -100) 62 | loss_ar = label_smoother(label_student_model_output, labels, shift_labels=True) 63 | loss_ar*=10 64 | if self.args.qlora: 65 | loss_ar.requires_grad = True 66 | print(f'loss ar: {loss_ar} computed! performing backward pass...') 67 | with self.accelerator.accumulate(model): 68 | self.accelerator.backward(loss_ar) 69 | 70 | # ### compute soft Consistency loss (global) ### 71 | # # Get teacher logits from converged point 72 | # attention_mask = torch.full_like(jacobian_trajectory[0], 1).to(model.device) 73 | # attention_mask = jacobian_trajectory[-1] != self.processing_class.pad_token_id 74 | # logits_last = self.get_logits(model, jacobian_trajectory[-1].clone().detach(), attention_mask) 75 | # # random select one point from trajectory 76 | # i = random.choice(range(len(jacobian_trajectory))[:-1]) 77 | 78 | # attention_mask = torch.full_like(jacobian_trajectory[0], 1).to(jacobian_trajectory[0].device) 79 | # attention_mask = jacobian_trajectory[i] != self.processing_class.pad_token_id 80 | # logits_i = self.get_logits(model, jacobian_trajectory[i].clone().detach(), attention_mask) 81 | 82 | # output_mask = jacobian_trajectory[i][..., 1:] == self.processing_class.pad_token_id 83 | # # We do not calculate the cross entrophy of same logits to alleviate misleading gradients 84 | # for j in range(bsz): 85 | # end_of_mask_position = torch.where(jacobian_trajectory[i][j, 1:] != jacobian_trajectory[-1][j, 1:])[0] 86 | # if len(end_of_mask_position)==0: 87 | # output_mask[j, :] = True 88 | # else: 89 | # output_mask[j, :end_of_mask_position[0]] = True 90 | 91 | # loss_global = self.soft_cross_entropy( 92 | # logits_i[..., :-1, :].float(), # logits generated by the last token is dropped 93 | # logits_last[..., :-1, :].to(logits_i.device).clone().detach().float(), 94 | # output_mask.to(logits_i.device) 95 | # ) 96 | # if self.args.qlora: 97 | # loss_global.requires_grad = True 98 | # print(f'loss global {loss_global} computed! performing backward pass...') 99 | # with self.accelerator.accumulate(model): 100 | # self.accelerator.backward(loss_global) 101 | 102 | ### compute hard Consistency loss (global) ### 103 | # random select one point from trajectory 104 | i = random.choice(range(len(jacobian_trajectory))[:-1]) 105 | 106 | attention_mask = torch.full_like(jacobian_trajectory[0], 1).to(jacobian_trajectory[0].device) 107 | attention_mask = jacobian_trajectory[i] != self.processing_class.pad_token_id 108 | with autocast(dtype=torch.bfloat16): 109 | trajectiory_i_student_model_output = model(jacobian_trajectory[i].clone().detach(), attention_mask) 110 | 111 | output_mask = jacobian_trajectory[i][..., 1:] == self.processing_class.pad_token_id 112 | # We do not calculate the cross entrophy of same logits to alleviate misleading gradients 113 | for j in range(bsz): 114 | end_of_mask_position = torch.where(jacobian_trajectory[i][j, 1:] != jacobian_trajectory[-1][j, 1:])[0] 115 | if len(end_of_mask_position)==0: 116 | jacobian_trajectory[-1][j, 1:] = IGNORE_TOKEN_ID 117 | else: 118 | jacobian_trajectory[-1][j, :end_of_mask_position[0]] = IGNORE_TOKEN_ID 119 | loss_global = label_smoother(trajectiory_i_student_model_output, jacobian_trajectory[-1].clone().detach(), shift_labels=True) 120 | if self.args.qlora: 121 | loss_global.requires_grad = True 122 | print(f'loss global {loss_global} computed! performing backward pass...') 123 | 124 | with self.accelerator.accumulate(model): 125 | self.accelerator.backward(loss_global) 126 | 127 | ### Logging results on wandb and syncing processes ### 128 | if self.args.local_rank == 0: 129 | wandb.log({"ar loss": loss_ar}) 130 | wandb.log({"consistency loss": loss_global}) 131 | 132 | # sync processes 133 | torch.distributed.barrier() 134 | loss = loss_ar.detach() + loss_global.detach() 135 | 136 | return loss 137 | 138 | def log(self, logs, *args, **kwargs): 139 | 140 | if 'loss' in logs and logs['loss'] == -1: 141 | del logs['loss'] 142 | 143 | super().log(logs) 144 | 145 | 146 | def get_train_dataloader(self): 147 | # Create custom DataLoader with shuffle set to False 148 | shuffle = True 149 | dataloader_params = { 150 | "batch_size": self.args.per_device_train_batch_size, 151 | "shuffle": shuffle, 152 | "num_workers": self.args.dataloader_num_workers, 153 | "pin_memory": self.args.dataloader_pin_memory, 154 | } 155 | 156 | return self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) 157 | 158 | ###################### Helper Functions ############################# 159 | def soft_cross_entropy(self, predicts, targets, padding_mask): 160 | # TODO: support batch_size >1 here. 161 | if (~padding_mask).sum() == 0: 162 | return 0*predicts[0][0][0] 163 | predict_log_prob = torch.nn.functional.log_softmax(predicts, dim=-1) 164 | targets_prob = torch.nn.functional.softmax(targets, dim=-1) 165 | entropy = -targets_prob * predict_log_prob 166 | expand_mask = padding_mask.unsqueeze(-1).expand_as(entropy) 167 | entropy.masked_fill_(expand_mask, 0) 168 | mean_entropy = entropy.sum() / (~padding_mask).sum() 169 | return mean_entropy 170 | 171 | def get_logits(self, model, input_ids, attention_mask): 172 | 173 | # FlashAttention only support fp16 and bf16 data type 174 | with autocast(dtype=torch.bfloat16): 175 | logits = model( 176 | input_ids=input_ids, 177 | attention_mask=attention_mask, 178 | ).logits 179 | 180 | return logits 181 | 182 | -------------------------------------------------------------------------------- /assets/baseline_comparison.py: -------------------------------------------------------------------------------- 1 | # baseline_comparison.py 2 | # Usage: 3 | # python3 baseline_comparison.py \ 4 | # --csv data.csv \ 5 | # --out chart.png \ 6 | # --baseline-throughput 40.0 \ 7 | # --baseline-pass1 90.0 8 | # 9 | # CSV columns (two supported formats): 10 | # A) Absolute values: 11 | # technique,throughput_tps,pass1,train_tokens_B 12 | # (If --baseline-* not provided, baseline = first row.) 13 | # B) Precomputed deltas (skip --baseline-*): 14 | # technique,delta_throughput_tps,delta_pass1,train_tokens_B 15 | # 16 | # Notes: 17 | # - Bubble size is proportional to sqrt(training tokens, in billions). 18 | # - Labels have a small offset; tune --label-offset-x/--label-offset-y. 19 | 20 | import argparse 21 | import math 22 | import os 23 | import re 24 | import numpy as np 25 | import pandas as pd 26 | import matplotlib.pyplot as plt 27 | 28 | PURPLE = "#7e57c2" # single color for all data points 29 | 30 | def load_data(path: str) -> pd.DataFrame: 31 | return pd.read_csv(path) 32 | 33 | def _parse_tokens_to_float(x) -> float: 34 | """Accept numbers or strings like '~1', '322+', '50B' and return a float.""" 35 | if pd.isna(x): 36 | return 0.0 37 | if isinstance(x, (int, float)): 38 | return float(x) 39 | m = re.search(r'(\d+(\.\d+)?)', str(x)) 40 | return float(m.group(1)) if m else 0.0 41 | 42 | def compute_deltas(df: pd.DataFrame, 43 | baseline_tps: float | None, 44 | baseline_pass1: float | None) -> pd.DataFrame: 45 | df = df.copy() 46 | 47 | has_deltas = {"delta_throughput_tps", "delta_pass1"}.issubset(df.columns) 48 | if has_deltas: 49 | # Normalize delta column names 50 | df.rename(columns={ 51 | "delta_throughput_tps": "delta_throughput", 52 | "delta_pass1": "delta_pass1" 53 | }, inplace=True) 54 | # Recover absolute metrics if baseline provided or can be inferred 55 | if baseline_tps is not None and baseline_pass1 is not None: 56 | df["abs_throughput"] = df["delta_throughput"] + float(baseline_tps) 57 | df["abs_pass1"] = df["delta_pass1"] + float(baseline_pass1) 58 | else: 59 | # Cannot compute absolute values without a baseline -> mark NaN 60 | df["abs_throughput"] = np.nan 61 | df["abs_pass1"] = np.nan 62 | return df 63 | 64 | # Absolute metrics path -> need a baseline (explicit or first row) 65 | needed = {"throughput_tps", "pass1"} 66 | if not needed.issubset(df.columns): 67 | raise ValueError("CSV must include either {delta_throughput_tps, delta_pass1} " 68 | "or {throughput_tps, pass1}.") 69 | if baseline_tps is None or baseline_pass1 is None: 70 | baseline_tps = float(df.iloc[0]["throughput_tps"]) 71 | baseline_pass1 = float(df.iloc[0]["pass1"]) 72 | 73 | df["delta_throughput"] = df["throughput_tps"] - baseline_tps 74 | df["delta_pass1"] = df["pass1"] - baseline_pass1 75 | df["abs_throughput"] = df["throughput_tps"] 76 | df["abs_pass1"] = df["pass1"] 77 | return df 78 | 79 | def bubble_sizes(tokens_B_series: pd.Series, scale: float = 80.0): 80 | # Matplotlib scatter 's=' is in points^2; area grows with sqrt(tokens). 81 | vals = tokens_B_series.apply(_parse_tokens_to_float).clip(lower=0.0) 82 | return [scale * math.sqrt(v) for v in vals] 83 | 84 | def main(): 85 | p = argparse.ArgumentParser() 86 | p.add_argument("--csv", required=True) 87 | p.add_argument("--out", default="paper/baseline_comparisons.png") 88 | # AR-adapted Parallel Decoders: Speed-Quality Trade-offs (with Costs) 89 | p.add_argument("--title", default="Speed-Quality Trade-offs (on A100 GPU)") 90 | p.add_argument("--baseline-throughput", type=float, default=None) 91 | p.add_argument("--baseline-pass1", type=float, default=None) 92 | p.add_argument("--size-scale", type=float, default=80.0) 93 | p.add_argument("--label-offset-x", type=float, default=+8.0) 94 | p.add_argument("--label-offset-y", type=float, default=0.0) 95 | args = p.parse_args() 96 | 97 | df = load_data(args.csv) 98 | df_delta = compute_deltas(df, args.baseline_throughput, args.baseline_pass1) 99 | 100 | if "train_tokens_B" not in df.columns: 101 | raise ValueError("CSV must include 'train_tokens_B' (billions of tokens) for bubble sizes.") 102 | 103 | # Bubble sizes (area) 104 | sizes = bubble_sizes(df["train_tokens_B"], scale=args.size_scale) 105 | 106 | fig, ax = plt.subplots(figsize=(8, 6)) 107 | 108 | sc = ax.scatter( 109 | df_delta["delta_throughput"], 110 | df_delta["delta_pass1"], 111 | s=sizes, 112 | color=PURPLE, 113 | ) 114 | 115 | # Labels near each point 116 | color="black" 117 | for i, r in df_delta.iterrows(): 118 | if "Jacobi Forcing" in df_delta.loc[i, "technique"]: 119 | label_offset_x = -38.0 120 | label_offset_y = 1.0 121 | color = "red" 122 | elif "qwen-2.5-coder-7b-instruct" in df_delta.loc[i, "technique"]: 123 | label_offset_x = -8.0 124 | label_offset_y = 1.0 125 | else: 126 | label_offset_x = args.label_offset_x 127 | label_offset_y = args.label_offset_y 128 | ax.annotate( 129 | df_delta.loc[i, "technique"], 130 | xy=(r["delta_throughput"], r["delta_pass1"]), 131 | xytext=( 132 | r["delta_throughput"] + label_offset_x, 133 | r["delta_pass1"] + label_offset_y 134 | ), 135 | fontsize=16, 136 | color=color 137 | ) 138 | 139 | 140 | # Axes lines (dashed) through zero 141 | ax.axhline(0, linewidth=1, linestyle="--") 142 | ax.axvline(0, linewidth=1, linestyle="--") 143 | 144 | # Mark the baseline (origin) and annotate absolute baseline values 145 | ax.scatter(0, 0, 146 | s=260, 147 | marker="X", 148 | color=PURPLE, 149 | zorder=5, 150 | alpha=0.5 151 | ) 152 | 153 | ax.annotate( 154 | "Baseline (41.3 t/s, 87.8%)", 155 | xy=(0, 0), 156 | xytext=(-40, 4), # adjust offsets if needed 157 | fontsize=16, 158 | bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="gray", alpha=0.8), 159 | ) 160 | 161 | # mark the data point with highest (accuracy × throughput) --- 162 | # Use absolute metrics when available (df_delta['abs_*'] may be NaN if deltas w/o baseline). 163 | if "abs_throughput" in df_delta.columns and "abs_pass1" in df_delta.columns: 164 | score = df_delta["abs_throughput"] * df_delta["abs_pass1"] 165 | if score.notna().any(): 166 | idx = score.idxmax() 167 | x_star = df_delta.loc[idx, "delta_throughput"] 168 | y_star = df_delta.loc[idx, "delta_pass1"] 169 | abs_tps = df_delta.loc[idx, "abs_throughput"] 170 | abs_acc = df_delta.loc[idx, "abs_pass1"] 171 | 172 | #ax.scatter(x_star, y_star, 173 | # s=320, 174 | # marker="*", 175 | # color="red", 176 | # edgecolor="white", 177 | # linewidth=0.8, 178 | # zorder=6, 179 | # alpha=0.5, 180 | #) 181 | 182 | ax.annotate( 183 | f"{abs_tps:.1f} t/s, {abs_acc:.1f}%", 184 | xy=(x_star, y_star), 185 | xytext=(x_star - 26.75, y_star - 8.0), 186 | fontsize=16, 187 | bbox=dict(boxstyle="round,pad=0.25", fc="white", ec="gray", alpha=0.8), 188 | arrowprops=dict(arrowstyle="->", lw=1) 189 | ) 190 | 191 | ax.set_xlabel("Δ Throughput (tokens/sec) vs. Baseline", fontsize=20) 192 | ax.set_ylabel("Δ pass@1 on HumanEval vs. Baseline", fontsize=20) 193 | ax.set_title(args.title, fontsize=20) 194 | 195 | # --- Size legend: fixed 0B, 1B, 10B, 100B --- 196 | legend_levels = [0.001, 1.0, 10.0, 100.0] 197 | legend_sizes = [args.size_scale * math.sqrt(lv) for lv in legend_levels] 198 | handles = [ax.scatter([], [], s=s, color=PURPLE, alpha=1.0) for s in legend_sizes] 199 | labels = [f"{int(lv)}B" for lv in legend_levels] 200 | size_legend = ax.legend(handles, labels, title="adaptation cost (train tokens)", loc="lower right", 201 | frameon=True, labelspacing=1.0, fontsize=12, title_fontsize=12) 202 | ax.add_artist(size_legend) 203 | 204 | # Make axis numbers larger (optional) 205 | ax.tick_params(axis="both", which="major", labelsize=16) 206 | 207 | # Ensure output directory exists 208 | out_dir = os.path.dirname(args.out) 209 | if out_dir: 210 | os.makedirs(out_dir, exist_ok=True) 211 | 212 | ax.set_ylim(top=7) 213 | 214 | fig.tight_layout() 215 | fig.savefig(args.out, dpi=200) 216 | print(f"Saved {args.out}") 217 | 218 | if __name__ == "__main__": 219 | main() 220 | -------------------------------------------------------------------------------- /JacobiForcing/train/deprecated/soft_cllm_loss_trainer.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | from transformers import Trainer 4 | from transformers.trainer_pt_utils import LabelSmoother 5 | import wandb 6 | import random 7 | from torch.utils.data import DataLoader 8 | from torch.cuda.amp import autocast 9 | 10 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 11 | 12 | class CllmTrainer(Trainer): 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | args = kwargs["args"] 16 | self.train_step_cnt = 0 17 | self.max_new_tokens = args.max_new_tokens 18 | self.use_gt_labels = args.use_gt_labels 19 | 20 | def training_step(self, model, inputs, num_items_in_batch): 21 | self.train_step_cnt += 1 22 | return self.consistency_training_step(model, inputs) 23 | 24 | def consistency_training_step(self, model, inputs): 25 | 26 | max_new_tokens = self.max_new_tokens 27 | 28 | jacobian_trajectory = inputs["jacobian_trajectory"] 29 | input_masks = inputs["attention_mask"] 30 | bsz = jacobian_trajectory[0].shape[0] 31 | eos_reached = torch.tensor([False] * bsz).to(model.device) 32 | 33 | ### tokens generated after are set to 34 | for i in range(len(jacobian_trajectory)): 35 | for j in range(bsz): 36 | trajectory_len = torch.sum(input_masks, dim=-1) 37 | # find the first accurate 38 | eos_positions = torch.where(jacobian_trajectory[i][j, :(trajectory_len[j]-max_new_tokens)]==self.processing_class.eos_token_id)[0] 39 | if len(eos_positions)==0: 40 | continue 41 | # otherwise, set tokens coming after the accurate as pad 42 | eos_reached[j] = True 43 | trajectory_copy = jacobian_trajectory[i].clone().detach() 44 | eos_pos = eos_positions[0] 45 | trajectory_copy[j, int(eos_pos)+1:] = self.processing_class.pad_token_id 46 | jacobian_trajectory[i] = trajectory_copy 47 | 48 | ### compute AutoRegression loss ### 49 | # use labels to avoid pattern collapse 50 | if self.use_gt_labels: 51 | labels = inputs['labels_ids'] 52 | else: 53 | labels = inputs['teacher_output_ids'] 54 | # TODO: check if it's right when batch size > 1 55 | labels = torch.tensor(labels).to(model.device) 56 | attention_mask = torch.full_like(labels, 1).to(model.device) 57 | # FlashAttention only support fp16 and bf16 data type 58 | with autocast(dtype=torch.bfloat16): 59 | label_student_model_output = model(labels, attention_mask) 60 | 61 | label_smoother = LabelSmoother(epsilon=0.1, ignore_index= -100) 62 | loss_ar = label_smoother(label_student_model_output, labels, shift_labels=True) 63 | loss_ar *= 10 64 | if self.args.qlora: 65 | loss_ar.requires_grad = True 66 | print(f'loss ar: {loss_ar} computed! performing backward pass...') 67 | with self.accelerator.accumulate(model): 68 | self.accelerator.backward(loss_ar) 69 | 70 | # ========================================== # 71 | ### compute soft Consistency loss (global) ### 72 | # Get teacher logits from converged point 73 | attention_mask_last = torch.full_like(jacobian_trajectory[0], 1).to(model.device) 74 | attention_mask_last = jacobian_trajectory[-1] != self.processing_class.pad_token_id 75 | logits_last = self.get_logits(model, jacobian_trajectory[-1].clone().detach(), attention_mask_last) 76 | 77 | # random select one point from trajectory 78 | i = random.choice(range(len(jacobian_trajectory))[:-1]) 79 | attention_mask = torch.full_like(jacobian_trajectory[0], 1).to(jacobian_trajectory[0].device) 80 | attention_mask = jacobian_trajectory[i] != self.processing_class.pad_token_id 81 | logits_i = self.get_logits(model, jacobian_trajectory[i].clone().detach(), attention_mask) 82 | 83 | output_mask = jacobian_trajectory[i][..., 1:] == self.processing_class.pad_token_id 84 | # We do not calculate the cross entrophy of same logits to minimize misleading gradients 85 | for j in range(bsz): 86 | end_of_mask_position = torch.where(jacobian_trajectory[i][j, 1:] != jacobian_trajectory[-1][j, 1:])[0] 87 | if len(end_of_mask_position)==0: 88 | output_mask[j, :] = True 89 | else: 90 | output_mask[j, :end_of_mask_position[0]] = True 91 | 92 | loss_global = self.soft_cross_entropy( 93 | logits_i[..., :-1, :].float(), # logits generated by the last token is dropped 94 | logits_last[..., :-1, :].to(logits_i.device).clone().detach().float(), 95 | output_mask.to(logits_i.device) 96 | ) 97 | if self.args.qlora: 98 | loss_global.requires_grad = True 99 | print(f'loss global {loss_global} computed! performing backward pass...') 100 | with self.accelerator.accumulate(model): 101 | self.accelerator.backward(loss_global) 102 | 103 | # ========================================== # 104 | ### compute hard Consistency loss (global) ### 105 | # random select one point from trajectory 106 | #i = random.choice(range(len(jacobian_trajectory))[:-1]) 107 | 108 | #attention_mask = torch.full_like(jacobian_trajectory[0], 1).to(jacobian_trajectory[0].device) 109 | #attention_mask = jacobian_trajectory[i] != self.processing_class.pad_token_id 110 | #with autocast(dtype=torch.bfloat16): 111 | # trajectiory_i_student_model_output = model(jacobian_trajectory[i].clone().detach(), attention_mask) 112 | 113 | #output_mask = jacobian_trajectory[i][..., 1:] == self.processing_class.pad_token_id 114 | # We do not calculate the cross entrophy of same logits to alleviate misleading gradients 115 | #for j in range(bsz): 116 | # end_of_mask_position = torch.where(jacobian_trajectory[i][j, 1:] != jacobian_trajectory[-1][j, 1:])[0] 117 | # if len(end_of_mask_position)==0: 118 | # jacobian_trajectory[-1][j, 1:] = IGNORE_TOKEN_ID 119 | # else: 120 | # jacobian_trajectory[-1][j, :end_of_mask_position[0]] = IGNORE_TOKEN_ID 121 | #loss_global = label_smoother(trajectiory_i_student_model_output, jacobian_trajectory[-1].clone().detach(), shift_labels=True) 122 | #if self.args.qlora: 123 | # loss_global.requires_grad = True 124 | #print(f'loss global {loss_global} computed! performing backward pass...') 125 | 126 | #with self.accelerator.accumulate(model): 127 | # self.accelerator.backward(loss_global) 128 | 129 | ### Logging results on wandb and syncing processes ### 130 | if self.args.local_rank == 0: 131 | wandb.log({"ar loss": loss_ar}) 132 | wandb.log({"consistency loss": loss_global}) 133 | 134 | # sync processes 135 | torch.distributed.barrier() 136 | loss = loss_ar.detach() + loss_global.detach() 137 | 138 | return loss 139 | 140 | def log(self, logs, *args, **kwargs): 141 | 142 | if 'loss' in logs and logs['loss'] == -1: 143 | del logs['loss'] 144 | 145 | super().log(logs) 146 | 147 | 148 | def get_train_dataloader(self): 149 | # Create custom DataLoader with shuffle set to False 150 | shuffle = True 151 | dataloader_params = { 152 | "batch_size": self.args.per_device_train_batch_size, 153 | "shuffle": shuffle, 154 | "num_workers": self.args.dataloader_num_workers, 155 | "pin_memory": self.args.dataloader_pin_memory, 156 | } 157 | 158 | return self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) 159 | 160 | ###################### Helper Functions ############################# 161 | def soft_cross_entropy(self, predicts, targets, padding_mask): 162 | # TODO: support batch_size >1 here. 163 | if (~padding_mask).sum() == 0: 164 | return 0*predicts[0][0][0] 165 | predict_log_prob = torch.nn.functional.log_softmax(predicts, dim=-1) 166 | targets_prob = torch.nn.functional.softmax(targets, dim=-1) 167 | entropy = -targets_prob * predict_log_prob 168 | expand_mask = padding_mask.unsqueeze(-1).expand_as(entropy) 169 | entropy.masked_fill_(expand_mask, 0) 170 | mean_entropy = entropy.sum() / (~padding_mask).sum() 171 | return mean_entropy 172 | 173 | def get_logits(self, model, input_ids, attention_mask): 174 | 175 | # FlashAttention only support fp16 and bf16 data type 176 | with autocast(dtype=torch.bfloat16): 177 | logits = model( 178 | input_ids=input_ids, 179 | attention_mask=attention_mask, 180 | ).logits 181 | 182 | return logits 183 | 184 | -------------------------------------------------------------------------------- /JacobiForcing/train/deprecated/soft_flexattn_cllm_trainer_legacy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | from torch.cuda.amp import autocast 4 | from transformers import Trainer 5 | from transformers.trainer_pt_utils import LabelSmoother 6 | 7 | import torch.nn.functional as F 8 | 9 | from torch.nn.attention.flex_attention import create_block_mask 10 | 11 | from functools import lru_cache 12 | 13 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 14 | 15 | class CllmTrainer(Trainer): 16 | def __init__(self, *args, accelerator=None, optimizer=None, lr_scheduler=None, train_dataloader=None, **kwargs): 17 | super().__init__(*args, **kwargs) 18 | args = kwargs["args"] 19 | self.accelerator = accelerator 20 | self.optimizer = optimizer 21 | self.lr_scheduler = lr_scheduler 22 | self.train_dataloader = train_dataloader 23 | 24 | self.train_step_cnt = 0 25 | 26 | self.max_new_tokens = args.max_new_tokens 27 | self.use_gt_labels = args.use_gt_labels 28 | # cache BlockMasks keyed by (L, prompt_len, heads) — no u_list dependence 29 | self._blockmask_cache = {} 30 | 31 | # ---------------- Utilities ---------------- # 32 | 33 | @staticmethod 34 | def _to_int(x): 35 | return x.item() if isinstance(x, torch.Tensor) else int(x) 36 | 37 | def _unpack_sample(self, inputs): 38 | """ 39 | Extract a single sample. (Assumes per_device_train_batch_size == 1.) 40 | Required keys: 41 | - input_ids: [1, L] 42 | - prompt_ids_len: scalar or [1] 43 | - T: length of traj_position_indices (last uncorrupted token positions) in [1, T] 44 | """ 45 | input_ids = inputs["input_ids"][0] 46 | prompt_len = inputs["prompt_ids_len"] 47 | if isinstance(prompt_len, torch.Tensor): 48 | if prompt_len.dim() > 0: 49 | prompt_len = prompt_len[0] 50 | prompt_len = self._to_int(prompt_len) 51 | 52 | traj_position_indices = inputs["traj_position_indices"][0][0] 53 | traj_position_indices = [int(u) for u in traj_position_indices] 54 | T = len(traj_position_indices) 55 | 56 | return ( 57 | input_ids.to(self.args.device), 58 | prompt_len, 59 | T, 60 | ) 61 | 62 | @staticmethod 63 | def _index_layout(prompt_len: int, T: int, N): 64 | """Return lists of start indices for all k_j and last_j blocks in flattened sequence.""" 65 | k_starts = [prompt_len + 2 * j * N for j in range(T)] 66 | l_starts = [prompt_len + (2 * j + 1) * N for j in range(T)] 67 | return k_starts, l_starts 68 | 69 | # FlexAttention BlockMask 70 | def _build_block_mask(self, L: int, prompt_len: int, T: int, heads: int): 71 | N = self.max_new_tokens 72 | k_starts, l_starts = self._index_layout(prompt_len, T, N) 73 | ks = torch.tensor(k_starts, device=self.args.device) # [T] 74 | ls = torch.tensor(l_starts, device=self.args.device) # [T] 75 | num_traj = T 76 | 77 | def mask_mod(b, h, q, k): 78 | # q, k: any shape, torch tensors (can be batched) 79 | rel_q = q - prompt_len 80 | rel_k = k - prompt_len 81 | block_idx_q = torch.div(rel_q, N, rounding_mode="floor") 82 | block_idx_k = torch.div(rel_k, N, rounding_mode="floor") 83 | 84 | is_prompt_q = q < prompt_len 85 | is_prompt_k = k < prompt_len 86 | is_kj_q = (q >= prompt_len) & (block_idx_q % 2 == 0) 87 | is_lastj_q = (q >= prompt_len) & (block_idx_q % 2 == 1) 88 | is_kj_k = (k >= prompt_len) & (block_idx_k % 2 == 0) 89 | is_lastj_k = (k >= prompt_len) & (block_idx_k % 2 == 1) 90 | j_q = torch.clamp(block_idx_q // 2, min=0) 91 | j_k = torch.clamp(block_idx_k // 2, min=0) 92 | 93 | # k_in_prev_last: is_lastj_k & (block_idx_k < 2 * j_q) 94 | k_in_prev_last = is_lastj_k & (block_idx_k < 2 * j_q) 95 | 96 | same_kj_block = is_kj_q & is_kj_k & (block_idx_q == block_idx_k) 97 | same_lastj_block = is_lastj_q & is_lastj_k & (block_idx_q == block_idx_k) 98 | 99 | ks_per_q = ks[torch.clamp(j_q, max=len(ks) - 1)] 100 | ls_per_q = ls[torch.clamp(j_q, max=len(ls) - 1)] 101 | 102 | mask_prompt = is_prompt_q & (k <= q) 103 | 104 | mask_kj = is_kj_q & ( 105 | is_prompt_k | 106 | k_in_prev_last | 107 | (same_kj_block & (k >= ks_per_q) & (k <= q)) 108 | ) 109 | 110 | mask_lastj = is_lastj_q & ( 111 | is_prompt_k | 112 | k_in_prev_last | 113 | (same_lastj_block & (k >= ls_per_q) & (k <= q)) 114 | ) 115 | 116 | mask = mask_prompt | mask_kj | mask_lastj 117 | return mask 118 | 119 | block_mask = create_block_mask( 120 | mask_mod, B=1, H=heads, Q_LEN=L, KV_LEN=L, device=self.args.device 121 | ) 122 | 123 | return block_mask 124 | 125 | 126 | # ---------------- Core Training Step ---------------- 127 | def training_step(self, model, inputs, num_items_in_batch=None): 128 | self.train_step_cnt += 1 129 | return self._one_pass_losses_step(model, inputs) 130 | 131 | def _one_pass_losses_step(self, model, inputs): 132 | """ 133 | Single forward pass to compute: 134 | - AR loss: first u_j tokens of each last_j (shifted LM) 135 | - Consistency loss: corrupted tail of each k_j vs teacher last_j at same offsets 136 | """ 137 | input_ids, prompt_len, T = self._unpack_sample(inputs) 138 | 139 | L = input_ids.size(0) 140 | expected_len = prompt_len + 2 * T * self.max_new_tokens 141 | if L != expected_len: 142 | raise ValueError( 143 | f"Length mismatch: L={L}, expected {expected_len} (prompt_len={prompt_len}, T={T}, n_token_sequence_size={self.max_new_tokens})" 144 | ) 145 | 146 | num_heads = getattr(getattr(model, "config", None), "num_attention_heads", 1) 147 | blk_mask = self._build_block_mask(L, prompt_len, T, num_heads) 148 | 149 | outputs = model( 150 | input_ids=input_ids.unsqueeze(0), 151 | block_mask=blk_mask, 152 | attn_implementation="flex_attention", 153 | ) 154 | # [1, L, V] 155 | logits = outputs.logits 156 | del blk_mask 157 | 158 | # ========== AR loss ========== 159 | ar_labels = torch.full((L,), IGNORE_TOKEN_ID, device=self.args.device) 160 | k_starts, l_starts = self._index_layout(prompt_len, T, self.max_new_tokens) 161 | for j in range(T): 162 | ls = l_starts[j] 163 | ar_labels[ls : ls + self.max_new_tokens] = input_ids[ls : ls + self.max_new_tokens] 164 | 165 | label_smoother = LabelSmoother(epsilon=0.1, ignore_index=IGNORE_TOKEN_ID) 166 | loss_ar = label_smoother( 167 | outputs, ar_labels.unsqueeze(0), shift_labels=True 168 | ) * 10.0 169 | 170 | del ar_labels, label_smoother 171 | torch.cuda.empty_cache() 172 | 173 | # ========== Consistency loss (soft) ========== 174 | T_soft = getattr(self.args, "distill_temperature", 1.0) 175 | student_positions, teacher_positions = [], [] 176 | for j in range(T): 177 | ks, ls = k_starts[j], l_starts[j] 178 | offs = range(self.max_new_tokens) 179 | student_positions.extend(ks + off for off in offs) 180 | teacher_positions.extend(ls + off for off in offs) 181 | 182 | if len(student_positions) == 0: 183 | loss_consistency = torch.zeros((), device=self.args.device) 184 | else: 185 | # [M, V] 186 | student_logits_sel = logits[0, student_positions, :] 187 | teacher_logits_sel = logits[0, teacher_positions, :].detach() 188 | 189 | log_ps = F.log_softmax(student_logits_sel / T_soft, dim=-1) 190 | p_t = F.softmax(teacher_logits_sel / T_soft, dim=-1) 191 | loss_consistency = (-(p_t * log_ps).sum(dim=-1)).mean() * (T_soft * T_soft) 192 | 193 | del student_logits_sel, teacher_logits_sel, log_ps, p_t 194 | torch.cuda.empty_cache() 195 | 196 | del logits, k_starts, l_starts 197 | torch.cuda.empty_cache() 198 | 199 | total_loss = loss_ar + loss_consistency 200 | 201 | if self.args.qlora: 202 | total_loss.requires_grad = True 203 | 204 | if self.args.local_rank == 0: 205 | wandb.log( 206 | { 207 | "ar loss": float(loss_ar.detach().cpu()), 208 | "consistency loss": float(loss_consistency.detach().cpu()), 209 | } 210 | ) 211 | 212 | del loss_ar, loss_consistency, outputs 213 | torch.cuda.empty_cache() 214 | 215 | with self.accelerator.accumulate(model): 216 | self.accelerator.backward(total_loss) 217 | 218 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 219 | torch.distributed.barrier() 220 | return total_loss.detach() 221 | -------------------------------------------------------------------------------- /JacobiForcing/ar_inference_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | import json 5 | import math 6 | import random 7 | import argparse 8 | from pathlib import Path 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import pandas as pd 13 | from tqdm import tqdm 14 | from transformers import AutoModelForCausalLM, AutoTokenizer 15 | 16 | 17 | def parse_args(): 18 | p = argparse.ArgumentParser() 19 | p.add_argument("--dataset_parquet", type=str, 20 | default="/home/lah003/data/openai_humaneval/openai_humaneval/test-00000-of-00001.parquet") 21 | p.add_argument("--model_name", type=str, 22 | default="/home/lah003/models/Qwen2.5-Coder-7B-Instruct") 23 | p.add_argument("--tokenizer_name", type=str, 24 | default="/home/lah003/models/Qwen2.5-Coder-7B-Instruct") 25 | p.add_argument("--eval_dir", type=str, 26 | default="/home/lah003/data/CLLM2_eval_generations/baselines") 27 | p.add_argument("--original_jsonl", type=str, 28 | default="/home/lah003/data/CLLM2_eval_generations/humaneval_python_example.jsonl") 29 | 30 | # Gen settings 31 | p.add_argument("--max_new_tokens", type=int, default=1024) 32 | p.add_argument("--temperature", type=float, default=0.0) 33 | p.add_argument("--top_p", type=float, default=1.0) 34 | p.add_argument("--do_sample", action="store_true", default=False) 35 | p.add_argument("--seed", type=int, default=1234) 36 | p.add_argument("--limit", type=int, default=0, help="Limit number of samples (0 = all)") 37 | 38 | # Misc 39 | p.add_argument("--attention_impl", type=str, default="flash_attention_2") 40 | p.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16", "float32"]) 41 | p.add_argument("--device_map", type=str, default="cuda") 42 | return p.parse_args() 43 | 44 | 45 | def set_seed(seed: int): 46 | random.seed(seed) 47 | torch.manual_seed(seed) 48 | torch.cuda.manual_seed_all(seed) 49 | 50 | 51 | def load_jsonl(file_path): 52 | with open(file_path, 'r') as f: 53 | return [json.loads(line.strip()) for line in f] 54 | 55 | 56 | def save_jsonl(data, save_path): 57 | with open(save_path, 'w') as f: 58 | for item in data: 59 | f.write(json.dumps(item) + '\n') 60 | 61 | 62 | def extract_python_code(text: str) -> str: 63 | """Extract the first ```python ... ``` code block; fallback to raw text if none.""" 64 | match = re.search(r'```python([\s\S]*?)```', text) 65 | if match: 66 | return match.group(1).strip() 67 | return text 68 | 69 | 70 | def main(): 71 | args = parse_args() 72 | set_seed(args.seed) 73 | 74 | # Dtype mapping 75 | dtype_map = { 76 | "float16": torch.float16, 77 | "bfloat16": torch.bfloat16, 78 | "float32": torch.float32, 79 | } 80 | torch_dtype = dtype_map[args.dtype] 81 | 82 | df = pd.read_parquet(args.dataset_parquet) 83 | if args.limit and args.limit > 0: 84 | df = df.head(args.limit).copy() 85 | records = df.to_dict(orient="records") 86 | print(f"Loaded HumanEval dataset with {len(records)} samples") 87 | 88 | 89 | print("Loading model/tokenizer...") 90 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) 91 | model = AutoModelForCausalLM.from_pretrained( 92 | args.model_name, 93 | device_map=args.device_map, 94 | torch_dtype=torch_dtype, 95 | attn_implementation=args.attention_impl, 96 | ) 97 | model.eval() 98 | 99 | # Handle pad token if missing 100 | if tokenizer.pad_token_id is None: 101 | tokenizer.pad_token = tokenizer.eos_token 102 | 103 | eos_id = tokenizer.eos_token_id 104 | alt_eos_id = 151645 105 | 106 | os.makedirs(args.eval_dir, exist_ok=True) 107 | all_rows = [] 108 | all_generations = [] 109 | 110 | overall_gen_time = 0.0 111 | overall_total_tokens = 0 112 | 113 | t0_overall = time.perf_counter() 114 | 115 | with torch.inference_mode(): 116 | for idx, row in tqdm(list(enumerate(records)), total=len(records)): 117 | task_id = row.get("task_id", f"idx_{idx}") 118 | prompt = ( 119 | "Please continue to complete the function. You are not allowed to modify the given code and do the completion only. " 120 | "Please return all completed function in a codeblock. Here is the given code to do completion:\n" 121 | "```python\n" 122 | f"{row['prompt'].strip()}\n" 123 | "```" 124 | ) 125 | 126 | messages = [{"role": "user", "content": prompt}] 127 | text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 128 | model_inputs = tokenizer([text], return_tensors="pt").to(model.device) 129 | 130 | input_ids = model_inputs["input_ids"] 131 | prompt_len = input_ids.shape[1] 132 | 133 | # ============================== 134 | # === Generation-only timing === 135 | t0 = time.perf_counter() 136 | output_ids = model.generate( 137 | **model_inputs, 138 | max_new_tokens=args.max_new_tokens, 139 | do_sample=args.do_sample, 140 | temperature=args.temperature, 141 | top_p=args.top_p, 142 | eos_token_id=[eos_id, alt_eos_id] if alt_eos_id is not None else eos_id, 143 | pad_token_id=tokenizer.pad_token_id, 144 | use_cache=True, 145 | ) 146 | gen_time = time.perf_counter() - t0 147 | # ============================== 148 | 149 | # Stats 150 | new_tokens = int(output_ids.shape[1] - prompt_len) 151 | total_tokens = new_tokens 152 | toks_per_sec = (total_tokens / gen_time) 153 | 154 | # Determine stop reason 155 | generated_part = output_ids[0, prompt_len:] 156 | hit_eos = False 157 | if eos_id is not None: 158 | hit_eos = (generated_part == eos_id).any().item() 159 | if not hit_eos and alt_eos_id is not None: 160 | hit_eos = (generated_part == alt_eos_id).any().item() 161 | stop_reason = "eos" if hit_eos else ("max_new_tokens" if new_tokens >= args.max_new_tokens else "unknown") 162 | 163 | # Decode only the newly generated portion 164 | gen_str = tokenizer.decode(output_ids[0, prompt_len:], skip_special_tokens=False) 165 | print(f"Generated answer:\n{gen_str}\n") 166 | all_generations.append(gen_str) 167 | 168 | all_rows.append({ 169 | "index": idx, 170 | "task_id": task_id, 171 | "prompt_tokens": int(prompt_len), 172 | "new_tokens": int(new_tokens), 173 | "total_tokens": int(total_tokens), 174 | "gen_time_sec": float(gen_time), 175 | "toks_per_sec": float(toks_per_sec), 176 | "stop_reason": stop_reason, 177 | }) 178 | 179 | overall_gen_time += gen_time 180 | overall_total_tokens += total_tokens 181 | 182 | if (idx + 1) % 5 == 0 or (idx + 1) == len(records): 183 | print(f"====[{idx+1}/{len(records)}] task_id={task_id} " 184 | f"new_toks={new_tokens} gen_time={gen_time:.2f}s toks/sec={toks_per_sec:.2f} " 185 | f"reason={stop_reason}====") 186 | 187 | break 188 | 189 | t_overall = time.perf_counter() - t0_overall 190 | 191 | # --------------------------- 192 | # Save generations as JSONL 193 | # --------------------------- 194 | original_generations = load_jsonl(args.original_jsonl) 195 | if len(original_generations) != len(all_generations): 196 | print(f"[WARN] original_jsonl has {len(original_generations)} entries, but we produced {len(all_generations)}.") 197 | 198 | for i, original in enumerate(original_generations[:len(all_generations)]): 199 | original['output'] = all_generations[i] 200 | code_only = extract_python_code(all_generations[i]) 201 | print(f"Task id: {i}, Extracted answer:\n{code_only}\n") 202 | original['generation'] = code_only 203 | 204 | ar_save_path = os.path.join( 205 | args.eval_dir, 206 | f"ar_code_only_prompt_humaneval_generation_{Path(args.model_name).name}.jsonl" 207 | ) 208 | save_jsonl(original_generations[:len(all_generations)], ar_save_path) 209 | print(f"\n=== All AR generations done (HumanEval). Results are saved to {ar_save_path} ===") 210 | 211 | df_profile = pd.DataFrame(all_rows) 212 | 213 | def _safe_mean(series): 214 | s = pd.to_numeric(series, errors="coerce") 215 | return float(s.mean()) if s.size and not pd.isna(s).all() else float("nan") 216 | 217 | df_eos = df_profile[df_profile["stop_reason"] == "eos"].copy() 218 | n_eos = len(df_eos) 219 | n_total = len(df_profile) 220 | 221 | print("\n=== AR Generation Profiling (HumanEval) ===") 222 | print(f"Examples (eos): {n_eos} / {n_total} Total gen time: {overall_gen_time:.2f}s (overall wall: {t_overall:.2f}s)") 223 | print(f"Avg new tokens / prompt: {_safe_mean(df_eos['new_tokens']):.2f}") 224 | print(f"Avg toks/sec: {_safe_mean(df_eos['toks_per_sec']):.2f}") 225 | 226 | if __name__ == "__main__": 227 | main() 228 | -------------------------------------------------------------------------------- /generate_trajectory/data/2_prepare_efficient_cllm_training_data_new.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse, json, hashlib, random, os, sqlite3, tempfile, atexit, itertools, re 3 | from concurrent.futures import ProcessPoolExecutor, as_completed, wait, FIRST_COMPLETED 4 | from typing import Optional, Dict, Any, Tuple, List 5 | from tqdm import tqdm 6 | from datasets import load_dataset 7 | 8 | # ----------------------- 9 | # Args 10 | # ----------------------- 11 | def parse_args(): 12 | parser = argparse.ArgumentParser( 13 | description="Low-RAM parallel preprocessing for efficient CLLM training JSONL." 14 | ) 15 | parser.add_argument("--input_path", required=True) 16 | parser.add_argument("--output_path", required=True) 17 | parser.add_argument("--n-token-seq-length", type=int, default=64) 18 | parser.add_argument("--cache-dir", default=None) 19 | parser.add_argument("--seed", type=int, default=42) 20 | parser.add_argument("--num-workers", type=int, default=max(1, os.cpu_count() or 8)) 21 | parser.add_argument("--max-in-flight", type=int, default=None) 22 | parser.add_argument("--db-path", default=None) 23 | parser.add_argument("--no-progress", action="store_true") 24 | parser.add_argument("--verbose", action="store_true") 25 | parser.add_argument("--single-process", action="store_true") 26 | return parser.parse_args() 27 | 28 | # ----------------------- 29 | # Deterministic RNG 30 | # ----------------------- 31 | def stable_seed(*parts: Any, base_seed: int = 0) -> int: 32 | h = hashlib.sha256() 33 | for p in parts: 34 | h.update(str(p).encode("utf-8")); h.update(b"|") 35 | h.update(base_seed.to_bytes(8, "little", signed=False)) 36 | return int.from_bytes(h.digest()[:8], "big", signed=False) 37 | 38 | # ----------------------- 39 | # ID parsing helpers 40 | # ----------------------- 41 | def parse_data_id_int(data_id: str) -> int: 42 | """Extract the integer from 'data_{id}' robustly.""" 43 | if data_id.startswith("data_"): 44 | return int(data_id[5:]) 45 | m = re.search(r"(\d+)", data_id) 46 | return int(m.group(1)) if m else 0 47 | 48 | def parse_itr_int(itr_id: str) -> int: 49 | """Extract the integer from 'itr_{iteration_id}' robustly.""" 50 | if itr_id.startswith("itr_"): 51 | return int(itr_id[4:]) 52 | m = re.search(r"(\d+)", itr_id) 53 | return int(m.group(1)) if m else 0 54 | 55 | # ----------------------- 56 | # Per-sample transform 57 | # ----------------------- 58 | def process_one_sample( 59 | sample: Dict[str, Any], 60 | n_token_seq_length: int, 61 | base_seed: int, 62 | ) -> Optional[Tuple[str, Dict[str, Any]]]: 63 | """ 64 | Return (data_id, entry) or None. 65 | Entry contains one (k_j,last_j) pair; final stitching happens later. 66 | """ 67 | data_id = sample["data_id"] # e.g. "data_123" 68 | diffusion_itr_id = sample["diffusion_itr_id"] # e.g. "itr_7" 69 | data_id_int = parse_data_id_int(data_id) 70 | diffusion_itr = parse_itr_int(diffusion_itr_id) 71 | 72 | prompt_ids = sample["prompt_ids"] 73 | answer_traj = sample["answer_trajectory_ids"] 74 | 75 | if len(answer_traj) < 2: 76 | return None # skip 77 | 78 | rng = random.Random(stable_seed(data_id, diffusion_itr_id, base_seed=base_seed)) 79 | k_j = rng.randint(0, len(answer_traj) - 2) 80 | 81 | sampled_seq = answer_traj[k_j][-n_token_seq_length:] 82 | fixed_seq = answer_traj[-1][-n_token_seq_length:] 83 | 84 | pair_seq = list(sampled_seq) + list(fixed_seq) 85 | 86 | entry = dict( 87 | data_id=data_id, 88 | data_id_int=int(data_id_int), 89 | prompt_ids=list(prompt_ids), 90 | pairs=[dict( 91 | diffusion_itr=int(diffusion_itr), # for sorting 92 | traj_position_index=int(k_j), 93 | seq=pair_seq 94 | )], 95 | ) 96 | return data_id, entry 97 | 98 | # ----------------------- 99 | # In-memory merge helpers 100 | # ----------------------- 101 | def merge_entry(existing: Dict[str, Any], new_entry: Dict[str, Any], verbose: bool = False): 102 | existing["pairs"].extend(new_entry["pairs"]) 103 | if verbose: 104 | print(f"Merged duplicate data_id {existing['data_id']} " 105 | f"(total pairs: {len(existing['pairs'])})") 106 | 107 | # ----------------------- 108 | # SQLite helpers 109 | # ----------------------- 110 | def open_db(path: str) -> sqlite3.Connection: 111 | conn = sqlite3.connect(path, timeout=60) 112 | conn.execute("PRAGMA journal_mode=WAL;") 113 | conn.execute(""" 114 | CREATE TABLE IF NOT EXISTS entries ( 115 | data_id TEXT PRIMARY KEY, 116 | data_id_int INTEGER NOT NULL, 117 | value TEXT NOT NULL 118 | ) 119 | """) 120 | conn.execute("CREATE INDEX IF NOT EXISTS idx_entries_data_id_int ON entries(data_id_int)") 121 | return conn 122 | 123 | def db_get(conn: sqlite3.Connection, data_id: str) -> Optional[Dict[str, Any]]: 124 | cur = conn.execute("SELECT value FROM entries WHERE data_id=?", (data_id,)) 125 | row = cur.fetchone() 126 | return None if row is None else json.loads(row[0]) 127 | 128 | def db_put(conn: sqlite3.Connection, data_id: str, entry: Dict[str, Any]): 129 | conn.execute( 130 | "INSERT INTO entries (data_id, data_id_int, value) VALUES(?,?,?) " 131 | "ON CONFLICT(data_id) DO UPDATE SET value=excluded.value", 132 | (data_id, int(entry["data_id_int"]), json.dumps(entry, ensure_ascii=False)) 133 | ) 134 | 135 | def merge_into_db(conn, data_id, entry, verbose): 136 | existing = db_get(conn, data_id) 137 | if existing is None: 138 | db_put(conn, data_id, entry) 139 | else: 140 | merge_entry(existing, entry, verbose) 141 | db_put(conn, data_id, existing) 142 | 143 | # ----------------------- 144 | # Main 145 | # ----------------------- 146 | def main(): 147 | args = parse_args() 148 | max_in_flight = args.max_in_flight or max(4, args.num_workers * 4) 149 | 150 | # Streaming dataset 151 | data = load_dataset( 152 | "json", 153 | data_files={"train": args.input_path}, 154 | split="train", 155 | streaming=True, 156 | cache_dir=args.cache_dir 157 | ) 158 | 159 | # SQLite store 160 | db_path = args.db_path or tempfile.mkstemp( 161 | prefix="merge_", suffix=".sqlite", 162 | dir=os.path.dirname(os.path.abspath(args.output_path)) or "." 163 | )[1] 164 | conn = open_db(db_path) 165 | atexit.register(lambda: conn.close()) 166 | 167 | pbar = tqdm(desc="Processing", disable=args.no_progress) 168 | 169 | def handle(res): 170 | if res is None: 171 | pbar.update(1); return 172 | data_id, entry = res 173 | merge_into_db(conn, data_id, entry, args.verbose) 174 | if (pbar.n % 5000) == 0: 175 | conn.commit() 176 | pbar.update(1) 177 | 178 | try: 179 | if args.single_process or args.num_workers <= 1: 180 | for sample in data: 181 | handle(process_one_sample(sample, args.n_token_seq_length, args.seed)) 182 | conn.commit() 183 | else: 184 | in_flight = set() 185 | with ProcessPoolExecutor(max_workers=args.num_workers) as ex: 186 | for i, sample in enumerate(data): 187 | in_flight.add(ex.submit(process_one_sample, 188 | sample, args.n_token_seq_length, args.seed)) 189 | if len(in_flight) >= max_in_flight: 190 | done, in_flight = wait(in_flight, return_when=FIRST_COMPLETED) 191 | for fut in done: handle(fut.result()) 192 | for fut in as_completed(in_flight): handle(fut.result()) 193 | conn.commit() 194 | finally: 195 | pbar.close() 196 | 197 | # -------- Final write-out with sorting -------- 198 | with open(args.output_path, "w", encoding="utf-8") as fout: 199 | cur = cur = conn.execute("SELECT value FROM entries ORDER BY data_id_int") 200 | count = 0 201 | 202 | # across all data_id 203 | for (value_str,) in cur: 204 | entry = json.loads(value_str) 205 | 206 | # Sort the (k_j,last_j) pairs by diffusion_itr (int) 207 | pairs_sorted = sorted(entry["pairs"], key=lambda p: p["diffusion_itr"]) 208 | 209 | # Flatten sequences 210 | concatenated_pairs: List[int] = list( 211 | itertools.chain.from_iterable(p["seq"] for p in pairs_sorted) 212 | ) 213 | 214 | traj_position_indices: List[int] = list( 215 | p["traj_position_index"] for p in pairs_sorted 216 | ) 217 | 218 | output_entry = dict( 219 | data_id = entry["data_id"], 220 | prompt_ids = entry["prompt_ids"][0], 221 | complete_training_sequence_ids = entry["prompt_ids"][0] + concatenated_pairs, 222 | prompt_ids_len = len(entry["prompt_ids"][0]), 223 | traj_position_indices = traj_position_indices, 224 | ) 225 | fout.write(json.dumps(output_entry, ensure_ascii=False)) 226 | fout.write("\n") 227 | count += 1 228 | 229 | # Remove temp DB if we created it 230 | if not args.db_path: 231 | try: os.remove(db_path) 232 | except Exception: pass 233 | 234 | print(f"Processed {count} unique data_id samples --> {args.output_path}") 235 | 236 | if __name__ == "__main__": 237 | main() -------------------------------------------------------------------------------- /generate_trajectory/data/1_progressive_masking_based_prepare_trajectory.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import AutoTokenizer 3 | from tqdm import tqdm 4 | import random, copy, json, math, re, os 5 | 6 | import multiprocessing as mp 7 | 8 | import random 9 | 10 | from functools import partial 11 | 12 | import os 13 | 14 | mp.set_start_method("spawn", force=True) 15 | 16 | THOUGHT_RE = re.compile( 17 | r"<\|begin_of_thought\|>\n\n(.*?)\n\n<\|end_of_thought\|>", re.DOTALL 18 | ) 19 | SOLUTION_RE = re.compile( 20 | r"<\|begin_of_solution\|>\n\n(.*?)\n\n<\|end_of_solution\|>", re.DOTALL 21 | ) 22 | 23 | tokenizer = None # global (per process) 24 | TOKENIZER_PATH = "/checkpoint/lhu/models/OpenThinker2-7B" 25 | 26 | def init_worker(): 27 | global tokenizer 28 | from transformers import AutoTokenizer 29 | tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) 30 | 31 | 32 | def process_response(resp: str) -> str: 33 | tm, sm = THOUGHT_RE.search(resp), SOLUTION_RE.search(resp) 34 | if not (tm and sm): 35 | return resp.strip() 36 | return f"\n{tm.group(1).strip()}\n\n\n{sm.group(1).strip()}" 37 | 38 | def build_messages(sample, use_think_format=False, use_system_prompt=False): 39 | if use_system_prompt: 40 | system_msg = sample["system"] 41 | msgs = [{"role": "system", "content": system_msg}] 42 | else: 43 | msgs = [] 44 | for turn in sample["conversations"]: 45 | role = "user" if turn["from"] == "user" else "assistant" 46 | if use_think_format: 47 | content = turn["value"] if role == "user" else process_response(turn["value"]) 48 | else: 49 | content = turn["value"] 50 | msgs.append({"role": role, "content": content}) 51 | return msgs 52 | 53 | def build_user_prompt(sample, use_system_prompt=False): 54 | if use_system_prompt: 55 | system_msg = sample["system"] 56 | msgs = [{"role": "system", "content": system_msg}] 57 | else: 58 | msgs = [] 59 | for turn in sample["conversations"]: 60 | if turn["from"] == "user": 61 | msgs.append({"role": "user", "content": turn["value"]}) 62 | return msgs 63 | 64 | def corrupt_chunk(chunk, i, full_ids, prompt_ids_len, lookup_context_len, pad_id): 65 | """ 66 | Return list of progressively corrupted versions of the chunk, 67 | where each is a full-length accepted tokens up to end of this chunk, 68 | with rightmost n tokens of this chunk replaced by random context tokens. 69 | """ 70 | corrupted_sequence = [] 71 | start_idx = prompt_ids_len + i # offset into full_ids (start of current chunk) 72 | prefix = full_ids[:start_idx] 73 | chunk_len = len(chunk) 74 | for corrupt_right in reversed(range(chunk_len + 1)): 75 | keep = chunk[:chunk_len - corrupt_right] if corrupt_right > 0 else chunk[:] 76 | # corrupt these tokens with random tokens from context 77 | corrupt = [] 78 | if corrupt_right > 0: 79 | sampling_start = max(0, start_idx - lookup_context_len) 80 | sampling_pool = full_ids[sampling_start:start_idx] 81 | if not sampling_pool: 82 | sampling_pool = [pad_id] 83 | corrupt = [random.choice(sampling_pool) for _ in range(corrupt_right)] 84 | # prefix + (corrupted chunk) 85 | seq = prefix + keep + corrupt 86 | corrupted_sequence.append(seq) 87 | return corrupted_sequence 88 | 89 | def convert_sample( 90 | sample, 91 | row_id: int, 92 | chunk_size=32, 93 | use_think_format=False, 94 | use_system_prompt=False, 95 | lookup_context_len=128, 96 | sequence_sampling_ratio=1.0, 97 | ): 98 | global tokenizer 99 | 100 | try: 101 | prompt_msgs = build_user_prompt(sample, use_system_prompt=use_system_prompt) 102 | msgs = build_messages(sample, use_think_format=use_think_format, use_system_prompt=use_system_prompt) 103 | prompt_ids = tokenizer.apply_chat_template( 104 | prompt_msgs, 105 | tokenize=True, 106 | add_generation_prompt=True, 107 | return_tensors="pt" 108 | ).squeeze(0).tolist() 109 | full_ids = tokenizer.apply_chat_template( 110 | msgs, 111 | tokenize=True, 112 | add_generation_prompt=False, 113 | return_tensors="pt" 114 | ).squeeze(0).tolist() 115 | pad_id = tokenizer.eos_token_id 116 | 117 | if len(full_ids) > 16_384: 118 | return [] 119 | 120 | response_length = len(full_ids) - len(prompt_ids) 121 | if response_length % chunk_size != 0: 122 | pad_amt = chunk_size - (response_length % chunk_size) 123 | full_ids = full_ids + [pad_id] * pad_amt 124 | 125 | records = [] 126 | num_chunks = (len(full_ids) - len(prompt_ids)) // chunk_size 127 | 128 | # === Sampling indices to include according to the ratio === 129 | chunk_indices = list(range(num_chunks)) 130 | num_to_sample = max(1, int(num_chunks * sequence_sampling_ratio)) # always keep at least 1 131 | sampled_indices = set(random.sample(chunk_indices, num_to_sample)) 132 | # ========================================================= 133 | 134 | for chunk_idx in range(num_chunks): 135 | if chunk_idx not in sampled_indices: 136 | continue 137 | 138 | i = chunk_idx * chunk_size 139 | chunk = full_ids[len(prompt_ids) + i : len(prompt_ids) + i + chunk_size] 140 | answer_trajectory = corrupt_chunk( 141 | chunk, i, full_ids, len(prompt_ids), lookup_context_len, pad_id 142 | ) 143 | record = dict( 144 | data_id = f"data_{row_id}", 145 | diffusion_itr_id = f"itr_{chunk_idx}", 146 | prompt_ids_len = [len(prompt_ids)], 147 | prompt_ids = prompt_ids, 148 | answer_trajectory_ids = answer_trajectory, 149 | teacher_output_ids = full_ids, 150 | labels_ids = full_ids, 151 | ) 152 | records.append(record) 153 | 154 | return records 155 | except Exception as e: 156 | import traceback 157 | print(f"❌ Worker crashed on row {row_id}: {e}") 158 | traceback.print_exc() 159 | return [] 160 | 161 | def preprocess_parallel( 162 | data, 163 | chunk_size=32, 164 | n_workers=4, 165 | start_idx=0, 166 | use_think_format=False, 167 | use_system_prompt=False, 168 | lookup_context_len=128, 169 | sequence_sampling_ratio=1.0, 170 | ): 171 | func = partial( 172 | convert_sample, 173 | chunk_size=chunk_size, 174 | use_think_format=use_think_format, 175 | use_system_prompt=use_system_prompt, 176 | lookup_context_len=lookup_context_len, 177 | sequence_sampling_ratio=sequence_sampling_ratio, 178 | ) 179 | jobs = [(s, start_idx + i) for i, s in enumerate(data)] 180 | with mp.Pool( 181 | n_workers, 182 | initializer=init_worker, 183 | maxtasksperchild=200 184 | ) as pool: 185 | out = [] 186 | for recs in tqdm(pool.starmap(func, jobs), total=len(jobs)): 187 | out.extend(recs) 188 | return out 189 | 190 | if __name__ == "__main__": 191 | random.seed(42) 192 | 193 | ds = load_dataset( 194 | "parquet", 195 | data_files="/home/yak/data/OpenThoughts-114k/data/train-*.parquet", 196 | split="train" 197 | ) 198 | 199 | print("Loaded", len(ds), "rows") 200 | 201 | SPLIT_RATIO = 1 202 | subset_size = len(ds) // SPLIT_RATIO 203 | indices = list(range(len(ds))) 204 | random.shuffle(indices) 205 | selected_indices = indices[:subset_size] 206 | ds_subset = ds.select(selected_indices) 207 | 208 | print("Subset size:", len(ds_subset)) 209 | 210 | CHUNK = 36 211 | N_TOKEN_SEQ_LENGTH = 64 212 | LOOKUP_CONTEXT_LENGTH = N_TOKEN_SEQ_LENGTH * 10 213 | SEQUENCE_SAMPLING_RATIO = 1 214 | 215 | OUTFILE = f"/checkpoint/lhu/data/CLLM2_openthought/train_openthoughts_split_ratio_{SPLIT_RATIO}_size_{len(ds_subset)}_ntok_size_{N_TOKEN_SEQ_LENGTH}_lookup_size_{LOOKUP_CONTEXT_LENGTH}_sampling_ratio_{SEQUENCE_SAMPLING_RATIO}_eos_tokens_termination_with_think_format_without_sysmsg.json" 216 | os.makedirs(os.path.dirname(OUTFILE), exist_ok=True) 217 | 218 | N_WORKERS = min(12, os.cpu_count()) 219 | 220 | all_records = [] 221 | with open(OUTFILE, "w", encoding="utf-8") as f: 222 | for shard in range(math.ceil(len(ds_subset) / CHUNK)): 223 | a, b = shard * CHUNK, min((shard + 1) * CHUNK, len(ds_subset)) 224 | print(f"Processing rows {a}…{b-1}") 225 | sub = ds_subset.select(range(a, b)) 226 | # Convert to list of dicts for safe serialization 227 | sub_dicts = [dict(x) for x in sub] 228 | records = preprocess_parallel( 229 | sub_dicts, 230 | chunk_size=N_TOKEN_SEQ_LENGTH, 231 | n_workers=N_WORKERS, 232 | start_idx=a, 233 | use_think_format=True, 234 | use_system_prompt=False, 235 | lookup_context_len=LOOKUP_CONTEXT_LENGTH, 236 | sequence_sampling_ratio=SEQUENCE_SAMPLING_RATIO, 237 | ) 238 | for rec in records: 239 | f.write(json.dumps(rec) + "\n") 240 | 241 | print(f"Wrote {len(records):,} aligned records --> {OUTFILE}") 242 | 243 | del records, sub_dicts 244 | -------------------------------------------------------------------------------- /generate_trajectory/data/2_prepare_baseline_training_data_sft.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse, json, hashlib, random, os, sqlite3, tempfile, atexit, itertools, re 3 | from concurrent.futures import ProcessPoolExecutor, as_completed, wait, FIRST_COMPLETED 4 | from typing import Optional, Dict, Any, Tuple, List 5 | from tqdm import tqdm 6 | from datasets import load_dataset 7 | 8 | # ----------------------- 9 | # Args 10 | # ----------------------- 11 | def parse_args(): 12 | parser = argparse.ArgumentParser( 13 | description="Low-RAM parallel preprocessing for efficient CLLM training JSONL." 14 | ) 15 | parser.add_argument("--input_path", required=True) 16 | parser.add_argument("--output_path", required=True) 17 | parser.add_argument("--n-token-seq-length", type=int, default=64) 18 | parser.add_argument("--cache-dir", default=None) 19 | parser.add_argument("--seed", type=int, default=42) 20 | parser.add_argument("--num-workers", type=int, default=max(1, os.cpu_count() or 8)) 21 | parser.add_argument("--max-in-flight", type=int, default=None) 22 | parser.add_argument("--db-path", default=None) 23 | parser.add_argument("--no-progress", action="store_true") 24 | parser.add_argument("--verbose", action="store_true") 25 | parser.add_argument("--single-process", action="store_true") 26 | return parser.parse_args() 27 | 28 | # ----------------------- 29 | # Deterministic RNG 30 | # ----------------------- 31 | def stable_seed(*parts: Any, base_seed: int = 0) -> int: 32 | h = hashlib.sha256() 33 | for p in parts: 34 | h.update(str(p).encode("utf-8")); h.update(b"|") 35 | h.update(base_seed.to_bytes(8, "little", signed=False)) 36 | return int.from_bytes(h.digest()[:8], "big", signed=False) 37 | 38 | # ----------------------- 39 | # ID parsing helpers 40 | # ----------------------- 41 | def parse_data_id_int(data_id: str) -> int: 42 | """Extract the integer from 'data_{id}' robustly.""" 43 | if data_id.startswith("data_"): 44 | return int(data_id[5:]) 45 | m = re.search(r"(\d+)", data_id) 46 | return int(m.group(1)) if m else 0 47 | 48 | def parse_itr_int(itr_id: str) -> int: 49 | """Extract the integer from 'itr_{iteration_id}' robustly.""" 50 | if itr_id.startswith("itr_"): 51 | return int(itr_id[4:]) 52 | m = re.search(r"(\d+)", itr_id) 53 | return int(m.group(1)) if m else 0 54 | 55 | # ----------------------- 56 | # Per-sample transform 57 | # ----------------------- 58 | def process_one_sample( 59 | sample: Dict[str, Any], 60 | n_token_seq_length: int, 61 | base_seed: int, 62 | ) -> Optional[Tuple[str, Dict[str, Any]]]: 63 | """ 64 | Return (data_id, entry) or None. 65 | Entry contains one (k_j,last_j) pair; final stitching happens later. 66 | """ 67 | data_id = sample["data_id"] # e.g., "data_123" 68 | diffusion_itr_id = sample["diffusion_itr_id"] # e.g., "itr_7" 69 | data_id_int = parse_data_id_int(data_id) 70 | diffusion_itr = parse_itr_int(diffusion_itr_id) 71 | 72 | prompt_ids = sample["prompt_ids"] 73 | answer_traj = sample["answer_trajectory_ids"] 74 | 75 | if len(answer_traj) < 2: 76 | return None # skip 77 | 78 | rng = random.Random(stable_seed(data_id, diffusion_itr_id, base_seed=base_seed)) 79 | k_j = rng.randint(0, len(answer_traj) - 2) 80 | 81 | sampled_seq = answer_traj[k_j][-n_token_seq_length:] 82 | fixed_seq = answer_traj[-1][-n_token_seq_length:] # this is last_j 83 | 84 | pair_seq = list(sampled_seq) + list(fixed_seq) 85 | 86 | entry = dict( 87 | data_id=data_id, 88 | data_id_int=int(data_id_int), 89 | prompt_ids=list(prompt_ids), 90 | pairs=[dict( 91 | diffusion_itr=int(diffusion_itr), # for sorting 92 | # store the truncated last_j explicitly so we can build labels_ids later 93 | last_seq=list(fixed_seq), 94 | )], 95 | ) 96 | return data_id, entry 97 | 98 | # ----------------------- 99 | # In-memory merge helpers 100 | # ----------------------- 101 | def merge_entry(existing: Dict[str, Any], new_entry: Dict[str, Any], verbose: bool = False): 102 | existing["pairs"].extend(new_entry["pairs"]) 103 | if verbose: 104 | print(f"Merged duplicate data_id {existing['data_id']} " 105 | f"(total pairs: {len(existing['pairs'])})") 106 | 107 | # ----------------------- 108 | # SQLites 109 | # ----------------------- 110 | def open_db(path: str) -> sqlite3.Connection: 111 | conn = sqlite3.connect(path, timeout=60) 112 | conn.execute("PRAGMA journal_mode=WAL;") 113 | conn.execute(""" 114 | CREATE TABLE IF NOT EXISTS entries ( 115 | data_id TEXT PRIMARY KEY, 116 | data_id_int INTEGER NOT NULL, 117 | value TEXT NOT NULL 118 | ) 119 | """) 120 | conn.execute("CREATE INDEX IF NOT EXISTS idx_entries_data_id_int ON entries(data_id_int)") 121 | return conn 122 | 123 | def db_get(conn: sqlite3.Connection, data_id: str) -> Optional[Dict[str, Any]]: 124 | cur = conn.execute("SELECT value FROM entries WHERE data_id=?", (data_id,)) 125 | row = cur.fetchone() 126 | return None if row is None else json.loads(row[0]) 127 | 128 | def db_put(conn: sqlite3.Connection, data_id: str, entry: Dict[str, Any]): 129 | conn.execute( 130 | "INSERT INTO entries (data_id, data_id_int, value) VALUES(?,?,?) " 131 | "ON CONFLICT(data_id) DO UPDATE SET value=excluded.value", 132 | (data_id, int(entry["data_id_int"]), json.dumps(entry, ensure_ascii=False)) 133 | ) 134 | 135 | def merge_into_db(conn, data_id, entry, verbose): 136 | existing = db_get(conn, data_id) 137 | if existing is None: 138 | db_put(conn, data_id, entry) 139 | else: 140 | merge_entry(existing, entry, verbose) 141 | db_put(conn, data_id, existing) 142 | 143 | # ----------------------- 144 | # Main 145 | # ----------------------- 146 | def main(): 147 | args = parse_args() 148 | max_in_flight = args.max_in_flight or max(4, args.num_workers * 4) 149 | 150 | # Streaming dataset 151 | data = load_dataset( 152 | "json", 153 | data_files={"train": args.input_path}, 154 | split="train", 155 | streaming=True, 156 | cache_dir=args.cache_dir 157 | ) 158 | 159 | # SQLite store 160 | db_path = args.db_path or tempfile.mkstemp( 161 | prefix="merge_", suffix=".sqlite", 162 | dir=os.path.dirname(os.path.abspath(args.output_path)) or "." 163 | )[1] 164 | conn = open_db(db_path) 165 | atexit.register(lambda: conn.close()) 166 | 167 | pbar = tqdm(desc="Processing", disable=args.no_progress) 168 | 169 | def handle(res): 170 | if res is None: 171 | pbar.update(1); return 172 | data_id, entry = res 173 | merge_into_db(conn, data_id, entry, args.verbose) 174 | if (pbar.n % 5000) == 0: 175 | conn.commit() 176 | pbar.update(1) 177 | 178 | try: 179 | if args.single_process or args.num_workers <= 1: 180 | for sample in data: 181 | handle(process_one_sample(sample, args.n_token_seq_length, args.seed)) 182 | conn.commit() 183 | else: 184 | in_flight = set() 185 | with ProcessPoolExecutor(max_workers=args.num_workers) as ex: 186 | for i, sample in enumerate(data): 187 | in_flight.add(ex.submit(process_one_sample, 188 | sample, args.n_token_seq_length, args.seed)) 189 | if len(in_flight) >= max_in_flight: 190 | done, in_flight = wait(in_flight, return_when=FIRST_COMPLETED) 191 | for fut in done: handle(fut.result()) 192 | for fut in as_completed(in_flight): handle(fut.result()) 193 | conn.commit() 194 | finally: 195 | pbar.close() 196 | 197 | # Final write-out with sorting 198 | with open(args.output_path, "w", encoding="utf-8") as fout: 199 | cur = conn.execute("SELECT value FROM entries ORDER BY data_id_int") 200 | count = 0 201 | 202 | # across all data_id 203 | for (value_str,) in cur: 204 | entry = json.loads(value_str) 205 | 206 | # Sort the (k_j,last_j) pairs by diffusion_itr (int) 207 | pairs_sorted = sorted(entry["pairs"], key=lambda p: p["diffusion_itr"]) 208 | 209 | # Collect only the last_j sequences (as stored) for labels_ids 210 | # If 'last_seq' is missing for any reason, fall back to the tail of seq. 211 | labels_tail: List[int] = list( 212 | itertools.chain.from_iterable( 213 | p["last_seq"] for p in pairs_sorted 214 | ) 215 | ) 216 | 217 | prompt = entry["prompt_ids"][0] 218 | 219 | output_entry = dict( 220 | data_id = entry["data_id"], 221 | prompt_ids = prompt, 222 | # prompt + concat of all last_j (truncated) sequences 223 | labels_ids = prompt + labels_tail, 224 | prompt_ids_len = len(prompt), 225 | ) 226 | fout.write(json.dumps(output_entry, ensure_ascii=False)) 227 | fout.write("\n") 228 | count += 1 229 | 230 | # Remove temp DB if we created it 231 | if not args.db_path: 232 | try: os.remove(db_path) 233 | except Exception: pass 234 | 235 | print(f"Processed {count} unique data_id samples --> {args.output_path}") 236 | 237 | if __name__ == "__main__": 238 | main() 239 | -------------------------------------------------------------------------------- /generate_trajectory/data/2_prepare_efficient_cllm_training_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse, json, hashlib, random, os, sqlite3, tempfile, atexit, itertools, re 3 | from concurrent.futures import ProcessPoolExecutor, as_completed, wait, FIRST_COMPLETED 4 | from typing import Optional, Dict, Any, Tuple, List 5 | from tqdm import tqdm 6 | from datasets import load_dataset 7 | 8 | # ----------------------- 9 | # Args 10 | # ----------------------- 11 | def parse_args(): 12 | parser = argparse.ArgumentParser( 13 | description="Low-RAM parallel preprocessing for efficient CLLM training JSONL." 14 | ) 15 | parser.add_argument("--input_path", required=True) 16 | parser.add_argument("--output_path", required=True) 17 | parser.add_argument("--n-token-seq-length", type=int, default=64) 18 | parser.add_argument("--cache-dir", default=None) 19 | parser.add_argument("--seed", type=int, default=42) 20 | parser.add_argument("--num-workers", type=int, default=max(1, os.cpu_count() or 8)) 21 | parser.add_argument("--max-in-flight", type=int, default=None) 22 | parser.add_argument("--db-path", default=None) 23 | parser.add_argument("--no-progress", action="store_true") 24 | parser.add_argument("--verbose", action="store_true") 25 | parser.add_argument("--single-process", action="store_true") 26 | return parser.parse_args() 27 | 28 | # ----------------------- 29 | # Deterministic RNG 30 | # ----------------------- 31 | def stable_seed(*parts: Any, base_seed: int = 0) -> int: 32 | h = hashlib.sha256() 33 | for p in parts: 34 | h.update(str(p).encode("utf-8")); h.update(b"|") 35 | h.update(base_seed.to_bytes(8, "little", signed=False)) 36 | return int.from_bytes(h.digest()[:8], "big", signed=False) 37 | 38 | # ----------------------- 39 | # ID parsing helpers 40 | # ----------------------- 41 | def parse_data_id_int(data_id: str) -> int: 42 | """Extract the integer from 'data_{id}' robustly.""" 43 | if data_id.startswith("data_"): 44 | return int(data_id[5:]) 45 | m = re.search(r"(\d+)", data_id) 46 | return int(m.group(1)) if m else 0 47 | 48 | def parse_itr_int(itr_id: str) -> int: 49 | """Extract the integer from 'itr_{iteration_id}' robustly.""" 50 | if itr_id.startswith("itr_"): 51 | return int(itr_id[4:]) 52 | m = re.search(r"(\d+)", itr_id) 53 | return int(m.group(1)) if m else 0 54 | 55 | 56 | # ----------------------- 57 | # Per-sample transform 58 | # ----------------------- 59 | def process_one_sample( 60 | sample: Dict[str, Any], 61 | n_token_seq_length: int, 62 | base_seed: int 63 | ) -> Optional[Tuple[str, Dict[str, Any]]]: 64 | """ 65 | Return (data_id, entry) or None. 66 | Entry contains *one* (k_j,last_j) pair; final stitching happens later. 67 | """ 68 | data_id = sample["data_id"] # e.g., "data_123" 69 | diffusion_itr_id = sample["diffusion_itr_id"] # e.g., "itr_7" 70 | data_id_int = parse_data_id_int(data_id) 71 | diffusion_itr = parse_itr_int(diffusion_itr_id) 72 | 73 | prompt_ids = sample["prompt_ids"] 74 | #prompt_ids_len = sample["prompt_ids_len"] 75 | #labels_ids = sample["labels_ids"] 76 | answer_traj = sample["answer_trajectory_ids"] 77 | 78 | if len(answer_traj) < 2: 79 | return None # skip 80 | 81 | rng = random.Random(stable_seed(data_id, diffusion_itr_id, base_seed=base_seed)) 82 | k_j = rng.randint(0, len(answer_traj) - 2) 83 | 84 | sampled_seq = answer_traj[k_j][-n_token_seq_length:] 85 | fixed_seq = answer_traj[-1][-n_token_seq_length:] 86 | 87 | pair_seq = list(sampled_seq) + list(fixed_seq) 88 | 89 | entry = dict( 90 | data_id=data_id, 91 | data_id_int=int(data_id_int), 92 | #prompt_ids_len=prompt_ids_len, 93 | prompt_ids=list(prompt_ids), 94 | #labels_ids=list(labels_ids), 95 | pairs=[dict( 96 | diffusion_itr=int(diffusion_itr), # for sorting 97 | traj_position_index=int(k_j), 98 | seq=pair_seq 99 | )], 100 | ) 101 | return data_id, entry 102 | 103 | # ----------------------- 104 | # In-memory merge helpers 105 | # ----------------------- 106 | def merge_entry(existing: Dict[str, Any], new_entry: Dict[str, Any], verbose: bool = False): 107 | existing["pairs"].extend(new_entry["pairs"]) 108 | if verbose: 109 | print(f"Merged duplicate data_id {existing['data_id']} " 110 | f"(total pairs: {len(existing['pairs'])})") 111 | 112 | # ----------------------- 113 | # SQLite helpers 114 | # ----------------------- 115 | def open_db(path: str) -> sqlite3.Connection: 116 | conn = sqlite3.connect(path, timeout=60) 117 | conn.execute("PRAGMA journal_mode=WAL;") 118 | conn.execute(""" 119 | CREATE TABLE IF NOT EXISTS entries ( 120 | data_id TEXT PRIMARY KEY, 121 | data_id_int INTEGER NOT NULL, 122 | value TEXT NOT NULL 123 | ) 124 | """) 125 | conn.execute("CREATE INDEX IF NOT EXISTS idx_entries_data_id_int ON entries(data_id_int)") 126 | return conn 127 | 128 | def db_get(conn: sqlite3.Connection, data_id: str) -> Optional[Dict[str, Any]]: 129 | cur = conn.execute("SELECT value FROM entries WHERE data_id=?", (data_id,)) 130 | row = cur.fetchone() 131 | return None if row is None else json.loads(row[0]) 132 | 133 | def db_put(conn: sqlite3.Connection, data_id: str, entry: Dict[str, Any]): 134 | conn.execute( 135 | "INSERT INTO entries (data_id, data_id_int, value) VALUES(?,?,?) " 136 | "ON CONFLICT(data_id) DO UPDATE SET value=excluded.value", 137 | (data_id, int(entry["data_id_int"]), json.dumps(entry, ensure_ascii=False)) 138 | ) 139 | 140 | def merge_into_db(conn, data_id, entry, verbose): 141 | existing = db_get(conn, data_id) 142 | if existing is None: 143 | db_put(conn, data_id, entry) 144 | else: 145 | merge_entry(existing, entry, verbose) 146 | db_put(conn, data_id, existing) 147 | 148 | # ----------------------- 149 | # Main 150 | # ----------------------- 151 | def main(): 152 | args = parse_args() 153 | max_in_flight = args.max_in_flight or max(4, args.num_workers * 4) 154 | 155 | # Streaming dataset 156 | data = load_dataset( 157 | "json", 158 | data_files={"train": args.input_path}, 159 | split="train", 160 | streaming=True, 161 | cache_dir=args.cache_dir 162 | ) 163 | 164 | # SQLite store 165 | db_path = args.db_path or tempfile.mkstemp( 166 | prefix="merge_", suffix=".sqlite", 167 | dir=os.path.dirname(os.path.abspath(args.output_path)) or "." 168 | )[1] 169 | conn = open_db(db_path) 170 | atexit.register(lambda: conn.close()) 171 | 172 | pbar = tqdm(desc="Processing", disable=args.no_progress) 173 | 174 | def handle(res): 175 | if res is None: 176 | pbar.update(1); return 177 | data_id, entry = res 178 | merge_into_db(conn, data_id, entry, args.verbose) 179 | if (pbar.n % 5000) == 0: 180 | conn.commit() 181 | pbar.update(1) 182 | 183 | try: 184 | if args.single_process or args.num_workers <= 1: 185 | for sample in data: 186 | handle(process_one_sample(sample, args.n_token_seq_length, args.seed)) 187 | conn.commit() 188 | else: 189 | in_flight = set() 190 | with ProcessPoolExecutor(max_workers=args.num_workers) as ex: 191 | for i, sample in enumerate(data): 192 | in_flight.add(ex.submit(process_one_sample, 193 | sample, args.n_token_seq_length, args.seed)) 194 | if len(in_flight) >= max_in_flight: 195 | done, in_flight = wait(in_flight, return_when=FIRST_COMPLETED) 196 | for fut in done: handle(fut.result()) 197 | for fut in as_completed(in_flight): handle(fut.result()) 198 | conn.commit() 199 | finally: 200 | pbar.close() 201 | 202 | # -------- Final write-out with sorting -------- 203 | with open(args.output_path, "w", encoding="utf-8") as fout: 204 | cur = cur = conn.execute("SELECT value FROM entries ORDER BY data_id_int") 205 | count = 0 206 | 207 | # across all data_id 208 | for (value_str,) in cur: 209 | entry = json.loads(value_str) 210 | 211 | # Sort the (k_j,last_j) pairs by diffusion_itr (int) 212 | pairs_sorted = sorted(entry["pairs"], key=lambda p: p["diffusion_itr"]) 213 | 214 | # Flatten sequences 215 | concatenated_pairs: List[int] = list( 216 | itertools.chain.from_iterable(p["seq"] for p in pairs_sorted) 217 | ) 218 | traj_position_indices: List[int] = list( 219 | p["traj_position_index"] for p in pairs_sorted 220 | ) 221 | 222 | output_entry = dict( 223 | data_id = entry["data_id"], 224 | #prompt_ids_len = entry["prompt_ids_len"], 225 | prompt_ids = entry["prompt_ids"], 226 | #labels_ids = entry["labels_ids"], 227 | complete_training_sequence_ids = entry["prompt_ids"] + concatenated_pairs, 228 | traj_position_indices = traj_position_indices, 229 | ) 230 | fout.write(json.dumps(output_entry, ensure_ascii=False)) 231 | fout.write("\n") 232 | count += 1 233 | 234 | # Remove temp DB if we created it 235 | if not args.db_path: 236 | try: os.remove(db_path) 237 | except Exception: pass 238 | 239 | print(f"Processed {count} unique data_id samples --> {args.output_path}") 240 | 241 | if __name__ == "__main__": 242 | main() 243 | -------------------------------------------------------------------------------- /JacobiForcing/train/deprecated/vanilla_efficient_cllm_trainer.py: -------------------------------------------------------------------------------- 1 | # one_pass_trainer.py 2 | # 3 | # CllmTrainer that computes BOTH AR loss and Consistency loss 4 | # in a SINGLE forward pass by constructing a 2D attention mask. 5 | # 6 | # Assumptions: 7 | # - Inputs per sample (batch size == 1 expected here): 8 | # inputs["input_ids"] : Tensor [1, L] (flattened: prompt + k_0 + last_0 + k_1 + last_1 + ... ) 9 | # inputs["labels_ids"] : Tensor [1, L] (teacher/fixed sequence; same shape as input_ids) 10 | # inputs["prompt_ids_len"] : Tensor([P]) or int (length of prompt prefix) 11 | # inputs["traj_position_indices"]: Tensor [1, T] or list[int]; u_j for each k_j (uncorrupted prefix length in k_j) 12 | # 13 | # - Block layout (N_BLOCK = 64): 14 | # input_ids = 15 | # [ prompt(P), 16 | # k_0(64), last_0(64), 17 | # k_1(64), last_1(64), 18 | # ... 19 | # k_{T-1}(64), last_{T-1}(64) ] 20 | # 21 | # - AR loss: 22 | # supervise ONLY the first u_j tokens of each last_j block (shifted LM loss), 23 | # where u_j = traj_position_indices[j]. 24 | # Visibility for a last_j token at offset t (< u_j): 25 | # prompt + all previous last_m (m 1 can be added by looping over batch dimension and stacking masks/logits. 37 | 38 | import torch 39 | import wandb 40 | from torch.cuda.amp import autocast 41 | from transformers import Trainer 42 | from transformers.trainer_pt_utils import LabelSmoother 43 | 44 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 45 | N_BLOCK = 64 # tokens per (k_j / last_j) block 46 | 47 | 48 | class CllmTrainer(Trainer): 49 | def __init__(self, *args, **kwargs): 50 | super().__init__(*args, **kwargs) 51 | args = kwargs["args"] 52 | self.train_step_cnt = 0 53 | self.max_new_tokens = args.max_new_tokens 54 | self.use_gt_labels = args.use_gt_labels 55 | 56 | # ---------------- Utilities ---------------- # 57 | 58 | @staticmethod 59 | def _to_int(x): 60 | return x.item() if isinstance(x, torch.Tensor) else int(x) 61 | 62 | def _unpack_sample(self, inputs): 63 | """ 64 | Extract a single sample. (Assumes per_device_train_batch_size == 1.) 65 | Required keys: 66 | - input_ids, labels_ids: [1, L] 67 | - prompt_ids_len: scalar or [1] 68 | - traj_position_indices: [1, T] or list[int] 69 | """ 70 | input_ids = inputs["input_ids"][0] 71 | 72 | label_ids = inputs["labels_ids"][0] 73 | prompt_len = inputs["prompt_ids_len"] 74 | if isinstance(prompt_len, torch.Tensor): 75 | if prompt_len.dim() > 0: 76 | prompt_len = prompt_len[0] 77 | prompt_len = self._to_int(prompt_len) 78 | 79 | traj_position_indices = inputs["traj_position_indices"][0] 80 | if isinstance(traj_position_indices, torch.Tensor): 81 | traj_position_indices = traj_position_indices.tolist() 82 | traj_position_indices = [int(u) for u in traj_position_indices] 83 | 84 | return ( 85 | input_ids.to(self.args.device), 86 | label_ids.to(self.args.device), 87 | prompt_len, 88 | traj_position_indices, 89 | ) 90 | 91 | 92 | @staticmethod 93 | def _index_layout(prompt_len: int, T: int, N: int = N_BLOCK): 94 | """Return lists of start indices for all k_j and last_j blocks in flattened sequence.""" 95 | k_starts = [prompt_len + 2 * j * N for j in range(T)] 96 | l_starts = [prompt_len + (2 * j + 1) * N for j in range(T)] 97 | return k_starts, l_starts 98 | 99 | 100 | def _build_full_attention_mask(self, L: int, prompt_len: int, u_list, N: int = N_BLOCK): 101 | """ 102 | Build a single boolean attention mask of shape [1, L, L] that encodes: 103 | - prompt causal 104 | - k_j rows attend to: prompt + all previous last blocks + previous tokens in k_j 105 | - last_j rows (for t=u_j keep at least self-attend to avoid NaNs; they don't contribute to loss 107 | """ 108 | device = self.args.device 109 | T = len(u_list) 110 | M = torch.zeros((L, L), dtype=torch.bool, device=device) 111 | 112 | # Prompt causal 113 | for i in range(prompt_len): 114 | M[i, : i + 1] = True 115 | 116 | k_starts, l_starts = self._index_layout(prompt_len, T, N) 117 | 118 | def prev_last_sources(j_idx): 119 | allowed = list(range(prompt_len)) 120 | for m in range(j_idx): 121 | allowed.extend(range(l_starts[m], l_starts[m] + N)) 122 | return allowed 123 | 124 | for j, u_j in enumerate(u_list): 125 | u = max(0, min(N, int(u_j))) 126 | ks, ls = k_starts[j], l_starts[j] 127 | allowed_k_prefix = prev_last_sources(j) 128 | allowed_l_prefix = prev_last_sources(j) 129 | 130 | # k_j rows: prompt + prev last_m + own causal k_j 131 | for t in range(N): 132 | r = ks + t 133 | if allowed_k_prefix: 134 | M[r, allowed_k_prefix] = True 135 | M[r, ks : r + 1] = True 136 | 137 | # last_j rows: allow only first u tokens; others (>=u) keep self-attend 138 | for t in range(N): 139 | r = ls + t 140 | if t < u: 141 | if allowed_l_prefix: 142 | M[r, allowed_l_prefix] = True 143 | M[r, ls : r + 1] = True 144 | else: 145 | M[r, r] = True # numerical safety 146 | 147 | return M.unsqueeze(0) # [1, L, L] 148 | 149 | # ---------------- Core Training Step ---------------- # 150 | def training_step(self, model, inputs, num_items_in_batch=None): 151 | self.train_step_cnt += 1 152 | return self._one_pass_losses_step(model, inputs) 153 | 154 | def _one_pass_losses_step(self, model, inputs): 155 | """ 156 | Single forward pass to compute: 157 | - AR loss: first u_j tokens of each last_j (shifted LM) 158 | - Consistency loss: corrupted tail of each k_j vs teacher last_j at same offsets 159 | """ 160 | input_ids, label_ids, prompt_len, u_list = self._unpack_sample(inputs) 161 | 162 | # Basic layout 163 | L = input_ids.size(0) 164 | T = len(u_list) 165 | expected_len = prompt_len + 2 * T * N_BLOCK 166 | if L != expected_len: 167 | raise ValueError( 168 | f"Length mismatch: L={L}, expected {expected_len} (prompt_len={prompt_len}, T={T}, N_BLOCK={N_BLOCK})" 169 | ) 170 | 171 | # Build attention mask & forward once 172 | full_mask = self._build_full_attention_mask(L, prompt_len, u_list, N_BLOCK) # [1, L, L] 173 | with autocast(dtype=torch.bfloat16): 174 | outputs = model(input_ids=input_ids.unsqueeze(0), attention_mask=full_mask) 175 | logits = outputs.logits # [1, L, V] 176 | vocab_size = logits.size(-1) 177 | 178 | # AR loss (shifted LM on first u_j tokens of each last_j) 179 | ar_labels = torch.full((L,), IGNORE_TOKEN_ID, device=self.args.device) 180 | k_starts, l_starts = self._index_layout(prompt_len, T, N_BLOCK) 181 | for j, u_j in enumerate(u_list): 182 | u = max(0, min(N_BLOCK, int(u_j))) 183 | if u > 0: 184 | ls = l_starts[j] 185 | ar_labels[ls : ls + u] = label_ids[ls : ls + u] 186 | 187 | label_smoother = LabelSmoother(epsilon=0.1, ignore_index=IGNORE_TOKEN_ID) 188 | loss_ar = label_smoother( 189 | type("obj", (), {"logits": logits}), ar_labels.unsqueeze(0), shift_labels=True 190 | ) 191 | loss_ar = loss_ar * 10.0 # scale (as in your previous code) 192 | 193 | # Consistency loss (unshifted CE on corrupted tail of k_j) 194 | student_positions = [] 195 | target_token_ids = [] 196 | for j, u_j in enumerate(u_list): 197 | u = max(0, min(N_BLOCK, int(u_j))) 198 | if u < N_BLOCK: 199 | ks, ls = k_starts[j], l_starts[j] 200 | # collect k_j offsets [u..N_BLOCK-1] 201 | offs = range(u, N_BLOCK) 202 | student_positions.extend(ks + off for off in offs) 203 | target_token_ids.extend(label_ids[ls + off] for off in offs) 204 | 205 | if student_positions: 206 | student_logits_sel = logits[0, student_positions, :] # [M, V] 207 | targets = torch.stack(target_token_ids) # [M] 208 | loss_consistency = torch.nn.functional.cross_entropy( 209 | student_logits_sel, targets, reduction="mean" 210 | ) 211 | else: 212 | loss_consistency = logits.sum() * 0.0 # zero 213 | 214 | total_loss = loss_ar + loss_consistency 215 | 216 | if self.args.qlora: 217 | total_loss.requires_grad = True 218 | 219 | # Logging 220 | if self.args.local_rank == 0: 221 | wandb.log( 222 | { 223 | "ar loss": float(loss_ar.detach().cpu()), 224 | "consistency loss": float(loss_consistency.detach().cpu()), 225 | } 226 | ) 227 | 228 | # Backprop 229 | with self.accelerator.accumulate(model): 230 | self.accelerator.backward(total_loss) 231 | 232 | # Sync & return 233 | torch.distributed.barrier() 234 | return total_loss.detach() 235 | -------------------------------------------------------------------------------- /JacobiForcing/scripts/tool/3d_plot_inference_configuration_search_with_quadratic_fit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | from mpl_toolkits.mplot3d import Axes3D # noqa: F401 9 | 10 | 11 | def load_data(jsonl_path: Path) -> pd.DataFrame: 12 | df = pd.read_json(jsonl_path, lines=True) 13 | 14 | required_cols = ["K", "r", "block_size", "ngram_size", "avg_toks_per_sec"] 15 | missing = [c for c in required_cols if c not in df.columns] 16 | if missing: 17 | raise ValueError(f"Missing required columns in JSONL: {missing}") 18 | 19 | return df 20 | 21 | 22 | def filter_for_r(df: pd.DataFrame, r: float, r_tol: float = 1e-3) -> pd.DataFrame: 23 | """Filter dataframe for given r (with tolerance).""" 24 | mask = df["r"].sub(r).abs() <= r_tol 25 | filtered = df[mask].copy() 26 | if filtered.empty: 27 | raise ValueError(f"No rows found with r≈{r}") 28 | return filtered 29 | 30 | 31 | def aggregate_tokens_per_sec(df: pd.DataFrame) -> pd.DataFrame: 32 | """ 33 | Average tokens/s in case of duplicates for same (block_size, ngram_size, K). 34 | """ 35 | return ( 36 | df.groupby(["block_size", "ngram_size", "K"], as_index=False)["avg_toks_per_sec"] 37 | .mean() 38 | .rename(columns={"avg_toks_per_sec": "toks_per_sec"}) 39 | ) 40 | 41 | 42 | def fit_quadratic_surface(x: np.ndarray, y: np.ndarray, z: np.ndarray) -> np.ndarray: 43 | """ 44 | Fit z = a0 + a1*x + a2*y + a3*x*y + a4*x^2 + a5*y^2 using least squares. 45 | Returns coeffs [a0..a5]. 46 | """ 47 | A = np.column_stack( 48 | [ 49 | np.ones_like(x), 50 | x, 51 | y, 52 | x * y, 53 | x ** 2, 54 | y ** 2, 55 | ] 56 | ) 57 | coeffs, *_ = np.linalg.lstsq(A, z, rcond=None) 58 | return coeffs 59 | 60 | 61 | def eval_quadratic_surface(coeffs: np.ndarray, X: np.ndarray, Y: np.ndarray) -> np.ndarray: 62 | a0, a1, a2, a3, a4, a5 = coeffs 63 | return a0 + a1 * X + a2 * Y + a3 * X * Y + a4 * X ** 2 + a5 * Y ** 2 64 | 65 | 66 | def plot_best_fit_for_all_K(df_agg: pd.DataFrame, out_dir: Path, r: float): 67 | """ 68 | For each K: 69 | - fit quadratic surface 70 | - save 3D surface plot: best_fit_surface_K{K}_r{r}.png 71 | - save 2D contour plot: best_fit_contour_K{K}_r{r}.png 72 | 73 | Also: 74 | - overlay 3D surface for all K -> best_fit_surface_overlay_r{r}.png 75 | - overlay 2D contour for all K -> best_fit_contour_overlay_r{r}.png 76 | """ 77 | Ks = sorted(df_agg["K"].unique()) 78 | if not Ks: 79 | print("[INFO] No K values found after aggregation.") 80 | return 81 | 82 | # Global grid so all K plots are comparable in x/y range 83 | x_min, x_max = df_agg["block_size"].min(), df_agg["block_size"].max() 84 | y_min, y_max = df_agg["ngram_size"].min(), df_agg["ngram_size"].max() 85 | 86 | x_grid = np.linspace(x_min, x_max, 60) 87 | y_grid = np.linspace(y_min, y_max, 60) 88 | Xg, Yg = np.meshgrid(x_grid, y_grid) 89 | 90 | # Overlay figures (one 3D, one 2D) with multi-K 91 | fig_overlay_3d = plt.figure() 92 | ax_overlay_3d = fig_overlay_3d.add_subplot(111, projection="3d") 93 | 94 | fig_overlay_2d = plt.figure() 95 | ax_overlay_2d = fig_overlay_2d.add_subplot(111) 96 | 97 | cmap = plt.get_cmap("Blues_r") # smaller K index -> darker, larger -> lighter 98 | from matplotlib.lines import Line2D 99 | legend_handles = [] 100 | 101 | nK = len(Ks) 102 | denom = max(1, nK - 1) 103 | 104 | any_overlay_plotted = False 105 | 106 | for i, K in enumerate(Ks): 107 | df_k = df_agg[df_agg["K"] == K] 108 | if len(df_k) < 6: 109 | print(f"[WARN] Not enough points to fit a reliable surface for K={K} (have {len(df_k)}), skipping.") 110 | continue 111 | 112 | x = df_k["block_size"].to_numpy(dtype=float) 113 | y = df_k["ngram_size"].to_numpy(dtype=float) 114 | z = df_k["toks_per_sec"].to_numpy(dtype=float) 115 | 116 | coeffs = fit_quadratic_surface(x, y, z) 117 | Zg = eval_quadratic_surface(coeffs, Xg, Yg) 118 | 119 | # ---------- individual 3D surface ---------- 120 | fig = plt.figure() 121 | ax = fig.add_subplot(111, projection="3d") 122 | 123 | surf = ax.plot_surface(Xg, Yg, Zg, alpha=0.8, cmap="viridis") 124 | ax.scatter(x, y, z, s=35, color="k") 125 | 126 | ax.set_xlabel("Block size (n_token_seq_len)") 127 | ax.set_ylabel("N-gram size") 128 | ax.set_zlabel("Tokens / second") 129 | ax.set_title(f"Best-fit surface (K={K}, r={r})") 130 | 131 | cbar = fig.colorbar(surf, ax=ax, shrink=0.7) 132 | cbar.set_label("Tokens / second") 133 | 134 | ax.view_init(elev=25, azim=-135) 135 | plt.tight_layout() 136 | 137 | out_surface = out_dir / f"best_fit_surface_K{K}_r{r}.png" 138 | plt.savefig(out_surface) 139 | plt.close(fig) 140 | print(f"[PLOT] Saved 3D best-fit surface to {out_surface}") 141 | 142 | # ---------- individual 2D contour ---------- 143 | fig2 = plt.figure() 144 | cs = plt.contourf(Xg, Yg, Zg, levels=20) 145 | plt.scatter(x, y, c="k", s=20) 146 | 147 | plt.xlabel("Block size (n_token_seq_len)") 148 | plt.ylabel("N-gram size") 149 | plt.title(f"Best-fit surface (contour, K={K}, r={r})") 150 | 151 | cbar2 = plt.colorbar(cs) 152 | cbar2.set_label("Tokens / second") 153 | 154 | plt.tight_layout() 155 | out_contour = out_dir / f"best_fit_contour_K{K}_r{r}.png" 156 | plt.savefig(out_contour) 157 | plt.close(fig2) 158 | print(f"[PLOT] Saved 2D best-fit contour to {out_contour}") 159 | 160 | # ---------- overlay contributions ---------- 161 | # Color for this K (smaller K -> darker) 162 | t = i / denom 163 | color = cmap(t) 164 | 165 | # 3D overlay: uniform-colored surface + scatter 166 | ax_overlay_3d.plot_surface( 167 | Xg, 168 | Yg, 169 | Zg, 170 | alpha=0.4, 171 | color=color, 172 | linewidth=0, 173 | antialiased=True, 174 | ) 175 | ax_overlay_3d.scatter(x, y, z, color=color, s=20) 176 | 177 | # 2D overlay: contour lines + scatter 178 | zmin, zmax = Zg.min(), Zg.max() 179 | if zmin != zmax: 180 | levels = np.linspace(zmin, zmax, 7) 181 | ax_overlay_2d.contour( 182 | Xg, 183 | Yg, 184 | Zg, 185 | levels=levels, 186 | colors=[color], 187 | alpha=0.7, 188 | ) 189 | ax_overlay_2d.scatter(x, y, color=color, s=20, edgecolor="none") 190 | 191 | legend_handles.append( 192 | Line2D([0], [0], color=color, lw=2, label=f"K={K}") 193 | ) 194 | 195 | any_overlay_plotted = True 196 | print(f"[INFO] Finished K={K} with {len(df_k)} points.") 197 | 198 | # Finalize & save overlay figures 199 | if any_overlay_plotted: 200 | # 3D overlay 201 | ax_overlay_3d.set_xlabel("Block size (n_token_seq_len)") 202 | ax_overlay_3d.set_ylabel("N-gram size") 203 | ax_overlay_3d.set_zlabel("Tokens / second") 204 | ax_overlay_3d.set_title(f"Overlay best-fit surfaces across K (r={r})") 205 | ax_overlay_3d.view_init(elev=25, azim=-135) 206 | plt.tight_layout() 207 | out_overlay_surface = out_dir / f"best_fit_surface_overlay_r{r}.png" 208 | fig_overlay_3d.savefig(out_overlay_surface) 209 | plt.close(fig_overlay_3d) 210 | print(f"[PLOT] Saved overlay 3D best-fit surfaces to {out_overlay_surface}") 211 | 212 | # 2D overlay 213 | ax_overlay_2d.set_xlabel("Block size (n_token_seq_len)") 214 | ax_overlay_2d.set_ylabel("N-gram size") 215 | ax_overlay_2d.set_title(f"Overlay best-fit contours across K (r={r})") 216 | if legend_handles: 217 | ax_overlay_2d.legend(handles=legend_handles, title="K") 218 | ax_overlay_2d.grid(True, alpha=0.2) 219 | plt.tight_layout() 220 | out_overlay_contour = out_dir / f"best_fit_contour_overlay_r{r}.png" 221 | fig_overlay_2d.savefig(out_overlay_contour) 222 | plt.close(fig_overlay_2d) 223 | print(f"[PLOT] Saved overlay 2D best-fit contours to {out_overlay_contour}") 224 | else: 225 | plt.close(fig_overlay_3d) 226 | plt.close(fig_overlay_2d) 227 | print("[WARN] No valid K slices to plot overlays.") 228 | 229 | 230 | def main(): 231 | parser = argparse.ArgumentParser() 232 | parser.add_argument( 233 | "--input_jsonl", 234 | default="/home/lah003/workspace/CLLM2/profiling_results/summary/shiftedattn-10-16-7b-qwen2p5-coder-n32w16-n16distill-data-v2-ar-1-cyclic-noise-all-1e-6-summary.jsonl", 235 | help="Path to summary JSONL generated from logs", 236 | ) 237 | parser.add_argument( 238 | "--out_dir", 239 | type=str, 240 | default=None, 241 | help="Directory to save plots (default: same dir as JSONL)", 242 | ) 243 | parser.add_argument( 244 | "--r", 245 | type=float, 246 | default=0.85, 247 | help="Fixed r value to filter on (default: 0.85)", 248 | ) 249 | parser.add_argument( 250 | "--r_tol", 251 | type=float, 252 | default=1e-3, 253 | help="Tolerance when matching r (default: 1e-3)", 254 | ) 255 | args = parser.parse_args() 256 | 257 | jsonl_path = Path(args.input_jsonl) 258 | if not jsonl_path.is_file(): 259 | raise SystemExit(f"JSONL file not found: {jsonl_path}") 260 | 261 | out_dir = Path(args.out_dir) if args.out_dir else jsonl_path.parent 262 | out_dir.mkdir(parents=True, exist_ok=True) 263 | 264 | df = load_data(jsonl_path) 265 | df_r = filter_for_r(df, r=args.r, r_tol=args.r_tol) 266 | df_agg = aggregate_tokens_per_sec(df_r) 267 | 268 | plot_best_fit_for_all_K(df_agg, out_dir=out_dir, r=args.r) 269 | 270 | print("[DONE]") 271 | 272 | 273 | if __name__ == "__main__": 274 | main() 275 | -------------------------------------------------------------------------------- /JacobiForcing/jacobi_forcing_inference_MATH500.py: -------------------------------------------------------------------------------- 1 | from transformers import Qwen2ForCausalLM, Qwen3ForCausalLM, AutoTokenizer 2 | from datasets import load_dataset 3 | from einops import rearrange 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import torch 7 | import random 8 | import math 9 | import json 10 | from tqdm import tqdm 11 | import time 12 | 13 | import os 14 | 15 | import pandas as pd 16 | 17 | from pathlib import Path 18 | import sys 19 | path_root = Path(__file__).parents[1] 20 | sys.path.append(str(path_root)) 21 | 22 | from modeling.cllm2_qwen2_modeling_kv_terminate_on_eos_improved import jacobi_forward_greedy 23 | #Qwen2ForCausalLM.jacobi_forward_greedy = jacobi_forward_greedy 24 | Qwen3ForCausalLM.jacobi_forward_greedy = jacobi_forward_greedy 25 | 26 | # --------------------------- 27 | # Load dataset (first 100) 28 | # --------------------------- 29 | import pandas as pd 30 | 31 | df = pd.read_json("/home/lah003/data/MATH-500/test.jsonl", lines=True) 32 | df_size = len(df) 33 | print(f"Loaded MATH500 dataset with {df_size} samples") 34 | records = df.to_dict(orient="records") 35 | 36 | # --------------------------- 37 | # Load model/tokenizer once 38 | # --------------------------- 39 | model_name = "/home/lah003/models/1022-lx-math-4b-math-n16w16" 40 | model = Qwen3ForCausalLM.from_pretrained( 41 | model_name, 42 | device_map="cuda", 43 | torch_dtype=torch.bfloat16, 44 | attn_implementation="flash_attention_2" 45 | ) 46 | 47 | tokenizer = AutoTokenizer.from_pretrained("/home/lah003/models/Qwen3-4B-Instruct-2507") 48 | model.eval() 49 | 50 | 51 | eos_id = tokenizer.eos_token_id 52 | alt_eos_id = 151645 # keep your special EOS as a fallback 53 | 54 | # --------------------------- 55 | # Generation/profiling config 56 | # --------------------------- 57 | n_token_seq_len = 128 58 | 59 | # Safety caps so a sample can't run forever. 60 | max_new_tokens = 512 # hard cap on total new tokens per prompt 61 | max_calls = 1024 # hard cap on number of diffusion_decoding calls per prompt 62 | 63 | # --------------------------- 64 | # Iterate the dataset 65 | # --------------------------- 66 | all_rows = [] 67 | t0_overall = time.perf_counter() 68 | all_generations = [] 69 | 70 | total_gen_only_time = 0 71 | 72 | for idx, row in tqdm(enumerate(records[:10])): 73 | task_id = row.get("task_id", f"idx_{idx}") 74 | # prompt = """Problem: {}\nMark your solution with \\boxed\nAnswer:""".strip().format( 75 | # row["problem"].strip() 76 | # ) 77 | prompt = row["problem"] 78 | # messages = [{"role": "user", "content": prompt}] 79 | messages = [ 80 | {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, 81 | {"role": "user", "content": prompt} 82 | ] 83 | text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 84 | model_inputs = tokenizer([text], return_tensors="pt").to(model.device) 85 | 86 | input_ids = model_inputs["input_ids"] 87 | attention_mask = torch.full_like(input_ids, 1, device=model.device) 88 | 89 | # per-example stats 90 | iters = [] 91 | total_new_tokens = 0 92 | calls = 0 93 | prev_len = input_ids.shape[1] 94 | prompt_len = prev_len 95 | stop_reason = None 96 | prefill_phase = True 97 | generated_ids = input_ids 98 | 99 | prefill_drafted_n_gram = None 100 | 101 | gen_only_time = 0 102 | 103 | t_start = time.time() 104 | # run until EOS or caps 105 | while True: 106 | # Check EOS 107 | generated_part = generated_ids[0, prompt_len:] 108 | hit_eos = False 109 | if eos_id is not None: 110 | hit_eos = (generated_part == eos_id).any().item() 111 | if not hit_eos: 112 | # allow alternate special EOS id 113 | hit_eos = (generated_part == alt_eos_id).any().item() 114 | 115 | if hit_eos: 116 | stop_reason = "eos" 117 | break 118 | if total_new_tokens >= max_new_tokens: 119 | stop_reason = "max_new_tokens" 120 | break 121 | if calls >= max_calls: 122 | stop_reason = "max_calls" 123 | break 124 | 125 | #print(f"\nInit new subsequence {calls}...\n") 126 | 127 | ### One diffusion decoding call 128 | if prefill_phase: 129 | # pass in random-init draft 130 | q_sampled = [] 131 | for _ in range(n_token_seq_len): 132 | q_sample = torch.tensor([random.choice(generated_ids[0].tolist())], dtype=torch.long, device=model.device).unsqueeze(0) 133 | q_sampled.append(q_sample) 134 | prefill_draft_token_ids = torch.cat(q_sampled, dim=1) # shape [1, n_token_seq_len] 135 | 136 | prefill_input_ids = torch.cat((input_ids, prefill_draft_token_ids),dim=-1) 137 | 138 | # `jacobi_forward_greedy` will return iteration result from first iteration 139 | past_key_values, first_correct_token, prefill_drafted_n_gram, iter_count = model.jacobi_forward_greedy( 140 | input_ids=prefill_input_ids, 141 | attention_mask=attention_mask, 142 | past_key_values=None, 143 | use_cache=True, 144 | prefill_phase=prefill_phase, 145 | n_token_seq_len=n_token_seq_len, 146 | tokenizer=tokenizer, 147 | eos_token_id=eos_id, 148 | ) 149 | prefill_phase = False 150 | generated_ids = input_ids 151 | itr_count = 0 152 | else: 153 | # generation phase 154 | # ---- Initialize a draft tail (any tokens work; we'll fix on the first pass). 155 | # We keep your "random from prompt" init to avoid extra forward passes. 156 | if calls == 1: 157 | # First non-prefill call: reuse draft_tokens produced by prefill 158 | input_ids = prefill_drafted_n_gram 159 | else: 160 | q_sampled = [] 161 | for _ in range(n_token_seq_len-1): 162 | q_sample = torch.tensor([random.choice(generated_ids[0].tolist())], dtype=torch.long, device=model.device).unsqueeze(0) 163 | q_sampled.append(q_sample) 164 | q_sampled = torch.cat(q_sampled, dim=1) # shape [1, n_token_seq_len-1] 165 | input_ids = torch.cat((first_correct_token.view(1,-1), q_sampled),dim=-1) 166 | 167 | t_gen_start = time.perf_counter() 168 | past_key_values, first_correct_token, accepted_n_gram, itr_count = model.jacobi_forward_greedy( 169 | input_ids=input_ids, 170 | attention_mask=None, 171 | past_key_values=past_key_values, 172 | use_cache=True, 173 | prefill_phase=prefill_phase, 174 | n_token_seq_len=n_token_seq_len, 175 | tokenizer=tokenizer, 176 | eos_token_id=eos_id, 177 | ) 178 | t_gen_time = time.perf_counter() - t_gen_start 179 | gen_only_time += t_gen_time 180 | 181 | generated_ids = torch.cat((generated_ids, accepted_n_gram), dim=-1) 182 | 183 | calls += 1 184 | iters.append(itr_count) 185 | 186 | added = generated_ids.shape[1] - prev_len 187 | if added > 0: 188 | total_new_tokens += added 189 | prev_len = generated_ids.shape[1] 190 | 191 | # subtract prefill 192 | total_new_tokens -= 1 193 | # per-example finalize 194 | dt = time.time() - t_start 195 | total_iterations = sum(iters) 196 | avg_iter_per_call = (total_iterations / calls) 197 | avg_iter_per_token = (total_iterations / total_new_tokens) 198 | 199 | toks_per_sec = (total_new_tokens / gen_only_time) 200 | 201 | total_gen_only_time += gen_only_time 202 | 203 | prompt_len = model_inputs["input_ids"].shape[1] 204 | generated_str = ''.join(tokenizer.decode(generated_ids[0, prompt_len:], skip_special_tokens=False)) 205 | print(f'Generated answers: {generated_str}') 206 | all_generations.append(generated_str) 207 | 208 | all_rows.append( 209 | { 210 | "index": idx, 211 | "task_id": task_id, 212 | "prompt_tokens": prompt_len, 213 | "new_tokens": total_new_tokens, 214 | "calls": calls, 215 | "total_iterations": total_iterations, 216 | "avg_iter_per_call": avg_iter_per_call, 217 | "avg_iter_per_token": avg_iter_per_token, 218 | "time_sec": dt, 219 | "toks_per_sec": toks_per_sec, 220 | "stop_reason": stop_reason, 221 | } 222 | ) 223 | 224 | # light progress 225 | if (idx + 1) % 5 == 0 or (idx + 1) == len(records): 226 | print(f"====[{idx+1}/{len(records)}] task_id={task_id} new_toks={total_new_tokens} " 227 | f"calls={calls} avg_iter/call={avg_iter_per_call:.2f} reason={stop_reason}====") 228 | 229 | #### ADDED Lines #### 230 | # --------------------------- 231 | # Aggregate + save 232 | # --------------------------- 233 | t_overall = time.perf_counter() - t0_overall 234 | df_profile = pd.DataFrame(all_rows) 235 | csv_path = "diffusion_profile_math500.csv" 236 | df_profile.to_csv(csv_path, index=False) 237 | 238 | # Print quick summary (EOS-only) 239 | def _safe_mean(series): 240 | s = pd.to_numeric(series, errors="coerce") 241 | return float(s.mean()) if s.size and not pd.isna(s).all() else float("nan") 242 | 243 | df_eos = df_profile[df_profile["stop_reason"] == "eos"].copy() 244 | n_eos = len(df_eos) 245 | n_total = len(df_profile) 246 | 247 | print("\n=== Diffusion Decoding Profiling — EOS-only ===") 248 | print(f"Examples (eos): {n_eos} / {n_total} Total wall time: {t_overall:.4f}s") 249 | print(f"Avg new tokens / prompt: {_safe_mean(df_eos['new_tokens']):.4f}") 250 | print(f"Avg calls / prompt: {_safe_mean(df_eos['calls']):.4f}") 251 | print(f"Avg iterations / call: {_safe_mean(df_eos['avg_iter_per_call']):.4f}") 252 | print(f"Avg iterations / token: {_safe_mean(df_eos['avg_iter_per_token']):.4f}") 253 | print(f"Avg toks/sec: {_safe_mean(df_eos['toks_per_sec']):.4f}") 254 | 255 | # Optional: also show overall stop-reason distribution for context 256 | print("\nStop reasons (all examples):") 257 | print(df_profile['stop_reason'].value_counts()) 258 | 259 | # Optional: save EOS-only rows too 260 | df_eos.to_csv("diffusion_profile_greedy_math500_eos.csv", index=False) 261 | -------------------------------------------------------------------------------- /generate_trajectory/data/2_prepare_efficient_cllm_training_data_new_progressive_noise.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse, json, hashlib, random, os, sqlite3, tempfile, atexit, itertools, re 3 | from concurrent.futures import ProcessPoolExecutor, as_completed, wait, FIRST_COMPLETED 4 | from typing import Optional, Dict, Any, Tuple, List 5 | from tqdm import tqdm 6 | from datasets import load_dataset 7 | 8 | # ----------------------- 9 | # Args 10 | # ----------------------- 11 | def parse_args(): 12 | parser = argparse.ArgumentParser( 13 | description="Low-RAM parallel preprocessing for efficient CLLM training JSONL." 14 | ) 15 | parser.add_argument("--input_path", required=True) 16 | parser.add_argument("--output_path", required=True) 17 | parser.add_argument("--half-cap-idx", 18 | type=int, 19 | required=True, 20 | help="Cap for the negative support index: we use -(min(2 + diffusion_itr, half_cap_idx)). " 21 | "Must be >= 2. Example: 8") 22 | parser.add_argument("--n-token-seq-length", type=int, default=64) 23 | parser.add_argument("--cache-dir", default=None) 24 | parser.add_argument("--seed", type=int, default=42) 25 | parser.add_argument("--num-workers", type=int, default=max(os.cpu_count(), 1)) 26 | parser.add_argument("--max-in-flight", type=int, default=None) 27 | parser.add_argument("--db-path", default=None) 28 | parser.add_argument("--no-progress", action="store_true") 29 | parser.add_argument("--verbose", action="store_true") 30 | parser.add_argument("--single-process", action="store_true") 31 | # in parse_args() 32 | 33 | return parser.parse_args() 34 | 35 | # ----------------------- 36 | # Deterministic RNG 37 | # ----------------------- 38 | def stable_seed(*parts: Any, base_seed: int = 0) -> int: 39 | h = hashlib.sha256() 40 | for p in parts: 41 | h.update(str(p).encode("utf-8")); h.update(b"|") 42 | h.update(base_seed.to_bytes(8, "little", signed=False)) 43 | return int.from_bytes(h.digest()[:8], "big", signed=False) 44 | 45 | # ----------------------- 46 | # ID parsing helpers 47 | # ----------------------- 48 | def parse_data_id_int(data_id: str) -> int: 49 | """Extract the integer from 'data_{id}' robustly.""" 50 | if data_id.startswith("data_"): 51 | return int(data_id[5:]) 52 | m = re.search(r"(\d+)", data_id) 53 | return int(m.group(1)) if m else 0 54 | 55 | def parse_itr_int(itr_id: str) -> int: 56 | """Extract the integer from 'itr_{iteration_id}' robustly.""" 57 | if itr_id.startswith("itr_"): 58 | return int(itr_id[4:]) 59 | m = re.search(r"(\d+)", itr_id) 60 | return int(m.group(1)) if m else 0 61 | 62 | # ----------------------- 63 | # Per-sample transform 64 | # ----------------------- 65 | def parse_itr_int(itr_id: str) -> int: 66 | """ 67 | Extracts the integer that follows 'itr_' using regex. 68 | Examples: 69 | 'itr_0' -> 0 70 | 'itr_17a' -> 17 71 | Falls back to 0 if no match. 72 | """ 73 | m = re.search(r"itr_(\d+)", itr_id) 74 | return int(m.group(1)) if m else 0 75 | 76 | 77 | def process_one_sample( 78 | sample: Dict[str, Any], 79 | n_token_seq_length: int, 80 | base_seed: int, # kept for signature compatibility; unused 81 | half_cap_idx: int, # positional cap 82 | ) -> Optional[Tuple[str, Dict[str, Any]]]: 83 | """ 84 | For diffusion_itr = i, choose support index = -(min(2 + i, half_cap_idx)). 85 | i.e., itr_0 -> -2, itr_1 -> -3, ..., capped at -half_cap_idx. 86 | """ 87 | data_id = sample["data_id"] 88 | diffusion_itr_id = sample["diffusion_itr_id"] 89 | data_id_int = parse_data_id_int(data_id) 90 | diffusion_itr = parse_itr_int(diffusion_itr_id) # <-- regex-derived integer 91 | 92 | prompt_ids = sample["prompt_ids"] 93 | answer_traj = sample["answer_trajectory_ids"] 94 | 95 | if len(answer_traj) < 2: 96 | return None # need at least one support and the final (low-noise) step 97 | 98 | # sanitize the cap (ensure we never choose -1) 99 | half_cap_idx = max(2, int(half_cap_idx)) 100 | 101 | # compute the negative offset: -(min(2 + i, half_cap_idx)), but don’t exceed traj length 102 | # convert to forward index k_j in [0, len(answer_traj)-2] 103 | neg_offset = min(2 + diffusion_itr, half_cap_idx, len(answer_traj)) 104 | k_j = len(answer_traj) - neg_offset 105 | 106 | sampled_seq = answer_traj[k_j][-n_token_seq_length:] 107 | fixed_seq = answer_traj[-1][-n_token_seq_length:] 108 | pair_seq = list(sampled_seq) + list(fixed_seq) 109 | 110 | entry = dict( 111 | data_id=data_id, 112 | data_id_int=int(data_id_int), 113 | prompt_ids=list(prompt_ids), 114 | pairs=[dict( 115 | diffusion_itr=int(diffusion_itr), 116 | traj_position_index=int(k_j), 117 | seq=pair_seq 118 | )], 119 | ) 120 | return data_id, entry 121 | 122 | # ----------------------- 123 | # In-memory merge helpers 124 | # ----------------------- 125 | def merge_entry(existing: Dict[str, Any], new_entry: Dict[str, Any], verbose: bool = False): 126 | existing["pairs"].extend(new_entry["pairs"]) 127 | if verbose: 128 | print(f"Merged duplicate data_id {existing['data_id']} " 129 | f"(total pairs: {len(existing['pairs'])})") 130 | 131 | # ----------------------- 132 | # SQLite helpers 133 | # ----------------------- 134 | def open_db(path: str) -> sqlite3.Connection: 135 | conn = sqlite3.connect(path, timeout=60) 136 | conn.execute("PRAGMA journal_mode=WAL;") 137 | conn.execute(""" 138 | CREATE TABLE IF NOT EXISTS entries ( 139 | data_id TEXT PRIMARY KEY, 140 | data_id_int INTEGER NOT NULL, 141 | value TEXT NOT NULL 142 | ) 143 | """) 144 | conn.execute("CREATE INDEX IF NOT EXISTS idx_entries_data_id_int ON entries(data_id_int)") 145 | return conn 146 | 147 | def db_get(conn: sqlite3.Connection, data_id: str) -> Optional[Dict[str, Any]]: 148 | cur = conn.execute("SELECT value FROM entries WHERE data_id=?", (data_id,)) 149 | row = cur.fetchone() 150 | return None if row is None else json.loads(row[0]) 151 | 152 | def db_put(conn: sqlite3.Connection, data_id: str, entry: Dict[str, Any]): 153 | conn.execute( 154 | "INSERT INTO entries (data_id, data_id_int, value) VALUES(?,?,?) " 155 | "ON CONFLICT(data_id) DO UPDATE SET value=excluded.value", 156 | (data_id, int(entry["data_id_int"]), json.dumps(entry, ensure_ascii=False)) 157 | ) 158 | 159 | def merge_into_db(conn, data_id, entry, verbose): 160 | existing = db_get(conn, data_id) 161 | if existing is None: 162 | db_put(conn, data_id, entry) 163 | else: 164 | merge_entry(existing, entry, verbose) 165 | db_put(conn, data_id, existing) 166 | 167 | # ----------------------- 168 | # Main 169 | # ----------------------- 170 | def main(): 171 | args = parse_args() 172 | max_in_flight = args.max_in_flight or max(4, args.num_workers * 4) 173 | 174 | # Streaming dataset 175 | data = load_dataset( 176 | "json", 177 | data_files={"train": args.input_path}, 178 | split="train", 179 | streaming=True, 180 | cache_dir=args.cache_dir 181 | ) 182 | 183 | # SQLite store 184 | db_path = args.db_path or tempfile.mkstemp( 185 | prefix="merge_", suffix=".sqlite", 186 | dir=os.path.dirname(os.path.abspath(args.output_path)) or "." 187 | )[1] 188 | conn = open_db(db_path) 189 | atexit.register(lambda: conn.close()) 190 | 191 | pbar = tqdm(desc="Processing", disable=args.no_progress) 192 | 193 | def handle(res): 194 | if res is None: 195 | pbar.update(1); return 196 | data_id, entry = res 197 | merge_into_db(conn, data_id, entry, args.verbose) 198 | if (pbar.n % 5000) == 0: 199 | conn.commit() 200 | pbar.update(1) 201 | 202 | try: 203 | if args.single_process or args.num_workers <= 1: 204 | for sample in data: 205 | handle(process_one_sample(sample, args.n_token_seq_length, args.seed, args.half_cap_idx)) 206 | conn.commit() 207 | else: 208 | in_flight = set() 209 | with ProcessPoolExecutor(max_workers=args.num_workers) as ex: 210 | for i, sample in enumerate(data): 211 | in_flight.add(ex.submit( 212 | process_one_sample, sample, args.n_token_seq_length, args.seed, args.half_cap_idx 213 | )) 214 | if len(in_flight) >= max_in_flight: 215 | done, in_flight = wait(in_flight, return_when=FIRST_COMPLETED) 216 | for fut in done: handle(fut.result()) 217 | for fut in as_completed(in_flight): handle(fut.result()) 218 | conn.commit() 219 | finally: 220 | pbar.close() 221 | 222 | # -------- Final write-out with sorting -------- 223 | with open(args.output_path, "w", encoding="utf-8") as fout: 224 | cur = cur = conn.execute("SELECT value FROM entries ORDER BY data_id_int") 225 | count = 0 226 | 227 | # across all data_id 228 | for (value_str,) in cur: 229 | entry = json.loads(value_str) 230 | 231 | # Sort the (k_j,last_j) pairs by diffusion_itr (int) 232 | pairs_sorted = sorted(entry["pairs"], key=lambda p: p["diffusion_itr"]) 233 | 234 | # Flatten sequences 235 | concatenated_pairs: List[int] = list( 236 | itertools.chain.from_iterable(p["seq"] for p in pairs_sorted) 237 | ) 238 | 239 | traj_position_indices: List[int] = list( 240 | p["traj_position_index"] for p in pairs_sorted 241 | ) 242 | 243 | output_entry = dict( 244 | data_id = entry["data_id"], 245 | prompt_ids = entry["prompt_ids"][0], 246 | complete_training_sequence_ids = entry["prompt_ids"][0] + concatenated_pairs, 247 | prompt_ids_len = len(entry["prompt_ids"][0]), 248 | traj_position_indices = traj_position_indices, 249 | ) 250 | fout.write(json.dumps(output_entry, ensure_ascii=False)) 251 | fout.write("\n") 252 | count += 1 253 | 254 | # Remove temp DB if we created it 255 | if not args.db_path: 256 | try: os.remove(db_path) 257 | except Exception: pass 258 | 259 | print(f"Processed {count} unique data_id samples --> {args.output_path}") 260 | 261 | if __name__ == "__main__": 262 | main() --------------------------------------------------------------------------------