├── .gitattributes ├── .gitmodules ├── README.md ├── data ├── 4_by_4_mult │ ├── test_bigbench.txt │ ├── train.txt │ └── valid.txt ├── 5_by_5_mult │ ├── test_bigbench.txt │ ├── train.txt │ └── valid.txt └── gsm8k │ ├── fullcot_400kaugmented_math_scaffolding_formula.tgz │ ├── test.txt │ ├── train.txt │ ├── train_no_aug.txt │ └── valid.txt ├── gpt4_baselines ├── README.md ├── explicit_cot │ ├── 4_by_4_mult │ │ ├── cache.db │ │ └── evaluate.py │ ├── 5_by_5_mult │ │ ├── cache.db │ │ └── evaluate.py │ └── gsm8k │ │ ├── cache.db │ │ └── evaluate.py └── no_cot │ ├── 4_by_4_mult │ ├── cache.db │ └── evaluate.py │ ├── 5_by_5_mult │ ├── cache.db │ └── evaluate.py │ └── gsm8k │ ├── cache.db │ └── evaluate.py ├── imgs ├── training_illustration.png ├── training_illustration_a.png ├── training_illustration_b.png └── training_illustration_c.png ├── logs ├── 4_by_4_mult │ ├── gpt2-medium │ │ └── log.generate │ └── gpt2 │ │ └── log.generate ├── 5_by_5_mult │ ├── gpt2-medium │ │ └── log.generate │ └── gpt2 │ │ └── log.generate └── gsm8k │ ├── gpt2-medium │ └── log.generate │ └── gpt2 │ └── log.generate ├── src ├── data.py ├── generate.py ├── models │ ├── configuration_emulator.py │ ├── configuration_student.py │ ├── configuration_teacher.py │ ├── emulator.py │ ├── modeling_gpt2_implicit.py │ ├── student.py │ └── teacher.py ├── train_coupled_emulator_and_student.py ├── train_mind_reading_student.py ├── train_teacher.py ├── train_thought_emulator.py └── utils.py └── src_autoencoder ├── data.py ├── generate.py ├── models ├── autoencoder.py ├── configuration_autoencoder.py ├── configuration_emulator.py ├── configuration_student.py ├── configuration_teacher.py ├── emulator.py ├── modeling_gpt2_implicit.py ├── student.py └── teacher.py ├── train_autoencoder.py ├── train_coupled_emulator_and_student.py ├── train_mind_reading_student.py ├── train_teacher.py ├── train_thought_emulator.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | data/4_by_4_mult/train.txt filter=lfs diff=lfs merge=lfs -text 2 | data/5_by_5_mult/train.txt filter=lfs diff=lfs merge=lfs -text 3 | data/gsm8k/train.txt filter=lfs diff=lfs merge=lfs -text 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "transformers"] 2 | path = transformers 3 | url = git@github.com:da03/implicit_transformers.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implicit Chain of Thought Reasoning via Knowledge Distillation 2 | 3 | Here we provide code to reproduce our results. 4 | 5 | ## Prerequisites 6 | 7 | * [PyTorch](https://pytorch.org/get-started/locally/) 8 | * [transformers](https://github.com/huggingface/transformers) (`pip install transformers`) 9 | 10 | ## Datasets & Pretrained Models & Logs 11 | 12 | All dataset files and log files during inference are included in this repo, with the exception of large training files maintained under Git LFS. Model checkpoints are stored on Google Drive. The folder containing all checkpoints can be found at [this link](https://drive.google.com/drive/folders/1Sclr5bmLZIUcktCaFAeWRTevRGLUwlC_?usp=drive_link). 13 | 14 | * 4 X 4 Mult - GPT-2: [data](data/4_by_4_mult/) [model](https://drive.google.com/drive/folders/1Zp-PFwiHkwq0wuFScjN5R8jDdXdnQYQ_?usp=sharing) [log](logs/4_by_4_mult/gpt2/log.generate) 15 | * 4 X 4 Mult - GPT-2 Medium: [data](data/4_by_4_mult/) [model](https://drive.google.com/drive/folders/1B0e67ifTSTTuUg0Sh-of5135Rh4KQ-2v?usp=sharing) [log](logs/4_by_4_mult/gpt2-medium/log.generate) 16 | * 5 X 5 Mult - GPT-2: [data](data/5_by_5_mult/) [model](https://drive.google.com/drive/folders/1lHa2Xey8jJ3__RsYRhcOFHU7Xfqp7XTG?usp=sharing) [log](logs/5_by_5_mult/gpt2/log.generate) 17 | * 5 X 5 Mult - GPT-2 Medium: [data](data/5_by_5_mult/) [model](https://drive.google.com/drive/folders/18dRIynq0j5EBOnKTpOPaLJWCoMBXZYTi?usp=sharing) [log](logs/5_by_5_mult/gpt2-medium/log.generate) 18 | * GSM8K - GPT-2: [data](data/5_by_5_mult/) [model](https://drive.google.com/drive/folders/1aFBBcUr_vHtaDqgpU5A1ErEvrJyX-cEO?usp=sharing) [log](logs/gsm8k/gpt2/log.generate) 19 | * GSM8K - GPT-2 Medium: [data](data/5_by_5_mult/) [model](https://drive.google.com/drive/folders/1zFXfwq5jDjgKpbUVafY5KC0LmJpYXjQK?usp=sharing) [log](logs/gsm8k/gpt2-medium/log.generate) 20 | 21 | ## Usage 22 | 23 | We use 4 X 4 Mult with GPT2-Small as an example. We assume that the working directory is `implicit_chain_of_thought` throughout this document. 24 | 25 | ### Data Format 26 | 27 | The format of training, validation, and test files looks like below: 28 | 29 | ``` 30 | [input 1]||[chain-of-thought 1] #### [output 1] 31 | [input 2]||[chain-of-thought 2] #### [output 3] 32 | [input 3]||[chain-of-thought 2] #### [output 3] 33 | ... 34 | ``` 35 | 36 | As an example, let's take a look at the first line from the 4 X 4 Mult test set in [data/4_by_4_mult/test_bigbench.txt](data/4_by_4_mult/test_bigbench.txt): 37 | 38 | ``` 39 | 9 1 7 3 * 9 4 3 3||1 7 4 3 3 + 0 6 7 8 4 1 ( 1 3 2 2 8 1 ) + 0 0 7 5 1 1 1 ( 1 3 9 7 9 2 1 ) + 0 0 0 7 5 1 1 1 #### 1 3 9 4 5 4 2 1 40 | ``` 41 | 42 | In this example, the input is `9 1 7 3 * 9 4 3 3` (corresponding to `3719*3349`), the chain-of-thought is `1 7 4 3 3 + 0 6 7 8 4 1 ( 1 3 2 2 8 1 ) + 0 0 7 5 1 1 1 ( 1 3 9 7 9 2 1 ) + 0 0 0 7 5 1 1 1`, and the output is `1 3 9 4 5 4 2 1` (corresponding to `12454931`). 43 | 44 | Note that for Teacher Training, (a) Mind-Reading the Teacher, and (b) Thought Emulation, the chain-of-thought steps are used; but for (c) Couple and Optimize the chain-of-thought steps are not used. 45 | 46 | ### Training 47 | 48 | #### Prerequisite: Teacher Training 49 | 50 | Our approach is based on distilling a teacher models horizontal reasoning process into the vertical reasoning process of the emulator and the student. Therefore, we need to first train a teacher on the task of explicit chain-of-thought reasoning. 51 | 52 | ``` 53 | export FOLDER=data/4_by_4_mult 54 | export MODEL=gpt2 55 | export EPOCHS=1 56 | export LR=5e-5 57 | export BSZ=32 58 | export SAVE=train_models/4_by_4_mult/gpt2/teacher 59 | echo $SAVE 60 | mkdir -p $SAVE 61 | TOKENIZERS_PARALLELISM=false CUDA_VISIBLE_DEVICES=0 stdbuf -oL -eL python src/train_teacher.py \ 62 | --train_path ${FOLDER}/train.txt \ 63 | --val_path ${FOLDER}/valid.txt \ 64 | --epochs $EPOCHS \ 65 | --lr $LR \ 66 | --batch_size $BSZ \ 67 | --base_model $MODEL \ 68 | --save_model $SAVE \ 69 | > ${SAVE}/log.train 2>&1& 70 | ``` 71 | 72 | #### (a) Mind-Reading the Teacher 73 | 74 | ![](imgs/training_illustration_a.png) 75 | 76 | ``` 77 | export FOLDER=data/4_by_4_mult 78 | export DELTA=dynamic 79 | export MODEL=gpt2 80 | export EPOCHS=40 81 | export LR=5e-5 82 | export BSZ=32 83 | export TEACHER=train_models/4_by_4_mult/gpt2/teacher/checkpoint_0 84 | export SAVE=train_models/4_by_4_mult/gpt2/student_initial 85 | mkdir -p $SAVE 86 | TOKENIZERS_PARALLELISM=false CUDA_VISIBLE_DEVICES=0 stdbuf -oL -eL python src/train_mind_reading_student.py \ 87 | --train_path ${FOLDER}/train.txt \ 88 | --val_path ${FOLDER}/valid.txt \ 89 | --epochs $EPOCHS \ 90 | --lr $LR \ 91 | --batch_size $BSZ \ 92 | --base_model $MODEL \ 93 | --teacher $TEACHER \ 94 | --save_model $SAVE \ 95 | --delta $DELTA \ 96 | > ${SAVE}/log.train 2>&1& 97 | ``` 98 | 99 | #### (b) Thought Emulation 100 | 101 | ![](imgs/training_illustration_b.png) 102 | 103 | ``` 104 | export FOLDER=data/4_by_4_mult 105 | export DELTA=dynamic 106 | export MODEL=gpt2 107 | export EPOCHS=40 108 | export LR=5e-5 109 | export BSZ=32 110 | export MIXTURE_SIZE=1 111 | export TEACHER=train_models/4_by_4_mult/gpt2/teacher/checkpoint_0 112 | export SAVE=train_models/4_by_4_mult/gpt2/emulator_initial 113 | mkdir -p $SAVE 114 | TOKENIZERS_PARALLELISM=false CUDA_VISIBLE_DEVICES=0 stdbuf -oL -eL python src/train_thought_emulator.py \ 115 | --train_path ${FOLDER}/train.txt \ 116 | --val_path ${FOLDER}/valid.txt \ 117 | --epochs $EPOCHS \ 118 | --lr $LR \ 119 | --batch_size $BSZ \ 120 | --base_model $MODEL \ 121 | --teacher $TEACHER \ 122 | --save_model $SAVE \ 123 | --delta $DELTA \ 124 | --mixture_size ${MIXTURE_SIZE} \ 125 | > ${SAVE}/log.train 2>&1& 126 | ``` 127 | 128 | #### (c) Couple and Optimize 129 | 130 | ![](imgs/training_illustration_c.png) 131 | 132 | ``` 133 | export FOLDER=data/4_by_4_mult 134 | export EPOCHS=40 135 | export LR=5e-5 136 | export BSZ=32 137 | export STUDENT=train_models/4_by_4_mult/gpt2/student_initial/checkpoint_6 138 | export EMULATOR=train_models/4_by_4_mult/gpt2/emulator_initial/checkpoint_5 139 | export SAVE=train_models/4_by_4_mult/gpt2/ 140 | mkdir -p $SAVE 141 | TOKENIZERS_PARALLELISM=false CUDA_VISIBLE_DEVICES=0 stdbuf -oL -eL python src/train_coupled_emulator_and_student.py \ 142 | --train_path ${FOLDER}/train.txt \ 143 | --val_path ${FOLDER}/valid.txt \ 144 | --epochs $EPOCHS \ 145 | --lr $LR \ 146 | --batch_size $BSZ \ 147 | --student $STUDENT \ 148 | --emulator $EMULATOR \ 149 | --save_model $SAVE \ 150 | > ${SAVE}/log.train 2>&1& 151 | ``` 152 | 153 | ### Generation & Evaluation 154 | 155 | Here we use a pretrained model as an example. Download the folder `models/4_by_4_mult/gpt2`, then the following command will run inference and evaluate both accuracy and throughput, logged in file `generation_logs/4_by_4_mult/log.generate`. 156 | 157 | ``` 158 | export FOLDER=data/4_by_4_mult 159 | export STUDENT=models/4_by_4_mult/gpt2/student 160 | export EMULATOR=models/4_by_4_mult/gpt2/emulator 161 | export BSZ=1 162 | export SAVE=generation_logs/4_by_4_mult 163 | mkdir -p $SAVE 164 | TOKENIZERS_PARALLELISM=false CUDA_VISIBLE_DEVICES=0 stdbuf -oL -eL python src/generate.py \ 165 | --batch_size $BSZ \ 166 | --test_path ${FOLDER}/test_bigbench.txt \ 167 | --student_path $STUDENT \ 168 | --emulator_path $EMULATOR \ 169 | > ${SAVE}/log.generate 2>&1& 170 | ``` 171 | 172 | ## Citation 173 | 174 | ``` 175 | @misc{deng2023implicit, 176 | title={Implicit Chain of Thought Reasoning via Knowledge Distillation}, 177 | author={Yuntian Deng and Kiran Prasad and Roland Fernandez and Paul Smolensky and Vishrav Chaudhary and Stuart Shieber}, 178 | year={2023}, 179 | eprint={2311.01460}, 180 | archivePrefix={arXiv}, 181 | primaryClass={cs.CL} 182 | } 183 | ``` 184 | -------------------------------------------------------------------------------- /data/4_by_4_mult/train.txt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9857d1bba98bafb54da6977d98255656c8bddaa96dae69a4fe67a14eafdeb9fe 3 | size 106656000 4 | -------------------------------------------------------------------------------- /data/5_by_5_mult/train.txt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e41642536746292598fb5fa53bf8ebb01358a161a6ba78892d0579ccc1489253 3 | size 158368000 4 | -------------------------------------------------------------------------------- /data/gsm8k/fullcot_400kaugmented_math_scaffolding_formula.tgz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da03/implicit_chain_of_thought/eb3f2611cc190b665222d21e5f735a15e6bc228f/data/gsm8k/fullcot_400kaugmented_math_scaffolding_formula.tgz -------------------------------------------------------------------------------- /data/gsm8k/train.txt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6f9a20bc1476ca65eee9bc5117c2d0582b1f1733d5148ffc0ff29cad2a9e9c6b 3 | size 87580525 4 | -------------------------------------------------------------------------------- /gpt4_baselines/README.md: -------------------------------------------------------------------------------- 1 | # GPT-4 Turbo V.S. GPT-4 Comparison 2 | 3 | According to OpenAI, their new GPT-4 Turbo model is "3X cheaper for input tokens and 2X cheaper for output tokens compared to the original GPT-4 model" ([Schade, 2023](https://help.openai.com/en/articles/8555510-gpt-4-turbo)). But is GPT-4 Turbo as good as GPT-4? To answer this question, we compared GPT-4 Turbo to GPT-4 on three tasks requiring reasoning: 4-by-4 multiplication ([BIG-bench](https://github.com/google/BIG-bench)), 5-by-5 multiplication ([BIG-bench](https://github.com/google/BIG-bench)), and grade school math problems ([GSM8K](https://github.com/openai/grade-school-math)). Note that this set of experiments are adapted from the baselines of our [Implicit Chain of Thought Reasoning via Knowledge Distillation](https://arxiv.org/pdf/2311.01460.pdf) paper. 4 | 5 | ## Results 6 | 7 | | | **4X4** | **Mult** | | **5X5** | **Mult** | | **GSM** | **8K** | 8 | |------------------|---------:|------------|---|---------:|------------|---|---------:|------------| 9 | | | Accuracy | Throughput | | Accuracy | Throughput | | Accuracy | Throughput | 10 | | **No CoT** | | | | | | | | | 11 | | GPT-4 | 3.8% | 1.04 | | 0.1% | 1.02 | | 42.8% | 1.05 | 12 | | GPT-4 Turbo | 6.1% | 1.84 | | 0.3% | 1.71 | | 43.3% | 1.79 | 13 | | **Explicit CoT** | | | | | | | | | 14 | | GPT-4 | 77.8% | 0.09 | | 43.2% | 0.07 | | 90.9% | 0.10 | 15 | | GPT-4 Turbo | 76.6% | 0.31 | | 38.3% | 0.24 | | 91.4% | 0.31 | 16 | 17 | Note that some results are slightly differently from our paper due to us rerunning all baseline evaluations (current evaluations run on Nov 12, 2023). 18 | 19 | ## Usage 20 | 21 | Evaluation scripts of different settings (`no_cot` v.s. `explicit_cot`) and different tasks (`4_by_4_mult`, `5_by_5_mult`, and `gsm8k`) are organized into different folders. The only two required arguments that need to be provided are `--api_key` and `--model`, and there is an optional argument `--overwrite_cache` which makes a new query to OpenAI's API and overwrites the cache file even when an exact same query has been made and cached before. 22 | 23 | * `--api_key`: specifies the OpenAI API key 24 | * `--model`: specifies the OpenAI model to run, such as gpt-4 or gpt-4-1106-preview 25 | * `--overwrite_cache` (optional): if specified, always make new queries to OpenAI's API and overwrites the local cache 26 | 27 | For example, in order to evaluate GPT-4 Turbo's performance on GSM8K using explicit chain-of-thought reasoning, use the below command: 28 | 29 | ``` 30 | export APIKEY=your_openai_api_key 31 | python explicit_cot/gsm8k/evaluate.py --api_key $APIKEY --model gpt-4-1106-preview 32 | ``` 33 | 34 | Note that in order to ensure the reproducibility of results, we have included all cache files in this repo. By default the cache files are used, in which case only the accuracy measurement is meaningful, whereas the throughput measurement is not. 35 | 36 | ## Acknowledgements 37 | 38 | Our table above is inspired by [Chain-of-Thought Hub: Measuring LLMs' Reasoning Performance](https://github.com/FranxYao/chain-of-thought-hub), which contains a more comprehensive comparision of LLM's reasoning performance across a wider range of models and a wider range of tasks. 39 | -------------------------------------------------------------------------------- /gpt4_baselines/explicit_cot/4_by_4_mult/cache.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da03/implicit_chain_of_thought/eb3f2611cc190b665222d21e5f735a15e6bc228f/gpt4_baselines/explicit_cot/4_by_4_mult/cache.db -------------------------------------------------------------------------------- /gpt4_baselines/explicit_cot/4_by_4_mult/evaluate.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import numpy as np 3 | import time 4 | import sys, os, json, random, argparse, openai 5 | import tiktoken 6 | import tqdm 7 | import sqlite3 8 | 9 | 10 | def create_database(cache_file): 11 | conn = sqlite3.connect(cache_file) 12 | c = conn.cursor() 13 | # Create table 14 | c.execute('''CREATE TABLE IF NOT EXISTS main 15 | (key TEXT PRIMARY KEY, prompt TEXT, completion TEXT)''') 16 | conn.commit() 17 | conn.close() 18 | 19 | 20 | def insert_or_update(key, prompt, completion, cache_file): 21 | conn = sqlite3.connect(cache_file) 22 | c = conn.cursor() 23 | c.execute('''INSERT OR REPLACE INTO main 24 | (key, prompt, completion) VALUES (?, ?, ?)''', 25 | (key, prompt, completion)) 26 | conn.commit() 27 | conn.close() 28 | 29 | def retrieve(key, cache_file): 30 | conn = sqlite3.connect(cache_file) 31 | c = conn.cursor() 32 | 33 | c.execute("SELECT prompt, completion FROM main WHERE key=?", (key,)) 34 | result = c.fetchone() 35 | conn.close() 36 | if result: 37 | return (True, result) 38 | else: 39 | return (False, None) 40 | 41 | def get_completion_with_cache(prompt, model, temperature, max_tokens, cache_file, overwrite_cache): 42 | #import pdb; pdb.set_trace() 43 | key = hashlib.sha256(json.dumps({'prompt': prompt, 'model': model, 'temperature': temperature}).encode('utf-8')).hexdigest() 44 | hit, result = retrieve(key, cache_file) 45 | if overwrite_cache: 46 | hit = False 47 | if not hit: 48 | completion = get_completion(prompt, model, temperature, max_tokens) 49 | insert_or_update(key, json.dumps(prompt), completion, cache_file) 50 | else: 51 | print ('hit') 52 | prompt, completion = result 53 | return completion, hit 54 | 55 | 56 | #@retry(wait=wait_exponential(min=60, max=1200), stop=stop_after_attempt(12)) 57 | def chatcompletion_with_backoff(**kwargs): 58 | return openai.ChatCompletion.create(**kwargs) 59 | 60 | 61 | def get_completion(prompt, model, temperature, max_tokens): 62 | """ 63 | Get a completion using the specified language model. 64 | 65 | Args: 66 | prompt (list): The prompt for generating the sentence. 67 | model (str): The name of the language model to use. 68 | temperature (float): Sampling temperature for the model. 69 | max_tokens (int): Maximum number of tokens in the generated sentence. 70 | 71 | Returns: 72 | str: The generated sentence. 73 | """ 74 | generated_text = "" 75 | response = chatcompletion_with_backoff( 76 | model=model, 77 | messages=prompt, 78 | max_tokens=max_tokens, 79 | temperature=temperature, 80 | ) 81 | generated_text = response.choices[0].message.content.strip() 82 | return generated_text 83 | 84 | def construct_prompt(num_shot, examples): 85 | instruction = 'Answer the final question following the exact format of the given examples. Do not output anything else.\n\nExample problems:\n\n' 86 | example_demonstrations = random.sample(examples, num_shot) 87 | prompt = instruction 88 | for i, line in enumerate(example_demonstrations): 89 | s = f"Q: {line['input']}\n" 90 | s += f"A: {line['cot']} #### {line['output']}\n" 91 | prompt += s 92 | prompt += '\n' 93 | prompt += f"Question to answer:\n\nQ:" 94 | context = [{"role": 'user', "content": prompt}] 95 | return context 96 | 97 | def read_examples(filename): 98 | lines = [] 99 | with open(filename) as fin: 100 | for line in fin: 101 | input, cot_and_output = line.strip().split('||') 102 | cot, output = cot_and_output.split(' #### ') 103 | input = input.strip() 104 | num1, num2 = input.split('*') 105 | num1_digits = num1.strip().split()[::-1] 106 | num2_digits = num2.strip().split()[::-1] 107 | num1_str = ''.join(num1_digits) 108 | num2_str = ''.join(num2_digits) 109 | input = num1_str + ' * ' + num2_str 110 | cot = '' 111 | i = 0 112 | num1_int = int(num1_str) 113 | partial_sum = 0 114 | for num2_digit in num2_digits[::-1]: 115 | i += 1 116 | num2_digit = int(num2_digit) * 10**(i-1) 117 | mult = num2_digit * num1_int 118 | partial_sum += mult 119 | cot += f'{i}): {num2_digit} * {num1_int} = {mult} (partial sum {partial_sum-mult} + {mult} = {partial_sum}) ' 120 | final_result = partial_sum 121 | #import pdb; pdb.set_trace() 122 | cot = cot.strip() 123 | output = ''.join(output.strip().split()[::-1]) 124 | assert output.lstrip('0') == str(final_result) 125 | output = output.lstrip('0') 126 | lines.append({'input': input, 'output': output, 'cot': cot}) 127 | return lines 128 | 129 | def main(model, temperature, max_tokens, seed, num_shot, train_file, test_file, cache_file, overwrite_cache): 130 | create_database(cache_file) 131 | random.seed(seed) 132 | train_examples = read_examples(train_file) 133 | test_examples = read_examples(test_file) 134 | 135 | prompts = [] 136 | for _ in test_examples: 137 | prompt = construct_prompt(num_shot, train_examples) 138 | prompts.append(prompt) 139 | 140 | i = 0 141 | correct = 0 142 | total = 0 143 | total_time = 0 144 | not_hit = 0 145 | for example in test_examples: 146 | prompt = prompts[i] 147 | i += 1 148 | prompt[0]['content'] += example['input'] 149 | start_time = time.time() 150 | completion, hit = get_completion_with_cache(prompt, model, temperature, max_tokens, cache_file, overwrite_cache) 151 | answer = completion.split('####')[-1].strip() 152 | if not hit: 153 | not_hit += 1 154 | end_time = time.time() 155 | total_time += end_time - start_time 156 | print (f'throughput: {not_hit / total_time}, total time: {total_time}, total number of examples: {not_hit}') 157 | if answer == example['output']: 158 | correct += 1 159 | total += 1 160 | print (f'accuracy: {correct / total}, correct: {correct}, total: {total}') 161 | sys.stdout.flush() 162 | if not_hit > 0: 163 | print (f'final throughput: {not_hit / total_time}, total time: {total_time}, total number of examples: {not_hit}') 164 | print (f'final accuracy: {correct / total}, correct: {correct}, total: {total}') 165 | 166 | 167 | def parse_arguments(): 168 | """ 169 | Parse command-line arguments using argparse. 170 | 171 | Returns: 172 | argparse.Namespace: An object containing the parsed arguments. 173 | """ 174 | parser = argparse.ArgumentParser(description="Augment") 175 | parser.add_argument("--api_key", type=str, required=True, help="OpenAI API key") 176 | parser.add_argument("--model", type=str, default='gpt-4-1106-preview', help="model to evaluate") 177 | parser.add_argument("--train_file", type=str, default="../data/4_by_4_mult/train.txt") 178 | parser.add_argument("--test_file", type=str, default="../data/4_by_4_mult/test_bigbench.txt") 179 | parser.add_argument("--num_shot", type=int, default=5) 180 | parser.add_argument('--overwrite_cache', action='store_true') 181 | parser.add_argument("--cache_file", type=str, default="explicit_cot/4_by_4_mult/cache.db") 182 | parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for generation") 183 | parser.add_argument("--max_tokens", type=int, default=1200, help="Maximum number of tokens in the generated sentence") 184 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 185 | parser.set_defaults(overwrite_cache=False) 186 | return parser.parse_args() 187 | 188 | 189 | if __name__ == "__main__": 190 | args = parse_arguments() 191 | openai.api_key = args.api_key 192 | 193 | main(model=args.model, \ 194 | temperature=args.temperature, \ 195 | max_tokens=args.max_tokens, \ 196 | seed=args.seed, \ 197 | train_file=args.train_file, \ 198 | test_file=args.test_file, \ 199 | num_shot=args.num_shot, \ 200 | cache_file=args.cache_file, \ 201 | overwrite_cache=args.overwrite_cache) 202 | -------------------------------------------------------------------------------- /gpt4_baselines/explicit_cot/5_by_5_mult/cache.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da03/implicit_chain_of_thought/eb3f2611cc190b665222d21e5f735a15e6bc228f/gpt4_baselines/explicit_cot/5_by_5_mult/cache.db -------------------------------------------------------------------------------- /gpt4_baselines/explicit_cot/5_by_5_mult/evaluate.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import numpy as np 3 | import time 4 | import sys, os, json, random, argparse, openai 5 | import tiktoken 6 | import tqdm 7 | import sqlite3 8 | 9 | 10 | def create_database(cache_file): 11 | conn = sqlite3.connect(cache_file) 12 | c = conn.cursor() 13 | # Create table 14 | c.execute('''CREATE TABLE IF NOT EXISTS main 15 | (key TEXT PRIMARY KEY, prompt TEXT, completion TEXT)''') 16 | conn.commit() 17 | conn.close() 18 | 19 | 20 | def insert_or_update(key, prompt, completion, cache_file): 21 | conn = sqlite3.connect(cache_file) 22 | c = conn.cursor() 23 | c.execute('''INSERT OR REPLACE INTO main 24 | (key, prompt, completion) VALUES (?, ?, ?)''', 25 | (key, prompt, completion)) 26 | conn.commit() 27 | conn.close() 28 | 29 | def retrieve(key, cache_file): 30 | conn = sqlite3.connect(cache_file) 31 | c = conn.cursor() 32 | 33 | c.execute("SELECT prompt, completion FROM main WHERE key=?", (key,)) 34 | result = c.fetchone() 35 | conn.close() 36 | if result: 37 | return (True, result) 38 | else: 39 | return (False, None) 40 | 41 | def get_completion_with_cache(prompt, model, temperature, max_tokens, cache_file, overwrite_cache): 42 | #import pdb; pdb.set_trace() 43 | key = hashlib.sha256(json.dumps({'prompt': prompt, 'model': model, 'temperature': temperature}).encode('utf-8')).hexdigest() 44 | hit, result = retrieve(key, cache_file) 45 | if overwrite_cache: 46 | hit = False 47 | if not hit: 48 | completion = get_completion(prompt, model, temperature, max_tokens) 49 | insert_or_update(key, json.dumps(prompt), completion, cache_file) 50 | else: 51 | print ('hit') 52 | prompt, completion = result 53 | return completion, hit 54 | 55 | 56 | #@retry(wait=wait_exponential(min=60, max=1200), stop=stop_after_attempt(12)) 57 | def chatcompletion_with_backoff(**kwargs): 58 | return openai.ChatCompletion.create(**kwargs) 59 | 60 | 61 | def get_completion(prompt, model, temperature, max_tokens): 62 | """ 63 | Get a completion using the specified language model. 64 | 65 | Args: 66 | prompt (list): The prompt for generating the sentence. 67 | model (str): The name of the language model to use. 68 | temperature (float): Sampling temperature for the model. 69 | max_tokens (int): Maximum number of tokens in the generated sentence. 70 | 71 | Returns: 72 | str: The generated sentence. 73 | """ 74 | generated_text = "" 75 | response = chatcompletion_with_backoff( 76 | model=model, 77 | messages=prompt, 78 | max_tokens=max_tokens, 79 | temperature=temperature, 80 | ) 81 | generated_text = response.choices[0].message.content.strip() 82 | return generated_text 83 | 84 | def construct_prompt(num_shot, examples): 85 | instruction = 'Answer the final question following the exact format of the given examples. Do not output anything else.\n\nExample problems:\n\n' 86 | example_demonstrations = random.sample(examples, num_shot) 87 | prompt = instruction 88 | for i, line in enumerate(example_demonstrations): 89 | s = f"Q: {line['input']}\n" 90 | s += f"A: {line['cot']} #### {line['output']}\n" 91 | prompt += s 92 | prompt += '\n' 93 | prompt += f"Question to answer:\n\nQ:" 94 | context = [{"role": 'user', "content": prompt}] 95 | return context 96 | 97 | def read_examples(filename): 98 | lines = [] 99 | with open(filename) as fin: 100 | for line in fin: 101 | input, cot_and_output = line.strip().split('||') 102 | cot, output = cot_and_output.split(' #### ') 103 | input = input.strip() 104 | num1, num2 = input.split('*') 105 | num1_digits = num1.strip().split()[::-1] 106 | num2_digits = num2.strip().split()[::-1] 107 | num1_str = ''.join(num1_digits) 108 | num2_str = ''.join(num2_digits) 109 | input = num1_str + ' * ' + num2_str 110 | cot = '' 111 | i = 0 112 | num1_int = int(num1_str) 113 | partial_sum = 0 114 | for num2_digit in num2_digits[::-1]: 115 | i += 1 116 | num2_digit = int(num2_digit) * 10**(i-1) 117 | mult = num2_digit * num1_int 118 | partial_sum += mult 119 | cot += f'{i}): {num2_digit} * {num1_int} = {mult} (partial sum {partial_sum-mult} + {mult} = {partial_sum}) ' 120 | final_result = partial_sum 121 | #import pdb; pdb.set_trace() 122 | cot = cot.strip() 123 | output = ''.join(output.strip().split()[::-1]) 124 | assert output.lstrip('0') == str(final_result) 125 | output = output.lstrip('0') 126 | lines.append({'input': input, 'output': output, 'cot': cot}) 127 | return lines 128 | 129 | def main(model, temperature, max_tokens, seed, num_shot, train_file, test_file, cache_file, overwrite_cache): 130 | create_database(cache_file) 131 | random.seed(seed) 132 | train_examples = read_examples(train_file) 133 | test_examples = read_examples(test_file) 134 | 135 | prompts = [] 136 | for _ in test_examples: 137 | prompt = construct_prompt(num_shot, train_examples) 138 | prompts.append(prompt) 139 | 140 | i = 0 141 | correct = 0 142 | total = 0 143 | total_time = 0 144 | not_hit = 0 145 | for example in test_examples: 146 | prompt = prompts[i] 147 | i += 1 148 | prompt[0]['content'] += example['input'] 149 | start_time = time.time() 150 | completion, hit = get_completion_with_cache(prompt, model, temperature, max_tokens, cache_file, overwrite_cache) 151 | answer = completion.split('####')[-1].strip() 152 | if not hit: 153 | not_hit += 1 154 | end_time = time.time() 155 | total_time += end_time - start_time 156 | print (f'throughput: {not_hit / total_time}, total time: {total_time}, total number of examples: {not_hit}') 157 | if answer == example['output']: 158 | correct += 1 159 | total += 1 160 | print (f'accuracy: {correct / total}, correct: {correct}, total: {total}') 161 | sys.stdout.flush() 162 | if not_hit > 0: 163 | print (f'final throughput: {not_hit / total_time}, total time: {total_time}, total number of examples: {not_hit}') 164 | print (f'final accuracy: {correct / total}, correct: {correct}, total: {total}') 165 | 166 | 167 | def parse_arguments(): 168 | """ 169 | Parse command-line arguments using argparse. 170 | 171 | Returns: 172 | argparse.Namespace: An object containing the parsed arguments. 173 | """ 174 | parser = argparse.ArgumentParser(description="Augment") 175 | parser.add_argument("--api_key", type=str, required=True, help="OpenAI API key") 176 | parser.add_argument("--model", type=str, default='gpt-4-1106-preview', help="model to evaluate") 177 | parser.add_argument("--train_file", type=str, default="../data/5_by_5_mult/train.txt") 178 | parser.add_argument("--test_file", type=str, default="../data/5_by_5_mult/test_bigbench.txt") 179 | parser.add_argument("--num_shot", type=int, default=5) 180 | parser.add_argument('--overwrite_cache', action='store_true') 181 | parser.add_argument("--cache_file", type=str, default="explicit_cot/5_by_5_mult/cache.db") 182 | parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for generation") 183 | parser.add_argument("--max_tokens", type=int, default=1200, help="Maximum number of tokens in the generated sentence") 184 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 185 | parser.set_defaults(overwrite_cache=False) 186 | return parser.parse_args() 187 | 188 | 189 | if __name__ == "__main__": 190 | args = parse_arguments() 191 | openai.api_key = args.api_key 192 | 193 | main(model=args.model, \ 194 | temperature=args.temperature, \ 195 | max_tokens=args.max_tokens, \ 196 | seed=args.seed, \ 197 | train_file=args.train_file, \ 198 | test_file=args.test_file, \ 199 | num_shot=args.num_shot, \ 200 | cache_file=args.cache_file, \ 201 | overwrite_cache=args.overwrite_cache) 202 | -------------------------------------------------------------------------------- /gpt4_baselines/explicit_cot/gsm8k/cache.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da03/implicit_chain_of_thought/eb3f2611cc190b665222d21e5f735a15e6bc228f/gpt4_baselines/explicit_cot/gsm8k/cache.db -------------------------------------------------------------------------------- /gpt4_baselines/explicit_cot/gsm8k/evaluate.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import numpy as np 3 | import time 4 | import sys, os, json, random, argparse, openai 5 | import tiktoken 6 | import tqdm 7 | import sqlite3 8 | 9 | 10 | def create_database(cache_file): 11 | conn = sqlite3.connect(cache_file) 12 | c = conn.cursor() 13 | # Create table 14 | c.execute('''CREATE TABLE IF NOT EXISTS main 15 | (key TEXT PRIMARY KEY, prompt TEXT, completion TEXT)''') 16 | conn.commit() 17 | conn.close() 18 | 19 | 20 | def insert_or_update(key, prompt, completion, cache_file): 21 | conn = sqlite3.connect(cache_file) 22 | c = conn.cursor() 23 | c.execute('''INSERT OR REPLACE INTO main 24 | (key, prompt, completion) VALUES (?, ?, ?)''', 25 | (key, prompt, completion)) 26 | conn.commit() 27 | conn.close() 28 | 29 | def retrieve(key, cache_file): 30 | conn = sqlite3.connect(cache_file) 31 | c = conn.cursor() 32 | 33 | c.execute("SELECT prompt, completion FROM main WHERE key=?", (key,)) 34 | result = c.fetchone() 35 | conn.close() 36 | if result: 37 | return (True, result) 38 | else: 39 | return (False, None) 40 | 41 | def get_completion_with_cache(prompt, model, temperature, max_tokens, cache_file, overwrite_cache): 42 | #import pdb; pdb.set_trace() 43 | key = hashlib.sha256(json.dumps({'prompt': prompt, 'model': model, 'temperature': temperature}).encode('utf-8')).hexdigest() 44 | hit, result = retrieve(key, cache_file) 45 | if overwrite_cache: 46 | hit = False 47 | if not hit: 48 | completion = get_completion(prompt, model, temperature, max_tokens) 49 | insert_or_update(key, json.dumps(prompt), completion, cache_file) 50 | else: 51 | print ('hit') 52 | prompt, completion = result 53 | return completion, hit 54 | 55 | 56 | #@retry(wait=wait_exponential(min=60, max=1200), stop=stop_after_attempt(12)) 57 | def chatcompletion_with_backoff(**kwargs): 58 | return openai.ChatCompletion.create(**kwargs) 59 | 60 | 61 | def get_completion(prompt, model, temperature, max_tokens): 62 | """ 63 | Get a completion using the specified language model. 64 | 65 | Args: 66 | prompt (list): The prompt for generating the sentence. 67 | model (str): The name of the language model to use. 68 | temperature (float): Sampling temperature for the model. 69 | max_tokens (int): Maximum number of tokens in the generated sentence. 70 | 71 | Returns: 72 | str: The generated sentence. 73 | """ 74 | generated_text = "" 75 | response = chatcompletion_with_backoff( 76 | model=model, 77 | messages=prompt, 78 | max_tokens=max_tokens, 79 | temperature=temperature, 80 | ) 81 | generated_text = response.choices[0].message.content.strip() 82 | return generated_text 83 | 84 | def construct_prompt(num_shot, examples): 85 | instruction = 'Answer the final question following the format of the given examples.\n\nExample problems:\n\n' 86 | example_demonstrations = random.sample(examples, num_shot) 87 | prompt = instruction 88 | for i, line in enumerate(example_demonstrations): 89 | s = f"Q: {line['input']}\n" 90 | s += f"A: {line['cot']} #### {line['output']}\n" 91 | prompt += s 92 | prompt += '\n' 93 | prompt += f"Question to answer:\n\nQ:" 94 | context = [{"role": 'user', "content": prompt}] 95 | return context 96 | 97 | def read_examples(filename): 98 | lines = [] 99 | with open(filename) as fin: 100 | for line in fin: 101 | input, cot_and_output = line.strip().split('||') 102 | cot, output = cot_and_output.split(' #### ') 103 | input = input.strip() 104 | cot = cot.strip() 105 | output = output.strip() 106 | lines.append({'input': input, 'output': output, 'cot': cot}) 107 | return lines 108 | 109 | def main(model, temperature, max_tokens, seed, num_shot, train_file, test_file, cache_file, overwrite_cache): 110 | create_database(cache_file) 111 | random.seed(seed) 112 | train_examples = read_examples(train_file) 113 | test_examples = read_examples(test_file) 114 | 115 | prompts = [] 116 | for _ in test_examples: 117 | prompt = construct_prompt(num_shot, train_examples) 118 | prompts.append(prompt) 119 | 120 | i = 0 121 | correct = 0 122 | total = 0 123 | total_time = 0 124 | not_hit = 0 125 | for example in test_examples: 126 | prompt = prompts[i] 127 | i += 1 128 | prompt[0]['content'] += example['input'] 129 | start_time = time.time() 130 | completion, hit = get_completion_with_cache(prompt, model, temperature, max_tokens, cache_file, overwrite_cache) 131 | answer = completion.split('####')[-1].strip() 132 | if not hit: 133 | not_hit += 1 134 | end_time = time.time() 135 | total_time += end_time - start_time 136 | print (f'throughput: {not_hit / total_time}, total time: {total_time}, total number of examples: {not_hit}') 137 | if answer == example['output']: 138 | correct += 1 139 | total += 1 140 | print (f'accuracy: {correct / total}, correct: {correct}, total: {total}') 141 | sys.stdout.flush() 142 | if not_hit > 0: 143 | print (f'final throughput: {not_hit / total_time}, total time: {total_time}, total number of examples: {not_hit}') 144 | print (f'final accuracy: {correct / total}, correct: {correct}, total: {total}') 145 | 146 | 147 | def parse_arguments(): 148 | """ 149 | Parse command-line arguments using argparse. 150 | 151 | Returns: 152 | argparse.Namespace: An object containing the parsed arguments. 153 | """ 154 | parser = argparse.ArgumentParser(description="Augment") 155 | parser.add_argument("--api_key", type=str, required=True, help="OpenAI API key") 156 | parser.add_argument("--model", type=str, default='gpt-4-1106-preview', help="model to evaluate") 157 | parser.add_argument("--train_file", type=str, default="../data/gsm8k/train_no_aug.txt") 158 | parser.add_argument("--test_file", type=str, default="../data/gsm8k/test.txt") 159 | parser.add_argument("--num_shot", type=int, default=5) 160 | parser.add_argument('--overwrite_cache', action='store_true') 161 | parser.add_argument("--cache_file", type=str, default="explicit_cot/gsm8k/cache.db") 162 | parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for generation") 163 | parser.add_argument("--max_tokens", type=int, default=600, help="Maximum number of tokens in the generated sentence") 164 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 165 | parser.set_defaults(overwrite_cache=False) 166 | return parser.parse_args() 167 | 168 | 169 | if __name__ == "__main__": 170 | args = parse_arguments() 171 | openai.api_key = args.api_key 172 | 173 | main(model=args.model, \ 174 | temperature=args.temperature, \ 175 | max_tokens=args.max_tokens, \ 176 | seed=args.seed, \ 177 | train_file=args.train_file, \ 178 | test_file=args.test_file, \ 179 | num_shot=args.num_shot, \ 180 | cache_file=args.cache_file, \ 181 | overwrite_cache=args.overwrite_cache) 182 | -------------------------------------------------------------------------------- /gpt4_baselines/no_cot/4_by_4_mult/cache.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da03/implicit_chain_of_thought/eb3f2611cc190b665222d21e5f735a15e6bc228f/gpt4_baselines/no_cot/4_by_4_mult/cache.db -------------------------------------------------------------------------------- /gpt4_baselines/no_cot/4_by_4_mult/evaluate.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import numpy as np 3 | import time 4 | import sys, os, json, random, argparse, openai 5 | import tiktoken 6 | import tqdm 7 | import sqlite3 8 | 9 | 10 | def create_database(cache_file): 11 | conn = sqlite3.connect(cache_file) 12 | c = conn.cursor() 13 | # Create table 14 | c.execute('''CREATE TABLE IF NOT EXISTS main 15 | (key TEXT PRIMARY KEY, prompt TEXT, completion TEXT)''') 16 | conn.commit() 17 | conn.close() 18 | 19 | 20 | def insert_or_update(key, prompt, completion, cache_file): 21 | conn = sqlite3.connect(cache_file) 22 | c = conn.cursor() 23 | c.execute('''INSERT OR REPLACE INTO main 24 | (key, prompt, completion) VALUES (?, ?, ?)''', 25 | (key, prompt, completion)) 26 | conn.commit() 27 | conn.close() 28 | 29 | def retrieve(key, cache_file): 30 | conn = sqlite3.connect(cache_file) 31 | c = conn.cursor() 32 | 33 | c.execute("SELECT prompt, completion FROM main WHERE key=?", (key,)) 34 | result = c.fetchone() 35 | conn.close() 36 | if result: 37 | return (True, result) 38 | else: 39 | return (False, None) 40 | 41 | def get_completion_with_cache(prompt, model, temperature, max_tokens, cache_file, overwrite_cache): 42 | #import pdb; pdb.set_trace() 43 | key = hashlib.sha256(json.dumps({'prompt': prompt, 'model': model, 'temperature': temperature}).encode('utf-8')).hexdigest() 44 | hit, result = retrieve(key, cache_file) 45 | if overwrite_cache: 46 | hit = False 47 | if not hit: 48 | completion = get_completion(prompt, model, temperature, max_tokens) 49 | insert_or_update(key, json.dumps(prompt), completion, cache_file) 50 | else: 51 | print ('hit') 52 | prompt, completion = result 53 | return completion, hit 54 | 55 | 56 | #@retry(wait=wait_exponential(min=60, max=1200), stop=stop_after_attempt(12)) 57 | def chatcompletion_with_backoff(**kwargs): 58 | return openai.ChatCompletion.create(**kwargs) 59 | 60 | 61 | def get_completion(prompt, model, temperature, max_tokens): 62 | """ 63 | Get a completion using the specified language model. 64 | 65 | Args: 66 | prompt (list): The prompt for generating the sentence. 67 | model (str): The name of the language model to use. 68 | temperature (float): Sampling temperature for the model. 69 | max_tokens (int): Maximum number of tokens in the generated sentence. 70 | 71 | Returns: 72 | str: The generated sentence. 73 | """ 74 | generated_text = "" 75 | response = chatcompletion_with_backoff( 76 | model=model, 77 | messages=prompt, 78 | max_tokens=max_tokens, 79 | temperature=temperature, 80 | ) 81 | generated_text = response.choices[0].message.content.strip() 82 | return generated_text 83 | 84 | def construct_prompt(num_shot, examples): 85 | instruction = 'Answer the final question following the exact format of the given examples. Do not break the problem down, directly produce the answer.\n\nExample problems:\n\n' 86 | example_demonstrations = random.sample(examples, num_shot) 87 | prompt = instruction 88 | for i, line in enumerate(example_demonstrations): 89 | s = f"Q: {line['input']}\n" 90 | s += f"A: #### {line['output']}\n" 91 | prompt += s 92 | prompt += '\n' 93 | prompt += f"Question to answer:\n\nQ:" 94 | context = [{"role": 'user', "content": prompt}] 95 | return context 96 | 97 | def read_examples(filename): 98 | lines = [] 99 | with open(filename) as fin: 100 | for line in fin: 101 | input, cot_and_output = line.strip().split('||') 102 | cot, output = cot_and_output.split(' #### ') 103 | input = input.strip() 104 | items = input.split('*') 105 | input = ''.join(items[0].strip().split()[::-1]) + ' * ' + ''.join(items[1].strip().split()[::-1]) 106 | cot = cot.strip() 107 | output = ''.join(output.strip().split()[::-1]) 108 | output = output.lstrip('0') 109 | lines.append({'input': input, 'output': output, 'cot': cot}) 110 | return lines 111 | 112 | def main(model, temperature, max_tokens, seed, num_shot, train_file, test_file, cache_file, overwrite_cache): 113 | create_database(cache_file) 114 | random.seed(seed) 115 | train_examples = read_examples(train_file) 116 | test_examples = read_examples(test_file) 117 | 118 | prompts = [] 119 | for _ in test_examples: 120 | prompt = construct_prompt(num_shot, train_examples) 121 | prompts.append(prompt) 122 | 123 | i = 0 124 | correct = 0 125 | total = 0 126 | total_time = 0 127 | not_hit = 0 128 | for example in test_examples: 129 | prompt = prompts[i] 130 | i += 1 131 | prompt[0]['content'] += example['input'] 132 | start_time = time.time() 133 | completion, hit = get_completion_with_cache(prompt, model, temperature, max_tokens, cache_file, overwrite_cache) 134 | answer = completion.split('####')[-1].strip() 135 | if not hit: 136 | not_hit += 1 137 | end_time = time.time() 138 | total_time += end_time - start_time 139 | print (f'throughput: {not_hit / total_time}, total time: {total_time}, total number of examples: {not_hit}') 140 | if answer == example['output']: 141 | correct += 1 142 | total += 1 143 | print (f'accuracy: {correct / total}, correct: {correct}, total: {total}') 144 | sys.stdout.flush() 145 | if not_hit > 0: 146 | print (f'final throughput: {not_hit / total_time}, total time: {total_time}, total number of examples: {not_hit}') 147 | print (f'final accuracy: {correct / total}, correct: {correct}, total: {total}') 148 | 149 | 150 | def parse_arguments(): 151 | """ 152 | Parse command-line arguments using argparse. 153 | 154 | Returns: 155 | argparse.Namespace: An object containing the parsed arguments. 156 | """ 157 | parser = argparse.ArgumentParser(description="Augment") 158 | parser.add_argument("--api_key", type=str, required=True, help="OpenAI API key") 159 | parser.add_argument("--model", type=str, default='gpt-4-1106-preview', help="model to evaluate") 160 | parser.add_argument("--train_file", type=str, default="../data/4_by_4_mult/train.txt") 161 | parser.add_argument("--test_file", type=str, default="../data/4_by_4_mult/test_bigbench.txt") 162 | parser.add_argument("--num_shot", type=int, default=5) 163 | parser.add_argument('--overwrite_cache', action='store_true') 164 | parser.add_argument("--cache_file", type=str, default="no_cot/4_by_4_mult/cache.db") 165 | parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for generation") 166 | parser.add_argument("--max_tokens", type=int, default=600, help="Maximum number of tokens in the generated sentence") 167 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 168 | parser.set_defaults(overwrite_cache=False) 169 | return parser.parse_args() 170 | 171 | 172 | if __name__ == "__main__": 173 | args = parse_arguments() 174 | openai.api_key = args.api_key 175 | 176 | main(model=args.model, \ 177 | temperature=args.temperature, \ 178 | max_tokens=args.max_tokens, \ 179 | seed=args.seed, \ 180 | train_file=args.train_file, \ 181 | test_file=args.test_file, \ 182 | num_shot=args.num_shot, \ 183 | cache_file=args.cache_file, \ 184 | overwrite_cache=args.overwrite_cache) 185 | -------------------------------------------------------------------------------- /gpt4_baselines/no_cot/5_by_5_mult/cache.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da03/implicit_chain_of_thought/eb3f2611cc190b665222d21e5f735a15e6bc228f/gpt4_baselines/no_cot/5_by_5_mult/cache.db -------------------------------------------------------------------------------- /gpt4_baselines/no_cot/5_by_5_mult/evaluate.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import numpy as np 3 | import time 4 | import sys, os, json, random, argparse, openai 5 | import tiktoken 6 | import tqdm 7 | import sqlite3 8 | 9 | 10 | def create_database(cache_file): 11 | conn = sqlite3.connect(cache_file) 12 | c = conn.cursor() 13 | # Create table 14 | c.execute('''CREATE TABLE IF NOT EXISTS main 15 | (key TEXT PRIMARY KEY, prompt TEXT, completion TEXT)''') 16 | conn.commit() 17 | conn.close() 18 | 19 | 20 | def insert_or_update(key, prompt, completion, cache_file): 21 | conn = sqlite3.connect(cache_file) 22 | c = conn.cursor() 23 | c.execute('''INSERT OR REPLACE INTO main 24 | (key, prompt, completion) VALUES (?, ?, ?)''', 25 | (key, prompt, completion)) 26 | conn.commit() 27 | conn.close() 28 | 29 | def retrieve(key, cache_file): 30 | conn = sqlite3.connect(cache_file) 31 | c = conn.cursor() 32 | 33 | c.execute("SELECT prompt, completion FROM main WHERE key=?", (key,)) 34 | result = c.fetchone() 35 | conn.close() 36 | if result: 37 | return (True, result) 38 | else: 39 | return (False, None) 40 | 41 | def get_completion_with_cache(prompt, model, temperature, max_tokens, cache_file, overwrite_cache): 42 | #import pdb; pdb.set_trace() 43 | key = hashlib.sha256(json.dumps({'prompt': prompt, 'model': model, 'temperature': temperature}).encode('utf-8')).hexdigest() 44 | hit, result = retrieve(key, cache_file) 45 | if overwrite_cache: 46 | hit = False 47 | if not hit: 48 | completion = get_completion(prompt, model, temperature, max_tokens) 49 | insert_or_update(key, json.dumps(prompt), completion, cache_file) 50 | else: 51 | print ('hit') 52 | prompt, completion = result 53 | return completion, hit 54 | 55 | 56 | #@retry(wait=wait_exponential(min=60, max=1200), stop=stop_after_attempt(12)) 57 | def chatcompletion_with_backoff(**kwargs): 58 | return openai.ChatCompletion.create(**kwargs) 59 | 60 | 61 | def get_completion(prompt, model, temperature, max_tokens): 62 | """ 63 | Get a completion using the specified language model. 64 | 65 | Args: 66 | prompt (list): The prompt for generating the sentence. 67 | model (str): The name of the language model to use. 68 | temperature (float): Sampling temperature for the model. 69 | max_tokens (int): Maximum number of tokens in the generated sentence. 70 | 71 | Returns: 72 | str: The generated sentence. 73 | """ 74 | generated_text = "" 75 | response = chatcompletion_with_backoff( 76 | model=model, 77 | messages=prompt, 78 | max_tokens=max_tokens, 79 | temperature=temperature, 80 | ) 81 | generated_text = response.choices[0].message.content.strip() 82 | return generated_text 83 | 84 | def construct_prompt(num_shot, examples): 85 | instruction = 'Answer the final question following the exact format of the given examples. Do not break the problem down, directly produce the answer.\n\nExample problems:\n\n' 86 | example_demonstrations = random.sample(examples, num_shot) 87 | prompt = instruction 88 | for i, line in enumerate(example_demonstrations): 89 | s = f"Q: {line['input']}\n" 90 | s += f"A: #### {line['output']}\n" 91 | prompt += s 92 | prompt += '\n' 93 | prompt += f"Question to answer:\n\nQ:" 94 | context = [{"role": 'user', "content": prompt}] 95 | return context 96 | 97 | def read_examples(filename): 98 | lines = [] 99 | with open(filename) as fin: 100 | for line in fin: 101 | input, cot_and_output = line.strip().split('||') 102 | cot, output = cot_and_output.split(' #### ') 103 | input = input.strip() 104 | items = input.split('*') 105 | input = ''.join(items[0].strip().split()[::-1]) + ' * ' + ''.join(items[1].strip().split()[::-1]) 106 | cot = cot.strip() 107 | output = ''.join(output.strip().split()[::-1]) 108 | output = output.lstrip('0') 109 | lines.append({'input': input, 'output': output, 'cot': cot}) 110 | return lines 111 | 112 | def main(model, temperature, max_tokens, seed, num_shot, train_file, test_file, cache_file, overwrite_cache): 113 | create_database(cache_file) 114 | random.seed(seed) 115 | train_examples = read_examples(train_file) 116 | test_examples = read_examples(test_file) 117 | 118 | prompts = [] 119 | for _ in test_examples: 120 | prompt = construct_prompt(num_shot, train_examples) 121 | prompts.append(prompt) 122 | 123 | i = 0 124 | correct = 0 125 | total = 0 126 | total_time = 0 127 | not_hit = 0 128 | for example in test_examples: 129 | prompt = prompts[i] 130 | i += 1 131 | prompt[0]['content'] += example['input'] 132 | start_time = time.time() 133 | completion, hit = get_completion_with_cache(prompt, model, temperature, max_tokens, cache_file, overwrite_cache) 134 | answer = completion.split('####')[-1].strip() 135 | if not hit: 136 | not_hit += 1 137 | end_time = time.time() 138 | total_time += end_time - start_time 139 | print (f'throughput: {not_hit / total_time}, total time: {total_time}, total number of examples: {not_hit}') 140 | if answer == example['output']: 141 | correct += 1 142 | total += 1 143 | print (f'accuracy: {correct / total}, correct: {correct}, total: {total}') 144 | sys.stdout.flush() 145 | if not_hit > 0: 146 | print (f'final throughput: {not_hit / total_time}, total time: {total_time}, total number of examples: {not_hit}') 147 | print (f'final accuracy: {correct / total}, correct: {correct}, total: {total}') 148 | 149 | 150 | def parse_arguments(): 151 | """ 152 | Parse command-line arguments using argparse. 153 | 154 | Returns: 155 | argparse.Namespace: An object containing the parsed arguments. 156 | """ 157 | parser = argparse.ArgumentParser(description="Augment") 158 | parser.add_argument("--api_key", type=str, required=True, help="OpenAI API key") 159 | parser.add_argument("--model", type=str, default='gpt-4-1106-preview', help="model to evaluate") 160 | parser.add_argument("--train_file", type=str, default="../data/5_by_5_mult/train.txt") 161 | parser.add_argument("--test_file", type=str, default="../data/5_by_5_mult/test_bigbench.txt") 162 | parser.add_argument("--num_shot", type=int, default=5) 163 | parser.add_argument('--overwrite_cache', action='store_true') 164 | parser.add_argument("--cache_file", type=str, default="no_cot/5_by_5_mult/cache.db") 165 | parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for generation") 166 | parser.add_argument("--max_tokens", type=int, default=600, help="Maximum number of tokens in the generated sentence") 167 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 168 | parser.set_defaults(overwrite_cache=False) 169 | return parser.parse_args() 170 | 171 | 172 | if __name__ == "__main__": 173 | args = parse_arguments() 174 | openai.api_key = args.api_key 175 | 176 | main(model=args.model, \ 177 | temperature=args.temperature, \ 178 | max_tokens=args.max_tokens, \ 179 | seed=args.seed, \ 180 | train_file=args.train_file, \ 181 | test_file=args.test_file, \ 182 | num_shot=args.num_shot, \ 183 | cache_file=args.cache_file, \ 184 | overwrite_cache=args.overwrite_cache) 185 | -------------------------------------------------------------------------------- /gpt4_baselines/no_cot/gsm8k/cache.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da03/implicit_chain_of_thought/eb3f2611cc190b665222d21e5f735a15e6bc228f/gpt4_baselines/no_cot/gsm8k/cache.db -------------------------------------------------------------------------------- /gpt4_baselines/no_cot/gsm8k/evaluate.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import numpy as np 3 | import time 4 | import sys, os, json, random, argparse, openai 5 | import tiktoken 6 | import tqdm 7 | import sqlite3 8 | 9 | 10 | def create_database(cache_file): 11 | conn = sqlite3.connect(cache_file) 12 | c = conn.cursor() 13 | # Create table 14 | c.execute('''CREATE TABLE IF NOT EXISTS main 15 | (key TEXT PRIMARY KEY, prompt TEXT, completion TEXT)''') 16 | conn.commit() 17 | conn.close() 18 | 19 | 20 | def insert_or_update(key, prompt, completion, cache_file): 21 | conn = sqlite3.connect(cache_file) 22 | c = conn.cursor() 23 | c.execute('''INSERT OR REPLACE INTO main 24 | (key, prompt, completion) VALUES (?, ?, ?)''', 25 | (key, prompt, completion)) 26 | conn.commit() 27 | conn.close() 28 | 29 | def retrieve(key, cache_file): 30 | conn = sqlite3.connect(cache_file) 31 | c = conn.cursor() 32 | 33 | c.execute("SELECT prompt, completion FROM main WHERE key=?", (key,)) 34 | result = c.fetchone() 35 | conn.close() 36 | if result: 37 | return (True, result) 38 | else: 39 | return (False, None) 40 | 41 | def get_completion_with_cache(prompt, model, temperature, max_tokens, cache_file, overwrite_cache): 42 | #import pdb; pdb.set_trace() 43 | key = hashlib.sha256(json.dumps({'prompt': prompt, 'model': model, 'temperature': temperature}).encode('utf-8')).hexdigest() 44 | hit, result = retrieve(key, cache_file) 45 | if overwrite_cache: 46 | hit = False 47 | if not hit: 48 | completion = get_completion(prompt, model, temperature, max_tokens) 49 | insert_or_update(key, json.dumps(prompt), completion, cache_file) 50 | else: 51 | print ('hit') 52 | prompt, completion = result 53 | return completion, hit 54 | 55 | 56 | #@retry(wait=wait_exponential(min=60, max=1200), stop=stop_after_attempt(12)) 57 | def chatcompletion_with_backoff(**kwargs): 58 | return openai.ChatCompletion.create(**kwargs) 59 | 60 | 61 | def get_completion(prompt, model, temperature, max_tokens): 62 | """ 63 | Get a completion using the specified language model. 64 | 65 | Args: 66 | prompt (list): The prompt for generating the sentence. 67 | model (str): The name of the language model to use. 68 | temperature (float): Sampling temperature for the model. 69 | max_tokens (int): Maximum number of tokens in the generated sentence. 70 | 71 | Returns: 72 | str: The generated sentence. 73 | """ 74 | generated_text = "" 75 | response = chatcompletion_with_backoff( 76 | model=model, 77 | messages=prompt, 78 | max_tokens=max_tokens, 79 | temperature=temperature, 80 | ) 81 | generated_text = response.choices[0].message.content.strip() 82 | return generated_text 83 | 84 | def construct_prompt(num_shot, examples): 85 | instruction = 'Answer the final question following the exact format of the given examples. Do not break the problem down, directly produce the answer.\n\nExample problems:\n\n' 86 | example_demonstrations = random.sample(examples, num_shot) 87 | prompt = instruction 88 | for i, line in enumerate(example_demonstrations): 89 | s = f"Q: {line['input']}\n" 90 | s += f"A: #### {line['output']}\n" 91 | prompt += s 92 | prompt += '\n' 93 | prompt += f"Question to answer:\n\nQ:" 94 | context = [{"role": 'user', "content": prompt}] 95 | return context 96 | 97 | def read_examples(filename): 98 | lines = [] 99 | with open(filename) as fin: 100 | for line in fin: 101 | input, cot_and_output = line.strip().split('||') 102 | cot, output = cot_and_output.split(' #### ') 103 | input = input.strip() 104 | cot = cot.strip() 105 | output = output.strip() 106 | lines.append({'input': input, 'output': output, 'cot': cot}) 107 | return lines 108 | 109 | def main(model, temperature, max_tokens, seed, num_shot, train_file, test_file, cache_file, overwrite_cache): 110 | create_database(cache_file) 111 | random.seed(seed) 112 | train_examples = read_examples(train_file) 113 | test_examples = read_examples(test_file) 114 | 115 | prompts = [] 116 | for _ in test_examples: 117 | prompt = construct_prompt(num_shot, train_examples) 118 | prompts.append(prompt) 119 | 120 | i = 0 121 | correct = 0 122 | total = 0 123 | total_time = 0 124 | not_hit = 0 125 | for example in test_examples: 126 | prompt = prompts[i] 127 | i += 1 128 | prompt[0]['content'] += example['input'] 129 | start_time = time.time() 130 | completion, hit = get_completion_with_cache(prompt, model, temperature, max_tokens, cache_file, overwrite_cache) 131 | answer = completion.split('####')[-1].strip() 132 | if not hit: 133 | not_hit += 1 134 | end_time = time.time() 135 | total_time += end_time - start_time 136 | print (f'throughput: {not_hit / total_time}, total time: {total_time}, total number of examples: {not_hit}') 137 | if answer == example['output']: 138 | correct += 1 139 | total += 1 140 | print (f'accuracy: {correct / total}, correct: {correct}, total: {total}') 141 | sys.stdout.flush() 142 | if not_hit > 0: 143 | print (f'final throughput: {not_hit / total_time}, total time: {total_time}, total number of examples: {not_hit}') 144 | print (f'final accuracy: {correct / total}, correct: {correct}, total: {total}') 145 | 146 | 147 | def parse_arguments(): 148 | """ 149 | Parse command-line arguments using argparse. 150 | 151 | Returns: 152 | argparse.Namespace: An object containing the parsed arguments. 153 | """ 154 | parser = argparse.ArgumentParser(description="Augment") 155 | parser.add_argument("--api_key", type=str, required=True, help="OpenAI API key") 156 | parser.add_argument("--model", type=str, default='gpt-4-1106-preview', help="model to evaluate") 157 | parser.add_argument("--train_file", type=str, default="../data/gsm8k/train_no_aug.txt") 158 | parser.add_argument("--test_file", type=str, default="../data/gsm8k/test.txt") 159 | parser.add_argument("--num_shot", type=int, default=5) 160 | parser.add_argument('--overwrite_cache', action='store_true') 161 | parser.add_argument("--cache_file", type=str, default="no_cot/gsm8k/cache.db") 162 | parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for generation") 163 | parser.add_argument("--max_tokens", type=int, default=600, help="Maximum number of tokens in the generated sentence") 164 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 165 | parser.set_defaults(overwrite_cache=False) 166 | return parser.parse_args() 167 | 168 | 169 | if __name__ == "__main__": 170 | args = parse_arguments() 171 | openai.api_key = args.api_key 172 | 173 | main(model=args.model, \ 174 | temperature=args.temperature, \ 175 | max_tokens=args.max_tokens, \ 176 | seed=args.seed, \ 177 | train_file=args.train_file, \ 178 | test_file=args.test_file, \ 179 | num_shot=args.num_shot, \ 180 | cache_file=args.cache_file, \ 181 | overwrite_cache=args.overwrite_cache) 182 | -------------------------------------------------------------------------------- /imgs/training_illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da03/implicit_chain_of_thought/eb3f2611cc190b665222d21e5f735a15e6bc228f/imgs/training_illustration.png -------------------------------------------------------------------------------- /imgs/training_illustration_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da03/implicit_chain_of_thought/eb3f2611cc190b665222d21e5f735a15e6bc228f/imgs/training_illustration_a.png -------------------------------------------------------------------------------- /imgs/training_illustration_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da03/implicit_chain_of_thought/eb3f2611cc190b665222d21e5f735a15e6bc228f/imgs/training_illustration_b.png -------------------------------------------------------------------------------- /imgs/training_illustration_c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da03/implicit_chain_of_thought/eb3f2611cc190b665222d21e5f735a15e6bc228f/imgs/training_illustration_c.png -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | import copy 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torch.nn.utils.rnn import pad_sequence 7 | 8 | def extract_answer(text): 9 | split_pattern = '####' 10 | if split_pattern not in text: 11 | return text.strip().replace(',', '') 12 | else: 13 | _, ans = text.strip().split('####', 1) 14 | ans = '####' + ans 15 | ans = ans.strip().replace(',', '') 16 | return ans 17 | 18 | def extract_cot(text): 19 | split_pattern = '####' 20 | if split_pattern not in text: 21 | return None 22 | else: 23 | cot, _ = text.strip().split('####', 1) 24 | cot = cot.strip() 25 | return cot 26 | 27 | class CoTDataset(Dataset): 28 | def __init__(self, tokenizer, file_path, max_length): 29 | assert os.path.isfile(file_path), f"Input file path {file_path} not found" 30 | print (f'Creating features from dataset file at {file_path}') 31 | bos_tok = tokenizer.bos_token 32 | eos_tok = tokenizer.eos_token 33 | 34 | with open(file_path, encoding="utf-8") as f: 35 | lines = [line.split('||') for line in f.read().splitlines() if (len(line) > 0 and not line.isspace() 36 | and len(line.split('||')) ==2 )] 37 | src_lines, tgt_lines = list(zip(*lines)) 38 | src_lines = list(src_lines) 39 | tgt_lines = list(tgt_lines) 40 | 41 | edited_sents_cot = [] 42 | edited_sents_only = [] 43 | edited_sents_all = [] 44 | edited_sents_nocot = [] 45 | for src, tgt in zip(src_lines, tgt_lines): 46 | #import pdb; pdb.set_trace() 47 | ans = extract_answer(tgt) 48 | cot = extract_cot(tgt) 49 | sent = ' {} {} '.format(src, bos_tok) + cot + ' {}'.format(eos_tok) 50 | edited_sents_cot.append(sent) 51 | sent = ' {} {} '.format(src, bos_tok) 52 | edited_sents_only.append(sent) 53 | sent = ' {} {} '.format(src, bos_tok) + cot + ' {} '.format(eos_tok) + ans + ' {}'.format(eos_tok) 54 | edited_sents_all.append(sent) 55 | sent = ' {} {} '.format(src, bos_tok) + ans + ' {}'.format(eos_tok) 56 | edited_sents_nocot.append(sent) 57 | 58 | batch_encoding_cot = tokenizer(edited_sents_cot, add_special_tokens=True, truncation=True, max_length=max_length) 59 | batch_encoding_only = tokenizer(edited_sents_only, add_special_tokens=True, truncation=True, max_length=max_length) 60 | batch_encoding_all = tokenizer(edited_sents_all, add_special_tokens=True, truncation=True, max_length=max_length) 61 | batch_encoding_nocot = tokenizer(edited_sents_nocot, add_special_tokens=True, truncation=True, max_length=max_length) 62 | self.examples_cot = batch_encoding_cot["input_ids"] 63 | self.examples_only = batch_encoding_only["input_ids"] 64 | self.examples_all = batch_encoding_all["input_ids"] 65 | self.examples_nocot = batch_encoding_nocot["input_ids"] 66 | 67 | self.labels_cot = copy.deepcopy(self.examples_cot) 68 | self.labels_all = copy.deepcopy(self.examples_all) 69 | self.labels_cot_shift = copy.deepcopy(self.examples_cot) 70 | self.labels_nocot = copy.deepcopy(self.examples_nocot) 71 | 72 | self.src_sent_cot = [] 73 | self.tgt_sent_cot = [] 74 | 75 | temp_src_len = 0 76 | temp_tgt_len = 0 77 | temp_count = 0 78 | separator = tokenizer.eos_token_id #tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0] 79 | for i, elem in enumerate(self.labels_cot): 80 | sep_idx = elem.index(separator) + 1 81 | self.src_sent_cot.append(self.examples_cot[i][:sep_idx-1]) 82 | self.tgt_sent_cot.append(self.examples_cot[i][sep_idx-1:]) 83 | self.labels_cot[i][:sep_idx] = [-100] * sep_idx 84 | assert self.labels_all[i][sep_idx-1] == separator 85 | self.labels_all[i][:sep_idx] = [-100] * sep_idx 86 | self.labels_cot_shift[i][:sep_idx-1] = [-100] * (sep_idx-1) 87 | temp_src_len += sep_idx-1 88 | temp_tgt_len += len(elem) - (sep_idx-1) 89 | temp_count += 1 90 | 91 | print('tgt_avg: ', temp_tgt_len / temp_count) 92 | print('src_avg: ', temp_src_len / temp_count) 93 | print('ratios: ', temp_src_len/temp_tgt_len) 94 | 95 | self.src_sent_nocot = [] 96 | self.tgt_sent_nocot = [] 97 | temp_src_len = 0 98 | temp_tgt_len = 0 99 | temp_count = 0 100 | separator = tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0] 101 | for i, elem in enumerate(self.labels_nocot): 102 | sep_idx = elem.index(separator) + 1 103 | self.src_sent_nocot.append(self.examples_nocot[i][:sep_idx-1]) 104 | self.tgt_sent_nocot.append(self.examples_nocot[i][sep_idx-1:]) 105 | self.labels_nocot[i][:sep_idx] = [-100] * sep_idx 106 | temp_src_len += sep_idx-1 107 | temp_tgt_len += len(elem) - (sep_idx-1) 108 | temp_count += 1 109 | 110 | print('tgt_avg: ', temp_tgt_len / temp_count) 111 | print('src_avg: ', temp_src_len / temp_count) 112 | print('ratios: ', temp_src_len/temp_tgt_len) 113 | 114 | 115 | print(edited_sents_all[0]) 116 | print(self.labels_cot[0]) 117 | print(self.labels_nocot[0]) 118 | print(self.examples_nocot[0]) 119 | print(edited_sents_nocot[0]) 120 | print(self.src_sent_nocot[0]) 121 | print(self.tgt_sent_nocot[0]) 122 | 123 | def __len__(self): 124 | return len(self.examples_cot) 125 | 126 | # def __getitem__(self, i) -> torch.Tensor: 127 | def __getitem__(self, i): 128 | return (torch.tensor(self.examples_cot[i], dtype=torch.long), 129 | torch.tensor(self.examples_nocot[i], dtype=torch.long), 130 | torch.tensor(self.labels_cot[i], dtype=torch.long), 131 | torch.tensor(self.labels_cot_shift[i], dtype=torch.long), 132 | torch.tensor(self.labels_nocot[i], dtype=torch.long), 133 | torch.tensor(self.src_sent_cot[i], dtype=torch.long), 134 | torch.tensor(self.src_sent_nocot[i], dtype=torch.long), 135 | torch.tensor(self.tgt_sent_cot[i], dtype=torch.long), 136 | torch.tensor(self.tgt_sent_nocot[i], dtype=torch.long), 137 | torch.tensor(self.examples_only[i], dtype=torch.long), 138 | torch.tensor(self.examples_all[i], dtype=torch.long), 139 | torch.tensor(self.labels_all[i], dtype=torch.long), 140 | ) 141 | @dataclass 142 | class CoTDataCollator: 143 | """ 144 | VAEData collator used for language modeling. 145 | - collates batches of tensors, honoring their tokenizer's pad_token 146 | - preprocesses batches for masked language modeling 147 | """ 148 | def __init__(self, tokenizer): 149 | self.tokenizer = tokenizer 150 | 151 | def __call__(self, examples): 152 | #import pdb; pdb.set_trace() 153 | input_ids_cot, input_ids_nocot, labels_cot, labels_cot_shift, labels_nocot, src_cot, src_nocot, tgt_cot, tgt_nocot, input_ids_only, input_ids_all, labels_all = zip(*examples) 154 | input_ids_cot = self._tensorize_batch(input_ids_cot) 155 | input_ids_cot[input_ids_cot.lt(0)] = self.tokenizer.eos_token_id 156 | input_ids_only = self._tensorize_batch(input_ids_only) 157 | input_ids_only[input_ids_only.lt(0)] = self.tokenizer.eos_token_id 158 | input_ids_all = self._tensorize_batch(input_ids_all) 159 | input_ids_all[input_ids_all.lt(0)] = self.tokenizer.eos_token_id 160 | input_ids_nocot = self._tensorize_batch(input_ids_nocot) 161 | input_ids_nocot[input_ids_nocot.lt(0)] = self.tokenizer.eos_token_id 162 | labels_cot = self._tensorize_batch(labels_cot) 163 | labels_all = self._tensorize_batch(labels_all) 164 | labels_cot_shift = self._tensorize_batch(labels_cot_shift) 165 | labels_nocot = self._tensorize_batch(labels_nocot) 166 | return {"input_ids_cot": input_ids_cot, "input_ids_nocot": input_ids_nocot, "labels_cot": labels_cot, "labels_cot_shift": labels_cot_shift, "labels_nocot": labels_nocot, 'input_ids_only': input_ids_only, 'input_ids_all': input_ids_all, 'labels_all': labels_all} 167 | 168 | def _tensorize_batch(self, examples): 169 | # In order to accept both lists of lists and lists of Tensors 170 | if isinstance(examples[0], (list, tuple)): 171 | examples = [torch.tensor(e, dtype=torch.long) for e in examples] 172 | length_of_first = examples[0].size(0) 173 | are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) 174 | if are_tensors_same_length: 175 | return torch.stack(examples, dim=0) 176 | else: 177 | return pad_sequence(examples, batch_first=True, padding_value=-100) 178 | -------------------------------------------------------------------------------- /src/generate.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import re 4 | import torch 5 | import sys 6 | from torch.utils.data import DataLoader 7 | from torch.nn import CrossEntropyLoss 8 | from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW 9 | import argparse 10 | import os 11 | import inspect 12 | import tqdm 13 | from data import CoTDataset, CoTDataCollator, extract_answer 14 | import logging 15 | import random 16 | import torch.nn as nn 17 | 18 | from models.emulator import Emulator 19 | from models.student import Student 20 | from utils import get_sep_position 21 | 22 | torch.backends.cuda.matmul.allow_tf32 = True 23 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 24 | 25 | random.seed(1234) 26 | torch.manual_seed(1234) 27 | logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere 28 | 29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | 31 | 32 | @torch.no_grad() 33 | def evaluate(dataloader, tokenizer, ctx, emulator, student, max_new_tokens): 34 | total_time = 0 35 | total_instances = 0 36 | total_correct = 0 37 | 38 | for batch in tqdm.tqdm(dataloader): 39 | input_ids_all = batch['input_ids_nocot'].to(device) 40 | # Remove answer part 41 | sep_positions = get_sep_position(input_ids_all, tokenizer.eos_token_id) 42 | input_ids = input_ids_all[:, :sep_positions.max()+1] 43 | start_time = time.time() 44 | with ctx: 45 | emulated_teacher_states = emulator(input_ids) 46 | 47 | # Generate from student 48 | beam_output = student.generate( 49 | input_ids=input_ids, 50 | teacher_states=emulated_teacher_states, 51 | max_new_tokens=max_new_tokens, 52 | ) 53 | 54 | # Evaluate 55 | #import pdb; pdb.set_trace() 56 | for i, (input_ids_all_i, beam_output_i) in enumerate(zip(input_ids_all, beam_output)): 57 | #sep_position = input_ids_single.tolist().index(tokenizer.eos_token_id) 58 | sep_position = sep_positions[i].item() 59 | tgt = input_ids_all_i[sep_position+1:] 60 | tgt_text = tokenizer.decode(tgt, skip_special_tokens=True) 61 | ans = extract_answer(tgt_text) 62 | pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True) 63 | pred_ans = extract_answer(pred_text) 64 | #import pdb; pdb.set_trace() 65 | total_instances += 1 66 | if ans == pred_ans: 67 | total_correct += 1 68 | end_time = time.time() 69 | total_time += end_time - start_time 70 | 71 | #print (total_time, total_instances, total_instances / total_time) 72 | throughput = total_instances / total_time 73 | accuracy = total_correct / total_instances 74 | return accuracy, throughput 75 | 76 | 77 | def main(): 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument('--test_path', type=str, required=True) 80 | parser.add_argument('--batch_size', type=int, default=1) 81 | parser.add_argument('--max_new_tokens', type=int, default=128) 82 | parser.add_argument('--student_path', type=str, required=True) 83 | parser.add_argument('--emulator_path', type=str, required=True) 84 | args = parser.parse_args() 85 | 86 | print (args) 87 | 88 | dtype = 'float32' 89 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 90 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 91 | ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype) 92 | print (ptdtype, dtype, device) 93 | 94 | 95 | # Load Models 96 | emulator = Emulator.from_pretrained(args.emulator_path).to(device).to(ptdtype) 97 | student = Student.from_pretrained(args.student_path).to(device).to(ptdtype) 98 | emulator.eval() 99 | student.eval() 100 | 101 | # Load data 102 | tokenizer = emulator.tokenizer 103 | collate_fn = CoTDataCollator(tokenizer) 104 | test_dataset = CoTDataset(tokenizer, args.test_path, 1024) 105 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False) 106 | 107 | accuracy, throughput = evaluate(test_dataloader, tokenizer, ctx, emulator, student, args.max_new_tokens) 108 | print (f"Test Accuracy: {accuracy}. Throughput: {throughput}") 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /src/models/configuration_emulator.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | class EmulatorConfig(PretrainedConfig): 4 | def __init__( 5 | self, 6 | base_model='gpt2', 7 | tokenizer_name='gpt2', 8 | mixture_size=1, 9 | softmax_temperature=0.05, 10 | **kwargs, 11 | ): 12 | self.base_model = base_model 13 | self.tokenizer_name = tokenizer_name 14 | self.mixture_size = mixture_size 15 | self.softmax_temperature = softmax_temperature 16 | super().__init__(**kwargs) 17 | 18 | -------------------------------------------------------------------------------- /src/models/configuration_student.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | class StudentConfig(PretrainedConfig): 4 | def __init__( 5 | self, 6 | base_model='gpt2', 7 | tokenizer_name='gpt2', 8 | mixture_size=1, 9 | **kwargs, 10 | ): 11 | self.base_model = base_model 12 | self.tokenizer_name = tokenizer_name 13 | self.mixture_size = mixture_size 14 | super().__init__(**kwargs) 15 | -------------------------------------------------------------------------------- /src/models/configuration_teacher.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | class TeacherConfig(PretrainedConfig): 4 | def __init__( 5 | self, 6 | base_model='gpt2', 7 | tokenizer_name='gpt2', 8 | **kwargs, 9 | ): 10 | self.base_model = base_model 11 | self.tokenizer_name = tokenizer_name 12 | super().__init__(**kwargs) 13 | 14 | -------------------------------------------------------------------------------- /src/models/emulator.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions 8 | 9 | from .configuration_emulator import EmulatorConfig 10 | import sys 11 | sys.path.append("..") 12 | from utils import get_sep_position 13 | from .modeling_gpt2_implicit import GPT2LMHeadImplicitModel 14 | import logging 15 | 16 | class Emulator(nn.Module): 17 | def __init__(self, config): 18 | super().__init__() 19 | self.config = config 20 | self.base_model = GPT2LMHeadImplicitModel.from_pretrained(config.base_model) 21 | self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) 22 | num_layers = len(self.base_model.transformer.h) 23 | hidden_size = self.base_model.config.hidden_size 24 | 25 | self.mlps = nn.ModuleList([nn.Sequential( 26 | nn.Linear(2*hidden_size, 4*hidden_size), 27 | nn.ReLU(), 28 | nn.Linear(4*hidden_size, hidden_size), 29 | ) for _ in range(num_layers)]) 30 | 31 | self.mixture_components = nn.Embedding(config.mixture_size, hidden_size) 32 | self.rnn = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, num_layers=1, \ 33 | batch_first=False, dropout=0, bidirectional=False) 34 | self.key_proj = nn.Linear(hidden_size, hidden_size) 35 | self.query_proj = nn.Linear(hidden_size, hidden_size) 36 | self.out_proj = nn.Linear(hidden_size*2, hidden_size) 37 | 38 | def eval(self): 39 | self.base_model.eval() 40 | 41 | def forward(self, input_ids, requires_backward=False): 42 | sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id) 43 | input_ids = input_ids[:, :sep_positions.max()+1] 44 | outputs = self.base_model.forward(mode='forward_emulator', \ 45 | input_ids=input_ids, \ 46 | positions_to_take=sep_positions, \ 47 | softmax_temperature=self.config.softmax_temperature, \ 48 | requires_backward=requires_backward, \ 49 | rnn=self.rnn, \ 50 | mlps=self.mlps, \ 51 | mixture_components=self.mixture_components, \ 52 | key_proj=self.key_proj, \ 53 | query_proj=self.query_proj, \ 54 | out_proj=self.out_proj) 55 | emulated_teacher_states = outputs.f_h_cs 56 | return emulated_teacher_states 57 | 58 | def compute_loss(self, input_ids, teacher_states): 59 | emulated_teacher_states = self.forward(input_ids=input_ids, requires_backward=True) 60 | batch_size = input_ids.shape[0] 61 | 62 | loss_fct = nn.MSELoss(reduction='none') 63 | loss = 0 64 | for teacher_state, emulated_teacher_state in zip(teacher_states, emulated_teacher_states): 65 | loss += loss_fct(teacher_state, emulated_teacher_state).sum(-1) / 2 66 | loss = loss.mean() 67 | outputs = CausalLMOutputWithCrossAttentions(loss=loss) 68 | outputs.total_loss = loss * batch_size 69 | return outputs 70 | 71 | @classmethod 72 | def from_pretrained(self, pretrained_path): 73 | config = EmulatorConfig.from_pretrained(pretrained_path) 74 | model = Emulator(config) 75 | state_dict = torch.load(os.path.join(pretrained_path, 'state_dict.bin')) 76 | try: 77 | model.load_state_dict(state_dict) 78 | except: 79 | model.load_state_dict(state_dict, strict=False) 80 | logging.warn("Some weights of the model Emulator checkpoint not loaded.") 81 | return model 82 | 83 | def save_pretrained(self, save_directory): 84 | print (f'Saving to {save_directory}') 85 | self.config.save_pretrained(save_directory) 86 | state_dict = self.state_dict() 87 | torch.save(state_dict, os.path.join(save_directory, 'state_dict.bin')) 88 | -------------------------------------------------------------------------------- /src/models/modeling_gpt2_implicit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import GPT2Model, GPT2LMHeadModel 4 | from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions 5 | from typing import Optional, Tuple, Union, Dict, Any 6 | 7 | class GPT2ImplicitModel(GPT2Model): 8 | def __init__(self, config): 9 | super().__init__(config) 10 | 11 | def forward( 12 | self, 13 | input_ids: Optional[torch.LongTensor] = None, 14 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 15 | attention_mask: Optional[torch.FloatTensor] = None, 16 | token_type_ids: Optional[torch.LongTensor] = None, 17 | position_ids: Optional[torch.LongTensor] = None, 18 | head_mask: Optional[torch.FloatTensor] = None, 19 | inputs_embeds: Optional[torch.FloatTensor] = None, 20 | encoder_hidden_states: Optional[torch.Tensor] = None, 21 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 22 | use_cache: Optional[bool] = None, 23 | output_attentions: Optional[bool] = None, 24 | output_hidden_states: Optional[bool] = None, 25 | return_dict: Optional[bool] = None, 26 | zs=None, 27 | mult_p=0, 28 | no_mixture=0, 29 | softmax_p=0, 30 | softmax_temperature=1, 31 | mlps=None, 32 | relevant_tokens=None, 33 | mixture_components=None, 34 | rnn=None, 35 | key_proj=None, 36 | query_proj=None, 37 | out_proj=None, 38 | attended_to=None, 39 | attended_to_mask=None, 40 | positions_to_take=None, 41 | positions_to_substitute=None, 42 | states_to_substitute=None, 43 | mode=None, 44 | residual=False, 45 | requires_backward=False, 46 | phase2=False, 47 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 48 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 49 | output_hidden_states = ( 50 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 51 | ) 52 | use_cache = use_cache if use_cache is not None else self.config.use_cache 53 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 54 | 55 | if input_ids is not None and inputs_embeds is not None: 56 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 57 | elif input_ids is not None: 58 | input_shape = input_ids.size() 59 | input_ids = input_ids.view(-1, input_shape[-1]) 60 | batch_size = input_ids.shape[0] 61 | elif inputs_embeds is not None: 62 | input_shape = inputs_embeds.size()[:-1] 63 | batch_size = inputs_embeds.shape[0] 64 | else: 65 | raise ValueError("You have to specify either input_ids or inputs_embeds") 66 | 67 | device = input_ids.device if input_ids is not None else inputs_embeds.device 68 | 69 | if token_type_ids is not None: 70 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 71 | if position_ids is not None: 72 | position_ids = position_ids.view(-1, input_shape[-1]) 73 | 74 | if past_key_values is None: 75 | past_length = 0 76 | past_key_values = tuple([None] * len(self.h)) 77 | else: 78 | past_length = past_key_values[0][0].size(-2) 79 | if position_ids is None: 80 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 81 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 82 | 83 | # GPT2Attention mask. 84 | if attention_mask is not None: 85 | if batch_size <= 0: 86 | raise ValueError("batch_size has to be defined and > 0") 87 | attention_mask = attention_mask.view(batch_size, -1) 88 | # We create a 3D attention mask from a 2D tensor mask. 89 | # Sizes are [batch_size, 1, 1, to_seq_length] 90 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 91 | # this attention mask is more simple than the triangular masking of causal attention 92 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 93 | attention_mask = attention_mask[:, None, None, :] 94 | 95 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 96 | # masked positions, this operation will create a tensor which is 0.0 for 97 | # positions we want to attend and the dtype's smallest value for masked positions. 98 | # Since we are adding it to the raw scores before the softmax, this is 99 | # effectively the same as removing these entirely. 100 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility 101 | attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min 102 | 103 | # If a 2D or 3D attention mask is provided for the cross-attention 104 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 105 | if self.config.add_cross_attention and encoder_hidden_states is not None: 106 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 107 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 108 | if encoder_attention_mask is None: 109 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 110 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) 111 | else: 112 | encoder_attention_mask = None 113 | 114 | # Prepare head mask if needed 115 | # 1.0 in head_mask indicate we keep the head 116 | # attention_probs has shape bsz x n_heads x N x N 117 | # head_mask has shape n_layer x batch x n_heads x N x N 118 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 119 | 120 | if inputs_embeds is None: 121 | inputs_embeds = self.wte(input_ids) 122 | position_embeds = self.wpe(position_ids) 123 | hidden_states = inputs_embeds + position_embeds 124 | 125 | if token_type_ids is not None: 126 | token_type_embeds = self.wte(token_type_ids) 127 | hidden_states = hidden_states + token_type_embeds 128 | 129 | hidden_states = self.drop(hidden_states) 130 | 131 | 132 | output_shape = input_shape + (hidden_states.size(-1),) 133 | 134 | if self.gradient_checkpointing and self.training: 135 | if use_cache: 136 | logger.warning_once( 137 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 138 | ) 139 | use_cache = False 140 | 141 | presents = () if use_cache else None 142 | all_self_attentions = () if output_attentions else None 143 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 144 | all_hidden_states = () if output_hidden_states else None 145 | zs = [] 146 | f_h_cs = [] 147 | #import pdb; pdb.set_trace() 148 | if rnn is not None: 149 | rnn_state = None 150 | if key_proj is not None: 151 | assert rnn is not None 152 | past_keys = None # bsz, len, hidden_size 153 | context = None 154 | if mode == 'forward_emulator': 155 | weight = mixture_components.weight # vocab, hidden_size 156 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 157 | # Model parallel 158 | if self.model_parallel: 159 | torch.cuda.set_device(hidden_states.device) 160 | # Ensure layer_past is on same device as hidden_states (might not be correct) 161 | if layer_past is not None: 162 | layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) 163 | # Ensure that attention_mask is always on the same device as hidden_states 164 | if attention_mask is not None: 165 | attention_mask = attention_mask.to(hidden_states.device) 166 | if isinstance(head_mask, torch.Tensor): 167 | head_mask = head_mask.to(hidden_states.device) 168 | if output_hidden_states: 169 | all_hidden_states = all_hidden_states + (hidden_states,) 170 | #assert zs is None 171 | #assert zs[i] is not None 172 | hidden_size = hidden_states.shape[-1] 173 | #import pdb; pdb.set_trace() 174 | # Gather relevant hidden states at the separator 175 | if mode == 'forward_emulator': 176 | z = hidden_states.gather(1, positions_to_take.view(-1, 1, 1).expand(-1, -1, hidden_size)).squeeze(1) # bsz, hidden_size 177 | #if not phase2: 178 | # zs_p.append(zp) 179 | #else: 180 | # zs_q.append(zp) 181 | zs.append(z) 182 | c = z # bsz, hidden_size 183 | 184 | #if softmax_p == 0: 185 | # with torch.no_grad(): 186 | # log_probs = c @ weight.T # bsz, vocab 187 | # relevant_tokens_i_pred = log_probs.argmax(-1) 188 | # if no_mixture == 1: 189 | # relevant_tokens_i_pred = relevant_tokens_i_pred * 0 190 | # #import pdb; pdb.set_trace() 191 | # relevant_proj_pred = mixture_components(relevant_tokens_i_pred) # bsz, hidden_size 192 | #else: 193 | #import pdb; pdb.set_trace() 194 | if weight.shape[0] == 1: 195 | mixture_embedding = weight.expand(batch_size, -1) 196 | else: 197 | log_probs = c @ weight.T # bsz, vocab 198 | log_probs = log_probs / softmax_temperature 199 | probs = log_probs.softmax(dim=-1) # bsz, vocab 200 | #relevant_proj_pred = probs @ weight # bsz, H 201 | mixture_embedding = probs @ weight # bsz, H 202 | f_h_c = mlps[i](torch.cat((z, mixture_embedding), dim=-1)) # bsz, hidden_size 203 | 204 | #if phase2: 205 | # zs_p.append(f_h_c) 206 | f_h_cs.append(f_h_c) 207 | next_input = f_h_c 208 | if rnn is not None: 209 | #import pdb; pdb.set_trace() 210 | if key_proj is not None: 211 | if context is None: 212 | context = next_input.new_zeros(next_input.shape) 213 | output, rnn_state = rnn((next_input+context).unsqueeze(0), rnn_state) 214 | output = output.squeeze(0) 215 | current_key = key_proj(output) 216 | if past_keys is not None: 217 | current_query = query_proj(output) # bsz, hidden_size 218 | attn_weights = torch.bmm(past_keys, current_query.unsqueeze(-1)) # bsz, len, 1 219 | attn_probs = attn_weights.softmax(dim=1) 220 | attn_probs = attn_probs.squeeze(-1).unsqueeze(1) 221 | context = torch.bmm(attn_probs, past_keys).squeeze(1) 222 | past_keys = torch.cat((past_keys, current_key.unsqueeze(1)), dim=1) 223 | else: 224 | past_keys = current_key.unsqueeze(1) 225 | output = out_proj(torch.cat((output, context), dim=-1)) 226 | next_input = output 227 | else: 228 | rnn_output, rnn_state = rnn(next_input.unsqueeze(0), rnn_state) 229 | next_input = rnn_output.squeeze(0) 230 | 231 | #zs_p.append(hidden_states.gather(1, positions_to_take.view(-1, 1, 1).expand(-1, -1, hidden_size)).squeeze(1)) 232 | hidden_states_orig = hidden_states 233 | if requires_backward: 234 | hidden_states = hidden_states.clone() 235 | if positions_to_take.eq(positions_to_take[0]).all(): 236 | hidden_states[:, positions_to_take[0]] = next_input 237 | else: 238 | for batch_id in range(positions_to_take.shape[0]): 239 | hidden_states[batch_id, positions_to_take[batch_id]] = next_input[batch_id] 240 | elif mode == 'forward_student': 241 | assert states_to_substitute is not None 242 | hidden_size = hidden_states.shape[-1] 243 | #zs.append(hidden_states.gather(1, first_ids.view(-1, 1, 1).expand(-1, -1, hidden_size)).squeeze(1)) 244 | hidden_states_orig = hidden_states 245 | if requires_backward: 246 | hidden_states = hidden_states.clone() 247 | if positions_to_substitute.eq(positions_to_substitute[0]).all(): 248 | hidden_states[:, positions_to_substitute[0]] = states_to_substitute[i] 249 | else: 250 | for batch_id in range(batch_size): 251 | hidden_states[batch_id, positions_to_substitute[batch_id]] = states_to_substitute[i][batch_id] 252 | 253 | 254 | if self.gradient_checkpointing and self.training: 255 | 256 | def create_custom_forward(module): 257 | def custom_forward(*inputs): 258 | # None for past_key_value 259 | return module(*inputs, use_cache, output_attentions) 260 | 261 | return custom_forward 262 | 263 | outputs = torch.utils.checkpoint.checkpoint( 264 | create_custom_forward(block), 265 | hidden_states, 266 | None, 267 | attention_mask, 268 | head_mask[i], 269 | encoder_hidden_states, 270 | encoder_attention_mask, 271 | ) 272 | else: 273 | outputs = block( 274 | hidden_states, 275 | layer_past=layer_past, 276 | attention_mask=attention_mask, 277 | head_mask=head_mask[i], 278 | encoder_hidden_states=encoder_hidden_states, 279 | encoder_attention_mask=encoder_attention_mask, 280 | use_cache=use_cache, 281 | output_attentions=output_attentions, 282 | ) 283 | 284 | hidden_states = outputs[0] 285 | if use_cache is True: 286 | presents = presents + (outputs[1],) 287 | 288 | if output_attentions: 289 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 290 | if self.config.add_cross_attention: 291 | all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) 292 | 293 | # Model Parallel: If it's the last layer for that device, put things on the next device 294 | if self.model_parallel: 295 | for k, v in self.device_map.items(): 296 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 297 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 298 | 299 | hidden_states = self.ln_f(hidden_states) 300 | 301 | hidden_states = hidden_states.view(output_shape) 302 | # Add last hidden state 303 | if output_hidden_states: 304 | all_hidden_states = all_hidden_states + (hidden_states,) 305 | 306 | if not return_dict: 307 | return tuple( 308 | v 309 | for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] 310 | if v is not None 311 | ) 312 | 313 | outputs = BaseModelOutputWithPastAndCrossAttentions( 314 | last_hidden_state=hidden_states, 315 | past_key_values=presents, 316 | hidden_states=all_hidden_states, 317 | attentions=all_self_attentions, 318 | cross_attentions=all_cross_attentions, 319 | ) 320 | outputs.zs = zs 321 | outputs.f_h_cs = f_h_cs 322 | return outputs 323 | 324 | class GPT2LMHeadImplicitModel(GPT2LMHeadModel): 325 | def __init__(self, config): 326 | super(GPT2LMHeadModel, self).__init__(config) 327 | self.transformer = GPT2ImplicitModel(config) 328 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 329 | 330 | # Model parallel 331 | self.model_parallel = False 332 | self.device_map = None 333 | 334 | # Initialize weights and apply final processing 335 | self.post_init() 336 | 337 | def forward( 338 | self, 339 | mode=None, 340 | input_ids: Optional[torch.LongTensor] = None, 341 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 342 | attention_mask: Optional[torch.FloatTensor] = None, 343 | token_type_ids: Optional[torch.LongTensor] = None, 344 | position_ids: Optional[torch.LongTensor] = None, 345 | head_mask: Optional[torch.FloatTensor] = None, 346 | inputs_embeds: Optional[torch.FloatTensor] = None, 347 | encoder_hidden_states: Optional[torch.Tensor] = None, 348 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 349 | labels: Optional[torch.LongTensor] = None, 350 | use_cache: Optional[bool] = None, 351 | output_attentions: Optional[bool] = None, 352 | output_hidden_states: Optional[bool] = None, 353 | return_dict: Optional[bool] = None, 354 | zs = None, 355 | mult_p=0, 356 | softmax_p=0, 357 | no_mixture=0, 358 | softmax_temperature=1, 359 | mlps = None, 360 | rnn=None, 361 | key_proj=None, 362 | query_proj=None, 363 | out_proj=None, 364 | relevant_tokens=None, 365 | mixture_components=None, 366 | positions_to_take=None, 367 | positions_to_substitute=None, 368 | states_to_substitute=None, 369 | residual=False, 370 | requires_backward=False, 371 | phase2=False, 372 | ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: 373 | r""" 374 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 375 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 376 | `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` 377 | are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` 378 | """ 379 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 380 | 381 | #import pdb; pdb.set_trace() 382 | transformer_outputs = self.transformer.forward( 383 | input_ids, 384 | past_key_values=past_key_values, 385 | attention_mask=attention_mask, 386 | token_type_ids=token_type_ids, 387 | position_ids=position_ids, 388 | head_mask=head_mask, 389 | inputs_embeds=inputs_embeds, 390 | encoder_hidden_states=encoder_hidden_states, 391 | encoder_attention_mask=encoder_attention_mask, 392 | use_cache=use_cache, 393 | output_attentions=output_attentions, 394 | output_hidden_states=output_hidden_states, 395 | return_dict=return_dict, 396 | zs=zs, 397 | mult_p=mult_p, 398 | softmax_p=softmax_p, 399 | no_mixture=no_mixture, 400 | softmax_temperature=softmax_temperature, 401 | mlps=mlps, 402 | key_proj=key_proj, 403 | query_proj=query_proj, 404 | out_proj=out_proj, 405 | relevant_tokens=relevant_tokens, 406 | mixture_components=mixture_components, 407 | rnn=rnn, 408 | phase2=phase2, 409 | positions_to_take=positions_to_take, 410 | positions_to_substitute=positions_to_substitute, 411 | states_to_substitute=states_to_substitute, 412 | residual=residual, 413 | requires_backward=requires_backward, 414 | mode=mode, 415 | ) 416 | zs = transformer_outputs.zs 417 | f_h_cs = transformer_outputs.f_h_cs 418 | hidden_states = transformer_outputs[0] 419 | 420 | # Set device for model parallelism 421 | if self.model_parallel: 422 | torch.cuda.set_device(self.transformer.first_device) 423 | hidden_states = hidden_states.to(self.lm_head.weight.device) 424 | 425 | lm_logits = self.lm_head(hidden_states) 426 | 427 | loss = None 428 | if labels is not None: 429 | # move labels to correct device to enable model parallelism 430 | labels = labels.to(lm_logits.device) 431 | # Shift so that tokens < n predict n 432 | shift_logits = lm_logits[..., :-1, :].contiguous() 433 | shift_labels = labels[..., 1:].contiguous() 434 | # Flatten the tokens 435 | loss_fct = CrossEntropyLoss() 436 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 437 | 438 | if not return_dict: 439 | output = (lm_logits,) + transformer_outputs[1:] 440 | return ((loss,) + output) if loss is not None else output 441 | 442 | outputs = CausalLMOutputWithCrossAttentions( 443 | loss=loss, 444 | logits=lm_logits, 445 | past_key_values=transformer_outputs.past_key_values, 446 | hidden_states=transformer_outputs.hidden_states, 447 | attentions=transformer_outputs.attentions, 448 | cross_attentions=transformer_outputs.cross_attentions, 449 | ) 450 | outputs.zs = zs 451 | outputs.f_h_cs = f_h_cs 452 | return outputs 453 | 454 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, positions_to_substitute=None, states_to_substitute=None, mode=None, **kwargs): 455 | token_type_ids = kwargs.get("token_type_ids", None) 456 | # only last token for inputs_ids if past is defined in kwargs 457 | if past_key_values: 458 | input_ids = input_ids[:, -1].unsqueeze(-1) 459 | if token_type_ids is not None: 460 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 461 | 462 | attention_mask = kwargs.get("attention_mask", None) 463 | position_ids = kwargs.get("position_ids", None) 464 | 465 | if attention_mask is not None and position_ids is None: 466 | # create position_ids on the fly for batch generation 467 | position_ids = attention_mask.long().cumsum(-1) - 1 468 | position_ids.masked_fill_(attention_mask == 0, 1) 469 | if past_key_values: 470 | position_ids = position_ids[:, -1].unsqueeze(-1) 471 | else: 472 | position_ids = None 473 | 474 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 475 | if inputs_embeds is not None and past_key_values is None: 476 | model_inputs = {"inputs_embeds": inputs_embeds} 477 | else: 478 | model_inputs = {"input_ids": input_ids} 479 | 480 | model_inputs.update( 481 | { 482 | "past_key_values": past_key_values, 483 | "use_cache": kwargs.get("use_cache"), 484 | "position_ids": position_ids, 485 | "attention_mask": attention_mask, 486 | "token_type_ids": token_type_ids, 487 | } 488 | ) 489 | if positions_to_substitute is not None: 490 | model_inputs['positions_to_substitute'] = positions_to_substitute 491 | model_inputs['states_to_substitute'] = states_to_substitute 492 | model_inputs['mode'] = mode 493 | return model_inputs 494 | 495 | def _update_model_kwargs_for_generation( 496 | self, 497 | outputs, 498 | model_kwargs: Dict[str, Any], 499 | is_encoder_decoder: bool = False, 500 | standardize_cache_format: bool = False, 501 | ) -> Dict[str, Any]: 502 | # Remove positions_to_substitute 503 | if 'positions_to_substitute' in model_kwargs: 504 | del model_kwargs['positions_to_substitute'] 505 | del model_kwargs['states_to_substitute'] 506 | del model_kwargs['mode'] 507 | # update past_key_values 508 | model_kwargs["past_key_values"] = self._extract_past_from_model_output( 509 | outputs, standardize_cache_format=standardize_cache_format 510 | ) 511 | 512 | # update token_type_ids with last value 513 | if "token_type_ids" in model_kwargs: 514 | token_type_ids = model_kwargs["token_type_ids"] 515 | model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) 516 | 517 | if not is_encoder_decoder: 518 | # update attention mask 519 | if "attention_mask" in model_kwargs: 520 | attention_mask = model_kwargs["attention_mask"] 521 | model_kwargs["attention_mask"] = torch.cat( 522 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 523 | ) 524 | else: 525 | # update decoder attention mask 526 | if "decoder_attention_mask" in model_kwargs: 527 | decoder_attention_mask = model_kwargs["decoder_attention_mask"] 528 | model_kwargs["decoder_attention_mask"] = torch.cat( 529 | [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], 530 | dim=-1, 531 | ) 532 | 533 | return model_kwargs 534 | -------------------------------------------------------------------------------- /src/models/student.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import torch 5 | import torch.nn as nn 6 | from transformers import AutoTokenizer 7 | 8 | sys.path.append("..") 9 | from utils import get_sep_position 10 | from .configuration_student import StudentConfig 11 | from .modeling_gpt2_implicit import GPT2LMHeadImplicitModel 12 | 13 | class Student(nn.Module): 14 | def __init__(self, config): 15 | super().__init__() 16 | self.config = config 17 | self.base_model = GPT2LMHeadImplicitModel.from_pretrained(config.base_model) 18 | self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) 19 | num_layers = len(self.base_model.transformer.h) 20 | hidden_size = self.base_model.config.hidden_size 21 | self.num_layers = num_layers 22 | self.hidden_size = hidden_size 23 | 24 | self.mlps = nn.ModuleList([nn.Sequential( 25 | nn.Linear(hidden_size, 4*hidden_size), 26 | nn.ReLU(), 27 | nn.Linear(4*hidden_size, hidden_size), 28 | ) for _ in range(num_layers)]) 29 | 30 | def forward(self, input_ids, positions_to_substitute, teacher_states, output_hidden_states=False): 31 | outputs = self.base_model.forward(mode='forward_student', \ 32 | input_ids=input_ids, \ 33 | positions_to_substitute=positions_to_substitute, \ 34 | states_to_substitute=teacher_states, \ 35 | output_hidden_states=output_hidden_states) 36 | return outputs 37 | 38 | def compute_loss(self, input_ids, labels, teacher_states): 39 | #import pdb; pdb.set_trace() 40 | sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id) 41 | # First, project teacher states 42 | teacher_states = [self.mlps[l](teacher_states[l]) for l in range(len(teacher_states))] 43 | 44 | # Forward while substituting teacher states 45 | outputs = self.forward(input_ids, sep_positions, teacher_states) 46 | logits = outputs.logits 47 | 48 | labels_pred = logits.argmax(-1) 49 | mask = labels[...,1:].ge(0) 50 | correct_tokens = ((labels_pred[...,:-1] == labels[...,1:]) * mask).sum() 51 | total_tokens = mask.sum() 52 | token_accuracy = correct_tokens / total_tokens 53 | 54 | shift_logits = logits[..., :-1, :].contiguous() 55 | shift_labels = labels[..., 1:].contiguous() 56 | loss_fct = nn.CrossEntropyLoss() 57 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 58 | 59 | outputs.loss = loss 60 | outputs.token_accuracy = token_accuracy 61 | outputs.total_correct = correct_tokens 62 | outputs.total_loss = loss * total_tokens 63 | outputs.total_tokens = total_tokens 64 | return outputs 65 | 66 | def generate(self, input_ids, teacher_states, max_new_tokens=512, num_beams=1): 67 | sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id) 68 | batch_size = input_ids.shape[0] 69 | beam_output = [] 70 | # First, project teacher states 71 | teacher_states = [self.mlps[l](teacher_states[l]) for l in range(len(teacher_states))] 72 | for i in range(batch_size): 73 | input_ids_i = input_ids[i:i+1] 74 | sep_positions_i = sep_positions[i:i+1] 75 | input_ids_i = input_ids_i[:, :sep_positions_i+1] 76 | beam_output_i = self.base_model.generate( 77 | input_ids=input_ids_i, 78 | max_new_tokens=max_new_tokens, 79 | num_beams=num_beams, 80 | early_stopping=True, 81 | num_return_sequences=1, 82 | positions_to_substitute=sep_positions_i.repeat_interleave(num_beams, dim=0), 83 | states_to_substitute=[z[i:i+1].repeat_interleave(num_beams, dim=0) for z in teacher_states], 84 | mode='forward_student', 85 | ) 86 | beam_output.append(beam_output_i) 87 | return beam_output 88 | 89 | @classmethod 90 | def from_pretrained(self, pretrained_path): 91 | config = StudentConfig.from_pretrained(pretrained_path) 92 | model = Student(config) 93 | state_dict = torch.load(os.path.join(pretrained_path, 'state_dict.bin')) 94 | try: 95 | model.load_state_dict(state_dict) 96 | except: 97 | model.load_state_dict(state_dict, strict=False) 98 | logging.warn("Some weights of the model Student checkpoint not loaded.") 99 | 100 | return model 101 | 102 | def save_pretrained(self, save_directory): 103 | print (f'Saving to {save_directory}') 104 | self.config.save_pretrained(save_directory) 105 | state_dict = self.state_dict() 106 | torch.save(state_dict, os.path.join(save_directory, 'state_dict.bin')) 107 | 108 | -------------------------------------------------------------------------------- /src/models/teacher.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import CrossEntropyLoss 6 | 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList, GenerationConfig, LogitsProcessorList 8 | 9 | from .configuration_teacher import TeacherConfig 10 | import sys 11 | sys.path.append("..") 12 | from utils import get_sep_position, DoubleEOSStoppingCriteria, DoubleEOSLogitsProcessor 13 | from .modeling_gpt2_implicit import GPT2LMHeadImplicitModel 14 | 15 | 16 | class Teacher(nn.Module): 17 | def __init__(self, config): 18 | super().__init__() 19 | self.config = config 20 | self.base_model = GPT2LMHeadImplicitModel.from_pretrained(config.base_model) 21 | self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) 22 | num_layers = len(self.base_model.transformer.h) 23 | hidden_size = self.base_model.config.hidden_size 24 | self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False) 25 | self.num_layers = num_layers 26 | self.hidden_size = hidden_size 27 | 28 | def forward(self, input_ids): 29 | outputs = self.base_model.forward(input_ids=input_ids) 30 | return outputs 31 | 32 | def compute_positions_to_extract_per_layer(self, subset, delta, first_sep_positions, second_sep_positions): 33 | batch_size = first_sep_positions.shape[0] 34 | positions_to_extract_per_layer = first_sep_positions.new_zeros(batch_size, self.num_layers).long() 35 | layer_ids = torch.arange(start=0, end=self.num_layers).to(first_sep_positions.device) 36 | for batch_id in range(batch_size): 37 | first_position_to_extract = first_sep_positions[batch_id] 38 | last_position_to_extract = second_sep_positions[batch_id] 39 | if subset == 'diagonal': 40 | if delta == 'dynamic': # determine actual delta 41 | delta = (last_position_to_extract - first_position_to_extract) / (self.num_layers - 1) 42 | elif subset == 'first_column' or subset == 'last_column': 43 | delta = 0 44 | else: 45 | assert subset == 'last_column', subset 46 | delta = 0 47 | first_position_to_extract = last_position_to_extract 48 | positions_to_extract = torch.round(first_position_to_extract + layer_ids * delta) 49 | positions_to_extract = positions_to_extract.clamp(max=last_position_to_extract) 50 | positions_to_extract_per_layer[batch_id] = positions_to_extract 51 | return positions_to_extract_per_layer 52 | 53 | def extract_states(self, input_ids, delta, subset='diagonal'): 54 | if delta.isnumeric(): 55 | delta = int(delta) 56 | batch_size = input_ids.shape[0] 57 | hidden_size = self.hidden_size 58 | 59 | # Find the boundaries between input and CoT, and CoT and output 60 | # [input] first_sep_position [CoT] second_position [output] eos 61 | first_sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id, skip=0) 62 | second_sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id, skip=1) 63 | input_ids = input_ids[:, :second_sep_positions.max()+1] 64 | 65 | # Forward the teacher to produce all hidden states 66 | outputs = self.base_model.forward(input_ids=input_ids, output_hidden_states=True) 67 | hidden_states = outputs.hidden_states[:-1] 68 | 69 | # Compute the positions to extract teacher states (t_l in the paper) 70 | positions_to_extract_per_layer = self.compute_positions_to_extract_per_layer(subset, delta, first_sep_positions, second_sep_positions) 71 | 72 | # Extract teacher states 73 | teacher_states_extracted = [] 74 | for i, hidden_state in enumerate(hidden_states): 75 | if subset == 'diagonal' or subset == 'first_column' or subset == 'last_column': 76 | z = hidden_state.gather(1, positions_to_extract_per_layer[:,i].view(-1, 1, 1).expand(-1, -1, hidden_size)).squeeze(1) 77 | elif subset == 'top_row': 78 | z = hidden_states[-1].gather(1, positions_to_extract_per_layer[:,i].view(-1, 1, 1).expand(-1, -1, hidden_size)).squeeze(1) 79 | else: 80 | assert subset == 'bottom_row', subset 81 | z = hidden_states[0].gather(1, positions_to_extract_per_layer[:,i].view(-1, 1, 1).expand(-1, -1, hidden_size)).squeeze(1) 82 | # Apply layer norm to normalize to 0 mean and 1 std 83 | z = self.layer_norm(z) 84 | teacher_states_extracted.append(z) 85 | return teacher_states_extracted 86 | 87 | def compute_loss(self, input_ids, labels): 88 | #import pdb; pdb.set_trace() 89 | outputs = self.forward(input_ids=input_ids) 90 | logits = outputs.logits 91 | 92 | labels_pred = logits.argmax(-1) 93 | mask = labels[...,1:].ge(0) 94 | correct_tokens = ((labels_pred[...,:-1] == labels[...,1:]) * mask).sum() 95 | total_tokens = mask.sum() 96 | token_accuracy = correct_tokens / total_tokens 97 | 98 | shift_logits = logits[..., :-1, :].contiguous() 99 | shift_labels = labels[..., 1:].contiguous() 100 | loss_fct = CrossEntropyLoss() 101 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 102 | 103 | outputs.loss = loss 104 | outputs.token_accuracy = token_accuracy 105 | outputs.total_correct = correct_tokens 106 | outputs.total_loss = loss * total_tokens 107 | outputs.total_tokens = total_tokens 108 | return outputs 109 | 110 | def generate(self, input_ids, max_new_tokens=512, num_beams=1, stop_on_two_eos=True): 111 | sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id) 112 | batch_size = input_ids.shape[0] 113 | 114 | # Since there's one eos after CoT and another after final answer, we need to wait for two eos 115 | generation_config = GenerationConfig.from_model_config(self.base_model.config) 116 | if stop_on_two_eos: 117 | generation_config.eos_token_id = -1 118 | logits_processor = LogitsProcessorList([DoubleEOSLogitsProcessor(self.tokenizer.eos_token_id)]) 119 | stopping_criteria = StoppingCriteriaList([DoubleEOSStoppingCriteria(self.tokenizer.eos_token_id)]) 120 | else: 121 | logits_processor = None 122 | stopping_criteria = None 123 | 124 | if sep_positions.eq(sep_positions[0]).all(): 125 | input_ids = input_ids[:, :sep_positions[0]+1] 126 | beam_output = self.base_model.generate( 127 | input_ids=input_ids, 128 | generation_config=generation_config, 129 | max_new_tokens=max_new_tokens, 130 | num_beams=num_beams, 131 | early_stopping=True, 132 | num_return_sequences=1, 133 | logits_processor=logits_processor, 134 | stopping_criteria=stopping_criteria, 135 | ) 136 | beam_output = beam_output.unsqueeze(1) 137 | else: 138 | beam_output = [] 139 | for i in range(batch_size): 140 | input_ids_i = input_ids[i:i+1] 141 | sep_positions_i = sep_positions[i:i+1] 142 | input_ids_i = input_ids_i[:, :sep_positions_i+1] 143 | beam_output_i = self.base_model.generate( 144 | input_ids=input_ids_i, 145 | generation_config=generation_config, 146 | max_new_tokens=max_new_tokens, 147 | num_beams=num_beams, 148 | early_stopping=True, 149 | num_return_sequences=1, 150 | logits_processor=logits_processor, 151 | stopping_criteria=stopping_criteria, 152 | ) 153 | beam_output.append(beam_output_i) 154 | return beam_output 155 | 156 | @classmethod 157 | def from_pretrained(self, pretrained_path): 158 | config = TeacherConfig.from_pretrained(pretrained_path) 159 | model = Teacher(config) 160 | state_dict = torch.load(os.path.join(pretrained_path, 'state_dict.bin')) 161 | model.load_state_dict(state_dict) 162 | return model 163 | 164 | def save_pretrained(self, save_directory): 165 | print (f'Saving to {save_directory}') 166 | self.config.save_pretrained(save_directory) 167 | state_dict = self.state_dict() 168 | torch.save(state_dict, os.path.join(save_directory, 'state_dict.bin')) 169 | -------------------------------------------------------------------------------- /src/train_coupled_emulator_and_student.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from transformers import AdamW 5 | import argparse 6 | import os 7 | import inspect 8 | import tqdm 9 | import logging 10 | import random 11 | from itertools import chain 12 | 13 | from data import CoTDataset, CoTDataCollator, extract_answer 14 | from models.student import Student 15 | from models.emulator import Emulator 16 | from utils import get_sep_position 17 | 18 | torch.backends.cuda.matmul.allow_tf32 = True 19 | torch.backends.cudnn.allow_tf32 = True 20 | random.seed(1234) 21 | torch.manual_seed(1234) 22 | logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | 27 | @torch.no_grad() 28 | def evaluate(dataloader, tokenizer, ctx, emulator, student, max_new_tokens): 29 | total_instances = 0 30 | total_tokens = 0 31 | total_correct = 0 32 | total_correct_tokens = 0 33 | total_loss = 0 34 | for batch in tqdm.tqdm(dataloader): 35 | #import pdb; pdb.set_trace() 36 | input_ids_nocot = batch['input_ids_nocot'].to(device) 37 | labels_nocot = batch['labels_nocot'].to(device) 38 | batch_size = input_ids_nocot.shape[0] 39 | with ctx: 40 | emulated_teacher_states = emulator(input_ids=input_ids_nocot) 41 | outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=emulated_teacher_states) 42 | loss = outputs.loss 43 | token_accuracy = outputs.token_accuracy.item() 44 | total_loss += outputs.total_loss.item() 45 | total_correct_tokens += outputs.total_correct.item() 46 | total_tokens += outputs.total_tokens 47 | total_instances += batch_size 48 | 49 | # Generate 50 | with ctx: 51 | beam_output = student.generate( 52 | input_ids=input_ids_nocot, 53 | teacher_states=emulated_teacher_states, 54 | max_new_tokens=max_new_tokens, 55 | ) 56 | 57 | # Evaluate 58 | sep_positions = get_sep_position(input_ids_nocot, tokenizer.eos_token_id) 59 | for i, (input_ids_i, beam_output_i) in enumerate(zip(input_ids_nocot, beam_output)): 60 | sep_position = sep_positions[i].item() 61 | tgt = input_ids_i[sep_position+1:] 62 | tgt_text = tokenizer.decode(tgt, skip_special_tokens=True) 63 | ans = extract_answer(tgt_text) 64 | pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True) 65 | pred_ans = extract_answer(pred_text) 66 | if ans == pred_ans: 67 | total_correct += 1 68 | if i == 0: 69 | print (f'Input: {tokenizer.decode(input_ids_i[:sep_position], skip_special_tokens=True)}') 70 | print (f'Target: {tgt_text}') 71 | print (f'Predicted: {pred_text}') 72 | print ('') 73 | accuracy = total_correct / total_instances 74 | token_accuracy = total_correct_tokens / total_tokens 75 | loss = total_loss / total_tokens 76 | ppl = math.exp(loss) 77 | return accuracy, token_accuracy, ppl 78 | 79 | 80 | def main(): 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--emulator', type=str, required=True) 83 | parser.add_argument('--student', type=str, required=True) 84 | parser.add_argument('--train_path', type=str, required=True) 85 | parser.add_argument('--val_path', type=str, required=True) 86 | parser.add_argument('--save_model', type=str, required=True) 87 | parser.add_argument('--max_new_tokens', type=int, default=128) 88 | parser.add_argument('--epochs', type=int, default=5) 89 | parser.add_argument('--batch_size', type=int, default=32) 90 | parser.add_argument('--lr', type=float, default=5e-5) 91 | parser.add_argument('--max_grad_norm', type=float, default=1.0) 92 | parser.add_argument('--softmax_temperature', type=float, default=0.05) 93 | parser.add_argument('--fix_emulator', dest='fix_emulator', action='store_true') 94 | parser.set_defaults(fix_emulator=False) 95 | args = parser.parse_args() 96 | 97 | print (args) 98 | 99 | dtype = 'float32' 100 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 101 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 102 | ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype) 103 | print (ptdtype, dtype, device) 104 | 105 | # Load Student 106 | student = Student.from_pretrained(args.student).to(device).to(ptdtype) 107 | 108 | # Load Emulator 109 | emulator = Emulator.from_pretrained(args.emulator).to(device).to(ptdtype) 110 | 111 | # Load data 112 | tokenizer = emulator.tokenizer 113 | collate_fn = CoTDataCollator(tokenizer) 114 | train_dataset = CoTDataset(tokenizer, args.train_path, 1024) 115 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True) 116 | val_dataset = CoTDataset(tokenizer, args.val_path, 1024) 117 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False) 118 | 119 | # Create Optimizer 120 | if args.fix_emulator: 121 | trainable_params = list(student.parameters()) 122 | for p in emulator.parameters(): 123 | p.requires_grad = False 124 | else: 125 | trainable_params = list(student.parameters()) + list(emulator.parameters()) 126 | use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters 127 | extra_args = dict(fused=True) if use_fused else dict() 128 | optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args) 129 | 130 | emulator.eval() # to turn off dropout 131 | student.eval() # to turn off dropout 132 | 133 | 134 | # Train 135 | step = 0 136 | for epoch in range(args.epochs): 137 | print(f"Epoch {epoch}") 138 | 139 | for batch in tqdm.tqdm(train_dataloader): 140 | #import pdb; pdb.set_trace() 141 | input_ids_nocot = batch['input_ids_nocot'].to(device) 142 | labels_nocot = batch['labels_nocot'].to(device) 143 | with ctx: 144 | emulated_teacher_states = emulator(input_ids_nocot, requires_backward=not args.fix_emulator) 145 | outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=emulated_teacher_states) 146 | loss = outputs.loss 147 | token_accuracy = outputs.token_accuracy.item() 148 | 149 | loss.backward() 150 | torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm) 151 | optimizer.step() 152 | optimizer.zero_grad() 153 | ppl = loss.exp().item() 154 | if step % 100 == 0: 155 | print (f"Step: {step}. PPL: {ppl}. Token Accuracy: {token_accuracy}") 156 | step += 1 157 | accuracy, token_accuracy, ppl = evaluate(val_dataloader, tokenizer, ctx, emulator, student, args.max_new_tokens) 158 | print (f'Val. PPL: {ppl}; Accuracy: {accuracy}; Token Accuracy: {token_accuracy}.') 159 | student.save_pretrained(os.path.join(args.save_model, 'student', f'checkpoint_{epoch}')) 160 | emulator.save_pretrained(os.path.join(args.save_model, 'emulator', f'checkpoint_{epoch}')) 161 | 162 | 163 | if __name__ == "__main__": 164 | main() 165 | -------------------------------------------------------------------------------- /src/train_mind_reading_student.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from transformers import AdamW 5 | import argparse 6 | import os 7 | import sys 8 | import inspect 9 | import tqdm 10 | import logging 11 | import random 12 | 13 | from data import CoTDataset, CoTDataCollator, extract_answer 14 | from models.teacher import Teacher 15 | from models.student import Student 16 | from models.configuration_student import StudentConfig 17 | from utils import get_sep_position 18 | 19 | torch.backends.cuda.matmul.allow_tf32 = True 20 | torch.backends.cudnn.allow_tf32 = True 21 | 22 | random.seed(1234) 23 | torch.manual_seed(1234) 24 | logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere 25 | 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | 28 | @torch.no_grad() 29 | def evaluate(dataloader, tokenizer, ctx, teacher, student, delta, subset, max_new_tokens): 30 | total_instances = 0 31 | total_tokens = 0 32 | total_correct = 0 33 | total_correct_tokens = 0 34 | total_loss = 0 35 | for batch in tqdm.tqdm(dataloader): 36 | input_ids_all = batch['input_ids_all'].to(device) 37 | input_ids_nocot = batch['input_ids_nocot'].to(device) 38 | labels_nocot = batch['labels_nocot'].to(device) 39 | batch_size = input_ids_nocot.shape[0] 40 | with ctx: 41 | teacher_states = teacher.extract_states(input_ids=input_ids_all, delta=delta, subset=subset) 42 | outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=teacher_states) 43 | loss = outputs.loss 44 | token_accuracy = outputs.token_accuracy.item() 45 | total_loss += outputs.total_loss.item() 46 | total_correct_tokens += outputs.total_correct.item() 47 | total_tokens += outputs.total_tokens 48 | total_instances += batch_size 49 | 50 | # Generate 51 | with ctx: 52 | beam_output = student.generate( 53 | input_ids=input_ids_nocot, 54 | teacher_states=teacher_states, 55 | max_new_tokens=max_new_tokens, 56 | ) 57 | 58 | # Evaluate 59 | sep_positions = get_sep_position(input_ids_all, tokenizer.eos_token_id) 60 | for i, (input_ids_all_i, beam_output_i) in enumerate(zip(input_ids_all, beam_output)): 61 | sep_position = sep_positions[i].item() 62 | tgt = input_ids_all_i[sep_position+1:] 63 | tgt_text = tokenizer.decode(tgt, skip_special_tokens=True) 64 | ans = extract_answer(tgt_text) 65 | pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True) 66 | pred_ans = extract_answer(pred_text) 67 | if ans == pred_ans: 68 | total_correct += 1 69 | if i == 0: 70 | print (f'Input: {tokenizer.decode(input_ids_all_i[:sep_position], skip_special_tokens=True)}') 71 | print (f'Target: {tgt_text}') 72 | print (f'Predicted: {pred_text}') 73 | print ('') 74 | accuracy = total_correct / total_instances 75 | token_accuracy = total_correct_tokens / total_tokens 76 | loss = total_loss / total_tokens 77 | ppl = math.exp(loss) 78 | return accuracy, token_accuracy, ppl 79 | 80 | 81 | def main(): 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--teacher', type=str, required=True) 84 | parser.add_argument('--delta', type=str, required=True) 85 | parser.add_argument('--train_path', type=str, required=True) 86 | parser.add_argument('--val_path', type=str, required=True) 87 | parser.add_argument('--save_model', type=str, required=True) 88 | parser.add_argument('--max_new_tokens', type=int, default=128) 89 | parser.add_argument('--base_model', type=str, default='gpt2') 90 | parser.add_argument('--epochs', type=int, default=5) 91 | parser.add_argument('--batch_size', type=int, default=32) 92 | parser.add_argument('--lr', type=float, default=5e-5) 93 | parser.add_argument('--max_grad_norm', type=float, default=1.0) 94 | parser.add_argument('--subset', type=str, choices=['diagonal', 'last_column', 'top_row', 'bottom_row', 'first_column'], default='diagonal') 95 | args = parser.parse_args() 96 | 97 | print (args) 98 | 99 | dtype = 'float32' 100 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 101 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 102 | ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype) 103 | print (ptdtype, dtype, device) 104 | 105 | # Create Student 106 | config = StudentConfig(base_model=args.base_model) 107 | student = Student(config).to(device).to(ptdtype) 108 | 109 | # Load Teacher 110 | teacher = Teacher.from_pretrained(args.teacher).to(device).to(ptdtype) 111 | 112 | # Load data 113 | tokenizer = teacher.tokenizer 114 | collate_fn = CoTDataCollator(tokenizer) 115 | train_dataset = CoTDataset(tokenizer, args.train_path, 1024) 116 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True) 117 | val_dataset = CoTDataset(tokenizer, args.val_path, 1024) 118 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False) 119 | 120 | # Create Optimizer 121 | trainable_params = list(student.parameters()) 122 | use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters 123 | extra_args = dict(fused=True) if use_fused else dict() 124 | optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args) 125 | 126 | teacher.eval() 127 | student.eval() # to turn off dropout 128 | 129 | for p in teacher.parameters(): 130 | p.requires_grad = False 131 | 132 | # Train 133 | step = 0 134 | for epoch in range(args.epochs): 135 | print(f"Epoch {epoch}") 136 | 137 | for batch in tqdm.tqdm(train_dataloader): 138 | input_ids_all = batch['input_ids_all'].to(device) 139 | input_ids_nocot = batch['input_ids_nocot'].to(device) 140 | labels_nocot = batch['labels_nocot'].to(device) 141 | with ctx: 142 | with torch.no_grad(): 143 | teacher_states = teacher.extract_states(input_ids=input_ids_all, delta=args.delta, subset=args.subset) 144 | outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=teacher_states) 145 | loss = outputs.loss 146 | token_accuracy = outputs.token_accuracy.item() 147 | 148 | loss.backward() 149 | torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm) 150 | optimizer.step() 151 | optimizer.zero_grad() 152 | ppl = loss.exp().item() 153 | if step % 100 == 0: 154 | print (f"Step: {step}. PPL: {ppl}. Token Accuracy: {token_accuracy}") 155 | sys.stdout.flush() 156 | step += 1 157 | accuracy, token_accuracy, ppl = evaluate(val_dataloader, tokenizer, ctx, teacher, student, args.delta, args.subset, args.max_new_tokens) 158 | print (f'Val. PPL: {ppl}; Accuracy: {accuracy}; Token Accuracy: {token_accuracy}.') 159 | student.save_pretrained(os.path.join(args.save_model, f'checkpoint_{epoch}')) 160 | 161 | if __name__ == "__main__": 162 | main() 163 | -------------------------------------------------------------------------------- /src/train_teacher.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from transformers import AdamW 5 | import argparse 6 | import os 7 | import tqdm 8 | import inspect 9 | import logging 10 | 11 | from models.teacher import Teacher 12 | from models.configuration_teacher import TeacherConfig 13 | from data import CoTDataset, CoTDataCollator, extract_answer 14 | 15 | from utils import get_sep_position 16 | 17 | torch.backends.cuda.matmul.allow_tf32 = True 18 | torch.backends.cudnn.allow_tf32 = True 19 | logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere 20 | 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | 24 | def save_model(model, tokenizer, model_dir): 25 | print ('saving', model_dir) 26 | os.makedirs(model_dir, exist_ok=True) 27 | model.save_pretrained(model_dir) 28 | tokenizer.save_pretrained(model_dir) 29 | 30 | @torch.no_grad() 31 | def evaluate(dataloader, tokenizer, ctx, teacher, max_new_tokens): 32 | teacher.eval() 33 | total_instances = 0 34 | total_tokens = 0 35 | total_correct = 0 36 | total_correct_tokens = 0 37 | total_loss = 0 38 | for batch in tqdm.tqdm(dataloader): 39 | input_ids_all = batch['input_ids_all'].to(device) 40 | labels = batch['labels_all'].to(device) 41 | # Remove answer part 42 | sep_positions = get_sep_position(input_ids_all, tokenizer.eos_token_id) 43 | input_ids = input_ids_all[:, :sep_positions.max()+1] 44 | batch_size = input_ids.shape[0] 45 | with ctx: 46 | outputs = teacher.compute_loss(input_ids=input_ids_all, labels=labels) 47 | total_loss += outputs.total_loss.item() 48 | total_correct_tokens += outputs.total_correct.item() 49 | total_tokens += outputs.total_tokens 50 | total_instances += batch_size 51 | 52 | # Generate 53 | beam_output = teacher.generate( 54 | input_ids=input_ids, 55 | max_new_tokens=max_new_tokens, 56 | ) 57 | # Evaluate 58 | #import pdb; pdb.set_trace() 59 | for i, (input_ids_all_i, beam_output_i) in enumerate(zip(input_ids_all, beam_output)): 60 | sep_position = sep_positions[i].item() 61 | tgt = input_ids_all_i[sep_position+1:] 62 | tgt_text = tokenizer.decode(tgt, skip_special_tokens=True) 63 | ans = extract_answer(tgt_text) 64 | pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True) 65 | pred_ans = extract_answer(pred_text) 66 | if ans == pred_ans: 67 | total_correct += 1 68 | if i == 0: 69 | print (f'Input: {tokenizer.decode(input_ids_all_i[:sep_position], skip_special_tokens=True)}') 70 | print (f'Target: {tgt_text}') 71 | print (f'Predicted: {pred_text}') 72 | print ('') 73 | accuracy = total_correct / total_instances 74 | token_accuracy = total_correct_tokens / total_tokens 75 | loss = total_loss / total_tokens 76 | ppl = math.exp(loss) 77 | return accuracy, token_accuracy, ppl 78 | 79 | 80 | def main(): 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--train_path', type=str, required=True) 83 | parser.add_argument('--val_path', type=str, required=True) 84 | parser.add_argument('--save_model', type=str, required=True) 85 | parser.add_argument('--max_new_tokens', type=int, default=128) 86 | parser.add_argument('--base_model', type=str, default='gpt2') 87 | parser.add_argument('--epochs', type=int, default=1) 88 | parser.add_argument('--batch_size', type=int, default=32) 89 | parser.add_argument('--lr', type=float, default=5e-5) 90 | parser.add_argument('--max_grad_norm', type=float, default=1.0) 91 | args = parser.parse_args() 92 | 93 | print (args) 94 | 95 | dtype = 'float32' 96 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 97 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 98 | ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype) 99 | print (ptdtype, dtype, device) 100 | 101 | # Create Teacher 102 | config = TeacherConfig(base_model=args.base_model) 103 | teacher = Teacher(config).to(device).to(ptdtype) 104 | 105 | # Load data 106 | tokenizer = teacher.tokenizer 107 | collate_fn = CoTDataCollator(tokenizer) 108 | train_dataset = CoTDataset(tokenizer, args.train_path, 1024) 109 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True) 110 | val_dataset = CoTDataset(tokenizer, args.val_path, 1024) 111 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False) 112 | 113 | # Create Optimizer 114 | trainable_params = list(teacher.parameters()) 115 | use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters 116 | extra_args = dict(fused=True) if use_fused else dict() 117 | optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args) 118 | 119 | teacher.train() 120 | 121 | # Train 122 | step = 0 123 | for epoch in range(args.epochs): 124 | print(f"Epoch {epoch}") 125 | teacher.train() 126 | for batch in tqdm.tqdm(train_dataloader): 127 | input_ids = batch['input_ids_all'].to(device) 128 | labels = batch['labels_all'].to(device) 129 | with ctx: 130 | outputs = teacher.compute_loss(input_ids=input_ids, labels=labels) 131 | loss = outputs.loss 132 | token_accuracy = outputs.token_accuracy.item() 133 | 134 | loss.backward() 135 | torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm) 136 | optimizer.step() 137 | optimizer.zero_grad() 138 | ppl = loss.exp().item() 139 | if step % 100 == 0: 140 | print (f"Step: {step}. PPL: {ppl}. Token Accuracy: {token_accuracy}") 141 | step += 1 142 | accuracy, token_accuracy, ppl = evaluate(val_dataloader, tokenizer, ctx, teacher, args.max_new_tokens) 143 | print (f'Val. PPL: {ppl}; Accuracy: {accuracy}; Token Accuracy: {token_accuracy}.') 144 | teacher.save_pretrained(os.path.join(args.save_model, f'checkpoint_{epoch}')) 145 | 146 | if __name__ == "__main__": 147 | main() 148 | -------------------------------------------------------------------------------- /src/train_thought_emulator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from transformers import AdamW 4 | import argparse 5 | import os 6 | import inspect 7 | import tqdm 8 | import logging 9 | import random 10 | import torch.nn as nn 11 | 12 | from data import CoTDataset, CoTDataCollator 13 | from models.teacher import Teacher 14 | from models.emulator import Emulator 15 | from models.configuration_emulator import EmulatorConfig 16 | 17 | 18 | torch.backends.cuda.matmul.allow_tf32 = True 19 | torch.backends.cudnn.allow_tf32 = True 20 | random.seed(1234) 21 | torch.manual_seed(1234) 22 | logging.disable(logging.WARNING) 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | @torch.no_grad() 27 | def evaluate(dataloader, tokenizer, ctx, teacher, emulator, delta, subset): 28 | total_instances = 0 29 | total_loss = 0 30 | for batch in tqdm.tqdm(dataloader): 31 | #import pdb; pdb.set_trace() 32 | input_ids_cot = batch['input_ids_cot'].to(device) 33 | batch_size = input_ids_cot.shape[0] 34 | with ctx: 35 | teacher_states = teacher.extract_states(input_ids=input_ids_cot, delta=delta, subset=subset) 36 | outputs = emulator.compute_loss(input_ids=input_ids_cot, teacher_states=teacher_states) 37 | loss = outputs.loss 38 | total_loss += outputs.total_loss.item() 39 | total_instances += batch_size 40 | 41 | loss = total_loss / total_instances 42 | return loss 43 | 44 | def main(): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--teacher', type=str, required=True) 47 | parser.add_argument('--delta', type=str, required=True) 48 | parser.add_argument('--train_path', type=str, required=True) 49 | parser.add_argument('--val_path', type=str, required=True) 50 | parser.add_argument('--save_model', type=str, required=True) 51 | parser.add_argument('--base_model', type=str, default='gpt2') 52 | parser.add_argument('--epochs', type=int, default=5) 53 | parser.add_argument('--batch_size', type=int, default=32) 54 | parser.add_argument('--lr', type=float, default=5e-5) 55 | parser.add_argument('--max_grad_norm', type=float, default=1.0) 56 | parser.add_argument('--subset', type=str, choices=['diagonal', 'last_column', 'top_row', 'bottom_row', 'first_column'], default='diagonal') 57 | parser.add_argument('--mixture_size', type=int, default=1) 58 | args = parser.parse_args() 59 | 60 | print (args) 61 | dtype = 'float32' 62 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 63 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 64 | ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype) 65 | print (ptdtype, dtype, device) 66 | 67 | # Create Emulator 68 | config = EmulatorConfig(base_model=args.base_model, mixture_size=args.mixture_size) 69 | emulator = Emulator(config).to(device).to(ptdtype) 70 | 71 | # Load Teacher 72 | teacher = Teacher.from_pretrained(args.teacher).to(device).to(ptdtype) 73 | 74 | # Load data 75 | tokenizer = teacher.tokenizer 76 | collate_fn = CoTDataCollator(tokenizer) 77 | train_dataset = CoTDataset(tokenizer, args.train_path, 1024) 78 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True) 79 | val_dataset = CoTDataset(tokenizer, args.val_path, 1024) 80 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False) 81 | 82 | # Create Optimizer 83 | trainable_params = list(emulator.parameters()) 84 | use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters 85 | extra_args = dict(fused=True) if use_fused else dict() 86 | optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args) 87 | 88 | teacher.eval() 89 | emulator.eval() # to turn off dropout 90 | 91 | for p in teacher.parameters(): 92 | p.requires_grad = False 93 | 94 | # Train 95 | step = 0 96 | for epoch in range(args.epochs): 97 | print(f"Epoch {epoch}") 98 | 99 | for batch in tqdm.tqdm(train_dataloader): 100 | #import pdb; pdb.set_trace() 101 | input_ids_cot = batch['input_ids_cot'].to(device) 102 | input_ids_nocot = batch['input_ids_nocot'].to(device) 103 | with ctx: 104 | with torch.no_grad(): 105 | teacher_states = teacher.extract_states(input_ids=input_ids_cot, delta=args.delta, subset=args.subset) 106 | outputs = emulator.compute_loss(input_ids=input_ids_nocot, teacher_states=teacher_states) 107 | loss = outputs.loss 108 | 109 | loss.backward() 110 | torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm) 111 | optimizer.step() 112 | optimizer.zero_grad() 113 | if step % 100 == 0: 114 | print (f"Step: {step}. Loss: {loss}.") 115 | step += 1 116 | loss = evaluate(val_dataloader, tokenizer, ctx, teacher, emulator, args.delta, args.subset) 117 | print (f'Val. Loss: {loss}.') 118 | emulator.save_pretrained(os.path.join(args.save_model, f'checkpoint_{epoch}')) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria, LogitsProcessor 3 | 4 | def get_sep_position(input_ids, sep_id, skip=0): 5 | batch_size = input_ids.shape[0] 6 | sep_positions = input_ids.new_zeros(batch_size).long() 7 | for batch_id in range(batch_size): 8 | mask = input_ids[batch_id].eq(sep_id) 9 | sep_position = mask.nonzero()[0, -1].item() 10 | for _ in range(skip): 11 | mask[sep_position] = False 12 | sep_position = mask.nonzero()[0, -1].item() 13 | sep_positions[batch_id] = sep_position 14 | return sep_positions 15 | 16 | 17 | # Stop generation only after generating two EOSs, such as z y 18 | class DoubleEOSStoppingCriteria(StoppingCriteria): 19 | def __init__(self, eos_token_id): 20 | super().__init__() 21 | self.eos_token_id = eos_token_id 22 | self.init = False 23 | 24 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 25 | eos_count = (input_ids == self.eos_token_id).sum(dim=-1) 26 | if not self.init: 27 | self.init = True 28 | self.eos_count_init = eos_count 29 | done = (eos_count - self.eos_count_init) >= 2 30 | return done.all() 31 | 32 | class DoubleEOSLogitsProcessor(LogitsProcessor): 33 | def __init__(self, eos_token_id): 34 | super().__init__() 35 | self.eos_token_id = eos_token_id 36 | self.init = False 37 | 38 | def __call__(self, input_ids, scores): 39 | eos_count = (input_ids == self.eos_token_id).sum(dim=-1) 40 | if not self.init: 41 | self.init = True 42 | self.eos_count_init = eos_count 43 | done = (eos_count - self.eos_count_init) >= 2 44 | if done.any(): 45 | scores[done, :] = float('-inf') 46 | scores[done, self.eos_token_id] = 0 47 | return scores 48 | -------------------------------------------------------------------------------- /src_autoencoder/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | import copy 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torch.nn.utils.rnn import pad_sequence 7 | 8 | def extract_answer(text): 9 | split_pattern = '####' 10 | if split_pattern not in text: 11 | return text.strip().replace(',', '') 12 | else: 13 | _, ans = text.strip().split('####', 1) 14 | ans = '####' + ans 15 | ans = ans.strip().replace(',', '') 16 | return ans 17 | 18 | def extract_cot(text): 19 | split_pattern = '####' 20 | if split_pattern not in text: 21 | return None 22 | else: 23 | cot, _ = text.strip().split('####', 1) 24 | cot = cot.strip() 25 | return cot 26 | 27 | class CoTDataset(Dataset): 28 | def __init__(self, tokenizer, file_path, max_length): 29 | assert os.path.isfile(file_path), f"Input file path {file_path} not found" 30 | print (f'Creating features from dataset file at {file_path}') 31 | bos_tok = tokenizer.bos_token 32 | eos_tok = tokenizer.eos_token 33 | 34 | with open(file_path, encoding="utf-8") as f: 35 | lines = [line.split('||') for line in f.read().splitlines() if (len(line) > 0 and not line.isspace() 36 | and len(line.split('||')) ==2 )] 37 | src_lines, tgt_lines = list(zip(*lines)) 38 | src_lines = list(src_lines) 39 | tgt_lines = list(tgt_lines) 40 | 41 | edited_sents_cot = [] 42 | edited_sents_only = [] 43 | edited_sents_all = [] 44 | edited_sents_nocot = [] 45 | for src, tgt in zip(src_lines, tgt_lines): 46 | #import pdb; pdb.set_trace() 47 | ans = extract_answer(tgt) 48 | cot = extract_cot(tgt) 49 | sent = ' {} {} '.format(src, bos_tok) + cot + ' {}'.format(eos_tok) 50 | edited_sents_cot.append(sent) 51 | sent = ' {} {} '.format(src, bos_tok) 52 | edited_sents_only.append(sent) 53 | sent = ' {} {} '.format(src, bos_tok) + cot + ' {} '.format(eos_tok) + ans + ' {}'.format(eos_tok) 54 | edited_sents_all.append(sent) 55 | sent = ' {} {} '.format(src, bos_tok) + ans + ' {}'.format(eos_tok) 56 | edited_sents_nocot.append(sent) 57 | 58 | batch_encoding_cot = tokenizer(edited_sents_cot, add_special_tokens=True, truncation=True, max_length=max_length) 59 | batch_encoding_only = tokenizer(edited_sents_only, add_special_tokens=True, truncation=True, max_length=max_length) 60 | batch_encoding_all = tokenizer(edited_sents_all, add_special_tokens=True, truncation=True, max_length=max_length) 61 | batch_encoding_nocot = tokenizer(edited_sents_nocot, add_special_tokens=True, truncation=True, max_length=max_length) 62 | self.examples_cot = batch_encoding_cot["input_ids"] 63 | self.examples_only = batch_encoding_only["input_ids"] 64 | self.examples_all = batch_encoding_all["input_ids"] 65 | self.examples_nocot = batch_encoding_nocot["input_ids"] 66 | 67 | self.labels_cot = copy.deepcopy(self.examples_cot) 68 | self.labels_all = copy.deepcopy(self.examples_all) 69 | self.labels_cot_shift = copy.deepcopy(self.examples_cot) 70 | self.labels_nocot = copy.deepcopy(self.examples_nocot) 71 | 72 | self.src_sent_cot = [] 73 | self.tgt_sent_cot = [] 74 | 75 | temp_src_len = 0 76 | temp_tgt_len = 0 77 | temp_count = 0 78 | separator = tokenizer.eos_token_id #tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0] 79 | for i, elem in enumerate(self.labels_cot): 80 | sep_idx = elem.index(separator) + 1 81 | self.src_sent_cot.append(self.examples_cot[i][:sep_idx-1]) 82 | self.tgt_sent_cot.append(self.examples_cot[i][sep_idx-1:]) 83 | self.labels_cot[i][:sep_idx] = [-100] * sep_idx 84 | assert self.labels_all[i][sep_idx-1] == separator 85 | self.labels_all[i][:sep_idx] = [-100] * sep_idx 86 | self.labels_cot_shift[i][:sep_idx-1] = [-100] * (sep_idx-1) 87 | temp_src_len += sep_idx-1 88 | temp_tgt_len += len(elem) - (sep_idx-1) 89 | temp_count += 1 90 | 91 | print('tgt_avg: ', temp_tgt_len / temp_count) 92 | print('src_avg: ', temp_src_len / temp_count) 93 | print('ratios: ', temp_src_len/temp_tgt_len) 94 | 95 | self.src_sent_nocot = [] 96 | self.tgt_sent_nocot = [] 97 | temp_src_len = 0 98 | temp_tgt_len = 0 99 | temp_count = 0 100 | separator = tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0] 101 | for i, elem in enumerate(self.labels_nocot): 102 | sep_idx = elem.index(separator) + 1 103 | self.src_sent_nocot.append(self.examples_nocot[i][:sep_idx-1]) 104 | self.tgt_sent_nocot.append(self.examples_nocot[i][sep_idx-1:]) 105 | self.labels_nocot[i][:sep_idx] = [-100] * sep_idx 106 | temp_src_len += sep_idx-1 107 | temp_tgt_len += len(elem) - (sep_idx-1) 108 | temp_count += 1 109 | 110 | print('tgt_avg: ', temp_tgt_len / temp_count) 111 | print('src_avg: ', temp_src_len / temp_count) 112 | print('ratios: ', temp_src_len/temp_tgt_len) 113 | 114 | 115 | print(edited_sents_all[0]) 116 | print(self.labels_cot[0]) 117 | print(self.labels_nocot[0]) 118 | print(self.examples_nocot[0]) 119 | print(edited_sents_nocot[0]) 120 | print(self.src_sent_nocot[0]) 121 | print(self.tgt_sent_nocot[0]) 122 | 123 | def __len__(self): 124 | return len(self.examples_cot) 125 | 126 | # def __getitem__(self, i) -> torch.Tensor: 127 | def __getitem__(self, i): 128 | return (torch.tensor(self.examples_cot[i], dtype=torch.long), 129 | torch.tensor(self.examples_nocot[i], dtype=torch.long), 130 | torch.tensor(self.labels_cot[i], dtype=torch.long), 131 | torch.tensor(self.labels_cot_shift[i], dtype=torch.long), 132 | torch.tensor(self.labels_nocot[i], dtype=torch.long), 133 | torch.tensor(self.src_sent_cot[i], dtype=torch.long), 134 | torch.tensor(self.src_sent_nocot[i], dtype=torch.long), 135 | torch.tensor(self.tgt_sent_cot[i], dtype=torch.long), 136 | torch.tensor(self.tgt_sent_nocot[i], dtype=torch.long), 137 | torch.tensor(self.examples_only[i], dtype=torch.long), 138 | torch.tensor(self.examples_all[i], dtype=torch.long), 139 | torch.tensor(self.labels_all[i], dtype=torch.long), 140 | ) 141 | @dataclass 142 | class CoTDataCollator: 143 | """ 144 | VAEData collator used for language modeling. 145 | - collates batches of tensors, honoring their tokenizer's pad_token 146 | - preprocesses batches for masked language modeling 147 | """ 148 | def __init__(self, tokenizer): 149 | self.tokenizer = tokenizer 150 | 151 | def __call__(self, examples): 152 | #import pdb; pdb.set_trace() 153 | input_ids_cot, input_ids_nocot, labels_cot, labels_cot_shift, labels_nocot, src_cot, src_nocot, tgt_cot, tgt_nocot, input_ids_only, input_ids_all, labels_all = zip(*examples) 154 | input_ids_cot = self._tensorize_batch(input_ids_cot) 155 | input_ids_cot[input_ids_cot.lt(0)] = self.tokenizer.eos_token_id 156 | input_ids_only = self._tensorize_batch(input_ids_only) 157 | input_ids_only[input_ids_only.lt(0)] = self.tokenizer.eos_token_id 158 | input_ids_all = self._tensorize_batch(input_ids_all) 159 | input_ids_all[input_ids_all.lt(0)] = self.tokenizer.eos_token_id 160 | input_ids_nocot = self._tensorize_batch(input_ids_nocot) 161 | input_ids_nocot[input_ids_nocot.lt(0)] = self.tokenizer.eos_token_id 162 | labels_cot = self._tensorize_batch(labels_cot) 163 | labels_all = self._tensorize_batch(labels_all) 164 | labels_cot_shift = self._tensorize_batch(labels_cot_shift) 165 | labels_nocot = self._tensorize_batch(labels_nocot) 166 | return {"input_ids_cot": input_ids_cot, "input_ids_nocot": input_ids_nocot, "labels_cot": labels_cot, "labels_cot_shift": labels_cot_shift, "labels_nocot": labels_nocot, 'input_ids_only': input_ids_only, 'input_ids_all': input_ids_all, 'labels_all': labels_all} 167 | 168 | def _tensorize_batch(self, examples): 169 | # In order to accept both lists of lists and lists of Tensors 170 | if isinstance(examples[0], (list, tuple)): 171 | examples = [torch.tensor(e, dtype=torch.long) for e in examples] 172 | length_of_first = examples[0].size(0) 173 | are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) 174 | if are_tensors_same_length: 175 | return torch.stack(examples, dim=0) 176 | else: 177 | return pad_sequence(examples, batch_first=True, padding_value=-100) 178 | -------------------------------------------------------------------------------- /src_autoencoder/generate.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import re 4 | import torch 5 | import sys 6 | from torch.utils.data import DataLoader 7 | from torch.nn import CrossEntropyLoss 8 | from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW 9 | import argparse 10 | import os 11 | import inspect 12 | import tqdm 13 | from data import CoTDataset, CoTDataCollator, extract_answer 14 | import logging 15 | import random 16 | import torch.nn as nn 17 | 18 | from models.emulator import Emulator 19 | from models.student import Student 20 | from utils import get_sep_position 21 | 22 | torch.backends.cuda.matmul.allow_tf32 = True 23 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 24 | 25 | random.seed(1234) 26 | torch.manual_seed(1234) 27 | logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere 28 | 29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | 31 | 32 | @torch.no_grad() 33 | def evaluate(dataloader, tokenizer, ctx, emulator, student, max_new_tokens): 34 | total_time = 0 35 | total_instances = 0 36 | total_correct = 0 37 | 38 | for batch in tqdm.tqdm(dataloader): 39 | input_ids_all = batch['input_ids_nocot'].to(device) 40 | # Remove answer part 41 | sep_positions = get_sep_position(input_ids_all, tokenizer.eos_token_id) 42 | input_ids = input_ids_all[:, :sep_positions.max()+1] 43 | start_time = time.time() 44 | with ctx: 45 | emulated_teacher_states = emulator(input_ids) 46 | 47 | # Generate from student 48 | beam_output = student.generate( 49 | input_ids=input_ids, 50 | teacher_states=emulated_teacher_states, 51 | max_new_tokens=max_new_tokens, 52 | ) 53 | 54 | # Evaluate 55 | #import pdb; pdb.set_trace() 56 | for i, (input_ids_all_i, beam_output_i) in enumerate(zip(input_ids_all, beam_output)): 57 | #sep_position = input_ids_single.tolist().index(tokenizer.eos_token_id) 58 | sep_position = sep_positions[i].item() 59 | tgt = input_ids_all_i[sep_position+1:] 60 | tgt_text = tokenizer.decode(tgt, skip_special_tokens=True) 61 | ans = extract_answer(tgt_text) 62 | pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True) 63 | pred_ans = extract_answer(pred_text) 64 | #import pdb; pdb.set_trace() 65 | total_instances += 1 66 | if ans == pred_ans: 67 | total_correct += 1 68 | end_time = time.time() 69 | total_time += end_time - start_time 70 | 71 | #print (total_time, total_instances, total_instances / total_time) 72 | throughput = total_instances / total_time 73 | accuracy = total_correct / total_instances 74 | return accuracy, throughput 75 | 76 | 77 | def main(): 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument('--test_path', type=str, required=True) 80 | parser.add_argument('--batch_size', type=int, default=1) 81 | parser.add_argument('--max_new_tokens', type=int, default=128) 82 | parser.add_argument('--student_path', type=str, required=True) 83 | parser.add_argument('--emulator_path', type=str, required=True) 84 | args = parser.parse_args() 85 | 86 | print (args) 87 | 88 | dtype = 'float32' 89 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 90 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 91 | ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype) 92 | print (ptdtype, dtype, device) 93 | 94 | 95 | # Load Models 96 | emulator = Emulator.from_pretrained(args.emulator_path).to(device).to(ptdtype) 97 | student = Student.from_pretrained(args.student_path).to(device).to(ptdtype) 98 | emulator.eval() 99 | student.eval() 100 | 101 | # Load data 102 | tokenizer = emulator.tokenizer 103 | collate_fn = CoTDataCollator(tokenizer) 104 | test_dataset = CoTDataset(tokenizer, args.test_path, 1024) 105 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False) 106 | 107 | accuracy, throughput = evaluate(test_dataloader, tokenizer, ctx, emulator, student, args.max_new_tokens) 108 | print (f"Test Accuracy: {accuracy}. Throughput: {throughput}") 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /src_autoencoder/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | from transformers import AutoTokenizer, GPT2Model 7 | from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions 8 | from transformers.models.bert.modeling_bert import BertModel 9 | 10 | sys.path.append("..") 11 | from utils import get_sep_position 12 | from .configuration_autoencoder import AutoEncoderConfig 13 | from .modeling_gpt2_implicit import GPT2LMHeadImplicitModel 14 | 15 | class AutoEncoder(nn.Module): 16 | def __init__(self, config): 17 | super().__init__() 18 | self.config = config 19 | teacher_num_layers = config.teacher_num_layers 20 | self.base_model_decoder = GPT2Model.from_pretrained(config.base_model) 21 | #self.base_model_encoder = nn.ModuleList([BertModel.from_pretrained('bert-base-uncased') for _ in range(teacher_num_layers)]) #GPT2LMHeadImplicitModel.from_pretrained(config.base_model) 22 | self.base_model_encoder = BertModel.from_pretrained('bert-base-uncased') 23 | self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) 24 | #import pdb; pdb.set_trace() 25 | #num_layers = len(self.base_model.transformer.h) 26 | #hidden_size = self.base_model[0].config.hidden_size 27 | encoder_hidden_size = self.base_model_encoder.config.hidden_size 28 | decoder_hidden_size = self.base_model_decoder.config.hidden_size 29 | teacher_hidden_size = config.teacher_hidden_size 30 | self.teacher_hidden_size = teacher_hidden_size 31 | self.encoder_hidden_size = encoder_hidden_size 32 | self.decoder_hidden_size = decoder_hidden_size 33 | self.layer_norm = nn.LayerNorm(teacher_hidden_size, elementwise_affine=False) 34 | #self.num_layers = num_layers 35 | #self.hidden_size = hidden_size 36 | self.mlp_in = nn.Linear(teacher_hidden_size*teacher_num_layers, encoder_hidden_size) 37 | self.mlp_out = nn.Linear(encoder_hidden_size, teacher_hidden_size*teacher_num_layers) 38 | self.mlp_in_decoder = nn.Linear(teacher_hidden_size*teacher_num_layers, decoder_hidden_size) 39 | self.mlp_out_decoder = nn.Linear(decoder_hidden_size, teacher_hidden_size*teacher_num_layers) 40 | #self.mlps_in = nn.ModuleList([nn.Sequential( 41 | # nn.Linear(teacher_hidden_size, hidden_size), 42 | # ) for _ in range(teacher_num_layers)]) 43 | #self.mlps_out = nn.ModuleList([nn.Sequential( 44 | # nn.Linear(hidden_size, teacher_hidden_size), 45 | # ) for _ in range(teacher_num_layers)]) 46 | 47 | def forward(self, input_ids, positions_to_substitute, teacher_states, output_hidden_states=False): 48 | outputs = self.base_model.forward(mode='forward_student', \ 49 | input_ids=input_ids, \ 50 | positions_to_substitute=positions_to_substitute, \ 51 | states_to_substitute=teacher_states, \ 52 | output_hidden_states=output_hidden_states) 53 | return outputs 54 | 55 | def encode(self, teacher_states_cat): 56 | #import pdb; pdb.set_trace() 57 | batch_size = teacher_states_cat.shape[0] 58 | seq_len = teacher_states_cat.shape[1] 59 | teacher_states_cat = teacher_states_cat.view(batch_size, seq_len, -1) 60 | state_in = self.mlp_in(teacher_states_cat) 61 | state_out = self.base_model_encoder(inputs_embeds=state_in).last_hidden_state 62 | bottleneck = state_out.mean(1) 63 | bottleneck = self.mlp_out(bottleneck) 64 | bottleneck = bottleneck.view(batch_size, -1, self.teacher_hidden_size) 65 | bottleneck = self.layer_norm(bottleneck) # bsz, layers, hidden 66 | return bottleneck 67 | 68 | def compute_loss(self, teacher_states): 69 | teacher_states_cat = torch.stack(teacher_states, dim=-2) # bsz, seq_len, layers, hidden 70 | batch_size = teacher_states_cat.shape[0] 71 | seq_len = teacher_states_cat.shape[1] 72 | bottleneck = self.encode(teacher_states_cat) 73 | teacher_states_cat = teacher_states_cat.view(batch_size, seq_len, -1) 74 | 75 | bottleneck = bottleneck.view(batch_size, 1, -1) # bz, 1, layers*hidden 76 | inputs_embeds = teacher_states_cat + bottleneck # bsz, seq_len, layers*hidden 77 | inputs_embeds = torch.cat((bottleneck, inputs_embeds), dim=1) # bsz, 1+seq_len, layers*hidden 78 | inputs_embeds = self.mlp_in_decoder(inputs_embeds) # bsz, 1+seq_len, hidden 79 | outputs = self.base_model_decoder(inputs_embeds=inputs_embeds) 80 | last_hidden_state = outputs.last_hidden_state 81 | outputs = self.mlp_out_decoder(last_hidden_state) 82 | outputs = outputs[:, :-1] 83 | 84 | loss_fct = nn.MSELoss(reduction='none') 85 | loss = loss_fct(teacher_states_cat, outputs).sum(-1) / 2 86 | loss = loss.mean(0).sum(-1) 87 | outputs = CausalLMOutputWithCrossAttentions(loss=loss) 88 | outputs.total_loss = loss * batch_size 89 | # Decoder 90 | #teacher_states_cat_in = teacher_states_cat + bottleneck.view( 91 | #encoded_states = [] 92 | #for layer_id, state in enumerate(teacher_states): 93 | # state_in = self.mlps_in[layer_id](state) 94 | # state_out = self.base_model_encoder[layer_id](inputs_embeds=state_in).last_hidden_state 95 | # bottleneck = state_out.mean(1, keepdim=True) 96 | # bottleneck = self.layer_norm(bottleneck) 97 | # encoded_states.append(bottleneck) 98 | #encoded_states = torch.cat(encoded_states, 1) 99 | #sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id) 100 | ## First, project teacher states 101 | #teacher_states = [self.mlps[l](teacher_states[l]) for l in range(len(teacher_states))] 102 | 103 | ## Forward while substituting teacher states 104 | #outputs = self.forward(input_ids, sep_positions, teacher_states) 105 | #logits = outputs.logits 106 | 107 | #labels_pred = logits.argmax(-1) 108 | #mask = labels[...,1:].ge(0) 109 | #correct_tokens = ((labels_pred[...,:-1] == labels[...,1:]) * mask).sum() 110 | #total_tokens = mask.sum() 111 | #token_accuracy = correct_tokens / total_tokens 112 | 113 | #shift_logits = logits[..., :-1, :].contiguous() 114 | #shift_labels = labels[..., 1:].contiguous() 115 | #loss_fct = nn.CrossEntropyLoss() 116 | #loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 117 | 118 | #outputs.loss = loss 119 | #outputs.token_accuracy = token_accuracy 120 | #outputs.total_correct = correct_tokens 121 | #outputs.total_loss = loss * total_tokens 122 | #outputs.total_tokens = total_tokens 123 | return outputs 124 | 125 | def generate(self, input_ids, teacher_states, max_new_tokens=512, num_beams=1): 126 | sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id) 127 | batch_size = input_ids.shape[0] 128 | beam_output = [] 129 | # First, project teacher states 130 | teacher_states = [self.mlps[l](teacher_states[l]) for l in range(len(teacher_states))] 131 | for i in range(batch_size): 132 | input_ids_i = input_ids[i:i+1] 133 | sep_positions_i = sep_positions[i:i+1] 134 | input_ids_i = input_ids_i[:, :sep_positions_i+1] 135 | beam_output_i = self.base_model.generate( 136 | input_ids=input_ids_i, 137 | max_new_tokens=max_new_tokens, 138 | num_beams=num_beams, 139 | early_stopping=True, 140 | num_return_sequences=1, 141 | positions_to_substitute=sep_positions_i.repeat_interleave(num_beams, dim=0), 142 | states_to_substitute=[z[i:i+1].repeat_interleave(num_beams, dim=0) for z in teacher_states], 143 | mode='forward_student', 144 | ) 145 | beam_output.append(beam_output_i) 146 | return beam_output 147 | 148 | @classmethod 149 | def from_pretrained(self, pretrained_path): 150 | config = AutoEncoderConfig.from_pretrained(pretrained_path) 151 | model = AutoEncoder(config) 152 | state_dict = torch.load(os.path.join(pretrained_path, 'state_dict.bin')) 153 | model.load_state_dict(state_dict) 154 | return model 155 | 156 | def save_pretrained(self, save_directory): 157 | print (f'Saving to {save_directory}') 158 | self.config.save_pretrained(save_directory) 159 | state_dict = self.state_dict() 160 | torch.save(state_dict, os.path.join(save_directory, 'state_dict.bin')) 161 | 162 | -------------------------------------------------------------------------------- /src_autoencoder/models/configuration_autoencoder.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | class AutoEncoderConfig(PretrainedConfig): 4 | def __init__( 5 | self, 6 | base_model='gpt2', 7 | tokenizer_name='gpt2', 8 | mixture_size=1, 9 | teacher_hidden_size=None, 10 | teacher_num_layers=None, 11 | **kwargs, 12 | ): 13 | self.base_model = base_model 14 | self.tokenizer_name = tokenizer_name 15 | self.mixture_size = mixture_size 16 | self.teacher_hidden_size = teacher_hidden_size 17 | self.teacher_num_layers = teacher_num_layers 18 | super().__init__(**kwargs) 19 | -------------------------------------------------------------------------------- /src_autoencoder/models/configuration_emulator.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | class EmulatorConfig(PretrainedConfig): 4 | def __init__( 5 | self, 6 | base_model='gpt2', 7 | tokenizer_name='gpt2', 8 | mixture_size=1, 9 | softmax_temperature=0.05, 10 | **kwargs, 11 | ): 12 | self.base_model = base_model 13 | self.tokenizer_name = tokenizer_name 14 | self.mixture_size = mixture_size 15 | self.softmax_temperature = softmax_temperature 16 | super().__init__(**kwargs) 17 | 18 | -------------------------------------------------------------------------------- /src_autoencoder/models/configuration_student.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | class StudentConfig(PretrainedConfig): 4 | def __init__( 5 | self, 6 | base_model='gpt2', 7 | tokenizer_name='gpt2', 8 | mixture_size=1, 9 | **kwargs, 10 | ): 11 | self.base_model = base_model 12 | self.tokenizer_name = tokenizer_name 13 | self.mixture_size = mixture_size 14 | super().__init__(**kwargs) 15 | -------------------------------------------------------------------------------- /src_autoencoder/models/configuration_teacher.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | class TeacherConfig(PretrainedConfig): 4 | def __init__( 5 | self, 6 | base_model='gpt2', 7 | tokenizer_name='gpt2', 8 | **kwargs, 9 | ): 10 | self.base_model = base_model 11 | self.tokenizer_name = tokenizer_name 12 | super().__init__(**kwargs) 13 | 14 | -------------------------------------------------------------------------------- /src_autoencoder/models/emulator.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions 8 | 9 | from .configuration_emulator import EmulatorConfig 10 | import sys 11 | sys.path.append("..") 12 | from utils import get_sep_position 13 | from .modeling_gpt2_implicit import GPT2LMHeadImplicitModel 14 | 15 | 16 | class Emulator(nn.Module): 17 | def __init__(self, config): 18 | super().__init__() 19 | self.config = config 20 | self.base_model = GPT2LMHeadImplicitModel.from_pretrained(config.base_model) 21 | self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) 22 | num_layers = len(self.base_model.transformer.h) 23 | hidden_size = self.base_model.config.hidden_size 24 | 25 | self.mlps = nn.ModuleList([nn.Sequential( 26 | nn.Linear(2*hidden_size, 4*hidden_size), 27 | nn.ReLU(), 28 | nn.Linear(4*hidden_size, hidden_size), 29 | ) for _ in range(num_layers)]) 30 | 31 | self.mixture_components = nn.Embedding(config.mixture_size, hidden_size) 32 | self.rnn = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, num_layers=1, \ 33 | batch_first=False, dropout=0, bidirectional=False) 34 | self.key_proj = nn.Linear(hidden_size, hidden_size) 35 | self.query_proj = nn.Linear(hidden_size, hidden_size) 36 | self.out_proj = nn.Linear(hidden_size*2, hidden_size) 37 | 38 | def eval(self): 39 | self.base_model.eval() 40 | 41 | def forward(self, input_ids, requires_backward=False): 42 | sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id) 43 | input_ids = input_ids[:, :sep_positions.max()+1] 44 | outputs = self.base_model.forward(mode='forward_emulator', \ 45 | input_ids=input_ids, \ 46 | positions_to_take=sep_positions, \ 47 | softmax_temperature=self.config.softmax_temperature, \ 48 | requires_backward=requires_backward, \ 49 | rnn=self.rnn, \ 50 | mlps=self.mlps, \ 51 | mixture_components=self.mixture_components, \ 52 | key_proj=self.key_proj, \ 53 | query_proj=self.query_proj, \ 54 | out_proj=self.out_proj) 55 | emulated_teacher_states = outputs.f_h_cs 56 | return emulated_teacher_states 57 | 58 | def compute_loss(self, input_ids, teacher_states): 59 | emulated_teacher_states = self.forward(input_ids=input_ids, requires_backward=True) 60 | batch_size = input_ids.shape[0] 61 | 62 | loss_fct = nn.MSELoss(reduction='none') 63 | loss = 0 64 | for teacher_state, emulated_teacher_state in zip(teacher_states, emulated_teacher_states): 65 | loss += loss_fct(teacher_state, emulated_teacher_state).sum(-1) / 2 66 | loss = loss.mean() 67 | outputs = CausalLMOutputWithCrossAttentions(loss=loss) 68 | outputs.total_loss = loss * batch_size 69 | return outputs 70 | 71 | @classmethod 72 | def from_pretrained(self, pretrained_path): 73 | config = EmulatorConfig.from_pretrained(pretrained_path) 74 | model = Emulator(config) 75 | state_dict = torch.load(os.path.join(pretrained_path, 'state_dict.bin')) 76 | model.load_state_dict(state_dict) 77 | return model 78 | 79 | def save_pretrained(self, save_directory): 80 | print (f'Saving to {save_directory}') 81 | self.config.save_pretrained(save_directory) 82 | state_dict = self.state_dict() 83 | torch.save(state_dict, os.path.join(save_directory, 'state_dict.bin')) 84 | -------------------------------------------------------------------------------- /src_autoencoder/models/student.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | from transformers import AutoTokenizer 7 | 8 | sys.path.append("..") 9 | from utils import get_sep_position 10 | from .configuration_student import StudentConfig 11 | from .modeling_gpt2_implicit import GPT2LMHeadImplicitModel 12 | 13 | class Student(nn.Module): 14 | def __init__(self, config): 15 | super().__init__() 16 | self.config = config 17 | self.base_model = GPT2LMHeadImplicitModel.from_pretrained(config.base_model) 18 | self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) 19 | num_layers = len(self.base_model.transformer.h) 20 | hidden_size = self.base_model.config.hidden_size 21 | self.num_layers = num_layers 22 | self.hidden_size = hidden_size 23 | 24 | self.mlps = nn.ModuleList([nn.Sequential( 25 | nn.Linear(hidden_size, 4*hidden_size), 26 | nn.ReLU(), 27 | nn.Linear(4*hidden_size, hidden_size), 28 | ) for _ in range(num_layers)]) 29 | 30 | def forward(self, input_ids, positions_to_substitute, teacher_states, output_hidden_states=False): 31 | outputs = self.base_model.forward(mode='forward_student', \ 32 | input_ids=input_ids, \ 33 | positions_to_substitute=positions_to_substitute, \ 34 | states_to_substitute=teacher_states, \ 35 | output_hidden_states=output_hidden_states) 36 | return outputs 37 | 38 | def compute_loss(self, input_ids, labels, teacher_states): 39 | #import pdb; pdb.set_trace() 40 | sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id) 41 | # First, project teacher states 42 | teacher_states = [self.mlps[l](teacher_states[l]) for l in range(len(teacher_states))] 43 | 44 | # Forward while substituting teacher states 45 | outputs = self.forward(input_ids, sep_positions, teacher_states) 46 | logits = outputs.logits 47 | 48 | labels_pred = logits.argmax(-1) 49 | mask = labels[...,1:].ge(0) 50 | correct_tokens = ((labels_pred[...,:-1] == labels[...,1:]) * mask).sum() 51 | total_tokens = mask.sum() 52 | token_accuracy = correct_tokens / total_tokens 53 | 54 | shift_logits = logits[..., :-1, :].contiguous() 55 | shift_labels = labels[..., 1:].contiguous() 56 | loss_fct = nn.CrossEntropyLoss() 57 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 58 | 59 | outputs.loss = loss 60 | outputs.token_accuracy = token_accuracy 61 | outputs.total_correct = correct_tokens 62 | outputs.total_loss = loss * total_tokens 63 | outputs.total_tokens = total_tokens 64 | return outputs 65 | 66 | def generate(self, input_ids, teacher_states, max_new_tokens=512, num_beams=1): 67 | sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id) 68 | batch_size = input_ids.shape[0] 69 | beam_output = [] 70 | # First, project teacher states 71 | teacher_states = [self.mlps[l](teacher_states[l]) for l in range(len(teacher_states))] 72 | for i in range(batch_size): 73 | input_ids_i = input_ids[i:i+1] 74 | sep_positions_i = sep_positions[i:i+1] 75 | input_ids_i = input_ids_i[:, :sep_positions_i+1] 76 | beam_output_i = self.base_model.generate( 77 | input_ids=input_ids_i, 78 | max_new_tokens=max_new_tokens, 79 | num_beams=num_beams, 80 | early_stopping=True, 81 | num_return_sequences=1, 82 | positions_to_substitute=sep_positions_i.repeat_interleave(num_beams, dim=0), 83 | states_to_substitute=[z[i:i+1].repeat_interleave(num_beams, dim=0) for z in teacher_states], 84 | mode='forward_student', 85 | ) 86 | beam_output.append(beam_output_i) 87 | return beam_output 88 | 89 | @classmethod 90 | def from_pretrained(self, pretrained_path): 91 | config = StudentConfig.from_pretrained(pretrained_path) 92 | model = Student(config) 93 | state_dict = torch.load(os.path.join(pretrained_path, 'state_dict.bin')) 94 | model.load_state_dict(state_dict) 95 | return model 96 | 97 | def save_pretrained(self, save_directory): 98 | print (f'Saving to {save_directory}') 99 | self.config.save_pretrained(save_directory) 100 | state_dict = self.state_dict() 101 | torch.save(state_dict, os.path.join(save_directory, 'state_dict.bin')) 102 | 103 | -------------------------------------------------------------------------------- /src_autoencoder/models/teacher.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import CrossEntropyLoss 6 | 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList, GenerationConfig, LogitsProcessorList 8 | 9 | from .configuration_teacher import TeacherConfig 10 | import sys 11 | sys.path.append("..") 12 | from utils import get_sep_position, DoubleEOSStoppingCriteria, DoubleEOSLogitsProcessor 13 | from .modeling_gpt2_implicit import GPT2LMHeadImplicitModel 14 | 15 | 16 | class Teacher(nn.Module): 17 | def __init__(self, config): 18 | super().__init__() 19 | self.config = config 20 | self.base_model = GPT2LMHeadImplicitModel.from_pretrained(config.base_model) 21 | self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) 22 | num_layers = len(self.base_model.transformer.h) 23 | hidden_size = self.base_model.config.hidden_size 24 | self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False) 25 | self.num_layers = num_layers 26 | self.hidden_size = hidden_size 27 | 28 | def forward(self, input_ids): 29 | outputs = self.base_model.forward(input_ids=input_ids) 30 | return outputs 31 | 32 | def compute_positions_to_extract_per_layer(self, subset, delta, first_sep_positions, second_sep_positions): 33 | batch_size = first_sep_positions.shape[0] 34 | positions_to_extract_per_layer = first_sep_positions.new_zeros(batch_size, self.num_layers).long() 35 | layer_ids = torch.arange(start=0, end=self.num_layers).to(first_sep_positions.device) 36 | for batch_id in range(batch_size): 37 | first_position_to_extract = first_sep_positions[batch_id] 38 | last_position_to_extract = second_sep_positions[batch_id] 39 | if subset == 'diagonal': 40 | if delta == 'dynamic': # determine actual delta 41 | delta = (last_position_to_extract - first_position_to_extract) / (self.num_layers - 1) 42 | elif subset == 'first_column' or subset == 'last_column': 43 | delta = 0 44 | else: 45 | assert subset == 'last_column', subset 46 | delta = 0 47 | first_position_to_extract = last_position_to_extract 48 | positions_to_extract = torch.round(first_position_to_extract + layer_ids * delta) 49 | positions_to_extract = positions_to_extract.clamp(max=last_position_to_extract) 50 | positions_to_extract_per_layer[batch_id] = positions_to_extract 51 | return positions_to_extract_per_layer 52 | 53 | def extract_states(self, input_ids, delta, subset='diagonal'): 54 | if delta.isnumeric(): 55 | delta = int(delta) 56 | batch_size = input_ids.shape[0] 57 | hidden_size = self.hidden_size 58 | # Forward the teacher to produce all hidden states 59 | outputs = self.base_model.forward(input_ids=input_ids, output_hidden_states=True) 60 | hidden_states = outputs.hidden_states[:-1] 61 | 62 | if subset == None: 63 | return hidden_states 64 | 65 | # Find the boundaries between input and CoT, and CoT and output 66 | # [input] first_sep_position [CoT] second_position [output] eos 67 | first_sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id, skip=0) 68 | second_sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id, skip=1) 69 | input_ids = input_ids[:, :second_sep_positions.max()+1] 70 | 71 | # Compute the positions to extract teacher states (t_l in the paper) 72 | positions_to_extract_per_layer = self.compute_positions_to_extract_per_layer(subset, delta, first_sep_positions, second_sep_positions) 73 | 74 | # Extract teacher states 75 | teacher_states_extracted = [] 76 | for i, hidden_state in enumerate(hidden_states): 77 | if subset == 'diagonal' or subset == 'first_column' or subset == 'last_column': 78 | z = hidden_state.gather(1, positions_to_extract_per_layer[:,i].view(-1, 1, 1).expand(-1, -1, hidden_size)).squeeze(1) 79 | elif subset == 'top_row': 80 | z = hidden_states[-1].gather(1, positions_to_extract_per_layer[:,i].view(-1, 1, 1).expand(-1, -1, hidden_size)).squeeze(1) 81 | else: 82 | assert subset == 'bottom_row', subset 83 | z = hidden_states[0].gather(1, positions_to_extract_per_layer[:,i].view(-1, 1, 1).expand(-1, -1, hidden_size)).squeeze(1) 84 | # Apply layer norm to normalize to 0 mean and 1 std 85 | z = self.layer_norm(z) 86 | teacher_states_extracted.append(z) 87 | return teacher_states_extracted 88 | 89 | def compute_loss(self, input_ids, labels): 90 | #import pdb; pdb.set_trace() 91 | outputs = self.forward(input_ids=input_ids) 92 | logits = outputs.logits 93 | 94 | labels_pred = logits.argmax(-1) 95 | mask = labels[...,1:].ge(0) 96 | correct_tokens = ((labels_pred[...,:-1] == labels[...,1:]) * mask).sum() 97 | total_tokens = mask.sum() 98 | token_accuracy = correct_tokens / total_tokens 99 | 100 | shift_logits = logits[..., :-1, :].contiguous() 101 | shift_labels = labels[..., 1:].contiguous() 102 | loss_fct = CrossEntropyLoss() 103 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 104 | 105 | outputs.loss = loss 106 | outputs.token_accuracy = token_accuracy 107 | outputs.total_correct = correct_tokens 108 | outputs.total_loss = loss * total_tokens 109 | outputs.total_tokens = total_tokens 110 | return outputs 111 | 112 | def generate(self, input_ids, max_new_tokens=512, num_beams=1, stop_on_two_eos=True): 113 | sep_positions = get_sep_position(input_ids, self.tokenizer.eos_token_id) 114 | batch_size = input_ids.shape[0] 115 | 116 | # Since there's one eos after CoT and another after final answer, we need to wait for two eos 117 | generation_config = GenerationConfig.from_model_config(self.base_model.config) 118 | if stop_on_two_eos: 119 | generation_config.eos_token_id = -1 120 | logits_processor = LogitsProcessorList([DoubleEOSLogitsProcessor(self.tokenizer.eos_token_id)]) 121 | stopping_criteria = StoppingCriteriaList([DoubleEOSStoppingCriteria(self.tokenizer.eos_token_id)]) 122 | else: 123 | logits_processor = None 124 | stopping_criteria = None 125 | 126 | if sep_positions.eq(sep_positions[0]).all(): 127 | input_ids = input_ids[:, :sep_positions[0]+1] 128 | beam_output = self.base_model.generate( 129 | input_ids=input_ids, 130 | generation_config=generation_config, 131 | max_new_tokens=max_new_tokens, 132 | num_beams=num_beams, 133 | early_stopping=True, 134 | num_return_sequences=1, 135 | logits_processor=logits_processor, 136 | stopping_criteria=stopping_criteria, 137 | ) 138 | beam_output = beam_output.unsqueeze(1) 139 | else: 140 | beam_output = [] 141 | for i in range(batch_size): 142 | input_ids_i = input_ids[i:i+1] 143 | sep_positions_i = sep_positions[i:i+1] 144 | input_ids_i = input_ids_i[:, :sep_positions_i+1] 145 | beam_output_i = self.base_model.generate( 146 | input_ids=input_ids_i, 147 | generation_config=generation_config, 148 | max_new_tokens=max_new_tokens, 149 | num_beams=num_beams, 150 | early_stopping=True, 151 | num_return_sequences=1, 152 | logits_processor=logits_processor, 153 | stopping_criteria=stopping_criteria, 154 | ) 155 | beam_output.append(beam_output_i) 156 | return beam_output 157 | 158 | @classmethod 159 | def from_pretrained(self, pretrained_path): 160 | config = TeacherConfig.from_pretrained(pretrained_path) 161 | model = Teacher(config) 162 | state_dict = torch.load(os.path.join(pretrained_path, 'state_dict.bin')) 163 | model.load_state_dict(state_dict) 164 | return model 165 | 166 | def save_pretrained(self, save_directory): 167 | print (f'Saving to {save_directory}') 168 | self.config.save_pretrained(save_directory) 169 | state_dict = self.state_dict() 170 | torch.save(state_dict, os.path.join(save_directory, 'state_dict.bin')) 171 | -------------------------------------------------------------------------------- /src_autoencoder/train_autoencoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from transformers import AdamW 5 | import argparse 6 | import os 7 | import sys 8 | import inspect 9 | import tqdm 10 | import logging 11 | import random 12 | 13 | from data import CoTDataset, CoTDataCollator, extract_answer 14 | from models.teacher import Teacher 15 | from models.autoencoder import AutoEncoder 16 | from models.autoencoder import AutoEncoder 17 | from models.configuration_autoencoder import AutoEncoderConfig 18 | from utils import get_sep_position 19 | 20 | torch.backends.cuda.matmul.allow_tf32 = True 21 | torch.backends.cudnn.allow_tf32 = True 22 | 23 | random.seed(1234) 24 | torch.manual_seed(1234) 25 | logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere 26 | 27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | 29 | @torch.no_grad() 30 | def evaluate(dataloader, tokenizer, ctx, teacher, autoencoder, delta, subset): 31 | total_instances = 0 32 | total_loss = 0 33 | for batch in tqdm.tqdm(dataloader): 34 | #import pdb; pdb.set_trace() 35 | input_ids_cot = batch['input_ids_cot'].to(device) 36 | batch_size = input_ids_cot.shape[0] 37 | with ctx: 38 | teacher_states = teacher.extract_states(input_ids=input_ids_cot, delta=delta, subset=None) 39 | outputs = autoencoder.compute_loss(teacher_states=teacher_states) 40 | loss = outputs.loss 41 | total_loss += outputs.total_loss.item() 42 | total_instances += batch_size 43 | 44 | loss = total_loss / total_instances 45 | return loss 46 | 47 | 48 | def main(): 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--teacher', type=str, required=True) 51 | parser.add_argument('--delta', type=str, required=True) 52 | parser.add_argument('--train_path', type=str, required=True) 53 | parser.add_argument('--val_path', type=str, required=True) 54 | parser.add_argument('--save_model', type=str, required=True) 55 | parser.add_argument('--max_new_tokens', type=int, default=128) 56 | parser.add_argument('--base_model', type=str, default='gpt2') 57 | parser.add_argument('--epochs', type=int, default=5) 58 | parser.add_argument('--batch_size', type=int, default=32) 59 | parser.add_argument('--lr', type=float, default=5e-5) 60 | parser.add_argument('--max_grad_norm', type=float, default=1.0) 61 | parser.add_argument('--subset', type=str, choices=['diagonal', 'last_column', 'top_row', 'bottom_row', 'first_column'], default='diagonal') 62 | args = parser.parse_args() 63 | 64 | print (args) 65 | 66 | dtype = 'float32' 67 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 68 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 69 | ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype) 70 | print (ptdtype, dtype, device) 71 | 72 | # Load Teacher 73 | teacher = Teacher.from_pretrained(args.teacher).to(device).to(ptdtype) 74 | 75 | # Create AutoEncoder 76 | config = AutoEncoderConfig(base_model=args.base_model, teacher_hidden_size=teacher.hidden_size, teacher_num_layers=teacher.num_layers) 77 | autoencoder = AutoEncoder(config).to(device).to(ptdtype) 78 | 79 | # Load data 80 | tokenizer = teacher.tokenizer 81 | collate_fn = CoTDataCollator(tokenizer) 82 | train_dataset = CoTDataset(tokenizer, args.train_path, 1024) 83 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True) 84 | val_dataset = CoTDataset(tokenizer, args.val_path, 1024) 85 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False) 86 | 87 | # Create Optimizer 88 | trainable_params = list(autoencoder.parameters()) 89 | use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters 90 | extra_args = dict(fused=True) if use_fused else dict() 91 | optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args) 92 | 93 | teacher.eval() 94 | autoencoder.eval() # to turn off dropout 95 | 96 | for p in teacher.parameters(): 97 | p.requires_grad = False 98 | 99 | # Train 100 | step = 0 101 | for epoch in range(args.epochs): 102 | print(f"Epoch {epoch}") 103 | 104 | for batch in tqdm.tqdm(train_dataloader): 105 | #import pdb; pdb.set_trace() 106 | input_ids_all = batch['input_ids_all'].to(device) 107 | input_ids_nocot = batch['input_ids_nocot'].to(device) 108 | labels_nocot = batch['labels_nocot'].to(device) 109 | with ctx: 110 | with torch.no_grad(): 111 | teacher_states = teacher.extract_states(input_ids=input_ids_all, delta=args.delta, subset=None) 112 | outputs = autoencoder.compute_loss(teacher_states=teacher_states) 113 | loss = outputs.loss 114 | #token_accuracy = outputs.token_accuracy.item() 115 | 116 | loss.backward() 117 | torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm) 118 | optimizer.step() 119 | optimizer.zero_grad() 120 | #ppl = loss.exp().item() 121 | if step % 100 == 0: 122 | print (f"Step: {step}. Loss: {loss}.") 123 | sys.stdout.flush() 124 | step += 1 125 | loss = evaluate(val_dataloader, tokenizer, ctx, teacher, autoencoder, args.delta, args.subset) 126 | print (f'Val. Loss: {loss}.') 127 | autoencoder.save_pretrained(os.path.join(args.save_model, f'checkpoint_{epoch}')) 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /src_autoencoder/train_coupled_emulator_and_student.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from transformers import AdamW 5 | import argparse 6 | import os 7 | import inspect 8 | import tqdm 9 | import logging 10 | import random 11 | from itertools import chain 12 | 13 | from data import CoTDataset, CoTDataCollator, extract_answer 14 | from models.student import Student 15 | from models.emulator import Emulator 16 | from utils import get_sep_position 17 | 18 | torch.backends.cuda.matmul.allow_tf32 = True 19 | torch.backends.cudnn.allow_tf32 = True 20 | random.seed(1234) 21 | torch.manual_seed(1234) 22 | logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | 27 | @torch.no_grad() 28 | def evaluate(dataloader, tokenizer, ctx, emulator, student, max_new_tokens): 29 | total_instances = 0 30 | total_tokens = 0 31 | total_correct = 0 32 | total_correct_tokens = 0 33 | total_loss = 0 34 | for batch in tqdm.tqdm(dataloader): 35 | #import pdb; pdb.set_trace() 36 | input_ids_nocot = batch['input_ids_nocot'].to(device) 37 | labels_nocot = batch['labels_nocot'].to(device) 38 | batch_size = input_ids_nocot.shape[0] 39 | with ctx: 40 | emulated_teacher_states = emulator(input_ids=input_ids_nocot) 41 | outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=emulated_teacher_states) 42 | loss = outputs.loss 43 | token_accuracy = outputs.token_accuracy.item() 44 | total_loss += outputs.total_loss.item() 45 | total_correct_tokens += outputs.total_correct.item() 46 | total_tokens += outputs.total_tokens 47 | total_instances += batch_size 48 | 49 | # Generate 50 | with ctx: 51 | beam_output = student.generate( 52 | input_ids=input_ids_nocot, 53 | teacher_states=emulated_teacher_states, 54 | max_new_tokens=max_new_tokens, 55 | ) 56 | 57 | # Evaluate 58 | sep_positions = get_sep_position(input_ids_nocot, tokenizer.eos_token_id) 59 | for i, (input_ids_i, beam_output_i) in enumerate(zip(input_ids_nocot, beam_output)): 60 | sep_position = sep_positions[i].item() 61 | tgt = input_ids_i[sep_position+1:] 62 | tgt_text = tokenizer.decode(tgt, skip_special_tokens=True) 63 | ans = extract_answer(tgt_text) 64 | pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True) 65 | pred_ans = extract_answer(pred_text) 66 | if ans == pred_ans: 67 | total_correct += 1 68 | if i == 0: 69 | print (f'Input: {tokenizer.decode(input_ids_i[:sep_position], skip_special_tokens=True)}') 70 | print (f'Target: {tgt_text}') 71 | print (f'Predicted: {pred_text}') 72 | print ('') 73 | accuracy = total_correct / total_instances 74 | token_accuracy = total_correct_tokens / total_tokens 75 | loss = total_loss / total_tokens 76 | ppl = math.exp(loss) 77 | return accuracy, token_accuracy, ppl 78 | 79 | 80 | def main(): 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--emulator', type=str, required=True) 83 | parser.add_argument('--student', type=str, required=True) 84 | parser.add_argument('--train_path', type=str, required=True) 85 | parser.add_argument('--val_path', type=str, required=True) 86 | parser.add_argument('--save_model', type=str, required=True) 87 | parser.add_argument('--max_new_tokens', type=int, default=128) 88 | parser.add_argument('--epochs', type=int, default=5) 89 | parser.add_argument('--batch_size', type=int, default=32) 90 | parser.add_argument('--lr', type=float, default=5e-5) 91 | parser.add_argument('--max_grad_norm', type=float, default=1.0) 92 | parser.add_argument('--softmax_temperature', type=float, default=0.05) 93 | parser.add_argument('--fix_emulator', dest='fix_emulator', action='store_true') 94 | parser.set_defaults(fix_emulator=False) 95 | args = parser.parse_args() 96 | 97 | print (args) 98 | 99 | dtype = 'float32' 100 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 101 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 102 | ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype) 103 | print (ptdtype, dtype, device) 104 | 105 | # Load Student 106 | student = Student.from_pretrained(args.student).to(device).to(ptdtype) 107 | 108 | # Load Emulator 109 | emulator = Emulator.from_pretrained(args.emulator).to(device).to(ptdtype) 110 | 111 | # Load data 112 | tokenizer = emulator.tokenizer 113 | collate_fn = CoTDataCollator(tokenizer) 114 | train_dataset = CoTDataset(tokenizer, args.train_path, 1024) 115 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True) 116 | val_dataset = CoTDataset(tokenizer, args.val_path, 1024) 117 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False) 118 | 119 | # Create Optimizer 120 | if args.fix_emulator: 121 | trainable_params = student.parameters() 122 | for p in emulator.parameters(): 123 | p.requires_grad = False 124 | else: 125 | trainable_params = chain(student.parameters(), emulator.parameters()) 126 | use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters 127 | extra_args = dict(fused=True) if use_fused else dict() 128 | optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args) 129 | 130 | emulator.eval() # to turn off dropout 131 | student.eval() # to turn off dropout 132 | 133 | 134 | # Train 135 | step = 0 136 | for epoch in range(args.epochs): 137 | print(f"Epoch {epoch}") 138 | 139 | for batch in tqdm.tqdm(train_dataloader): 140 | #import pdb; pdb.set_trace() 141 | input_ids_nocot = batch['input_ids_nocot'].to(device) 142 | labels_nocot = batch['labels_nocot'].to(device) 143 | with ctx: 144 | emulated_teacher_states = emulator(input_ids_nocot, requires_backward=not args.fix_emulator) 145 | outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=emulated_teacher_states) 146 | loss = outputs.loss 147 | token_accuracy = outputs.token_accuracy.item() 148 | 149 | loss.backward() 150 | torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm) 151 | optimizer.step() 152 | optimizer.zero_grad() 153 | ppl = loss.exp().item() 154 | if step % 100 == 0: 155 | print (f"Step: {step}. PPL: {ppl}. Token Accuracy: {token_accuracy}") 156 | step += 1 157 | accuracy, token_accuracy, ppl = evaluate(val_dataloader, tokenizer, ctx, emulator, student, args.max_new_tokens) 158 | print (f'Val. PPL: {ppl}; Accuracy: {accuracy}; Token Accuracy: {token_accuracy}.') 159 | student.save_pretrained(os.path.join(args.save_model, 'student', f'checkpoint_{epoch}')) 160 | emulator.save_pretrained(os.path.join(args.save_model, 'emulator', f'checkpoint_{epoch}')) 161 | 162 | 163 | if __name__ == "__main__": 164 | main() 165 | -------------------------------------------------------------------------------- /src_autoencoder/train_mind_reading_student.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from transformers import AdamW 5 | import argparse 6 | import os 7 | import inspect 8 | import tqdm 9 | import logging 10 | import random 11 | 12 | from data import CoTDataset, CoTDataCollator, extract_answer 13 | from models.teacher import Teacher 14 | from models.autoencoder import AutoEncoder 15 | from models.student import Student 16 | from models.configuration_student import StudentConfig 17 | from utils import get_sep_position 18 | 19 | torch.backends.cuda.matmul.allow_tf32 = True 20 | torch.backends.cudnn.allow_tf32 = True 21 | 22 | random.seed(1234) 23 | torch.manual_seed(1234) 24 | logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere 25 | 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | 28 | @torch.no_grad() 29 | def evaluate(dataloader, tokenizer, ctx, teacher, autoencoder, student, delta, subset, max_new_tokens): 30 | total_instances = 0 31 | total_tokens = 0 32 | total_correct = 0 33 | total_correct_tokens = 0 34 | total_loss = 0 35 | for batch in tqdm.tqdm(dataloader): 36 | #import pdb; pdb.set_trace() 37 | input_ids_all = batch['input_ids_all'].to(device) 38 | input_ids_nocot = batch['input_ids_nocot'].to(device) 39 | labels_nocot = batch['labels_nocot'].to(device) 40 | batch_size = input_ids_nocot.shape[0] 41 | with ctx: 42 | teacher_states = teacher.extract_states(input_ids=input_ids_all, delta=delta, subset=None) 43 | teacher_states_cat = torch.stack(teacher_states, dim=-2) # bsz, seq_len, layers, hidden 44 | encoded_states = autoencoder.encode(teacher_states_cat) 45 | teacher_states = encoded_states.transpose(0, 1) 46 | outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=teacher_states) 47 | loss = outputs.loss 48 | token_accuracy = outputs.token_accuracy.item() 49 | total_loss += outputs.total_loss.item() 50 | total_correct_tokens += outputs.total_correct.item() 51 | total_tokens += outputs.total_tokens 52 | total_instances += batch_size 53 | 54 | # Generate 55 | with ctx: 56 | beam_output = student.generate( 57 | input_ids=input_ids_nocot, 58 | teacher_states=teacher_states, 59 | max_new_tokens=max_new_tokens, 60 | ) 61 | 62 | # Evaluate 63 | sep_positions = get_sep_position(input_ids_all, tokenizer.eos_token_id) 64 | for i, (input_ids_all_i, beam_output_i) in enumerate(zip(input_ids_all, beam_output)): 65 | sep_position = sep_positions[i].item() 66 | tgt = input_ids_all_i[sep_position+1:] 67 | tgt_text = tokenizer.decode(tgt, skip_special_tokens=True) 68 | ans = extract_answer(tgt_text) 69 | pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True) 70 | pred_ans = extract_answer(pred_text) 71 | if ans == pred_ans: 72 | total_correct += 1 73 | if i == 0: 74 | print (f'Input: {tokenizer.decode(input_ids_all_i[:sep_position], skip_special_tokens=True)}') 75 | print (f'Target: {tgt_text}') 76 | print (f'Predicted: {pred_text}') 77 | print ('') 78 | accuracy = total_correct / total_instances 79 | token_accuracy = total_correct_tokens / total_tokens 80 | loss = total_loss / total_tokens 81 | ppl = math.exp(loss) 82 | return accuracy, token_accuracy, ppl 83 | 84 | 85 | def main(): 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument('--teacher', type=str, required=True) 88 | parser.add_argument('--autoencoder', type=str, required=True) 89 | parser.add_argument('--delta', type=str, required=True) 90 | parser.add_argument('--train_path', type=str, required=True) 91 | parser.add_argument('--val_path', type=str, required=True) 92 | parser.add_argument('--save_model', type=str, required=True) 93 | parser.add_argument('--max_new_tokens', type=int, default=128) 94 | parser.add_argument('--base_model', type=str, default='gpt2') 95 | parser.add_argument('--epochs', type=int, default=5) 96 | parser.add_argument('--batch_size', type=int, default=32) 97 | parser.add_argument('--lr', type=float, default=5e-5) 98 | parser.add_argument('--max_grad_norm', type=float, default=1.0) 99 | parser.add_argument('--subset', type=str, choices=['diagonal', 'last_column', 'top_row', 'bottom_row', 'first_column'], default='diagonal') 100 | args = parser.parse_args() 101 | 102 | print (args) 103 | 104 | dtype = 'float32' 105 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 106 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 107 | ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype) 108 | print (ptdtype, dtype, device) 109 | 110 | # Create Student 111 | config = StudentConfig(base_model=args.base_model) 112 | student = Student(config).to(device).to(ptdtype) 113 | 114 | # Load Teacher 115 | teacher = Teacher.from_pretrained(args.teacher).to(device).to(ptdtype) 116 | autoencoder = AutoEncoder.from_pretrained(args.autoencoder).to(device).to(ptdtype) 117 | 118 | # Load data 119 | tokenizer = teacher.tokenizer 120 | collate_fn = CoTDataCollator(tokenizer) 121 | train_dataset = CoTDataset(tokenizer, args.train_path, 1024) 122 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True) 123 | val_dataset = CoTDataset(tokenizer, args.val_path, 1024) 124 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False) 125 | 126 | # Create Optimizer 127 | trainable_params = list(student.parameters()) 128 | use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters 129 | extra_args = dict(fused=True) if use_fused else dict() 130 | optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args) 131 | 132 | teacher.eval() 133 | student.eval() # to turn off dropout 134 | 135 | for p in teacher.parameters(): 136 | p.requires_grad = False 137 | 138 | # Train 139 | step = 0 140 | for epoch in range(args.epochs): 141 | print(f"Epoch {epoch}") 142 | 143 | for batch in tqdm.tqdm(train_dataloader): 144 | #import pdb; pdb.set_trace() 145 | input_ids_all = batch['input_ids_all'].to(device) 146 | input_ids_nocot = batch['input_ids_nocot'].to(device) 147 | labels_nocot = batch['labels_nocot'].to(device) 148 | with ctx: 149 | with torch.no_grad(): 150 | teacher_states = teacher.extract_states(input_ids=input_ids_all, delta=args.delta, subset=None) 151 | teacher_states_cat = torch.stack(teacher_states, dim=-2) # bsz, seq_len, layers, hidden 152 | encoded_states = autoencoder.encode(teacher_states_cat) 153 | teacher_states = encoded_states.transpose(0, 1) 154 | outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=teacher_states) 155 | loss = outputs.loss 156 | token_accuracy = outputs.token_accuracy.item() 157 | 158 | loss.backward() 159 | torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm) 160 | optimizer.step() 161 | optimizer.zero_grad() 162 | ppl = loss.exp().item() 163 | if step % 100 == 0: 164 | print (f"Step: {step}. PPL: {ppl}. Token Accuracy: {token_accuracy}") 165 | step += 1 166 | accuracy, token_accuracy, ppl = evaluate(val_dataloader, tokenizer, ctx, teacher, autoencoder, student, args.delta, args.subset, args.max_new_tokens) 167 | print (f'Val. PPL: {ppl}; Accuracy: {accuracy}; Token Accuracy: {token_accuracy}.') 168 | student.save_pretrained(os.path.join(args.save_model, f'checkpoint_{epoch}')) 169 | 170 | if __name__ == "__main__": 171 | main() 172 | -------------------------------------------------------------------------------- /src_autoencoder/train_teacher.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from transformers import AdamW 5 | import argparse 6 | import os 7 | import tqdm 8 | import inspect 9 | import logging 10 | 11 | from models.teacher import Teacher 12 | from models.configuration_teacher import TeacherConfig 13 | from data import CoTDataset, CoTDataCollator, extract_answer 14 | 15 | from utils import get_sep_position 16 | 17 | torch.backends.cuda.matmul.allow_tf32 = True 18 | torch.backends.cudnn.allow_tf32 = True 19 | logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere 20 | 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | 24 | def save_model(model, tokenizer, model_dir): 25 | print ('saving', model_dir) 26 | os.makedirs(model_dir, exist_ok=True) 27 | model.save_pretrained(model_dir) 28 | tokenizer.save_pretrained(model_dir) 29 | 30 | @torch.no_grad() 31 | def evaluate(dataloader, tokenizer, ctx, teacher, max_new_tokens): 32 | teacher.eval() 33 | total_instances = 0 34 | total_tokens = 0 35 | total_correct = 0 36 | total_correct_tokens = 0 37 | total_loss = 0 38 | for batch in tqdm.tqdm(dataloader): 39 | input_ids_all = batch['input_ids_all'].to(device) 40 | labels = batch['labels_all'].to(device) 41 | # Remove answer part 42 | sep_positions = get_sep_position(input_ids_all, tokenizer.eos_token_id) 43 | input_ids = input_ids_all[:, :sep_positions.max()+1] 44 | batch_size = input_ids.shape[0] 45 | with ctx: 46 | outputs = teacher.compute_loss(input_ids=input_ids_all, labels=labels) 47 | total_loss += outputs.total_loss.item() 48 | total_correct_tokens += outputs.total_correct.item() 49 | total_tokens += outputs.total_tokens 50 | total_instances += batch_size 51 | 52 | # Generate 53 | beam_output = teacher.generate( 54 | input_ids=input_ids, 55 | max_new_tokens=max_new_tokens, 56 | ) 57 | # Evaluate 58 | #import pdb; pdb.set_trace() 59 | for i, (input_ids_all_i, beam_output_i) in enumerate(zip(input_ids_all, beam_output)): 60 | sep_position = sep_positions[i].item() 61 | tgt = input_ids_all_i[sep_position+1:] 62 | tgt_text = tokenizer.decode(tgt, skip_special_tokens=True) 63 | ans = extract_answer(tgt_text) 64 | pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True) 65 | pred_ans = extract_answer(pred_text) 66 | if ans == pred_ans: 67 | total_correct += 1 68 | if i == 0: 69 | print (f'Input: {tokenizer.decode(input_ids_all_i[:sep_position], skip_special_tokens=True)}') 70 | print (f'Target: {tgt_text}') 71 | print (f'Predicted: {pred_text}') 72 | print ('') 73 | accuracy = total_correct / total_instances 74 | token_accuracy = total_correct_tokens / total_tokens 75 | loss = total_loss / total_tokens 76 | ppl = math.exp(loss) 77 | return accuracy, token_accuracy, ppl 78 | 79 | 80 | def main(): 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--train_path', type=str, required=True) 83 | parser.add_argument('--val_path', type=str, required=True) 84 | parser.add_argument('--save_model', type=str, required=True) 85 | parser.add_argument('--max_new_tokens', type=int, default=128) 86 | parser.add_argument('--base_model', type=str, default='gpt2') 87 | parser.add_argument('--epochs', type=int, default=1) 88 | parser.add_argument('--batch_size', type=int, default=32) 89 | parser.add_argument('--lr', type=float, default=5e-5) 90 | parser.add_argument('--max_grad_norm', type=float, default=1.0) 91 | args = parser.parse_args() 92 | 93 | print (args) 94 | 95 | dtype = 'float32' 96 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 97 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 98 | ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype) 99 | print (ptdtype, dtype, device) 100 | 101 | # Create Student 102 | config = TeacherConfig(base_model=args.base_model) 103 | teacher = Teacher(config).to(device).to(ptdtype) 104 | 105 | # Load data 106 | tokenizer = teacher.tokenizer 107 | collate_fn = CoTDataCollator(tokenizer) 108 | train_dataset = CoTDataset(tokenizer, args.train_path, 1024) 109 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True) 110 | val_dataset = CoTDataset(tokenizer, args.val_path, 1024) 111 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False) 112 | 113 | # Create Optimizer 114 | trainable_params = teacher.parameters() 115 | use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters 116 | extra_args = dict(fused=True) if use_fused else dict() 117 | optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args) 118 | 119 | teacher.train() 120 | 121 | # Train 122 | step = 0 123 | for epoch in range(args.epochs): 124 | print(f"Epoch {epoch}") 125 | teacher.train() 126 | for batch in tqdm.tqdm(train_dataloader): 127 | input_ids = batch['input_ids_all'].to(device) 128 | labels = batch['labels_all'].to(device) 129 | with ctx: 130 | outputs = teacher.compute_loss(input_ids=input_ids, labels=labels) 131 | loss = outputs.loss 132 | token_accuracy = outputs.token_accuracy.item() 133 | 134 | loss.backward() 135 | torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm) 136 | optimizer.step() 137 | optimizer.zero_grad() 138 | ppl = loss.exp().item() 139 | if step % 100 == 0: 140 | print (f"Step: {step}. PPL: {ppl}. Token Accuracy: {token_accuracy}") 141 | step += 1 142 | accuracy, token_accuracy, ppl = evaluate(val_dataloader, tokenizer, ctx, teacher, args.max_new_tokens) 143 | print (f'Val. PPL: {ppl}; Accuracy: {accuracy}; Token Accuracy: {token_accuracy}.') 144 | teacher.save_pretrained(os.path.join(args.save_model, f'checkpoint_{epoch}')) 145 | 146 | if __name__ == "__main__": 147 | main() 148 | -------------------------------------------------------------------------------- /src_autoencoder/train_thought_emulator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from transformers import AdamW 4 | import argparse 5 | import os 6 | import inspect 7 | import tqdm 8 | import logging 9 | import random 10 | import torch.nn as nn 11 | 12 | from data import CoTDataset, CoTDataCollator 13 | from models.teacher import Teacher 14 | from models.emulator import Emulator 15 | from models.configuration_emulator import EmulatorConfig 16 | 17 | 18 | torch.backends.cuda.matmul.allow_tf32 = True 19 | torch.backends.cudnn.allow_tf32 = True 20 | random.seed(1234) 21 | torch.manual_seed(1234) 22 | logging.disable(logging.WARNING) 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | @torch.no_grad() 27 | def evaluate(dataloader, tokenizer, ctx, teacher, emulator, delta, subset): 28 | total_instances = 0 29 | total_loss = 0 30 | for batch in tqdm.tqdm(dataloader): 31 | #import pdb; pdb.set_trace() 32 | input_ids_cot = batch['input_ids_cot'].to(device) 33 | batch_size = input_ids_cot.shape[0] 34 | with ctx: 35 | teacher_states = teacher.extract_states(input_ids=input_ids_cot, delta=delta, subset=subset) 36 | outputs = emulator.compute_loss(input_ids=input_ids_cot, teacher_states=teacher_states) 37 | loss = outputs.loss 38 | total_loss += outputs.total_loss.item() 39 | total_instances += batch_size 40 | 41 | loss = total_loss / total_instances 42 | return loss 43 | 44 | def main(): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--teacher', type=str, required=True) 47 | parser.add_argument('--delta', type=str, required=True) 48 | parser.add_argument('--train_path', type=str, required=True) 49 | parser.add_argument('--val_path', type=str, required=True) 50 | parser.add_argument('--save_model', type=str, required=True) 51 | parser.add_argument('--base_model', type=str, default='gpt2') 52 | parser.add_argument('--epochs', type=int, default=5) 53 | parser.add_argument('--batch_size', type=int, default=32) 54 | parser.add_argument('--lr', type=float, default=5e-5) 55 | parser.add_argument('--max_grad_norm', type=float, default=1.0) 56 | parser.add_argument('--subset', type=str, choices=['diagonal', 'last_column', 'top_row', 'bottom_row', 'first_column'], default='diagonal') 57 | parser.add_argument('--mixture_size', type=int, default=1) 58 | args = parser.parse_args() 59 | 60 | print (args) 61 | dtype = 'float32' 62 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 63 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 64 | ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype) 65 | print (ptdtype, dtype, device) 66 | 67 | # Create Emulator 68 | config = EmulatorConfig(base_model=args.base_model, mixture_size=args.mixture_size) 69 | emulator = Emulator(config).to(device).to(ptdtype) 70 | 71 | # Load Teacher 72 | teacher = Teacher.from_pretrained(args.teacher).to(device).to(ptdtype) 73 | 74 | # Load data 75 | tokenizer = teacher.tokenizer 76 | collate_fn = CoTDataCollator(tokenizer) 77 | train_dataset = CoTDataset(tokenizer, args.train_path, 1024) 78 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True) 79 | val_dataset = CoTDataset(tokenizer, args.val_path, 1024) 80 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False) 81 | 82 | # Create Optimizer 83 | trainable_params = emulator.parameters() 84 | use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters 85 | extra_args = dict(fused=True) if use_fused else dict() 86 | optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args) 87 | 88 | teacher.eval() 89 | emulator.eval() # to turn off dropout 90 | 91 | for p in teacher.parameters(): 92 | p.requires_grad = False 93 | 94 | # Train 95 | step = 0 96 | for epoch in range(args.epochs): 97 | print(f"Epoch {epoch}") 98 | 99 | for batch in tqdm.tqdm(train_dataloader): 100 | #import pdb; pdb.set_trace() 101 | input_ids_cot = batch['input_ids_cot'].to(device) 102 | input_ids_nocot = batch['input_ids_nocot'].to(device) 103 | with ctx: 104 | with torch.no_grad(): 105 | teacher_states = teacher.extract_states(input_ids=input_ids_cot, delta=args.delta, subset=args.subset) 106 | outputs = emulator.compute_loss(input_ids=input_ids_nocot, teacher_states=teacher_states) 107 | loss = outputs.loss 108 | 109 | loss.backward() 110 | torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm) 111 | optimizer.step() 112 | optimizer.zero_grad() 113 | if step % 100 == 0: 114 | print (f"Step: {step}. Loss: {loss}.") 115 | step += 1 116 | loss = evaluate(val_dataloader, tokenizer, ctx, teacher, emulator, args.delta, args.subset) 117 | print (f'Val. Loss: {loss}.') 118 | emulator.save_pretrained(os.path.join(args.save_model, f'checkpoint_{epoch}')) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /src_autoencoder/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria, LogitsProcessor 3 | 4 | def get_sep_position(input_ids, sep_id, skip=0): 5 | batch_size = input_ids.shape[0] 6 | sep_positions = input_ids.new_zeros(batch_size).long() 7 | for batch_id in range(batch_size): 8 | mask = input_ids[batch_id].eq(sep_id) 9 | sep_position = mask.nonzero()[0, -1].item() 10 | for _ in range(skip): 11 | mask[sep_position] = False 12 | sep_position = mask.nonzero()[0, -1].item() 13 | sep_positions[batch_id] = sep_position 14 | return sep_positions 15 | 16 | 17 | # Stop generation only after generating two EOSs, such as z y 18 | class DoubleEOSStoppingCriteria(StoppingCriteria): 19 | def __init__(self, eos_token_id): 20 | super().__init__() 21 | self.eos_token_id = eos_token_id 22 | self.init = False 23 | 24 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 25 | eos_count = (input_ids == self.eos_token_id).sum(dim=-1) 26 | if not self.init: 27 | self.init = True 28 | self.eos_count_init = eos_count 29 | done = (eos_count - self.eos_count_init) >= 2 30 | return done.all() 31 | 32 | class DoubleEOSLogitsProcessor(LogitsProcessor): 33 | def __init__(self, eos_token_id): 34 | super().__init__() 35 | self.eos_token_id = eos_token_id 36 | self.init = False 37 | 38 | def __call__(self, input_ids, scores): 39 | eos_count = (input_ids == self.eos_token_id).sum(dim=-1) 40 | if not self.init: 41 | self.init = True 42 | self.eos_count_init = eos_count 43 | done = (eos_count - self.eos_count_init) >= 2 44 | if done.any(): 45 | scores[done, :] = float('-inf') 46 | scores[done, self.eos_token_id] = 0 47 | return scores 48 | --------------------------------------------------------------------------------