├── gsm8k └── test-00000-of-00001.parquet ├── src ├── __pycache__ │ └── utils.cpython-39.pyc ├── data │ ├── __pycache__ │ │ ├── binary.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ └── my_dataset.cpython-39.pyc │ ├── __init__.py │ ├── my_dataset.py │ └── binary.py ├── model │ ├── __pycache__ │ │ ├── gpt2.cpython-39.pyc │ │ └── __init__.cpython-39.pyc │ ├── __init__.py │ └── gpt2.py ├── generate_data.py ├── utils.py └── train.py ├── config ├── data_config.py └── gpt2_tiny_wpetrain.py ├── huggingface_transformer_model └── gpt2 │ └── config.json ├── scripts └── train.sh ├── LICENSE ├── Figures ├── Fig9 │ └── fig9.py ├── Fig3 │ ├── fig3.1.py │ └── fig3.2.py ├── Fig2 │ ├── fig2.2.py │ ├── fig2.3.py │ └── fig2.1.py ├── Fig1 │ ├── fig1.1.py │ └── fig1.2.py ├── Fig5 │ ├── flg5.2.py │ └── fig5.1.py ├── Fig11 │ └── fig11.py ├── Fig10 │ └── fig10.py ├── Fig6-7 │ ├── fig7.py │ ├── fig6.py │ └── work.py ├── Fig8 │ └── fig8.py └── Fig4 │ └── fig4.py └── README.md /gsm8k/test-00000-of-00001.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhqwqwq/Learning-Parity-with-CoT/HEAD/gsm8k/test-00000-of-00001.parquet -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhqwqwq/Learning-Parity-with-CoT/HEAD/src/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/binary.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhqwqwq/Learning-Parity-with-CoT/HEAD/src/data/__pycache__/binary.cpython-39.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/gpt2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhqwqwq/Learning-Parity-with-CoT/HEAD/src/model/__pycache__/gpt2.cpython-39.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhqwqwq/Learning-Parity-with-CoT/HEAD/src/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhqwqwq/Learning-Parity-with-CoT/HEAD/src/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/my_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhqwqwq/Learning-Parity-with-CoT/HEAD/src/data/__pycache__/my_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /config/data_config.py: -------------------------------------------------------------------------------- 1 | { 2 | 'dataset_type': 'BinaryDataset', 3 | 'n_samples': 10000, 4 | 'val_samples': 2048, 5 | 'n_digit': 30, 6 | 'n_secret': 3, 7 | 'use_cot': True, 8 | } -------------------------------------------------------------------------------- /config/gpt2_tiny_wpetrain.py: -------------------------------------------------------------------------------- 1 | { 2 | 'model_type': 'gpt2', 3 | 'from_config' : True, 4 | 'hidden_size' : 720, 5 | 'intermediate_size' : 3072, 6 | 'num_hidden_layers' : 1, 7 | 'num_attention_heads' : 1, 8 | 'max_position_embeddings' : 2048, 9 | 'vocab_size' : 2, 10 | 'onehot_embed': False, 11 | 'wpe_train': True, 12 | } 13 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .binary import BinaryDataset 2 | import torch 3 | import os 4 | 5 | dataset_type_list = ['BinaryDataset'] 6 | def load_dataset(input_dir, dataset_type = 'BinaryDataset'): 7 | dataset = eval(dataset_type)(torch.load(os.path.join(input_dir, 'kwargs.pt')), generate = False) 8 | dataset.load(input_dir) 9 | return dataset 10 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt2 import get_gpt_transformer_model 2 | 3 | def get_model( 4 | model_type, 5 | **kwargs 6 | ): 7 | assert model_type == 'gpt2' 8 | model = get_gpt_transformer_model( 9 | **kwargs 10 | ) 11 | print(f'Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9:.2f}B') 12 | return model 13 | 14 | -------------------------------------------------------------------------------- /huggingface_transformer_model/gpt2/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_function": "gelu_new", 3 | "architectures": [ 4 | "GPT2LMHeadModel" 5 | ], 6 | "attn_pdrop": 0.1, 7 | "bos_token_id": 50256, 8 | "embd_pdrop": 0.1, 9 | "eos_token_id": 50256, 10 | "initializer_range": 0.02, 11 | "layer_norm_epsilon": 1e-05, 12 | "model_type": "gpt2", 13 | "n_ctx": 1024, 14 | "n_embd": 768, 15 | "n_head": 12, 16 | "n_layer": 12, 17 | "n_positions": 1024, 18 | "resid_pdrop": 0.1, 19 | "summary_activation": null, 20 | "summary_first_dropout": 0.1, 21 | "summary_proj_to_labels": true, 22 | "summary_type": "cls_index", 23 | "summary_use_proj": true, 24 | "task_specific_params": { 25 | "text-generation": { 26 | "do_sample": true, 27 | "max_length": 50 28 | } 29 | }, 30 | "vocab_size": 50257 31 | } -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | PROJECT=Istree 3 | MODEL=transformer 4 | total_training_sample=100000 5 | training_samples=10000 6 | n_digits=30 7 | k=3 8 | CoT=True 9 | lr=6e-5 10 | num_layers=4 11 | num_heads=3 12 | python src/train.py \ 13 | --world_size 1 \ 14 | --total_training_samples ${training_samples} \ 15 | --model_type transformer \ 16 | --model_config_path config/gpt2_tiny_wpetrain.py \ 17 | --dataset_dir data/Nonintersect_Binary/binary_${training_samples}_${n_digits}_${k}_${CoT}_False_False \ 18 | --dataset_type BinaryDataset \ 19 | --output_dir model/ \ 20 | --batch_size 512 \ 21 | --lr ${lr} \ 22 | --weight_decay 0 \ 23 | --log_interval 2048 \ 24 | --save_interval 2048 \ 25 | --eval_interval 2048 \ 26 | --report_to_wandb \ 27 | --num_hidden_layers ${num_layers} \ 28 | --num_attention_heads ${num_heads} \ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Huaqing Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/generate_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import * 3 | import numpy as np 4 | import argparse 5 | import copy 6 | import random 7 | 8 | if __name__ == '__main__': 9 | np.random.seed(0) 10 | random.seed(0) 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--config_path', default = '', type = str) 13 | args = parser.parse_args() 14 | config = eval(open(args.config_path, 'r').read()) 15 | assert config['dataset_type'] == 'BinaryDataset' # only for Binary now. 16 | n_samples = config['n_samples'] 17 | n_digit = config['n_digit'] 18 | 19 | 20 | if n_digit <= 20: 21 | input_list = list(range(1<= thre: 26 | print(result_step["epoch"] * 1000000) 27 | return result_step["epoch"] * 1000000 28 | return 1000000 29 | 30 | num_layers = 1 31 | num_heads = 1 32 | n_digits = 100 33 | n_samples = 1000000 34 | total_samples = 1000000 35 | CoT = "True" 36 | 37 | for embedding in ["gpt2_tiny_wpetrain"]: 38 | sample_complexity = [1000000 for i in range(9)] 39 | for idx in range(9): 40 | k = [20, 30, 40, 50, 60, 70, 80, 90, 100][idx] 41 | for lr in [6e-5, 8e-5, 1e-4]: 42 | # for lr in [1e-4]: 43 | path = f"model/binary_{n_samples}_{n_digits}_{k}_{CoT}_False_False_{total_samples}_LR={lr}_WD=0.0_1GPU*512Batch_{embedding}.py_#layer={num_layers}_#head={num_heads}" 44 | result = get_result(path) 45 | if result != None: 46 | sample_complexity[idx] = min(sample_complexity[idx], get_sample_complexity(result,0.995)) 47 | print(sample_complexity) 48 | 49 | x = [20, 30, 40, 50, 60, 70, 80, 90, 100] 50 | 51 | plt.figure(figsize=(6, 5.5)) 52 | 53 | plt.plot(x, sample_complexity, marker='x', label="with CoT") 54 | 55 | def y_format_func(value, tick_number): 56 | return f'{value / 1e5:.1f}' 57 | 58 | plt.gca().yaxis.set_major_formatter(ticker.FuncFormatter(y_format_func)) 59 | 60 | plt.gca().set_ylabel("Sample Complexity ($\\times10^5$)", fontsize=24) 61 | 62 | plt.xlabel("Number of Secret Variables $k$", fontsize=24) 63 | plt.xticks([20, 30, 40, 50, 60, 70, 80, 90, 100], fontsize=18) 64 | plt.yticks(fontsize=18) 65 | 66 | plt.grid(True, which="both", ls="--", linewidth=0.5) 67 | plt.legend(fontsize=24) 68 | plt.tight_layout() 69 | plt.show() 70 | plt.savefig("Figures/Figs/fig3.1.pdf") 71 | plt.savefig("Figures/Figs/fig3.1.svg") 72 | 73 | -------------------------------------------------------------------------------- /Figures/Fig2/fig2.2.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import json 4 | import os 5 | from matplotlib.ticker import LogLocator, FuncFormatter 6 | 7 | def get_result(path): 8 | for item in os.listdir(path): 9 | item_path = os.path.join(path, item) 10 | if os.path.isdir(item_path): 11 | path = item_path 12 | path = os.path.join(path,"trainer_state.json") 13 | if os.path.isfile(path): 14 | with open(path, 'r') as f: 15 | result = json.load(f) 16 | return result 17 | else: 18 | return None 19 | 20 | def get_sample_complexity(result): 21 | for index, result_step in enumerate(result["log_history"]): 22 | if index % 2 == 1: 23 | if result_step["eval_exact_match"] == 1: 24 | return result_step["epoch"] * 10000000 25 | return 10000000 26 | 27 | num_layers = 1 28 | num_heads = 1 29 | n_digits = 30 30 | n_samples = 10000000 31 | total_samples = 10000000 32 | 33 | for embedding in ["gpt2_tiny_wpetrain"]: 34 | sample_complexity = {} 35 | sample_complexity["True"] = [10000000 for i in range(4)] 36 | sample_complexity["False"] = [10000000 for i in range(4)] 37 | for CoT in ["True","False"]: 38 | for k in [1,2,3,4]: 39 | for lr in [6e-5, 8e-5, 1e-4]: 40 | path = f"model/binary_{n_samples}_{n_digits}_{k}_{CoT}_False_False_{total_samples}_LR={lr}_WD=0.0_1GPU*512Batch_{embedding}.py_#layer={num_layers}_#head={num_heads}" 41 | result = get_result(path) 42 | if result != None: 43 | print(CoT, k, lr, get_sample_complexity(result)) 44 | sample_complexity[CoT][k-1] = min(sample_complexity[CoT][k-1], get_sample_complexity(result)) 45 | print(sample_complexity[CoT]) 46 | 47 | 48 | plt.figure(figsize=(6.5, 6)) 49 | 50 | x = [1,2,3,4] 51 | plt.plot(x, sample_complexity["True"], label="with CoT", marker='o') 52 | plt.plot(x, sample_complexity["False"], label="without CoT", marker='x') 53 | plt.yscale('log') 54 | 55 | def y_formatter(y, pos): 56 | if y == 1e7: 57 | return '≥ 10⁷' 58 | else: 59 | return f'$10^{{{int(np.log10(y))}}}$' 60 | plt.gca().yaxis.set_major_formatter(FuncFormatter(y_formatter)) 61 | 62 | plt.xlabel("Number of Secret Variables $k$", fontsize=24) 63 | plt.ylabel("Sample Complexity", fontsize=24) 64 | plt.xticks([1, 2, 3, 4], fontsize=18) 65 | plt.yticks(fontsize=18) 66 | 67 | # Add legend 68 | plt.legend(fontsize=20) 69 | 70 | plt.grid(True, which="major", ls="--", linewidth=0.5) 71 | 72 | plt.tight_layout() 73 | plt.show() 74 | plt.savefig("Figures/Figs/fig2.2.pdf") 75 | plt.savefig("Figures/Figs/fig2.2.svg") 76 | 77 | -------------------------------------------------------------------------------- /Figures/Fig2/fig2.3.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import json 4 | import os 5 | from matplotlib.ticker import LogLocator, FuncFormatter 6 | 7 | def get_result(path): 8 | for item in os.listdir(path): 9 | item_path = os.path.join(path, item) 10 | if os.path.isdir(item_path): 11 | path = item_path 12 | path = os.path.join(path,"trainer_state.json") 13 | if os.path.isfile(path): 14 | with open(path, 'r') as f: 15 | result = json.load(f) 16 | return result 17 | else: 18 | return None 19 | 20 | def get_sample_complexity(result): 21 | for index, result_step in enumerate(result["log_history"]): 22 | if index % 2 == 1: 23 | if result_step["eval_exact_match"] == 1: 24 | return result_step["epoch"] * 10000000 25 | return 10000000 26 | 27 | num_layers = 4 28 | num_heads = 4 29 | n_digits = 30 30 | n_samples = 10000000 31 | total_samples = 10000000 32 | 33 | for embedding in ["gpt2_tiny_wpetrain"]: 34 | sample_complexity = {} 35 | sample_complexity["True"] = [10000000 for i in range(4)] 36 | sample_complexity["False"] = [10000000 for i in range(4)] 37 | for CoT in ["True","False"]: 38 | for k in [1,2,3,4]: 39 | for lr in [6e-5, 8e-5, 1e-4]: 40 | path = f"model/binary_{n_samples}_{n_digits}_{k}_{CoT}_False_False_{total_samples}_LR={lr}_WD=0.0_1GPU*512Batch_{embedding}.py_#layer={num_layers}_#head={num_heads}" 41 | result = get_result(path) 42 | if result != None: 43 | print(CoT, k, lr, get_sample_complexity(result)) 44 | sample_complexity[CoT][k-1] = min(sample_complexity[CoT][k-1], get_sample_complexity(result)) 45 | print(sample_complexity[CoT]) 46 | 47 | plt.figure(figsize=(6.5, 6)) 48 | x = [1,2,3,4] 49 | plt.plot(x, sample_complexity["True"], label="with CoT", marker='o') 50 | plt.plot(x, sample_complexity["False"], label="without CoT", marker='x') 51 | plt.yscale('log') 52 | 53 | def y_formatter(y, pos): 54 | if y == 1e7: 55 | return '≥ 10⁷' 56 | else: 57 | return f'$10^{{{int(np.log10(y))}}}$' 58 | 59 | plt.gca().yaxis.set_major_formatter(FuncFormatter(y_formatter)) 60 | 61 | plt.title("Sample complexity of 4-layer 4-head\ntransformer on parity with n=30", fontsize=20) 62 | plt.xlabel("Number of Secret Variables $k$", fontsize=24) 63 | plt.ylabel("Sample Complexity", fontsize=24) 64 | plt.xticks([1, 2, 3, 4], fontsize=18) 65 | plt.yticks(fontsize=18) 66 | 67 | plt.legend(fontsize=20) 68 | 69 | plt.grid(True, which="major", ls="--", linewidth=0.5) 70 | 71 | plt.tight_layout() 72 | plt.show() 73 | plt.savefig("Figures/Figs/fig2.3.pdf") 74 | plt.savefig("Figures/Figs/fig2.3.svg") -------------------------------------------------------------------------------- /Figures/Fig1/fig1.1.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import json 4 | import os 5 | from matplotlib.ticker import LogLocator, FuncFormatter 6 | 7 | def get_result(path): 8 | for item in os.listdir(path): 9 | item_path = os.path.join(path, item) 10 | if os.path.isdir(item_path): 11 | path = item_path 12 | path = os.path.join(path,"trainer_state.json") 13 | if os.path.isfile(path): 14 | with open(path, 'r') as f: 15 | result = json.load(f) 16 | return result 17 | else: 18 | return None 19 | 20 | def get_sample_complexity(result): 21 | for index, result_step in enumerate(result["log_history"]): 22 | if index % 2 == 1: 23 | if result_step["eval_exact_match"] == 1: 24 | return result_step["epoch"] * 10000000 25 | return 10000000 26 | 27 | num_layers = 4 28 | num_heads = 4 29 | n_digits = 30 30 | n_samples = 10000000 31 | total_samples = 10000000 32 | 33 | for embedding in ["gpt2_tiny_wpetrain"]: 34 | sample_complexity = {} 35 | sample_complexity["True"] = [10000000 for i in range(4)] 36 | sample_complexity["False"] = [10000000 for i in range(4)] 37 | for CoT in ["True","False"]: 38 | for k in [1,2,3,4]: 39 | for lr in [6e-5, 8e-5, 1e-4]: 40 | path = f"model/binary_{n_samples}_{n_digits}_{k}_{CoT}_False_False_{total_samples}_LR={lr}_WD=0.0_1GPU*512Batch_{embedding}.py_#layer={num_layers}_#head={num_heads}" 41 | result = get_result(path) 42 | if result != None: 43 | print(CoT, k, lr, get_sample_complexity(result)) 44 | sample_complexity[CoT][k-1] = min(sample_complexity[CoT][k-1], get_sample_complexity(result)) 45 | print(sample_complexity[CoT]) 46 | 47 | plt.figure(figsize=(6.5, 6)) 48 | x = [1,2,3,4] 49 | plt.plot(x, sample_complexity["True"], label="with CoT", marker='o') 50 | plt.plot(x, sample_complexity["False"], label="without CoT", marker='x') 51 | plt.yscale('log') 52 | 53 | def y_formatter(y, pos): 54 | if y == 1e7: 55 | return '≥ 10⁷' 56 | else: 57 | return f'$10^{{{int(np.log10(y))}}}$' 58 | 59 | plt.gca().yaxis.set_major_formatter(FuncFormatter(y_formatter)) 60 | 61 | # Add titles and labels 62 | plt.title("Sample complexity of 4-layer 4-head\ntransformer on parity with n=30", fontsize=20) 63 | plt.xlabel("Number of Secret Variables $k$", fontsize=24) 64 | plt.ylabel("Sample Complexity", fontsize=24) 65 | plt.xticks([1, 2, 3, 4], fontsize=18) 66 | plt.yticks(fontsize=18) 67 | 68 | # Add legend 69 | plt.legend(fontsize=20) 70 | 71 | plt.grid(True, which="major", ls="--", linewidth=0.5) 72 | 73 | plt.tight_layout() 74 | plt.show() 75 | plt.savefig("Figures/Figs/fig1.1.pdf") 76 | plt.savefig("Figures/Figs/fig1.1.svg") -------------------------------------------------------------------------------- /Figures/Fig5/flg5.2.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import json 4 | import os 5 | 6 | def get_result(path): 7 | for item in os.listdir(path): 8 | item_path = os.path.join(path, item) 9 | if os.path.isdir(item_path): 10 | path = item_path 11 | path = os.path.join(path,"trainer_state.json") 12 | with open(path, 'r') as f: 13 | result = json.load(f) 14 | return result 15 | 16 | def get_data(result): 17 | accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ = [], [], [], [], [] 18 | for index, result_step in enumerate(result["log_history"]): 19 | if index % 2 == 1: 20 | accuracy_.append(result_step["eval_exact_match"]) 21 | evaluation_loss_.append(result_step["eval_loss"]) 22 | entropy_.append(result_step["mean_min_attention_entropy_over_token"]) 23 | num_epoch_.append(result_step["epoch"]) 24 | else: 25 | training_loss_.append(result_step["loss"]) 26 | return accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ 27 | 28 | n_digits = 20 29 | lr = 1e-4 30 | total_samples = 10000000 31 | num_layers = 4 32 | num_heads = 4 33 | 34 | (embedding, k) = ("gpt2_tiny_wpetrain", 6) 35 | fig, axs = plt.subplots(2, num_layers, figsize=(22, 11)) 36 | 37 | column_titles = [] 38 | for i in range(num_layers): 39 | column_titles.append(f"Layer {i+1}") 40 | for col in range(num_layers): 41 | axs[0, col].set_title(column_titles[col], fontsize=28, pad=20) 42 | 43 | all_lines = [] 44 | all_labels = [] 45 | 46 | for i in range(2): 47 | n_samples = [50000, 1000000][i] 48 | path = f"model/binary_{n_samples}_{n_digits}_{k}_False_False_False_{total_samples}_LR={lr}_WD=0.0_1GPU*512Batch_{embedding}.py_#layer={num_layers}_#head={num_heads}" 49 | result = get_result(path) 50 | accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ = get_data(result) 51 | entropy_ = np.array(entropy_) 52 | 53 | axs[i, 0].set_ylabel(f'Normalized Attention Entropy', fontsize=24, labelpad=20) 54 | total_samples_ = [item * n_samples for item in num_epoch_] 55 | 56 | for idlayer in range(num_layers): 57 | for j in range(num_heads): 58 | line, = axs[i, idlayer].plot(total_samples_, entropy_[:, idlayer, j], label=f'head {j}', linewidth=3) 59 | 60 | if i == 0 and idlayer == 0: 61 | all_lines.append(line) 62 | all_labels.append(f'head {j}') 63 | 64 | axs[i, idlayer].set_xlabel('Iterations × Batch Size', fontsize=23) 65 | axs[i, idlayer].set_ylim(0, 1.05) 66 | axs[i, idlayer].tick_params(axis='both', labelsize=18) 67 | axs[i, idlayer].grid(True, which='both', linestyle='--', linewidth=0.5) 68 | axs[i, idlayer].xaxis.get_offset_text().set_fontsize(16) 69 | fig.legend(all_lines, all_labels, loc='center left', bbox_to_anchor=(1.05, 0.5), fontsize=18, frameon=False) 70 | 71 | plt.tight_layout() 72 | plt.savefig(f'Figures/Figs/fig5.2.svg') 73 | plt.savefig(f'Figures/Figs/fig5.2.pdf') 74 | plt.show() 75 | -------------------------------------------------------------------------------- /src/data/my_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from tqdm import tqdm 5 | import json 6 | import os 7 | from transformers import AutoTokenizer 8 | class MyDataset(Dataset): 9 | def __init__(self, kwargs, generate = True): 10 | assert 'n_samples' in kwargs, 'n_samples is required' 11 | self.init_from_args(kwargs) 12 | self.input_ids = [] 13 | self.attention_mask = [] 14 | self.data = [] 15 | self.labels = [] 16 | self.auxiliary_inputs = [] 17 | self.auxiliary_labels = [] 18 | if(generate): 19 | self.generate(self.n_samples) 20 | def init_from_args(self, kwargs): 21 | for key in kwargs: 22 | setattr(self, key, kwargs[key]) 23 | self.kwargs = kwargs 24 | if 'tokenizer_name' in kwargs: 25 | self.tokenizer = AutoTokenizer.from_pretrained(kwargs['tokenizer_name']) 26 | self.tokenizer.pad_token = self.tokenizer.eos_token 27 | self.tokenizer.padding_side = "left" 28 | def __len__(self): 29 | return self.n_samples 30 | def __getitem__(self, idx): 31 | if(len(self.auxiliary_labels) > 0): 32 | return {'input_ids': self.input_ids[idx], 33 | 'attention_mask': self.attention_mask[idx], 34 | 'labels': self.labels[idx], 35 | 'data': self.data[idx], 36 | 'auxiliary_labels': self.auxiliary_labels[idx], 37 | 'auxiliary_inputs': self.auxiliary_inputs[idx]} 38 | else: 39 | return {'input_ids': self.input_ids[idx], 40 | 'attention_mask': self.attention_mask[idx], 41 | 'labels': self.labels[idx], 42 | 'data': self.data[idx],} 43 | def generate(self, n_samples): 44 | for i in tqdm(range(n_samples)): 45 | self._generate() 46 | def _generate(self): 47 | assert NotImplementedError 48 | 49 | def save(self, output_dir): 50 | if not os.path.exists(output_dir): 51 | os.makedirs(output_dir) 52 | torch.save({ 53 | 'input_ids': self.input_ids, 54 | 'attention_mask': self.attention_mask, 55 | 'labels': self.labels, 56 | 'data': self.data, 57 | 'auxiliary_labels': self.auxiliary_labels, 58 | 'auxiliary_inputs': self.auxiliary_inputs, 59 | }, f'{output_dir}/tensor_data.pt') 60 | torch.save(self.kwargs, f'{output_dir}/kwargs.pt') 61 | 62 | 63 | def load(self, input_dir): 64 | kwargs = torch.load(f'{input_dir}/kwargs.pt') 65 | self.init_from_args(kwargs) 66 | tensor_data = torch.load(f'{input_dir}/tensor_data.pt') 67 | self.input_ids = tensor_data['input_ids'] 68 | self.attention_mask = tensor_data['attention_mask'] 69 | self.labels = tensor_data['labels'] 70 | self.data = tensor_data['data'] 71 | if('auxiliary_labels' in tensor_data): 72 | self.auxiliary_labels = tensor_data['auxiliary_labels'] 73 | self.auxiliary_inputs = tensor_data['auxiliary_inputs'] 74 | def get_name(self): 75 | assert NotImplementedError 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /Figures/Fig5/fig5.1.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import json 4 | import os 5 | 6 | def get_result(path): 7 | for item in os.listdir(path): 8 | item_path = os.path.join(path, item) 9 | if os.path.isdir(item_path): 10 | path = item_path 11 | path = os.path.join(path,"trainer_state.json") 12 | with open(path, 'r') as f: 13 | result = json.load(f) 14 | return result 15 | 16 | def get_data(result): 17 | accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ = [], [], [], [], [] 18 | for index, result_step in enumerate(result["log_history"]): 19 | if index % 2 == 1: 20 | accuracy_.append(result_step["eval_exact_match"]) 21 | evaluation_loss_.append(result_step["eval_loss"]) 22 | entropy_.append(result_step["mean_min_attention_entropy_over_token"]) 23 | num_epoch_.append(result_step["epoch"]) 24 | else: 25 | training_loss_.append(result_step["loss"]) 26 | return accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ 27 | accuracy_color = '#1f77b4' 28 | training_loss_color = '#ff7f0e' 29 | evaluation_loss_color = '#7f7f7f' 30 | 31 | n_digits = 20 32 | lr = 1e-4 33 | total_samples = 10000000 34 | num_layers = 4 35 | num_heads = 4 36 | for (embedding,k) in [("gpt2_tiny_wpetrain",6)]: 37 | 38 | all_lines = [] 39 | all_labels = [] 40 | for i in range(2): 41 | fig, axs = plt.subplots(1, 1, figsize=(4.8, 3.8)) 42 | n_samples = [50000, 1000000][i] 43 | path = f"model/binary_{n_samples}_{n_digits}_{k}_False_False_False_{total_samples}_LR={lr}_WD=0.0_1GPU*512Batch_{embedding}.py_#layer={num_layers}_#head={num_heads}" 44 | result = get_result(path) 45 | accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ = get_data(result) 46 | total_samples_ = [item * n_samples for item in num_epoch_] 47 | 48 | ax1 = axs 49 | line1, = ax1.plot(total_samples_, accuracy_, color=accuracy_color, label='Evaluation Accuracy') 50 | ax1.set_xlabel('Iterations × Batch Size', fontsize=15) 51 | ax1.set_ylabel(f'{n_samples} training samples\nEvaluation Accuracy', fontsize=17) 52 | ax1.set_ylim(0.48, 1.05) 53 | ax1.tick_params(axis='y', labelsize=12) 54 | ax1.tick_params(axis='x', labelsize=12) 55 | 56 | ax2 = ax1.twinx() 57 | line2, = ax2.plot(total_samples_, training_loss_, color=training_loss_color, label='Training Loss') 58 | line3, = ax2.plot(total_samples_, evaluation_loss_, color=evaluation_loss_color, label='Evaluation Loss') 59 | ax2.set_ylabel(f'Loss', fontsize=17) 60 | if i == 0: 61 | ax2.set_ylim(0, 4.2) 62 | else: 63 | ax2.set_ylim(0, 1.2) 64 | ax2.tick_params(axis='y', labelsize=12) 65 | 66 | ax1.grid(True, which='both', linestyle='--', linewidth=0.5) 67 | 68 | lines = [line1, line2, line3] 69 | labels = [line.get_label() for line in lines] 70 | 71 | if i == 1: 72 | ax1.legend(lines, labels, loc='right', fontsize=10) 73 | 74 | plt.tight_layout() 75 | plt.savefig(f'Figures/Figs/fig6_{n_samples}.svg', dpi=300, bbox_inches='tight') 76 | plt.savefig(f'Figures/Figs/fig6_{n_samples}.pdf', dpi=300, bbox_inches='tight') 77 | plt.show() 78 | -------------------------------------------------------------------------------- /Figures/Fig11/fig11.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import json 4 | import os 5 | 6 | def get_result(path): 7 | for item in os.listdir(path): 8 | item_path = os.path.join(path, item) 9 | if os.path.isdir(item_path): 10 | path = item_path 11 | path = os.path.join(path,"trainer_state.json") 12 | with open(path, 'r') as f: 13 | result = json.load(f) 14 | return result 15 | 16 | def get_data(result): 17 | accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ = [], [], [], [], [] 18 | for index, result_step in enumerate(result["log_history"]): 19 | if index % 2 == 1: 20 | accuracy_.append(result_step["eval_exact_match"]) 21 | evaluation_loss_.append(result_step["eval_loss"]) 22 | entropy_.append(result_step["mean_min_attention_entropy_over_token"]) 23 | num_epoch_.append(result_step["epoch"]) 24 | else: 25 | training_loss_.append(result_step["loss"]) 26 | return accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ 27 | 28 | n_digits = 20 29 | lr = 1e-4 30 | total_samples = 10000000 31 | num_layers = 2 32 | num_heads = 3 33 | 34 | (embedding, k) = ("gpt2_tiny_wpetrain", 6) 35 | fig, axs = plt.subplots(2, num_layers, figsize=(11, 10)) 36 | column_titles = [] 37 | for i in range(num_layers): 38 | column_titles.append(f"Layer {i+1}") 39 | for col in range(num_layers): 40 | if num_layers == 1: 41 | axs[0].set_title(column_titles[col], fontsize=24, pad=20) 42 | else: 43 | axs[0, col].set_title(column_titles[col], fontsize=24, pad=20) 44 | 45 | 46 | all_lines = [] 47 | all_labels = [] 48 | 49 | for i in range(2): 50 | n_samples = [50000, 1000000][i] 51 | path = f"model/binary_{n_samples}_{n_digits}_{k}_False_False_False_{total_samples}_LR={lr}_WD=0.0_1GPU*512Batch_{embedding}.py_#layer={num_layers}_#head={num_heads}" 52 | result = get_result(path) 53 | accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ = get_data(result) 54 | entropy_ = np.array(entropy_) 55 | if num_layers == 1: 56 | axs[i].set_ylabel(f'{n_samples} training samples\n\nNormalized Attention Entropy', fontsize=22, labelpad=20) 57 | else: 58 | axs[i, 0].set_ylabel(f'{n_samples} training samples\n\nNormalized Attention Entropy', fontsize=22, labelpad=20) 59 | 60 | total_samples_ = [item * n_samples for item in num_epoch_] 61 | 62 | for idlayer in range(num_layers): 63 | if num_layers == 1: 64 | ax = axs[i] 65 | else: 66 | ax = axs[i,idlayer] 67 | for j in range(num_heads): 68 | line, = ax.plot(total_samples_, entropy_[:, idlayer, j], label=f'head {j}', linewidth=3) 69 | 70 | if i == 0 and idlayer == 0: 71 | all_lines.append(line) 72 | all_labels.append(f'head {j}') 73 | ax.set_xlabel('Iterations × Batch Size', fontsize=18) 74 | ax.set_ylim(0, 1.05) 75 | ax.tick_params(axis='both', labelsize=18) 76 | ax.grid(True, which='both', linestyle='--', linewidth=0.5) 77 | ax.xaxis.get_offset_text().set_fontsize(16) 78 | fig.legend(all_lines, all_labels, loc='center left', bbox_to_anchor=(1.05, 0.5), fontsize=18, frameon=False) 79 | 80 | plt.tight_layout() 81 | 82 | plt.savefig(f'Figures/Figs/fig11.svg') 83 | plt.savefig(f'Figures/Figs/fig11.pdf') 84 | plt.show() 85 | -------------------------------------------------------------------------------- /src/data/binary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from .my_dataset import MyDataset 5 | from tqdm import tqdm 6 | import json 7 | import os 8 | from transformers import AutoTokenizer 9 | class BinaryDataset(MyDataset): 10 | def __init__(self, kwargs, generate = True): 11 | if('secret' not in kwargs): 12 | n_digit = kwargs['n_digit'] 13 | n_secret = kwargs['n_secret'] 14 | while True: 15 | secret = torch.zeros([n_digit]) 16 | secret_idx = np.random.choice(n_digit, n_secret, replace=False) 17 | secret[secret_idx] = 1 18 | if secret[-1] != 1 or n_digit == n_secret: # We want the last bit not to be 1. 19 | break 20 | kwargs['secret'] = secret 21 | kwargs['secret_idx'] = secret_idx 22 | self.auxiliary_level = kwargs.get('auxiliary_level', 1) 23 | self.use_long_cot = False 24 | self.position_sensitive = False 25 | super().__init__(kwargs, generate) 26 | 27 | def generate(self,n_samples): 28 | for i in tqdm(range(n_samples)): 29 | self._generate(is_fromlist = 'input_list' in self.kwargs, idx = i) 30 | 31 | def _generate(self, is_fromlist = False, idx = None): 32 | n_digit = self.n_digit 33 | if is_fromlist == True: 34 | q = [] 35 | tmp = self.input_list[idx] 36 | for k in range(n_digit): 37 | q.append(tmp&1) 38 | tmp>>=1 39 | q = torch.tensor(q) 40 | else: # random 41 | q = torch.randint(0, 2, [n_digit]) 42 | if(not self.position_sensitive): 43 | self.input_ids.append([_.item() for _ in list(q)]) 44 | else: 45 | self.input_ids.append([_.item() + 2 * i for i, _ in enumerate(list(q))]) 46 | self.labels.append([-100] * len(list(q))) 47 | self.auxiliary_labels.append([[2] for _ in range(self.auxiliary_level)]) 48 | auxiliary_length = (self.n_secret + 2) // self.auxiliary_level 49 | cnt = 0 50 | for id in range(n_digit): 51 | if(self.use_long_cot or self.secret[id]): 52 | cnt += 1 53 | partial_y = (torch.sum(q[:id] * self.secret[:id]) % 2).int().item() 54 | if(self.use_cot): 55 | self.input_ids[-1].append(partial_y) 56 | self.labels[-1].append(partial_y) 57 | self.auxiliary_labels[-1][cnt // auxiliary_length].append(partial_y) 58 | for _ in range(self.auxiliary_level): 59 | self.auxiliary_labels[-1][_] = self.auxiliary_labels[-1][_] + [0] * ((auxiliary_length) - len(self.auxiliary_labels[-1][_])) 60 | y = (torch.sum(q * self.secret) % 2).int().item() 61 | self.input_ids[-1].append(y) 62 | self.labels[-1].append(y) 63 | self.attention_mask.append([1] * len(self.input_ids[-1])) 64 | self.auxiliary_inputs.append(self.auxiliary_labels[-1]) 65 | self.data.append(q) 66 | 67 | def save(self, output_dir): 68 | super().save(output_dir) 69 | metadata = { 70 | 'secret': [_.item() for _ in list(self.secret)] 71 | } 72 | with open(f'{output_dir}/metadata.json', 'w') as f: 73 | json.dump(metadata, f, indent=2) 74 | 75 | def get_name(self): 76 | return f'binary_{self.n_samples}_{self.n_digit}_{self.n_secret}_{self.use_cot}_{self.use_long_cot}_{self.position_sensitive}' 77 | -------------------------------------------------------------------------------- /Figures/Fig1/fig1.2.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import json 4 | import os 5 | 6 | def get_result(path): 7 | for item in os.listdir(path): 8 | if item == "tmp": 9 | continue 10 | item_path = os.path.join(path, item) 11 | if os.path.isdir(item_path): 12 | path = item_path 13 | path = os.path.join(path,"trainer_state.json") 14 | if os.path.isfile(path): 15 | with open(path, 'r') as f: 16 | result = json.load(f) 17 | return result 18 | 19 | def get_data(result): 20 | accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ = [], [], [], [], [] 21 | for index, result_step in enumerate(result["log_history"]): 22 | if index % 2 == 1: 23 | accuracy_.append(result_step["eval_exact_match"]) 24 | evaluation_loss_.append(result_step["eval_loss"]) 25 | entropy_.append(result_step["mean_min_attention_entropy_over_token"]) 26 | num_epoch_.append(result_step["epoch"]) 27 | else: 28 | training_loss_.append(result_step["loss"]) 29 | return accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ 30 | 31 | n_digits = 30 32 | lr = 1e-4 33 | total_samples = 10000000 34 | n_samples = 10000000 35 | 36 | embedding = "gpt2_tiny_wpetrain" 37 | k = 3 38 | (num_hidden_layers, num_attention_heads) = (4,4) 39 | fig, axs = plt.subplots(1, 2, figsize=(10, 4.2)) 40 | plt.subplots_adjust(wspace=0.7) 41 | 42 | all_lines = [] 43 | all_labels = [] 44 | lr = 1e-4 45 | 46 | for i in range(2): 47 | CoT = ["True", "False"][i] 48 | path = f"model/binary_{n_samples}_{n_digits}_{k}_{CoT}_False_False_{total_samples}_LR={lr}_WD=0.0_1GPU*512Batch_{embedding}.py_#layer={num_hidden_layers}_#head={num_attention_heads}" 49 | result = get_result(path) 50 | accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ = get_data(result) 51 | entropy_ = np.array(entropy_) 52 | 53 | axs[i].set_title("with CoT" if CoT == "True" else "without CoT", fontsize=17) 54 | total_samples_ = [item * n_samples for item in num_epoch_] 55 | 56 | idlayer = 0 57 | lth = len(total_samples_) // 10 58 | ax2 = axs[i].twinx() 59 | 60 | line_accuracy, = ax2.plot(total_samples_[:lth], accuracy_[:lth], label='Evaluation Accuracy', 61 | color='red', linewidth=2, linestyle='--', alpha = 0.6) 62 | 63 | ax2.set_ylim(0.45, 1.1) 64 | ax2.set_ylabel('Evaluation Accuracy', fontsize=17) 65 | ax2.tick_params(axis='y', labelsize=14) 66 | 67 | colors = plt.cm.viridis(np.linspace(0, 1, num_attention_heads)) 68 | for j, color in zip(range(num_attention_heads), colors): 69 | line, = axs[i].plot(total_samples_[:lth], entropy_[:, idlayer, j][:lth], label=f'head {j+1}', 70 | linewidth=2, color=color, alpha=0.8) 71 | 72 | if i == 0 and idlayer == 0: 73 | all_lines.append(line) 74 | all_labels.append(f'head {j+1}') 75 | 76 | axs[i].set_ylabel('Normalized Attention Entropy', fontsize=17) 77 | axs[i].set_xlabel('Iterations × Batch Size', fontsize=17) 78 | axs[i].set_ylim(0, 1.1) 79 | axs[i].set_xticks([0, 500000, 1000000]) 80 | axs[i].tick_params(axis='both', labelsize=14) 81 | axs[i].grid(True, which='both', linestyle='--', linewidth=0.5) 82 | 83 | all_lines.append(line_accuracy) 84 | all_labels.append('Accuracy') 85 | 86 | if i == 0: 87 | axs[i].legend(all_lines, all_labels, fontsize=12, loc='right') 88 | 89 | plt.tight_layout() 90 | plt.show() 91 | plt.savefig(f'Figures/Figs/fig1.2.pdf') 92 | plt.savefig(f'Figures/Figs/fig1.2.svg') -------------------------------------------------------------------------------- /Figures/Fig10/fig10.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import json 4 | import os 5 | 6 | def get_result(path): 7 | for item in os.listdir(path): 8 | item_path = os.path.join(path, item) 9 | if os.path.isdir(item_path): 10 | path = item_path 11 | path = os.path.join(path,"trainer_state.json") 12 | with open(path, 'r') as f: 13 | result = json.load(f) 14 | return result 15 | 16 | def get_data(result): 17 | accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ = [], [], [], [], [] 18 | for index, result_step in enumerate(result["log_history"]): 19 | if index % 2 == 1: 20 | accuracy_.append(result_step["eval_exact_match"]) 21 | evaluation_loss_.append(result_step["eval_loss"]) 22 | entropy_.append(result_step["mean_min_attention_entropy_over_token"]) 23 | num_epoch_.append(result_step["epoch"]) 24 | else: 25 | training_loss_.append(result_step["loss"]) 26 | return accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ 27 | accuracy_color = '#1f77b4' 28 | training_loss_color = '#ff7f0e' 29 | evaluation_loss_color = '#7f7f7f' 30 | 31 | n_digits = 20 32 | lr = 1e-4 33 | total_samples = 10000000 34 | num_layers = 4 35 | num_heads = 4 36 | (embedding, k) = ("gpt2_tiny_wpetrain",6) 37 | fig, axs = plt.subplots(5, 3, figsize=(12, 16)) 38 | for j,(num_layers, num_heads) in enumerate([(1,2), (2,3), (4,4)]): 39 | all_lines = [] 40 | all_labels = [] 41 | for i in range(5): 42 | n_samples = [5000, 10000, 50000, 100000, 1000000][i] 43 | path = f"model/binary_{n_samples}_{n_digits}_{k}_False_False_False_{total_samples}_LR={lr}_WD=0.0_1GPU*512Batch_{embedding}.py_#layer={num_layers}_#head={num_heads}" 44 | result = get_result(path) 45 | accuracy_, training_loss_, evaluation_loss_, entropy_, num_epoch_ = get_data(result) 46 | steps_ = [item * n_samples / 512 for item in num_epoch_] 47 | 48 | ax1 = axs[i,j] 49 | ax1.plot(steps_, accuracy_, color=accuracy_color, label='Evaluation Accuracy') 50 | ax1.set_xlabel('Steps', fontsize=14) 51 | if j == 0: 52 | ax1.set_ylabel(f'{n_samples} training samples\nAccuracy', fontsize=16) 53 | else: 54 | ax1.set_ylabel(f'Accuracy', fontsize=16) 55 | 56 | 57 | ax1.set_ylim(0.48, 1.05) 58 | ax1.tick_params(axis='y', labelsize=12) 59 | ax1.tick_params(axis='x', labelsize=12) 60 | 61 | ax2 = ax1.twinx() 62 | ax2.plot(steps_, training_loss_, color=training_loss_color, label='Training Loss') 63 | ax2.plot(steps_, evaluation_loss_, color=evaluation_loss_color, label='Evaluation Loss') 64 | ax2.set_ylabel('Loss', fontsize=16) 65 | ax2.set_ylim(0, 1.2) 66 | ax2.tick_params(axis='y', labelsize=12) 67 | 68 | ax1.grid(True, which='both', linestyle='--', linewidth=0.5) 69 | 70 | if i == 0: 71 | lines_1, labels_1 = ax1.get_legend_handles_labels() 72 | lines_2, labels_2 = ax2.get_legend_handles_labels() 73 | all_lines.extend(lines_1 + lines_2) 74 | all_labels.extend(labels_1 + labels_2) 75 | 76 | if i == 0: 77 | ax1.set_title(f"{num_layers} layer{'s' if num_layers != 1 else ''} {num_heads} heads", fontsize=18) 78 | 79 | 80 | fig.legend(all_lines, all_labels, loc='upper center', bbox_to_anchor=(0.5,0), fontsize=17, frameon=False) 81 | 82 | plt.tight_layout() 83 | plt.savefig(f'Figures/Figs/fig10.pdf', dpi=300, bbox_inches='tight') 84 | plt.savefig(f'Figures/Figs/flg10.svg', dpi=300, bbox_inches='tight') 85 | plt.show() 86 | -------------------------------------------------------------------------------- /Figures/Fig6-7/fig7.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import json 4 | 5 | num_layer = 28 6 | num_head = 28 7 | 8 | for text_type in ['notonlyanswer','onlyanswer']: 9 | for average_type in ['min', 'average']: 10 | with open(f'Experiment_Section/GSM8K/Additional_Experiment/qwen_CoT_{text_type}_{average_type}.json', 'r', encoding='utf-8') as file: 11 | qwen_CoT = json.load(file) 12 | with open(f'Experiment_Section/GSM8K/Additional_Experiment/qwen_noCoT_{text_type}_{average_type}.json', 'r', encoding='utf-8') as file: 13 | qwen_noCoT = json.load(file) 14 | with open(f'Experiment_Section/GSM8K/Additional_Experiment/math_CoT_{text_type}_{average_type}.json', 'r', encoding='utf-8') as file: 15 | math_CoT = json.load(file) 16 | total_data = len(qwen_CoT) 17 | print(total_data) 18 | sum_qwenCoT_entropy = [[0 for _ in range(num_head)] for _ in range(num_layer)] 19 | sum_qwennoCoT_entropy = [[0 for _ in range(num_head)] for _ in range(num_layer)] 20 | sum_mathCoT_entropy = [[0 for _ in range(num_head)] for _ in range(num_layer)] 21 | for id in range(total_data): 22 | for i in range(num_layer): 23 | for j in range(num_head): 24 | sum_qwenCoT_entropy[i][j] += qwen_CoT[id]["attention_maps"][i][j] 25 | sum_qwennoCoT_entropy[i][j] += qwen_noCoT[id]["attention_maps"][i][j] 26 | sum_mathCoT_entropy[i][j] += math_CoT[id]["attention_maps"][i][j] 27 | qwenCoT_entropy = [[0 for _ in range(num_head)] for _ in range(num_layer)] 28 | qwennoCoT_entropy = [[0 for _ in range(num_head)] for _ in range(num_layer)] 29 | mathCoT_entropy = [[0 for _ in range(num_head)] for _ in range(num_layer)] 30 | for i in range(num_layer): 31 | for j in range(num_head): 32 | qwenCoT_entropy[i][j] = sum_qwenCoT_entropy[i][j]/total_data 33 | qwennoCoT_entropy[i][j] = sum_qwennoCoT_entropy[i][j]/total_data 34 | mathCoT_entropy[i][j] = sum_mathCoT_entropy[i][j]/total_data 35 | for i in range(num_layer): 36 | qwenCoT_entropy[i].sort() 37 | qwennoCoT_entropy[i].sort() 38 | mathCoT_entropy[i].sort() 39 | 40 | indices = np.arange(1,len(qwennoCoT_entropy)+1) 41 | 42 | bar_width = 1 43 | fig, axs = plt.subplots(7, 4, figsize=(31, 45)) 44 | 45 | for i in range (28): 46 | layer = i 47 | ax = axs[i//4,i%4] 48 | 49 | ax.bar(indices, qwennoCoT_entropy[layer], width=bar_width, label='Qwen2-7B+\nNo CoT', color='#f6bebf') 50 | ax.bar(indices, qwenCoT_entropy[layer], width=bar_width, label='Qwen2-7B+\nWith CoT', color='#4187A2') 51 | ax.bar(indices, mathCoT_entropy[layer], width=bar_width, label='Qwen2-Math-7B+\nWith CoT', color='#C6DBAD') 52 | ax.plot(indices, mathCoT_entropy[layer], marker='o', color='#59ac50', linestyle='-', linewidth=2, markersize=8) 53 | ax.plot(indices, qwenCoT_entropy[layer], marker='D', color='#244c7e', linestyle='-', linewidth=2, markersize=8) 54 | ax.plot(indices, qwennoCoT_entropy[layer], marker='x', color='#C95762', linestyle='-', linewidth=2, markersize=8) 55 | 56 | ax.set_title(f'Layer {layer+1}',fontsize=36) 57 | ax.set_xlabel('head',fontsize=36) 58 | if i%4 == 0: 59 | ax.set_ylabel('Normalized Attention Entropy',fontsize=28) 60 | ax.tick_params(axis='both', which='major', labelsize=30) 61 | ax.set_xticks([1,28]) 62 | handles, labels = axs[0,0].get_legend_handles_labels() 63 | fig.legend(handles, labels, loc='lower center', fontsize=30, bbox_to_anchor=(0.93, 0.5), frameon=False) 64 | plt.tight_layout(rect=[0, 0, 0.83, 1]) 65 | plt.show() 66 | plt.savefig(f'Figures/Figs/fig7r.pdf') -------------------------------------------------------------------------------- /src/model/gpt2.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoConfig, LlamaForCausalLM, AutoTokenizer 2 | import torch 3 | 4 | # PATH = 'gpt2' 5 | PATH = 'huggingface_transformer_model/gpt2' 6 | 7 | def get_gpt_transformer_config( 8 | vocab_size = 50264, 9 | hidden_size=768, 10 | intermediate_size=1024, 11 | num_hidden_layers=12, 12 | num_attention_heads=12, 13 | max_position_embeddings=4096, 14 | ): 15 | config = AutoConfig.from_pretrained(PATH) 16 | config.vocab_size = vocab_size 17 | config.hidden_size = hidden_size 18 | config.intermediate_size = intermediate_size 19 | config.num_hidden_layers = num_hidden_layers 20 | config.num_attention_heads = num_attention_heads 21 | config.num_key_value_heads = num_attention_heads 22 | config.max_position_embeddings = max_position_embeddings 23 | return config 24 | 25 | 26 | def get_gpt_transformer_model_from_config( 27 | onehot_embed = False, 28 | wpe_train = False, 29 | vocab_size = 50264, 30 | hidden_size=32, 31 | intermediate_size=128, 32 | num_hidden_layers=16, 33 | num_attention_heads=4, 34 | max_position_embeddings=4096 35 | ): 36 | config = get_gpt_transformer_config( 37 | vocab_size=vocab_size, 38 | hidden_size=hidden_size, 39 | intermediate_size=intermediate_size, 40 | num_hidden_layers=num_hidden_layers, 41 | num_attention_heads=num_attention_heads, 42 | max_position_embeddings=max_position_embeddings 43 | ) 44 | model = AutoModelForCausalLM.from_config(config) 45 | if onehot_embed: 46 | model.transformer.wte = torch.nn.Embedding(vocab_size, hidden_size, dtype = torch.bfloat16) 47 | model.transformer.wpe.weight.data.requires_grad = True 48 | manual_embed = torch.nn.init.eye_(torch.empty(vocab_size, hidden_size)).to(dtype = torch.bfloat16) 49 | manual_embed.requires_grad = False 50 | model.transformer.wte.weight.data = manual_embed 51 | model.lm_head.weight.data = manual_embed.clone() 52 | else: 53 | model.lm_head.weight.data = model.lm_head.weight.data.clone() 54 | if wpe_train: 55 | model.transformer.wpe.weight.data.requires_grad = True 56 | return model 57 | 58 | def get_gpt_transformer_model( 59 | from_config, 60 | onehot_embed = False, 61 | wpe_train = False, 62 | vocab_size = 50264, 63 | hidden_size=32, 64 | intermediate_size=128, 65 | num_hidden_layers=16, 66 | num_attention_heads=4, 67 | max_position_embeddings=4096 68 | ): 69 | if from_config: 70 | model = get_gpt_transformer_model_from_config( 71 | onehot_embed = onehot_embed, 72 | wpe_train = wpe_train, 73 | vocab_size=vocab_size, 74 | hidden_size=hidden_size, 75 | intermediate_size=intermediate_size, 76 | num_hidden_layers=num_hidden_layers, 77 | num_attention_heads=num_attention_heads, 78 | max_position_embeddings=max_position_embeddings 79 | ).to(dtype = torch.bfloat16) 80 | else: 81 | model = LlamaForCausalLM.from_pretrained(PATH, torch_dtype = torch.bfloat16, resume_download = True) 82 | print(f'Transformer Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9:.2f}B') 83 | return model 84 | 85 | 86 | if __name__ == '__main__': 87 | # pretrained 88 | # model = get_transformer_model(from_config = False) 89 | # tokenizer = AutoTokenizer.from_pretrained(PATH) 90 | # tokenized_input = tokenizer("Hello, who are you?", return_tensors = 'pt') 91 | # input_ids = tokenized_input['input_ids'] 92 | # attention_mask = tokenized_input['attention_mask'] 93 | # output = model(input_ids, attention_mask = attention_mask) 94 | # scratch 95 | model = get_gpt_transformer_model(from_config = True, onehot_embed = True, vocab_size = 2) 96 | from IPython import embed; embed() 97 | -------------------------------------------------------------------------------- /Figures/Fig6-7/fig6.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import json 4 | 5 | 6 | 7 | num_layer = 28 8 | num_head = 28 9 | 10 | for text_type in ['notonlyanswer','onlyanswer']: 11 | for average_type in ['min', 'average']: 12 | with open(f'Experiment_Section/GSM8K/Additional_Experiment/qwen_CoT_{text_type}_{average_type}.json', 'r', encoding='utf-8') as file: 13 | qwen_CoT = json.load(file) 14 | with open(f'Experiment_Section/GSM8K/Additional_Experiment/qwen_noCoT_{text_type}_{average_type}.json', 'r', encoding='utf-8') as file: 15 | qwen_noCoT = json.load(file) 16 | with open(f'Experiment_Section/GSM8K/Additional_Experiment/math_CoT_{text_type}_{average_type}.json', 'r', encoding='utf-8') as file: 17 | math_CoT = json.load(file) 18 | total_data = len(qwen_CoT) 19 | print(total_data) 20 | sum_qwenCoT_entropy = [[0 for _ in range(num_head)] for _ in range(num_layer)] 21 | sum_qwennoCoT_entropy = [[0 for _ in range(num_head)] for _ in range(num_layer)] 22 | sum_mathCoT_entropy = [[0 for _ in range(num_head)] for _ in range(num_layer)] 23 | for id in range(total_data): 24 | for i in range(num_layer): 25 | for j in range(num_head): 26 | sum_qwenCoT_entropy[i][j] += qwen_CoT[id]["attention_maps"][i][j] 27 | sum_qwennoCoT_entropy[i][j] += qwen_noCoT[id]["attention_maps"][i][j] 28 | sum_mathCoT_entropy[i][j] += math_CoT[id]["attention_maps"][i][j] 29 | qwenCoT_entropy = [[0 for _ in range(num_head)] for _ in range(num_layer)] 30 | qwennoCoT_entropy = [[0 for _ in range(num_head)] for _ in range(num_layer)] 31 | mathCoT_entropy = [[0 for _ in range(num_head)] for _ in range(num_layer)] 32 | for i in range(num_layer): 33 | for j in range(num_head): 34 | qwenCoT_entropy[i][j] = sum_qwenCoT_entropy[i][j]/total_data 35 | qwennoCoT_entropy[i][j] = sum_qwennoCoT_entropy[i][j]/total_data 36 | mathCoT_entropy[i][j] = sum_mathCoT_entropy[i][j]/total_data 37 | for i in range(num_layer): 38 | qwenCoT_entropy[i].sort() 39 | qwennoCoT_entropy[i].sort() 40 | mathCoT_entropy[i].sort() 41 | 42 | indices = np.arange(1,len(qwennoCoT_entropy)+1) 43 | 44 | bar_width = 1 45 | fig, axs = plt.subplots(1, 4, figsize=(31, 6.5)) # 设置图表大小 46 | 47 | for i in range (4): 48 | layer = [0,10,20,27][i] 49 | ax = axs[i] 50 | 51 | ax.bar(indices, qwennoCoT_entropy[layer], width=bar_width, label='Qwen2-7B+\nNo CoT', color='#f6bebf') 52 | ax.bar(indices, qwenCoT_entropy[layer], width=bar_width, label='Qwen2-7B+\nWith CoT', color='#4187A2') 53 | ax.bar(indices, mathCoT_entropy[layer], width=bar_width, label='Qwen2-Math-7B+\nWith CoT', color='#C6DBAD') 54 | ax.plot(indices, mathCoT_entropy[layer], marker='o', color='#59ac50', linestyle='-', linewidth=2, markersize=8) 55 | ax.plot(indices, qwenCoT_entropy[layer], marker='D', color='#244c7e', linestyle='-', linewidth=2, markersize=8) 56 | ax.plot(indices, qwennoCoT_entropy[layer], marker='x', color='#C95762', linestyle='-', linewidth=2, markersize=8) 57 | 58 | ax.set_title(f'Layer {layer+1}',fontsize=36) 59 | ax.set_xlabel('head',fontsize=36) 60 | if i == 0: 61 | ax.set_ylabel('Normalized Attention Entropy',fontsize=28) 62 | if layer == 27: 63 | ax.set_yticks([0,0.1]) 64 | ax.tick_params(axis='both', which='major', labelsize=30) 65 | ax.set_xticks([1,28]) 66 | handles, labels = axs[0].get_legend_handles_labels() 67 | fig.legend(handles, labels, loc='center left', fontsize=30, bbox_to_anchor=(0.83, 0.5), frameon=False) 68 | plt.subplots_adjust(wspace=0.1) 69 | plt.tight_layout(rect=[0, 0, 0.83, 1]) 70 | plt.show() 71 | plt.savefig(f'Figures/Figs/fig6.pdf') -------------------------------------------------------------------------------- /Figures/Fig2/fig2.1.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib import colors, ticker 3 | import matplotlib.gridspec as gridspec 4 | import seaborn as sns 5 | import numpy as np 6 | import json 7 | import os 8 | 9 | def get_result(path): 10 | for item in os.listdir(path): 11 | item_path = os.path.join(path, item) 12 | if os.path.isdir(item_path): 13 | path = item_path 14 | path = os.path.join(path,"trainer_state.json") 15 | if os.path.isfile(path): 16 | with open(path, 'r') as f: 17 | result = json.load(f) 18 | return result 19 | else: 20 | return None 21 | 22 | def get_sample_complexity(result): 23 | for index, result_step in enumerate(result["log_history"]): 24 | if index % 2 == 1: 25 | if result_step["eval_exact_match"] == 1: 26 | # breakpoint() 27 | return result_step["epoch"] * 10000000 28 | return 10000000 29 | 30 | n_digits = 30 31 | n_samples = 10000000 32 | total_samples = 10000000 33 | 34 | 35 | def my_plot(sample_complexity): 36 | log_norm = colors.LogNorm(vmin=10**4, vmax=10**7) 37 | fig = plt.figure(figsize=(20, 9)) 38 | gs = gridspec.GridSpec(2, 5, width_ratios=[1, 1, 1, 1, 0.1]) 39 | 40 | axes = [fig.add_subplot(gs[0, i]) for i in range(4)] + [fig.add_subplot(gs[1, i]) for i in range(4)] 41 | 42 | for i in range(4): 43 | sns.heatmap(sample_complexity["True"][i], ax=axes[i], cmap="YlGnBu", norm=log_norm, cbar=False) 44 | axes[i].set_title(f'k = {i+1}', fontsize=24) 45 | axes[i].set_xlabel("Head Number", fontsize=24) 46 | axes[i].set_ylabel("Layer Number", fontsize=24) 47 | axes[i].set_xticks([0.5, 1.5, 2.5, 3.5]) 48 | axes[i].set_xticklabels([1, 2, 3, 4],fontsize=18) 49 | axes[i].set_yticks([0.5, 1.5, 2.5, 3.5]) 50 | axes[i].set_yticklabels([1, 2, 3, 4],fontsize=18) 51 | if i == 0: 52 | axes[i].text(-1.3, 2, "with CoT", fontsize=26, rotation=90, va="center") 53 | 54 | for i in range(4): 55 | sns.heatmap(sample_complexity["False"][i], ax=axes[i+4], cmap="YlGnBu", norm=log_norm, cbar=False) 56 | axes[i+4].set_xlabel("Head Number", fontsize=24) 57 | axes[i+4].set_ylabel("Layer Number", fontsize=24) 58 | axes[i+4].set_xticks([0.5, 1.5, 2.5, 3.5]) 59 | axes[i+4].set_xticklabels([1, 2, 3, 4],fontsize=18) 60 | axes[i+4].set_yticks([0.5, 1.5, 2.5, 3.5]) 61 | axes[i+4].set_yticklabels([1, 2, 3, 4],fontsize=18) 62 | if i == 0: 63 | axes[i+4].text(-1.3, 2, "without CoT", fontsize=26, rotation=90, va="center") 64 | 65 | cbar_ax = fig.add_subplot(gs[:, 4]) 66 | cbar = fig.colorbar(axes[-1].collections[0], cax=cbar_ax) 67 | cbar.set_label('Sample Complexity', fontsize=26) 68 | cbar.ax.yaxis.set_tick_params(labelsize=22) 69 | 70 | def custom_formatter(x, pos): 71 | if x == 10**7: 72 | return '≥ 10⁷' 73 | else: 74 | return f'$10^{{{int(np.log10(x))}}}$' 75 | 76 | cbar.formatter = ticker.FuncFormatter(custom_formatter) 77 | cbar.update_ticks() 78 | 79 | plt.tight_layout() 80 | plt.show() 81 | plt.savefig("Figures/Figs/fig2.1.pdf") 82 | plt.savefig("Figures/Figs/fig2.1.svg") 83 | 84 | 85 | for embedding in ["gpt2_tiny_wpetrain"]: 86 | sample_complexity = {} 87 | sample_complexity["True"] = [[[10000000 for _ in range(4)] for _ in range(4)] for _ in range(4)] 88 | sample_complexity["False"] = [[[10000000 for _ in range(4)] for _ in range(4)] for _ in range(4)] 89 | for CoT in ["True","False"]: 90 | for k in [1,2,3,4]: 91 | for num_layers in [1,2,3,4]: 92 | for num_heads in [1,2,3,4]: 93 | for lr in [6e-5, 8e-5, 1e-4]: 94 | path = f"model/binary_{n_samples}_{n_digits}_{k}_{CoT}_False_False_{total_samples}_LR={lr}_WD=0.0_1GPU*512Batch_{embedding}.py_#layer={num_layers}_#head={num_heads}" 95 | result = get_result(path) 96 | if result != None: 97 | sample_complexity[CoT][k-1][num_layers-1][num_heads-1] = min(sample_complexity[CoT][k-1][num_layers-1][num_heads-1], get_sample_complexity(result)) 98 | print(k,sample_complexity[CoT][k-1]) 99 | my_plot(sample_complexity) -------------------------------------------------------------------------------- /Figures/Fig8/fig8.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoModel 2 | import matplotlib.pyplot as plt 3 | from seaborn import heatmap 4 | import seaborn as sns 5 | import numpy as np 6 | import os 7 | import sys 8 | import json 9 | import torch 10 | import transformers 11 | from matplotlib.colors import LinearSegmentedColormap 12 | sys.path.append("src") 13 | from data import load_dataset 14 | from model import get_model 15 | 16 | def load_dataset_(dataset_dir): 17 | dataset_type = 'BinaryDataset' 18 | val_dataset = load_dataset(os.path.join(dataset_dir, 'val'), dataset_type) 19 | return val_dataset 20 | 21 | def load_model_(embedding, model_dir, num_layers, num_heads): 22 | model_config_path = f"config/{embedding}.py" 23 | model_args = eval(open(model_config_path).read()) 24 | model_args["num_hidden_layers"] = num_layers 25 | model_args["num_attention_heads"] = num_heads 26 | print(model_args) 27 | model = get_model( 28 | **model_args 29 | ) 30 | for item in os.listdir(model_dir): 31 | item_path = os.path.join(model_dir, item) 32 | if os.path.isdir(item_path): 33 | model_dir = item_path 34 | model_dir = os.path.join(model_dir,'pytorch_model.bin') 35 | model.load_state_dict(torch.load(model_dir)) 36 | return model 37 | 38 | n_digits = 40 39 | total_samples = 10000000 40 | num_layers = 2 41 | num_heads = 2 42 | n_samples = 10000000 43 | embedding = "gpt2_tiny_wpetrain" 44 | k = 20 45 | secret = [1,2,3,7,12,13,14,16,18,19,22,24,25,26,27,29,33,34,36,38] 46 | CoT = "True" 47 | lr = 6e-5 48 | 49 | model_dir = f"model/binary_{n_samples}_{n_digits}_{k}_{CoT}_False_False_{total_samples}_LR={lr}_WD=0.0_1GPU*512Batch_{embedding}.py_#layer={num_layers}_#head={num_heads}" 50 | dataset_dir = f"data/Nonintersect_Binary/binary_{n_samples}_{n_digits}_{k}_{CoT}_False_False" 51 | val_dataset = load_dataset_(dataset_dir) 52 | model = load_model_(embedding, model_dir, num_layers, num_heads) 53 | model.eval() 54 | 55 | num = 1 56 | with torch.no_grad(): 57 | attention_sum = None 58 | for i in range(num): 59 | data = val_dataset[i] 60 | print(i) 61 | output = model(input_ids = torch.tensor(data['input_ids']).unsqueeze(dim = 0),output_attentions=True) 62 | attention = output.attentions 63 | attention = torch.stack(attention, dim=0).transpose(0, 1) 64 | if attention_sum is None: 65 | attention_sum = attention[0] 66 | else: 67 | attention_sum += attention[0] 68 | average_attention = attention_sum / num 69 | average_attention = average_attention.to(torch.float32) 70 | 71 | colors = ["#2E004E", '#E60073', '#FF9933', '#FFE4CC'] 72 | cmap = LinearSegmentedColormap.from_list("purple_yellow", colors) 73 | 74 | fig, axs = plt.subplots(num_layers, num_heads, figsize=(20*num_heads, 10*num_layers)) 75 | for idx_layer in range(num_layers): 76 | for idx_head in range(num_heads): 77 | mask = np.zeros_like(average_attention[0][0][n_digits:], dtype=bool) 78 | for i in range(k): 79 | for j in range(k-i): 80 | mask[i, -j-1] = True 81 | ax = axs[idx_layer][idx_head] 82 | heatmap_plot = heatmap( 83 | average_attention[idx_layer][idx_head][n_digits:], 84 | cmap=cmap, 85 | ax=ax, 86 | cbar=True, 87 | square=True, 88 | vmin=0, 89 | vmax=1, 90 | mask=mask, 91 | cbar_kws={'shrink': 0.8, 'ticks': np.linspace(0, 1, 6), 'pad' : 0.04} 92 | ) 93 | 94 | heatmap_plot.figure.axes[-1].tick_params(labelsize=20) 95 | 96 | xticks = ax.get_xticklabels() 97 | for j, label in enumerate(xticks): 98 | if j in secret: 99 | label.set_color("red") 100 | else: 101 | label.set_color("black") 102 | 103 | yticks = ax.get_yticks() 104 | new_ytick_labels = [str(int(tick) + 30) for tick in yticks] 105 | ax.set_yticklabels(new_ytick_labels) 106 | ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right',rotation_mode='anchor') 107 | ax.set_yticklabels(ax.get_yticklabels(), rotation=45, ha='right') 108 | 109 | ax.tick_params(axis='both', which='major', labelsize=20) 110 | ax.set_ylabel(f"Layer {idx_layer+1}", fontsize = 50) 111 | ax.set_xlabel(f"Head {idx_head+1}", fontsize = 50) 112 | 113 | fig.tight_layout() 114 | fig.subplots_adjust(right=1.1) 115 | fig.suptitle("Attention Pattern", fontsize=28 * num_layers * 0.9, x=0.48, y=1) 116 | plt.savefig(f'Figures/Figs/fig8.pdf') 117 | plt.savefig(f'Figures/Figs/fig8.svg') 118 | -------------------------------------------------------------------------------- /Figures/Fig4/fig4.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import json 4 | import os 5 | 6 | def get_result(path): 7 | for item in os.listdir(path): 8 | item_path = os.path.join(path, item) 9 | if os.path.isdir(item_path): 10 | path = item_path 11 | path = os.path.join(path,"trainer_state.json") 12 | with open(path, 'r') as f: 13 | result = json.load(f) 14 | return result 15 | 16 | def get_data(result): 17 | accuracy_, training_loss_, evaluation_loss_, num_epoch_ = [], [], [], [] 18 | for index, result_step in enumerate(result["log_history"]): 19 | if index % 2 == 1: 20 | accuracy_.append(result_step["eval_exact_match"]) 21 | evaluation_loss_.append(result_step["eval_loss"]) 22 | num_epoch_.append(result_step["epoch"]) 23 | else: 24 | training_loss_.append(result_step["loss"]) 25 | return accuracy_, training_loss_, evaluation_loss_, num_epoch_ 26 | 27 | def my_plot(path, accuracy, num_epoch): 28 | colors = [plt.cm.Greens, plt.cm.Blues, plt.cm.Purples, plt.cm.Reds, plt.cm.Oranges, plt.cm.Greys] 29 | n_groups = 4 30 | n_curves_per_group = 4 31 | font_size = 18 32 | 33 | fig, (ax, ax_color) = plt.subplots(1, 2, figsize=(12, 4), gridspec_kw={'width_ratios': [5, 1]}) 34 | 35 | for i in range(n_groups): 36 | for j in range(n_curves_per_group): 37 | num_layers = [1, 2, 3, 4, 6, 8][i] 38 | num_heads = [1, 2, 3, 4, 6, 8][j] 39 | color = colors[i](j / n_curves_per_group * 0.8 + 0.3) 40 | ax.plot(num_epoch[i][j], accuracy[i][j], color=color, label=f'{num_layers} layers, {num_heads} heads') 41 | ax.set_ylim(0.48, 1) 42 | ax.set_xlabel('epoch number', fontsize=font_size) 43 | ax.set_ylabel('Evaluation Accuracy', fontsize=font_size) 44 | 45 | ax.tick_params(axis='both', which='major', labelsize=font_size) 46 | # ax.grid() 47 | 48 | for i in range(n_groups): 49 | for j in range(n_curves_per_group): 50 | ax_color.add_patch(plt.Rectangle((j, i), 1, 1, color=colors[i](j / n_curves_per_group * 0.8 + 0.3))) 51 | 52 | ax_color.set_aspect('equal') 53 | ax_color.set_xlim(0, n_curves_per_group) 54 | ax_color.set_ylim(0, n_groups) 55 | ax_color.set_xticks(np.arange(n_curves_per_group) + 0.5) 56 | ax_color.set_yticks(np.arange(n_groups) + 0.5) 57 | ax_color.set_xticklabels([f'{[1, 2, 3, 4, 6, 8][j]}' for j in range(4)], fontsize=font_size) 58 | ax_color.set_yticklabels([f'{[1, 2, 3, 4, 6, 8][i]}' for i in range(4)], fontsize=font_size) 59 | ax_color.set_xlabel('Head number', fontsize=15) 60 | ax_color.set_ylabel('Layer number', fontsize=15) 61 | 62 | ax_color.set_position([0.72, 0.65, 0.2, 0.2]) 63 | ax.axvline(x=5, color='blue', linestyle='--', linewidth=2) 64 | 65 | xticks = list(ax.get_xticks()) 66 | if 5 not in xticks: 67 | xticks.append(5) 68 | breakpoint() 69 | xticks = xticks[2:] 70 | ax.set_xticks(xticks) 71 | for label in ax.get_xticklabels(): 72 | if label.get_text() == '5': 73 | label.set_color('blue') 74 | ax.set_xticklabels([f'{int(tick)}' if tick != 20 else '5' for tick in xticks]) 75 | ax.set_xlim(0,1000) 76 | 77 | 78 | plt.tight_layout() 79 | plt.savefig(path, dpi=300) 80 | plt.show() 81 | 82 | 83 | accuracy = [[[] for _ in range(6)] for _ in range(6)] 84 | num_epoch = [[[] for _ in range(6)] for _ in range(6)] 85 | n_digits = 20 86 | 87 | 88 | accuracy = [[[] for _ in range(6)] for _ in range(6)] 89 | num_epoch = [[[] for _ in range(6)] for _ in range(6)] 90 | 91 | (embedding,k) = ("gpt2_tiny_wpetrain",6) 92 | n_samples = 10000 93 | total_samples = 10000000 94 | for idx_layers in range(4): 95 | for idx_heads in range(4): 96 | lr = 1e-4 if idx_layers == 0 else 6e-5 97 | num_layers = [1,2,3,4,6,8][idx_layers] 98 | num_heads = [1,2,3,4,6,8][idx_heads] 99 | path = f"model/Experiment2.1/binary_{n_samples}_{n_digits}_{k}_False_False_False_{total_samples}_LR={lr}_WD=0.0_1GPU*512Batch_{embedding}.py_#layer={num_layers}_#head={num_heads}" 100 | result = get_result(path) 101 | accuracy_, training_loss_, evaluation_loss_, num_epoch_ = get_data(result) 102 | if accuracy[idx_layers][idx_heads] == []: 103 | accuracy[idx_layers][idx_heads] = accuracy_ 104 | num_epoch[idx_layers][idx_heads] = num_epoch_ 105 | else: 106 | print(accuracy_[-1], accuracy[idx_layers][idx_heads]) 107 | if accuracy_[-1] > accuracy[idx_layers][idx_heads][-1] or num_epoch_[-1] < num_epoch[idx_layers][idx_heads][-1]: # 学得更好或学得更快 108 | accuracy[idx_layers][idx_heads] = accuracy_ 109 | num_epoch[idx_layers][idx_heads] = num_epoch_ 110 | my_plot(f"Fugures/Figs/fig4.pdf",accuracy,num_epoch) -------------------------------------------------------------------------------- /Figures/Fig3/fig3.2.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoModel 2 | import matplotlib.pyplot as plt 3 | from seaborn import heatmap 4 | import seaborn as sns 5 | import numpy as np 6 | import os 7 | import sys 8 | import json 9 | import torch 10 | import transformers 11 | from matplotlib.colors import LinearSegmentedColormap 12 | sys.path.append("src") 13 | from data import load_dataset 14 | from model import get_model 15 | 16 | def load_dataset_(dataset_dir): 17 | dataset_type = 'BinaryDataset' 18 | val_dataset = load_dataset(os.path.join(dataset_dir, 'val'), dataset_type) 19 | return val_dataset 20 | 21 | def load_model_(embedding, model_dir, num_layers, num_heads): 22 | model_config_path = f"config/{embedding}.py" 23 | model_args = eval(open(model_config_path).read()) 24 | model_args["num_hidden_layers"] = num_layers 25 | model_args["num_attention_heads"] = num_heads 26 | print(model_args) 27 | model = get_model( 28 | **model_args 29 | ) 30 | for item in os.listdir(model_dir): 31 | item_path = os.path.join(model_dir, item) 32 | if os.path.isdir(item_path): 33 | model_dir = item_path 34 | model_dir = os.path.join(model_dir,'pytorch_model.bin') 35 | model.load_state_dict(torch.load(model_dir)) 36 | return model 37 | 38 | n_digits = 40 39 | total_samples = 10000000 40 | num_layers = 1 41 | num_heads = 1 42 | n_samples = 10000000 43 | embedding = "gpt2_tiny_wpetrain" 44 | CoT = "True" 45 | lr = 6e-5 46 | for (n_samples, num_layers, num_heads) in [(1000000,2,2),(1000000,4,4)]: 47 | for k in [20]: 48 | secret = [1,2,3,7,12,13,14,16,18,19,22,24,25,26,27,29,33,34,36,38] 49 | model_dir = f"model/binary_{n_samples}_{n_digits}_{k}_{CoT}_False_False_{total_samples}_LR={lr}_WD=0.0_1GPU*512Batch_{embedding}.py_#layer={num_layers}_#head={num_heads}" 50 | dataset_dir = f"data/Nonintersect_Binary/binary_{n_samples}_{n_digits}_{k}_{CoT}_False_False" 51 | val_dataset = load_dataset_(dataset_dir) 52 | model = load_model_(embedding, model_dir, num_layers, num_heads) 53 | model.eval() 54 | 55 | num = 1 56 | with torch.no_grad(): 57 | attention_sum = None 58 | for i in range(num): 59 | data = val_dataset[i] 60 | print(i) 61 | output = model(input_ids = torch.tensor(data['input_ids']).unsqueeze(dim = 0),output_attentions=True) 62 | attention = output.attentions 63 | attention = torch.stack(attention, dim=0).transpose(0, 1) 64 | if attention_sum is None: 65 | attention_sum = attention[0] 66 | else: 67 | attention_sum += attention[0] 68 | average_attention = attention_sum / num 69 | average_attention = average_attention.to(torch.float32) 70 | 71 | colors = ["#2E004E", '#E60073', '#FF9933', '#FFE4CC'] 72 | cmap = LinearSegmentedColormap.from_list("purple_yellow", colors) 73 | if num_layers == 1 and num_heads == 1: 74 | fig, axs = plt.subplots(num_layers, num_heads, figsize=(22, 7.5)) 75 | mask = np.zeros_like(average_attention[0][0][n_digits:], dtype=bool) 76 | for i in range(k): 77 | for j in range(k-i): 78 | mask[i, -j-1] = True 79 | heatmap_plot = heatmap( 80 | average_attention[0][0][n_digits:], 81 | cmap=cmap, 82 | ax=axs, 83 | cbar=True, 84 | square=True, 85 | vmin=0, 86 | vmax=1, 87 | mask=mask, 88 | cbar_kws={'shrink': 1, 'ticks': np.linspace(0, 1, 6), 'pad' : 0.04} 89 | ) 90 | 91 | heatmap_plot.figure.axes[-1].tick_params(labelsize=20) 92 | 93 | xticks = axs.get_xticklabels() 94 | for j, label in enumerate(xticks): 95 | if j in secret: 96 | label.set_color("red") 97 | else: 98 | label.set_color("black") 99 | 100 | yticks = axs.get_yticks() 101 | new_ytick_labels = [str(int(tick) + 30) for tick in yticks] 102 | axs.set_yticklabels(new_ytick_labels) 103 | axs.set_xticklabels(axs.get_xticklabels(), rotation=45, ha='right',rotation_mode='anchor') 104 | axs.set_yticklabels(axs.get_yticklabels(), rotation=45, ha='right') 105 | 106 | axs.tick_params(axis='both', which='major', labelsize=20) 107 | 108 | fig.tight_layout() 109 | fig.subplots_adjust(right=1.1) 110 | fig.suptitle("Attention Pattern", fontsize=28, x=0.48, y=1) 111 | plt.savefig(f'Figures/Figs/fig3.2.pdf') 112 | plt.savefig(f'Figures/Figs/fig3.2.svg') 113 | -------------------------------------------------------------------------------- /Figures/Fig6-7/work.py: -------------------------------------------------------------------------------- 1 | import os 2 | import transformers 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | import pandas as pd 5 | import json 6 | import torch 7 | from math import log 8 | import numpy as np 9 | import re 10 | 11 | def compute_entropy(attention_vector): 12 | entropy = -torch.sum(attention_vector * torch.log(attention_vector + 1e-9), dim=-1) 13 | return entropy 14 | 15 | def compute_min_attention_entropy_over_token(attentions, input_token_num = 0): 16 | layer_head_entropies = [] 17 | for layer_attention_ in attentions: 18 | layer_attention = layer_attention_[0] 19 | head_entropies = [] 20 | for head_index in range(layer_attention.shape[0]): 21 | head_attention = layer_attention[head_index, :, :] 22 | log_tensor = torch.log(torch.arange(2, head_attention.shape[0] + 1, dtype=torch.float)).unsqueeze(0).to(head_attention.device) 23 | entropy = compute_entropy(head_attention) 24 | normalized_entropy = entropy[1:] = entropy[1:] / log_tensor 25 | min_entropy_over_token = normalized_entropy[:,input_token_num:].min(dim=1).values.item() 26 | head_entropies.append(min_entropy_over_token) 27 | layer_head_entropies.append(head_entropies) 28 | return layer_head_entropies 29 | 30 | def compute_average_attention_entropy_over_token(attentions, input_token_num = 0): 31 | layer_head_entropies = [] 32 | for layer_attention_ in attentions: 33 | layer_attention = layer_attention_[0] 34 | head_entropies = [] 35 | for head_index in range(layer_attention.shape[0]): 36 | head_attention = layer_attention[head_index, :, :] 37 | log_tensor = torch.log(torch.arange(2, head_attention.shape[0] + 1, dtype=torch.float)).unsqueeze(0).to(head_attention.device) 38 | entropy = compute_entropy(head_attention) 39 | normalized_entropy = entropy[1:] = entropy[1:] / log_tensor 40 | average_entropy_over_token = torch.mean(normalized_entropy[:,input_token_num:],dim = 1).item() 41 | head_entropies.append(average_entropy_over_token) 42 | layer_head_entropies.append(head_entropies) 43 | return layer_head_entropies 44 | 45 | 46 | device = "cuda" 47 | math_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-Math-7B") 48 | math_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-Math-7B") 49 | math_model.to(device) 50 | math_model.eval() 51 | math_model.config.output_attentions = True 52 | 53 | qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B") 54 | qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-7B") 55 | qwen_model.to(device) 56 | qwen_model.eval() 57 | qwen_model.config.output_attentions = True 58 | 59 | # Load dataset 60 | gsm8k_dataset = pd.read_parquet('gsm8k/test-00000-of-00001.parquet').to_dict(orient='records') 61 | for sample in gsm8k_dataset: 62 | sample['CoT_text'] = sample['question'] + ' '+sample['answer'] 63 | match = re.search(r'(\d+)\s*$', sample['answer']) 64 | sample['noCoT_text'] = sample['question'] + f" The answer is {match.group(1)}." 65 | print(sample['noCoT_text']) 66 | 67 | 68 | results = [] 69 | for sample in gsm8k_dataset: 70 | text = sample["CoT_text"] 71 | inputs = math_tokenizer(text, return_tensors='pt', truncation=True, padding=True) 72 | inputs = {key: value.to(device) for key, value in inputs.items()} 73 | print(inputs) 74 | output_ids = math_model(**inputs) 75 | print(output_ids.keys()) 76 | 77 | attention = output_ids.attentions 78 | min_normalized_entropy = compute_min_attention_entropy_over_token(attention, sample["math_input_token_num"]) 79 | 80 | print(min_normalized_entropy) 81 | results.append({ 82 | "test_text": text, 83 | "attention_maps": min_normalized_entropy 84 | }) 85 | 86 | path = "Figures/Fig6-7/math_CoT.json" 87 | with open(path, "w", encoding="utf-8") as f: 88 | json.dump(results, f, ensure_ascii=False, indent=4) 89 | 90 | results = [] 91 | for sample in gsm8k_dataset: 92 | text = sample["CoT_text"] 93 | inputs = qwen_tokenizer(text, return_tensors='pt', truncation=True, padding=True) 94 | inputs = {key: value.to(device) for key, value in inputs.items()} 95 | print(inputs) 96 | output_ids = qwen_model(**inputs) 97 | print(output_ids.keys()) 98 | 99 | attention = output_ids.attentions 100 | min_normalized_entropy = compute_min_attention_entropy_over_token(attention, sample["math_input_token_num"]) 101 | 102 | print(min_normalized_entropy) 103 | results.append({ 104 | "test_text": text, 105 | "attention_maps": min_normalized_entropy 106 | }) 107 | 108 | path = "Figures/Fig6-7/qwen_CoT.json" 109 | with open(path, "w", encoding="utf-8") as f: 110 | json.dump(results, f, ensure_ascii=False, indent=4) 111 | 112 | results = [] 113 | for sample in gsm8k_dataset: 114 | text = sample["noCoT_text"] 115 | inputs = qwen_tokenizer(text, return_tensors='pt', truncation=True, padding=True) 116 | inputs = {key: value.to(device) for key, value in inputs.items()} 117 | print(inputs) 118 | output_ids = qwen_model(**inputs) 119 | print(output_ids.keys()) 120 | 121 | attention = output_ids.attentions 122 | min_normalized_entropy = compute_min_attention_entropy_over_token(attention, sample["math_input_token_num"]) 123 | 124 | print(min_normalized_entropy) 125 | results.append({ 126 | "test_text": text, 127 | "attention_maps": min_normalized_entropy 128 | }) 129 | 130 | path = "Figures/Fig6-7/qwen_noCoT.json" 131 | with open(path, "w", encoding="utf-8") as f: 132 | json.dump(results, f, ensure_ascii=False, indent=4) -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoModel 2 | from transformers import Trainer, EvalPrediction 3 | import transformers 4 | import numpy as np 5 | # import wandb 6 | import sys 7 | import os 8 | import json 9 | import torch 10 | current_dir = os.path.dirname(os.path.abspath(__file__)) 11 | parent_dir = os.path.dirname(current_dir) 12 | sys.path.append(parent_dir) 13 | from utils import set_seed, parse_args 14 | from data import load_dataset 15 | from model import get_model 16 | 17 | def compute_entropy(attention_vector): 18 | entropy = -torch.sum(attention_vector * torch.log(attention_vector + 1e-9), dim=-1) 19 | return entropy 20 | 21 | log_tensor = None 22 | def compute_mean_min_attention_entropy_over_token(attentions): 23 | layer_head_entropies = [] 24 | global log_tensor 25 | for layer_attention in attentions: # (batch_size, num_heads, sequence_length, sequence_length) 26 | head_entropies = [] 27 | for head_index in range(layer_attention.shape[1]): # num_heads 28 | head_attention = layer_attention[:, head_index, :, :] # (batch_size, sequence_length, sequence_length) 29 | if log_tensor == None: 30 | log_tensor = torch.log(torch.arange(2, head_attention.shape[1] + 1, dtype=torch.float)).unsqueeze(0).to(head_attention.device) 31 | entropy = compute_entropy(head_attention) # (batch_size,sequence_length) 32 | normalized_entropy = entropy[:, 1:] = entropy[:, 1:] / log_tensor 33 | min_entropy_over_token = normalized_entropy.min(dim=1).values 34 | mean_entropy = min_entropy_over_token.mean().item() 35 | head_entropies.append(mean_entropy) 36 | layer_head_entropies.append(head_entropies) 37 | return layer_head_entropies 38 | 39 | class CustomTrainer(Trainer): 40 | pass 41 | def evaluation_loop(self, dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix: str = "eval"): 42 | output = super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix) 43 | 44 | all_entropies = [] 45 | for step, inputs in enumerate(dataloader): 46 | inputs = self._prepare_inputs(inputs) 47 | with torch.no_grad(): 48 | outputs = self.model(**inputs,output_attentions=True) 49 | if 'attentions' in outputs: 50 | entropies = compute_mean_min_attention_entropy_over_token(outputs['attentions']) 51 | all_entropies.append(entropies) 52 | 53 | mean_entropy = torch.tensor(all_entropies).float().mean(dim=0).tolist() 54 | output.metrics['mean_min_attention_entropy_over_token'] = mean_entropy 55 | 56 | return output 57 | 58 | class EarlyStoppingCallback(transformers.TrainerCallback): 59 | def on_evaluate(self, args, state, control, **kwargs): 60 | eval_accuracy = kwargs.get("metrics")["eval_exact_match"] 61 | print(eval_accuracy) 62 | if eval_accuracy and eval_accuracy >= 1.0: 63 | print(f"Stopping training early. Evaluation accuracy has reached {eval_accuracy}.") 64 | control.should_training_stop = True 65 | 66 | def main(): 67 | args = parse_args() 68 | set_seed(args.seed) 69 | train_dataset = load_dataset(args.dataset_dir, args.dataset_type) 70 | val_dataset = load_dataset(os.path.join(args.dataset_dir, 'val'), args.dataset_type) 71 | model_args = eval(open(args.model_config_path).read()) 72 | model_args["num_hidden_layers"] = args.num_hidden_layers 73 | model_args["num_attention_heads"] = args.num_attention_heads 74 | print(model_args) 75 | model = get_model( 76 | **model_args 77 | ) 78 | if(args.model_dir): 79 | import safetensors 80 | safetensors.torch.load_model(model, os.path.join(args.model_dir, 'model.safetensors')) 81 | output_dir = f"{args.output_dir}{args.dataset_dir.split('/')[-1]}_{args.total_training_samples}_LR={args.lr}_WD={args.weight_decay}_{args.world_size}GPU*{args.batch_size}Batch_{args.model_config_path.split('/')[-1]}_#layer={args.num_hidden_layers}_#head={args.num_attention_heads}" 82 | training_args = transformers.TrainingArguments( 83 | output_dir=output_dir, 84 | num_train_epochs=args.total_training_samples / len(train_dataset), 85 | per_device_train_batch_size=args.batch_size, 86 | per_device_eval_batch_size=args.batch_size, 87 | warmup_steps=0, 88 | weight_decay=args.weight_decay, 89 | logging_dir='./logs', 90 | logging_steps= args.log_interval // (args.batch_size * args.world_size), 91 | save_steps = args.save_interval // (args.batch_size *args.world_size), 92 | save_total_limit = 1, 93 | evaluation_strategy="steps", 94 | eval_steps= args.eval_interval // (args.batch_size * args.world_size), 95 | learning_rate = args.lr, 96 | label_names = ['labels'], 97 | save_safetensors = False 98 | ) 99 | def compute_metrics(eval_pred): 100 | predictions, labels = eval_pred 101 | if model_args["model_type"] == 'gpt2_custom_simpler': 102 | predictions = np.squeeze((predictions >= 0.5).astype(int)) 103 | print(predictions) 104 | else: 105 | predictions = np.argmax(predictions, axis=-1) 106 | predictions = predictions[:, :-1] 107 | labels = labels[:, 1:] 108 | exact_match_cnt = 0 109 | cnt = 0 110 | for prediction, label in zip(predictions, labels): 111 | correct = (prediction == label) + (label == -100) 112 | cnt += 1 113 | exact_match_cnt += correct.all() 114 | return {"exact_match": exact_match_cnt / cnt} 115 | trainer = CustomTrainer( 116 | model=model, args=training_args, train_dataset=train_dataset, 117 | compute_metrics=compute_metrics, 118 | eval_dataset = val_dataset, 119 | ) 120 | trainer.train(ignore_keys_for_eval = ['past_key_values', 'dreamer_loss_1', 'dreamer_loss_0']) 121 | trainer.save_model(output_dir=args.output_dir) 122 | 123 | if __name__ == '__main__': 124 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Parity with Chain-of-Thought 2 | 3 | This repository contains the code to reproduce the results from our paper *From Sparse Dependence to Sparse Attention: Unveiling How Chain-of-Thought Enhances Transformer Sample Efficiency*. 4 | 5 | ## Prerequisites 6 | 7 | * [PyTorch](https://pytorch.org/get-started/locally/) 8 | * [transformers](https://github.com/huggingface/transformers) 9 | 10 | ## Data Generation 11 | 12 | To generate synthetic parity data, first create a configuration file `data_config.py`. Below is an example configuration (`configs/data_config.py`): 13 | 14 | ```python{ 15 | { 16 | 'dataset_type': 'BinaryDataset', 17 | 'n_samples': 10000, 18 | 'val_samples': 2048, 19 | 'n_digit': 30, 20 | 'n_secret': 3, 21 | 'use_cot': True, 22 | } 23 | ``` 24 | 25 | - `n_samples`: Number of training samples. 26 | - `val_samples`: Number of validation samples. 27 | - `n_digit`: Number of input variables. 28 | - `n_secret`: Number of secret variables. 29 | - `use_cot`: Set to `True` to use Chain-of-Thought (CoT) data. 30 | 31 | After defining the configuration, generate the data by running: 32 | 33 | ```bash 34 | python src/generate_data.py --config path-to-data_config.py 35 | ``` 36 | 37 | 38 | 39 | ## Training 40 | 41 | To train a transformer model on the parity problem, use the following command: 42 | 43 | ```bash 44 | export CUDA_VISIBLE_DEVICES=0 45 | PROJECT=Istree 46 | MODEL=transformer 47 | total_training_sample=100000 48 | training_samples=10000 49 | n_digits=30 50 | k=3 51 | CoT=True 52 | lr=6e-5 53 | num_layers=4 54 | num_heads=3 55 | python src/train.py \ 56 | --world_size 1 \ 57 | --total_training_samples ${training_samples} \ 58 | --model_type transformer \ 59 | --model_config_path config/gpt2_tiny_wpetrain.py \ 60 | --dataset_dir data/Nonintersect_Binary/binary_${training_samples}_${n_digits}_${k}_${CoT}_False_False \ 61 | --dataset_type BinaryDataset \ 62 | --output_dir model/ \ 63 | --batch_size 512 \ 64 | --lr ${lr} \ 65 | --weight_decay 0 \ 66 | --log_interval 2048 \ 67 | --save_interval 2048 \ 68 | --eval_interval 2048 \ 69 | --report_to_wandb \ 70 | --num_hidden_layers ${num_layers} \ 71 | --num_attention_heads ${num_heads} \ 72 | ``` 73 | 74 | Where: 75 | 76 | - `training_samples`: Number of samples in the training set. 77 | - `total_training_samples`: Total number of samples used during training (`iterations = total_training_samples / batch_size`). 78 | - `lr`: Learning rate. 79 | - `num_layers`: Number of hidden layers in the model. 80 | - `num_heads`: Number of attention heads. 81 | - `CoT`: Set to `True` to train with Chain-of-Thought data, or `False` otherwise. 82 | 83 | Training results will be saved in the `model/` directory. 84 | 85 | ## Reproducing the Results 86 | 87 | ### Figure 1, 2 88 | 89 | To reproduce the results shown in Figures 1 and 2 of our paper, follow these steps: 90 | 91 | 1. Use the following configurations to generate data: 92 | 93 | | n_samples | n_digits | n_secret | use_cot | 94 | | ---------- | -------- | --------- | ----------- | 95 | | $10000000$ | $30$ | $1,2,3,4$ | True, False | 96 | 97 | 2. Train the model with the following settings: 98 | 99 | | total_training_samples | training_samples | n_digits | k | CoT | num_layers | num_heads | lr | 100 | | ---------------------- | ---------------- | -------- | --------- | ----------- | ---------- | --------- | ----------------------------------------------- | 101 | | $10^7$ | $10^7$ | $30$ | $1,2,3,4$ | True, False | $1,2,3,4$ | $1,2,3,4$ | $6\times10^{-5}, 8\times10^{-5},1\times10^{-4}$ | 102 | 103 | 3. To reproduce Figure 1, run`Figures/Fig1/fig1.1.py` and `Figures/Fig1/fig1.2.py`. 104 | 105 | To reproduce Figure 2, run `Figures/Fig2/fig2.1.py`, `Figures/Fig2/fig2.2.py` and `Figures/Fig2/fig2.3.py`. 106 | 107 | ### Figure 3, 8 108 | 109 | 1. Use the following configurations to generate data: 110 | 111 | | n_samples | n_digits | n_secret | use_cot | 112 | | ---------- | -------- | ------------------ | ------- | 113 | | $1000000$ | $100$ | $20,30,\cdots,100$ | True | 114 | | $10000000$ | $40$ | $20$ | True | 115 | 116 | 2. Train the model with the following settings: 117 | 118 | | total_training_samples | training_samples | n_digits | k | CoT | num_layers | num_heads | lr | 119 | | ---------------------- | ---------------- | -------- | ------------------ | ---- | ---------- | --------- | ----------------------------------------------- | 120 | | $10^7$ | $10^6$ | $100$ | $20,30,\cdots,100$ | True | $1$ | $1$ | $6\times10^{-5}, 8\times10^{-5},1\times10^{-4}$ | 121 | | $10^7$ | $10^7$ | $40$ | $20$ | True | $1$ | $1$ | $6\times10^{-5}$ | 122 | | $10^7$ | $10^7$ | $40$ | $20$ | True | $2$ | $2$ | $6\times10^{-5}$ | 123 | 124 | 3. To reproduce Figure 3, run `Figures/Fig2/fig3.1.py` and`Figures/Fig2/fig3.2.py`. 125 | 126 | ### Figure 4 127 | 128 | 1. Use the following configurations to generate data: 129 | 130 | | n_samples | n_digits | n_secret | use_cot | 131 | | --------- | -------- | -------- | ----------- | 132 | | $10000$ | $20$ | $6$ | True, False | 133 | 134 | 2. Train the model with the following settings: 135 | 136 | | total_training_samples | training_samples | n_digits | k | CoT | num_layers | num_heads | lr | 137 | | ---------------------- | ---------------- | -------- | ---- | ----- | ---------- | --------- | ----------------- | 138 | | $10^7$ | $10000$ | $20$ | $6$ | False | $1,2,3,4$ | $1,2,3,4$ | $1\times10^{-4}$ | 139 | | $10^7$ | $10000$ | $20$ | $6$ | True | $1$ | $1$ | $1\times 10^{-5}$ | 140 | 141 | 3. To reproduce Figure 4, run`Figures/Fig4/fig4.py`. 142 | 143 | ### Figure 5, 10, 11 144 | 145 | 1. Use the following configurations to generate data: 146 | 147 | | n_samples | n_digits | n_secret | use_cot | 148 | | ---------------------------- | -------- | -------- | ------- | 149 | | $5000,10000,50000,10^5,10^6$ | $20$ | $6$ | False | 150 | 151 | 2. Train the model with the following settings: 152 | 153 | | total_training_samples | training_samples | n_digits | k | CoT | num_layers | num_heads | lr | 154 | | ---------------------- | ---------------------------- | -------- | ---- | ----- | ---------- | --------- | ---------------- | 155 | | $10^7$ | $5000,10000,50000,10^5,10^6$ | $20$ | $6$ | False | $1$ | $2$ | $1\times10^{-4}$ | 156 | | $10^7$ | $5000,10000,50000,10^5,10^6$ | $20$ | $6$ | False | $2$ | $3$ | $1\times10^{-4}$ | 157 | | $10^7$ | $5000,10000,50000,10^5,10^6$ | $20$ | $6$ | False | $4$ | $4$ | $1\times10^{-4}$ | 158 | 159 | 3. To reproduce Figure 5, run `Figures/Fig5/fig5.1.py`, `Figures/Fig5/fig5.2.py`. 160 | 161 | To reproduce Figure 10, run `Figures/Fig10/fig10.py` 162 | 163 | To reproduce Figure 11, run `Figures/Fig11/fig11.py` 164 | 165 | ### Figure 9 166 | 167 | 1. Use the following configurations to generate data: 168 | 169 | | n_samples | n_digits | n_secret | use_cot | 170 | | ---------------- | -------- | -------- | ------- | 171 | | $10^4,10^5,10^6$ | $20$ | $12$ | False | 172 | | $10^4,10^6$ | $20$ | $12$ | True | 173 | 174 | 2. Train the model with the following settings: 175 | 176 | | total_training_samples | training_samples | n_digits | k | CoT | num_layers | num_heads | lr | 177 | | ---------------------- | ---------------- | -------- | ---- | ---- | ------------- | ------------- | ------------------------------------------------ | 178 | | $10^7$ | $10^4,10^5,10^6$ | $20$ | $12$ | True | $1,2,3,4,6,8$ | $1,2,3,4,6,8$ | $6\times10^{-5}, 8\times10^{-5},1\times10^{-4}$ | 179 | | $10^7$ | $10^4,10^6$ | $20$ | $12$ | True | $1$ | $1$ | $6\times10^{-5},8\times 10 ^{-5},1\times10^{-4}$ | 180 | 181 | 3. To reproduce Figure 9, run `Figures/Fig9/fig9.py`. 182 | 183 | ### Figure 6,7 184 | 185 | To reproduce Figure 6 and 7, first run `Figures/Fig6-7/work.py` to compute the normalized attention entropy. Then run `Figures/Fig6-7/fig6.py` and `Figures/Fig6-7/fig7.py`. --------------------------------------------------------------------------------