├── resource └── logo.png ├── bash ├── eval │ └── run_eval.sh ├── train │ └── sft_train.sh └── ablation │ └── ablation_study.sh ├── script ├── analysis │ ├── length_ratio │ │ ├── length_ratio_analysis_with_improvement.png │ │ ├── extract_metrics_enhanced.py │ │ └── compute_response_lengths.py │ ├── venn │ │ ├── venn_config.json │ │ └── plot_venn_from_json.py │ ├── proportion │ │ ├── proportion.py │ │ └── auto_kv_cache_evaluation.py │ ├── gate_weight │ │ ├── collect_projector_weights.py │ │ └── analyze_projector_weights.ipynb │ └── scaling │ │ ├── scaling_curve.py │ │ ├── batch_evaluate_T2T.py │ │ └── batch_evaluate_checkpoints.py ├── dataset │ ├── launch_server.sh │ ├── run_generation.sh │ ├── save_dataset.py │ └── qwen3_nonthinking.jinja ├── evaluation │ └── run_tests.sh ├── train │ └── generate_configs.py ├── playground │ ├── live_chat_example.py │ ├── sample_response.py │ └── inference_example.py └── examples │ ├── two_stage_example.py │ └── two_stage_rosetta_example.py ├── .gitignore ├── rosetta ├── train │ ├── __init__.py │ └── model_utils.py ├── utils │ ├── core.py │ └── registry.py └── model │ ├── sampling.py │ └── ablation_projector.py ├── recipe ├── eval_recipe │ ├── ablation_base.yaml │ └── unified_eval.yaml └── train_recipe │ ├── baseline_partial_config.json │ ├── baseline_config.json │ ├── baseline_lora_config.json │ ├── oracle.json │ ├── C2C_0.6+0.5.json │ ├── C2C_ablation.json │ └── all_in_one.json ├── pyproject.toml ├── environment.yml ├── LICENSE └── README.md /resource/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-nics/C2C/HEAD/resource/logo.png -------------------------------------------------------------------------------- /bash/eval/run_eval.sh: -------------------------------------------------------------------------------- 1 | python script/evaluation/unified_evaluator.py --config recipe/eval_recipe/unified_eval.yaml -------------------------------------------------------------------------------- /bash/train/sft_train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | torchrun --nproc_per_node=1 --master_port=29501 script/train/SFT_train.py \ 3 | --config recipe/train_recipe/C2C_0.6+0.5.json \ -------------------------------------------------------------------------------- /script/analysis/length_ratio/length_ratio_analysis_with_improvement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-nics/C2C/HEAD/script/analysis/length_ratio/length_ratio_analysis_with_improvement.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # File types 2 | *.env.sh 3 | *__pycache__* 4 | *.egg-info 5 | *.png 6 | 7 | # Folders 8 | .vscode/ 9 | .gradio/ 10 | local/ 11 | wandb/ 12 | logs/ 13 | eval_results/ 14 | outputs/ 15 | results/ 16 | checkpoints/ 17 | 18 | recipe/experiments/ 19 | script/evaluation/deprecated/ 20 | script/analysis/oracle/ 21 | 22 | # Special files 23 | run.sh 24 | test.py 25 | script/train/SFT_train_debug.py -------------------------------------------------------------------------------- /rosetta/train/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training utilities for RosettaModel 3 | """ 4 | 5 | from .dataset_adapters import ( 6 | ChatDataset, 7 | RosettaDataCollator, 8 | ) 9 | from .model_utils import setup_models 10 | 11 | __all__ = [ 12 | "RosettaTrainer", 13 | "ProjectorSaveCallback", 14 | "freeze_model_components", 15 | "InstructCoderChatDataset", 16 | "ChatDataset", 17 | "RosettaDataCollator", 18 | "create_instructcoder_dataset", 19 | "setup_models" 20 | ] -------------------------------------------------------------------------------- /script/dataset/launch_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Simple script to launch SGLang server for OpenHermes dataset generation 4 | 5 | echo "🚀 Launching SGLang server..." 6 | 7 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 8 | 9 | python3 -m sglang.launch_server \ 10 | --model-path Qwen/Qwen3-4B \ 11 | --host 0.0.0.0 \ 12 | --port 30000 \ 13 | --tp-size 1 \ 14 | --dp-size 8 \ 15 | --mem-fraction-static 0.9 \ 16 | --dtype bfloat16 \ 17 | --log-level warning \ 18 | --chat-template script/dataset/create/qwen3_nonthinking.jinja 19 | 20 | echo "Server stopped." 21 | -------------------------------------------------------------------------------- /recipe/eval_recipe/ablation_base.yaml: -------------------------------------------------------------------------------- 1 | eval: 2 | answer_method: generate 3 | dataset: mmlu-redux 4 | gpu_ids: 5 | - 0 6 | - 1 7 | - 2 8 | - 3 9 | - 4 10 | - 5 11 | - 6 12 | - 7 13 | sample_interval: 1 14 | use_cot: false 15 | use_template: true 16 | model: 17 | generation_config: 18 | do_sample: false 19 | max_new_tokens: 8 20 | model_name: Rosetta 21 | rosetta_config: 22 | alignment_strategy: longest 23 | base_model: Qwen/Qwen3-0.6B 24 | checkpoints_dir: local/checkpoints/ablation_study_general/level_0/final 25 | is_do_alignment: false 26 | teacher_model: Qwen/Qwen2.5-0.5B-Instruct 27 | output: 28 | output_dir: local/ablation_results/level_0_mmlu-redux 29 | -------------------------------------------------------------------------------- /script/dataset/run_generation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Simple script to run OpenHermes dataset generation 4 | # Make sure the server is running first! 5 | 6 | echo "🚀 Starting dataset generation..." 7 | 8 | python script/dataset/create/create_mmlu.py \ 9 | --model_path "Qwen/Qwen3-4B" \ 10 | --api_url "http://localhost:30000/v1" \ 11 | --dataset_path "cais/mmlu" \ 12 | --output_dir "local/teacher_datasets/mmlu_4b_output_150_words" \ 13 | --max_concurrent_requests 256 \ 14 | --max_new_tokens 512 \ 15 | --split auxiliary_train \ 16 | --temperature 0.7 \ 17 | --top_p 0.8 \ 18 | --top_k 20 \ 19 | --min_p 0.0 \ 20 | --request_timeout 6000 \ 21 | --save_every 100 \ 22 | --sample_every_n 6 \ 23 | 24 | echo "✅ Generation completed!" 25 | -------------------------------------------------------------------------------- /script/analysis/venn/venn_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "model_name": "Rosetta", 4 | "rosetta_config":{ 5 | "base_model": "Qwen/Qwen3-0.6B", 6 | "teacher_model": "Qwen/Qwen2.5-0.5B-Instruct" 7 | } 8 | }, 9 | "output": { 10 | "output_dir": "./script/analysis/venn/venn_results", 11 | "result_dir": "./script/analysis/venn/venn_results" 12 | }, 13 | "eval":{ 14 | "use_cot": false, 15 | "gpu_ids": [0, 1, 2, 3, 4, 5, 6, 7], 16 | "answer_method": "logits", 17 | "checkpoints_dir": "local/checkpoints/0.6+0.5B_C2C_mmlu/final" 18 | }, 19 | "models":{ 20 | "slm": "Qwen/Qwen3-0.6B", 21 | "llm": "Qwen/Qwen2.5-0.5B-Instruct", 22 | "rosetta": "0.6+0.5B_C2C_mmlu" 23 | } 24 | } -------------------------------------------------------------------------------- /bash/ablation/ablation_study.sh: -------------------------------------------------------------------------------- 1 | # Defaults (override by exporting env vars before calling this script) 2 | BASE_CONFIG=${BASE_CONFIG:-recipe/train_recipe/ablation_base.json} 3 | BASE_EVAL_CONFIG=${BASE_EVAL_CONFIG:-recipe/eval_recipe/ablation_base.yaml} 4 | OUTPUT_DIR=${OUTPUT_DIR:-local/checkpoints/ablation_study_general} 5 | EVAL_OUTPUT_DIR=${EVAL_OUTPUT_DIR:-local/ablation_results_general} 6 | GPU_IDS=${GPU_IDS:-0,1,2,3,4,5,6,7} 7 | ABLATION_LEVELS=${ABLATION_LEVELS:-0,1,2,3,4} 8 | MASTER_PORT=${MASTER_PORT:-29504} 9 | 10 | # GPU visibility 11 | export CUDA_VISIBLE_DEVICES="$GPU_IDS" 12 | 13 | # Forward any additional flags directly to the Python script 14 | python script/ablation/ablation_study.py \ 15 | --base_config "$BASE_CONFIG" \ 16 | --base_eval_config "$BASE_EVAL_CONFIG" \ 17 | --output_dir "$OUTPUT_DIR" \ 18 | --eval_output_dir "$EVAL_OUTPUT_DIR" \ 19 | --gpu_ids "$GPU_IDS" \ 20 | --ablation_levels "$ABLATION_LEVELS" \ 21 | --master_port "$MASTER_PORT" \ 22 | "$@" 23 | -------------------------------------------------------------------------------- /script/analysis/proportion/proportion.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import pandas as pd 4 | 5 | # 数据 6 | data = { 7 | "Percentage": ["0%", "25%", "50%", "75%", "100%"], 8 | "Former": [35.53, 21.32, 46.64, 56.41, 61.86], 9 | "Latter": [35.53, 30.58, 33.65, 45.88, 61.86] 10 | } 11 | df = pd.DataFrame(data) 12 | 13 | # 设置论文风格 14 | sns.set_theme(style="whitegrid", font="serif", font_scale=1.2) 15 | 16 | # 绘制折线图 17 | plt.figure(figsize=(6,4)) 18 | sns.lineplot(data=df, x="Percentage", y="Former", marker="o", label="Former", linewidth=2) 19 | line = sns.lineplot(data=df, x="Percentage", y="Latter", marker="s", label="Latter", linewidth=2) 20 | # 添加baseline水平虚线,使用与折线相同的颜色 21 | plt.axhline(y=71.38, color=line.get_lines()[0].get_color(), linestyle='--', linewidth=2, label='Sharer Model') 22 | 23 | # 美化 24 | plt.xlabel("Proportion of C2C Fused KV-Cache", fontsize=13) 25 | plt.ylabel("Accuracy", fontsize=13) 26 | plt.legend(title="", fontsize=11, frameon=False) 27 | plt.tight_layout() 28 | 29 | # 保存高分辨率图片 30 | plt.savefig("./proportion.pdf", dpi=300) 31 | plt.show() -------------------------------------------------------------------------------- /recipe/train_recipe/baseline_partial_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "baseline_model": "Qwen/Qwen3-0.6B" 4 | }, 5 | "training": { 6 | "learning_rate": 3e-4, 7 | "weight_decay": 0.01, 8 | "num_epochs": 3, 9 | "max_length": 2048, 10 | "device": "cuda", 11 | "scheduler_type": "linear", 12 | "warmup_ratio": 0.1, 13 | "max_grad_norm": 1.0, 14 | "per_device_train_batch_size": 4, 15 | "gradient_accumulation_steps": 4, 16 | "seed": 42, 17 | "partial_training": { 18 | "method": "layer_wise", 19 | "ratio": 0.6 20 | } 21 | }, 22 | "data": { 23 | "type": "MMLUChatDataset", 24 | "kwargs": { 25 | "split": "test", 26 | "num_samples": 1000 27 | }, 28 | "train_ratio": 0.8 29 | }, 30 | "output": { 31 | "output_dir": "outputs/baseline_partial", 32 | "eval_steps": 100, 33 | "save_steps": 500, 34 | "wandb_config": { 35 | "project": "baseline_partial_training", 36 | "run_name": "baseline_partial_run", 37 | "mode": "offline" 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /recipe/train_recipe/baseline_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "baseline_model": "Qwen/Qwen3-0.6B" 4 | }, 5 | "training": { 6 | "learning_rate": 1e-4, 7 | "weight_decay": 0.01, 8 | "num_epochs": 1, 9 | "max_length": 2048, 10 | "device": "cuda", 11 | "scheduler_type": "linear", 12 | "warmup_ratio": 0.1, 13 | "max_grad_norm": 1.0, 14 | "per_device_train_batch_size": 4, 15 | "num_processes": 8, 16 | "freeze": [], 17 | "gradient_accumulation_steps": 8, 18 | "seed": 42 19 | }, 20 | "output": { 21 | "output_dir": "local/checkpoints/0.6B_general_finetune", 22 | "save_steps": 500, 23 | "eval_steps": 100, 24 | "wandb_config": { 25 | "project": "Rosetta", 26 | "mode": "offline", 27 | "run_name": "baseline_general_finetune_OpenHermes_500k", 28 | "entity": "nics-efc" 29 | } 30 | }, 31 | "data": { 32 | "type": "OpenHermesChatDataset", 33 | "kwargs": { 34 | "split": "train", 35 | "max_word_count": 2048, 36 | "num_samples": 500000 37 | }, 38 | "train_ratio": 0.99 39 | } 40 | } -------------------------------------------------------------------------------- /recipe/train_recipe/baseline_lora_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "baseline_model": "Qwen/Qwen3-0.6B" 4 | }, 5 | "training": { 6 | "learning_rate": 1e-4, 7 | "weight_decay": 0.01, 8 | "num_epochs": 3, 9 | "max_length": 2048, 10 | "device": "cuda", 11 | "scheduler_type": "linear", 12 | "warmup_ratio": 0.1, 13 | "max_grad_norm": 1.0, 14 | "per_device_train_batch_size": 4, 15 | "gradient_accumulation_steps": 4, 16 | "seed": 42, 17 | "lora": { 18 | "r": 16, 19 | "lora_alpha": 32, 20 | "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 21 | "lora_dropout": 0.1, 22 | "bias": "none", 23 | "task_type": "CAUSAL_LM" 24 | } 25 | }, 26 | "data": { 27 | "type": "MMLUChatDataset", 28 | "kwargs": { 29 | "split": "test", 30 | "num_samples": 1000 31 | }, 32 | "train_ratio": 0.8 33 | }, 34 | "output": { 35 | "output_dir": "outputs/baseline_lora", 36 | "eval_steps": 100, 37 | "save_steps": 500, 38 | "wandb_config": { 39 | "project": "baseline_lora_training", 40 | "run_name": "baseline_lora_run", 41 | "mode": "offline", 42 | "entity": "nics-efc" 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /recipe/train_recipe/oracle.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "base_model": "/mnt/public/public_models/Qwen3-0.6B", 4 | "teacher_model": "/mnt/public/public_models/Qwen3-4B", 5 | "projector": { 6 | "type": "ReplaceProjector", 7 | "params": { 8 | "hidden_dim": 1024, 9 | "num_layers": 3, 10 | "dropout": 0.1, 11 | "activation": "gelu", 12 | "use_layer_norm": true, 13 | "init_weight": 0.0, 14 | "anneal_steps": 1160 15 | } 16 | }, 17 | "mapping": "last_aligned" 18 | }, 19 | "training": { 20 | "learning_rate": 3e-4, 21 | "weight_decay": 0.01, 22 | "num_epochs": 1, 23 | "max_length": 32768, 24 | "device": "cuda", 25 | "scheduler_type": "linear", 26 | "warmup_ratio": 0.1, 27 | "max_grad_norm": 1.0, 28 | "per_device_train_batch_size": 4, 29 | "num_processes": 8, 30 | "freeze": ["teacher","base"], 31 | "seed": 42 32 | }, 33 | "output": { 34 | "output_dir": "local/checkpoints", 35 | "save_steps": 10000, 36 | "eval_steps": 10000, 37 | "wandb_config": { 38 | "project": "Rosetta", 39 | "mode": "offline", 40 | "run_name": "std_mc" 41 | } 42 | }, 43 | "data": { 44 | "type": "MMLUChatDataset", 45 | "kwargs": { 46 | "split": "test", 47 | "num_samples": null 48 | }, 49 | "train_ratio": 0.99 50 | } 51 | } -------------------------------------------------------------------------------- /recipe/train_recipe/C2C_0.6+0.5.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "base_model": "Qwen/Qwen3-0.6B", 4 | "teacher_model": "Qwen/Qwen2.5-0.5B-Instruct", 5 | "is_do_alignment": false, 6 | "alignment_strategy": "first", 7 | "projector": { 8 | "type": "C2CProjector", 9 | "params": { 10 | "hidden_dim": 1024, 11 | "intermediate_dim": 1024, 12 | "num_layers": 3, 13 | "dropout": 0.1, 14 | "initial_temperature": 1.0, 15 | "final_temperature": 0.001, 16 | "anneal_steps": 1929 17 | } 18 | }, 19 | "mapping": "last_aligned" 20 | }, 21 | "training": { 22 | "learning_rate": 1e-4, 23 | "weight_decay": 0.01, 24 | "num_epochs": 1, 25 | "max_length": 2048, 26 | "device": "cuda", 27 | "scheduler_type": "linear", 28 | "warmup_ratio": 0.1, 29 | "max_grad_norm": 1.0, 30 | "gradient_accumulation_steps": 8, 31 | "per_device_train_batch_size": 4, 32 | "num_processes": 8, 33 | "freeze": ["teacher","base"], 34 | "seed": 42 35 | }, 36 | "output": { 37 | "output_dir": "local/checkpoints/0.6+0.5B_C2C_general_again", 38 | "save_steps": 500, 39 | "eval_steps": 100, 40 | "wandb_config": { 41 | "project": "Rosetta", 42 | "mode": "online", 43 | "entity": "nics-efc", 44 | "run_name": "0.6B+0.5B_C2C_general_OpenHermes_500k" 45 | } 46 | }, 47 | "data": { 48 | "type": "OpenHermesChatDataset", 49 | "kwargs": { 50 | "split": "train", 51 | "max_word_count": 2048, 52 | "num_samples": 500000 53 | }, 54 | "train_ratio": 0.99 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /recipe/train_recipe/C2C_ablation.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "base_model": "Qwen/Qwen3-0.6B", 4 | "teacher_model": "Qwen/Qwen2.5-0.5B-Instruct", 5 | "is_do_alignment": false, 6 | "alignment_strategy": "first", 7 | "projector": { 8 | "type": "AblationProjector", 9 | "params": { 10 | "hidden_dim": 1024, 11 | "intermediate_dim": 1024, 12 | "num_layers": 3, 13 | "dropout": 0.1, 14 | "initial_temperature": 1.0, 15 | "final_temperature": 0.001, 16 | "anneal_steps": 1929, 17 | "ablation_level": 3 18 | } 19 | }, 20 | "mapping": "last_aligned" 21 | }, 22 | "training": { 23 | "learning_rate": 1e-4, 24 | "weight_decay": 0.01, 25 | "num_epochs": 1, 26 | "max_length": 2048, 27 | "device": "cuda", 28 | "scheduler_type": "linear", 29 | "warmup_ratio": 0.1, 30 | "max_grad_norm": 1.0, 31 | "gradient_accumulation_steps": 8, 32 | "per_device_train_batch_size": 4, 33 | "num_processes": 8, 34 | "freeze": ["teacher","base"], 35 | "seed": 42 36 | }, 37 | "output": { 38 | "output_dir": "local/checkpoints/0.6+0.5B_C2C_general_ablation_level_3", 39 | "save_steps": 500, 40 | "eval_steps": 100, 41 | "wandb_config": { 42 | "project": "Rosetta", 43 | "mode": "online", 44 | "entity": "nics-efc", 45 | "run_name": "0.6B+0.5B_C2C_general_OpenHermes_500k_ablation_level_3" 46 | } 47 | }, 48 | "data": { 49 | "type": "OpenHermesChatDataset", 50 | "kwargs": { 51 | "split": "train", 52 | "max_word_count": 2048, 53 | "num_samples": 500000 54 | }, 55 | "train_ratio": 0.99 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /recipe/eval_recipe/unified_eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | # model_name: Qwen/Qwen3-0.6B 3 | model_name: Rosetta 4 | rosetta_config: # Only needed for Rosetta models 5 | base_model: Qwen/Qwen3-0.6B 6 | teacher_model: Qwen/Qwen2.5-0.5B-Instruct 7 | is_do_alignment: false 8 | alignment_strategy: "longest" 9 | checkpoints_dir: local/checkpoints/0.6+0.5B_C2C_general_again/final 10 | 11 | # Two stage 12 | # model_name: "two_stage" # Use two-stage pipeline 13 | # answer_model_path: "Qwen/Qwen3-0.6B" # Second LLM for answering 14 | # context_model_path: "Qwen/Qwen3-4B" 15 | # background_prompt: "In one clear sentence, describe the most essential background knowledge needed to answer the question: {question}" 16 | 17 | # Generation configuration - applied to all models during evaluation 18 | generation_config: 19 | do_sample: false # Whether to use sampling (true) or greedy decoding (false) 20 | max_new_tokens: 64 # Maximum number of tokens to generate 21 | # Sampling parameters (only used when do_sample=true): 22 | # temperature: 0.7 # Controls randomness (0.0 = deterministic, higher = more random) 23 | # top_p: 0.8 # Nucleus sampling threshold 24 | # top_k: 20 # Top-k sampling threshold 25 | # # min_p: 0.05 # Minimum probability threshold 26 | # # repetition_penalty: 1.0 # Penalty for repeating tokens 27 | # presence_penalty: 1.5 # Penalty for repeating tokens 28 | # frequency_penalty: 1.0 29 | 30 | output: 31 | output_dir: local/final_results/0.6+0.5B_C2C_general_again 32 | 33 | eval: 34 | dataset: mmlu-redux 35 | gpu_ids: [0] # GPUs to use for evaluation 36 | answer_method: generate # 'generate' or 'logits' 37 | use_cot: false # Enable chain-of-thought reasoning 38 | use_template: true 39 | sample_interval: 1 # Sample every N examples 40 | # limit: # Limit examples per subject (null for all) 41 | # subjects: ["nutrition"] # Optional: specify specific subjects to evaluate 42 | # response_text: "" 43 | math_grading_method: "comprehensive" -------------------------------------------------------------------------------- /script/evaluation/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 测试配置 4 | MODELS=("Qwen3-0.6B" "Qwen3-1.7B" "Rosetta") 5 | METHODS=("zero_shot" "few_shot") 6 | ANSWER_METHODS=("logits" ) 7 | GPU_ID=1 # 设置使用的GPU ID 8 | MAX_LENGTH=32768 9 | BATCH_SIZE=1 10 | 11 | # 遍历所有配置组合 12 | for model in "${MODELS[@]}"; do 13 | for method in "${METHODS[@]}"; do 14 | for answer_method in "${ANSWER_METHODS[@]}"; do 15 | 16 | # 处理zero_shot情况(只运行ntrain=0) 17 | if [ "$method" == "zero_shot" ]; then 18 | ntrain=0 19 | 20 | echo "==============================================" 21 | echo "Running: model=$model, method=$method, answer_method=$answer_method, ntrain=$ntrain" 22 | echo "==============================================" 23 | 24 | python evaluator.py \ 25 | --model_name $model \ 26 | --method $method \ 27 | --answer_method $answer_method \ 28 | --gpu_id $GPU_ID \ 29 | --ntrain $ntrain \ 30 | --max_length $MAX_LENGTH \ 31 | --batch_size $BATCH_SIZE 32 | 33 | # 处理few_shot情况(运行ntrain=1到10) 34 | else 35 | for ntrain in {1..10}; do 36 | echo "==============================================" 37 | echo "Running: model=$model, method=$method, answer_method=$answer_method, ntrain=$ntrain" 38 | echo "==============================================" 39 | 40 | python evaluator.py \ 41 | --model_name $model \ 42 | --method $method \ 43 | --answer_method $answer_method \ 44 | --gpu_id $GPU_ID \ 45 | --ntrain $ntrain \ 46 | --max_length $MAX_LENGTH \ 47 | --batch_size $BATCH_SIZE 48 | done 49 | fi 50 | 51 | done 52 | done 53 | done 54 | 55 | echo "All tests completed!" -------------------------------------------------------------------------------- /rosetta/utils/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core utilities for Cache-to-Cache (C2C) operations. 3 | """ 4 | 5 | from typing import List 6 | 7 | 8 | def sharers_to_mask(sharer_indices: List[int]) -> int: 9 | """ 10 | Convert a list of sharer indices to a bitmask. 11 | 12 | Args: 13 | sharer_indices: List of 1-based sharer indices (e.g., [1, 2, 3]) 14 | 15 | Returns: 16 | Bitmask integer (e.g., [1, 2] -> 3, [1, 3] -> 5, [1, 2, 3] -> 7) 17 | 18 | Example: 19 | >>> sharers_to_mask([1]) # 001 = 1 20 | 1 21 | >>> sharers_to_mask([2]) # 010 = 2 22 | 2 23 | >>> sharers_to_mask([1, 2]) # 011 = 3 24 | 3 25 | >>> sharers_to_mask([1, 3]) # 101 = 5 26 | 5 27 | """ 28 | mask = 0 29 | for idx in sharer_indices: 30 | mask |= (1 << (idx - 1)) 31 | return mask 32 | 33 | 34 | def mask_to_sharers(mask: int) -> List[int]: 35 | """ 36 | Convert a bitmask to a list of sharer indices. 37 | 38 | Args: 39 | mask: Bitmask integer 40 | 41 | Returns: 42 | List of 1-based sharer indices 43 | 44 | Example: 45 | >>> mask_to_sharers(1) # 001 -> [1] 46 | [1] 47 | >>> mask_to_sharers(3) # 011 -> [1, 2] 48 | [1, 2] 49 | >>> mask_to_sharers(5) # 101 -> [1, 3] 50 | [1, 3] 51 | >>> mask_to_sharers(7) # 111 -> [1, 2, 3] 52 | [1, 2, 3] 53 | """ 54 | if mask <= 0: 55 | return [] 56 | sharers = [] 57 | idx = 1 58 | while mask: 59 | if mask & 1: 60 | sharers.append(idx) 61 | mask >>= 1 62 | idx += 1 63 | return sharers 64 | 65 | 66 | def all_sharers_mask(num_sharers: int) -> int: 67 | """ 68 | Get bitmask that selects all sharers. 69 | 70 | Args: 71 | num_sharers: Number of sharers 72 | 73 | Returns: 74 | Bitmask with all bits set (e.g., 3 sharers -> 7 = 111) 75 | """ 76 | return (1 << num_sharers) - 1 77 | 78 | 79 | def format_sharer_mask(mask: int) -> str: 80 | """ 81 | Format a sharer mask as a human-readable string. 82 | 83 | Args: 84 | mask: Bitmask integer (-1=no projection, 0=self projection, >0=sharer bitmask) 85 | 86 | Returns: 87 | Formatted string like "sharers [1, 2]" or "no projection" 88 | """ 89 | if mask < 0: 90 | return "no projection" 91 | if mask == 0: 92 | return "self projection" 93 | sharers = mask_to_sharers(mask) 94 | return f"sharers {sharers}" 95 | -------------------------------------------------------------------------------- /recipe/train_recipe/all_in_one.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "base_model": "Qwen/Qwen3-1.7B", 4 | "teacher_model": "Qwen/Qwen2.5-1.5B-Instruct", 5 | "is_do_alignment": false, 6 | "alignment_strategy": "first", 7 | "projector": { 8 | "type": "AllInOneProjector", 9 | "params": { 10 | "hidden_dim": 1024, 11 | "weight_hidden_dim": 1024, 12 | "num_layers": 3, 13 | "dropout": 0.1, 14 | "activation": "gelu", 15 | "use_layer_norm": true, 16 | "use_residual": true, 17 | "use_swiglu": true, 18 | "use_concat": true, 19 | "gate_granularity": "scalar", 20 | "gate_depends_on_input": false, 21 | "gate_input_features": "target_key", 22 | "gate_init_value": 0.0, 23 | "weight_granularity": "head_merged", 24 | "weight_depends_on_input": true, 25 | "weight_input_features": "target_projected_key", 26 | "weight_init_value": 0.0, 27 | "use_gumbel": true, 28 | "initial_temperature": 1.0, 29 | "final_temperature": 0.001, 30 | "preserve_target_weight": false, 31 | "add_self": true, 32 | "anneal_steps": 112, 33 | "scalar_temperature": 1.0, 34 | "max_sequence_length": 8192 35 | } 36 | }, 37 | "mapping": "last_aligned" 38 | }, 39 | "training": { 40 | "learning_rate": 1e-4, 41 | "weight_decay": 0.01, 42 | "num_epochs": 1, 43 | "max_length": 1024, 44 | "device": "cuda", 45 | "scheduler_type": "linear", 46 | "warmup_ratio": 0.1, 47 | "max_grad_norm": 1.0, 48 | "gradient_accumulation_steps": 2, 49 | "per_device_train_batch_size": 8, 50 | "num_processes": 8, 51 | "freeze": ["teacher","base"], 52 | "seed": 42 53 | }, 54 | "output": { 55 | "output_dir": "local/checkpoints/1.7B+1.5B_test", 56 | "save_steps": 500, 57 | "eval_steps": 100, 58 | "wandb_config": { 59 | "project": "Rosetta", 60 | "mode": "online", 61 | "entity": "nics-efc", 62 | "run_name": "1.7B+1.5B_test_all_in_one" 63 | } 64 | }, 65 | "data": { 66 | "type": "OpenHermesChatDataset", 67 | "kwargs": { 68 | "split": "train", 69 | "max_word_count": 1024, 70 | "num_samples": 15000 71 | }, 72 | "train_ratio": 0.99 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ "setuptools>=61.0", "wheel",] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "rosetta" 7 | version = "0.1.0" 8 | description = "Unified Memory for Multi-Model Ensemble with KV-Cache Projection" 9 | readme = "README.md" 10 | classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development :: Libraries :: Python Modules",] 11 | requires-python = ">=3.10" 12 | dependencies = [ "torch==2.6.0", "transformers==4.52.4",] 13 | 14 | [project.license] 15 | text = "MIT" 16 | 17 | [project.optional-dependencies] 18 | dev = [ "pytest>=6.0", "pytest-cov>=2.0", "black>=22.0", "isort>=5.0", "flake8>=4.0", "mypy>=0.900", "einops>=0.8", "matplotlib", "scikit-learn", "torchvision==0.21.0",] 19 | training = [ "datasets>=2.0", "accelerate>=0.20", "wandb>=0.13", "peft"] 20 | evaluation = [ "scikit-learn>=1.0", "matplotlib>=3.5", "seaborn>=0.11", "jsonlines", "openai", "math_verify", "latex2sympy2_extended", "spaces>=0.30.0"] 21 | 22 | [project.urls] 23 | Homepage = "https://github.com/your-org/unified_memory" 24 | Repository = "https://github.com/your-org/unified_memory" 25 | Documentation = "https://unified-memory.readthedocs.io" 26 | 27 | [tool.black] 28 | line-length = 88 29 | target-version = [ "py38",] 30 | include = "\\.pyi?$" 31 | extend-exclude = "/(\n # directories\n \\.eggs\n | \\.git\n | \\.hg\n | \\.mypy_cache\n | \\.tox\n | \\.venv\n | build\n | dist\n)/\n" 32 | 33 | [tool.isort] 34 | profile = "black" 35 | multi_line_output = 3 36 | line_length = 88 37 | known_first_party = [ "rosetta",] 38 | 39 | [tool.mypy] 40 | python_version = "3.8" 41 | warn_return_any = true 42 | warn_unused_configs = true 43 | disallow_untyped_defs = true 44 | disallow_incomplete_defs = true 45 | check_untyped_defs = true 46 | disallow_untyped_decorators = true 47 | no_implicit_optional = true 48 | warn_redundant_casts = true 49 | warn_unused_ignores = true 50 | warn_no_return = true 51 | warn_unreachable = true 52 | strict_equality = true 53 | [[tool.mypy.overrides]] 54 | module = [ "transformers.*", "torch.*",] 55 | ignore_missing_imports = true 56 | 57 | [tool.setuptools.package-data] 58 | rosetta = [ "*.json", "*.yaml", "*.yml",] 59 | 60 | [tool.pytest.ini_options] 61 | testpaths = [ "test",] 62 | python_files = [ "test_*.py",] 63 | python_classes = [ "Test*",] 64 | python_functions = [ "test_*",] 65 | addopts = [ "--strict-markers", "--strict-config", "--cov=rosetta", "--cov-report=term-missing", "--cov-report=html", "--cov-report=xml",] 66 | 67 | [tool.setuptools.packages.find] 68 | where = [ ".",] 69 | include = [ "rosetta*",] 70 | exclude = [ "test*", "script*",] 71 | -------------------------------------------------------------------------------- /script/dataset/save_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from datetime import datetime 4 | from datasets import Dataset 5 | 6 | output_dir = "local/teacher_datasets/openhermes_qwen_output" 7 | os.makedirs(output_dir, exist_ok=True) 8 | csv_path = os.path.join(output_dir, "OpenHermes_generated_results.csv") 9 | 10 | if os.path.exists(csv_path): 11 | try: 12 | combined_df = pd.read_csv(csv_path, low_memory=False) 13 | print(f"Found existing CSV with {len(combined_df)} records. Skipping re-append and building dataset directly.") 14 | except Exception as e: 15 | print(f"Error reading existing CSV: {e}") 16 | combined_df = None 17 | 18 | # Filter out rows with empty or None model_response to prevent ArrowTypeError 19 | before_count = len(combined_df) 20 | if "model_response" in combined_df.columns: 21 | combined_df = combined_df[ 22 | combined_df["model_response"].notna() 23 | & (combined_df["model_response"].astype(str).str.strip() != "") 24 | ] 25 | after_count = len(combined_df) 26 | removed = before_count - after_count 27 | if removed > 0: 28 | print(f"Filtered out {removed} rows with empty model_response") 29 | 30 | # Normalize dtypes to avoid pyarrow ArrowTypeError (e.g., floats/None in string fields) 31 | if "id" in combined_df.columns: 32 | combined_df["id"] = combined_df["id"].astype(str) 33 | if "input_text" in combined_df.columns: 34 | combined_df["input_text"] = combined_df["input_text"].fillna("").astype(str) 35 | if "model_reasoning" in combined_df.columns: 36 | combined_df["model_reasoning"] = combined_df["model_reasoning"].fillna("").astype(str) 37 | else: 38 | combined_df["model_reasoning"] = "" 39 | if "model_response" in combined_df.columns: 40 | combined_df["model_response"] = combined_df["model_response"].fillna("").astype(str) 41 | if "is_finished" in combined_df.columns: 42 | combined_df["is_finished"] = combined_df["is_finished"].fillna(True).astype(bool) 43 | else: 44 | combined_df["is_finished"] = True 45 | 46 | # Convert combined data to HuggingFace dataset 47 | dataset_dict = { 48 | "id": combined_df["id"].tolist(), 49 | "input_text": combined_df["input_text"].tolist(), 50 | "model_reasoning": combined_df.get("model_reasoning", pd.Series([None]*len(combined_df))).tolist(), 51 | "model_response": combined_df["model_response"].tolist(), 52 | "is_finished": combined_df.get("is_finished", pd.Series([True]*len(combined_df))).tolist(), 53 | } 54 | 55 | hf_dataset = Dataset.from_dict(dataset_dict) 56 | 57 | dataset_path = os.path.join(output_dir, "dataset") 58 | hf_dataset.save_to_disk(dataset_path) 59 | print(f"Saved HuggingFace dataset with {len(hf_dataset)} problems to {dataset_path}") 60 | 61 | hf_dataset_finished = hf_dataset.filter(lambda x: x["is_finished"] == True) 62 | print(f"Filtered dataset with {len(hf_dataset_finished)} finished problems") 63 | 64 | dataset_finished_path = os.path.join(output_dir, "dataset_finished") 65 | hf_dataset_finished.save_to_disk(dataset_finished_path) 66 | print(f"Saved finished HuggingFace dataset with {len(hf_dataset_finished)} problems to {dataset_finished_path}") 67 | -------------------------------------------------------------------------------- /rosetta/model/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Union 4 | 5 | def sample_token(logits: torch.Tensor, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1) -> Union[int, torch.Tensor]: 6 | """Sample a token from logits using temperature, top-p, and top-k sampling. 7 | Args: 8 | logits: Token logits of shape [vocab_size] or [batch_size, vocab_size] 9 | temperature: Temperature for sampling (>0). Higher values produce more random samples. 10 | top_p: Top-p probability threshold for nucleus sampling (0 < top_p ≤ 1) 11 | top_k: Top-k threshold for sampling (if -1, no top-k filtering is applied) 12 | Returns: 13 | Sampled token ID (int for single sample, tensor for batch) 14 | """ 15 | if not isinstance(logits, torch.Tensor): 16 | raise TypeError("logits must be a torch.Tensor") 17 | 18 | if logits.dim() not in [1, 2]: 19 | raise ValueError("logits must have shape [vocab_size] or [batch_size, vocab_size]") 20 | 21 | # Handle single dimension input 22 | is_single_input = logits.dim() == 1 23 | if is_single_input: 24 | logits = logits.unsqueeze(0) 25 | 26 | batch_size = logits.shape[0] 27 | 28 | # For greedy sampling (temperature=0), just return argmax 29 | if temperature == 0 or temperature <= 1e-5: 30 | tokens = torch.argmax(logits, dim=-1) 31 | return tokens.item() if is_single_input else tokens 32 | 33 | # Convert to probabilities 34 | probs = torch.nn.functional.softmax(logits / temperature, dim=-1) 35 | 36 | # Apply top-k filtering first (if specified) 37 | if top_k != -1: 38 | # Get top-k values and indices 39 | top_k_values, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]), dim=-1) 40 | 41 | # Create a mask to zero out non-top-k probabilities 42 | mask = torch.zeros_like(probs, dtype=torch.bool) 43 | mask.scatter_(-1, top_k_indices, True) 44 | 45 | # Zero out non-top-k probabilities 46 | probs = probs * mask.float() 47 | 48 | # Renormalize probabilities 49 | probs = probs / probs.sum(dim=-1, keepdim=True) 50 | 51 | # Apply top-p (nucleus) sampling 52 | if top_p < 1.0: 53 | # Sort probabilities in descending order 54 | sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) 55 | 56 | # Calculate cumulative probabilities 57 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) 58 | 59 | # Create a mask for probabilities to keep 60 | # Values above top_p threshold are masked out 61 | mask = cumulative_probs <= top_p 62 | 63 | # Always keep at least one token 64 | mask[:, 0] = True 65 | 66 | # Zero out masked positions to exclude them from sampling 67 | sorted_probs = sorted_probs * mask.float() 68 | 69 | # Renormalize probabilities 70 | sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True) 71 | 72 | # Sample from the filtered distribution 73 | sampled_indices = torch.multinomial(sorted_probs, num_samples=1) 74 | 75 | # Map back to original vocabulary indices 76 | tokens = torch.gather(sorted_indices, dim=-1, index=sampled_indices) 77 | tokens = tokens.squeeze(-1) # Remove sample dimension 78 | else: 79 | # Direct sampling if no top-p filtering 80 | tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) 81 | 82 | return tokens.item() if is_single_input else tokens 83 | -------------------------------------------------------------------------------- /script/analysis/venn/plot_venn_from_json.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from typing import Set, Dict, Any 5 | 6 | import matplotlib.pyplot as plt 7 | from matplotlib_venn import venn3 8 | 9 | 10 | def load_regions(json_path: str) -> Dict[str, Set[int]]: 11 | with open(json_path, 'r') as f: 12 | data: Dict[str, Any] = json.load(f) 13 | regions: Dict[str, Set[int]] = {} 14 | for k in [ 15 | "rosetta_only", "slm_only", "llm_only", 16 | "rosetta_slm", "rosetta_llm", "slm_llm", 17 | "all_three" 18 | ]: 19 | regions[k] = set(data.get(k, [])) 20 | return regions 21 | 22 | 23 | def reconstruct_sets(regions: Dict[str, Set[int]]) -> tuple[Set[int], Set[int], Set[int]]: 24 | A_only = regions.get("rosetta_only", set()) 25 | B_only = regions.get("slm_only", set()) 26 | C_only = regions.get("llm_only", set()) 27 | AB = regions.get("rosetta_slm", set()) 28 | AC = regions.get("rosetta_llm", set()) 29 | BC = regions.get("slm_llm", set()) 30 | ABC = regions.get("all_three", set()) 31 | 32 | A = set().union(A_only, AB, AC, ABC) 33 | B = set().union(B_only, AB, BC, ABC) 34 | C = set().union(C_only, AC, BC, ABC) 35 | return A, B, C 36 | 37 | 38 | def plot_venn_from_json(json_path: str, out_path: str, 39 | label_a: str = "C2C", label_b: str = "Receiver", label_c: str = "Sharer", 40 | title: str | None = None): 41 | regions = load_regions(json_path) 42 | A, B, C = reconstruct_sets(regions) 43 | 44 | FIGSIZE = (9, 9) 45 | DPI = 300 46 | 47 | fig = plt.figure(figsize=FIGSIZE, constrained_layout=True) 48 | ax = fig.add_subplot(111) 49 | v = venn3([A, B, C], set_labels=(label_a, label_b, label_c)) 50 | 51 | # Improve visibility and reduce empty margins 52 | if v is not None: 53 | if hasattr(v, 'set_labels') and v.set_labels is not None: 54 | for txt in v.set_labels: 55 | if txt is not None: 56 | txt.set_fontsize(30) 57 | for sid in ("100","010","001","110","101","011","111"): 58 | lbl = v.get_label_by_id(sid) 59 | if lbl is not None: 60 | lbl.set_fontsize(30) 61 | for sid in ("100","010","001","110","101","011","111"): 62 | patch = v.get_patch_by_id(sid) 63 | if patch is not None: 64 | patch.set_linewidth(1.5) 65 | 66 | # if title is None: 67 | # title = f"Correct Answer Overlap: {label_a} vs {label_b}(Qwen3-0.6B) vs {label_c}(Qwen3-4B)" 68 | ax.set_position([0.03, 0.03, 0.96, 0.96]) 69 | plt.title(title, fontsize=16, pad=8) 70 | 71 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 72 | plt.savefig(out_path, dpi=DPI, bbox_inches=None) 73 | plt.close() 74 | print(f"Saved Venn diagram to {out_path}") 75 | 76 | 77 | def main(): 78 | parser = argparse.ArgumentParser(description="Plot a Venn diagram directly from precomputed region JSON") 79 | parser.add_argument("--json", type=str, required=True, help="Path to regions JSON file") 80 | parser.add_argument("--out", type=str, required=True, help="Output image path (e.g., .png)") 81 | parser.add_argument("--label_a", type=str, default="C2C", help="Label for set A (default: C2C)") 82 | parser.add_argument("--label_b", type=str, default="Receiver", help="Label for set B (default: Receiver)") 83 | parser.add_argument("--label_c", type=str, default="Sharer", help="Label for set C (default: Sharer)") 84 | parser.add_argument("--title", type=str, default=None, help="Optional custom title") 85 | args = parser.parse_args() 86 | 87 | plot_venn_from_json(args.json, args.out, args.label_a, args.label_b, args.label_c, args.title) 88 | 89 | 90 | if __name__ == "__main__": 91 | main() 92 | 93 | 94 | -------------------------------------------------------------------------------- /script/dataset/qwen3_nonthinking.jinja: -------------------------------------------------------------------------------- 1 | ip{%- if tools %} 2 | {{- '<|im_start|>system\n' }} 3 | {%- if messages[0].role == 'system' %} 4 | {{- messages[0].content + '\n\n' }} 5 | {%- endif %} 6 | {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} 7 | {%- for tool in tools %} 8 | {{- "\n" }} 9 | {{- tool | tojson }} 10 | {%- endfor %} 11 | {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} 12 | {%- else %} 13 | {%- if messages[0].role == 'system' %} 14 | {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} 15 | {%- endif %} 16 | {%- endif %} 17 | {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} 18 | {%- for message in messages[::-1] %} 19 | {%- set index = (messages|length - 1) - loop.index0 %} 20 | {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} 21 | {%- set ns.multi_step_tool = false %} 22 | {%- set ns.last_query_index = index %} 23 | {%- endif %} 24 | {%- endfor %} 25 | {%- for message in messages %} 26 | {%- if message.content is string %} 27 | {%- set content = message.content %} 28 | {%- else %} 29 | {%- set content = '' %} 30 | {%- endif %} 31 | {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} 32 | {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} 33 | {%- elif message.role == "assistant" %} 34 | {%- set reasoning_content = '' %} 35 | {%- if message.reasoning_content is string %} 36 | {%- set reasoning_content = message.reasoning_content %} 37 | {%- else %} 38 | {%- if '' in content %} 39 | {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} 40 | {%- set content = content.split('')[-1].lstrip('\n') %} 41 | {%- endif %} 42 | {%- endif %} 43 | {%- if loop.index0 > ns.last_query_index %} 44 | {%- if loop.last or (not loop.last and reasoning_content) %} 45 | {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} 46 | {%- else %} 47 | {{- '<|im_start|>' + message.role + '\n' + content }} 48 | {%- endif %} 49 | {%- else %} 50 | {{- '<|im_start|>' + message.role + '\n' + content }} 51 | {%- endif %} 52 | {%- if message.tool_calls %} 53 | {%- for tool_call in message.tool_calls %} 54 | {%- if (loop.first and content) or (not loop.first) %} 55 | {{- '\n' }} 56 | {%- endif %} 57 | {%- if tool_call.function %} 58 | {%- set tool_call = tool_call.function %} 59 | {%- endif %} 60 | {{- '\n{"name": "' }} 61 | {{- tool_call.name }} 62 | {{- '", "arguments": ' }} 63 | {%- if tool_call.arguments is string %} 64 | {{- tool_call.arguments }} 65 | {%- else %} 66 | {{- tool_call.arguments | tojson }} 67 | {%- endif %} 68 | {{- '}\n' }} 69 | {%- endfor %} 70 | {%- endif %} 71 | {{- '<|im_end|>\n' }} 72 | {%- elif message.role == "tool" %} 73 | {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} 74 | {{- '<|im_start|>user' }} 75 | {%- endif %} 76 | {{- '\n\n' }} 77 | {{- content }} 78 | {{- '\n' }} 79 | {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} 80 | {{- '<|im_end|>\n' }} 81 | {%- endif %} 82 | {%- endif %} 83 | {%- endfor %} 84 | {%- if add_generation_prompt %} 85 | {{- '<|im_start|>assistant\n\n\n\n\n' }} 86 | {%- endif %} -------------------------------------------------------------------------------- /rosetta/train/model_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model setup utilities for RosettaModel training/evaluation 3 | """ 4 | 5 | import torch 6 | from typing import Dict, Any, List 7 | from transformers import AutoModelForCausalLM, AutoTokenizer 8 | 9 | from rosetta.model.wrapper import RosettaModel 10 | from rosetta.model.projector import create_projector 11 | 12 | """ 13 | Mapping strategies 14 | """ 15 | def k_nearest_sources(num_target_layers: int, num_source_layers: int, k: int) -> Dict[int, List[int]]: 16 | """ 17 | Compute a per-target mapping to K nearest source layers. 18 | 19 | Returns: Dict[target_idx, List[source_idx]] only for targets we map. 20 | Distances are computed by placing target and source layers uniformly in [0, 1] 21 | and sorting by absolute distance. 22 | """ 23 | if num_target_layers <= 1: 24 | target_positions = [0.0] 25 | else: 26 | target_positions = [i / (num_target_layers - 1) for i in range(num_target_layers)] 27 | if num_source_layers <= 1: 28 | source_positions = [0.0] 29 | else: 30 | source_positions = [j / (num_source_layers - 1) for j in range(num_source_layers)] 31 | 32 | mapping: Dict[int, List[int]] = {} 33 | for t_idx, t_pos in enumerate(target_positions): 34 | sorted_src = sorted(range(num_source_layers), key=lambda j: abs(source_positions[j] - t_pos)) 35 | chosen = sorted_src[:max(0, k)] 36 | if len(chosen) > 0: 37 | mapping[t_idx] = chosen 38 | return mapping 39 | 40 | 41 | def last_aligned_sources(num_target_layers: int, num_source_layers: int, k: int = 1) -> Dict[int, List[int]]: 42 | """ 43 | Return a per-target mapping that aligns the last target layer to the last 44 | source layer and walks toward the front. 45 | 46 | Returns: Dict[target_idx, List[source_idx]] only for targets we map. For each 47 | target t, we choose up to K sources anchored at the aligned index, preferring 48 | backward indices first then forward to satisfy K. 49 | 50 | Example (T=11, S=33): target 10 -> [32, 31, ...], target 9 -> [31, 30, ...] 51 | """ 52 | mapping: Dict[int, List[int]] = {} 53 | if num_target_layers <= 0 or num_source_layers <= 0: 54 | return mapping 55 | 56 | # Align ends; offset >= 0 means extra source layers at the front 57 | offset = num_source_layers - num_target_layers 58 | 59 | def take_k_from(s0: int) -> List[int]: 60 | result: List[int] = [] 61 | # Prefer moving backward from the anchor (last-to-front) 62 | for back in range(k): 63 | idx = s0 - back 64 | if 0 <= idx < num_source_layers: 65 | result.append(idx) 66 | # If not enough due to boundary, extend forward 67 | next_idx = s0 + 1 68 | while len(result) < k and next_idx < num_source_layers: 69 | result.append(next_idx) 70 | next_idx += 1 71 | return result 72 | 73 | for t in range(num_target_layers): 74 | s0 = offset + t 75 | # Clamp to valid range for edge cases (e.g., fewer source layers) 76 | if s0 < 0: 77 | s0 = 0 78 | elif s0 > num_source_layers - 1: 79 | s0 = num_source_layers - 1 80 | chosen = take_k_from(s0) 81 | if len(chosen) > 0: 82 | mapping[t] = chosen 83 | 84 | return mapping 85 | 86 | 87 | def setup_models(model_config: Dict[str, Any], device: str = "cuda", dtype: torch.dtype = torch.bfloat16): 88 | """Setup RosettaModel with base model, teacher model, and projectors""" 89 | 90 | # Load tokenizer 91 | tokenizer = AutoTokenizer.from_pretrained(model_config["base_model"]) 92 | if tokenizer.pad_token is None: 93 | tokenizer.pad_token = tokenizer.eos_token 94 | 95 | # Load models 96 | base_model = AutoModelForCausalLM.from_pretrained( 97 | model_config["base_model"], 98 | torch_dtype=dtype, 99 | device_map=device 100 | ) 101 | 102 | teacher_model = AutoModelForCausalLM.from_pretrained( 103 | model_config["teacher_model"], 104 | torch_dtype=dtype, 105 | device_map=device 106 | ) 107 | 108 | # Create projector 109 | projector_config = model_config["projector"] 110 | projector_params = projector_config["params"].copy() 111 | projector_params["dtype"] = dtype 112 | 113 | projector = create_projector( 114 | projector_config["type"], 115 | source_dim=teacher_model.config.head_dim, 116 | target_dim=base_model.config.head_dim, 117 | **projector_params 118 | ) 119 | 120 | # Setup RosettaModel 121 | rosetta_model = RosettaModel( 122 | model_list=[base_model, teacher_model], 123 | base_model_idx=0, 124 | projector_list=[projector] 125 | ).to(device) 126 | 127 | # Configure projector mappings 128 | num_layers_to_map = min( 129 | base_model.config.num_hidden_layers, 130 | teacher_model.config.num_hidden_layers 131 | ) 132 | 133 | for layer_idx in range(num_layers_to_map): 134 | rosetta_model.set_projector_config( 135 | source_model_idx=1, # Teacher 136 | source_model_layer_idx=layer_idx, 137 | target_model_idx=0, # Base 138 | target_model_layer_idx=layer_idx, 139 | projector_idx=0 140 | ) 141 | 142 | return rosetta_model, tokenizer -------------------------------------------------------------------------------- /script/analysis/gate_weight/collect_projector_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | from typing import Dict, List, Any 5 | 6 | import torch 7 | import pandas as pd 8 | 9 | from rosetta.utils.evaluate import load_rosetta_model 10 | from rosetta.train.dataset_adapters import MMLUChatDataset 11 | 12 | 13 | def ensure_dir(path: str) -> None: 14 | if not os.path.exists(path): 15 | os.makedirs(path, exist_ok=True) 16 | 17 | 18 | def build_inputs(tokenizer, conversation: List[Dict[str, str]], device: torch.device): 19 | text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) 20 | enc = tokenizer(text, return_tensors="pt") 21 | return {k: v.to(device) for k, v in enc.items()} 22 | 23 | 24 | def load_model_and_tokenizer(checkpoint_dir: str, device: torch.device): 25 | model_config = { 26 | "model_name": "Rosetta", 27 | "rosetta_config": { 28 | "checkpoints_dir": checkpoint_dir, 29 | "base_model": "Qwen/Qwen3-0.6B", 30 | "teacher_model": "Qwen/Qwen2.5-0.5B-Instruct", 31 | }, 32 | } 33 | eval_config = {"checkpoints_dir": checkpoint_dir} 34 | 35 | rosetta_model, tokenizer = load_rosetta_model(model_config, eval_config, device) 36 | rosetta_model.eval() 37 | 38 | if getattr(tokenizer, "chat_template", None) is None: 39 | tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}<|user|>\n{{ message['content'] }}\n{% elif message['role'] == 'assistant' %}<|assistant|>\n{{ message['content'] }}\n{% endif %}{% endfor %}<|assistant|>\n" 40 | 41 | return rosetta_model, tokenizer 42 | 43 | 44 | @torch.no_grad() 45 | def run_and_collect(model, dataset, tokenizer, device: torch.device, num_samples: int) -> pd.DataFrame: 46 | rows: List[Dict[str, Any]] = [] 47 | 48 | for sample_idx in range(min(num_samples, len(dataset))): 49 | conv = dataset[sample_idx] 50 | inputs = build_inputs(tokenizer, conv, device) 51 | 52 | # Prepare kv_cache indexes similar to compare_projector_terms.py to trigger forward once 53 | full_length = inputs["input_ids"].shape[1] 54 | instruction_length = full_length - 1 55 | instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(instruction_length, 1).unsqueeze(0).to(device) 56 | response_index = torch.tensor([-1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(device) 57 | kv_cache_list = [instruction_index, response_index] 58 | 59 | _ = model.generate( 60 | input_ids=inputs["input_ids"], 61 | kv_cache_index=kv_cache_list, 62 | max_new_tokens=1, 63 | do_sample=False, 64 | use_cache=True, 65 | ) 66 | 67 | # After generation, each projector should expose capture attributes 68 | for proj_idx, proj in enumerate(model.projector_list): 69 | # Expect attributes set in C2CProjector.forward 70 | norm_key_scalar = getattr(proj, "last_norm_key_scalar", None) 71 | norm_value_scalar = getattr(proj, "last_norm_value_scalar", None) 72 | key_gate_logit = getattr(proj, "last_key_gate_logit", None) 73 | value_gate_logit = getattr(proj, "last_value_gate_logit", None) 74 | 75 | # Convert tensors to nested Python lists for CSV (JSON-encoded) 76 | norm_key_scalar_list = norm_key_scalar.tolist() if norm_key_scalar is not None else None 77 | norm_value_scalar_list = norm_value_scalar.tolist() if norm_value_scalar is not None else None 78 | 79 | row = { 80 | "sample_index": sample_idx, 81 | "projector_index": proj_idx, 82 | "norm_key_scalar": json.dumps(norm_key_scalar_list) if norm_key_scalar_list is not None else None, 83 | "norm_value_scalar": json.dumps(norm_value_scalar_list) if norm_value_scalar_list is not None else None, 84 | "key_gate_logit": key_gate_logit, 85 | "value_gate_logit": value_gate_logit, 86 | } 87 | rows.append(row) 88 | 89 | return pd.DataFrame(rows) 90 | 91 | 92 | def main() -> None: 93 | parser = argparse.ArgumentParser(description="Collect per-projector scalar weights and gate logits into CSV") 94 | parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint directory") 95 | parser.add_argument("--split", type=str, default="validation", help="Dataset split to use (e.g., validation)") 96 | parser.add_argument("--num_samples", type=int, default=None, help="Number of samples to process") 97 | parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda or cpu)") 98 | parser.add_argument("--output_csv", type=str, default="projector_weights.csv", help="Output CSV path") 99 | args = parser.parse_args() 100 | 101 | device = torch.device(args.device if torch.cuda.is_available() and args.device.startswith("cuda") else "cpu") 102 | 103 | model, tokenizer = load_model_and_tokenizer(args.checkpoint, device) 104 | dataset = MMLUChatDataset(split=args.split, num_samples=args.num_samples) 105 | 106 | df = run_and_collect(model, dataset, tokenizer, device, args.num_samples) 107 | 108 | ensure_dir(os.path.dirname(args.output_csv) or ".") 109 | df.to_csv(args.output_csv, index=False) 110 | print(f"Saved: {args.output_csv} (rows={len(df)})") 111 | 112 | 113 | if __name__ == "__main__": 114 | main() 115 | 116 | 117 | -------------------------------------------------------------------------------- /script/train/generate_configs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Generate different configuration files for testing various projector types and freeze configurations. 4 | """ 5 | 6 | import json 7 | import os 8 | import itertools 9 | from typing import Dict, Any, List 10 | 11 | def load_base_config(config_path: str = "recipe/default_config.json") -> Dict[str, Any]: 12 | """Load the base configuration from the default config file""" 13 | if not os.path.exists(config_path): 14 | raise FileNotFoundError(f"Default config file not found: {config_path}") 15 | 16 | with open(config_path, 'r') as f: 17 | config = json.load(f) 18 | 19 | return config 20 | 21 | def generate_configs(): 22 | """Generate all configuration combinations""" 23 | 24 | # Load base configuration from default config file 25 | base_config = load_base_config() 26 | 27 | # Define the parameter variations 28 | projector_configs = { 29 | "AdditiveProjector": { 30 | "type": "AdditiveProjector", 31 | "params": { 32 | "hidden_dim": 1024, 33 | "num_layers": 3, 34 | "dropout": 0.1, 35 | "activation": "gelu", 36 | "use_layer_norm": True, 37 | "init_weight": 0.1 38 | } 39 | }, 40 | "MLPProjector": { 41 | "type": "MLPProjector", 42 | "params": { 43 | "hidden_dim": 1024, 44 | "num_layers": 3, 45 | "dropout": 0.1, 46 | "activation": "gelu", 47 | "use_layer_norm": True, 48 | "residual_connection": False 49 | } 50 | } 51 | } 52 | 53 | freeze_configs = { 54 | "freeze_teacher": ["teacher"], 55 | # "freeze_base": ["base"], 56 | # "freeze_projector": ["projector"], 57 | "freeze_base_teacher": ["base", "teacher"], 58 | # "freeze_base_projector": ["base", "projector"], 59 | # "freeze_teacher_projector": ["teacher", "projector"], 60 | # "freeze_none": [] 61 | } 62 | 63 | # Create output directory 64 | output_dir = "recipe/experiments" 65 | os.makedirs(output_dir, exist_ok=True) 66 | 67 | # Generate all combinations 68 | config_files = [] 69 | 70 | for projector_name, projector_config in projector_configs.items(): 71 | for freeze_name, freeze_config in freeze_configs.items(): 72 | # Create copy of base config 73 | config = json.loads(json.dumps(base_config)) # Deep copy 74 | 75 | # Update projector 76 | config["model"]["projector"] = projector_config 77 | 78 | # Update freeze configuration 79 | config["training"]["freeze"] = freeze_config 80 | 81 | # Create descriptive run name 82 | run_name = f"rosetta_{projector_name.lower()}_{freeze_name}" 83 | config["output"]["run_name"] = run_name 84 | 85 | # Create filename 86 | filename = f"{projector_name.lower()}_{freeze_name}.json" 87 | filepath = os.path.join(output_dir, filename) 88 | 89 | # Save config 90 | with open(filepath, 'w') as f: 91 | json.dump(config, f, indent=4) 92 | 93 | config_files.append(filepath) 94 | print(f"Generated: {filepath}") 95 | 96 | print(f"\nGenerated {len(config_files)} configuration files in {output_dir}/") 97 | return config_files 98 | 99 | def generate_summary(): 100 | """Generate a summary of all configurations""" 101 | output_dir = "recipe/experiments" 102 | summary_file = os.path.join(output_dir, "experiment_summary.txt") 103 | 104 | projector_types = ["AdditiveProjector", "MLPProjector"] 105 | freeze_options = [ 106 | "freeze_teacher", "freeze_base", "freeze_projector", 107 | "freeze_base_teacher", "freeze_base_projector", "freeze_teacher_projector", 108 | "freeze_none" 109 | ] 110 | 111 | with open(summary_file, 'w') as f: 112 | f.write("Experiment Configuration Summary\n") 113 | f.write("=" * 50 + "\n\n") 114 | 115 | f.write("Projector Types:\n") 116 | for proj in projector_types: 117 | f.write(f" - {proj}\n") 118 | f.write("\n") 119 | 120 | f.write("Freeze Configurations:\n") 121 | freeze_descriptions = { 122 | "freeze_teacher": "Only teacher model frozen", 123 | "freeze_base": "Only base model frozen", 124 | "freeze_projector": "Only projector frozen", 125 | "freeze_base_teacher": "Base and teacher models frozen", 126 | "freeze_base_projector": "Base model and projector frozen", 127 | "freeze_teacher_projector": "Teacher model and projector frozen", 128 | "freeze_none": "No components frozen" 129 | } 130 | 131 | for freeze, desc in freeze_descriptions.items(): 132 | f.write(f" - {freeze}: {desc}\n") 133 | f.write("\n") 134 | 135 | f.write("Total Experiments:\n") 136 | f.write(f" {len(projector_types)} projectors × {len(freeze_options)} freeze configs = {len(projector_types) * len(freeze_options)} experiments\n\n") 137 | 138 | f.write("Generated Files:\n") 139 | for proj in projector_types: 140 | for freeze in freeze_options: 141 | filename = f"{proj.lower()}_{freeze}.json" 142 | f.write(f" - {filename}\n") 143 | 144 | print(f"Summary saved to: {summary_file}") 145 | 146 | if __name__ == "__main__": 147 | print("Generating experiment configurations...") 148 | config_files = generate_configs() 149 | generate_summary() 150 | print("\nDone!") -------------------------------------------------------------------------------- /script/analysis/length_ratio/extract_metrics_enhanced.py: -------------------------------------------------------------------------------- 1 | """ 2 | 简单提取模型结果数据 3 | 提取Rosetta和Qwen的子类准确率,以及长度比信息 4 | """ 5 | 6 | import json 7 | import pandas as pd 8 | 9 | def extract_data(): 10 | """提取两个模型的数据并整理成表格""" 11 | 12 | # 文件路径 13 | rosetta_file = "/share/minzihan/unified_memory/cot_eval_results/Rosetta_context_cot_2_generate_20250818_173215_summary.json" 14 | qwen_file = "/share/minzihan/unified_memory/cot_eval_results/Qwen3-4B_generate_20250818_181807_summary.json" 15 | 16 | # 加载数据 17 | with open(rosetta_file, 'r', encoding='utf-8') as f: 18 | rosetta_data = json.load(f) 19 | 20 | with open(qwen_file, 'r', encoding='utf-8') as f: 21 | qwen_data = json.load(f) 22 | 23 | # 提取Rosetta数据 24 | rosetta_metrics = {} 25 | if 'length_statistics' in rosetta_data and 'subcategories' in rosetta_data['length_statistics']: 26 | for subcategory, stats in rosetta_data['length_statistics']['subcategories'].items(): 27 | rosetta_metrics[subcategory] = { 28 | 'accuracy': stats.get('accuracy', 0.0), 29 | 'length_ratio': stats.get('avg_length_ratio', 0.0), 30 | 'samples': stats.get('total_samples', 0) 31 | } 32 | 33 | # 提取Qwen数据 34 | qwen_metrics = {} 35 | if 'length_statistics' in qwen_data and 'subcategories' in qwen_data['length_statistics']: 36 | for subcategory, stats in qwen_data['length_statistics']['subcategories'].items(): 37 | qwen_metrics[subcategory] = { 38 | 'accuracy': stats.get('accuracy', 0.0), 39 | 'length_ratio': stats.get('avg_length_ratio', 0.0), 40 | 'samples': stats.get('total_samples', 0) 41 | } 42 | 43 | # 获取所有子类 44 | all_subcategories = sorted(set(rosetta_metrics.keys()) | set(qwen_metrics.keys())) 45 | 46 | # 整理数据 47 | data = [] 48 | for subcategory in all_subcategories: 49 | row = { 50 | '子类': subcategory, 51 | 'Rosetta准确率': rosetta_metrics.get(subcategory, {}).get('accuracy', None), 52 | 'Qwen准确率': qwen_metrics.get(subcategory, {}).get('accuracy', None), 53 | 'Rosetta长度比': rosetta_metrics.get(subcategory, {}).get('length_ratio', None), 54 | 'Qwen长度比': qwen_metrics.get(subcategory, {}).get('length_ratio', None), 55 | 'Rosetta样本数': rosetta_metrics.get(subcategory, {}).get('samples', None), 56 | 'Qwen样本数': qwen_metrics.get(subcategory, {}).get('samples', None) 57 | } 58 | data.append(row) 59 | 60 | # 创建DataFrame 61 | df = pd.DataFrame(data) 62 | 63 | # 格式化数值,保留4位小数 64 | for col in ['Rosetta准确率', 'Qwen准确率', 'Rosetta长度比', 'Qwen长度比']: 65 | df[col] = df[col].apply(lambda x: f"{x:.4f}" if x is not None else "N/A") 66 | 67 | return df 68 | 69 | def main(): 70 | """主函数""" 71 | print("提取模型评估数据...") 72 | 73 | # 提取数据 74 | df = extract_data() 75 | 76 | # 显示结果 77 | print("\n模型评估数据汇总:") 78 | print("=" * 80) 79 | print(df.to_string(index=False)) 80 | 81 | # 显示一些基本统计 82 | print("\n基本统计:") 83 | print("-" * 40) 84 | 85 | # 计算有效数据的统计 86 | rosetta_acc = df[df['Rosetta准确率'] != "N/A"]['Rosetta准确率'].apply(lambda x: float(x)) 87 | qwen_acc = df[df['Qwen准确率'] != "N/A"]['Qwen准确率'].apply(lambda x: float(x)) 88 | 89 | if len(rosetta_acc) > 0: 90 | print(f"Rosetta平均准确率: {rosetta_acc.mean():.4f}") 91 | if len(qwen_acc) > 0: 92 | print(f"Qwen平均准确率: {qwen_acc.mean():.4f}") 93 | 94 | print(f"共有 {len(df)} 个子类") 95 | print(f"Rosetta有数据的子类: {len(rosetta_acc)} 个") 96 | print(f"Qwen有数据的子类: {len(qwen_acc)} 个") 97 | 98 | # 分析Rosetta表现与长度比的关系 99 | print("\nRosetta表现与长度比分析:") 100 | print("-" * 50) 101 | 102 | # 筛选同时有Rosetta和Qwen准确率数据的子类 103 | valid_comparison = df[(df['Rosetta准确率'] != "N/A") & 104 | (df['Qwen准确率'] != "N/A") & 105 | (df['Rosetta长度比'] != "N/A")] 106 | 107 | if len(valid_comparison) > 0: 108 | # 转换为数值 109 | valid_comparison = valid_comparison.copy() 110 | valid_comparison['Rosetta准确率_num'] = valid_comparison['Rosetta准确率'].apply(float) 111 | valid_comparison['Qwen准确率_num'] = valid_comparison['Qwen准确率'].apply(float) 112 | valid_comparison['Rosetta长度比_num'] = valid_comparison['Rosetta长度比'].apply(float) 113 | 114 | # 计算Rosetta相对Qwen的表现差异 115 | valid_comparison['准确率差异'] = valid_comparison['Rosetta准确率_num'] - valid_comparison['Qwen准确率_num'] 116 | 117 | # 分为强势和弱势子类 118 | rosetta_strong = valid_comparison[valid_comparison['准确率差异'] > 0] # Rosetta比Qwen强的子类 119 | rosetta_weak = valid_comparison[valid_comparison['准确率差异'] < 0] # Rosetta比Qwen弱的子类 120 | rosetta_equal = valid_comparison[valid_comparison['准确率差异'] == 0] # 表现相等的子类 121 | 122 | print(f"可比较的子类总数: {len(valid_comparison)}") 123 | print(f"Rosetta表现更好的子类: {len(rosetta_strong)} 个") 124 | print(f"Rosetta表现较差的子类: {len(rosetta_weak)} 个") 125 | print(f"表现相等的子类: {len(rosetta_equal)} 个") 126 | 127 | if len(rosetta_strong) > 0: 128 | strong_avg_length = rosetta_strong['Rosetta长度比_num'].mean() 129 | print(f"\nRosetta强势子类的平均长度比: {strong_avg_length:.4f}") 130 | print("强势子类详情:") 131 | for _, row in rosetta_strong.iterrows(): 132 | print(f" {row['子类']}: 准确率差异 +{row['准确率差异']:.4f}, 长度比 {row['Rosetta长度比_num']:.4f}") 133 | 134 | if len(rosetta_weak) > 0: 135 | weak_avg_length = rosetta_weak['Rosetta长度比_num'].mean() 136 | print(f"\nRosetta弱势子类的平均长度比: {weak_avg_length:.4f}") 137 | print("弱势子类详情:") 138 | for _, row in rosetta_weak.iterrows(): 139 | print(f" {row['子类']}: 准确率差异 {row['准确率差异']:.4f}, 长度比 {row['Rosetta长度比_num']:.4f}") 140 | 141 | if len(rosetta_strong) > 0 and len(rosetta_weak) > 0: 142 | length_diff = strong_avg_length - weak_avg_length 143 | print(f"\n长度比差异分析:") 144 | print(f"强势子类平均长度比 - 弱势子类平均长度比 = {length_diff:.4f}") 145 | if length_diff > 0: 146 | print("→ Rosetta在长度比较高的子类上表现更好") 147 | elif length_diff < 0: 148 | print("→ Rosetta在长度比较低的子类上表现更好") 149 | else: 150 | print("→ 长度比与Rosetta的相对表现无明显关系") 151 | else: 152 | print("没有足够的数据进行比较分析") 153 | 154 | if __name__ == "__main__": 155 | main() 156 | -------------------------------------------------------------------------------- /script/playground/live_chat_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Live Chat Example with C2C Models 3 | 4 | Demonstrates the key RosettaModel interface for cache-to-cache communication. 5 | 6 | Usage: 7 | python live_chat_example.py --checkpoint_dir path/to/checkpoint 8 | python live_chat_example.py --checkpoint_dir path/to/ckpt1 path/to/ckpt2 # multi-source 9 | """ 10 | 11 | import argparse 12 | import torch 13 | import json, yaml, re, os 14 | from pathlib import Path 15 | from transformers import AutoTokenizer, AutoModelForCausalLM 16 | 17 | from rosetta.utils.evaluate import set_default_chat_template 18 | from rosetta.utils.core import all_sharers_mask, format_sharer_mask 19 | from rosetta.model.wrapper import RosettaModel 20 | from rosetta.model.projector import load_projector 21 | 22 | def load_model(checkpoint_dirs: list, subfolder: str = "final", device: str = "cuda:0"): 23 | """ 24 | Load RosettaModel from checkpoint directories. 25 | Supports both single-source and multi-source modes. 26 | """ 27 | 28 | device = torch.device(device) 29 | 30 | # Read config from each checkpoint 31 | configs = [] 32 | for ckpt_dir in checkpoint_dirs: 33 | with open(Path(ckpt_dir) / "config.json") as f: 34 | configs.append(yaml.safe_load(f)) 35 | 36 | base_model = configs[0]["model"]["base_model"] # assume all checkpoints have the same base model 37 | teacher_models = [cfg["model"]["teacher_model"] for cfg in configs] 38 | 39 | print(f"Base model: {base_model}") 40 | print(f"Teacher models: {teacher_models}") 41 | 42 | # Load tokenizer and base model 43 | tokenizer = AutoTokenizer.from_pretrained(base_model) 44 | set_default_chat_template(tokenizer, base_model) 45 | 46 | base_llm = AutoModelForCausalLM.from_pretrained( 47 | base_model, torch_dtype=torch.bfloat16, device_map={"": device} 48 | ).eval() 49 | 50 | # Load teacher models 51 | sharer_llms = [ 52 | AutoModelForCausalLM.from_pretrained( 53 | tm, torch_dtype=torch.bfloat16, device_map={"": device} 54 | ).eval() 55 | for tm in teacher_models 56 | ] 57 | 58 | # Load projectors from each checkpoint 59 | projector_list = [] 60 | projector_offsets = [0] 61 | 62 | for ckpt_dir in checkpoint_dirs: 63 | proj_dir = Path(ckpt_dir) / subfolder 64 | num_proj = len([f for f in os.listdir(proj_dir) if re.match(r"projector_\d+\.pt", f)]) 65 | 66 | for i in range(num_proj): 67 | proj = load_projector(str(proj_dir / f"projector_{i}.json")).to(device) 68 | proj.load_state_dict(torch.load(proj_dir / f"projector_{i}.pt", map_location=device), strict=False) 69 | projector_list.append(proj) 70 | projector_offsets.append(len(projector_list)) 71 | 72 | # Create RosettaModel 73 | rosetta_model = RosettaModel( 74 | model_list=[base_llm] + sharer_llms, 75 | base_model_idx=0, 76 | projector_list=projector_list, 77 | ).to(device).eval() 78 | 79 | # Load projector configs 80 | for llm_idx, ckpt_dir in enumerate(checkpoint_dirs): 81 | cfg_path = Path(ckpt_dir) / subfolder / "projector_config.json" 82 | if cfg_path.exists(): 83 | with open(cfg_path) as f: 84 | cfg = json.load(f) 85 | # Adjust indices and merge into projector_dict 86 | for tgt_idx, sources in cfg.items(): 87 | tgt_idx = int(tgt_idx) 88 | if tgt_idx not in rosetta_model.projector_dict: 89 | rosetta_model.projector_dict[tgt_idx] = {} 90 | src_idx = llm_idx + 1 # actual source model index 91 | rosetta_model.projector_dict[tgt_idx][src_idx] = { 92 | int(layer): [(sl, pi + projector_offsets[llm_idx]) for sl, pi in mappings] 93 | for layer, mappings in list(sources.values())[0].items() 94 | } 95 | 96 | return rosetta_model, tokenizer, len(teacher_models) 97 | 98 | 99 | def generate(model, tokenizer, prompt: str, num_sharers: int, device): 100 | """ 101 | Generate response using RosettaModel. 102 | 103 | Key interface: kv_cache_index[i][0][0][0] controls sharer selection: 104 | - -1: no projection (receiver only) 105 | - >0: bitmask (1=sharer1, 2=sharer2, 3=both, 7=all three) 106 | """ 107 | # Prepare input 108 | messages = [{"role": "user", "content": prompt}] 109 | text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False) 110 | inputs = tokenizer(text, return_tensors="pt").to(device) 111 | 112 | # Create kv_cache_index with bitmask (use all sharers) 113 | sharer_mask = all_sharers_mask(num_sharers) 114 | seq_len = inputs.input_ids.shape[1] 115 | 116 | # instruction_index: apply C2C projection for prompt tokens 117 | # label_index: -1 means no projection for generation 118 | instruction_index = torch.tensor([sharer_mask, 0]).repeat(seq_len - 1, 1).unsqueeze(0).to(device) 119 | label_index = torch.tensor([[-1, 0]]).unsqueeze(0).to(device) 120 | 121 | print(f" Using {format_sharer_mask(sharer_mask)} for prompt encoding") 122 | 123 | # Generate 124 | with torch.no_grad(): 125 | outputs = model.generate( 126 | kv_cache_index=[instruction_index, label_index], 127 | input_ids=inputs.input_ids, 128 | attention_mask=inputs.attention_mask, 129 | do_sample=False, 130 | max_new_tokens=256, 131 | ) 132 | 133 | return tokenizer.decode(outputs[0], skip_special_tokens=True) 134 | 135 | 136 | def main(): 137 | parser = argparse.ArgumentParser(description='Live Chat with C2C Models') 138 | parser.add_argument("--checkpoint_dir", type=str, nargs="+", required=True, help="Checkpoint directory(s)") 139 | parser.add_argument("--subfolder", type=str, default="final") 140 | parser.add_argument("--device", type=str, default="cuda:0") 141 | args = parser.parse_args() 142 | 143 | checkpoint_dirs = args.checkpoint_dir 144 | 145 | # Load model 146 | device = torch.device(args.device) 147 | model, tokenizer, num_sharers = load_model(checkpoint_dirs, args.subfolder, args.device) 148 | print(f"Loaded {num_sharers} sharer(s). Type 'q' to quit.\n") 149 | 150 | # Chat loop 151 | while True: 152 | user_input = input("You: ").strip() 153 | if user_input.lower() in ['q', 'quit', 'exit']: 154 | break 155 | if not user_input: 156 | continue 157 | 158 | response = generate(model, tokenizer, user_input, num_sharers, device) 159 | print(f"Bot: {response}\n") 160 | 161 | 162 | if __name__ == "__main__": 163 | main() 164 | -------------------------------------------------------------------------------- /script/playground/sample_response.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | import os 4 | import json 5 | import torch 6 | import numpy as np 7 | import pandas as pd 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | from datasets import load_dataset, DatasetDict 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 12 | try: 13 | from transformers.cache_utils import DynamicCache 14 | except ImportError: 15 | DynamicCache = None 16 | 17 | from rosetta.model.projector import create_projector 18 | from rosetta.model.wrapper import RosettaModel 19 | from rosetta.train.dataset_adapters import MMLUChatDataset, ChatDataset 20 | 21 | def load_qwen_model(model_name): 22 | """加载Qwen模型""" 23 | print(f"Loading Qwen model: {model_name}") 24 | model_path = "Qwen/" + model_name 25 | 26 | tokenizer = AutoTokenizer.from_pretrained( 27 | str(model_path), 28 | trust_remote_code=True, 29 | padding_side='left' 30 | ) 31 | 32 | if tokenizer.pad_token is None: 33 | tokenizer.pad_token = tokenizer.eos_token 34 | 35 | model = AutoModelForCausalLM.from_pretrained( 36 | str(model_path), 37 | torch_dtype=torch.bfloat16, 38 | ) 39 | 40 | return model, tokenizer 41 | 42 | def load_rosetta_model(): 43 | """加载Rosetta模型""" 44 | 45 | slm_model_path = "Qwen/Qwen3-0.6B" 46 | llm_model_path = "Qwen/Qwen3-4B" 47 | 48 | slm_tokenizer = AutoTokenizer.from_pretrained(str(slm_model_path)) 49 | 50 | slm_model = AutoModelForCausalLM.from_pretrained( 51 | str(slm_model_path), 52 | torch_dtype=torch.bfloat16, 53 | device_map = 'cuda' 54 | ).eval() 55 | 56 | llm_model = AutoModelForCausalLM.from_pretrained( 57 | str(llm_model_path), 58 | torch_dtype=torch.bfloat16, 59 | device_map = 'cuda' 60 | ).eval() 61 | 62 | # 创建投影器 63 | projector_config = { 64 | "type": "AdditiveProjector", 65 | "params": { 66 | "hidden_dim": 1024, 67 | "num_layers": 3, 68 | "dropout": 0.1, 69 | "activation": "gelu", 70 | "use_layer_norm": True, 71 | "init_weight": 0.1 72 | } 73 | } 74 | projector_params = projector_config["params"].copy() 75 | projector_params["dtype"] = torch.bfloat16 76 | projector = create_projector( 77 | projector_config["type"], 78 | source_dim=llm_model.config.head_dim, 79 | target_dim=slm_model.config.head_dim, 80 | **projector_params 81 | ) 82 | 83 | # 初始化Rosetta模型 84 | rosetta_model = RosettaModel( 85 | model_list=[slm_model, llm_model], 86 | base_model_idx=0, 87 | projector_list=[projector] 88 | ).to('cuda').eval() 89 | 90 | # 导入projector权重 91 | projector_weight_path = "local/checkpoints/4b_21-27_cot_1e-4/final/projector_0.pt" 92 | rosetta_model.projector_list[0].load_state_dict(torch.load(projector_weight_path, map_location='cpu')) 93 | layer_offset = llm_model.config.num_hidden_layers - slm_model.config.num_hidden_layers 94 | # layer_offset = 0 95 | 96 | # 配置投影器映射 97 | for layer_idx in range(21, 28): 98 | rosetta_model.set_projector_config( 99 | source_model_idx=1, # Teacher model 100 | source_model_layer_idx=layer_idx + layer_offset, 101 | target_model_idx=0, # Base model 102 | target_model_layer_idx=layer_idx, 103 | projector_idx=0 104 | ) 105 | 106 | return rosetta_model, slm_tokenizer 107 | 108 | import debugpy 109 | debugpy.listen(("0.0.0.0", 5678)) 110 | print("Waiting for debugger attach...") 111 | debugpy.wait_for_client() 112 | print("Debugger attached, running...") 113 | 114 | rosetta_model, rosetta_tokenizer = load_rosetta_model() 115 | slm_model, slm_tokenizer = load_qwen_model("Qwen3-0.6B") 116 | llm_model, llm_tokenizer = load_qwen_model("Qwen3-4B") 117 | 118 | instruct_ds = MMLUChatDataset(split="validation", num_samples=None) 119 | 120 | sampling_params = { 121 | 'do_sample': True, 122 | 'temperature': 0.7, 123 | 'top_p': 0.8, 124 | 'top_k': 20, 125 | 'min_p': 0.0, 126 | 'repetition_penalty': 1.1, 127 | 'max_new_tokens': 256 128 | } 129 | 130 | correct = 0 131 | total = 0 132 | 133 | rosetta_model.eval() 134 | rosetta_model.cuda() # 如果你在GPU上运行 135 | slm_model.eval() 136 | slm_model.cuda() # 如果你在GPU上运行 137 | llm_model.eval() 138 | llm_model.cuda() # 如果你在GPU上运行 139 | 140 | with open("analysis/venn_regions.json", "r") as f: 141 | results = json.load(f) 142 | 143 | idx = results["rosetta_llm"] 144 | 145 | for i in idx: 146 | sample = instruct_ds[int(i)] 147 | sample[0]['content'] += "\nYou should first give a short explanation and then output the final answer in the format 'The correct answer is ...'. Don't output the answer directly. Don't give a very long explanation, just a few sentences is enough." 148 | # sample[0]['content'] += "Give your answer in the format: 'The correct answer is A/B/C/D'. You should only output 'The correct answer is ...', without any additional text." 149 | # 用三个模型分别构造并且输出回答 150 | instruction_rosetta = rosetta_tokenizer.apply_chat_template( 151 | sample[:1], 152 | tokenize=False, 153 | add_generation_prompt=True, 154 | enable_thinking=False, 155 | ) 156 | input_rosetta = rosetta_tokenizer(instruction_rosetta, add_special_tokens=False) 157 | with torch.no_grad(): 158 | output_ids = rosetta_model.generate( 159 | input_ids=torch.tensor(input_rosetta["input_ids"]).unsqueeze(0).cuda(), 160 | attention_mask=torch.tensor(input_rosetta["attention_mask"]).unsqueeze(0).cuda(), 161 | **sampling_params 162 | )[0] 163 | full_output_rosetta = rosetta_tokenizer.decode(output_ids[len(input_rosetta['input_ids']):], skip_special_tokens=True) 164 | 165 | instruction_slm = slm_tokenizer.apply_chat_template( 166 | sample[:1], 167 | tokenize=False, 168 | add_generation_prompt=True, 169 | enable_thinking=False, 170 | ) 171 | input_slm = slm_tokenizer(instruction_slm, add_special_tokens=False) 172 | with torch.no_grad(): 173 | output_ids_slm = slm_model.generate( 174 | input_ids=torch.tensor(input_slm["input_ids"]).unsqueeze(0).cuda(), 175 | attention_mask=torch.tensor(input_slm["attention_mask"]).unsqueeze(0).cuda(), 176 | **sampling_params 177 | )[0] 178 | full_output_slm = slm_tokenizer.decode(output_ids_slm[len(input_slm['input_ids']):], skip_special_tokens=True) 179 | 180 | instruction_llm = llm_tokenizer.apply_chat_template( 181 | sample[:1], 182 | tokenize=False, 183 | add_generation_prompt=True, 184 | enable_thinking=False, 185 | ) 186 | input_llm = llm_tokenizer(instruction_llm, add_special_tokens=False) 187 | with torch.no_grad(): 188 | output_ids_llm = llm_model.generate( 189 | input_ids=torch.tensor(input_llm["input_ids"]).unsqueeze(0).cuda(), 190 | attention_mask=torch.tensor(input_llm["attention_mask"]).unsqueeze(0).cuda(), 191 | **sampling_params 192 | )[0] 193 | full_output_llm = llm_tokenizer.decode(output_ids_llm[len(input_llm['input_ids']):], skip_special_tokens=True) 194 | -------------------------------------------------------------------------------- /script/examples/two_stage_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example usage of the TwoStageInference pipeline for LLM+LLM evaluation. 3 | """ 4 | 5 | from multi_stage import TwoStageInference 6 | 7 | 8 | def example_standalone(): 9 | """Example of standalone usage.""" 10 | # Initialize the two-stage pipeline 11 | pipeline = TwoStageInference( 12 | context_model_path="Qwen/Qwen3-4B", # e.g., "Qwen/Qwen2.5-7B-Instruct" 13 | answer_model_path="Qwen/Qwen3-0.6B", # e.g., "Qwen/Qwen2.5-72B-Instruct" 14 | device="cuda", 15 | max_new_tokens=512, 16 | background_prompt="Analyze the key concepts and provide relevant background information needed to solve this problem:\n\n{question}" 17 | ) 18 | 19 | # Example MMLU question 20 | question_without_options = "What is the primary function of mitochondria in cells?" 21 | 22 | question_with_options = """What is the primary function of mitochondria in cells? 23 | 24 | A. Protein synthesis 25 | B. Energy production through ATP synthesis 26 | C. DNA replication 27 | D. Waste removal 28 | 29 | Answer: The correct answer is""" 30 | 31 | # Generate answer using the new generate method (model-like interface) 32 | answer = pipeline.generate( 33 | question_without_options=question_without_options, 34 | question_with_options=question_with_options, 35 | max_new_tokens=512 36 | ) 37 | 38 | print("Final Answer from Two-Stage Pipeline:") 39 | print(answer) 40 | 41 | # For detailed context, use process method 42 | result = pipeline.process( 43 | question_without_options=question_without_options, 44 | question_with_options=question_with_options 45 | ) 46 | 47 | print("\n" + "="*50) 48 | print("Detailed Breakdown:") 49 | print("Background Context from First LLM:") 50 | print(result["context"]) 51 | print("\nFinal Answer from Second LLM:") 52 | print(result["answer"]) 53 | 54 | 55 | def example_with_evaluator_integration(): 56 | """ 57 | Example showing how to integrate with unified_evaluator.py 58 | 59 | Minimal changes needed in unified_evaluator.py: 60 | 1. Import TwoStageInference 61 | 2. Add a flag in config to enable two-stage mode 62 | 3. In evaluate_subject method, check flag and use TwoStageInference 63 | """ 64 | 65 | # This would be in the evaluator's evaluate_subject method 66 | def modified_evaluate_subject_snippet(example, use_two_stage=True): 67 | """ 68 | Pseudo-code showing integration points. 69 | """ 70 | if use_two_stage: 71 | # Extract question without options 72 | question_text = example.get('question', '') 73 | 74 | # Build question with options (existing code) 75 | choices = "" 76 | for i, choice in enumerate(example.get('choices', [])): 77 | choices += f"{chr(65+i)}. {choice}\n" 78 | 79 | question_with_options = f"""{question_text} 80 | 81 | {choices} 82 | Answer: The correct answer is""" 83 | 84 | # Use TwoStageInference (now treated as a model) 85 | pipeline = TwoStageInference( 86 | context_model_path="model1_path", 87 | answer_model_path="model2_path" 88 | ) 89 | 90 | # Use generate method (model-like interface) 91 | answer = pipeline.generate( 92 | question_without_options=question_text, 93 | question_with_options=question_with_options, 94 | max_new_tokens=1024 95 | ) 96 | 97 | return answer 98 | else: 99 | # Use existing single-model approach 100 | pass 101 | 102 | 103 | def format_question_for_stages(example, dataset_name="mmlu-redux"): 104 | """ 105 | Helper function to format questions for two-stage processing. 106 | 107 | Args: 108 | example: Dataset example 109 | dataset_name: Name of the dataset 110 | 111 | Returns: 112 | Tuple of (question_without_options, question_with_options) 113 | """ 114 | if dataset_name == "mmlu-redux": 115 | question_text = example['question'] 116 | 117 | # Build choices 118 | choices = "" 119 | for i, choice in enumerate(example['choices']): 120 | choices += f"{chr(65+i)}. {choice}\n" 121 | 122 | # Question with full template for answering 123 | question_with_options = f"""{question_text} 124 | 125 | {choices} 126 | Answer: The correct answer is""" 127 | 128 | return question_text, question_with_options 129 | 130 | elif dataset_name == "mmmlu": 131 | question_text = example['Question'] 132 | 133 | # Build choices 134 | choices = "" 135 | for i, choice_key in enumerate(['A', 'B', 'C', 'D']): 136 | if choice_key in example: 137 | choices += f"{choice_key}. {example[choice_key]}\n" 138 | 139 | question_with_options = f"""{question_text} 140 | 141 | {choices} 142 | Answer: The correct answer is""" 143 | 144 | return question_text, question_with_options 145 | 146 | else: 147 | raise ValueError(f"Unknown dataset: {dataset_name}") 148 | 149 | 150 | 151 | def example_different_prompts(): 152 | """Example showing different background prompts.""" 153 | print("\n" + "="*60) 154 | print("Example with Different Background Prompts") 155 | print("="*60) 156 | 157 | # Different prompt styles 158 | prompts = { 159 | "concise": "Briefly describe the most useful background to solve the problem:\n\n{question}", 160 | "detailed": "Analyze the key concepts and provide relevant background information needed to solve this problem:\n\n{question}", 161 | "step_by_step": "Break down the problem and explain the key concepts step by step:\n\n{question}", 162 | "domain_focused": "Provide domain-specific knowledge and context relevant to this question:\n\n{question}" 163 | } 164 | 165 | question = "What is the primary function of mitochondria in cells?" 166 | 167 | for style, prompt in prompts.items(): 168 | print(f"\n--- {style.upper()} STYLE ---") 169 | print(f"Prompt: {prompt}") 170 | print(f"Full prompt: {prompt.format(question=question)}") 171 | print("-" * 40) 172 | 173 | if __name__ == "__main__": 174 | print("TwoStageInference Example") 175 | print("=" * 50) 176 | 177 | # Note: For actual usage, replace model paths with real model names 178 | print("\nThis script demonstrates the usage of TwoStageInference.") 179 | print("To run with real models, update the model paths in the code.") 180 | 181 | # Show the formatting example 182 | example_data = { 183 | 'question': 'What is the capital of France?', 184 | 'choices': ['London', 'Berlin', 'Paris', 'Madrid'] 185 | } 186 | 187 | q_without, q_with = format_question_for_stages(example_data) 188 | print("\nExample Question Formatting:") 189 | print("\n1. Question without options (for context generation):") 190 | print(q_without) 191 | print("\n2. Question with options (for final answer):") 192 | print(q_with) 193 | 194 | example_standalone() 195 | example_different_prompts() 196 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: rosetta 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/ 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 8 | - https://repo.anaconda.com/pkgs/main 9 | - https://repo.anaconda.com/pkgs/r 10 | - defaults 11 | dependencies: 12 | - _libgcc_mutex=0.1=main 13 | - _openmp_mutex=5.1=1_gnu 14 | - bzip2=1.0.8=h5eee18b_6 15 | - ca-certificates=2025.2.25=h06a4308_0 16 | - expat=2.7.1=h6a678d5_0 17 | - ld_impl_linux-64=2.40=h12ee557_0 18 | - libffi=3.4.4=h6a678d5_1 19 | - libgcc-ng=11.2.0=h1234567_1 20 | - libgomp=11.2.0=h1234567_1 21 | - libstdcxx-ng=11.2.0=h1234567_1 22 | - libuuid=1.41.5=h5eee18b_0 23 | - libxcb=1.17.0=h9b100fa_0 24 | - ncurses=6.5=h7934f7d_0 25 | - openssl=3.0.17=h5eee18b_0 26 | - pip=25.1=pyhc872135_2 27 | - pthread-stubs=0.3=h0ce48e5_1 28 | - python=3.10.18=h1a3bd86_0 29 | - readline=8.2=h5eee18b_0 30 | - setuptools=78.1.1=py310h06a4308_0 31 | - sqlite=3.50.2=hb25bd0a_1 32 | - tk=8.6.14=h993c535_1 33 | - wheel=0.45.1=py310h06a4308_0 34 | - xorg-libx11=1.8.12=h9b100fa_1 35 | - xorg-libxau=1.0.12=h9b100fa_0 36 | - xorg-libxdmcp=1.1.5=h9b100fa_0 37 | - xorg-xorgproto=2024.1=h5eee18b_1 38 | - xz=5.6.4=h5eee18b_1 39 | - zlib=1.2.13=h5eee18b_1 40 | - pip: 41 | - accelerate==1.9.0 42 | - aiohappyeyeballs==2.6.1 43 | - aiohttp==3.12.15 44 | - aiosignal==1.4.0 45 | - airportsdata==20250811 46 | - annotated-types==0.7.0 47 | - anthropic==0.64.0 48 | - antlr4-python3-runtime==4.13.2 49 | - anyio==4.10.0 50 | - asttokens==3.0.0 51 | - async-timeout==5.0.1 52 | - attrs==25.3.0 53 | - black==25.1.0 54 | - certifi==2025.8.3 55 | - cffi==1.17.1 56 | - charset-normalizer==3.4.2 57 | - click==8.2.1 58 | - cloudpickle==3.1.1 59 | - comm==0.2.3 60 | - compressed-tensors==0.10.2 61 | - contourpy==1.3.2 62 | - coverage==7.10.2 63 | - cuda-bindings==13.0.0 64 | - cuda-pathfinder==1.1.0 65 | - cuda-python==13.0.0 66 | - cycler==0.12.1 67 | - datasets==4.0.0 68 | - debugpy==1.8.15 69 | - decorator==5.2.1 70 | - decord==0.6.0 71 | - dill==0.3.8 72 | - diskcache==5.6.3 73 | - distro==1.9.0 74 | - einops==0.8.1 75 | - exceptiongroup==1.3.0 76 | - executing==2.2.0 77 | - fastapi==0.116.1 78 | - filelock==3.18.0 79 | - flake8==7.3.0 80 | - flashinfer-python==0.2.3 81 | - fonttools==4.59.0 82 | - frozenlist==1.7.0 83 | - fsspec==2025.3.0 84 | - gitdb==4.0.12 85 | - gitpython==3.1.45 86 | - h11==0.16.0 87 | - hf-transfer==0.1.9 88 | - hf-xet==1.1.5 89 | - httpcore==1.0.9 90 | - httpx==0.28.1 91 | - huggingface-hub==0.34.3 92 | - idna==3.10 93 | - importlib-metadata==8.7.0 94 | - iniconfig==2.1.0 95 | - interegular==0.3.3 96 | - ipykernel==6.30.1 97 | - ipython==8.37.0 98 | - isort==6.0.1 99 | - jedi==0.19.2 100 | - jinja2==3.1.6 101 | - jiter==0.10.0 102 | - joblib==1.5.1 103 | - jsonlines==4.0.0 104 | - jsonschema==4.25.0 105 | - jsonschema-specifications==2025.4.1 106 | - jupyter-client==8.6.3 107 | - jupyter-core==5.8.1 108 | - kiwisolver==1.4.8 109 | - lark==1.2.2 110 | - latex2sympy2-extended==1.10.2 111 | - litellm==1.75.5.post1 112 | - llguidance==0.7.30 113 | - markupsafe==3.0.2 114 | - math-verify==0.8.0 115 | - matplotlib==3.10.5 116 | - matplotlib-inline==0.1.7 117 | - matplotlib-venn==1.1.2 118 | - mccabe==0.7.0 119 | - modelscope==1.29.0 120 | - mpmath==1.3.0 121 | - multidict==6.6.3 122 | - multiprocess==0.70.16 123 | - mypy==1.17.1 124 | - mypy-extensions==1.1.0 125 | - nanobind==2.8.0 126 | - nest-asyncio==1.6.0 127 | - networkx==3.4.2 128 | - ninja==1.13.0 129 | - numpy==2.2.6 130 | - nvidia-cublas-cu12==12.4.5.8 131 | - nvidia-cuda-cupti-cu12==12.4.127 132 | - nvidia-cuda-nvrtc-cu12==12.4.127 133 | - nvidia-cuda-runtime-cu12==12.4.127 134 | - nvidia-cudnn-cu12==9.1.0.70 135 | - nvidia-cufft-cu12==11.2.1.3 136 | - nvidia-curand-cu12==10.3.5.147 137 | - nvidia-cusolver-cu12==11.6.1.9 138 | - nvidia-cusparse-cu12==12.3.1.170 139 | - nvidia-cusparselt-cu12==0.6.2 140 | - nvidia-ml-py==12.575.51 141 | - nvidia-nccl-cu12==2.21.5 142 | - nvidia-nvjitlink-cu12==12.4.127 143 | - nvidia-nvtx-cu12==12.4.127 144 | - openai==1.99.9 145 | - orjson==3.11.2 146 | - outlines==0.1.11 147 | - outlines-core==0.1.26 148 | - packaging==25.0 149 | - pandas==2.3.1 150 | - parso==0.8.4 151 | - partial-json-parser==0.2.1.1.post6 152 | - pathspec==0.12.1 153 | - peft==0.17.1 154 | - pexpect==4.9.0 155 | - pillow==11.3.0 156 | - platformdirs==4.3.8 157 | - pluggy==1.6.0 158 | - prometheus-client==0.22.1 159 | - prompt-toolkit==3.0.51 160 | - propcache==0.3.2 161 | - protobuf==6.31.1 162 | - psutil==7.0.0 163 | - ptyprocess==0.7.0 164 | - pure-eval==0.2.3 165 | - pyarrow==21.0.0 166 | - pycodestyle==2.14.0 167 | - pycountry==24.6.1 168 | - pycparser==2.22 169 | - pydantic==2.11.7 170 | - pydantic-core==2.33.2 171 | - pyflakes==3.4.0 172 | - pygments==2.19.2 173 | - pylatexenc==2.10 174 | - pynvml==12.0.0 175 | - pyparsing==3.2.3 176 | - pytest==8.4.1 177 | - pytest-cov==6.2.1 178 | - python-dateutil==2.9.0.post0 179 | - python-dotenv==1.1.1 180 | - python-multipart==0.0.20 181 | - pytz==2025.2 182 | - pyyaml==6.0.2 183 | - pyzmq==27.0.1 184 | - referencing==0.36.2 185 | - regex==2025.7.34 186 | - requests==2.32.4 187 | - rosetta==0.1.0 188 | - rpds-py==0.27.0 189 | - safetensors==0.5.3 190 | - scikit-learn==1.7.1 191 | - scipy==1.15.3 192 | - seaborn==0.13.2 193 | - sentencepiece==0.2.1 194 | - sentry-sdk==2.34.1 195 | - setproctitle==1.3.6 196 | - sgl-kernel==0.0.9.post2 197 | - sglang==0.4.6 198 | - six==1.17.0 199 | - smmap==5.0.2 200 | - sniffio==1.3.1 201 | - soundfile==0.13.1 202 | - stack-data==0.6.3 203 | - starlette==0.47.2 204 | - sympy==1.13.1 205 | - threadpoolctl==3.6.0 206 | - tiktoken==0.11.0 207 | - tokenizers==0.21.4 208 | - tomli==2.2.1 209 | - torch==2.6.0 210 | - torchao==0.12.0 211 | - torchvision==0.21.0 212 | - tornado==6.5.2 213 | - tqdm==4.67.1 214 | - traitlets==5.14.3 215 | - transformers==4.52.4 216 | - triton==3.2.0 217 | - typing-extensions==4.14.1 218 | - typing-inspection==0.4.1 219 | - tzdata==2025.2 220 | - urllib3==2.5.0 221 | - uvicorn==0.35.0 222 | - uvloop==0.21.0 223 | - wandb==0.21.0 224 | - wcwidth==0.2.13 225 | - xgrammar==0.1.17 226 | - xxhash==3.5.0 227 | - yarl==1.20.1 228 | - zipp==3.23.0 229 | prefix: /share/minzihan/miniconda3/envs/rosetta 230 | -------------------------------------------------------------------------------- /script/analysis/gate_weight/analyze_projector_weights.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Projector Weight Scalars Analysis\n", 8 | "\n", 9 | "This notebook loads a CSV collected by `collect_projector_weights.py`, computes per-sample mean key/value normalized scalars for each projector, and visualizes their distributions. For each projector, it plots:\n", 10 | "- Key scalar distribution with the corresponding key gate logit (shown as sigmoid(logit)) marked.\n", 11 | "- Value scalar distribution with the corresponding value gate logit (shown as sigmoid(logit)) marked.\n", 12 | "\n", 13 | "Set the CSV path and output directory below and run all cells.\n" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import os\n", 23 | "import json\n", 24 | "import math\n", 25 | "from typing import List\n", 26 | "\n", 27 | "import numpy as np\n", 28 | "import pandas as pd\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "\n", 31 | "# User parameters\n", 32 | "CSV_PATH = \"/mnt/public/minzihan/unified_memory/projector_weights.csv\" # path to your collected CSV\n", 33 | "OUTPUT_DIR = \"/mnt/public/minzihan/unified_memory/weight_distribution_plots\"\n", 34 | "SAVE_FIG = True\n", 35 | "FIG_DPI = 220\n", 36 | "\n", 37 | "os.makedirs(OUTPUT_DIR, exist_ok=True)\n", 38 | "\n", 39 | "df = pd.read_csv(CSV_PATH)\n", 40 | "print(df.head())\n", 41 | "print({c: df[c].dtype for c in df.columns})\n", 42 | "\n", 43 | "# Parse JSON columns into arrays\n", 44 | "\n", 45 | "def parse_json_array_safe(x):\n", 46 | " if pd.isna(x):\n", 47 | " return None\n", 48 | " try:\n", 49 | " return np.array(json.loads(x))\n", 50 | " except Exception:\n", 51 | " return None\n", 52 | "\n", 53 | "# Apply parsing\n", 54 | "parsed_key = df[\"norm_key_scalar\"].apply(parse_json_array_safe)\n", 55 | "parsed_value = df[\"norm_value_scalar\"].apply(parse_json_array_safe)\n", 56 | "\n", 57 | "df_parsed = df.copy()\n", 58 | "df_parsed[\"norm_key_scalar_arr\"] = parsed_key\n", 59 | "df_parsed[\"norm_value_scalar_arr\"] = parsed_value\n", 60 | "\n", 61 | "# Compute per-sample mean scalars for each projector\n", 62 | "\n", 63 | "def compute_mean_scalars(arr: np.ndarray) -> float:\n", 64 | " if arr is None:\n", 65 | " return np.nan\n", 66 | " try:\n", 67 | " return float(np.nanmean(arr))\n", 68 | " except Exception:\n", 69 | " return np.nan\n", 70 | "\n", 71 | "df_parsed[\"mean_key_scalar\"] = df_parsed[\"norm_key_scalar_arr\"].apply(compute_mean_scalars)\n", 72 | "df_parsed[\"mean_value_scalar\"] = df_parsed[\"norm_value_scalar_arr\"].apply(compute_mean_scalars)\n", 73 | "\n", 74 | "# Helper: convert gate logits to probabilities via sigmoid\n", 75 | "sigmoid = lambda t: 1.0 / (1.0 + np.exp(-t))\n", 76 | "\n", 77 | "# Aggregate per projector\n", 78 | "projector_ids = sorted(df_parsed[\"projector_index\"].unique().tolist())\n", 79 | "print(f\"Found {len(projector_ids)} projectors\")\n", 80 | "\n" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "def plot_projector_distributions(df_proj: pd.DataFrame, projector_index: int, bins: int = 30):\n", 90 | " data_key = df_proj[\"mean_key_scalar\"].dropna().values\n", 91 | " data_val = df_proj[\"mean_value_scalar\"].dropna().values\n", 92 | "\n", 93 | " # For logits, take unique values (they should be constant per projector)\n", 94 | " key_logit_vals = df_proj[\"key_gate_logit\"].dropna().unique()\n", 95 | " val_logit_vals = df_proj[\"value_gate_logit\"].dropna().unique()\n", 96 | "\n", 97 | " key_logit = key_logit_vals[0] if len(key_logit_vals) > 0 else None\n", 98 | " val_logit = val_logit_vals[0] if len(val_logit_vals) > 0 else None\n", 99 | "\n", 100 | " key_enabled = (key_logit is not None) and (key_logit >= 0)\n", 101 | " val_enabled = (val_logit is not None) and (val_logit >= 0)\n", 102 | "\n", 103 | " # Plot key distribution only if enabled (logit >= 0)\n", 104 | " if key_enabled and data_key.size > 0:\n", 105 | " plt.figure(figsize=(6.5, 4.0))\n", 106 | " plt.hist(data_key, bins=bins, color=\"#225ea8\", alpha=1.0, edgecolor=None)\n", 107 | " plt.title(f\"Projector {projector_index} - Key Scalar Mean Distribution\")\n", 108 | " plt.xlabel(\"Mean Key Scalar (per sample)\")\n", 109 | " plt.ylabel(\"Count\")\n", 110 | " plt.xlim(0.0, 1.0)\n", 111 | " plt.grid(True, linestyle=\":\", alpha=0.5)\n", 112 | " if SAVE_FIG:\n", 113 | " out_path = os.path.join(OUTPUT_DIR, f\"projector_{projector_index:03d}_key_dist.png\")\n", 114 | " plt.tight_layout()\n", 115 | " plt.savefig(out_path, dpi=FIG_DPI, bbox_inches=\"tight\")\n", 116 | " print(f\"Saved: {out_path}\")\n", 117 | " plt.show()\n", 118 | " else:\n", 119 | " print(f\"Skip key plot for projector {projector_index}: logit<{0} or no data\")\n", 120 | "\n", 121 | " # Plot value distribution only if enabled (logit >= 0)\n", 122 | " if val_enabled and data_val.size > 0:\n", 123 | " plt.figure(figsize=(6.5, 4.0))\n", 124 | " plt.hist(data_val, bins=bins, color=\"#fb6a4a\", alpha=1.0, edgecolor=None)\n", 125 | " plt.title(f\"Projector {projector_index} - Value Scalar Mean Distribution\")\n", 126 | " plt.xlabel(\"Mean Value Scalar (per sample)\")\n", 127 | " plt.ylabel(\"Count\")\n", 128 | " plt.xlim(0.0, 1.0)\n", 129 | " plt.grid(True, linestyle=\":\", alpha=0.5)\n", 130 | " if SAVE_FIG:\n", 131 | " out_path = os.path.join(OUTPUT_DIR, f\"projector_{projector_index:03d}_value_dist.png\")\n", 132 | " plt.tight_layout()\n", 133 | " plt.savefig(out_path, dpi=FIG_DPI, bbox_inches=\"tight\")\n", 134 | " print(f\"Saved: {out_path}\")\n", 135 | " plt.show()\n", 136 | " else:\n", 137 | " print(f\"Skip value plot for projector {projector_index}: logit<{0} or no data\")\n", 138 | "\n", 139 | "# Run over all projectors\n", 140 | "for pid in projector_ids:\n", 141 | " df_proj = df_parsed[df_parsed[\"projector_index\"] == pid]\n", 142 | " plot_projector_distributions(df_proj, projector_index=pid, bins=30)\n", 143 | "\n" 144 | ] 145 | } 146 | ], 147 | "metadata": { 148 | "kernelspec": { 149 | "display_name": "new_rosetta", 150 | "language": "python", 151 | "name": "python3" 152 | }, 153 | "language_info": { 154 | "codemirror_mode": { 155 | "name": "ipython", 156 | "version": 3 157 | }, 158 | "file_extension": ".py", 159 | "mimetype": "text/x-python", 160 | "name": "python", 161 | "nbconvert_exporter": "python", 162 | "pygments_lexer": "ipython3", 163 | "version": "3.10.18" 164 | } 165 | }, 166 | "nbformat": 4, 167 | "nbformat_minor": 2 168 | } 169 | -------------------------------------------------------------------------------- /script/analysis/length_ratio/compute_response_lengths.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compute statistics for model response lengths in a dataset directory. 3 | 4 | Supports: 5 | - CSV file at /OpenHermes_generated_results.csv 6 | - Hugging Face dataset saved at /dataset (via save_to_disk) 7 | 8 | Optionally uses a tokenizer to compute token lengths; otherwise uses character lengths. 9 | Saves a JSON summary to /response_length_stats.json and prints to stdout. 10 | """ 11 | 12 | import os 13 | import json 14 | import argparse 15 | from typing import List, Optional, Dict, Any 16 | 17 | import numpy as np 18 | import pandas as pd 19 | from tqdm import tqdm 20 | 21 | 22 | def read_csv_if_exists(dataset_dir: str) -> Optional[pd.DataFrame]: 23 | csv_path = os.path.join(dataset_dir, "OpenHermes_generated_results.csv") 24 | if os.path.exists(csv_path): 25 | try: 26 | df = pd.read_csv(csv_path) 27 | return df 28 | except Exception as e: 29 | print(f"Failed to read CSV at {csv_path}: {e}") 30 | return None 31 | 32 | 33 | def read_hf_dataset_if_exists(dataset_dir: str) -> Optional[Any]: 34 | dataset_path = os.path.join(dataset_dir, "dataset") 35 | if os.path.exists(dataset_path): 36 | try: 37 | from datasets import load_from_disk 38 | ds = load_from_disk(dataset_path) 39 | return ds 40 | except Exception as e: 41 | print(f"Failed to load HF dataset at {dataset_path}: {e}") 42 | return None 43 | 44 | 45 | def compute_char_lengths(texts: List[str]) -> List[int]: 46 | lengths: List[int] = [] 47 | for t in texts: 48 | if t is None or (isinstance(t, float) and np.isnan(t)): 49 | lengths.append(0) 50 | else: 51 | lengths.append(len(str(t))) 52 | return lengths 53 | 54 | 55 | def compute_token_lengths(texts: List[str], tokenizer_name: str, batch_size: int = 1024) -> List[int]: 56 | from transformers import AutoTokenizer 57 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True, trust_remote_code=True) 58 | 59 | lengths: List[int] = [] 60 | batch: List[str] = [] 61 | for t in tqdm(texts, desc="Tokenizing", unit="ex"): 62 | if t is None or (isinstance(t, float) and np.isnan(t)): 63 | t = "" 64 | batch.append(str(t)) 65 | if len(batch) >= batch_size: 66 | enc = tokenizer(batch, add_special_tokens=False, padding=False, truncation=False) 67 | lens = [len(ids) for ids in enc["input_ids"]] 68 | lengths.extend(lens) 69 | batch = [] 70 | if batch: 71 | enc = tokenizer(batch, add_special_tokens=False, padding=False, truncation=False) 72 | lens = [len(ids) for ids in enc["input_ids"]] 73 | lengths.extend(lens) 74 | return lengths 75 | 76 | 77 | def summarize_lengths(lengths: List[int]) -> Dict[str, Any]: 78 | if len(lengths) == 0: 79 | return {"count": 0} 80 | arr = np.array(lengths, dtype=np.int64) 81 | percentiles = [50, 90, 95, 99] 82 | perc_vals = np.percentile(arr, percentiles).tolist() 83 | summary = { 84 | "count": int(arr.size), 85 | "mean": float(arr.mean()), 86 | "std": float(arr.std(ddof=0)), 87 | "min": int(arr.min()), 88 | "p50": float(perc_vals[0]), 89 | "p90": float(perc_vals[1]), 90 | "p95": float(perc_vals[2]), 91 | "p99": float(perc_vals[3]), 92 | "max": int(arr.max()), 93 | } 94 | return summary 95 | 96 | 97 | def run(dataset_dir: str, column: str, tokenizer_name: Optional[str], batch_size: int, sum_columns: Optional[List[str]] = None) -> Dict[str, Any]: 98 | # Load data source (CSV preferred for speed; fallback to HF dataset) 99 | df = read_csv_if_exists(dataset_dir) 100 | if df is not None: 101 | if sum_columns: 102 | # Concatenate multiple columns into one text 103 | for col in sum_columns: 104 | if col not in df.columns: 105 | raise KeyError(f"Column '{col}' not found in CSV columns: {list(df.columns)}") 106 | texts = [ 107 | "\n".join([ 108 | "" if (pd.isna(row.get(col)) if isinstance(row, dict) else pd.isna(row[col])) else str(row[col]) 109 | for col in sum_columns 110 | ]) 111 | for _, row in df.iterrows() 112 | ] 113 | else: 114 | texts = df.get(column, pd.Series(dtype=object)).tolist() 115 | else: 116 | ds = read_hf_dataset_if_exists(dataset_dir) 117 | if ds is None: 118 | raise FileNotFoundError( 119 | f"No CSV or HF dataset found under {dataset_dir}. Expected CSV 'OpenHermes_generated_results.csv' or dataset folder 'dataset'." 120 | ) 121 | # Avoid loading the full dataset into memory; stream the column 122 | texts = [] 123 | if sum_columns: 124 | for col in sum_columns: 125 | if col not in ds.column_names: 126 | raise KeyError(f"Column '{col}' not found in dataset columns: {ds.column_names}") 127 | for ex in tqdm(ds, desc="Reading dataset", unit="ex"): 128 | parts = [] 129 | for col in sum_columns: 130 | val = ex.get(col, "") 131 | if val is None: 132 | val = "" 133 | parts.append(str(val)) 134 | texts.append("\n".join(parts)) 135 | else: 136 | if column not in ds.column_names: 137 | raise KeyError(f"Column '{column}' not found in dataset columns: {ds.column_names}") 138 | for ex in tqdm(ds, desc="Reading dataset", unit="ex"): 139 | texts.append(ex.get(column, "")) 140 | 141 | if tokenizer_name: 142 | lengths = compute_token_lengths(texts, tokenizer_name=tokenizer_name, batch_size=batch_size) 143 | unit = "tokens" 144 | else: 145 | lengths = compute_char_lengths(texts) 146 | unit = "chars" 147 | 148 | summary = summarize_lengths(lengths) 149 | summary["unit"] = unit 150 | return summary 151 | 152 | 153 | def main(): 154 | parser = argparse.ArgumentParser(description="Compute model response length statistics.") 155 | parser.add_argument( 156 | "--dataset_dir", 157 | type=str, 158 | default="local/teacher_datasets/openhermes_qwen_output", 159 | help="Directory containing CSV and/or HF dataset subfolder", 160 | ) 161 | parser.add_argument( 162 | "--column", 163 | type=str, 164 | default="model_response", 165 | help="Column name to analyze", 166 | ) 167 | parser.add_argument( 168 | "--sum_columns", 169 | nargs="+", 170 | default=None, 171 | help="Concatenate these columns (e.g., input_text model_response) and analyze their combined length", 172 | ) 173 | parser.add_argument( 174 | "--tokenizer", 175 | type=str, 176 | default=None, 177 | help="Tokenizer name/path to compute token lengths. If omitted, character lengths are used.", 178 | ) 179 | parser.add_argument( 180 | "--batch_size", 181 | type=int, 182 | default=1024, 183 | help="Batch size for tokenization", 184 | ) 185 | args = parser.parse_args() 186 | 187 | summary = run( 188 | dataset_dir=args.dataset_dir, 189 | column=args.column, 190 | tokenizer_name=args.tokenizer, 191 | batch_size=args.batch_size, 192 | sum_columns=args.sum_columns, 193 | ) 194 | 195 | # Print 196 | print("\n=== Response Length Summary ===") 197 | for k, v in summary.items(): 198 | print(f"{k}: {v}") 199 | 200 | # Save JSON next to dataset 201 | out_path = os.path.join(args.dataset_dir, "response_length_stats.json") 202 | try: 203 | with open(out_path, "w") as f: 204 | json.dump(summary, f, indent=2) 205 | print(f"Saved summary to {out_path}") 206 | except Exception as e: 207 | print(f"Failed to save summary JSON: {e}") 208 | 209 | 210 | if __name__ == "__main__": 211 | main() 212 | 213 | 214 | -------------------------------------------------------------------------------- /script/examples/two_stage_rosetta_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example usage of the TwoStageRosetta pipeline for LLM+Rosetta evaluation. 3 | """ 4 | 5 | from rosetta.baseline.multi_stage import TwoStageRosetta 6 | 7 | 8 | def example_standalone(): 9 | """Example of standalone usage.""" 10 | # Initialize the two-stage pipeline with Rosetta (following load_model_from_checkpoint pattern) 11 | pipeline = TwoStageRosetta( 12 | context_model_path="Qwen/Qwen3-4B", # First LLM for context generation 13 | rosetta_checkpoint_dir="local/checkpoints/Qwen3_0.6B_4B_general_LLM_data", # Path to checkpoint directory 14 | rosetta_subfolder="checkpoint-2000", # Subfolder name (e.g., 'final', 'checkpoint-1000') 15 | device="cuda", 16 | max_new_tokens=512, 17 | background_prompt="In one clear sentence, describe the most essential background knowledge needed to answer the question: {question}" 18 | ) 19 | 20 | # Example MMLU question 21 | question_without_options = "What is the primary function of mitochondria in cells?" 22 | 23 | question_with_options = """What is the primary function of mitochondria in cells? 24 | 25 | A. Protein synthesis 26 | B. Energy production through ATP synthesis 27 | C. DNA replication 28 | D. Waste removal 29 | 30 | Answer: The correct answer is""" 31 | 32 | # Generate answer using the new generate method (model-like interface) 33 | answer = pipeline.generate( 34 | question_without_options=question_without_options, 35 | question_with_options=question_with_options, 36 | max_new_tokens=512 37 | ) 38 | 39 | print("Final Answer from Two-Stage Rosetta Pipeline:") 40 | print(answer) 41 | 42 | # For detailed context, use process method 43 | result = pipeline.process( 44 | question_without_options=question_without_options, 45 | question_with_options=question_with_options 46 | ) 47 | 48 | print("\n" + "="*50) 49 | print("Detailed Breakdown:") 50 | print("Background Context from First LLM:") 51 | print(result["context"]) 52 | print("\nFinal Answer from Rosetta Model:") 53 | print(result["answer"]) 54 | 55 | 56 | def example_with_evaluator_integration(): 57 | """ 58 | Example showing how to integrate with unified_evaluator.py 59 | 60 | The integration would be similar to TwoStageInference but with Rosetta-specific config. 61 | """ 62 | 63 | # This would be in the evaluator's evaluate_subject method 64 | def modified_evaluate_subject_snippet(example, use_two_stage_rosetta=True): 65 | """ 66 | Pseudo-code showing integration points for TwoStageRosetta. 67 | """ 68 | if use_two_stage_rosetta: 69 | # Extract question without options 70 | question_text = example.get('question', '') 71 | 72 | # Build question with options (existing code) 73 | choices = "" 74 | for i, choice in enumerate(example.get('choices', [])): 75 | choices += f"{chr(65+i)}. {choice}\n" 76 | 77 | question_with_options = f"""{question_text} 78 | 79 | {choices} 80 | Answer: The correct answer is""" 81 | 82 | # Use TwoStageRosetta (following load_model_from_checkpoint pattern) 83 | pipeline = TwoStageRosetta( 84 | context_model_path="model1_path", 85 | rosetta_checkpoint_dir="/path/to/rosetta/checkpoints", 86 | rosetta_subfolder="final" 87 | ) 88 | 89 | # Use generate method (model-like interface) 90 | answer = pipeline.generate( 91 | question_without_options=question_text, 92 | question_with_options=question_with_options, 93 | max_new_tokens=1024 94 | ) 95 | 96 | return answer 97 | else: 98 | # Use existing single-model approach 99 | pass 100 | 101 | 102 | def format_question_for_stages(example, dataset_name="mmlu-redux"): 103 | """ 104 | Helper function to format questions for two-stage processing. 105 | 106 | Args: 107 | example: Dataset example 108 | dataset_name: Name of the dataset 109 | 110 | Returns: 111 | Tuple of (question_without_options, question_with_options) 112 | """ 113 | if dataset_name == "mmlu-redux": 114 | question_text = example['question'] 115 | 116 | # Build choices 117 | choices = "" 118 | for i, choice in enumerate(example['choices']): 119 | choices += f"{chr(65+i)}. {choice}\n" 120 | 121 | # Question with full template for answering 122 | question_with_options = f"""{question_text} 123 | 124 | {choices} 125 | Answer: The correct answer is""" 126 | 127 | return question_text, question_with_options 128 | 129 | elif dataset_name == "mmmlu": 130 | question_text = example['Question'] 131 | 132 | # Build choices 133 | choices = "" 134 | for i, choice_key in enumerate(['A', 'B', 'C', 'D']): 135 | if choice_key in example: 136 | choices += f"{choice_key}. {example[choice_key]}\n" 137 | 138 | question_with_options = f"""{question_text} 139 | 140 | {choices} 141 | Answer: The correct answer is""" 142 | 143 | return question_text, question_with_options 144 | 145 | else: 146 | raise ValueError(f"Unknown dataset: {dataset_name}") 147 | 148 | 149 | def example_different_prompts(): 150 | """Example showing different background prompts.""" 151 | print("\n" + "="*60) 152 | print("Example with Different Background Prompts") 153 | print("="*60) 154 | 155 | # Different prompt styles 156 | prompts = { 157 | "concise": "Briefly describe the most useful background to solve the problem:\n\n{question}", 158 | "detailed": "Analyze the key concepts and provide relevant background information needed to solve this problem:\n\n{question}", 159 | "step_by_step": "Break down the problem and explain the key concepts step by step:\n\n{question}", 160 | "domain_focused": "Provide domain-specific knowledge and context relevant to this question:\n\n{question}" 161 | } 162 | 163 | question = "What is the primary function of mitochondria in cells?" 164 | 165 | for style, prompt in prompts.items(): 166 | print(f"\n--- {style.upper()} STYLE ---") 167 | print(f"Prompt: {prompt}") 168 | print(f"Full prompt: {prompt.format(question=question)}") 169 | print("-" * 40) 170 | 171 | 172 | def example_configuration_options(): 173 | """Example showing different configuration options for Rosetta.""" 174 | print("\n" + "="*60) 175 | print("Configuration Options for TwoStageRosetta") 176 | print("="*60) 177 | 178 | # Example configurations (following load_model_from_checkpoint pattern) 179 | configs = { 180 | "basic": { 181 | "context_model_path": "Qwen/Qwen3-4B", 182 | "rosetta_checkpoint_dir": "/path/to/checkpoints", 183 | "rosetta_subfolder": "final", 184 | "device": "cuda", 185 | "max_new_tokens": 512 186 | }, 187 | "different_subfolder": { 188 | "context_model_path": "Qwen/Qwen3-4B", 189 | "rosetta_checkpoint_dir": "/path/to/checkpoints", 190 | "rosetta_subfolder": "checkpoint-1000", 191 | "device": "cuda", 192 | "max_new_tokens": 1024 193 | }, 194 | "custom_prompt": { 195 | "context_model_path": "Qwen/Qwen3-4B", 196 | "rosetta_checkpoint_dir": "/path/to/checkpoints", 197 | "rosetta_subfolder": "final", 198 | "device": "cuda", 199 | "max_new_tokens": 512, 200 | "background_prompt": "Provide detailed background information for this question:\n\n{question}" 201 | } 202 | } 203 | 204 | for config_name, config in configs.items(): 205 | print(f"\n--- {config_name.upper()} CONFIG ---") 206 | for key, value in config.items(): 207 | print(f"{key}: {value}") 208 | print("-" * 40) 209 | 210 | 211 | if __name__ == "__main__": 212 | print("TwoStageRosetta Example") 213 | print("=" * 50) 214 | 215 | # Note: For actual usage, replace model paths with real model names 216 | print("\nThis script demonstrates the usage of TwoStageRosetta.") 217 | print("To run with real models, update the model paths and checkpoint directories.") 218 | 219 | # Show the formatting example 220 | example_data = { 221 | 'question': 'What is the capital of France?', 222 | 'choices': ['London', 'Berlin', 'Paris', 'Madrid'] 223 | } 224 | 225 | q_without, q_with = format_question_for_stages(example_data) 226 | print("\nExample Question Formatting:") 227 | print("\n1. Question without options (for context generation):") 228 | print(q_without) 229 | print("\n2. Question with options (for final answer):") 230 | print(q_with) 231 | 232 | example_standalone() 233 | example_different_prompts() 234 | example_configuration_options() 235 | -------------------------------------------------------------------------------- /rosetta/utils/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unified registry utilities and simple JSON-based save/load helpers. 3 | 4 | This module provides: 5 | - create_registry: factory to create (registry dict, register decorator, get_class) 6 | - capture_init_args: decorator to record __init__ kwargs on instances as _init_args 7 | - save_object / load_object: serialize/deserialize object configs via registry 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | import inspect 13 | import json 14 | from typing import Dict, Type, Callable, Optional, Tuple, TypeVar, Any 15 | import torch 16 | 17 | T = TypeVar("T") 18 | 19 | 20 | def create_registry( 21 | registry_name: str, 22 | case_insensitive: bool = False, 23 | ) -> Tuple[Dict[str, Type[T]], Callable[..., Type[T]], Callable[[str], Type[T]]]: 24 | """ 25 | Create a registry system with register and get functions. 26 | 27 | Args: 28 | registry_name: Name used in error messages (e.g., "projector") 29 | case_insensitive: Whether to store lowercase versions of names 30 | 31 | Returns: 32 | (registry_dict, register_function, get_function) 33 | """ 34 | 35 | registry: Dict[str, Type[T]] = {} 36 | 37 | def register(cls_or_name=None, name: Optional[str] = None): 38 | """Register a class in the registry. Supports multiple usage patterns. 39 | 40 | Usage: 41 | @register 42 | class Foo: ... 43 | 44 | @register("foo") 45 | class Foo: ... 46 | 47 | @register(name="foo") 48 | class Foo: ... 49 | """ 50 | 51 | def _register(c: Type[T]) -> Type[T]: 52 | # Determine the name to use 53 | if isinstance(cls_or_name, str): 54 | class_name = cls_or_name 55 | elif name is not None: 56 | class_name = name 57 | else: 58 | class_name = c.__name__ 59 | 60 | registry[class_name] = c 61 | if case_insensitive: 62 | registry[class_name.lower()] = c 63 | return c 64 | 65 | if cls_or_name is not None and not isinstance(cls_or_name, str): 66 | # Called as @register or register(cls) 67 | return _register(cls_or_name) 68 | else: 69 | # Called as @register("name") or @register(name="name") 70 | return _register 71 | 72 | def get_class(name: str) -> Type[T]: 73 | """Get class by name from registry.""" 74 | if name not in registry: 75 | # Build readable available list without duplicates when case_insensitive 76 | seen = set() 77 | available = [] 78 | for k in registry.keys(): 79 | if k.lower() in seen: 80 | continue 81 | seen.add(k.lower()) 82 | available.append(k) 83 | raise ValueError( 84 | f"Unknown {registry_name} class: {name}. Available: {available}" 85 | ) 86 | return registry[name] 87 | 88 | return registry, register, get_class 89 | 90 | 91 | def capture_init_args(cls): 92 | """ 93 | Decorator to capture initialization arguments of a class. 94 | 95 | Stores the mapping of the constructor's parameters to the values supplied 96 | at instantiation time into `self._init_args` for later serialization. 97 | """ 98 | original_init = cls.__init__ 99 | 100 | def new_init(self, *args, **kwargs): 101 | # Store all initialization arguments 102 | init_args: Dict[str, Any] = {} 103 | 104 | # Get parameter names from the original __init__ method 105 | sig = inspect.signature(original_init) 106 | param_names = list(sig.parameters.keys())[1:] # Skip 'self' 107 | 108 | # Map positional args to parameter names 109 | for i, arg in enumerate(args): 110 | if i < len(param_names): 111 | init_args[param_names[i]] = arg 112 | 113 | # Add keyword args 114 | init_args.update(kwargs) 115 | 116 | self._init_args = init_args 117 | 118 | # Call the original __init__ 119 | original_init(self, *args, **kwargs) 120 | 121 | cls.__init__ = new_init 122 | return cls 123 | 124 | 125 | # ------------------------- 126 | # Serialization utilities 127 | # ------------------------- 128 | 129 | def _encode_value(value: Any) -> Any: 130 | """Best-effort JSON encoding for common ML types.""" 131 | # Primitives and None 132 | if value is None or isinstance(value, (bool, int, float, str)): 133 | return value 134 | 135 | # Tuples -> lists 136 | if isinstance(value, tuple): 137 | return [ 138 | _encode_value(v) for v in value 139 | ] 140 | 141 | # Lists 142 | if isinstance(value, list): 143 | return [ 144 | _encode_value(v) for v in value 145 | ] 146 | 147 | # Dicts 148 | if isinstance(value, dict): 149 | return {k: _encode_value(v) for k, v in value.items()} 150 | 151 | # torch-specific types 152 | if torch is not None: 153 | # torch.dtype 154 | if isinstance(value, type(getattr(torch, "float32", object))): 155 | # Guard: torch.dtype is not a class; rely on str(value) format 156 | s = str(value) 157 | if s.startswith("torch."): 158 | return {"__type__": "torch.dtype", "value": s.split(".")[-1]} 159 | 160 | # torch.device 161 | if isinstance(value, getattr(torch, "device", ())): 162 | return {"__type__": "torch.device", "value": str(value)} 163 | 164 | # Fallback to string representation 165 | return {"__type__": "str", "value": str(value)} 166 | 167 | 168 | def _decode_value(value: Any) -> Any: 169 | """Decode values produced by _encode_value, recursively for containers.""" 170 | # Lists: decode each element 171 | if isinstance(value, list): 172 | return [_decode_value(v) for v in value] 173 | 174 | # Dicts: either a typed-marker dict or a regular mapping that needs recursive decoding 175 | if isinstance(value, dict): 176 | if "__type__" in value: 177 | t = value.get("__type__") 178 | v = value.get("value") 179 | 180 | if t == "torch.dtype" and torch is not None: 181 | dtype = getattr(torch, str(v), None) 182 | if dtype is None: 183 | raise ValueError(f"Unknown torch.dtype: {v}") 184 | return dtype 185 | 186 | if t == "torch.device" and torch is not None: 187 | return torch.device(v) 188 | 189 | if t == "str": 190 | return str(v) 191 | 192 | # Unknown type marker; return raw as-is 193 | return value 194 | 195 | # Regular dict: decode values recursively 196 | return {k: _decode_value(v) for k, v in value.items()} 197 | 198 | # Primitives and anything else: return as-is 199 | return value 200 | 201 | 202 | def save_object(obj: Any, file_path: str) -> None: 203 | """ 204 | Save an object's construction config to a JSON file. 205 | 206 | The object is expected to have been decorated with capture_init_args, 207 | so that `obj._init_args` exists. 208 | """ 209 | class_name = obj.__class__.__name__ 210 | init_args = getattr(obj, "_init_args", {}) 211 | 212 | serializable_args = _encode_value(init_args) 213 | payload = { 214 | "class": class_name, 215 | "init_args": serializable_args, 216 | } 217 | 218 | with open(file_path, "w", encoding="utf-8") as f: 219 | json.dump(payload, f, indent=2) 220 | 221 | 222 | def load_object( 223 | file_path: str, 224 | get_class_fn: Callable[[str], Type[T]], 225 | override_args: Optional[Dict[str, Any]] = None, 226 | ) -> T: 227 | """ 228 | Load an object from a JSON config file previously saved by save_object. 229 | 230 | Args: 231 | file_path: Path to JSON file 232 | get_class_fn: Function to resolve class names from registry 233 | override_args: Optional dict to override stored init args 234 | 235 | Returns: 236 | Instantiated object of type T 237 | """ 238 | with open(file_path, "r", encoding="utf-8") as f: 239 | payload = json.load(f) 240 | 241 | class_name = payload["class"] 242 | encoded_args = payload.get("init_args", {}) 243 | init_args = _decode_value(encoded_args) 244 | 245 | if override_args: 246 | init_args.update(override_args) 247 | 248 | cls = get_class_fn(class_name) 249 | return cls(**init_args) 250 | 251 | 252 | def dumps_object_config(obj: Any) -> str: 253 | """Return a JSON string with the object's class and init args.""" 254 | class_name = obj.__class__.__name__ 255 | init_args = getattr(obj, "_init_args", {}) 256 | serializable_args = _encode_value(init_args) 257 | return json.dumps({"class": class_name, "init_args": serializable_args}, indent=2) 258 | 259 | 260 | def loads_object_config( 261 | s: str, 262 | get_class_fn: Callable[[str], Type[T]], 263 | override_args: Optional[Dict[str, Any]] = None, 264 | ) -> T: 265 | """Instantiate an object from a JSON string produced by dumps_object_config.""" 266 | payload = json.loads(s) 267 | class_name = payload["class"] 268 | encoded_args = payload.get("init_args", {}) 269 | init_args = _decode_value(encoded_args) 270 | if override_args: 271 | init_args.update(override_args) 272 | cls = get_class_fn(class_name) 273 | return cls(**init_args) 274 | 275 | 276 | # Model Registry System (case-insensitive for backward compatibility) 277 | PROJECTOR_REGISTRY, register_model, get_projector_class = create_registry( 278 | "projector", case_insensitive=True 279 | ) -------------------------------------------------------------------------------- /script/analysis/scaling/scaling_curve.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import seaborn as sns 4 | from matplotlib import colors as mcolors 5 | from matplotlib.ticker import MaxNLocator 6 | from matplotlib import transforms as mtransforms 7 | import matplotlib.pyplot as plt 8 | 9 | # # 定义 base 模型 10 | # base_models = ["Pure Teacher", "0.6B", "1.7B", "4B", "8B", "14B"] 11 | 12 | # # 定义 teacher size 13 | # teacher_sizes = ["0", "0.5B", "1.5B", "7B", "14B"] 14 | 15 | # # 数据表(先空着,后续你填真实 acc 数据) 16 | # # 每个 base 对应一个列表,长度等于 teacher_sizes 17 | # data = { 18 | # "Pure Teacher": [None,36.49, 58.79, 72.25, 77.98], 19 | # "0.6B": [35.03, 47.19, 57.62, 69.67, 65.75], 20 | # "1.7B": [58.26, None,61.43, 71.22, 74.20], 21 | # "4B": [71.47, None, None, 73.67, 75.73], 22 | # "8B": [75.53, None, None, 75.83, 77.45], 23 | # "14B": [79.49, None, None, None, 80.33], 24 | # } 25 | 26 | # 定义 base 模型 27 | base_models = ["Pure Teacher", "0.6B", "1.7B", "4B", "8B", "14B"] 28 | 29 | # 定义 teacher size 30 | teacher_sizes = ["0", "0.5B", "1.5B", "3B", "7B", "14B"] 31 | 32 | # 数据表(先空着,后续你填真实 acc 数据) 33 | # 每个 base 对应一个列表,长度等于 teacher_sizes 34 | data = { 35 | "Pure Teacher": [None,38.42, 58.79, 63.32, 72.25, 77.98], 36 | "0.6B": [35.53, 47.19, 57.62, 59.13, 68.43, 67.49], 37 | "1.7B": [58.26, None,61.43, 63.42, 71.22, 74.20], 38 | "4B": [71.47, None, None, 71.22, 73.67, 75.73], 39 | "8B": [75.53, None, None, None, 75.83, 77.45], 40 | "14B": [79.49, None, None, None, None, 80.33], 41 | } 42 | 43 | # data = { 44 | # "Pure Teacher": [None,38.42, 58.79, 63.32, 72.25, 77.98], 45 | # "0.6B": [35.53, 41.51, 43.22, 43.32, 43.50, 42.74], 46 | # "1.7B": [58.26, None,60.00, 60.94, 60.46, 60.40], 47 | # "4B": [71.47, None, None, 72.67, 72.66, 73.28], 48 | # "8B": [75.53, None, None, None, 76.83, 76.74], 49 | # "14B": [79.49, None, None, None, None, 75.41], 50 | # } 51 | 52 | # 使用 seaborn 美化风格(学术风格) 53 | sns.set_theme(context="paper", style="whitegrid", font_scale=1.0) 54 | 55 | # 统一风格设置(参考 style.md) 56 | FIGSIZE = (6, 3.2) 57 | BASE_FONT_SIZE = 14 58 | LEGEND_FONT_SIZE = 12 59 | TITLE_FONT_SIZE = 14 60 | TEXT_COLOR = "#000000" # Black 61 | plt.rcParams.update({ 62 | "axes.titlesize": TITLE_FONT_SIZE, 63 | "axes.labelsize": BASE_FONT_SIZE, 64 | "xtick.labelsize": BASE_FONT_SIZE, 65 | "ytick.labelsize": BASE_FONT_SIZE, 66 | "legend.fontsize": LEGEND_FONT_SIZE, 67 | "legend.title_fontsize": LEGEND_FONT_SIZE, 68 | # 文本与坐标轴颜色 69 | "text.color": TEXT_COLOR, 70 | "axes.labelcolor": TEXT_COLOR, 71 | "xtick.color": TEXT_COLOR, 72 | "ytick.color": TEXT_COLOR, 73 | # Matplotlib >=3.6 支持 legend.labelcolor 74 | "legend.labelcolor": TEXT_COLOR, 75 | }) 76 | 77 | # 为当前绘图排除 teacher-size 为 0 的数据,但保留原始数据以备后用 78 | plot_indices = list(range(1, len(teacher_sizes))) 79 | teacher_sizes_plot = [teacher_sizes[i] for i in plot_indices] 80 | 81 | # 将教师规模标签转换为数值(单位:B)以在 x 轴上使用真实间距 82 | def _parse_size(label: str) -> float: 83 | if label.endswith("B"): 84 | label = label[:-1] 85 | return float(label) 86 | 87 | numeric_teacher_sizes = [_parse_size(s) for s in teacher_sizes] 88 | x_plot = [numeric_teacher_sizes[i] for i in plot_indices] 89 | pure_teacher_ticks = [data["Pure Teacher"][i] for i in plot_indices] 90 | 91 | # 颜色配置:非 Pure Teacher 使用调色板,Pure Teacher 使用灰色虚线 92 | non_pure_models = [m for m in base_models if m != "Pure Teacher"] 93 | palette = sns.color_palette("tab10", n_colors=len(non_pure_models)) 94 | model_to_color = {m: palette[i] for i, m in enumerate(non_pure_models)} 95 | 96 | # plt.figure(figsize=(6.4, 4), dpi=300) 97 | # plt.xscale('log', base=10) 98 | # ax = plt.gca() 99 | # trans_data_axes = mtransforms.blended_transform_factory(ax.transData, ax.transAxes) 100 | 101 | # for model in base_models: 102 | # y = data[model] 103 | # valid_orig_indices = [i for i in plot_indices if y[i] is not None] 104 | # if not valid_orig_indices: 105 | # continue 106 | # plot_x = [numeric_teacher_sizes[i] for i in valid_orig_indices] 107 | # plot_y = [y[i] for i in valid_orig_indices] 108 | 109 | # if model == "Pure Teacher": 110 | # sns.lineplot(x=plot_x, y=plot_y, marker="o", color="0.4", linestyle="--", linewidth=2, label=model) 111 | # else: 112 | # sns.lineplot(x=plot_x, y=plot_y, marker="o", color=model_to_color[model], linewidth=2, label=model) 113 | 114 | 115 | 116 | # plt.xticks(ticks=x_plot, labels=teacher_sizes_plot) 117 | # plt.gca().set_xticks(x_plot) 118 | # plt.gca().set_xticklabels(teacher_sizes_plot) 119 | # plt.xlabel("Teacher Size (B)") 120 | # plt.ylabel("Performance (Acc)") 121 | # plt.title("Performance vs Teacher Size") 122 | # plt.legend(title="Base Model") 123 | # plt.gca().yaxis.set_major_locator(MaxNLocator(nbins=4)) 124 | # ax = plt.gca() 125 | # ax.xaxis.set_label_coords(0.5, -0.18) 126 | # plt.gcf().subplots_adjust(bottom=0.26) 127 | # sns.despine() 128 | 129 | # plt.tight_layout() 130 | # plt.savefig("scaling_curve.png", bbox_inches="tight") 131 | # plt.show() 132 | 133 | # 第二张图:展示相对 base-only 的提升(Δ Accuracy) 134 | plt.figure(figsize=FIGSIZE, dpi=300) 135 | plt.xscale('log', base=10) 136 | ax = plt.gca() 137 | trans_data_axes = mtransforms.blended_transform_factory(ax.transData, ax.transAxes) 138 | 139 | baseline_dot_y = 0 140 | label_y_for_baseline = -0.7 141 | 142 | # 使用蓝色渐变为不同曲线着色(最浅 #8FAADC → 最深 #203864) 143 | blue_cmap = mcolors.LinearSegmentedColormap.from_list("blue_grad", ["#8FAADC", "#203864"]) 144 | blue_shades = [blue_cmap(x) for x in np.linspace(0, 1, len(non_pure_models))] 145 | improv_color_map = {m: blue_shades[i] for i, m in enumerate(non_pure_models)} 146 | 147 | for model in non_pure_models: 148 | y = data[model] 149 | baseline = y[0] 150 | valid_orig_indices = [i for i in plot_indices if y[i] is not None] 151 | if not valid_orig_indices: 152 | continue 153 | plot_x = [numeric_teacher_sizes[i] for i in valid_orig_indices] 154 | plot_y = [y[i] - baseline for i in valid_orig_indices] 155 | 156 | # legend: model size 正常,括号里的 baseline 用斜体 157 | legend_label = f"{model} $\\mathit{{({baseline:.2f})}}$" 158 | sns.lineplot( 159 | x=plot_x, y=plot_y, marker="o", 160 | color=improv_color_map[model], linewidth=2, 161 | label=legend_label 162 | ) 163 | 164 | # 在每条曲线的起点正下方标记空心点 165 | # first_x = plot_x[0] 166 | # plt.scatter( 167 | # first_x, baseline_dot_y, s=20, 168 | # facecolors='none', edgecolors=improv_color_map[model], 169 | # linewidths=1.5, zorder=3 170 | # ) 171 | 172 | plt.axhline(0, color="0.7", linewidth=1) 173 | plt.ylim(bottom=label_y_for_baseline - 5) 174 | 175 | # xticks: size 正常,括号里的 acc 用斜体 176 | _xtick_labels = [ 177 | f"{size}\n${{({acc:.2f})}}$" 178 | for size, acc in zip(teacher_sizes_plot, pure_teacher_ticks) 179 | ] 180 | plt.xticks(ticks=x_plot, labels=_xtick_labels) 181 | plt.gca().set_xticks(x_plot) 182 | plt.gca().set_xticklabels(_xtick_labels) 183 | 184 | # plt.xlabel(r"Sharer Model Size $\mathit{(Accuracy)}$") 185 | plt.xlabel("Sharer Model Size (Accuracy)") 186 | plt.ylabel("Δ Accuracy") 187 | # plt.legend(title="Reciever Model Size\n" + r" $\mathit{(Accuracy)}$") 188 | plt.legend( 189 | title="Reciever Model Size (Accuracy)", 190 | ncol=5, 191 | loc='lower center', 192 | bbox_to_anchor=(0.5, 1.02), 193 | frameon=False, 194 | columnspacing=0.8, 195 | handlelength=1.2, 196 | handletextpad=0.4, 197 | borderpad=0.2, 198 | labelspacing=0.2, 199 | fontsize=10, 200 | title_fontsize=12, 201 | ) 202 | plt.gca().set_yticks([0, 10, 20, 30]) 203 | 204 | ax = plt.gca() 205 | ax.xaxis.set_label_coords(0.5, -0.30) 206 | plt.gcf().subplots_adjust(bottom=0.66, top=0.86) 207 | sns.despine() 208 | 209 | plt.tight_layout() 210 | plt.savefig("scaling_improvement_T2T.pdf", bbox_inches="tight") 211 | plt.show() 212 | 213 | # # 第三张图:基于 base 错误率的相对提升(Relative Error Reduction) 214 | # plt.figure(figsize=(6.4, 4), dpi=300) 215 | # plt.xscale('log', base=10) 216 | # ax = plt.gca() 217 | # trans_data_axes = mtransforms.blended_transform_factory(ax.transData, ax.transAxes) 218 | 219 | # baseline_dot_y = -0.65 220 | # label_y_for_baseline = baseline_dot_y - 0.30 221 | 222 | # for model in non_pure_models: 223 | # y = data[model] 224 | # baseline = y[0] 225 | # err_base = max(1e-6, 100.0 - baseline) # 基线错误率(百分比),防止除零 226 | # valid_orig_indices = [i for i in plot_indices if y[i] is not None] 227 | # if not valid_orig_indices: 228 | # continue 229 | # plot_x = [numeric_teacher_sizes[i] for i in valid_orig_indices] 230 | # # 相对错误率降低(百分比):(y - baseline) / (100 - baseline) * 100 231 | # plot_y = [((y[i] - baseline) / err_base) * 100.0 for i in valid_orig_indices] 232 | # legend_label = f"{model} $({baseline:.2f})$" 233 | # sns.lineplot(x=plot_x, y=plot_y, marker="o", color=model_to_color[model], linewidth=2, label=legend_label) 234 | 235 | # # 在每条曲线的起点下方放置空心圆(y<0)并标注基线绝对值 236 | # first_x = plot_x[0] 237 | # plt.scatter(first_x, baseline_dot_y, s=40, facecolors='none', edgecolors=model_to_color[model], linewidths=1.5, zorder=3) 238 | # ax.text(first_x, -0.08, f"({baseline:.2f})", transform=trans_data_axes, ha="center", va="top", fontsize=BASE_FONT_SIZE, color=TEXT_COLOR, fontstyle="italic", clip_on=False) 239 | 240 | # plt.axhline(0, color="0.7", linewidth=1) 241 | # plt.ylim(bottom=label_y_for_baseline - 1) 242 | # plt.xticks(ticks=x_plot, labels=teacher_sizes_plot) 243 | # plt.gca().set_xticks(x_plot) 244 | # plt.gca().set_xticklabels(teacher_sizes_plot) 245 | # plt.xlabel("Teacher Size (B)") 246 | # plt.ylabel("Relative Error Reduction vs Base-only (%)") 247 | # plt.title("Normalized Improvement by Base Error Rate") 248 | # plt.legend(title="Base Model") 249 | # plt.gca().yaxis.set_major_locator(MaxNLocator(nbins=4)) 250 | # ax = plt.gca() 251 | # ax.xaxis.set_label_coords(0.5, -0.18) 252 | # plt.gcf().subplots_adjust(bottom=0.28) 253 | # sns.despine() 254 | 255 | # plt.tight_layout() 256 | # plt.savefig("scaling_error_reduction.png", bbox_inches="tight") 257 | # plt.show() -------------------------------------------------------------------------------- /script/analysis/scaling/batch_evaluate_T2T.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | 批量进行 T2T(Two-Stage)评测的脚本 4 | - 直接从两个模型家族(family1: Qwen3,family2: Qwen2.5-Instruct)组合生成评测配置 5 | - 基于 T2T_scaling.yaml 作为模板生成临时评测配置 6 | - 组合规则:family1 尺寸 ≤ family2 尺寸(例如 tiny→tiny/small/medium/large/xlarge) 7 | - 支持跳过已有结果 8 | """ 9 | 10 | import os 11 | import sys 12 | import yaml 13 | import subprocess 14 | from pathlib import Path 15 | from typing import Dict, List, Tuple 16 | import logging 17 | 18 | # 日志设置 19 | logging.basicConfig( 20 | level=logging.INFO, 21 | format='%(asctime)s - %(levelname)s - %(message)s', 22 | handlers=[ 23 | logging.FileHandler('batch_evaluate_T2T.log'), 24 | logging.StreamHandler() 25 | ] 26 | ) 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | class T2TBatchEvaluator: 31 | """T2T 家族组合评测器""" 32 | 33 | def __init__(self, 34 | template_path: str = "T2T_scaling.yaml", 35 | evaluation_script: str = "script/evaluation/unified_evaluator.py", 36 | output_base_dir: str = "local/auto_eval_results_scaling_T2T"): 37 | # 模板与脚本路径 38 | self.template_path = template_path 39 | self.evaluation_script = evaluation_script 40 | self.output_base_dir = output_base_dir 41 | 42 | # 模型家族定义(与 batch_evaluate_checkpoints.py 保持一致) 43 | self.family1: Dict[str, str] = { 44 | "tiny": "Qwen/Qwen3-0.6B", 45 | "small": "Qwen/Qwen3-1.7B", 46 | "medium": "Qwen/Qwen3-4B", 47 | "large": "Qwen/Qwen3-8B", 48 | "xlarge": "Qwen/Qwen3-14B", 49 | } 50 | self.family2: Dict[str, str] = { 51 | "tiny": "Qwen/Qwen2.5-0.5B-Instruct", 52 | "small": "Qwen/Qwen2.5-1.5B-Instruct", 53 | "medium": "Qwen/Qwen2.5-3B-Instruct", 54 | "large": "Qwen/Qwen2.5-7B-Instruct", 55 | "xlarge": "Qwen/Qwen2.5-14B-Instruct", 56 | } 57 | 58 | # 尺寸标签(用于生成易读实验名) 59 | self.family1_size_label: Dict[str, str] = { 60 | "tiny": "0.6B", 61 | "small": "1.7B", 62 | "medium": "4B", 63 | "large": "8B", 64 | "xlarge": "14B", 65 | } 66 | self.family2_size_label: Dict[str, str] = { 67 | "tiny": "0.5B", 68 | "small": "1.5B", 69 | "medium": "3B", 70 | "large": "7B", 71 | "xlarge": "14B", 72 | } 73 | 74 | # 尺寸顺序与等级 75 | self.size_order: List[str] = ["tiny", "small", "medium", "large", "xlarge"] 76 | self.size_rank: Dict[str, int] = {k: i for i, k in enumerate(self.size_order)} 77 | 78 | # ------------------------------ 模板与配置 ------------------------------ 79 | def load_template(self) -> Dict: 80 | """加载 T2T 评测模板配置""" 81 | try: 82 | with open(self.template_path, 'r', encoding='utf-8') as f: 83 | return yaml.safe_load(f) 84 | except Exception as e: 85 | logger.error(f"Failed to load template {self.template_path}: {e}") 86 | raise 87 | 88 | def create_eval_config(self, 89 | experiment_name: str, 90 | answer_model_path: str, 91 | context_model_path: str) -> str: 92 | """ 93 | 根据模板创建单次实验的评测配置文件 94 | 返回生成的配置文件路径 95 | """ 96 | cfg = self.load_template() 97 | 98 | # 覆盖模型路径 99 | cfg.setdefault("model", {}) 100 | cfg["model"]["model_name"] = "two_stage" 101 | cfg["model"]["answer_model_path"] = answer_model_path 102 | cfg["model"]["context_model_path"] = context_model_path 103 | 104 | # 输出路径带上数据集名(如 mmlu-redux) 105 | dataset = cfg.get("eval", {}).get("dataset", "mmlu-redux") 106 | output_dir = f"{self.output_base_dir}/{experiment_name}_{dataset}" 107 | cfg.setdefault("output", {}) 108 | cfg["output"]["output_dir"] = output_dir 109 | 110 | # 将生成的配置写入临时文件 111 | config_filename = f"eval_recipe/T2T_scaling_{experiment_name}_eval.yaml" 112 | os.makedirs(Path(config_filename).parent, exist_ok=True) 113 | with open(config_filename, 'w', encoding='utf-8') as f: 114 | yaml.dump(cfg, f, default_flow_style=False, allow_unicode=True) 115 | 116 | logger.info(f"Created eval config: {config_filename}") 117 | logger.info(f" Answer model: {answer_model_path}") 118 | logger.info(f" Context model: {context_model_path}") 119 | logger.info(f" Output dir: {output_dir}") 120 | return config_filename 121 | 122 | # ------------------------------ 组合生成 ------------------------------ 123 | def _experiment_name(self, size_key_answer: str, size_key_context: str) -> str: 124 | ans = self.family1_size_label.get(size_key_answer, size_key_answer) 125 | ctx = self.family2_size_label.get(size_key_context, size_key_context) 126 | return f"{ans}+{ctx}" 127 | 128 | def generate_ge_pairs(self) -> List[Tuple[str, str, str]]: 129 | """ 130 | 生成满足 family1 尺寸 ≤ family2 尺寸 的组合 131 | 返回 (experiment_name, answer_model_path, context_model_path) 132 | """ 133 | pairs: List[Tuple[str, str, str]] = [] 134 | f1_keys = [k for k in self.size_order if k in self.family1] 135 | f2_keys = [k for k in self.size_order if k in self.family2] 136 | for k1 in f1_keys: 137 | for k2 in f2_keys: 138 | if self.size_rank[k2] >= self.size_rank[k1]: 139 | exp = self._experiment_name(k1, k2) 140 | pairs.append((exp, self.family1[k1], self.family2[k2])) 141 | return pairs 142 | 143 | # ------------------------------ 运行评测 ------------------------------ 144 | def check_existing_results(self, experiment_name: str, dataset: str) -> bool: 145 | out_dir = Path(f"{self.output_base_dir}/{experiment_name}_{dataset}") 146 | if out_dir.exists() and any(out_dir.iterdir()): 147 | logger.info(f"Results already exist for {experiment_name}, skipping") 148 | return True 149 | return False 150 | 151 | def run_evaluation(self, config_path: str, experiment_name: str) -> bool: 152 | logger.info(f"Starting evaluation for experiment: {experiment_name}") 153 | try: 154 | cmd = [ 155 | "python", self.evaluation_script, 156 | "--config", config_path, 157 | ] 158 | logger.info(f"Running command: {' '.join(cmd)}") 159 | 160 | process = subprocess.Popen( 161 | cmd, 162 | cwd=os.getcwd(), 163 | stdout=subprocess.PIPE, 164 | stderr=subprocess.STDOUT, 165 | text=True, 166 | bufsize=1, 167 | ) 168 | if process.stdout: 169 | for line in process.stdout: 170 | sys.stdout.write(line) 171 | sys.stdout.flush() 172 | process.wait() 173 | if process.returncode == 0: 174 | logger.info(f"Evaluation completed successfully for {experiment_name}") 175 | return True 176 | else: 177 | logger.error(f"Evaluation failed for {experiment_name}") 178 | logger.error(f"Process returned code {process.returncode}") 179 | return False 180 | except Exception as e: 181 | logger.error(f"Exception during evaluation {experiment_name}: {e}") 182 | return False 183 | 184 | def cleanup_temp_configs(self, temp_configs: List[str]): 185 | for config_path in temp_configs: 186 | try: 187 | if os.path.exists(config_path): 188 | os.remove(config_path) 189 | logger.info(f"Removed temporary config: {config_path}") 190 | except Exception as e: 191 | logger.warning(f"Failed to remove temporary config {config_path}: {e}") 192 | 193 | # ------------------------------ 批量主流程 ------------------------------ 194 | def run_batch(self, skip_existing: bool = True): 195 | logger.info("=" * 80) 196 | logger.info("STARTING BATCH T2T EVALUATION") 197 | logger.info("=" * 80) 198 | 199 | # 预读模板,拿到 dataset 用于 skip 判断 200 | template = self.load_template() 201 | dataset = template.get("eval", {}).get("dataset", "mmlu-redux") 202 | 203 | pairs = self.generate_ge_pairs() 204 | logger.info(f"Prepared {len(pairs)} experiment pairs (F1<=F2)") 205 | for exp, ans, ctx in pairs: 206 | logger.info(f" ✓ {exp} -> answer={ans} | context={ctx}") 207 | 208 | success = 0 209 | failed = 0 210 | skipped = 0 211 | temp_configs: List[str] = [] 212 | 213 | for i, (exp_name, ans_path, ctx_path) in enumerate(pairs, 1): 214 | logger.info(f"\n[{i}/{len(pairs)}] Processing experiment: {exp_name}") 215 | 216 | if skip_existing and self.check_existing_results(exp_name, dataset): 217 | skipped += 1 218 | continue 219 | 220 | try: 221 | cfg_path = self.create_eval_config(exp_name, ans_path, ctx_path) 222 | temp_configs.append(cfg_path) 223 | ok = self.run_evaluation(cfg_path, exp_name) 224 | if ok: 225 | success += 1 226 | else: 227 | failed += 1 228 | except Exception as e: 229 | logger.error(f"Failed to process {exp_name}: {e}") 230 | failed += 1 231 | 232 | self.cleanup_temp_configs(temp_configs) 233 | 234 | logger.info("\n" + "=" * 80) 235 | logger.info("BATCH T2T EVALUATION SUMMARY") 236 | logger.info("=" * 80) 237 | logger.info(f"Total experiments: {len(pairs)}") 238 | logger.info(f"Successful evaluations: {success}") 239 | logger.info(f"Failed evaluations: {failed}") 240 | logger.info(f"Skipped evaluations: {skipped}") 241 | 242 | 243 | def main(): 244 | import argparse 245 | parser = argparse.ArgumentParser(description='Batch T2T evaluation across model families') 246 | parser.add_argument('--skip-existing', action='store_true', default=True, 247 | help='Skip experiments that already have results (default: True)') 248 | parser.add_argument('--no-skip-existing', action='store_false', dest='skip_existing', 249 | help='Evaluate all experiments, even if results exist') 250 | 251 | parser.add_argument('--template', type=str, default='T2T_scaling.yaml', 252 | help='Path to the T2T template yaml (default: T2T_scaling.yaml)') 253 | parser.add_argument('--output-base-dir', type=str, default='local/auto_eval_results_scaling_T2T', 254 | help='Base output directory for results') 255 | 256 | args = parser.parse_args() 257 | 258 | evaluator = T2TBatchEvaluator(template_path=args.template, 259 | output_base_dir=args.output_base_dir) 260 | evaluator.run_batch(skip_existing=args.skip_existing) 261 | 262 | 263 | if __name__ == "__main__": 264 | main() 265 | -------------------------------------------------------------------------------- /script/analysis/proportion/auto_kv_cache_evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | KV Cache 比例和顺序自动评测脚本 4 | 基于test_eval.yaml模板,自动测试不同的kv_cache_proportion和kv_cache_order_mode组合 5 | """ 6 | 7 | import os 8 | import yaml 9 | import subprocess 10 | import logging 11 | import json 12 | from pathlib import Path 13 | from typing import Dict, List, Tuple 14 | from datetime import datetime 15 | import time 16 | 17 | # 设置日志 18 | logging.basicConfig( 19 | level=logging.INFO, 20 | format='%(asctime)s - %(levelname)s - %(message)s', 21 | handlers=[ 22 | logging.FileHandler('auto_kv_cache_evaluation.log'), 23 | logging.StreamHandler() 24 | ] 25 | ) 26 | logger = logging.getLogger(__name__) 27 | 28 | class KVCacheAutoEvaluator: 29 | """KV Cache自动评测类""" 30 | 31 | def __init__(self, base_config_path: str = "eval_recipe/test_eval.yaml"): 32 | self.base_config_path = base_config_path 33 | self.base_config = self.load_base_config() 34 | 35 | # 定义测试参数组合 36 | self.proportions = [0.0, 0.25, 0.5, 0.75, 1.0] 37 | self.order_modes = ["front", "back"] 38 | 39 | # 生成所有组合 40 | self.test_combinations = [] 41 | for proportion in self.proportions: 42 | for order_mode in self.order_modes: 43 | experiment_name = f"prop_{proportion:0.2f}_order_{order_mode}" 44 | self.test_combinations.append((proportion, order_mode, experiment_name)) 45 | 46 | # 创建配置和结果目录 47 | self.config_dir = Path("eval_configs_kv_cache") 48 | self.config_dir.mkdir(exist_ok=True) 49 | 50 | logger.info(f"初始化完成,将测试{len(self.test_combinations)}种配置组合") 51 | 52 | def load_base_config(self) -> Dict: 53 | """加载基础配置文件""" 54 | try: 55 | with open(self.base_config_path, 'r', encoding='utf-8') as f: 56 | return yaml.safe_load(f) 57 | except Exception as e: 58 | logger.error(f"Failed to load base config {self.base_config_path}: {e}") 59 | raise 60 | 61 | def create_config_for_combination(self, proportion: float, order_mode: str, experiment_name: str) -> str: 62 | """为特定组合创建配置文件""" 63 | # 深拷贝基础配置 64 | config = yaml.safe_load(yaml.dump(self.base_config)) 65 | 66 | # 修改KV cache参数 67 | config["eval"]["kv_cache_proportion"] = proportion 68 | config["eval"]["kv_cache_order_mode"] = order_mode 69 | 70 | # 修改输出目录 71 | base_output_dir = config["output"]["output_dir"] 72 | config["output"]["output_dir"] = f"{base_output_dir}/{experiment_name}" 73 | 74 | # 保存配置文件 75 | config_path = self.config_dir / f"config_{experiment_name}.yaml" 76 | with open(config_path, 'w', encoding='utf-8') as f: 77 | yaml.dump(config, f, default_flow_style=False, allow_unicode=True) 78 | 79 | logger.info(f"Created config: {config_path}") 80 | logger.info(f" - Proportion: {proportion}") 81 | logger.info(f" - Order mode: {order_mode}") 82 | logger.info(f" - Output dir: {config['output']['output_dir']}") 83 | 84 | return str(config_path) 85 | 86 | def run_evaluation(self, config_path: str, experiment_name: str) -> bool: 87 | """运行单个评估""" 88 | logger.info(f"开始评估实验: {experiment_name}") 89 | 90 | try: 91 | # 构建评估命令 92 | cmd = [ 93 | "python", "script/evaluation/unified_evaluator.py", 94 | "--config", config_path 95 | ] 96 | 97 | logger.info(f"运行命令: {' '.join(cmd)}") 98 | 99 | # 运行评估(实时输出) 100 | process = subprocess.Popen( 101 | cmd, 102 | cwd=os.getcwd(), 103 | stdout=subprocess.PIPE, 104 | stderr=subprocess.STDOUT, 105 | text=True, 106 | bufsize=1, 107 | ) 108 | 109 | # 实时输出日志 110 | assert process.stdout is not None 111 | for line in process.stdout: 112 | print(f"[{experiment_name}] {line.rstrip()}") 113 | 114 | process.wait() 115 | 116 | if process.returncode == 0: 117 | logger.info(f"评估成功完成: {experiment_name}") 118 | return True 119 | else: 120 | logger.error(f"评估失败: {experiment_name}, 返回码: {process.returncode}") 121 | return False 122 | 123 | except Exception as e: 124 | logger.error(f"评估异常: {experiment_name}, 错误: {e}") 125 | return False 126 | 127 | def extract_results(self, output_dir: str) -> Dict: 128 | """从输出目录提取结果""" 129 | result_info = { 130 | "status": "unknown", 131 | "accuracy": None, 132 | "details": {} 133 | } 134 | 135 | output_path = Path(output_dir) 136 | if not output_path.exists(): 137 | result_info["status"] = "missing_output" 138 | return result_info 139 | 140 | # 查找结果文件(可能的命名模式) 141 | possible_files = [ 142 | "final_results.json", 143 | "evaluation_results.json", 144 | "results.json", 145 | "summary.json" 146 | ] 147 | 148 | results_file = None 149 | for filename in possible_files: 150 | candidate = output_path / filename 151 | if candidate.exists(): 152 | results_file = candidate 153 | break 154 | 155 | if results_file: 156 | try: 157 | with open(results_file, 'r', encoding='utf-8') as f: 158 | data = json.load(f) 159 | result_info["status"] = "success" 160 | result_info["details"] = data 161 | 162 | # 尝试提取准确率 163 | if "accuracy" in data: 164 | result_info["accuracy"] = data["accuracy"] 165 | elif "overall_accuracy" in data: 166 | result_info["accuracy"] = data["overall_accuracy"] 167 | elif "avg_accuracy" in data: 168 | result_info["accuracy"] = data["avg_accuracy"] 169 | 170 | except Exception as e: 171 | logger.warning(f"无法解析结果文件 {results_file}: {e}") 172 | result_info["status"] = "parse_error" 173 | else: 174 | result_info["status"] = "no_results_file" 175 | logger.warning(f"在 {output_dir} 中未找到结果文件") 176 | 177 | return result_info 178 | 179 | def run_all_experiments(self) -> Dict: 180 | """运行所有实验组合""" 181 | logger.info("="*80) 182 | logger.info("开始KV Cache自动评测实验") 183 | logger.info("="*80) 184 | 185 | logger.info(f"计划测试 {len(self.test_combinations)} 种配置组合:") 186 | for i, (proportion, order_mode, experiment_name) in enumerate(self.test_combinations, 1): 187 | logger.info(f" {i:2d}. {experiment_name} (proportion={proportion}, order_mode={order_mode})") 188 | 189 | results = {} 190 | successful_experiments = 0 191 | 192 | for i, (proportion, order_mode, experiment_name) in enumerate(self.test_combinations, 1): 193 | logger.info(f"\n[{i}/{len(self.test_combinations)}] 处理组合: {experiment_name}") 194 | 195 | try: 196 | # 1. 创建配置文件 197 | config_path = self.create_config_for_combination(proportion, order_mode, experiment_name) 198 | 199 | # 2. 运行评估 200 | success = self.run_evaluation(config_path, experiment_name) 201 | 202 | # 3. 提取结果 203 | if success: 204 | output_dir = f"{self.base_config['output']['output_dir']}/{experiment_name}" 205 | result_info = self.extract_results(output_dir) 206 | results[experiment_name] = { 207 | "proportion": proportion, 208 | "order_mode": order_mode, 209 | "evaluation_success": True, 210 | "result_info": result_info 211 | } 212 | successful_experiments += 1 213 | else: 214 | results[experiment_name] = { 215 | "proportion": proportion, 216 | "order_mode": order_mode, 217 | "evaluation_success": False, 218 | "result_info": {"status": "evaluation_failed"} 219 | } 220 | 221 | except Exception as e: 222 | logger.error(f"实验 {experiment_name} 发生异常: {e}") 223 | results[experiment_name] = { 224 | "proportion": proportion, 225 | "order_mode": order_mode, 226 | "evaluation_success": False, 227 | "result_info": {"status": "exception", "error": str(e)} 228 | } 229 | 230 | # 输出最终结果 231 | logger.info(f"\n" + "="*80) 232 | logger.info("最终结果汇总") 233 | logger.info("="*80) 234 | logger.info(f"总实验数: {len(self.test_combinations)}") 235 | logger.info(f"成功完成: {successful_experiments}") 236 | logger.info(f"失败数量: {len(self.test_combinations) - successful_experiments}") 237 | 238 | # 详细结果 239 | logger.info(f"\n详细结果:") 240 | for exp_name, result in results.items(): 241 | status = "SUCCESS ✓" if result["evaluation_success"] else "FAILED ✗" 242 | accuracy = result["result_info"].get("accuracy", "N/A") 243 | logger.info(f" {exp_name}: {status} (准确率: {accuracy})") 244 | 245 | return results 246 | 247 | def save_summary_report(self, results: Dict): 248 | """保存汇总报告""" 249 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 250 | 251 | # 保存详细结果 252 | detailed_results_file = f"kv_cache_evaluation_results_{timestamp}.json" 253 | with open(detailed_results_file, 'w', encoding='utf-8') as f: 254 | json.dump(results, f, indent=4, ensure_ascii=False) 255 | logger.info(f"详细结果已保存到: {detailed_results_file}") 256 | 257 | # 创建CSV汇总表 258 | csv_file = f"kv_cache_evaluation_summary_{timestamp}.csv" 259 | with open(csv_file, 'w', encoding='utf-8') as f: 260 | f.write("experiment_name,proportion,order_mode,evaluation_success,accuracy,status\n") 261 | for exp_name, result in results.items(): 262 | proportion = result["proportion"] 263 | order_mode = result["order_mode"] 264 | eval_success = result["evaluation_success"] 265 | accuracy = result["result_info"].get("accuracy", "") 266 | status = result["result_info"].get("status", "") 267 | f.write(f"{exp_name},{proportion},{order_mode},{eval_success},{accuracy},{status}\n") 268 | logger.info(f"CSV汇总已保存到: {csv_file}") 269 | 270 | return detailed_results_file, csv_file 271 | 272 | 273 | def main(): 274 | """主函数""" 275 | evaluator = KVCacheAutoEvaluator() 276 | 277 | # 运行所有实验 278 | results = evaluator.run_all_experiments() 279 | 280 | # 保存汇总报告 281 | evaluator.save_summary_report(results) 282 | 283 | logger.info("\n" + "="*80) 284 | logger.info("KV Cache自动评测完成!") 285 | logger.info("="*80) 286 | 287 | 288 | if __name__ == "__main__": 289 | main() 290 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | Cache-to-Cache Logo 3 | 4 |

Cache-to-Cache

5 |

Direct Semantic Communication Between Large Language Models

6 | 7 |

8 | 🌐 Project Page • 9 | 📑 Paper • 10 | 🤗 HuggingFace • 11 | 🚀 Live Demo 12 |

13 | 14 |
15 | 16 | Cache-to-Cache (C2C) enables Large Language Models to communicate directly through their KV-Caches, bypassing text generation. By projecting and fusing KV-Caches between models, C2C achieves 8.5–10.5% higher accuracy than individual models and 3.0–5.0% better performance than text-based communication, with 2.0× speedup in latency. 17 | 18 | Feel free to star the repo or cite the paper if you find it interesting. 19 | 20 | ```bibtex 21 | @article{fu2025c2c, 22 | title={Cache-to-Cache: Direct Semantic Communication Between Large Language Models}, 23 | author={Tianyu Fu and Zihan Min and Hanling Zhang and Jichao Yan and Guohao Dai and Wanli Ouyang and Yu Wang}, 24 | journal={arXiv preprint arXiv:2510.03215}, 25 | year={2025}, 26 | } 27 | ``` 28 | 29 | > **Why "Rosetta"?** The Python package is named after the **Rosetta Stone**, the ancient artefact that unlocked the translation of Egyptian hieroglyphs by presenting the same text in multiple scripts. Likewise, C2C translates KV-cache representations between otherwise independent LLMs, allowing them to speak a common language in a richer and more direct way. 30 | 31 | 32 | ## News 33 | 34 | [2025/12] 🧪 Multi-sharer support is now available! Fuse KV-caches from multiple sharer models to a single receiver. This feature is in preliminary stages and we are still actively working on it. See `live_chat_example.py` for usage. 35 | 36 | [2025/11] 🚀 Thank you for the enthusiasm from the community! [Live demo](https://huggingface.co/spaces/nics-efc/C2C_demo) is now available! Try C2C in action with side-by-side model comparison. 37 | 38 | [2025/10] 🤗 Our paper was featured as the **#1 Paper of the Day** on [Hugging Face Daily Papers](https://huggingface.co/papers/2510.03215) 39 | 40 | ## Demo 41 | 42 |
43 | Combined Wisdom 44 | 45 | > Only by combining latent semantics from both Qwen2.5 and Qwen3 can this philosophical question be correctly answered. 46 | > 47 | > https://github.com/user-attachments/assets/c36ffaa1-0297-4ed8-b472-1fbbd9cc397f 48 | 49 |
50 | 51 | The demo can be reproduced with `script/playground/gradio_demo.py`. 52 | 53 | ## Environment Setup 54 | 55 | Create a new environment: 56 | 57 | ```bash 58 | conda create -n rosetta python=3.10 59 | conda activate rosetta 60 | ``` 61 | 62 | Install the package: 63 | 64 | ```bash 65 | pip install -e . 66 | ``` 67 | 68 | For training and evaluation, install additional dependencies: 69 | 70 | ```bash 71 | pip install -e ".[training,evaluation]" 72 | ``` 73 | 74 | ## How to 75 | 76 | ### Use Hugging Face weights 77 | 78 | Minimal example to load published C2C weights from the Hugging Face collection and run the provided inference script: 79 | 80 | ```python 81 | import torch 82 | from huggingface_hub import snapshot_download 83 | from script.playground.inference_example import load_rosetta_model, run_inference_example 84 | 85 | checkpoint_dir = snapshot_download( 86 | repo_id="nics-efc/C2C_Fuser", 87 | allow_patterns=["qwen3_0.6b+qwen2.5_0.5b_Fuser/*"], 88 | ) 89 | 90 | model_config = { 91 | "rosetta_config": { 92 | "base_model": "Qwen/Qwen3-0.6B", 93 | "teacher_model": "Qwen/Qwen2.5-0.5B-Instruct", 94 | "checkpoints_dir": f"{checkpoint_dir}/qwen3_0.6b+qwen2.5_0.5b_Fuser/final", 95 | } 96 | } 97 | 98 | rosetta_model, tokenizer = load_rosetta_model(model_config, eval_config={}, device=torch.device("cuda")) 99 | device = rosetta_model.device 100 | 101 | prompt = [{"role": "user", "content": "Say hello in one short sentence."}] 102 | input_text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True, enable_thinking=False) 103 | inputs = tokenizer(input_text, return_tensors="pt").to(device) 104 | 105 | instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(inputs['input_ids'].shape[1] - 1, 1).unsqueeze(0).to(device) 106 | label_index = torch.tensor([-1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(device) 107 | kv_cache_index = [instruction_index, label_index] 108 | 109 | with torch.no_grad(): 110 | sampling_params = { 111 | 'do_sample': False, 112 | 'max_new_tokens': 256 113 | } 114 | outputs = rosetta_model.generate(**inputs, kv_cache_index=kv_cache_index, **sampling_params) 115 | output_text = tokenizer.decode(outputs[0, instruction_index.shape[1] + 1:], skip_special_tokens=True) 116 | print(f"C2C output text: {output_text}") 117 | ``` 118 | 119 | ### Run an example 120 | 121 | We provide an interactive chat example to demonstrate cache-to-cache communication with pre-trained projectors in `script/playground/live_chat_example.py`. 122 | 123 | ```bash 124 | # Single sharer 125 | python script/playground/live_chat_example.py --checkpoint_dir path/to/checkpoint 126 | 127 | # Multiple sharers (teacher models read from each checkpoint's config.json) 128 | python script/playground/live_chat_example.py --checkpoint_dir path/to/ckpt1 path/to/ckpt2 129 | ``` 130 | 131 | ### Apply Cache-to-Cache 132 | 133 | You can apply C2C to your own models with a few lines of code. Here is an example: 134 | 135 | ```python 136 | import torch 137 | from transformers import AutoModelForCausalLM 138 | from rosetta.model.wrapper import RosettaModel 139 | from rosetta.model.projector import C2CProjector 140 | 141 | # Load target (receiver) and source (sharer) models 142 | target_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B") 143 | source_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") 144 | 145 | # Create C2C projector for KV-Cache transformation 146 | projector_list = [] 147 | for i in range(target_model.config.num_hidden_layers): 148 | projector = C2CProjector( 149 | source_dim=128, target_dim=128, 150 | source_num_heads=8, target_num_heads=8, 151 | hidden_dim=1024, num_layers=3 152 | ) 153 | projector_list.append(projector) 154 | # If you want to use a pretrained projector, you can load it from the checkpoint directory 155 | 156 | # Wrap with RosettaModel for cache-to-cache communication 157 | c2c_model = RosettaModel( 158 | model_list=[target_model, source_model], 159 | base_model_idx=0, 160 | projector_list=projector_list 161 | ) 162 | 163 | # Configure layer-wise projection mappings 164 | for idx, layer_idx in enumerate(range(target_model.config.num_hidden_layers)): 165 | c2c_model.set_projector_config( 166 | source_model_idx=1, source_model_layer_idx=layer_idx, 167 | target_model_idx=0, target_model_layer_idx=layer_idx, 168 | projector_idx=idx 169 | ) 170 | 171 | # Generate: kv_cache_index controls when to apply C2C projection 172 | # [1, 0] = apply projection from sharer 1, [-1, 0] = no projection 173 | seq_len = input_ids.shape[1] 174 | instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(seq_len-1, 1)[None, :, :] 175 | response_index = torch.tensor([[-1, 0]], dtype=torch.long)[None, :, :] 176 | outputs = c2c_model.generate( 177 | kv_cache_index=[instruction_index, response_index], 178 | input_ids=inputs.input_ids, 179 | ) 180 | ``` 181 | 182 | ### Train C2C Projectors 183 | 184 | Prepare a training configuration file in `recipe/train_recipe/`. Specify the base model, teacher model, projector type and parameters, training hyperparameters, dataset, and output directory. See `recipe/train_recipe/C2C_0.6+0.5.json` for a complete example. 185 | 186 | Run training: 187 | 188 | ```bash 189 | # Single GPU 190 | python script/train/SFT_train.py --config recipe/train_recipe/C2C_0.6+0.5.json 191 | 192 | # Multi-GPU 193 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 194 | torchrun --nproc_per_node=8 script/train/SFT_train.py \ 195 | --config recipe/train_recipe/C2C_0.6+0.5.json 196 | ``` 197 | 198 | During training, only the C2C projector parameters are updated while both source and target models remain frozen. 199 | 200 | ### Evaluate C2C 201 | 202 | Prepare an evaluation configuration file in `recipe/eval_recipe/`. Specify the model configuration with base model, teacher model, checkpoint directory, generation config, evaluation dataset, and output directory. See `recipe/eval_recipe/unified_eval.yaml` for a complete example. 203 | 204 | Run evaluation: 205 | 206 | ```bash 207 | python script/evaluation/unified_evaluator.py --config recipe/eval_recipe/unified_eval.yaml 208 | ``` 209 | 210 | ## Understanding the Code 211 | 212 | ### Code Structure 213 | 214 | - `rosetta/`: The main package for Cache-to-Cache. 215 | - `model/`: Core model components. 216 | - `train/`: Training utilities. 217 | - `baseline/`: Baseline implementations. 218 | - `utils/`: Utility functions for evaluation and model registry. 219 | - `script/`: Scripts for running experiments. 220 | - `train/`: Training scripts including `SFT_train.py`. 221 | - `evaluation/`: Evaluation scripts including `unified_evaluator.py`. 222 | - `dataset/`: Dataset preparation scripts. 223 | - `examples/`: Usage examples. 224 | - `recipe/`: Configuration files. 225 | - `train_recipe/`: Training configurations (e.g., `C2C_0.6+0.5.json`). 226 | - `eval_recipe/`: Evaluation configurations (e.g., `unified_eval.yaml`). 227 | 228 | ### Adding Projector 229 | 230 | Add a new projector architecture in `rosetta/model/projector.py`. The projector transforms source model's KV-Cache to target model's semantic space. 231 | 232 | > Key components: projection networks (MLP/concat-based), gating mechanism for layer-wise selection, and temperature-annealed training. See `C2CProjector` in `rosetta/model/projector.py` as an example. 233 | 234 | ```python 235 | from rosetta.utils.registry import register_model, capture_init_args 236 | 237 | @register_model 238 | @capture_init_args 239 | class MyProjector(Projector): 240 | def __init__(self, source_dim, target_dim, **kwargs): 241 | super().__init__() 242 | # Your architecture 243 | self.projection = nn.Linear(source_dim, target_dim) 244 | self.gate = nn.Parameter(torch.tensor(0.0)) 245 | def forward(self, source_kv, target_kv): 246 | # Project and fuse KV-caches 247 | return projected_kv 248 | ``` 249 | 250 | Register in configuration: `{"projector": {"type": "MyProjector", "params": {...}}}` 251 | 252 | ### Adding Dataset 253 | 254 | Add a new dataset in `rosetta/train/dataset_adapters.py` for training with your data. 255 | 256 | ```python 257 | @dataclass 258 | class MyDatasetConfig(DatasetConfig): 259 | dataset_name: str = "my_dataset" 260 | def load(self): 261 | return load_dataset("path/to/dataset") 262 | 263 | def my_formatting_func(examples): 264 | return {"text": [f"Q: {q}\nA: {a}" for q, a in zip(...)]} 265 | 266 | DATASET_CONFIGS["MyDataset"] = MyDatasetConfig 267 | FORMATTING_FUNCS["MyDataset"] = my_formatting_func 268 | ``` 269 | 270 | Use in configuration: `{"data": {"type": "MyDataset"}}` 271 | 272 | ### Adding Benchmark 273 | 274 | Add evaluation logic in `script/evaluation/` following the pattern in `unified_evaluator.py`. The evaluator loads models, runs inference, and computes metrics for your benchmark dataset. 275 | 276 | ## Supported Model Pairs 277 | 278 | ### Qwen Family 279 | 280 | * Qwen3-0.6B + Qwen2.5-0.5B-Instruct 281 | * Qwen3-0.6B + Llama-3.2-1B-Instruct 282 | * Qwen3-0.6B + Qwen3-4B-Base 283 | 284 | ### Other Configurations 285 | 286 | C2C supports arbitrary model pairs. The framework automatically handles: 287 | - Different hidden dimensions 288 | - Different number of layers 289 | - Different attention head configurations 290 | - Different tokenizers 291 | 292 | To use custom model pairs, simply specify them in your configurations. 293 | -------------------------------------------------------------------------------- /script/playground/inference_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example inference using RosettaModel with Qwen3-0.6B and Qwen3-1.7B models and MLP projector 3 | """ 4 | 5 | import torch 6 | import sys 7 | import os 8 | from pathlib import Path 9 | 10 | # Add the project root to the path 11 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 12 | 13 | from transformers import AutoModelForCausalLM, AutoTokenizer 14 | from rosetta.model.wrapper import RosettaModel 15 | from rosetta.model.aligner import TokenAligner, AlignmentStrategy 16 | from rosetta.model.projector import AllInOneProjector 17 | from rosetta.train.dataset_adapters import generate_kv_cache_index 18 | from typing import Dict, Any, List, Tuple, Optional 19 | from transformers import AutoModelForCausalLM, AutoTokenizer 20 | 21 | from rosetta.model.projector import load_projector 22 | from rosetta.model.wrapper import RosettaModel 23 | from rosetta.utils.evaluate import set_default_chat_template 24 | import re 25 | 26 | def test_token_aligner(slm_tokenizer: AutoTokenizer, llm_tokenizer: AutoTokenizer): 27 | """Test the TokenAligner functionality 28 | Args: 29 | slm_tokenizer: SLM tokenizer 30 | llm_tokenizer: LLM tokenizer 31 | """ 32 | print("\n" + "="*80) 33 | print("Testing TokenAligner") 34 | print("="*80) 35 | 36 | # Test with FIRST strategy 37 | aligner_first = TokenAligner( 38 | slm_tokenizer=slm_tokenizer, 39 | llm_tokenizer=llm_tokenizer, 40 | strategy=AlignmentStrategy.FIRST, 41 | verbose=True 42 | ) 43 | 44 | # Test with LONGEST strategy 45 | aligner_longest = TokenAligner( 46 | slm_tokenizer=slm_tokenizer, 47 | llm_tokenizer=llm_tokenizer, 48 | strategy=AlignmentStrategy.LONGEST, 49 | verbose=True 50 | ) 51 | 52 | # Test text samples 53 | test_texts = [ 54 | "Hello world!", 55 | "The future of artificial intelligence is", 56 | "北京是中国的首都", # Chinese text 57 | "🚀 Emojis and special characters!", 58 | ] 59 | 60 | for text in test_texts: 61 | print(f"\nTest text: '{text}'") 62 | print("-" * 40) 63 | 64 | # Test FIRST strategy 65 | print("\nFIRST Strategy:") 66 | aligner_first.visualize_alignment(text) 67 | 68 | # Test LONGEST strategy 69 | print("\nLONGEST Strategy:") 70 | aligner_longest.visualize_alignment(text) 71 | 72 | # Test alignment without visualization 73 | sample_text = "This is a test." 74 | slm_tokens, aligned_llm_tokens = aligner_first.align_sequence(sample_text) 75 | print(f"\nQuick alignment test for: '{sample_text}'") 76 | print(f"SLM tokens: {slm_tokens}") 77 | print(f"Aligned LLM tokens: {aligned_llm_tokens}") 78 | 79 | print("\n✅ TokenAligner test completed") 80 | 81 | 82 | def run_inference_example(rosetta_model: RosettaModel, tokenizer: AutoTokenizer, prompt: str): 83 | """Run inference example with RosettaModel 84 | Args: 85 | rosetta_model: RosettaModel 86 | tokenizer: AutoTokenizer 87 | prompt: str 88 | """ 89 | print("Running inference example...") 90 | 91 | device = rosetta_model.device 92 | 93 | # Prepare input 94 | 95 | 96 | input_text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True, enable_thinking=False) 97 | print(f"Input text: {input_text}") 98 | inputs = tokenizer(input_text, return_tensors="pt").to(device) 99 | print(f"Input tokens: {inputs['input_ids']}") 100 | instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(inputs['input_ids'].shape[1] - 1, 1).unsqueeze(0).to(device) 101 | label_index = torch.tensor([-1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(device) 102 | kv_cache_index = [instruction_index, label_index] 103 | # slm_tokenizer = tokenizer 104 | # llm_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") 105 | # strategy = "first" 106 | # aligner = TokenAligner(slm_tokenizer=slm_tokenizer, llm_tokenizer=llm_tokenizer, strategy=AlignmentStrategy(strategy)) 107 | # messages = [ 108 | # {"role": "user", "content": prompt} 109 | # ] 110 | # details = aligner.align_chat_messages(messages, add_generation_prompt=True, return_details=True) 111 | # slm_ids = torch.tensor(details['slm_ids_padded']).unsqueeze(0) 112 | # llm_ids = torch.tensor(details['llm_ids_padded']).unsqueeze(0) 113 | 114 | # slm_pad_mask = torch.tensor(details['slm_padding_mask']).unsqueeze(0) 115 | # llm_pad_mask = torch.tensor(details['llm_padding_mask']).unsqueeze(0) 116 | 117 | # slm_attention_mask = (~slm_pad_mask).float() 118 | # llm_attention_mask = (~llm_pad_mask).float() 119 | 120 | # message_mask = torch.tensor(details['message_mask']) 121 | # kv_cache_index = generate_kv_cache_index(slm_ids.shape[1], slm_ids.shape[1]) 122 | # kv_cache_index[~message_mask] = torch.tensor([[-1,0]]) 123 | 124 | # kv_idx = kv_cache_index 125 | # change_points = [0] 126 | # for i in range(1, kv_idx.size(0)): 127 | # if not torch.equal(kv_idx[i], kv_idx[i - 1]): 128 | # change_points.append(i) 129 | # change_points.append(kv_idx.size(0)) 130 | 131 | # kv_cache_list = [] 132 | 133 | # for i in range(len(change_points) - 1): 134 | # start = change_points[i] 135 | # end = change_points[i + 1] 136 | # kv_cache_list.append(kv_idx[start:end, :].unsqueeze(0).to(device)) 137 | # prefill_kv_cache_list = kv_cache_list[:-1] 138 | # print(f"Input prompt: '{prompt}'") 139 | # print(f"Input shape: {slm_ids.shape}") 140 | # print(f"Device: {device}") 141 | 142 | # slm_ids = slm_ids.to(device) 143 | # llm_ids = llm_ids.to(device) 144 | # slm_attention_mask = slm_attention_mask.to(device) 145 | # llm_attention_mask = llm_attention_mask.to(device) 146 | 147 | # Run inference 148 | # with torch.no_grad(): 149 | # # outputs = rosetta_model.forward( 150 | # # input_ids=[slm_ids, llm_ids], 151 | # # attention_mask=[slm_attention_mask, llm_attention_mask], 152 | # # kv_cache_index=kv_cache_list, 153 | # # position_ids=torch.arange(slm_ids.shape[1]).unsqueeze(0).to(device), 154 | # # use_cache=True, 155 | # # output_attentions=False, 156 | # # output_hidden_states=False, 157 | # # sample=False, 158 | # # ) 159 | # outputs = rosetta_model(**inputs, kv_cache_index=kv_cache_index) 160 | 161 | # # Get logits and generate next token 162 | # logits = outputs.logits 163 | # next_token_logits = logits[0, -1, :] 164 | # next_token_id = torch.argmax(next_token_logits, dim=-1) 165 | # next_token = tokenizer.decode(next_token_id) 166 | 167 | # print(f"Output logits shape: {logits.shape}") 168 | # print(f"Next predicted token: '{next_token}'") 169 | # print("✅ Inference completed successfully") 170 | 171 | # Run generation 172 | with torch.no_grad(): 173 | # outputs = rosetta_model.generate( 174 | # prefill_kv_cache_index=prefill_kv_cache_list, 175 | # input_ids=[slm_ids, llm_ids], 176 | # attention_mask=[slm_attention_mask, llm_attention_mask], 177 | # use_cache=True, 178 | # output_attentions=False, 179 | # output_hidden_states=False, 180 | # max_new_tokens=256, 181 | # do_sample=False, 182 | # ) 183 | sampling_params = { 184 | 'do_sample': True, 185 | 'temperature': 0.7, 186 | 'top_p': 0.8, 187 | 'top_k': 20, 188 | 'min_p': 0.0, 189 | 'repetition_penalty': 1.2, 190 | 'max_new_tokens': 1024 191 | } 192 | outputs = rosetta_model.generate(**inputs, kv_cache_index=kv_cache_index, **sampling_params) 193 | output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) 194 | print(f"Rosetta output text: {output_text}") 195 | 196 | with torch.no_grad(): 197 | sampling_params = { 198 | 'do_sample': True, 199 | 'temperature': 0.7, 200 | 'top_p': 0.8, 201 | 'top_k': 20, 202 | 'min_p': 0.0, 203 | 'repetition_penalty': 1.2, 204 | 'max_new_tokens': 1024 205 | } 206 | slm_model = rosetta_model.model_list[0] 207 | outputs = slm_model.generate(**inputs, **sampling_params) 208 | output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) 209 | print(f"SLM output text: {output_text}") 210 | 211 | with torch.no_grad(): 212 | sampling_params = { 213 | 'do_sample': True, 214 | 'temperature': 0.7, 215 | 'top_p': 0.8, 216 | 'top_k': 20, 217 | 'min_p': 0.0, 218 | 'repetition_penalty': 1.2, 219 | 'max_new_tokens': 1024 220 | } 221 | llm_model = rosetta_model.model_list[1] 222 | outputs = llm_model.generate(**inputs, **sampling_params) 223 | output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) 224 | print(f"LLM output text: {output_text}") 225 | 226 | def load_rosetta_model(model_config: Dict[str, Any], eval_config: Dict[str, Any], 227 | device: torch.device) -> Tuple[Any, Any]: 228 | """ 229 | Load Rosetta model with projectors and aggregators. 230 | 231 | Args: 232 | model_config: Model configuration dict 233 | eval_config: Evaluation configuration dict 234 | device: Device to load model on 235 | 236 | Returns: 237 | Tuple of (rosetta_model, tokenizer) 238 | """ 239 | # Prefer checkpoints_dir under model.rosetta_config; fall back to eval config for backward compatibility 240 | rosetta_config = model_config["rosetta_config"] 241 | checkpoint_dir = rosetta_config.get("checkpoints_dir", eval_config.get("checkpoints_dir")) 242 | if checkpoint_dir is None: 243 | raise KeyError("checkpoints_dir must be provided under model.rosetta_config (preferred) or eval config (legacy)") 244 | slm_model_path = rosetta_config["base_model"] 245 | llm_model_path = rosetta_config["teacher_model"] 246 | 247 | # Load tokenizer 248 | slm_tokenizer = AutoTokenizer.from_pretrained(str(slm_model_path)) 249 | set_default_chat_template(slm_tokenizer, slm_model_path) 250 | 251 | # Load models 252 | slm_model = AutoModelForCausalLM.from_pretrained( 253 | str(slm_model_path), 254 | torch_dtype=torch.bfloat16, 255 | device_map={"": device} 256 | ).eval() 257 | 258 | llm_model = AutoModelForCausalLM.from_pretrained( 259 | str(llm_model_path), 260 | torch_dtype=torch.bfloat16, 261 | device_map={"": device} 262 | ).eval() 263 | 264 | # Load projectors 265 | num_projectors = len([f for f in os.listdir(checkpoint_dir) if re.match(r"projector_\d+\.pt", f)]) 266 | projector_list = [] 267 | for t in range(num_projectors): 268 | json_cfg = os.path.join(checkpoint_dir, f"projector_{t}.json") 269 | proj = load_projector(json_cfg) 270 | proj = proj.to(device) 271 | pt_path = os.path.join(checkpoint_dir, f"projector_{t}.pt") 272 | if os.path.exists(pt_path): 273 | state_dict = torch.load(pt_path, map_location=device) 274 | proj.load_state_dict(state_dict, strict=False) 275 | projector_list.append(proj) 276 | 277 | # Load aggregators 278 | num_aggregators = len([f for f in os.listdir(checkpoint_dir) if re.match(r"aggregator_\d+\.pt", f)]) 279 | aggregator_list = [] 280 | for t in range(num_aggregators): 281 | json_cfg = os.path.join(checkpoint_dir, f"aggregator_{t}.json") 282 | agg_path = os.path.join(checkpoint_dir, f"aggregator_{t}.pt") 283 | agg = load_aggregator(json_cfg) 284 | if os.path.exists(agg_path): 285 | sd = torch.load(agg_path, map_location="cpu") 286 | agg.load_state_dict(sd, strict=False) 287 | agg = agg.to(device) 288 | aggregator_list.append(agg) 289 | 290 | # Initialize Rosetta model 291 | rosetta_model = RosettaModel( 292 | model_list=[slm_model, llm_model], 293 | base_model_idx=0, 294 | projector_list=projector_list, 295 | aggregator_list=aggregator_list, 296 | ).to(device).eval() 297 | 298 | # Load projector/aggregator mapping configs 299 | proj_cfg_path = os.path.join(checkpoint_dir, "projector_config.json") 300 | agg_cfg_path = os.path.join(checkpoint_dir, "aggregator_config.json") 301 | rosetta_model.load_projector_config(proj_cfg_path) 302 | rosetta_model.load_aggregator_config(agg_cfg_path) 303 | 304 | return rosetta_model, slm_tokenizer 305 | 306 | def main(): 307 | """Main function to run the inference example""" 308 | 309 | rosetta_model, slm_tokenizer = load_rosetta_model( 310 | model_config={ 311 | "rosetta_config": { 312 | "base_model": "Qwen/Qwen3-0.6B", 313 | "teacher_model": "Qwen/Qwen3-4B", 314 | "checkpoints_dir": "local/checkpoints/0.6B_4B_general/final" 315 | } 316 | }, 317 | eval_config={}, 318 | device=torch.device("cuda") 319 | ) 320 | 321 | # Test token aligner 322 | # test_token_aligner(slm_tokenizer, llm_tokenizer) 323 | 324 | # Run inference 325 | prompt = [{ 326 | "role": "user", 327 | "content": "Accurately answer the following question:\n\nStatement 1 | If T: V -> W is a linear transformation and dim(V ) < dim(W) < 1, then T must be injective. Statement 2 | Let dim(V) = n and suppose that T: V -> V is linear. If T is injective, then it is a bijection.\n\nAre these statements correct? Let's think step by step and then answer the question starting with Answer:" 328 | }] 329 | run_inference_example(rosetta_model, slm_tokenizer, prompt) 330 | # run_inference_example(rosetta_model, slm_tokenizer, "从美国向北进入加拿大时,您会看到北星(北极星)越来越") 331 | 332 | 333 | if __name__ == "__main__": 334 | # import debugpy 335 | # debugpy.listen(("0.0.0.0", 5678)) 336 | # print("Waiting for debugger attach...") 337 | # debugpy.wait_for_client() 338 | # print("Debugger attached, running...") 339 | main() 340 | -------------------------------------------------------------------------------- /script/analysis/scaling/batch_evaluate_checkpoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | 批量评测已训练的模型检查点脚本 4 | 自动遍历scaling_MMLU_15k目录下的所有checkpoint,生成评测配置并运行评测 5 | """ 6 | 7 | import os 8 | import sys 9 | import json 10 | import yaml 11 | import subprocess 12 | from pathlib import Path 13 | from typing import Dict, List, Tuple, Optional 14 | import logging 15 | from datetime import datetime 16 | 17 | # 设置日志 18 | logging.basicConfig( 19 | level=logging.INFO, 20 | format='%(asctime)s - %(levelname)s - %(message)s', 21 | handlers=[ 22 | logging.FileHandler('batch_evaluate.log'), 23 | logging.StreamHandler() 24 | ] 25 | ) 26 | logger = logging.getLogger(__name__) 27 | 28 | class BatchCheckpointEvaluator: 29 | """批量checkpoint评测器""" 30 | 31 | def __init__(self): 32 | # 路径配置 33 | self.checkpoints_base_dir = Path("local/checkpoints/scaling_MMLU_15k") 34 | self.eval_template_path = "eval_recipe/test_eval.yaml" 35 | self.training_configs_dir = Path("recipe") 36 | self.evaluation_script = "script/evaluation/unified_evaluator.py" 37 | 38 | # 模型家族定义(与auto_scaling_experiment.py保持一致) 39 | self.family1 = { 40 | "tiny": "Qwen/Qwen3-0.6B", 41 | "small": "Qwen/Qwen3-1.7B", 42 | "medium": "Qwen/Qwen3-4B", 43 | "large": "Qwen/Qwen3-8B", 44 | "xlarge": "Qwen/Qwen3-14B" 45 | } 46 | 47 | self.family2 = { 48 | "tiny": "Qwen/Qwen2.5-0.5B-Instruct", 49 | "small": "Qwen/Qwen2.5-1.5B-Instruct", 50 | "medium": "Qwen/Qwen2.5-3B-Instruct", 51 | "large": "Qwen/Qwen2.5-7B-Instruct", 52 | "xlarge": "Qwen/Qwen2.5-14B-Instruct" 53 | } 54 | 55 | # 尺寸映射 56 | self.size_name_to_key = { 57 | "0.6B": "tiny", "1.7B": "small", "4B": "medium", "8B": "large", "14B": "xlarge", 58 | "0.5B": "tiny", "1.5B": "small", "3B": "medium", "7B": "large" 59 | } 60 | 61 | def find_checkpoints(self) -> List[Tuple[str, Path]]: 62 | """ 63 | 找到所有可用的checkpoint目录 64 | 65 | Returns: 66 | List of (experiment_name, checkpoint_path) tuples 67 | """ 68 | checkpoints = [] 69 | 70 | if not self.checkpoints_base_dir.exists(): 71 | logger.error(f"Checkpoints base directory does not exist: {self.checkpoints_base_dir}") 72 | return checkpoints 73 | 74 | for item in self.checkpoints_base_dir.iterdir(): 75 | if item.is_dir() and item.name.startswith("scaling_"): 76 | # 检查是否存在final子目录 77 | final_dir = item / "final" 78 | if final_dir.exists() and final_dir.is_dir(): 79 | experiment_name = item.name.replace("scaling_", "") 80 | checkpoints.append((experiment_name, final_dir)) 81 | logger.info(f"Found checkpoint: {experiment_name} -> {final_dir}") 82 | else: 83 | logger.warning(f"No 'final' directory found in {item}") 84 | 85 | logger.info(f"Found {len(checkpoints)} valid checkpoints") 86 | return checkpoints 87 | 88 | def parse_experiment_name(self, experiment_name: str) -> Optional[Tuple[str, str]]: 89 | """ 90 | 解析实验名称,提取base模型和teacher模型信息 91 | 92 | Args: 93 | experiment_name: 实验名称,如 "0.6B+0.5B", "1.7B+3B" 94 | 95 | Returns: 96 | (base_size_key, teacher_size_key) 或 None 97 | """ 98 | if "+" not in experiment_name: 99 | logger.error(f"Invalid experiment name format: {experiment_name}") 100 | return None 101 | 102 | try: 103 | base_size, teacher_size = experiment_name.split("+") 104 | 105 | # 查找对应的尺寸键 106 | base_key = self.size_name_to_key.get(base_size) 107 | teacher_key = self.size_name_to_key.get(teacher_size) 108 | 109 | if base_key is None or teacher_key is None: 110 | logger.error(f"Unknown model sizes in experiment: {experiment_name}") 111 | logger.error(f" Base size '{base_size}' -> {base_key}") 112 | logger.error(f" Teacher size '{teacher_size}' -> {teacher_key}") 113 | return None 114 | 115 | return base_key, teacher_key 116 | 117 | except ValueError: 118 | logger.error(f"Failed to parse experiment name: {experiment_name}") 119 | return None 120 | 121 | def load_eval_template(self) -> Dict: 122 | """加载评测模板配置""" 123 | try: 124 | with open(self.eval_template_path, 'r', encoding='utf-8') as f: 125 | return yaml.safe_load(f) 126 | except Exception as e: 127 | logger.error(f"Failed to load eval template {self.eval_template_path}: {e}") 128 | raise 129 | 130 | def create_eval_config(self, experiment_name: str, checkpoint_path: Path, 131 | base_model_path: str, teacher_model_path: str) -> str: 132 | """ 133 | 创建评测配置文件 134 | 135 | Args: 136 | experiment_name: 实验名称 137 | checkpoint_path: checkpoint路径 138 | base_model_path: base模型路径 139 | teacher_model_path: teacher模型路径 140 | 141 | Returns: 142 | 生成的配置文件路径 143 | """ 144 | # 加载模板 145 | config = self.load_eval_template() 146 | 147 | # 修改配置 148 | config["model"]["model_name"] = "Rosetta" 149 | config["model"]["rosetta_config"]["base_model"] = base_model_path 150 | config["model"]["rosetta_config"]["teacher_model"] = teacher_model_path 151 | config["model"]["rosetta_config"]["checkpoints_dir"] = str(checkpoint_path) 152 | 153 | # 设置输出目录 154 | output_dir = f"local/scaling_new_prompt_results/{experiment_name}_mmlu-redux" 155 | config["output"]["output_dir"] = output_dir 156 | 157 | # 保存配置文件 158 | config_filename = f"eval_recipe/scaling_new_prompt_{experiment_name}_eval.yaml" 159 | with open(config_filename, 'w', encoding='utf-8') as f: 160 | yaml.dump(config, f, default_flow_style=False, allow_unicode=True) 161 | 162 | logger.info(f"Created eval config: {config_filename}") 163 | logger.info(f" Base model: {base_model_path}") 164 | logger.info(f" Teacher model: {teacher_model_path}") 165 | logger.info(f" Checkpoint: {checkpoint_path}") 166 | logger.info(f" Output dir: {output_dir}") 167 | 168 | return config_filename 169 | 170 | def run_evaluation(self, config_path: str, experiment_name: str) -> bool: 171 | """ 172 | 运行单个实验的评测 173 | 174 | Args: 175 | config_path: 评测配置文件路径 176 | experiment_name: 实验名称 177 | 178 | Returns: 179 | 是否评测成功 180 | """ 181 | logger.info(f"Starting evaluation for experiment: {experiment_name}") 182 | 183 | try: 184 | # 构建评测命令 185 | cmd = [ 186 | "python", self.evaluation_script, 187 | "--config", config_path 188 | ] 189 | 190 | logger.info(f"Running command: {' '.join(cmd)}") 191 | 192 | # 运行评测(实时输出) 193 | process = subprocess.Popen( 194 | cmd, 195 | cwd=os.getcwd(), 196 | stdout=subprocess.PIPE, 197 | stderr=subprocess.STDOUT, 198 | text=True, 199 | bufsize=1, 200 | ) 201 | 202 | # 实时输出日志 203 | if process.stdout: 204 | for line in process.stdout: 205 | sys.stdout.write(line) 206 | sys.stdout.flush() 207 | 208 | process.wait() 209 | 210 | if process.returncode == 0: 211 | logger.info(f"Evaluation completed successfully for {experiment_name}") 212 | return True 213 | else: 214 | logger.error(f"Evaluation failed for {experiment_name}") 215 | logger.error(f"Process returned code {process.returncode}") 216 | return False 217 | 218 | except Exception as e: 219 | logger.error(f"Exception during evaluation {experiment_name}: {e}") 220 | return False 221 | 222 | def cleanup_temp_configs(self, temp_configs: List[str]): 223 | """清理临时生成的配置文件""" 224 | for config_path in temp_configs: 225 | try: 226 | if os.path.exists(config_path): 227 | os.remove(config_path) 228 | logger.info(f"Removed temporary config: {config_path}") 229 | except Exception as e: 230 | logger.warning(f"Failed to remove temporary config {config_path}: {e}") 231 | 232 | def check_existing_results(self, experiment_name: str) -> bool: 233 | """检查实验结果是否已存在""" 234 | output_dir = Path(f"local/scaling_new_prompt_results/{experiment_name}_mmlu-redux") 235 | if output_dir.exists() and any(output_dir.iterdir()): 236 | logger.info(f"Results already exist for {experiment_name}, skipping") 237 | return True 238 | return False 239 | 240 | def run_batch_evaluation(self, skip_existing: bool = True, dry_run: bool = False): 241 | """ 242 | 运行批量评测 243 | 244 | Args: 245 | skip_existing: 是否跳过已有结果的实验 246 | dry_run: 是否为试运行(不实际执行评测) 247 | """ 248 | logger.info("=" * 80) 249 | logger.info("STARTING BATCH CHECKPOINT EVALUATION") 250 | logger.info("=" * 80) 251 | 252 | # 查找所有checkpoint 253 | checkpoints = self.find_checkpoints() 254 | 255 | if not checkpoints: 256 | logger.error("No valid checkpoints found!") 257 | return 258 | 259 | logger.info(f"Found {len(checkpoints)} checkpoints to evaluate:") 260 | for exp_name, checkpoint_path in checkpoints: 261 | logger.info(f" ✓ {exp_name} -> {checkpoint_path}") 262 | 263 | if dry_run: 264 | logger.info("DRY RUN MODE - No actual evaluations will be performed") 265 | 266 | successful_evaluations = 0 267 | failed_evaluations = 0 268 | skipped_evaluations = 0 269 | temp_configs = [] 270 | 271 | for i, (experiment_name, checkpoint_path) in enumerate(checkpoints, 1): 272 | logger.info(f"\n[{i}/{len(checkpoints)}] Processing experiment: {experiment_name}") 273 | 274 | # 检查是否已有结果 275 | if skip_existing and self.check_existing_results(experiment_name): 276 | skipped_evaluations += 1 277 | continue 278 | 279 | # 解析实验名称 280 | parsed = self.parse_experiment_name(experiment_name) 281 | if parsed is None: 282 | logger.error(f"Failed to parse experiment name: {experiment_name}") 283 | failed_evaluations += 1 284 | continue 285 | 286 | base_key, teacher_key = parsed 287 | base_model_path = self.family1[base_key] 288 | teacher_model_path = self.family2[teacher_key] 289 | 290 | try: 291 | # 创建评测配置 292 | config_path = self.create_eval_config( 293 | experiment_name, checkpoint_path, 294 | base_model_path, teacher_model_path 295 | ) 296 | temp_configs.append(config_path) 297 | 298 | if dry_run: 299 | logger.info(f"DRY RUN: Would evaluate {experiment_name}") 300 | successful_evaluations += 1 301 | else: 302 | # 运行评测 303 | success = self.run_evaluation(config_path, experiment_name) 304 | 305 | if success: 306 | successful_evaluations += 1 307 | else: 308 | failed_evaluations += 1 309 | 310 | except Exception as e: 311 | logger.error(f"Failed to process {experiment_name}: {e}") 312 | failed_evaluations += 1 313 | 314 | # 清理临时配置文件 315 | if not dry_run: 316 | self.cleanup_temp_configs(temp_configs) 317 | 318 | # 输出最终结果 319 | logger.info(f"\n" + "=" * 80) 320 | logger.info("BATCH EVALUATION SUMMARY") 321 | logger.info(f"=" * 80) 322 | logger.info(f"Total checkpoints found: {len(checkpoints)}") 323 | logger.info(f"Successful evaluations: {successful_evaluations}") 324 | logger.info(f"Failed evaluations: {failed_evaluations}") 325 | logger.info(f"Skipped evaluations: {skipped_evaluations}") 326 | 327 | if dry_run: 328 | logger.info("\nThis was a DRY RUN. No actual evaluations were performed.") 329 | logger.info("Run with --no-dry-run to perform actual evaluations.") 330 | 331 | 332 | def main(): 333 | """主函数""" 334 | import argparse 335 | 336 | parser = argparse.ArgumentParser(description='Batch evaluate trained checkpoints') 337 | parser.add_argument('--skip-existing', action='store_true', default=True, 338 | help='Skip experiments that already have results (default: True)') 339 | parser.add_argument('--no-skip-existing', action='store_false', dest='skip_existing', 340 | help='Evaluate all experiments, even if results exist') 341 | parser.add_argument('--dry-run', action='store_true', default=False, 342 | help='Show what would be evaluated without running actual evaluations') 343 | parser.add_argument('--no-dry-run', action='store_false', dest='dry_run', 344 | help='Perform actual evaluations (default)') 345 | 346 | args = parser.parse_args() 347 | 348 | # 创建评测器并运行 349 | evaluator = BatchCheckpointEvaluator() 350 | evaluator.run_batch_evaluation( 351 | skip_existing=args.skip_existing, 352 | dry_run=args.dry_run 353 | ) 354 | 355 | 356 | if __name__ == "__main__": 357 | main() 358 | -------------------------------------------------------------------------------- /rosetta/model/ablation_projector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ablation Projector: A configurable projector for ablation studies based on C2CProjector. 3 | Allows gradual removal of components to study their individual contributions. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import Tensor 9 | from typing import Optional, Tuple, Literal 10 | 11 | from rosetta.utils.registry import register_model, capture_init_args 12 | from rosetta.model.projector import Projector 13 | from rosetta.model.projector import RegularMLP 14 | 15 | 16 | @register_model 17 | @capture_init_args 18 | class AblationProjector(Projector): 19 | """ 20 | Ablation study projector based on C2CProjector with configurable component removal. 21 | 22 | Ablation levels: 23 | 0. Full C2C (baseline) 24 | 1. Remove scalar weights (set to 1.0) 25 | 2. Remove gates (set to 1.0) 26 | 3. Remove target contribution (only use source) 27 | 4. Remove gates only (gates=1.0), keep scalars and target 28 | 29 | Each level builds on the previous one, allowing gradual degradation study. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | source_dim: int, 35 | target_dim: int, 36 | source_num_heads: int = 1, 37 | target_num_heads: int = 1, 38 | intermediate_dim: int = 1024, 39 | hidden_dim: int = 1024, 40 | num_layers: int = 3, 41 | dropout: float = 0.1, 42 | initial_temperature: float = 1.0, 43 | final_temperature: float = 0.001, 44 | anneal_steps: int = 1929, 45 | dtype: torch.dtype = torch.float32, 46 | 47 | # Ablation configuration 48 | ablation_level: int = 0, # 0=full, 1=no_scalar, 2=no_gate+no_scalar, 3=no_target, 4=no_gate_only 49 | use_scalar_weights: bool = True, # Can be overridden by ablation_level 50 | use_gates: bool = True, # Can be overridden by ablation_level 51 | use_target: bool = True, # Can be overridden by ablation_level 52 | ): 53 | super().__init__() 54 | 55 | assert 0 <= ablation_level <= 4, "ablation_level must be 0, 1, 2, 3, or 4" 56 | 57 | # Dimensions 58 | self.source_dim = source_dim 59 | self.target_dim = target_dim 60 | self.source_num_heads = source_num_heads 61 | self.target_num_heads = target_num_heads 62 | self.ablation_level = ablation_level 63 | 64 | # Override component usage based on ablation level 65 | if ablation_level == 4: 66 | # Special case: disable gates only, keep scalars and target 67 | use_scalar_weights = True 68 | use_gates = False 69 | use_target = True 70 | else: 71 | if ablation_level >= 1: 72 | use_scalar_weights = False 73 | if ablation_level >= 2: 74 | use_gates = False 75 | if ablation_level >= 3: 76 | use_target = False 77 | 78 | self.use_scalar_weights = use_scalar_weights 79 | self.use_gates = use_gates 80 | self.use_target = use_target 81 | 82 | # Sizes 83 | in_dim = source_dim * source_num_heads 84 | out_dim = target_dim * target_num_heads 85 | 86 | # 1) concat(source_X, target_X) then project to hidden_dim 87 | # If not using target, only use source features 88 | if self.use_target: 89 | self.key_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype) 90 | self.value_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype) 91 | else: 92 | # Only use source features 93 | self.key_in = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype) 94 | self.value_in = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype) 95 | 96 | # 2) one-layer common embedding MLP to get intermediate representation (at hidden_dim) 97 | self.key_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype) 98 | self.value_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype) 99 | 100 | # 3a) intermediate representation → (L-2)-layer MLP for weights → project to head dim 101 | # Only build if using scalar weights 102 | if self.use_scalar_weights: 103 | self.key_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype) 104 | self.value_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype) 105 | self.key_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype) 106 | self.value_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype) 107 | 108 | # 3b) intermediate representation → (L-2)-layer MLP for projected_X → finally project hidden_dim → out_dim 109 | self.key_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype) 110 | self.value_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype) 111 | self.key_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype) 112 | self.value_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype) 113 | 114 | # Scalar key/value gate parameters and temperature schedule 115 | # Only build if using gates 116 | if self.use_gates: 117 | self.key_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype)) 118 | self.value_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype)) 119 | self.use_gumbel = True 120 | self.register_buffer("gate_temperature", torch.tensor(initial_temperature, dtype=dtype)) 121 | self.initial_temperature = initial_temperature 122 | self.final_temperature = final_temperature 123 | self.anneal_steps = anneal_steps 124 | 125 | # Temperature for weight normalization 126 | self.scalar_temperature = 1.0 127 | 128 | def update_temperature(self, step: int): 129 | """Update temperature using exponential annealing schedule for gates.""" 130 | if self.use_gates: 131 | ratio = min(step / self.anneal_steps, 1.0) 132 | temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio 133 | self.gate_temperature.fill_(temp) 134 | 135 | def forward( 136 | self, 137 | source_kv: Tuple[Tensor, Tensor], 138 | target_kv: Tuple[Tensor, Tensor], 139 | position_ids: Optional[Tensor] = None, 140 | max_pos: Optional[Tensor] = None, 141 | ) -> Tuple[Tensor, Tensor]: 142 | source_key, source_value = source_kv 143 | target_key, target_value = target_kv 144 | 145 | B, Hs, N, Ds = source_key.shape 146 | _, Ht, _, Dt = target_key.shape 147 | 148 | # Flatten heads 149 | source_key_flat = source_key.transpose(1, 2).contiguous().view(B, N, Hs * Ds) 150 | source_value_flat = source_value.transpose(1, 2).contiguous().view(B, N, Hs * Ds) 151 | target_key_flat = target_key.transpose(1, 2).contiguous().view(B, N, Ht * Dt) 152 | target_value_flat = target_value.transpose(1, 2).contiguous().view(B, N, Ht * Dt) 153 | 154 | # 1) Prepare input features based on ablation level 155 | if self.use_target: 156 | # Full C2C: concat source and target features 157 | key_cat = torch.cat([source_key_flat, target_key_flat], dim=-1) 158 | value_cat = torch.cat([source_value_flat, target_value_flat], dim=-1) 159 | else: 160 | # Ablation level 3: only use source features 161 | key_cat = source_key_flat 162 | value_cat = source_value_flat 163 | 164 | # 2) project to hidden dim 165 | key_hidden = self.key_in(key_cat) 166 | value_hidden = self.value_in(value_cat) 167 | 168 | # 3) one-layer common embedding MLP to get intermediate representation (at hidden_dim) 169 | key_hidden = self.key_mlp1(key_hidden) 170 | value_hidden = self.value_mlp1(value_hidden) 171 | 172 | # 4b) intermediate representation -> projected feature path 173 | key_proj_hidden = self.key_proj_out(self.key_proj_mlp2(key_hidden)) # (B, N, Ht * Dt) 174 | value_proj_hidden = self.value_proj_out(self.value_proj_mlp2(value_hidden)) # (B, N, Ht * Dt) 175 | projected_key = key_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt) 176 | projected_value = value_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt) 177 | 178 | # 4a) intermediate representation -> scalar path (if using scalar weights) 179 | if self.use_scalar_weights: 180 | key_scalar = self.key_scalar_head(self.key_scalar_mlp2(key_hidden)) # (B, N, Ht) 181 | value_scalar = self.value_scalar_head(self.value_scalar_mlp2(value_hidden)) # (B, N, Ht) 182 | key_scalar = key_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1) 183 | value_scalar = value_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1) 184 | # Normalize scalars 185 | norm_key_scalar = torch.sigmoid(key_scalar) 186 | norm_value_scalar = torch.sigmoid(value_scalar) 187 | else: 188 | # Ablation level 1+: set scalar weights to 1.0 189 | norm_key_scalar = torch.ones(B, Ht, N, 1, device=projected_key.device, dtype=projected_key.dtype) 190 | norm_value_scalar = torch.ones(B, Ht, N, 1, device=projected_value.device, dtype=projected_value.dtype) 191 | 192 | # Key/value gates (if using gates) 193 | if self.use_gates: 194 | key_gate_logit = self.key_gate_logit.view(1, 1, 1, 1) 195 | value_gate_logit = self.value_gate_logit.view(1, 1, 1, 1) 196 | if self.training and self.use_gumbel: 197 | u1 = torch.rand(B, Ht, N, 1, device=key_gate_logit.device, dtype=key_gate_logit.dtype) 198 | u2 = torch.rand(B, Ht, N, 1, device=value_gate_logit.device, dtype=value_gate_logit.dtype) 199 | g1 = -torch.log(-torch.log(u1 + 1e-20) + 1e-20) 200 | g2 = -torch.log(-torch.log(u2 + 1e-20) + 1e-20) 201 | key_gate = torch.sigmoid((key_gate_logit + g1) / self.gate_temperature) 202 | value_gate = torch.sigmoid((value_gate_logit + g2) / self.gate_temperature) 203 | else: 204 | key_gate = (key_gate_logit > 0).float() 205 | value_gate = (value_gate_logit > 0).float() 206 | else: 207 | # Gates disabled: set gates to 1.0 (always open) 208 | key_gate = torch.ones(B, Ht, N, 1, device=projected_key.device, dtype=projected_key.dtype) 209 | value_gate = torch.ones(B, Ht, N, 1, device=projected_value.device, dtype=projected_value.dtype) 210 | 211 | # Compute projected contribution 212 | projected_key_term = key_gate * norm_key_scalar * projected_key 213 | projected_value_term = value_gate * norm_value_scalar * projected_value 214 | 215 | # Compute target contribution (if using target) 216 | if self.use_target: 217 | # Full C2C: add target with projected 218 | output_key = target_key + projected_key_term 219 | output_value = target_value + projected_value_term 220 | else: 221 | # Ablation level 3: only use projected (no target) 222 | output_key = projected_key_term 223 | output_value = projected_value_term 224 | 225 | return output_key, output_value 226 | 227 | def get_ablation_info(self) -> dict: 228 | """Return information about current ablation configuration.""" 229 | return { 230 | 'ablation_level': self.ablation_level, 231 | 'use_scalar_weights': self.use_scalar_weights, 232 | 'use_gates': self.use_gates, 233 | 'use_target': self.use_target, 234 | 'description': self._get_ablation_description() 235 | } 236 | 237 | def _get_ablation_description(self) -> str: 238 | """Get human-readable description of current ablation level.""" 239 | descriptions = { 240 | 0: "Full C2C (baseline)", 241 | 1: "No scalar weights (scalars=1.0)", 242 | 2: "No gates (gates=1.0) + No scalar weights", 243 | 3: "No target (source-only) + No gates + No scalar weights", 244 | 4: "No gates (gates=1.0), keep scalars and target" 245 | } 246 | return descriptions.get(self.ablation_level, "Unknown ablation level") 247 | 248 | 249 | # Convenience functions for creating specific ablation levels 250 | def create_ablation_projector( 251 | source_dim: int, 252 | target_dim: int, 253 | source_num_heads: int = 1, 254 | target_num_heads: int = 1, 255 | ablation_level: int = 0, 256 | **kwargs 257 | ) -> AblationProjector: 258 | """Create an AblationProjector with specified ablation level.""" 259 | return AblationProjector( 260 | source_dim=source_dim, 261 | target_dim=target_dim, 262 | source_num_heads=source_num_heads, 263 | target_num_heads=target_num_heads, 264 | ablation_level=ablation_level, 265 | **kwargs 266 | ) 267 | 268 | 269 | def create_full_c2c_projector(**kwargs) -> AblationProjector: 270 | """Create full C2C projector (ablation level 0).""" 271 | return create_ablation_projector(ablation_level=0, **kwargs) 272 | 273 | 274 | def create_no_scalar_projector(**kwargs) -> AblationProjector: 275 | """Create projector without scalar weights (ablation level 1).""" 276 | return create_ablation_projector(ablation_level=1, **kwargs) 277 | 278 | 279 | def create_no_gate_projector(**kwargs) -> AblationProjector: 280 | """Create projector without gates (ablation level 2).""" 281 | return create_ablation_projector(ablation_level=2, **kwargs) 282 | 283 | 284 | def create_source_only_projector(**kwargs) -> AblationProjector: 285 | """Create source-only projector (ablation level 3).""" 286 | return create_ablation_projector(ablation_level=3, **kwargs) 287 | 288 | 289 | def create_no_gate_only_projector(**kwargs) -> AblationProjector: 290 | """Create projector without gates but with scalar weights and target (ablation level 4).""" 291 | return create_ablation_projector(ablation_level=4, **kwargs) 292 | --------------------------------------------------------------------------------