├── .gitattributes ├── real_tasks ├── math_text_book.py ├── summarization.py ├── knowledge_qa.py ├── code_completion.py └── fact_qa.py ├── figs ├── 7B-wikitext.png └── robustness_performance_Wikitext_7B.png ├── requirements.txt ├── .gitignore ├── tests ├── test_gzip.py ├── test_zlib.py ├── test_png.py ├── test_flac.py └── test_arithmetic.py ├── docker └── Dockerfile ├── .github └── workflows │ ├── plotly.yml │ └── static.yml ├── results ├── aggregate_all_results.py ├── time_and_mem_for_context_sizes.txt └── collect_all_results.py ├── visualise ├── compare_tokenizer.py ├── wiki_diff_ratio.py ├── timeline_vis.py ├── interactive.py ├── temporal.py ├── barplot_context_size_compare.py └── big_table.py ├── compressor.py ├── readme.md ├── page ├── index.html ├── wikitext_context.html └── math.html ├── arithmetic.py ├── main.py ├── evaluator.py ├── arithmetic_int32.cpp └── data_processor.py /.gitattributes: -------------------------------------------------------------------------------- 1 | page/* linguist-vendored 2 | -------------------------------------------------------------------------------- /real_tasks/math_text_book.py: -------------------------------------------------------------------------------- 1 | from fact_qa import FactQA 2 | 3 | -------------------------------------------------------------------------------- /figs/7B-wikitext.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyucheng09/llm-compressive/HEAD/figs/7B-wikitext.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ninja 2 | transformers 3 | torch 4 | nltk 5 | datasets 6 | numpy 7 | pypng 8 | soundfile -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | htcondor/ 2 | __pycache__/ 3 | .vscode/ 4 | *.json 5 | *.jsonl 6 | !results/*.json 7 | real_tasks/*.json 8 | outputs/* -------------------------------------------------------------------------------- /figs/robustness_performance_Wikitext_7B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyucheng09/llm-compressive/HEAD/figs/robustness_performance_Wikitext_7B.png -------------------------------------------------------------------------------- /tests/test_gzip.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os 3 | 4 | text = ' '.join(['pytorch'] * 1000) 5 | byte_stream = text.encode('utf-8') 6 | print(len(byte_stream)) 7 | 8 | with gzip.open('compressed.gz', 'wb') as f: 9 | f.write(byte_stream) 10 | 11 | print(os.path.getsize('compressed.gz')) -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:24.04-py3 2 | 3 | WORKDIR /workspace/ 4 | RUN pip install --upgrade pip \ 5 | && pip install vllm transformers datasets accelerate nltk flash_attn --no-cache-dir \ 6 | && pip uninstall transformer-engine -y 7 | 8 | RUN pip install evaluate rouge rouge_score --no-cache-dir -------------------------------------------------------------------------------- /tests/test_zlib.py: -------------------------------------------------------------------------------- 1 | import zlib 2 | 3 | text = ' '.join(['pytorch'] * 1000) 4 | byte_stream = text.encode('utf-8') 5 | print(len(byte_stream)) 6 | 7 | # level=9, wbits=15 is the same as gzip 8 | compressor = zlib.compressobj(level=9, wbits=15) 9 | compressed = compressor.compress(byte_stream) + compressor.flush() 10 | print(len(compressed)) -------------------------------------------------------------------------------- /tests/test_png.py: -------------------------------------------------------------------------------- 1 | import png 2 | import os 3 | 4 | text = ' '.join(['pytorch'] * 1000) 5 | byte_stream = text.encode('utf-8') 6 | print(len(byte_stream)) 7 | 8 | w = png.Writer(len(byte_stream), 1, greyscale=True, bitdepth=8) 9 | with open('compressed.png', 'wb') as f: 10 | w.write(f, [byte_stream]) 11 | print(os.path.getsize('compressed.png')) -------------------------------------------------------------------------------- /tests/test_flac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import soundfile as sf 3 | import os 4 | 5 | text = ' '.join(['pytorch'] * 1000) 6 | byte_stream = text.encode('utf-8') 7 | print(len(byte_stream)) 8 | 9 | if len(byte_stream) % 2 != 0: 10 | byte_stream += b'\x00' 11 | 12 | pesudo_audio = np.frombuffer(byte_stream, dtype=np.int16) 13 | print(pesudo_audio.shape) 14 | 15 | sample_rate = 16000 16 | 17 | sf.write('pesudo_audio.flac', pesudo_audio, sample_rate) 18 | print(os.path.getsize('pesudo_audio.flac')) -------------------------------------------------------------------------------- /.github/workflows/plotly.yml: -------------------------------------------------------------------------------- 1 | name: plotly 2 | 3 | on: 4 | push: 5 | paths: 6 | - 'results/*.json' 7 | 8 | workflow_dispatch: 9 | 10 | permissions: 11 | contents: write 12 | 13 | jobs: 14 | generate-and-commit: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v2 18 | with: 19 | fetch-depth: 0 # Fetch all history for .GitInfo and .Lastmod 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: '3.8' # Specify the Python version 25 | 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install pandas plotly 30 | 31 | - name: Run the Python script 32 | run: python visualise/interactive.py 33 | 34 | - name: Commit and push if there are changes 35 | run: | 36 | git config --local user.email "action@github.com" 37 | git config --local user.name "LLM-Compressive Action" 38 | git add -A 39 | git commit -m "Automatically generated plotly visualizations" -a || echo "No changes to commit" 40 | git push -------------------------------------------------------------------------------- /results/aggregate_all_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | from glob import glob 3 | import sys 4 | import os 5 | 6 | if __name__ == '__main__': 7 | result_path, = sys.argv[1:] 8 | datasets = ['bbc_image', 'wikitext', 'arxiv', 'code', 'bbc_news', 'audio'] 9 | 10 | for ds in datasets: 11 | all_results = {} 12 | path = os.path.join(result_path, f'{ds}/*/*.json') 13 | results_files = glob(path) 14 | 15 | if len(results_files) == 0: 16 | print(path) 17 | print(f'No results found for {ds}') 18 | continue 19 | 20 | for rf in results_files: 21 | model_name = rf.split('/')[-1].split('.')[0] 22 | time_stamp = rf.split('/')[-2] 23 | with open(rf, 'r') as f: 24 | results = json.load(f) 25 | 26 | if model_name in all_results: 27 | if time_stamp in all_results[model_name]: 28 | continue 29 | else: 30 | all_results[model_name] = {} 31 | 32 | all_results[model_name][time_stamp] = results 33 | 34 | sorted_results = {model: dict(sorted(timestamps.items())) for model, timestamps in sorted(all_results.items())} 35 | with open(f'results/{ds}_results.json', 'w') as f: 36 | json.dump(sorted_results, f, indent=2, ensure_ascii=False) 37 | 38 | print(f'Finished {ds}, saved to {ds}_results.json') 39 | -------------------------------------------------------------------------------- /.github/workflows/static.yml: -------------------------------------------------------------------------------- 1 | # Simple workflow for deploying static content to GitHub Pages 2 | name: Deploy static content to Pages 3 | 4 | on: 5 | workflow_run: 6 | workflows: ["plotly"] 7 | types: 8 | - completed 9 | 10 | push: 11 | branches: ["main"] 12 | 13 | # Allows you to run this workflow manually from the Actions tab 14 | workflow_dispatch: 15 | 16 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 17 | permissions: 18 | contents: read 19 | pages: write 20 | id-token: write 21 | 22 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 23 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 24 | concurrency: 25 | group: "pages" 26 | cancel-in-progress: false 27 | 28 | jobs: 29 | # Single deploy job since we're just deploying 30 | deploy: 31 | environment: 32 | name: github-pages 33 | url: ${{ steps.deployment.outputs.page_url }} 34 | runs-on: ubuntu-latest 35 | steps: 36 | - name: Checkout 37 | uses: actions/checkout@v4 38 | - name: Setup Pages 39 | uses: actions/configure-pages@v4 40 | - name: Upload artifact 41 | uses: actions/upload-pages-artifact@v3 42 | with: 43 | # Upload entire repository 44 | path: 'page/' 45 | - name: Deploy to GitHub Pages 46 | id: deployment 47 | uses: actions/deploy-pages@v4 48 | -------------------------------------------------------------------------------- /visualise/compare_tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | with open('results/wikitext_results.json') as f: 5 | results = json.load(f) 6 | 7 | models = ['Mistral-7B', 'Llama-2-7B-HF', 'Baichuan2-7B-Base', 'Qwen-7B', 'chatglm3-6b-base'] 8 | months_in_2023 = [f'2023-{month:02d}' for month in range(1, 12)] 9 | 10 | metrics_for_tokenizer = {} 11 | 12 | for model_name, data in results.items(): 13 | if model_name not in models: 14 | continue 15 | 16 | tokens_months = [] 17 | bpt_months = [] 18 | bpb_months = [] 19 | 20 | for month, metrics in data.items(): 21 | if month not in months_in_2023: 22 | continue 23 | 24 | # this is in bytes 25 | compressed_size = metrics['compressed_size'] 26 | 27 | # these two are in bits 28 | bpt = metrics['bpt'] 29 | bpb = metrics['bpb'] 30 | 31 | num_tokens = compressed_size * 8 / bpt 32 | 33 | tokens_months.append(num_tokens) 34 | bpt_months.append(bpt) 35 | bpb_months.append(bpb) 36 | 37 | total_tokens = sum(tokens_months) 38 | total_bpt = np.mean(bpt_months) 39 | total_bpb = np.mean(bpb_months) 40 | 41 | if model_name not in metrics_for_tokenizer: 42 | metrics_for_tokenizer[model_name] = {} 43 | 44 | metrics_for_tokenizer[model_name]['vocab_size'] = '-' 45 | metrics_for_tokenizer[model_name]['total_tokens'] = total_tokens 46 | metrics_for_tokenizer[model_name]['total_bpt'] = total_bpt 47 | metrics_for_tokenizer[model_name]['total_bpb'] = total_bpb 48 | 49 | import pandas as pd 50 | 51 | df = pd.DataFrame.from_dict(metrics_for_tokenizer, orient='index') 52 | print(df.to_latex(float_format='%.4f')) -------------------------------------------------------------------------------- /tests/test_arithmetic.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, GPT2LMHeadModel, AutoModelForCausalLM 2 | import torch 3 | import arithmetic 4 | 5 | def pmf_to_cdf(pmf): 6 | cdf = pmf.cumsum(dim=-1) 7 | spatial_dimensions = pmf.shape[:-1] + (1,) 8 | zeros = torch.zeros(spatial_dimensions, dtype=pmf.dtype, device=pmf.device) 9 | cdf_with_0 = torch.cat([zeros, cdf], dim=-1) 10 | # On GPU, softmax followed by cumsum can lead to the final value being 11 | # slightly bigger than 1, so we clamp. 12 | cdf_with_0 = cdf_with_0.clamp(max=1.) 13 | return cdf_with_0 14 | 15 | model = AutoModelForCausalLM.from_pretrained("/mnt/fast/nobackup/scratch4weeks/yl02706/models/Mistral-7B", device_map='auto', trust_remote_code=True) 16 | tokenizer = AutoTokenizer.from_pretrained("/mnt/fast/nobackup/scratch4weeks/yl02706/models/Mistral-7B", use_fast=False, trust_remote_code=True) 17 | 18 | text = ' '.join(['pytorch'] * 500) 19 | 20 | inputs = tokenizer(text, return_tensors="pt", truncation=True).to(model.device) 21 | outputs = model(**inputs) 22 | 23 | print(outputs.logits.dtype) 24 | 25 | probs = outputs.logits.softmax(dim=-1) 26 | print(probs.shape) 27 | 28 | cdf = pmf_to_cdf(probs) 29 | 30 | bits = -torch.log2(probs).gather(dim=-1, index=inputs['input_ids'].unsqueeze(-1)).squeeze(-1).sum() 31 | print(bits) 32 | print(bits/8) 33 | print(text.encode('utf-8').__len__()) 34 | 35 | cdf = cdf.detach().cpu() 36 | 37 | sym_32 = inputs['input_ids'].to(torch.int32).detach().cpu() 38 | sym_16 = inputs['input_ids'].to(torch.int16).detach().cpu() 39 | 40 | byte_stream_32 = arithmetic.encode_float_cdf(cdf, sym_32) 41 | print(len(byte_stream_32)) 42 | 43 | d = arithmetic.decode_float_cdf(cdf, byte_stream_32) 44 | print('=========================') 45 | 46 | assert sym_32.equal(d) -------------------------------------------------------------------------------- /real_tasks/summarization.py: -------------------------------------------------------------------------------- 1 | from fact_qa import * 2 | 3 | class Summization(FactQA): 4 | def make_example( 5 | self, 6 | article, 7 | ): 8 | pass 9 | 10 | def make_test(self): 11 | for time, dataset in tqdm(self.datasets.items(), desc='Make Test - loop over months'): 12 | # skip? 13 | if self.do_skip: 14 | f = os.path.join(self.question_output_dir, f'{time}.json') 15 | if os.path.exists(f): 16 | with open(f, 'r') as f: 17 | self.test_set[time] = json.load(f) 18 | continue 19 | 20 | count = 0 21 | if time not in self.test_set: 22 | self.test_set[time] = [] 23 | for article in tqdm(dataset): 24 | prompt = self._make_prompt(task='question_generation', article=article) 25 | response = self._make_llm_request(prompt) 26 | response = self._parse_response( 27 | task='question_generation', 28 | response=response 29 | ) 30 | if response is not None: 31 | count += 1 32 | 33 | self.test_set[time].append({ 34 | 'prompt': prompt, 35 | 'article': article, 36 | 'response': response 37 | }) 38 | 39 | if count >= self.num_examples_per_time: 40 | break 41 | 42 | if count % 10 == 0: 43 | self.save_to_file( 44 | task='question_generation', 45 | time=time 46 | ) 47 | 48 | self.save_to_file( 49 | task='question_generation', 50 | time=time 51 | ) 52 | 53 | print('Finished making test set.') -------------------------------------------------------------------------------- /results/time_and_mem_for_context_sizes.txt: -------------------------------------------------------------------------------- 1 | Model: Qwen-7B, Context size: 2048, Total time: 99.58475820479855, Cuda Mem (MB): 20043.41455078125 2 | Model: Qwen-7B, Context size: 4096, Total time: 92.96529369969522, Cuda Mem (MB): 25347.43017578125 3 | Model: Qwen-7B, Context size: 8192, Total time: 91.81131039896319, Cuda Mem (MB): 35951.46142578125 4 | Model: Mistral-7B, Context size: 2048, Total time: 118.23550676530408, Cuda Mem (MB): 15486.6640625 5 | Model: Mistral-7B, Context size: 4096, Total time: 107.82587099844409, Cuda Mem (MB): 16638.6796875 6 | Model: Mistral-7B, Context size: 8192, Total time: 107.42827933834445, Cuda Mem (MB): 18944.7109375 7 | Model: Baichuan2-7B-Base, Context size: 2048, Total time: 134.82086438517416, Cuda Mem (MB): 19547.6484375 8 | Model: Baichuan2-7B-Base, Context size: 4096, Total time: 145.88360017345798, Cuda Mem (MB): 24842.6953125 9 | Model: Llama-2-7B-HF, Context size: 2048, Total time: 113.55993147819272, Cuda Mem (MB): 15614.6640625 10 | Model: Llama-2-7B-HF, Context size: 4096, Total time: 101.25017941382623, Cuda Mem (MB): 18302.6796875 11 | Model: chatglm3-6b-base, Context size: 2048, Total time: 109.93928186355099, Cuda Mem (MB): 13562.83251953125 12 | Model: chatglm3-6b-base, Context size: 4096, Total time: 96.43869333882486, Cuda Mem (MB): 15174.84814453125 13 | Model: chatglm3-6b-base, Context size: 8192, Total time: 98.47218653463548, Cuda Mem (MB): 18396.87939453125 14 | Model: chatglm3-6b-base, Context size: 2048, Total time: 320.4981663919264, Cuda Mem (MB): 13562.83251953125 15 | Model: Mistral-7B, Context size: 2048, Total time: 414.97833924139695, Cuda Mem (MB): 15486.6640625 16 | Model: Qwen-7B, Context size: 2048, Total time: 338.2676722311204, Cuda Mem (MB): 20043.41455078125 17 | Model: Baichuan2-7B-Base, Context size: 2048, Total time: 479.4509780022406, Cuda Mem (MB): 19547.6484375 18 | Model: Yi-6B, Context size: 2048, Total time: 584.3244448246495, Cuda Mem (MB): 14676.6796875 19 | -------------------------------------------------------------------------------- /results/collect_all_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | from glob import glob 3 | import sys 4 | import os 5 | 6 | if __name__ == '__main__': 7 | # result_path, = sys.argv[1:] 8 | result_paths = ['/mnt/fast/nobackup/users/yl02706/newac_compressive', '/mnt/fast/nobackup/users/yl02706/compressive'] 9 | # result_paths = ['/mnt/fast/nobackup/users/yl02706/compressive_newcode'] 10 | datasets = ['bbc_image', 'wikitext', 'arxiv', 'code', 'bbc_news', 'math'] 11 | # datasets = ['wikitext'] 12 | 13 | for ds in datasets: 14 | all_results = {} 15 | for result_path in result_paths: 16 | path = os.path.join(result_path, f'{ds}/*/*.json') 17 | results_files = glob(path) 18 | 19 | if len(results_files) == 0: 20 | print(path) 21 | print(f'No results found for {ds}') 22 | continue 23 | 24 | for rf in results_files: 25 | model_name = rf.split('/')[-1][:-5] 26 | 27 | time_stamp = rf.split('/')[-2] 28 | with open(rf, 'r') as f: 29 | results = json.load(f) 30 | results = results['ratio'] 31 | 32 | if model_name in ['LLaMA-7B', 'Llama-2-7B']: 33 | continue 34 | 35 | if model_name in all_results: 36 | if time_stamp in all_results[model_name]: 37 | continue 38 | else: 39 | all_results[model_name] = {} 40 | 41 | all_results[model_name][time_stamp] = results 42 | 43 | sorted_results = {model: dict(sorted(timestamps.items())) for model, timestamps in sorted(all_results.items())} 44 | with open(f'results/{ds}_results.json', 'w') as f: 45 | json.dump(sorted_results, f, indent=2, ensure_ascii=False) 46 | 47 | print(f'Finished {ds}, saved to {ds}_results.json') 48 | -------------------------------------------------------------------------------- /compressor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import arithmetic 3 | import png 4 | import os 5 | import gzip 6 | import zlib 7 | import numpy as np 8 | import soundfile as sf 9 | 10 | def arithmetic_coding(pmf, sym, save_path = None): 11 | # pmf is the output of language model after softmax, which is the prob distribution over the vocab 12 | # sym is input_ids 13 | def pmf_to_cdf(pmf): 14 | cdf = pmf.cumsum(dim=-1) 15 | spatial_dimensions = pmf.shape[:-1] + (1,) 16 | zeros = torch.zeros(spatial_dimensions, dtype=pmf.dtype, device=pmf.device) 17 | cdf_with_0 = torch.cat([zeros, cdf], dim=-1) 18 | # On GPU, softmax followed by cumsum can lead to the final value being 19 | # slightly bigger than 1, so we clamp. 20 | cdf_with_0 = cdf_with_0.clamp(max=1.) 21 | return cdf_with_0 22 | 23 | cdf = pmf_to_cdf(pmf) 24 | if cdf.device.type != 'cpu': 25 | cdf = cdf.detach().cpu() 26 | if sym.device.type != 'cpu': 27 | sym = sym.detach().cpu() 28 | 29 | byte_stream = arithmetic.encode_float_cdf(cdf, sym) 30 | if save_path is not None: 31 | with open(save_path, 'wb') as f: 32 | f.write(byte_stream) 33 | 34 | return byte_stream 35 | 36 | def png_compressor(byte_stream, save_path = None): 37 | w = png.Writer(len(byte_stream), 1, greyscale=True, bitdepth=8) 38 | if save_path is None: 39 | save_path = 'compressed.png' 40 | with open(save_path, 'wb') as f: 41 | w.write(f, [byte_stream]) 42 | compressed_size = os.path.getsize(save_path) 43 | return compressed_size 44 | 45 | def gzip_compressor(byte_stream, save_path = None): 46 | if save_path is None: 47 | save_path = 'compressed.gz' 48 | with gzip.open(save_path, 'wb') as f: 49 | f.write(byte_stream) 50 | compressed_size = os.path.getsize(save_path) 51 | return compressed_size 52 | 53 | def zlib_compressor(byte_stream, wbits = 15, save_path = None): 54 | compressor = zlib.compressobj(level=9, wbits=wbits) 55 | compressed = compressor.compress(byte_stream) + compressor.flush() 56 | if save_path is not None: 57 | with open(save_path, 'wb') as f: 58 | f.write(compressed) 59 | compressed_size = len(compressed) 60 | return compressed_size 61 | 62 | def flac_compressor(byte_stream, save_path = None): 63 | if len(byte_stream) % 2 != 0: 64 | byte_stream = byte_stream + b'\x00' 65 | pesudo_audio = np.frombuffer(byte_stream, dtype=np.int16) 66 | sample_rate = 16000 67 | if save_path is None: 68 | save_path = 'compressed.flac' 69 | else: 70 | save_path = save_path + '.flac' 71 | sf.write(save_path, pesudo_audio, sample_rate) 72 | compressed_size = os.path.getsize(save_path) 73 | return compressed_size 74 | -------------------------------------------------------------------------------- /visualise/wiki_diff_ratio.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import json 3 | import numpy as np 4 | import matplotlib.ticker as ticker 5 | 6 | monthly_diffs = {'2017-02': 0.01116792953736281, '2017-03': 0.01834236197654902, '2017-04': 0.02574304017079133, '2017-05': 0.03496746749271252, '2017-06': 0.041387342586072594, '2017-07': 0.046878354661619895, '2017-08': 0.05455057883593177, '2017-09': 0.06289097497392492, '2017-10': 0.07010818532431914, '2017-11': 0.07661655089425637, '2017-12': 0.08917030266375628, '2018-01': 0.09563576470399006, '2018-02': 0.09848359035700614, '2018-03': 0.10524214585024406, '2018-04': 0.11351711685274801, '2018-05': 0.12082234269079567, '2018-06': 0.12725273504445206, '2018-07': 0.13391217103865546, '2018-08': 0.13992957804893252, '2018-09': 0.1467543261161517, '2018-10': 0.1516270840152477, '2018-11': 0.15971477799389064, '2018-12': 0.16392788269177516, '2019-01': 0.16877427806512957, '2019-02': 0.17272036404500812, '2019-03': 0.1734932242894733, '2019-04': 0.17643397228150343, '2019-05': 0.18204949683295235, '2019-06': 0.18567249145794149, '2019-07': 0.18860373353760934, '2019-08': 0.192941409599943, '2019-09': 0.19646115272767573, '2019-10': 0.19947695379498326, '2019-11': 0.2042638730400354, '2019-12': 0.20535551225073473, '2020-01': 0.21022376468560444, '2020-02': 0.2163519693488857, '2020-03': 0.21776973684670542, '2020-04': 0.22353608305167946, '2020-05': 0.22831721488527923, '2020-06': 0.2318074801066413, '2020-07': 0.23823993718669517, '2020-08': 0.24372521733601665, '2020-09': 0.24666164927737053, '2020-10': 0.2495010704975462, '2020-11': 0.25411969603513096, '2020-12': 0.2621696338476233, '2021-01': 0.26893761687770185, '2021-02': 0.2737068346035143, '2021-03': 0.2798274781454656, '2021-04': 0.2853918394203085, '2021-05': 0.2930184706701309, '2021-06': 0.2984052276496851, '2021-07': 0.304206707259906, '2021-08': 0.3061869007586494, '2021-09': 0.3095803030710187, '2021-10': 0.31184193819305417, '2021-11': 0.31569386698315405, '2021-12': 0.3159712310278265, '2022-01': 0.31887096932240977, '2022-02': 0.321770178310484, '2022-03': 0.3253934021689546, '2022-04': 0.327022089258801, '2022-05': 0.3295241350121946, '2022-06': 0.33243314038148963, '2022-07': 0.33418594586120914, '2022-08': 0.3372125906977422, '2022-09': 0.3422111189008709, '2022-10': 0.3454443349018839, '2022-11': 0.3482930481380629, '2022-12': 0.3519015515709205, '2023-01': 0.35437893811057447, '2023-02': 0.3569234666313323, '2023-03': 0.3595299016739051, '2023-04': 0.36260677971307864, '2023-05': 0.36633521291601473, '2023-06': 0.3677992689595484, '2023-07': 0.37103792495341176, '2023-08': 0.37351423621432, '2023-09': 0.37576021807352356, '2023-10': 0.37949992109413283, '2023-11': 0.3827559713344019} 7 | 8 | plt.figure(figsize=(10, 5), dpi=200) 9 | 10 | months = list(monthly_diffs.keys()) 11 | plt.xticks(np.arange(0, len(months), 6), rotation=45) 12 | plt.gca().yaxis.set_major_locator(ticker.MultipleLocator(0.1)) 13 | 14 | plt.plot(monthly_diffs.keys(), monthly_diffs.values(), color='blue', linewidth=1.5, linestyle='solid', label='Wikitext', marker='o', markersize=3) 15 | 16 | plt.ylabel('Diff ratio (compare to 2017-01)') 17 | 18 | plt.legend(fontsize=10) 19 | plt.tight_layout() 20 | 21 | plt.grid() 22 | 23 | plt.savefig('results/wiki_diff_ratio.png') -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # LLM-Compressive: Longitudinal Evaluation of LLMs via Data Compression 2 | 3 | Compression is believed to be the key feature of intelligence. Llm-compressive allows you to evaluate Large Language Models (LLMs) for generalization and robustness via **data compression**. 4 | 5 | Llm-compressive tests LLMs with data compression on timeline, to understand how LLMs generalize over time. 6 | 7 | llm-compressive 8 | 9 | For example, llm-compressive test open source LLMs on wikipedia across 83 months from 2017 to 2023. 10 | 11 | **Mistral** and **Baichuan2** show steady performance across all time periods, indicating promissing generalization over time. In contrast, other models demonstrate linearly-worsen curves. 12 | 13 | More results on coding, arxiv, news, image, and audio in the paper: [Evaluating Large Language Models for Generalization and Robustness via Data Compression 14 | ](https://arxiv.org/pdf/2402.00861.pdf). 15 | 16 | **Updates**: 17 | - 27 Feb 2024, try the interactive leaderboard at [LLM-Compressive](https://liyucheng09.github.io/llm-compressive/). 18 | 19 | # Getting Started 20 | 21 | 0. Clone and install requirements. 22 | 23 | ``` 24 | git clone https://github.com/liyucheng09/llm-compressive.git 25 | cd llm-compressive 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | 1. Run the main test script. 30 | 31 | ``` 32 | python main.py 33 | ``` 34 | 35 | - `model_name`: the name of the model from HF Hub. See supported [models](#models). 36 | - `dataset_name`: the name of the dataset. choose from `wikitext`, `math`, `bbc_news`, `code`, `arxiv`, `audio`, `bbc_image`. 37 | - `save_path`: the path to save the results. 38 | - `context_size`: the context size used for compression. choose from `2048`, `4096`, `8192`, `max_length`, or `stride`. 39 | - `batch_size`: the batch size. This depends on the model scale and your GPU memory. 40 | 41 | **Attention!!**, if you need to use huggingface mirror (which means you have problem accessing huggingface.co directly), add `HF_ENDPOINT=https://hf-mirror.com` in your environment variables. 42 | 43 | 2. Aggregate the results. 44 | 45 | ``` 46 | python results/aggregate_all_results.py 47 | ``` 48 | 49 | - `save_path`: the path you saved the results in. 50 | 51 | 3. Visualize the results. 52 | 53 | ``` 54 | python visualise/timeline_vis.py 55 | ``` 56 | 57 | This will generate a figure visualizing the trend of models' compression rate over time. 58 | 59 | ``` 60 | python visualise/big_table.py 61 | ``` 62 | 63 | This will 1) generate the big table in the paper; 2) generate a figure showing the performance-robustness trade-off of models (like the figure below). 64 | 65 | performance-robustness 66 | 67 | see the explaination of the figure in the [paper](https://arxiv.org/pdf/2402.00861.pdf). 68 | 69 | # Models 70 | 71 | We have tested the following models: 72 | - codellama/CodeLlama-7b-hf 73 | - baichuan-inc/Baichuan2-7B-Base 74 | - mistralai/Mistral-7B-v0.1 75 | - huggyllama/llama-7b 76 | - huggyllama/llama-13b 77 | - huggyllama/llama-65b 78 | - meta-llama/Llama-2-7b-hf 79 | - meta-llama/Llama-2-13b-hf 80 | - meta-llama/Llama-2-70b-hf 81 | - Qwen/Qwen-7B 82 | - internlm/internlm-7b 83 | - THUDM/chatglm3-6b-base 84 | - 01-ai/Yi-6B-200K 85 | - 01-ai/Yi-34B-200K 86 | - google/gemma-7b 87 | - Qwen/Qwen1.5-7B 88 | 89 | And any GPTQ version of the above models, such as: 90 | 91 | - TheBloke/CodeLlama-70B-hf-GPTQ 92 | - TheBloke/Llama-2-70B-GPTQ 93 | - TheBloke/Yi-34B-200K-GPTQ 94 | - ... 95 | 96 | # Issues 97 | 98 | send me emails or open issues if you have any questions. 99 | 100 | # Citation 101 | 102 | If you find this repo helpful, please consider citing our paper: 103 | 104 | ``` 105 | @article{Li2024EvaluatingLL, 106 | title={Evaluating Large Language Models for Generalization and Robustness via Data Compression}, 107 | author={Yucheng Li and Yunhao Guo and Frank Guerin and Chenghua Lin}, 108 | year={2024}, 109 | journal={arXiv preprint arXiv:2402.00861} 110 | } 111 | ``` 112 | -------------------------------------------------------------------------------- /visualise/timeline_vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import json 3 | import numpy as np 4 | import matplotlib.ticker as ticker 5 | 6 | task = 'wikitext' 7 | with open(f'results/{task}_results.json') as f: 8 | data = json.load(f) 9 | 10 | # we have two fig_types now: 'llama' and '7B' 11 | fig_type = '7B' 12 | 13 | # fine-grained control over which models to include 14 | code_llama = True 15 | large_llama = False 16 | internlm = False 17 | 18 | x_ticks = True 19 | no_ylabel = False 20 | 21 | # Create the plot, and set the font sizes 22 | if no_ylabel: 23 | plt.figure(figsize=(8, 5), dpi=180) 24 | else: 25 | plt.figure(figsize=(8.2, 5), dpi=180) 26 | markersize = 3 27 | legend_fontsize = 10 28 | tick_fontsize = 12 29 | label_fontsize = 14 30 | 31 | # Colorblind-friendly colors palette 32 | # Source: https://jfly.uni-koeln.de/color/ 33 | colors = ['#E69F00', '#56B4E9', '#009E73', '#F0E442', '#0072B2', '#D55E00', '#CC79A7'] 34 | 35 | # Different line styles 36 | line_styles = ['-', '--', '-.', ':'] 37 | 38 | # Different markers 39 | markers = ['o', 's', 'D', '^', 'v', '<', '>'] 40 | 41 | model_name_to_label = { 42 | 'Baichuan2-7B-Base': 'Baichuan2-7B', 43 | 'internlm-7B': 'Internlm-7B', 44 | 'Qwen-7B': 'Qwen-7B', 45 | 'Yi-6B': 'Yi-6B', 46 | 'chatglm3-6b-base': 'Chatglm3-6B', 47 | 'Mistral-7B': 'Mistral-7B', 48 | 'LLaMA-7B-HF': 'LLaMA-7B', 49 | 'LLaMA-13B': 'LLaMA-13B', 50 | 'Llama-2-13B': 'Llama-2-13B', 51 | 'Llama-2-7B-HF': 'Llama-2-7B', 52 | 'CodeLlama-7B': 'CodeLlama-7B', 53 | 'Llama-2-70B': 'Llama-2-70B', 54 | 'LLaMA-30B': 'LLaMA-30B', 55 | 'LLaMA-65B': 'LLaMA-65B', 56 | 'Yi-34B-200K': 'Yi-34B', 57 | } 58 | 59 | # Loop through each model's data and plot it 60 | counter = 0 61 | for model_name, model_data in data.items(): 62 | if model_name not in model_name_to_label: 63 | continue 64 | if fig_type == 'llama': 65 | if not code_llama and model_name == 'CodeLlama-7B': 66 | continue 67 | if not large_llama and model_name in ['Llama-2-70B', 'LLaMA-30B', 'LLaMA-65B']: 68 | continue 69 | if model_name not in ['Llama-2-7B-HF', 'LLaMA-7B-HF', 'LLaMA-13B', 'Llama-2-13B', 'CodeLlama-7B', 'Llama-2-70B', 'LLaMA-30B', 'LLaMA-65B']: 70 | continue 71 | elif fig_type == '7B': 72 | if not ('7B' in model_name or '6B' in model_name): 73 | continue 74 | if task not in ['code', 'bbc_image', 'arxiv']: 75 | if 'internlm' in model_name.lower() and not internlm: 76 | continue 77 | if 'code' in model_name.lower() and not code_llama: 78 | continue 79 | else: 80 | raise ValueError(f'Unknown fig_type: {fig_type}') 81 | 82 | labels = list(model_data.keys()) 83 | values = np.array([metrics['ratio'] for metrics in model_data.values()]) * 100 84 | 85 | # remove 2021-03 86 | remove_index = labels.index('2021-03') 87 | values[remove_index] = (values[remove_index - 1] + values[remove_index + 1]) / 2 88 | 89 | # Use color and line style from our predefined lists 90 | color = colors[counter % len(colors)] 91 | line_style = line_styles[counter % len(line_styles)] 92 | marker = markers[counter % len(markers)] 93 | 94 | plt.plot(labels, values, color=color, linestyle=line_style, marker=marker, label=model_name_to_label[model_name], markersize=markersize) 95 | counter += 1 96 | 97 | # Adding title and labels 98 | plt.title(f'{task}') 99 | if not no_ylabel: 100 | plt.ylabel(f'Compression Rates (%)', fontsize=label_fontsize) 101 | 102 | # Adding a legend to differentiate the lines from each model 103 | plt.legend(fontsize=legend_fontsize) 104 | 105 | # Display the plot 106 | plt.grid(True) 107 | 108 | # x ticks are too dense, so we only show every 3rd tick 109 | plt.xticks(list(range(0, len(labels), 6)), rotation=45) 110 | 111 | if not x_ticks: 112 | plt.xticks([]) 113 | 114 | plt.gca().yaxis.set_major_locator(ticker.MultipleLocator(0.4)) 115 | plt.tight_layout() 116 | 117 | plt.tick_params(axis='x', labelsize=tick_fontsize) 118 | plt.tick_params(axis='y', labelsize=tick_fontsize) 119 | 120 | plt.savefig(f'figs/{fig_type}-{task}.png') -------------------------------------------------------------------------------- /page/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | LLM-Compressive 7 | 8 | 9 | 16 | 60 | 61 | 62 | 63 |
64 | LLM-Compressive:
Longitudinal Evaluation of LLMs via Data Compression 65 |
66 | 67 |
68 |

69 | Paper  /  70 | Code  /  71 | 知乎 72 |

73 |
74 | 75 | 76 |
77 |

Outline

78 |
    79 |
  1. Intro
  2. 80 |
  3. Issues
  4. 81 |
  5. Benchmark Performance
  6. 82 |
  7. Context Length Performance
  8. 83 |
84 |
85 | 86 |
87 | 88 |
89 |

1. Intro:

90 | LLM-Compressive evaluates LLMs via data compression on data collected every month from 2017 to 2024. 91 |
92 |
93 | We currently have sources include Code, Wikipedia, Math, arXiv, BBC News, Images, and Audio . 94 |
95 |
96 | All y-axis represent compression ratio (%, lower is better). Models maintain a constant compression ratio over time demonstrate good generalization, while models that degrade over time demonstrate overfitting. 97 |
98 | 99 |
100 |

2. Issues:

101 | If you have problems or want to request results of a new model, please head to our project page and open an issue. 102 |
103 | 104 |
105 |

3. Benchmark Performance:

106 |
107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 |
115 |

4. Context Length Performance:

116 |
117 | 118 | 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /visualise/interactive.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import plotly.graph_objects as go 3 | from plotly.subplots import make_subplots 4 | import plotly.express as px 5 | import plotly.offline as pyo 6 | import plotly.io as pio 7 | 8 | def exponential_smoothing(data, alpha = 0.3): 9 | """ 10 | Apply exponential smoothing to the data. 11 | :param data: List of data points. 12 | :param alpha: Smoothing factor, between 0 and 1. 13 | :return: List of smoothed data points. 14 | """ 15 | smoothed_data = [] 16 | for i, point in enumerate(data): 17 | if i == 0: 18 | # The first smoothed value is the first data point. 19 | smoothed_data.append(point) 20 | else: 21 | # Compute the smoothed value. 22 | new_smoothed = alpha * point + (1 - alpha) * smoothed_data[i-1] 23 | smoothed_data.append(new_smoothed) 24 | return smoothed_data 25 | 26 | base_model_name_to_label = { 27 | 'Baichuan2-7B-Base': 'Baichuan2-7B', 28 | 'internlm-7B': 'Internlm-7B', 29 | 'Qwen-7B': 'Qwen-7B', 30 | 'Yi-6B': 'Yi-6B', 31 | 'chatglm3-6b-base': 'Chatglm3-6B', 32 | 'Mistral-7B': 'Mistral-7B', 33 | 'LLaMA-7B-HF': 'LLaMA-7B', 34 | 'LLaMA-13B': 'LLaMA-13B', 35 | 'Llama-2-13B': 'Llama-2-13B', 36 | 'Llama-2-7B-HF': 'Llama-2-7B', 37 | 'CodeLlama-7B': 'CodeLlama-7B', 38 | 'Llama-2-70B': 'Llama-2-70B', 39 | 'LLaMA-65B': 'LLaMA-65B', 40 | 'Yi-34B-200K': 'Yi-34B', 41 | 'Qwen1.5-7B': 'Qwen1.5-7B', 42 | 'gemma-7b': 'Gemma-7B', 43 | } 44 | 45 | long_context_models = [ 46 | 'Baichuan2' 47 | 'Mistral', 48 | 'Llama', 49 | 'LLaMA', 50 | 'Yi', 51 | 'chatglm3', 52 | 'gemma', 53 | 'Qwen', 54 | 'Qwen1.5', 55 | ] 56 | 57 | long_context_unselected_defaultly = [ 58 | 'LLaMA', 59 | ] 60 | 61 | unselected_defaultly = [ 62 | 'Internlm-7B', 63 | 'Chatglm3-6B', 64 | 'LLaMA-7B', 65 | 'LLaMA-13B', 66 | 'Llama-2-13B', 67 | 'Yi-34B', 68 | 'LLaMA-65B', 69 | 'CodeLlama-7B', 70 | 'Qwen-7B', 71 | 'Llama-2-70B', 72 | ] 73 | 74 | if __name__ == "__main__": 75 | # tasks = ['wikitext'] 76 | tasks = ['wikitext', 'arxiv', 'bbc_news', 'code', 'math', 'bbc_image'] 77 | line_styles = ['solid', 'dot', 'dash', 'longdash', 'dashdot'] 78 | markers = ['circle', 'square', 'diamond', 'cross', 'x'] 79 | 80 | df_for_tasks = {} 81 | for task in tasks: 82 | df = pd.read_json(f'results/{task}_results.json') 83 | df.index = df.index.strftime('%Y-%m') 84 | df.dropna(axis=1, how='any', inplace=True) 85 | df = df.applymap(lambda x: x * 100) 86 | df = df[[col for col in df.columns if col in base_model_name_to_label]] 87 | df = df.rename(columns=base_model_name_to_label) 88 | df_for_tasks[task] = df 89 | 90 | for task, df in df_for_tasks.items(): 91 | fig = go.Figure() 92 | for i, model in enumerate(df.columns): 93 | visible = "legendonly" if model in unselected_defaultly else True 94 | if 'codellama' in model.lower() and task in ['code', 'math', 'bbc_image']: 95 | visible = True 96 | if task != 'wikitext': 97 | y = exponential_smoothing(df[model]) 98 | else: 99 | y = df[model] 100 | fig.add_trace(go.Scatter(x=df.index, y=y, mode='lines+markers', name=model, line=dict(dash=line_styles[i%len(line_styles)]), 101 | marker=dict(symbol=markers[i%len(markers)], size=4), visible=visible)) 102 | 103 | fig.update_layout(title=task, xaxis_title='Date', yaxis_title='Compression Ratio', xaxis_fixedrange=True, yaxis_fixedrange=True) 104 | pio.write_html(fig, file=f'page/{task}.html', include_plotlyjs='cdn') 105 | 106 | # process results/long_wikitext_results.json 107 | df = pd.read_json('results/long_wikitext_results.json') 108 | df.index = df.index.strftime('%Y-%m') 109 | df = df.applymap(lambda x: x * 100) 110 | df = df[[col for col in df.columns if any([model.lower() in col.lower() for model in long_context_models])]] 111 | 112 | context_sizes = [] 113 | models = [] 114 | avg_perform = [] 115 | results = df.mean(axis=0) 116 | for model, perf in results.items(): 117 | model, context = model.rsplit('-', 1) 118 | context_sizes.append(int(context[:-1])) 119 | models.append(model) 120 | avg_perform.append(perf) 121 | 122 | plot_df = pd.DataFrame({'Model': models, 'Context Size': context_sizes, 'Average Performance': avg_perform}) 123 | plot_df = plot_df.sort_values(by='Context Size') 124 | fig = go.Figure() 125 | 126 | for i, model in enumerate(long_context_models): 127 | visible = "legendonly" if model in long_context_unselected_defaultly else True 128 | fig.add_trace(go.Scatter(x=plot_df[plot_df['Model'] == model]['Context Size'], y=plot_df[plot_df['Model'] == model]['Average Performance'], 129 | mode='lines+markers', name=model, line=dict(dash=line_styles[i%len(line_styles)]), 130 | marker=dict(symbol=markers[i%len(markers)], size=4), visible=visible)) 131 | fig.update_layout(title='Wikitext', xaxis_title='Context Size', yaxis_title='Compression Ratio (%, across all times)', xaxis_fixedrange=True, yaxis_fixedrange=True) 132 | fig.update_xaxes( 133 | tickvals=context_sizes, 134 | ticktext=[f'{size}k' for size in context_sizes] 135 | ) 136 | pio.write_html(fig, file='page/wikitext_context.html', include_plotlyjs='cdn') 137 | 138 | print('Done') 139 | -------------------------------------------------------------------------------- /arithmetic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.cpp_extension import load 5 | 6 | PRECISION = 32 7 | 8 | # Load on-the-fly with ninja. 9 | torchac_dir = os.path.dirname(os.path.realpath(__file__)) 10 | torchac_int32 = load( 11 | name="torchac_int32", 12 | sources=[os.path.join(torchac_dir, "arithmetic_int32.cpp")], 13 | verbose=True) 14 | 15 | def encode_float_cdf(cdf_float, 16 | sym, 17 | needs_normalization=True, 18 | check_input_bounds=False): 19 | """Encode symbols `sym` with potentially unnormalized floating point CDF. 20 | 21 | Check the README for more details. 22 | 23 | :param cdf_float: CDF tensor, float32, on CPU. Shape (N1, ..., Nm, Lp). 24 | :param sym: The symbols to encode, int32, on CPU. Shape (N1, ..., Nm). 25 | :param needs_normalization: if True, assume `cdf_float` is un-normalized and 26 | needs normalization. Otherwise only convert it, without normalizing. 27 | :param check_input_bounds: if True, ensure inputs have valid values. 28 | Important: may take significant time. Only enable to check. 29 | 30 | :return: byte-string, encoding `sym`. 31 | """ 32 | if check_input_bounds: 33 | if cdf_float.min() < 0: 34 | raise ValueError(f'cdf_float.min() == {cdf_float.min()}, should be >=0.!') 35 | if cdf_float.max() > 1: 36 | raise ValueError(f'cdf_float.max() == {cdf_float.max()}, should be <=1.!') 37 | Lp = cdf_float.shape[-1] 38 | if sym.max() >= Lp - 1: 39 | raise ValueError 40 | cdf_int = _convert_to_int_and_normalize(cdf_float, needs_normalization) 41 | return encode_int_normalized_cdf(cdf_int, sym) 42 | 43 | 44 | def decode_float_cdf(cdf_float, byte_stream, needs_normalization=True): 45 | """Encode symbols in `byte_stream` with potentially unnormalized float CDF. 46 | 47 | Check the README for more details. 48 | 49 | :param cdf_float: CDF tensor, float32, on CPU. Shape (N1, ..., Nm, Lp). 50 | :param byte_stream: byte-stream, encoding some symbols `sym`. 51 | :param needs_normalization: if True, assume `cdf_float` is un-normalized and 52 | needs normalization. Otherwise only convert it, without normalizing. 53 | 54 | :return: decoded `sym` of shape (N1, ..., Nm). 55 | """ 56 | cdf_int = _convert_to_int_and_normalize(cdf_float, needs_normalization) 57 | return decode_int_normalized_cdf(cdf_int, byte_stream) 58 | 59 | 60 | def encode_int_normalized_cdf(cdf_int, sym): 61 | """Encode symbols `sym` with a normalized integer cdf `cdf_int`. 62 | 63 | Check the README for more details. 64 | 65 | :param cdf_int: CDF tensor, int16, on CPU. Shape (N1, ..., Nm, Lp). 66 | :param sym: The symbols to encode, int16, on CPU. Shape (N1, ..., Nm). 67 | 68 | :return: byte-string, encoding `sym` 69 | """ 70 | cdf_int, sym = _check_and_reshape_inputs(cdf_int, sym) 71 | return torchac_int32.encode_cdf(cdf_int, sym) 72 | 73 | def decode_int_normalized_cdf(cdf_int, byte_stream): 74 | """Decode symbols in `byte_stream` with a normalized integer cdf `cdf_int`. 75 | 76 | Check the README for more details. 77 | 78 | :param cdf_int: CDF tensor, int16, on CPU. Shape (N1, ..., Nm, Lp). 79 | :param byte_stream: byte-stream, encoding some symbols `sym`. 80 | 81 | :return: decoded `sym` of shape (N1, ..., Nm). 82 | """ 83 | cdf_reshaped = _check_and_reshape_inputs(cdf_int) 84 | # Merge the m dimensions into one. 85 | sym = torchac_int32.decode_cdf(cdf_reshaped, byte_stream) 86 | return _reshape_output(cdf_int.shape, sym) 87 | 88 | 89 | def _check_and_reshape_inputs(cdf, sym=None): 90 | """Check device, dtype, and shapes.""" 91 | if cdf.is_cuda: 92 | raise ValueError('CDF must be on CPU') 93 | if sym is not None and sym.is_cuda: 94 | raise ValueError('Symbols must be on CPU') 95 | if sym is not None and sym.dtype != torch.int32: 96 | raise ValueError(f'Symbols must be int32! Got {sym.dtype}.') 97 | if sym is not None: 98 | if len(cdf.shape) != len(sym.shape) + 1 or cdf.shape[:-1] != sym.shape: 99 | raise ValueError(f'Invalid shapes of cdf={cdf.shape}, sym={sym.shape}! ' 100 | 'The first m elements of cdf.shape must be equal to ' 101 | 'sym.shape, and cdf should only have one more dimension.') 102 | Lp = cdf.shape[-1] 103 | cdf = cdf.reshape(-1, Lp) 104 | if sym is None: 105 | return cdf 106 | sym = sym.reshape(-1) 107 | return cdf, sym 108 | 109 | 110 | def _reshape_output(cdf_shape, sym): 111 | """Reshape single dimension `sym` back to the correct spatial dimensions.""" 112 | spatial_dimensions = cdf_shape[:-1] 113 | if len(sym) != np.prod(spatial_dimensions): 114 | raise ValueError() 115 | return sym.reshape(*spatial_dimensions) 116 | 117 | 118 | def _convert_to_int_and_normalize(cdf_float, needs_normalization): 119 | """Convert floatingpoint CDF to integers. See README for more info. 120 | 121 | The idea is the following: 122 | When we get the cdf here, it is (assumed to be) between 0 and 1, i.e, 123 | cdf \in [0, 1) 124 | (note that 1 should not be included.) 125 | We now want to convert this to int16 but make sure we do not get 126 | the same value twice, as this would break the arithmetic coder 127 | (you need a strictly monotonically increasing function). 128 | So, if needs_normalization==True, we multiply the input CDF 129 | with 2**16 - (Lp - 1). This means that now, 130 | cdf \in [0, 2**16 - (Lp - 1)]. 131 | Then, in a final step, we add an arange(Lp), which is just a line with 132 | slope one. This ensure that for sure, we will get unique, strictly 133 | monotonically increasing CDFs, which are \in [0, 2**16) 134 | """ 135 | Lp = cdf_float.shape[-1] 136 | factor = torch.tensor( 137 | 2, dtype=torch.float32, device=cdf_float.device).pow_(30) 138 | new_max_value = factor 139 | if needs_normalization: 140 | new_max_value = new_max_value - (Lp - 1) 141 | cdf_float = cdf_float.mul(new_max_value) 142 | cdf_float = cdf_float.round() 143 | cdf = cdf_float.to(dtype=torch.int32, non_blocking=True) 144 | if needs_normalization: 145 | r = torch.arange(Lp, dtype=torch.int32, device=cdf.device) 146 | cdf.add_(r) 147 | return cdf 148 | -------------------------------------------------------------------------------- /visualise/temporal.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import json 3 | import numpy as np 4 | import matplotlib.ticker as ticker 5 | 6 | task = 'math' 7 | with open(f'results/{task}_results.json') as f: 8 | data = json.load(f) 9 | 10 | small = True 11 | fig_type = '7B' 12 | code_llama = False 13 | large_llama = False 14 | x_ticks = True 15 | no_ylabel = False 16 | internlm = False 17 | small_llama=False 18 | glm = False 19 | mistral = True 20 | long_context = True 21 | smoothing_ratio = None 22 | 23 | # Create a plot 24 | if not small: 25 | plt.figure(figsize=(20, 10), dpi=160) 26 | markersize = 6 27 | legend_fontsize = 17 28 | tick_fontsize = 15 29 | label_fontsize = 16 30 | else: 31 | if no_ylabel: 32 | plt.figure(figsize=(8, 5), dpi=180) 33 | else: 34 | plt.figure(figsize=(8.2, 5), dpi=180) 35 | markersize = 3 36 | legend_fontsize = 10 37 | tick_fontsize = 12 38 | label_fontsize = 14 39 | 40 | def exponential_smoothing(data, alpha = 0.5): 41 | """ 42 | Apply exponential smoothing to the data. 43 | :param data: List of data points. 44 | :param alpha: Smoothing factor, between 0 and 1. 45 | :return: List of smoothed data points. 46 | """ 47 | smoothed_data = [] 48 | for i, point in enumerate(data): 49 | if i == 0: 50 | # The first smoothed value is the first data point. 51 | smoothed_data.append(point) 52 | else: 53 | # Compute the smoothed value. 54 | new_smoothed = alpha * point + (1 - alpha) * smoothed_data[i-1] 55 | smoothed_data.append(new_smoothed) 56 | return smoothed_data 57 | 58 | # Colorblind-friendly colors palette 59 | # Source: https://jfly.uni-koeln.de/color/ 60 | colors = ['#E69F00', '#56B4E9', '#009E73', '#F0E442', '#0072B2', '#D55E00', '#CC79A7'] 61 | 62 | # Different line styles 63 | line_styles = ['-', '--', '-.', ':'] 64 | 65 | # Different markers 66 | markers = ['o', 's', 'D', '^', 'v', '<', '>'] 67 | 68 | model_name_to_label = { 69 | 'Baichuan2-7B-Base': 'Baichuan2-7B', 70 | 'internlm-7B': 'Internlm-7B', 71 | 'Qwen-7B': 'Qwen-7B', 72 | 'Yi-6B': 'Yi-6B', 73 | 'chatglm3-6b-base': 'Chatglm3-6B', 74 | 'Mistral-7B': 'Mistral-7B', 75 | 'LLaMA-7B-HF': 'LLaMA-7B', 76 | 'LLaMA-13B': 'LLaMA-13B', 77 | 'Llama-2-13B': 'Llama-2-13B', 78 | 'Llama-2-7B-HF': 'Llama-2-7B', 79 | 'CodeLlama-7B': 'CodeLlama-7B', 80 | 'Llama-2-70B': 'Llama-2-70B', 81 | 'LLaMA-30B': 'LLaMA-30B', 82 | 'LLaMA-65B': 'LLaMA-65B', 83 | 'Yi-34B-200K': 'Yi-34B', 84 | 'Qwen1.5-7B': 'Qwen1.5-7B', 85 | 'Qwen-7B-12288-long': 'Qwen-7B-12K', 86 | 'Qwen1.5-7B-12288-long': 'Qwen1.5-7B-12K', 87 | } 88 | 89 | # Loop through each model's data and plot it 90 | counter = 0 91 | for model_name, model_data in data.items(): 92 | if model_name not in model_name_to_label: 93 | continue 94 | if not long_context and 'long' in model_name: 95 | continue 96 | if fig_type == 'llama': 97 | if not code_llama and model_name == 'CodeLlama-7B': 98 | continue 99 | if not large_llama and model_name in ['Llama-2-70B', 'LLaMA-30B', 'LLaMA-65B']: 100 | continue 101 | if model_name not in ['Llama-2-7B-HF', 'LLaMA-7B-HF', 'LLaMA-13B', 'Llama-2-13B', 'CodeLlama-7B', 'Llama-2-70B', 'LLaMA-30B', 'LLaMA-65B']: 102 | continue 103 | elif fig_type == '7B': 104 | if not ('7b' in model_name.lower() or '6b' in model_name.lower()): 105 | continue 106 | if not internlm and 'internlm' in model_name.lower(): 107 | continue 108 | if not code_llama and 'code' in model_name.lower(): 109 | continue 110 | if not small_llama and ('LLaMA-7B-HF' in model_name or 'Llama-2-7B' in model_name): 111 | continue 112 | if not glm and 'chatglm3' in model_name.lower(): 113 | continue 114 | if not mistral and 'mistral' in model_name.lower(): 115 | continue 116 | 117 | else: 118 | # model_specific ploting 119 | if fig_type not in model_name.lower(): 120 | continue 121 | 122 | labels = list(model_data.keys()) 123 | values = np.array([metrics['ratio'] for metrics in model_data.values()]) * 100 124 | 125 | # remove 2021-03 126 | remove_index = labels.index('2021-03') 127 | values[remove_index] = (values[remove_index - 1] + values[remove_index + 1]) / 2 128 | 129 | if smoothing_ratio is not None: 130 | values = exponential_smoothing(values, alpha=smoothing_ratio) 131 | 132 | # Use color and line style from our predefined lists 133 | color = colors[counter % len(colors)] 134 | line_style = line_styles[counter % len(line_styles)] 135 | marker = markers[counter % len(markers)] 136 | 137 | plt.plot(labels, values, color=color, linestyle=line_style, marker=marker, label=model_name_to_label[model_name], markersize=markersize) 138 | counter += 1 139 | 140 | # Adding title and labels 141 | if small: 142 | plt.title(f'{task}') 143 | if not no_ylabel: 144 | plt.ylabel(f'Compression Rates (%)', fontsize=label_fontsize) 145 | else: 146 | if not no_ylabel: 147 | plt.ylabel(f'Compression Rates - {task}', fontsize=label_fontsize) 148 | 149 | # Adding a legend to differentiate the lines from each model 150 | # plt.legend(fontsize=legend_fontsize) 151 | plt.legend(fontsize=legend_fontsize, loc=(0.01, 0.01)) 152 | # plt.legend(fontsize=legend_fontsize, loc=(0.82,0.73)) 153 | 154 | # Display the plot 155 | plt.grid(True) 156 | # plt.xticks(rotation=45, ha='right') 157 | 158 | # x ticks are too dense, so we only show every 3rd tick 159 | # plt.xticks(list(range(0, len(labels), 6)), rotation=45, ha='right') 160 | if small: 161 | plt.xticks(list(range(0, len(labels), 6)), rotation=45) 162 | else: 163 | plt.xticks(list(range(0, len(labels), 2)), rotation=45, ha='right') 164 | 165 | if not x_ticks: 166 | plt.xticks([]) 167 | 168 | plt.gca().yaxis.set_major_locator(ticker.MultipleLocator(0.4)) 169 | plt.tight_layout() 170 | 171 | plt.tick_params(axis='x', labelsize=tick_fontsize) 172 | plt.tick_params(axis='y', labelsize=tick_fontsize) 173 | 174 | plt.savefig(f'figs/{fig_type}-{task}.png') -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel 2 | import torch 3 | from data_processor import ( 4 | BBCNewsProcessor, 5 | BBCImageProcessor, 6 | WikiTextProcessor, 7 | CodeProcessor, 8 | ArxivProcessor, 9 | AudioProcessor, 10 | MathProcessor, 11 | ) 12 | import os 13 | import sys 14 | from tqdm import tqdm 15 | from evaluator import Metrics 16 | from auto_gptq import exllama_set_max_input_length 17 | import time as time_module 18 | 19 | model_max_context = { 20 | 'Baichuan2-7B-Base': 4096, 21 | 'chatglm3-6b-base': 32768, 22 | 'internlm-7B': 2048, 23 | 'LLaMA-13B': 2048, 24 | 'Llama-2-13B': 4096, 25 | 'Llama-2-70B': 4096, 26 | 'Llama-2-7B': 4096, 27 | 'Llama-2-7B-HF': 4096, 28 | 'LLaMA-30B': 2048, 29 | 'LLaMA-65B': 2048, 30 | 'LLaMA-7B': 2048, 31 | 'LLaMA-7B-HF': 2048, 32 | 'Qwen-7B': 32768, 33 | 'Yi-34B-200K': 200000, 34 | 'Yi-6B': 4096, 35 | 'Yi-6B-200K': 200000, 36 | 'Mistral-7B': 32768, 37 | 'Mistral-7B-Instruct': 32768, 38 | 'Qwen/Qwen1.5-7B': 32768, 39 | 'google/gemma-7b': 8192, 40 | } 41 | 42 | def prepare_data(data_name, save_path, tokenizer): 43 | all_time_stamps = [f'{year}-{month:02d}' for year in range(2017, 2024) for month in range(1, 13) if not (year == 2023 and month > 11)][53:] 44 | 45 | if data_name == 'bbc_news': 46 | data_path = 'RealTimeData/bbc_news_alltime' 47 | modality = 'text' 48 | processor = BBCNewsProcessor 49 | elif data_name == 'wikitext': 50 | data_path = 'RealTimeData/wikitext_alltime' 51 | modality = 'text' 52 | processor = WikiTextProcessor 53 | elif data_name == 'bbc_image': 54 | data_path = 'RealTimeData/bbc_images_alltime' 55 | modality = 'image' 56 | processor = BBCImageProcessor 57 | elif data_name == 'code': 58 | data_path = 'RealTimeData/code_alltime' 59 | modality = 'text' 60 | processor = CodeProcessor 61 | elif data_name == 'arxiv': 62 | data_path = 'RealTimeData/arxiv_alltime' 63 | modality = 'text' 64 | processor = ArxivProcessor 65 | elif data_name == 'audio': 66 | data_path = 'RealTimeData/audio_alltime' 67 | modality = 'audio' 68 | processor = AudioProcessor 69 | elif data_name == 'math': 70 | data_path = 'RealTimeData/math_alltime' 71 | modality = 'text' 72 | processor = MathProcessor 73 | 74 | all_data = [ 75 | processor( 76 | name = data_name, 77 | modality = modality, 78 | load_path = data_path, 79 | cache_path = os.path.join(save_path, data_name), 80 | tokenizer = tokenizer, 81 | config = time 82 | ) 83 | for time in all_time_stamps 84 | ] 85 | return all_data, modality 86 | 87 | def load_model_and_tokenizer(model_name): 88 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True) 89 | 90 | if 'chatglm' in model_name.lower(): 91 | model = AutoModel.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True) 92 | elif 'gptq' in model_name.lower(): 93 | model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', trust_remote_code=True, attn_implementation="flash_attention_2") 94 | elif 'yi' in model_name.lower() or 'mistral' in model_name.lower() or 'llama' in model_name.lower() or 'gemma' in model_name.lower(): 95 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True, attn_implementation="flash_attention_2") 96 | else: 97 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True) 98 | 99 | return model, tokenizer 100 | 101 | if __name__ == '__main__': 102 | model_name, data_name, save_path, context_size, batch_size, = sys.argv[1:] 103 | if 'qwen' in model_name.lower(): 104 | os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' 105 | 106 | batch_size = int(batch_size) 107 | if context_size == 'stride': 108 | context_size = 2048 109 | stride = 512 110 | elif context_size == 'max_length': 111 | # restrict the context size up to 8192 to prevent OOM 112 | context_size = min(model_max_context[model_name], 12288) 113 | stride = None 114 | else: 115 | context_size = int(context_size) 116 | stride = None 117 | 118 | # where is your model? use your path here. 119 | # Set model_path = model_name if you want to download the model from huggingface 120 | model_path = os.path.join('/mnt/fast/nobackup/scratch4weeks/yl02706/models', model_name) 121 | if not os.path.exists(model_path): 122 | model_path = model_name 123 | model, tokenizer = load_model_and_tokenizer(model_path) 124 | 125 | # resize buffer size for exllama, if gptq is used 126 | if getattr(model.config, 'quantization_config', None) is not None and model.config.quantization_config.use_exllama and model.config.quantization_config.desc_act: 127 | model = exllama_set_max_input_length(model, context_size*batch_size) 128 | 129 | all_data, modality = prepare_data(data_name, save_path, tokenizer) 130 | print(f'Data {data_name}, Modality {modality}') 131 | time_used = 0 132 | 133 | for data in all_data: 134 | name, time = data.name, data.config 135 | print(f'Processing {name} {time}...') 136 | 137 | data.prepare_batches(context_size, stride = stride) 138 | print(f'Total number of chunks: {data.metadata["num_chunks"]}') 139 | 140 | metrics = Metrics(modality, save_path, model_name, byte2id=data.byte2ids if modality != 'text' else None, use_arithmetic_coding=False) 141 | 142 | start_time = time_module.time() 143 | for i, chunk in enumerate(tqdm(data.batches(batch_size))): 144 | 145 | input_ids = torch.tensor(chunk, dtype=torch.long, device=model.device) 146 | input_ = {'input_ids': input_ids} 147 | with torch.no_grad(): 148 | output = model(**input_) 149 | 150 | logits = output.logits 151 | 152 | metrics.step(logits, input_ids, stride = stride) 153 | 154 | time_used += time_module.time() - start_time 155 | data.metadata['time_used'] = time_used 156 | 157 | metrics(data.stream, data.metadata, model_name) 158 | print(f'==== Finished processing {name} {time}. Self-info: {metrics.self_info_cache} ======') 159 | 160 | metrics.clear_cache() -------------------------------------------------------------------------------- /visualise/barplot_context_size_compare.py: -------------------------------------------------------------------------------- 1 | import json 2 | import matplotlib.ticker as ticker 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | months = [f'2023-{month:02d}' for month in range(1, 12)] 7 | do_plot = False 8 | 9 | with open('results/wikitext_results.json') as f: 10 | size1 = json.load(f) 11 | 12 | with open('results/wikitext_results_long.json') as f: 13 | size_long = json.load(f) 14 | 15 | with open('results/wikitext_results_4k.json') as f: 16 | size1_4k = json.load(f) 17 | 18 | with open('results/wikitext_results_stride.json') as f: 19 | size_stride = json.load(f) 20 | 21 | # with open('results/bbc_news_results.json') as f: 22 | # bbc_size1 = json.load(f) 23 | 24 | # with open('results/bbc_news_results_long.json') as f: 25 | # bbc_size_long = json.load(f) 26 | 27 | # with open('results/bbc_news_results_stride.json') as f: 28 | # bbc_size_stride = json.load(f) 29 | 30 | # plt.figure(figsize=(5, 2), dpi=180) 31 | 32 | # Colorblind-friendly colors palette 33 | # Source: https://jfly.uni-koeln.de/color/ 34 | colors = ['#E69F00', '#56B4E9', '#009E73', '#F0E442', '#0072B2', '#D55E00', '#CC79A7'] 35 | 36 | colors = ['lightgreen', 'salmon', 'lavender'] 37 | # colors = ['linen', 'salmon', 'lavender'] 38 | 39 | # (num_models, num_context_sizes) 40 | ratio_matrix = np.zeros((5, 4)) 41 | bbc_ratio_matrix = np.zeros((5, 3)) 42 | 43 | models = [] 44 | """ 45 | Baichuan2-7B-Base 46 | Qwen-7B 47 | chatglm3-6b-base 48 | Mistral-7B 49 | Llama-2-7B-HF 50 | """ 51 | context_label = { 52 | "Baichuan2-7B": ['2K', '2K+SW', '4K'], 53 | "Qwen-7B": ['2K', '2K+SW', '8K'], 54 | "chatglm3-6b": ['2K', '2K+SW', '8K'], 55 | "Mistral-7B": ['2K', '2K+SW', '8K'], 56 | "Llama-2-7B": ['2K', '2K+SW', '4K'], 57 | } 58 | 59 | fig, axs = plt.subplots(2,1, figsize=(5, 4), dpi=200, sharex=True) 60 | 61 | # dpi 62 | fig.dpi = 180 63 | 64 | for index, (model_name, data) in enumerate(size_long.items()): 65 | if model_name in ['flac', 'gzip', 'png', 'zlib']: 66 | continue 67 | long_values = [data[month]['ratio'] for month in months] 68 | avg_long_values = np.mean(long_values) 69 | 70 | stride_values = [size_stride[model_name][month]['ratio'] for month in months] 71 | avg_stride_values = np.mean(stride_values) 72 | 73 | size1_values = [size1[model_name][month]['ratio'] for month in months] 74 | avg_size1_values = np.mean(size1_values) 75 | 76 | if model_name in size1_4k: 77 | size1_4k_values = [size1_4k[model_name][month]['ratio'] for month in months] 78 | avg_size1_4k_values = np.mean(size1_4k_values) 79 | 80 | size1_8k_values = [data[month]['ratio'] for month in months] 81 | avg_size1_8k_values = np.mean(size1_8k_values) 82 | else: 83 | size1_4k_values = [data[month]['ratio'] for month in months] 84 | avg_size1_4k_values = np.mean(size1_4k_values) 85 | 86 | size1_8k_values = '-' 87 | avg_size1_8k_values = '-' 88 | 89 | print(model_name, '-', index) 90 | print('2K:', avg_size1_values) 91 | print('2K+SW:', avg_stride_values) 92 | print('4K:', avg_size1_4k_values) 93 | print('8K:', avg_size1_8k_values) 94 | 95 | ratio_matrix[index, 0] = avg_size1_values 96 | ratio_matrix[index, 1] = avg_stride_values 97 | ratio_matrix[index, 2] = avg_size1_4k_values 98 | ratio_matrix[index, 3] = avg_size1_8k_values if avg_size1_8k_values != '-' else 0 99 | 100 | # bbc_long_values = [bbc_size_long[model_name][month]['ratio'] for month in months] 101 | # bbc_avg_long_values = np.mean(bbc_long_values) 102 | 103 | # bbc_stride_values = [bbc_size_stride[model_name][month]['ratio'] for month in months] 104 | # bbc_avg_stride_values = np.mean(bbc_stride_values) 105 | 106 | # bbc_size1_values = [bbc_size1[model_name][month]['ratio'] for month in months] 107 | # bbc_avg_size1_values = np.mean(bbc_size1_values) 108 | 109 | # bbc_ratio_matrix[index, 0] = bbc_avg_size1_values 110 | # bbc_ratio_matrix[index, 1] = bbc_avg_stride_values 111 | # bbc_ratio_matrix[index, 2] = bbc_avg_long_values 112 | 113 | if model_name == 'Llama-2-7B-HF': 114 | model_name = 'Llama-2-7B' 115 | if model_name == 'chatglm3-6b-base': 116 | model_name = 'chatglm3-6b' 117 | if model_name == 'Baichuan2-7B-Base': 118 | model_name = 'Baichuan2-7B' 119 | 120 | models.append(model_name) 121 | 122 | if do_plot: 123 | 124 | x = np.arange(len(models)) # the label locations 125 | width = 0.25 # the width of the bars 126 | 127 | # for i, size in enumerate(['2K', '2K+SW', 'Max']): 128 | # # which model got this size? 129 | # plt.bar(x + i*width, ratio_matrix[:, i], width, label=size, color=colors[i]) 130 | 131 | for i, size in enumerate(['2K', '2K+SW', 'Max']): 132 | axs[0].bar(x + i*width, ratio_matrix[:, i], width, label=size, color=colors[i], edgecolor='black') 133 | 134 | for i, size in enumerate(['2K', '2K+SW', 'Max']): 135 | axs[1].bar(x + i*width, bbc_ratio_matrix[:, i] + ((i+1)%2)*np.random.random()*0.001, width, label=size, color=colors[i], edgecolor='black') 136 | 137 | # plt.title('Wikitext') 138 | # plt.xticks(x + width, models) 139 | # plt.legend() 140 | 141 | axs[0].set_title('Wikitext') 142 | axs[0].set_xticks(x + width) 143 | axs[0].set_xticklabels(models, fontsize=8) 144 | axs[0].legend(fontsize=8) 145 | 146 | axs[1].set_title('News') 147 | axs[1].set_xticks(x + width) 148 | axs[1].set_xticklabels(models, fontsize=8) 149 | # axs[1].legend(fontsize=8) 150 | 151 | 152 | # plt.ylim(0.07, 0.086) 153 | axs[0].set_ylim(0.07, 0.086) 154 | axs[1].set_ylim(0.072, 0.09) 155 | 156 | plt.savefig('figs/context_size.png') 157 | 158 | rate = { 159 | 'Baichuan2-7B': { 160 | '2K': ratio_matrix[0, 0], 161 | '2K+SW': ratio_matrix[0, 1], 162 | '4K': ratio_matrix[0, 2], 163 | '8K': ratio_matrix[0, 3], 164 | }, 165 | 'Qwen-7B': { 166 | '2K': ratio_matrix[3, 0], 167 | '2K+SW': ratio_matrix[3, 1], 168 | '4K': ratio_matrix[3, 2], 169 | '8K': ratio_matrix[3, 3], 170 | }, 171 | 'chatglm3-6b': { 172 | '2K': ratio_matrix[4, 0], 173 | '2K+SW': ratio_matrix[4, 1], 174 | '4K': ratio_matrix[4, 2], 175 | '8K': ratio_matrix[4, 3], 176 | }, 177 | 'Mistral-7B': { 178 | '2K': ratio_matrix[2, 0], 179 | '2K+SW': ratio_matrix[2, 1], 180 | '4K': ratio_matrix[2, 2], 181 | '8K': ratio_matrix[2, 3], 182 | }, 183 | 'Llama-2-7B': { 184 | '2K': ratio_matrix[1, 0], 185 | '2K+SW': ratio_matrix[1, 1], 186 | '4K': ratio_matrix[1, 2], 187 | '8K': ratio_matrix[1, 3], 188 | }, 189 | } 190 | 191 | import pandas as pd 192 | df = pd.DataFrame.from_dict(rate, orient='index') 193 | 194 | print(df.to_latex(float_format=lambda x: "{:.5f}".format(x).lstrip('0') if x != '-' else x)) -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | from compressor import ( 2 | arithmetic_coding, 3 | png_compressor, 4 | gzip_compressor, 5 | zlib_compressor, 6 | flac_compressor 7 | ) 8 | import os 9 | import json 10 | import torch 11 | import copy 12 | 13 | class Metrics: 14 | 15 | baselines = { 16 | 'png': png_compressor, 17 | 'gzip': gzip_compressor, 18 | 'zlib': zlib_compressor, 19 | 'flac': flac_compressor 20 | } 21 | 22 | def __init__(self, modality, save_path, model_name, baselines = ['png', 'zlib', 'flac'], byte2id = None, use_arithmetic_coding = True): 23 | self.baselines = {baseline: Metrics.baselines[baseline] for baseline in baselines} 24 | self.metrics = { 25 | 'bpb': Metrics._bpb, 26 | 'original_size': Metrics._original_size, 27 | 'ratio': Metrics._ratio, 28 | 'compressed_size': Metrics._compressed_size, 29 | 'context_size': Metrics._context_size, 30 | 'stride': Metrics._stride, 31 | 'batches': Metrics._num_chunks, 32 | 'time_used': Metrics._time_used, 33 | } 34 | self.modality = modality 35 | if modality == 'text': 36 | self.metrics['bpt'] = Metrics._bpt 37 | self.metrics['bpc'] = Metrics._bpc 38 | else: 39 | assert byte2id is not None 40 | self.byte2id = torch.tensor(byte2id, dtype=torch.long) 41 | 42 | self.save_path = save_path 43 | self.use_arithmetic_coding = use_arithmetic_coding 44 | 45 | def __call__(self, data_stream, metadata, model_name): 46 | # we will create a file for every model, under the save_dir. 47 | name = metadata['name'] 48 | time = metadata['time'] 49 | save_dir = os.path.join(self.save_path, name, time) 50 | if not os.path.exists(save_dir): 51 | os.makedirs(save_dir) 52 | 53 | if '/' in model_name: 54 | model_name = model_name.split('/')[-1] 55 | if metadata['context_size'] != 2048: 56 | model_name += f'-{metadata["context_size"]}' 57 | # check whether baselines have been computed before 58 | for baseline in self.baselines: 59 | baseline_result_path = os.path.join(save_dir, baseline + '.json') 60 | baseline_compressed_path = os.path.join(save_dir, baseline + '.compressed') 61 | if not os.path.exists(baseline_result_path): 62 | # compute baseline 63 | compressed_size = self.baselines[baseline](data_stream, save_path = baseline_compressed_path) 64 | baseline_metadata = copy.deepcopy(metadata) 65 | baseline_metadata['context_size'] = 'None' 66 | 67 | baseline_metrics = self._compute_metrics(compressed_size, baseline_metadata) 68 | with open(baseline_result_path, 'w') as f: 69 | json.dump(baseline_metrics, f, ensure_ascii=False, indent=2) 70 | 71 | # Now compute metrics for the model 72 | model_result_path = os.path.join(save_dir, model_name + '.json') 73 | if self.use_arithmetic_coding: 74 | model_compressed_path = os.path.join(save_dir, model_name + '.compressed') 75 | with open(model_compressed_path, 'wb') as f: 76 | f.write(self.arithmetic_coding_cache) 77 | 78 | compressed_size = self.self_info_cache / 8 79 | model_metrics = self._compute_metrics(compressed_size, metadata) 80 | 81 | if self.use_arithmetic_coding: 82 | compressed_size = len(self.arithmetic_coding_cache) 83 | ac_metrics = self._compute_metrics(compressed_size, metadata) 84 | ac_metrics = {'ac_' + metric: ac_metrics[metric] for metric in ac_metrics} 85 | model_metrics.update(ac_metrics) 86 | 87 | with open(model_result_path, 'w') as f: 88 | json.dump(model_metrics, f, ensure_ascii=False, indent=2) 89 | 90 | print('Metrics computed for model {} on dataset {}'.format(model_name, name)) 91 | 92 | def _cache_arithmetic_coding(self, pmf, sym, stride = None): 93 | # Due to pmf is extremely memory consuming, we thus do the 94 | # cache the arithmetic coding result 95 | if getattr(self, 'arithmetic_coding_cache', None) is None or self.arithmetic_coding_cache == b'': 96 | self.arithmetic_coding_cache = b'' 97 | elif stride is not None: 98 | # if stride is not None and the cache is not empty 99 | # then we only need to use pmf[:, -stride:, :] 100 | pmf = pmf[:, -stride:, :] 101 | sym = sym[:, -stride:] 102 | self.arithmetic_coding_cache += arithmetic_coding(pmf[:, :-1, :], sym[:, 1:]) 103 | 104 | def _cache_self_info(self, pmf, sym, stride = None): 105 | if getattr(self, 'self_info_cache', None) is None or self.self_info_cache == 0: 106 | self.self_info_cache = 0 107 | elif stride is not None: 108 | pmf = pmf[:, -stride:, :] 109 | sym = sym[:, -stride:] 110 | pmf = torch.clamp(pmf, min=1e-32) 111 | self_info = -torch.log2(pmf[:, :-1, :]).gather(dim=-1, index=sym[:, 1:].unsqueeze(-1)).squeeze(-1) 112 | self.self_info_cache += self_info.sum().item() 113 | 114 | def clear_cache(self): 115 | self.arithmetic_coding_cache = b'' 116 | self.self_info_cache = 0 117 | 118 | @torch.no_grad() 119 | def step(self, logits, sym, stride = None): 120 | # if logits in dtype torch.float16, torch.bfloat16, then is it very important to convert it to torch.float32 121 | if logits.dtype in [torch.float16, torch.bfloat16]: 122 | logits = logits.to(torch.float32) 123 | if self.modality != 'text': 124 | if self.byte2id.device != logits.device: 125 | self.byte2id = self.byte2id.to(logits.device) 126 | 127 | # we restrict the output space in the byte space 128 | true_logits = logits.index_select(dim=-1, index=self.byte2id) 129 | pmf = torch.softmax(true_logits, dim=-1) 130 | 131 | # map the byte symbol to the index in the pure byte space coordinating the true_logits 132 | _, _, new_sym = torch.nonzero(self.byte2id == sym.unsqueeze(-1), as_tuple=True) 133 | sym=new_sym.view(sym.shape) 134 | else: 135 | pmf = torch.softmax(logits, dim=-1) 136 | del logits 137 | 138 | if self.use_arithmetic_coding: 139 | self._cache_arithmetic_coding(pmf, sym.to(torch.int32), stride = stride) 140 | self._cache_self_info(pmf, sym, stride = stride) 141 | 142 | @staticmethod 143 | def _bpb(compressed_size, metadata): 144 | # bits per byte 145 | num_bytes = metadata['num_bytes'] 146 | return compressed_size * 8 / num_bytes 147 | 148 | @staticmethod 149 | def _original_size(compressed_size, metadata): 150 | # bits per byte 151 | num_bytes = metadata['num_bytes'] 152 | return num_bytes 153 | 154 | @staticmethod 155 | def _stride(compressed_size, metadata): 156 | return metadata['stride'] 157 | 158 | @staticmethod 159 | def _compressed_size(compressed_size, metadata): 160 | # bits per byte 161 | return compressed_size 162 | 163 | @staticmethod 164 | def _bpt(compressed_size, metadata): 165 | # bits per token 166 | num_tokens = metadata['num_tokens'] 167 | return compressed_size * 8 / num_tokens 168 | 169 | @staticmethod 170 | def _bpc(compressed_size, metadata): 171 | # bits per character 172 | num_chars = metadata['num_chars'] 173 | return compressed_size * 8 / num_chars 174 | 175 | @staticmethod 176 | def _ratio(compressed_size, metadata): 177 | # compression ratio 178 | num_bytes = metadata['num_bytes'] 179 | return compressed_size / num_bytes 180 | 181 | @staticmethod 182 | def _context_size(compressed_size, metadata): 183 | # context size in the compression 184 | return metadata['context_size'] 185 | 186 | @staticmethod 187 | def _num_chunks(compressed_size, metadata): 188 | # number of chunks 189 | return metadata['num_chunks'] 190 | 191 | @staticmethod 192 | def _time_used(compressed_size, metadata): 193 | # time used 194 | return metadata['time_used'] 195 | 196 | def _compute_metrics(self, compressed_size, metadata): 197 | metrics = {} 198 | for metric in self.metrics: 199 | metrics[metric] = self.metrics[metric](compressed_size, metadata) 200 | return metrics -------------------------------------------------------------------------------- /page/wikitext_context.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
5 |
6 | 7 | -------------------------------------------------------------------------------- /visualise/big_table.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | do_latex = True 5 | do_plot = False 6 | task = 'Wikitext' 7 | fig_type = '7B' 8 | 9 | with open('results/wikitext_results_v2.json') as f: 10 | wiki_results = json.load(f) 11 | 12 | with open('results/bbc_news_results_v2.json') as f: 13 | bbc_results = json.load(f) 14 | 15 | with open('results/bbc_image_results_v2.json') as f: 16 | bbc_image_results = json.load(f) 17 | 18 | with open('results/code_results_v2.json') as f: 19 | code_results = json.load(f) 20 | 21 | with open('results/arxiv_results_v2.json') as f: 22 | arxiv_results = json.load(f) 23 | 24 | with open('results/audio_results.json') as f: 25 | audio_results = json.load(f) 26 | 27 | model_name_to_label = { 28 | 'Baichuan2-7B-Base': 'Baichuan2-7B', 29 | 'internlm-7B': 'Internlm-7B', 30 | 'Qwen-7B': 'Qwen-7B', 31 | 'Yi-6B': 'Yi-6B', 32 | 'chatglm3-6b-base': 'Chatglm3-6B', 33 | 'Mistral-7B': 'Mistral-7B', 34 | 'LLaMA-7B-HF': 'LLaMA-7B', 35 | 'LLaMA-13B': 'LLaMA-13B', 36 | 'Llama-2-13B': 'Llama-2-13B', 37 | 'Llama-2-7B-HF': 'Llama-2-7B', 38 | 'CodeLlama-7B': 'CodeLlama-7B', 39 | 'Llama-2-70B': 'Llama-2-70B', 40 | 'LLaMA-30B': 'LLaMA-30B', 41 | 'LLaMA-65B': 'LLaMA-65B', 42 | 'Yi-34B-200K': 'Yi-34B', 43 | 'flac': 'Flac', 44 | 'png': 'Png', 45 | 'zlib': 'Zlib', 46 | } 47 | 48 | # 2017-2023, except 2023-12 49 | all_months = [f'{year}-{month:02d}' for year in range(2017, 2024) for month in range(1, 13) if not (year == 2023 and month == 12)] 50 | 51 | # 2023 52 | months_in_2023 = [f'2023-{month:02d}' for month in range(1, 12)] 53 | 54 | # pre-2023 55 | pre_2023_months = [f'{year}-{month:02d}' for year in range(2017, 2023) for month in range(1, 13)] 56 | 57 | models = {} 58 | 59 | def compute_results(data, model, task): 60 | ratios_all_months = [] 61 | for month, metrics in data.items(): 62 | if month in all_months: 63 | ratios_all_months.append(metrics['ratio']) 64 | 65 | if len(ratios_all_months) < 75: print(model, task, len(ratios_all_months)) 66 | avg_all_months = np.mean(ratios_all_months) 67 | 68 | ratios_2023 = [ data[month]['ratio'] for month in months_in_2023 if month in data] 69 | avg_2023 = np.mean(ratios_2023) 70 | 71 | pre_2023_avg = np.mean([ data[month]['ratio'] for month in pre_2023_months if month in data]) 72 | diff_train_test = avg_2023 - pre_2023_avg 73 | 74 | diff = avg_2023 - avg_all_months 75 | 76 | def float_formatter(x, digits=4): 77 | if x <1: 78 | num_str = ('{:.' + str(digits) + 'f}').format(x).lstrip('0') 79 | else: 80 | num_str = ('{:.' + str(digits) + 'g}').format(x) 81 | 82 | # pad 0 if not enough digits 83 | num_digits = len(num_str.replace('.', '')) 84 | if num_digits < digits: 85 | digits_to_pad = digits - num_digits 86 | if '.' not in num_str: 87 | num_str += '.' 88 | num_str += '0' * digits_to_pad 89 | return num_str 90 | 91 | avg_all_months_str = '{:.4f}'.format(avg_all_months).lstrip('0') 92 | avg_all_months_str_100 = float_formatter(avg_all_months * 100) 93 | avg_2023_str = '{:.4f}'.format(avg_2023).lstrip('0') 94 | avg_2023_str_100 = float_formatter(avg_2023 * 100) 95 | pre_2023_avg_str = '{:.4f}'.format(pre_2023_avg).lstrip('0') 96 | 97 | arrow = '↑' if diff > 0 else '↓' 98 | diff_str = '{:.4f}'.format(abs(diff)).lstrip('0') 99 | avg_2023_with_diff = f'{avg_2023_str} {arrow} {diff_str}' 100 | 101 | arrow2 = '↑' if diff_train_test > 0 else '↓' 102 | diff_train_test_str = '{:.4f}'.format(abs(diff_train_test)).lstrip('0') 103 | diff_train_test_str_100 = float_formatter(abs(diff_train_test) * 100, digits=3) 104 | 105 | pre_2023_avg_with_diff = f'{avg_2023_str} {arrow2} {diff_train_test_str}' 106 | pre_2023_avg_with_diff_ = f'{avg_2023_str_100} {arrow2} {diff_train_test_str_100}' 107 | 108 | return { 109 | # 'Avg.': avg_all_months_str, 110 | 'Avg.': avg_all_months_str_100, 111 | # '2023': pre_2023_avg_with_diff, 112 | '2023': pre_2023_avg_with_diff_, 113 | # # 'performance': avg_all_months, 114 | # 'train': pre_2023_avg_str, 115 | # 'test': pre_2023_avg_with_diff, 116 | # 'Avg.': avg_all_months, 117 | 'performance': avg_2023 * 100, 118 | 'robustness': diff_train_test * 100, 119 | } 120 | 121 | for model_name, data in wiki_results.items(): 122 | if model_name in ['LLaMA-7B', 'Llama-2-7B']: 123 | continue 124 | print(model_name) 125 | model_name_ = model_name_to_label[model_name] 126 | 127 | if model_name_ not in models: 128 | models[model_name_] = {} 129 | 130 | if model_name_ not in bbc_results: 131 | bbc_results[model_name_] = {} 132 | 133 | if model_name_ not in bbc_image_results: 134 | bbc_image_results[model_name_] = {} 135 | 136 | if model_name_ not in code_results: 137 | code_results[model_name_] = {} 138 | 139 | if model_name_ not in arxiv_results: 140 | arxiv_results[model_name_] = {} 141 | 142 | if model_name not in audio_results: 143 | audio_results[model_name] = {} 144 | 145 | wiki_results_ = compute_results(data, model_name, 'Wikitext') 146 | models[model_name_]['Wikitext'] = wiki_results_ 147 | 148 | news_results = compute_results(bbc_results[model_name], model_name, 'BBC News') 149 | # models[model_name]['BBC News'] = news_results 150 | models[model_name_]['BBC News'] = news_results 151 | 152 | image_results = compute_results(bbc_image_results[model_name], model_name, 'BBC Image') 153 | # models[model_name]['Image'] = image_results 154 | models[model_name_]['Image'] = image_results 155 | 156 | code_results_ = compute_results(code_results[model_name], model_name, 'Code') 157 | # models[model_name]['Code'] = code_results_ 158 | models[model_name_]['Code'] = code_results_ 159 | 160 | arxiv_results_ = compute_results(arxiv_results[model_name], model_name, 'Arxiv') 161 | # models[model_name]['Arxiv'] = arxiv_results_ 162 | models[model_name_]['Arxiv'] = arxiv_results_ 163 | 164 | audio_results_ = compute_results(audio_results[model_name], model_name, 'Audio') 165 | models[model_name_]['Audio'] = audio_results_ 166 | 167 | if do_latex: 168 | import pandas as pd 169 | 170 | df = pd.DataFrame.from_dict({(i, j): models[i][j] 171 | for i in models.keys() 172 | for j in models[i].keys()}, 173 | orient='index') 174 | df.index = pd.MultiIndex.from_tuples(df.index) 175 | df = df.unstack().swaplevel(0, 1, axis=1).sort_index(axis=1) 176 | 177 | # new_order = ['Avg.', '2023'] 178 | new_order = ['Avg.', '2023'] 179 | df = df.reindex(columns=[(col, subcol) for col in df.columns.levels[0] for subcol in new_order]) 180 | 181 | top_level_order = ['Wikitext', 'BBC News', 'Code', 'Arxiv', 'Image', 'Audio'] 182 | df = df[top_level_order] 183 | 184 | print(df) 185 | print(df.to_latex()) 186 | 187 | if do_plot: 188 | import matplotlib.pyplot as plt 189 | 190 | # Clamping robustness to a minimum of 0 191 | new_models = {} 192 | plt.figure(figsize=(8, 5), dpi=200) 193 | 194 | print('='*30) 195 | performances = [] 196 | robustnesses = [] 197 | for model in models: 198 | print(model) 199 | if model in ['flac', 'png', 'zlib']: 200 | continue 201 | # if not ('7b' not in model.lower() and '6b' not in model.lower()): 202 | if fig_type == '7B' and '7b' not in model.lower() and '6b' not in model.lower(): 203 | continue 204 | if fig_type == 'large' and '13b' not in model.lower() and '34b' not in model.lower() and '70b' not in model.lower() and '65b' not in model.lower(): 205 | continue 206 | # if 'llama' not in model.lower(): 207 | # continue 208 | models[model][task]['robustness'] = max(models[model][task]['robustness'], 0) 209 | new_models[model] = models[model] 210 | performances.append(models[model][task]['performance']) 211 | robustnesses.append(models[model][task]['robustness']) 212 | models = new_models 213 | 214 | max_performance = max(performances) 215 | min_performance = min(performances) 216 | max_robustness = max(robustnesses) 217 | min_robustness = min(robustnesses) 218 | 219 | performance_range = max_performance - min_performance 220 | robustness_range = max_robustness - min_robustness 221 | 222 | performance_padding = performance_range * 0.05 # Adding 5% padding 223 | robustness_padding = robustness_range * 0.05 # Adding 5% padding 224 | 225 | # Adjusting the axis limits with added padding 226 | plt.xlim(max_performance + performance_padding, min_performance - performance_padding) # Inverting x-axis 227 | plt.ylim(max_robustness + robustness_padding, min_robustness - robustness_padding) # Inverting y-axis 228 | 229 | mid_performance = (max_performance + min_performance) / 2 230 | mid_robustness = (max_robustness + min_robustness) / 2 231 | 232 | # Drawing lines to divide the plot into quadrants 233 | plt.axvline(x=mid_performance, color='grey', linestyle='--') 234 | plt.axhline(y=mid_robustness, color='grey', linestyle='--') 235 | 236 | # Plotting 237 | for model in models: 238 | plt.scatter(models[model][task]['performance'], models[model][task]['robustness'], label=model, marker='o') 239 | text_x_offset = 0.2 240 | text_y_offset = -0.002 241 | # if model == 'Baichuan2-7B': 242 | # text_x_offset = 0.4 243 | # text_y_offset = .004 244 | # if model == 'Yi-6B': 245 | # text_x_offset = 0 246 | # Check if label is close to the right edge 247 | if models[model][task]['performance'] > max_performance - performance_padding: 248 | text_x_offset = 0 # Move text to the left 249 | 250 | plt.text(models[model][task]['performance'] + text_x_offset, models[model][task]['robustness'] + text_y_offset, model, fontsize=12) 251 | 252 | plt.xlabel(f'Compression Rate - {task} (%, lower is better, axis inverted)', fontsize=12) 253 | plt.ylabel('Robustness (gap of train/test period, %)', fontsize=12) 254 | # plt.title('Model Performance vs Robustness') 255 | # plt.gca().invert_xaxis() # Inverting x-axis to have the best models on top-right 256 | # plt.gca().invert_yaxis() # Inverting y-axis as well 257 | plt.legend() 258 | # plt.legend(loc='lower left') 259 | # plt.legend(loc='upper left', fontsize='small') 260 | plt.grid(True) 261 | plt.tight_layout() 262 | plt.savefig(f'figs/robustness_performance_{task}_{fig_type}.png') -------------------------------------------------------------------------------- /arithmetic_int32.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | 16 | using cdf_t = uint32_t; 17 | 18 | /** Encapsulates a pointer to a CDF tensor */ 19 | struct cdf_ptr { 20 | const cdf_t* data; // expected to be a N_sym x Lp matrix, stored in row major. 21 | const int N_sym; // Number of symbols stored by `data`. 22 | const int Lp; // == L+1, where L is the number of possible values a symbol can take. 23 | cdf_ptr(const cdf_t* data, 24 | const int N_sym, 25 | const int Lp) : data(data), N_sym(N_sym), Lp(Lp) {}; 26 | }; 27 | 28 | /** Class to save output bit by bit to a byte string */ 29 | class OutCacheString { 30 | private: 31 | public: 32 | std::string out=""; 33 | uint8_t cache=0; 34 | uint8_t count=0; 35 | void append(const int bit) { 36 | cache <<= 1; 37 | cache |= bit; 38 | count += 1; 39 | if (count == 8) { 40 | out.append(reinterpret_cast(&cache), 1); 41 | count = 0; 42 | } 43 | } 44 | void flush() { 45 | if (count > 0) { 46 | for (int i = count; i < 8; ++i) { 47 | append(0); 48 | } 49 | assert(count==0); 50 | } 51 | } 52 | void append_bit_and_pending(const int bit, uint64_t &pending_bits) { 53 | append(bit); 54 | while (pending_bits > 0) { 55 | append(!bit); 56 | pending_bits -= 1; 57 | } 58 | } 59 | }; 60 | 61 | /** Class to read byte string bit by bit */ 62 | class InCacheString { 63 | private: 64 | const std::string& in_; 65 | 66 | public: 67 | explicit InCacheString(const std::string& in) : in_(in) {}; 68 | 69 | uint8_t cache=0; 70 | uint8_t cached_bits=0; // num 71 | size_t in_ptr=0; 72 | 73 | void get(uint32_t& value) { 74 | if (cached_bits == 0) { 75 | if (in_ptr == in_.size()){ 76 | value <<= 1; 77 | return; 78 | } 79 | /// Read 1 byte 80 | cache = (uint8_t) in_[in_ptr]; 81 | in_ptr++; 82 | cached_bits = 8; 83 | } 84 | value <<= 1; 85 | value |= (cache >> (cached_bits - 1)) & 1; 86 | cached_bits--; 87 | } 88 | 89 | void initialize(uint32_t& value) { 90 | for (int i = 0; i < 32; ++i) { 91 | get(value); 92 | } 93 | } 94 | }; 95 | 96 | const void check_sym(const torch::Tensor& sym) { 97 | TORCH_CHECK(sym.sizes().size() == 1, 98 | "Invalid size for sym. Expected just 1 dim.") 99 | } 100 | 101 | /** Get an instance of the `cdf_ptr` struct. */ 102 | const struct cdf_ptr get_cdf_ptr(const torch::Tensor& cdf) 103 | { 104 | TORCH_CHECK(!cdf.is_cuda(), "cdf must be on CPU!") 105 | const auto s = cdf.sizes(); 106 | TORCH_CHECK(s.size() == 2, "Invalid size for cdf! Expected (N, Lp)") 107 | 108 | const int N_sym = s[0]; 109 | const int Lp = s[1]; 110 | const auto cdf_acc = cdf.accessor(); 111 | const cdf_t* cdf_ptr = (uint32_t*)cdf_acc.data(); 112 | 113 | const struct cdf_ptr res(cdf_ptr, N_sym, Lp); 114 | return res; 115 | } 116 | 117 | py::bytes encode( 118 | const cdf_ptr& cdf_ptr, 119 | const torch::Tensor& sym){ 120 | 121 | OutCacheString out_cache; 122 | 123 | uint32_t low = 0; 124 | uint32_t high = 0xFFFFFFFFU; 125 | uint64_t pending_bits = 0; 126 | 127 | const int norm = 30; 128 | 129 | const cdf_t* cdf = cdf_ptr.data; 130 | const int N_sym = cdf_ptr.N_sym; 131 | const int Lp = cdf_ptr.Lp; 132 | const int max_symbol = Lp - 2; 133 | 134 | auto sym_ = sym.accessor(); 135 | 136 | for (int i = 0; i < N_sym; ++i) { 137 | const int32_t sym_i = sym_[i]; 138 | 139 | const uint64_t span = static_cast(high) - static_cast(low) + 1; 140 | 141 | const int offset = i * Lp; 142 | // Left boundary is at offset + sym_i 143 | const uint32_t c_low = cdf[offset + sym_i]; 144 | // Right boundary is at offset + sym_i + 1, except for the `max_symbol` 145 | // For which we hardcode the maxvalue. So if e.g. 146 | // L == 4, it means that Lp == 5, and the allowed symbols are 147 | // {0, 1, 2, 3}. The max symbol is thus Lp - 2 == 3. It's probability 148 | // is then given by c_max - 149 | const uint32_t c_high = sym_i == max_symbol ? 0x40000000U : cdf[offset + sym_i + 1]; 150 | 151 | high = (low - 1) + ((span * static_cast(c_high)) >> norm); 152 | low = (low) + ((span * static_cast(c_low)) >> norm); 153 | 154 | while (true) { 155 | if (high < 0x80000000U) { 156 | out_cache.append_bit_and_pending(0, pending_bits); 157 | low <<= 1; 158 | high <<= 1; 159 | high |= 1; 160 | } else if (low >= 0x80000000U) { 161 | out_cache.append_bit_and_pending(1, pending_bits); 162 | low <<= 1; 163 | high <<= 1; 164 | high |= 1; 165 | } else if (low >= 0x40000000U && high < 0xC0000000U) { 166 | pending_bits++; 167 | low <<= 1; 168 | low &= 0x7FFFFFFF; 169 | high <<= 1; 170 | high |= 0x80000001; 171 | } else { 172 | break; 173 | } 174 | } 175 | } 176 | 177 | pending_bits += 1; 178 | 179 | if (pending_bits) { 180 | if (low < 0x40000000U) { 181 | out_cache.append_bit_and_pending(0, pending_bits); 182 | } else { 183 | out_cache.append_bit_and_pending(1, pending_bits); 184 | } 185 | } 186 | 187 | out_cache.flush(); 188 | 189 | #ifdef VERBOSE 190 | std::chrono::steady_clock::time_point end= std::chrono::steady_clock::now(); 191 | std::cout << "Time difference (sec) = " << (std::chrono::duration_cast(end - begin).count()) /1000000.0 <((left + right) / 2); 222 | const auto v = cdf[offset + m]; 223 | if (v < target) { 224 | left = m; 225 | } else if (v > target) { 226 | right = m; 227 | } else { 228 | return m; 229 | } 230 | } 231 | 232 | return left; 233 | } 234 | 235 | 236 | torch::Tensor decode( 237 | const cdf_ptr& cdf_ptr, 238 | const std::string& in) { 239 | 240 | #ifdef VERBOSE 241 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 242 | #endif 243 | 244 | const cdf_t* cdf = cdf_ptr.data; 245 | const int N_sym = cdf_ptr.N_sym; // To know the # of syms to decode. is encoded in file! 246 | const int Lp = cdf_ptr.Lp; // To calculate offset 247 | const int max_symbol = Lp - 2; 248 | 249 | // 32 bit 250 | auto out = torch::empty({N_sym}, torch::kInt32); 251 | auto out_ = out.accessor(); 252 | 253 | uint32_t low = 0; 254 | uint32_t high = 0xFFFFFFFFU; 255 | uint32_t value = 0; 256 | const uint64_t c_count = 0x40000000U; 257 | const int norm = 30; 258 | 259 | InCacheString in_cache(in); 260 | in_cache.initialize(value); 261 | 262 | for (int i = 0; i < N_sym; ++i) { 263 | const uint64_t span = static_cast(high) - static_cast(low) + 1; 264 | // always < 0x10000 ??? 265 | const uint32_t count = ((static_cast(value) - static_cast(low) + 1) * c_count - 1) / span; 266 | 267 | const int offset = i * Lp; 268 | auto sym_i = binsearch(cdf, count, (cdf_t)max_symbol, offset); 269 | 270 | out_[i] = (int32_t)sym_i; 271 | 272 | if (i == N_sym-1) { 273 | break; 274 | } 275 | 276 | const uint32_t c_low = cdf[offset + sym_i]; 277 | const uint32_t c_high = sym_i == max_symbol ? 0x40000000U : cdf[offset + sym_i + 1]; 278 | 279 | high = (low - 1) + ((span * static_cast(c_high)) >> norm); 280 | low = (low) + ((span * static_cast(c_low)) >> norm); 281 | 282 | while (true) { 283 | if (low >= 0x80000000U || high < 0x80000000U) { 284 | low <<= 1; 285 | high <<= 1; 286 | high |= 1; 287 | in_cache.get(value); 288 | } else if (low >= 0x40000000U && high < 0xC0000000U) { 289 | /** 290 | * 0100 0000 ... <= value < 1100 0000 ... 291 | * <=> 292 | * 0100 0000 ... <= value <= 1011 1111 ... 293 | * <=> 294 | * value starts with 01 or 10. 295 | * 01 - 01 == 00 | 10 - 01 == 01 296 | * i.e., with shifts 297 | * 01A -> 0A or 10A -> 1A, i.e., discard 2SB as it's all the same while we are in 298 | * near convergence 299 | */ 300 | low <<= 1; 301 | low &= 0x7FFFFFFFU; // make MSB 0 302 | high <<= 1; 303 | high |= 0x80000001U; // add 1 at the end, retain MSB = 1 304 | value -= 0x40000000U; 305 | in_cache.get(value); 306 | } else { 307 | break; 308 | } 309 | } 310 | } 311 | 312 | #ifdef VERBOSE 313 | std::chrono::steady_clock::time_point end= std::chrono::steady_clock::now(); 314 | std::cout << "Time difference (sec) = " << (std::chrono::duration_cast(end - begin).count()) /1000000.0 <<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 68 | } 69 | model_name = model_name if model_name is not None else 'meta-llama/Llama-2-7b-hf' 70 | 71 | prompt = templates[model_name].format( 72 | question=question 73 | ) 74 | 75 | return prompt 76 | 77 | elif task == 'llm_scorer': 78 | templates = { 79 | 'meta-llama/Meta-Llama-3-8B-Instruct': ( 80 | "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" 81 | "Here is a question: \"{question}\". And its reference answer is \"{ref_answer}\". " 82 | "Now, your job is to judge whether the following student answer contains the reference answer? " 83 | "Student answer is: \"{out}\" " 84 | "You should reply \"Yes\" if the student answer includes reference answer; otherwise reply \"No\". " 85 | "Return only \"Yes\" or \"No\" as your judgement and return nothing else. " 86 | "<|eot_id|><|start_header_id|>assistant<|end_header_id|> My answer is: " 87 | ), 88 | 'gpt-4': ( 89 | "The question is: {question}. And its reference answer is {ref_answer}. " 90 | "Now, your job is to judge whether the following answer is correct? Student answer: {out} " 91 | "You should only reply \"Yes\" or \"No\" as your answer and return nothing else." 92 | ) 93 | } 94 | 95 | prompt = templates[model_name].format( 96 | question=question, 97 | ref_answer=ref_answer, 98 | out=out 99 | ) 100 | 101 | return prompt 102 | 103 | def make_test( 104 | self, 105 | ): 106 | tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf', use_fast=True) 107 | 108 | for time, dataset in tqdm(self.datasets.items(), desc='loop over months'): 109 | # skip? 110 | if self.do_skip: 111 | f = os.path.join(self.question_output_dir, f'{time}.json') 112 | if os.path.exists(f): 113 | with open(f, 'r') as f: 114 | self.test_set[time] = json.load(f) 115 | continue 116 | 117 | count = 0 118 | if time not in self.test_set: 119 | self.test_set[time] = [] 120 | for article in tqdm(dataset): 121 | text = article['text'] 122 | 123 | chunked_text = self._chunked_text(text, tokenizer, max_len=3072) 124 | prompt = self._make_prompt(task='question_generation', article={'title': article['title'], 'content': chunked_text}) 125 | response = self._make_llm_request(prompt) 126 | response = self._parse_response( 127 | task='question_generation', 128 | response=response 129 | ) 130 | if response is not None: 131 | count += 1 132 | 133 | self.test_set[time].append({ 134 | 'title': article['title'], 135 | 'prompt': prompt, 136 | 'response': response 137 | }) 138 | 139 | if count >= self.num_examples_per_time: 140 | break 141 | 142 | if count % 10 == 0: 143 | self.save_to_file( 144 | task='question_generation', 145 | time=time 146 | ) 147 | 148 | self.save_to_file( 149 | task='question_generation', 150 | time=time 151 | ) 152 | 153 | print('Finished making test set.') 154 | 155 | def get_answer( 156 | self, 157 | model_name: str 158 | ): 159 | self.answer_set = {} 160 | 161 | for time, test_set in tqdm(self.test_set.items(), desc='loop over months'): 162 | # skip? 163 | if self.do_skip: 164 | f = os.path.join(self.answer_output_dir, f'{time}.json') 165 | if os.path.exists(f): 166 | with open(f, 'r') as f: 167 | self.answer_set[time] = json.load(f) 168 | continue 169 | 170 | if time not in self.answer_set: 171 | self.answer_set[time] = [] 172 | 173 | questions = [] 174 | ref_answers = [] 175 | titles = [] 176 | prompts = [] 177 | for example in test_set: 178 | if example['response'] is None: 179 | continue 180 | 181 | title = example['title'] 182 | for qa in example['response']: 183 | try: 184 | query = qa['question'] 185 | ref_answer = qa['answer'] 186 | except: 187 | continue 188 | 189 | prompts.append(self._make_prompt( 190 | task='answer_generation', 191 | question=query, 192 | model_name=model_name 193 | )) 194 | questions.append(query) 195 | ref_answers.append(ref_answer) 196 | titles.append(title) 197 | 198 | outs = self._inference_with_vllm( 199 | task='answer_generation', 200 | prompts=prompts, 201 | # model_name='meta-llama/Llama-2-7b-hf' 202 | model_name=model_name, 203 | ) 204 | 205 | for question, ref_answer, out, title in zip(questions, ref_answers, outs, titles): 206 | gen_answers = [o.text for o in out.outputs] 207 | self.answer_set[time].append({ 208 | 'question': question, 209 | 'ref_answer': ref_answer, 210 | 'prompt': out.prompt, 211 | 'out': gen_answers, 212 | 'title': title, 213 | 'accuracy': self.auto_scorer(gen_answers, ref_answer), 214 | }) 215 | 216 | self.save_to_file( 217 | task='answer_generation', 218 | time=time 219 | ) 220 | 221 | def llm_scorer( 222 | self, 223 | model_name, 224 | ): 225 | for time, answer_set in tqdm(self.answer_set.items(), desc='loop over months'): 226 | # skip? 227 | # if self.do_skip: 228 | # f = os.path.join(self.scored_output_dir, f'{time}.json') 229 | # if os.path.exists(f): 230 | # with open(f, 'r') as f: 231 | # self.answer_set[time] = json.load(f) 232 | # continue 233 | 234 | for example in answer_set: 235 | ref_answer = example['ref_answer'] 236 | query = example['question'] 237 | 238 | prompts = [ 239 | self._make_prompt( 240 | task='llm_scorer', 241 | question=query, 242 | ref_answer=ref_answer, 243 | out=o.split('\n')[0], 244 | model_name=model_name 245 | ) 246 | for o in example['out'] 247 | ] 248 | 249 | responses = self._inference_with_vllm( 250 | task='llm_scorer', 251 | prompts=prompts, 252 | model_name=model_name 253 | ) 254 | 255 | num_correct = 0 256 | num_wrong = 0 257 | judges = [] 258 | for r in responses: 259 | tr = r.outputs[0].text 260 | judges.append(tr.strip()) 261 | if 'yes' in tr.lower(): 262 | num_correct += 1 263 | elif 'no' in tr.lower(): 264 | num_wrong += 1 265 | 266 | if num_correct + num_wrong == 0: 267 | example['llm_acc'] = None 268 | else: 269 | example['llm_acc'] = num_correct / (num_correct + num_wrong) 270 | example['judges'] = judges 271 | 272 | self.save_to_file( 273 | task='llm_scorer', 274 | time=time 275 | ) 276 | 277 | print('Finished getting answers.') 278 | 279 | if __name__ == '__main__': 280 | # every months from 2017-01 to 2024-02 281 | times = [f'{year}-{month:02}' for year in range(2017, 2025) for month in range(1, 13)] 282 | times = times[:86] 283 | 284 | factqa = KnowledgeQA( 285 | task_name='knowledge_qa', 286 | dataset_name='RealTimeData/wikitext_alltime', 287 | times=times, 288 | data_dir='real_tasks/knowledge_qa', 289 | output_dir='real_tasks/knowledge_qa/knowledge_qa_llama_2_7b_chat', 290 | num_examples_per_time=50, 291 | # do_skip=False 292 | ) 293 | 294 | factqa.make_test() 295 | # factqa.postprocess() 296 | # factqa.get_answer(model_name='meta-llama/Llama-2-7b-hf') 297 | factqa.get_answer(model_name='meta-llama/Llama-2-7b-chat-hf') 298 | # # factqa.get_answer(model_name='meta-llama/Meta-Llama-3-8B') 299 | factqa.llm_scorer(model_name='meta-llama/Meta-Llama-3-8B-Instruct') -------------------------------------------------------------------------------- /real_tasks/code_completion.py: -------------------------------------------------------------------------------- 1 | from fact_qa import FactQA, edit_distance_scorer 2 | import os 3 | import json 4 | import numpy as np 5 | from tqdm import tqdm 6 | from transformers import AutoTokenizer 7 | 8 | class CodeCompletion(FactQA): 9 | def _make_prompt( 10 | self, 11 | task, # choose from ['question_generation', 'post_processing', 'answer_generation'] 12 | article = None, 13 | question = None, 14 | ref_answer = None, 15 | last_line = None, 16 | out = None, 17 | model_name = None, 18 | ): 19 | if task == 'post_processing': 20 | assert last_line is not None 21 | 22 | last_line = last_line.strip() 23 | prompt = ( 24 | "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{last_line}\n\n" 25 | "Is the above line real code or just comment? Reply \"Yes\" if it's code, reply \"No\" if it's comment (for example has # or // at the begining). " 26 | "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 27 | ) 28 | prompt = prompt.format( 29 | last_line=last_line 30 | ) 31 | 32 | return prompt 33 | 34 | elif task == 'answer_generation': 35 | templates = { 36 | 'meta-llama/Llama-2-7b-hf': "Question: {question}\nAnswer: ", 37 | 'meta-llama/Llama-2-7b-chat-hf': "[INST] {question} [/INST]", 38 | 'meta-llama/Meta-Llama-3-8B': "Question: {question}\nAnswer: ", 39 | 'meta-llama/Meta-Llama-3-8B-Instruct': "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 40 | } 41 | model_name = model_name if model_name is not None else 'meta-llama/Llama-2-7b-hf' 42 | 43 | prompt = templates[model_name].format( 44 | question=question 45 | ) 46 | 47 | return prompt 48 | 49 | elif task == 'llm_scorer': 50 | assert ref_answer is not None and out is not None 51 | 52 | ref_answer = ref_answer.strip() 53 | out = out.strip() 54 | templates = { 55 | 'meta-llama/Meta-Llama-3-8B-Instruct': ( 56 | "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" 57 | "Here are a line of code: the reference code A. \"{ref_answer}\"; and the student code B. \"{out}\". " 58 | "Now, your job is to judge whether the student answer (B) reflect the reference answer (A)? " 59 | "You should reply \"Yes\" if the student answer make sense according to the reference answer; otherwise reply \"No\". " 60 | "Return only \"Yes\" or \"No\" as your judgement and return nothing else. " 61 | "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 62 | ), 63 | 'gpt-4': ( 64 | "The question is: {question}. And its reference answer is {ref_answer}. " 65 | "Now, your job is to judge whether the following answer is correct? Student answer: {out} " 66 | "You should only reply \"Yes\" or \"No\" as your answer and return nothing else." 67 | ) 68 | } 69 | 70 | prompt = templates[model_name].format( 71 | question=question, 72 | ref_answer=ref_answer, 73 | out=out 74 | ) 75 | 76 | return prompt 77 | 78 | else: 79 | raise ValueError(f'Unknown task: {task}') 80 | 81 | def make_example( 82 | self, 83 | code, 84 | tokenizer 85 | ): 86 | # retrieve a random block from the code (about 100 lines of code) 87 | # in the block, use the last line as the answer to be generated, the rest as the context 88 | 89 | # block 90 | all_lines = [i for i in code.split('\n') if i.strip()] 91 | if len(all_lines) < 120: 92 | return None 93 | 94 | random_block_start = np.random.randint(0, len(all_lines) - 110) 95 | block = all_lines[random_block_start:random_block_start+100] 96 | 97 | # find a last line that is longer than 30 characters 98 | while True: 99 | if not block: 100 | return None 101 | last_line = block.pop() 102 | tokenized_last_line = tokenizer.encode(last_line, add_special_tokens=False) 103 | if len(tokenized_last_line) > 12: 104 | break 105 | 106 | prompt_len = 0 107 | num_lines_in_prompt = 0 108 | for i in range(len(block), 0, -1): 109 | i-=1 110 | line = block[i] 111 | prompt_len += len(tokenizer.encode(line, add_special_tokens=False)) 112 | if prompt_len > 2048: 113 | break 114 | num_lines_in_prompt += 1 115 | 116 | block = block[-num_lines_in_prompt:] 117 | block = '\n'.join(block) 118 | 119 | # add a small prompt from the last line to the block 120 | prompting_tokens = tokenizer.decode(tokenized_last_line[:3]) 121 | left_tokens = tokenizer.decode(tokenized_last_line[3:]) 122 | 123 | return { 124 | 'code': block, 125 | 'last_line': last_line, 126 | 'prompting_tokens': prompting_tokens, 127 | 'left_tokens': left_tokens 128 | } 129 | 130 | def make_test( 131 | self 132 | ): 133 | tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf', use_fast=True) 134 | for time, dataset in tqdm(self.datasets.items(), desc='loop over months'): 135 | # skip? 136 | if self.do_skip: 137 | f = os.path.join(self.question_output_dir, f'{time}.json') 138 | if os.path.exists(f): 139 | with open(f, 'r') as f: 140 | self.test_set[time] = json.load(f) 141 | continue 142 | 143 | count = 0 144 | if time not in self.test_set: 145 | self.test_set[time] = [] 146 | for example in tqdm(dataset): 147 | code = example['code'] 148 | example = self.make_example(code, tokenizer) 149 | 150 | if example is None: 151 | continue 152 | 153 | count += 1 154 | self.test_set[time].append(example) 155 | 156 | if count >= self.num_examples_per_time: 157 | break 158 | 159 | if count % 50 == 0: 160 | self.save_to_file( 161 | task='question_generation', 162 | time=time 163 | ) 164 | 165 | self.save_to_file( 166 | task='question_generation', 167 | time=time 168 | ) 169 | 170 | print('Finished making test set.') 171 | 172 | def get_answer( 173 | self, 174 | model_name: str 175 | ): 176 | self.answer_set = {} 177 | 178 | for time, test_set in tqdm(self.test_set.items(), desc='loop over months'): 179 | # skip? 180 | if self.do_skip: 181 | f = os.path.join(self.answer_output_dir, f'{time}.json') 182 | if os.path.exists(f): 183 | with open(f, 'r') as f: 184 | self.answer_set[time] = json.load(f) 185 | continue 186 | 187 | if time not in self.answer_set: 188 | self.answer_set[time] = [] 189 | 190 | blocks = [example['code'] for example in test_set if example is not None] 191 | prompting = [example['prompting_tokens'] for example in test_set if example is not None] 192 | last_lines = [example['last_line'] for example in test_set if example is not None] 193 | prompts = [example['code'] + '\n' + example['prompting_tokens'] for example in test_set if example is not None] 194 | ref_answers = [example['left_tokens'] for example in test_set if example is not None] 195 | 196 | outs = self._inference_with_vllm( 197 | task='answer_generation', 198 | prompts=prompts, 199 | model_name=model_name 200 | ) 201 | 202 | for block, left_tokens, last_line, prompting_tokens, out in zip(blocks, ref_answers, last_lines, prompting, outs): 203 | self.answer_set[time].append({ 204 | 'code': block, 205 | 'prompt': out.prompt, 206 | 'prompting_tokens': prompting_tokens, 207 | 'last_line': last_line, 208 | 'left_tokens': left_tokens, 209 | 'out': [o.text for o in out.outputs], 210 | }) 211 | 212 | self.save_to_file( 213 | task='answer_generation', 214 | time=time 215 | ) 216 | 217 | print('Finished getting answers.') 218 | 219 | def postprocess( 220 | self, 221 | model_name: str 222 | ): 223 | for time, test_set in tqdm(self.test_set.items(), desc='loop over months'): 224 | # skip? 225 | if self.do_skip: 226 | f = os.path.join(self.postprocessed_output_dir, f'{time}.json') 227 | if os.path.exists(f): 228 | with open(f, 'r') as f: 229 | self.test_set[time] = json.load(f) 230 | continue 231 | 232 | finalized_test_set = [] 233 | prompts = [] 234 | for index, example in tqdm(enumerate(test_set), desc='loop over examples'): 235 | last_line = example['last_line'] 236 | prompt = self._make_prompt(task='post_processing', last_line=last_line) 237 | prompts.append(prompt) 238 | 239 | outs = self._inference_with_vllm( 240 | task='post_processing', 241 | prompts=prompts, 242 | model_name=model_name 243 | ) 244 | 245 | for example, out in zip(test_set, outs): 246 | response = out.outputs[0].text 247 | 248 | answerable = self._parse_response( 249 | task='post_processing', 250 | response=response 251 | ) 252 | 253 | if answerable: 254 | finalized_test_set.append(example) 255 | 256 | self.test_set[time] = finalized_test_set 257 | self.save_to_file( 258 | task='post_processing', 259 | time=time 260 | ) 261 | 262 | print('Finished postprocessing test set.') 263 | 264 | def llm_scorer( 265 | self, 266 | model_name, 267 | ): 268 | for time, answer_set in tqdm(self.answer_set.items(), desc='loop over months'): 269 | # skip? 270 | if self.do_skip: 271 | f = os.path.join(self.scored_output_dir, f'{time}.json') 272 | if os.path.exists(f): 273 | with open(f, 'r') as f: 274 | self.answer_set[time] = json.load(f) 275 | continue 276 | 277 | for example in answer_set: 278 | 279 | last_line = example['last_line'] 280 | prompting_tokens = example['prompting_tokens'] 281 | example['out'] = [ i.split('\n', 1)[0] for i in example['out']] 282 | 283 | prompts = [ 284 | self._make_prompt( 285 | task='llm_scorer', 286 | ref_answer=last_line, 287 | out= prompting_tokens + o, 288 | model_name=model_name 289 | ) 290 | for o in example['out'] 291 | ] 292 | 293 | responses = self._inference_with_vllm( 294 | task='llm_scorer', 295 | prompts=prompts, 296 | model_name=model_name 297 | ) 298 | 299 | auto_acc = edit_distance_scorer( 300 | example['out'], 301 | example['left_tokens'] 302 | ) 303 | 304 | num_correct = 0 305 | num_wrong = 0 306 | judges = [] 307 | for r in responses: 308 | judges.append(r.outputs[0].text) 309 | if 'yes' in r.outputs[0].text.lower(): 310 | num_correct += 1 311 | elif 'no' in r.outputs[0].text.lower(): 312 | num_wrong += 1 313 | 314 | if num_correct + num_wrong == 0: 315 | example['llm_acc'] = None 316 | else: 317 | example['llm_acc'] = num_correct / (num_correct + num_wrong) 318 | example['judges'] = judges 319 | example['auto_acc'] = auto_acc 320 | 321 | self.save_to_file( 322 | task='llm_scorer', 323 | time=time 324 | ) 325 | 326 | print('Finished getting answers.') 327 | 328 | if __name__ == '__main__': 329 | 330 | times = [f'{year}-{month:02}' for year in range(2017, 2025) for month in range(1, 13)] 331 | times = times[:86] 332 | 333 | cc = CodeCompletion( 334 | task_name='code_completion', 335 | dataset_name='RealTimeData/code_alltime', 336 | times=times, 337 | data_dir='real_tasks/cc', 338 | output_dir='real_tasks/cc/cc_llama_2_7b_chat', 339 | num_examples_per_time=400, 340 | do_skip=False 341 | ) 342 | 343 | cc.make_test() 344 | cc.postprocess(model_name='meta-llama/Meta-Llama-3-8B-Instruct') 345 | # cc.get_answer(model_name='meta-llama/Llama-2-7b-hf') 346 | cc.get_answer(model_name='meta-llama/Llama-2-7b-chat-hf') 347 | cc.llm_scorer(model_name='meta-llama/Meta-Llama-3-8B-Instruct') -------------------------------------------------------------------------------- /data_processor.py: -------------------------------------------------------------------------------- 1 | # data_processor.py is to prepare data chunks to compress 2 | import datasets 3 | import numpy as np 4 | import os 5 | from nltk.tokenize import sent_tokenize 6 | from transformers import AutoTokenizer 7 | import multiprocessing 8 | import pickle 9 | import time 10 | from tqdm import tqdm 11 | import re 12 | 13 | global_tokenizer = None 14 | def set_global_tokenizer(tokenizer): 15 | global global_tokenizer 16 | global_tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=True, trust_remote_code=True) 17 | 18 | class BaseProcessor: 19 | def __init__(self, name, modality, load_path, cache_path, tokenizer, config, total_size=2**23, chunk_size=2**11): 20 | # total_size: the full size of the data to be compressed 21 | # chunk_size: the size of each chunk. 22 | # default 2048, according to most LLMs' context size. 23 | # But it will change according to the model. 24 | # if modal is 'text', chunk_size is the number of tokens. 25 | # if modal is 'image', chunk_size is the number of pixels (bytes as in gray scale). 26 | # if modal is 'audio', chunk_size is the number of bytes. 27 | 28 | self.name = name 29 | self.modality = modality 30 | self.load_path = load_path 31 | self.cache_path = cache_path 32 | if not os.path.exists(self.cache_path): 33 | os.makedirs(self.cache_path) 34 | 35 | self.config = config 36 | 37 | self.tokenizer = tokenizer 38 | self.total_size = total_size 39 | 40 | self.chunk_size = chunk_size 41 | if self.modality == 'image': 42 | self.sample_patch_size = (32, 64) 43 | elif self.modality == 'text': 44 | self.sample_chunk_size = 2**14 45 | elif self.modality == 'audio': 46 | self.sample_chunk_size = 2**16 47 | 48 | # self._check_and_load_cached_input_ids() 49 | 50 | def _check_and_load_cached_input_ids(self): 51 | cache_path = os.path.join(self.cache_path, f'{self.config}.input_ids') 52 | if os.path.exists(cache_path): 53 | with open(cache_path, 'rb') as f: 54 | self.input_ids = pickle.load(f) 55 | 56 | def batches(self, batch_size): 57 | for i in range(0, len(self.chunks), batch_size): 58 | yield self.chunks[i:i+batch_size] 59 | 60 | def _cache_tokenized(self): 61 | assert getattr(self, 'input_ids', None) is not None, 'Please run _data_stream() first.' 62 | cache_path = os.path.join(self.cache_path, f'{self.config}.input_ids') 63 | with open(cache_path, 'wb') as f: 64 | pickle.dump(self.input_ids, f) 65 | print(f'Saved tokenized input_ids to {cache_path}') 66 | 67 | @staticmethod 68 | def _tokenize_chunk(text): 69 | return global_tokenizer(text, add_special_tokens=False)['input_ids'] 70 | 71 | class MultiModalProcessor(BaseProcessor): 72 | def __init__(self, name, modality, load_path, cache_path, tokenizer, config, total_size=2**20, chunk_size=2**11): 73 | super().__init__(name, modality, load_path, cache_path, tokenizer, config, total_size, chunk_size) 74 | 75 | self._load_dataset() 76 | self._prepare_tokenizer() 77 | self._data_stream() 78 | 79 | def _prepare_tokenizer(self): 80 | if 'qwen' in self.tokenizer.name_or_path: 81 | exit('You are using Qwen tokenizer on multimodal data. Qwen do not support byte tokenization. exiting...') 82 | 83 | byte2ids = np.zeros(256, dtype=np.int32) 84 | for byte in range(256): 85 | byte_token = f'<0x{byte:02X}>' # e.g. <0x00> 86 | byte2ids[byte] = self.tokenizer.convert_tokens_to_ids(byte_token) 87 | 88 | self.byte2ids = byte2ids 89 | 90 | @staticmethod 91 | def tokenize(byte_stream, byte2ids): 92 | # huggingface tokenizer only take string as input, so here we build a byte-ids mapping and make a byte tokenizer 93 | # this function returns a 1D list consists of token ids 94 | 95 | byte_array = np.frombuffer(byte_stream, dtype=np.uint8) 96 | ids = byte2ids[byte_array].tolist() 97 | return ids 98 | 99 | def prepare_batches(self, context_size, stride=None): 100 | assert getattr(self, 'input_ids', None) is not None, 'Please run _data_stream() first.' 101 | input_ids = self.input_ids 102 | 103 | if stride is None: 104 | # now we chunk the long input_ids into batches 105 | chunks = [input_ids[i:i+context_size] for i in range(0, len(input_ids), context_size)] 106 | last_chunk = chunks[-1] 107 | chunks = chunks[:-1] 108 | else: 109 | assert stride < context_size, 'Stride should be smaller than context size.' 110 | self.stride = stride 111 | chunks = [] 112 | for i in range(0, len(input_ids), stride): 113 | chunk = input_ids[i:i+context_size] 114 | if len(chunk) < context_size: 115 | break 116 | chunks.append(chunk) 117 | last_chunk = input_ids[i + (context_size - stride):] 118 | 119 | self.bytes_droped = len(last_chunk) 120 | self.chunks = chunks 121 | 122 | self.context_size = context_size 123 | self._metadata() 124 | 125 | def _metadata(self): 126 | num_chunks = len(self.chunks) 127 | num_bytes = self.context_size * num_chunks 128 | stride = getattr(self, 'stride', 'None') 129 | self.metadata = { 130 | 'name': self.name, 131 | 'modality': self.modality, 132 | 'time': self.config, 133 | 'load_path': self.load_path, 134 | 'cache_path': self.cache_path, 135 | 'total_size': self.total_size, 136 | 'context_size': self.context_size, 137 | 'num_chunks': num_chunks, 138 | 'num_bytes': num_bytes, 139 | 'stride': stride, 140 | } 141 | 142 | def _data_stream(self): 143 | print(f'Tokenizing Bytes {self.name} {self.config}...') 144 | self.stream = self.stream[:self.total_size] 145 | 146 | chunk_size = 2**13 147 | byte_chunks = [self.stream[i:i+chunk_size] for i in range(0, len(self.stream), chunk_size)] 148 | 149 | num_workers = multiprocessing.cpu_count() 150 | with multiprocessing.Pool(processes=num_workers) as pool: 151 | result = pool.starmap(self.tokenize, [(chunk, self.byte2ids) for chunk in byte_chunks]) 152 | 153 | input_ids = [] 154 | for chunk in result: 155 | input_ids += chunk 156 | 157 | self.input_ids = input_ids 158 | 159 | class TextProcessor(BaseProcessor): 160 | def __init__(self, name, modality, load_path, cache_path, tokenizer, config, total_size=2**23, chunk_size=2**11, **kwargs): 161 | super().__init__(name, modality, load_path, cache_path, tokenizer, config, total_size, chunk_size) 162 | 163 | self._load_dataset(**kwargs) 164 | self._data_stream() 165 | 166 | def _data_stream(self): 167 | print(f'Tokenizing {self.name} {self.config}...') 168 | start_time = time.time() 169 | 170 | # For code data, sents are a list of code files 171 | # For non-code data, sents are a list of sentences 172 | sents = [] 173 | size_counter = 0 174 | for sent in self.sents: 175 | size_counter += len(sent.encode()) 176 | if size_counter > self.total_size: 177 | break 178 | sents.append(sent) 179 | 180 | # For non-code data, we tokenize in sent-level. 181 | # So add space between sentences here. 182 | if self.name != 'code': 183 | for i in range(1, len(sents)): 184 | sents[i] = ' ' + sents[i] 185 | 186 | self.all_text = ''.join(sents) 187 | self.stream = self.all_text.encode('utf-8') 188 | 189 | if self.name in ['code', 'arxiv', 'math']: 190 | chunk_size = 1 191 | else: 192 | chunk_size = 100 193 | sent_chunks = [sents[i:i+chunk_size] for i in range(0, len(sents), chunk_size)] 194 | 195 | num_workers = multiprocessing.cpu_count() 196 | tokenizer_path = self.tokenizer.name_or_path 197 | with multiprocessing.Pool(processes=num_workers, initializer=set_global_tokenizer, initargs=(tokenizer_path,)) as pool: 198 | result = pool.map(self._tokenize_chunk, sent_chunks) 199 | 200 | input_ids = [] 201 | for chunk in result: 202 | for sent in chunk: 203 | input_ids += sent 204 | 205 | self.input_ids = input_ids 206 | 207 | end_time = time.time() 208 | print(f'Tokenization finished. Time used: {end_time - start_time:.2f}s') 209 | # self._cache_chunks_and_metadata() 210 | 211 | def prepare_batches(self, context_size, stride=None): 212 | assert getattr(self, 'input_ids', None) is not None, 'Please run _data_stream() first.' 213 | input_ids = self.input_ids 214 | 215 | if stride is None: 216 | # now we chunk the long input_ids into batches 217 | if self.tokenizer.bos_token_id is not None: 218 | max_length = context_size - 1 219 | else: 220 | max_length = context_size 221 | chunks = [input_ids[i:i+max_length] for i in range(0, len(input_ids), context_size)] 222 | if self.tokenizer.bos_token_id is not None: 223 | chunks = [[self.tokenizer.bos_token_id] + chunk for chunk in chunks] 224 | last_chunk = chunks[-1] 225 | chunks = chunks[:-1] 226 | 227 | self.text_droped = self.tokenizer.decode(last_chunk) 228 | else: 229 | assert stride < context_size, 'Stride should be smaller than context size.' 230 | self.stride = stride 231 | chunks = [] 232 | for i in range(0, len(input_ids), stride): 233 | chunk = input_ids[i:i+context_size] 234 | if len(chunk) < context_size: 235 | break 236 | chunks.append(chunk) 237 | self.text_droped = self.tokenizer.decode(input_ids[i + (context_size - stride):]) 238 | 239 | self.chunks = chunks 240 | self.context_size = context_size 241 | self._metadata() 242 | 243 | def _metadata(self): 244 | num_chunks = len(self.chunks) 245 | num_tokens = num_chunks * self.context_size 246 | num_chars = len(self.all_text) - len(self.text_droped) 247 | num_bytes = len(self.stream) - len(self.text_droped.encode('utf-8')) 248 | stride = getattr(self, 'stride', 'None') 249 | self.metadata = { 250 | 'name': self.name, 251 | 'modality': self.modality, 252 | 'time': self.config, 253 | 'load_path': self.load_path, 254 | 'cache_path': self.cache_path, 255 | 'total_size': self.total_size, 256 | 'context_size': self.context_size, 257 | 'num_chunks': num_chunks, 258 | 'num_tokens': num_tokens, 259 | 'num_chars': num_chars, 260 | 'num_bytes': num_bytes, 261 | 'stride': stride, 262 | } 263 | 264 | class BBCNewsProcessor(TextProcessor): 265 | 266 | def _load_dataset(self, num_articles=1000): 267 | ds = datasets.load_dataset(self.load_path, self.config, split='train') 268 | all_sents = [] 269 | if num_articles > len(ds): 270 | num_articles = len(ds) 271 | ds = ds.select(range(num_articles)) 272 | for article in ds: 273 | text = article['content'] 274 | sents = sent_tokenize(text) 275 | all_sents += sents 276 | self.sents = all_sents 277 | 278 | class WikiTextProcessor(TextProcessor): 279 | def _load_dataset(self, num_sents_per_article=80): 280 | ds = datasets.load_dataset(self.load_path, self.config, split='train') 281 | all_sents = [] 282 | for article in ds: 283 | text = article['text'] 284 | sents = sent_tokenize(text) 285 | all_sents += sents[:num_sents_per_article] 286 | self.sents = all_sents 287 | 288 | class CodeProcessor(TextProcessor): 289 | def _load_dataset(self): 290 | ds = datasets.load_dataset(self.load_path, self.config, split='train') 291 | ds = ds.shuffle(seed=42) 292 | 293 | all_sents = [] 294 | for code in ds: 295 | # If the code file is over 500 lines, we sample a 500 lines continues chunk from the code 296 | code_lines = code['code'].splitlines() 297 | if len(code_lines) > 500: 298 | start = np.random.randint(0, len(code_lines) - 500) 299 | code_lines = code_lines[start:start+500] 300 | code_ = '\n'.join(code_lines) 301 | all_sents.append(code_) 302 | 303 | self.sents = all_sents 304 | 305 | class ArxivProcessor(TextProcessor): 306 | def _load_dataset(self, num_sections = 2): 307 | ds = datasets.load_dataset(self.load_path, self.config, split='train') 308 | all_sents = [] 309 | for article in ds: 310 | text = article['text'] 311 | sections = self._process_article(text) 312 | if sections is None: 313 | continue 314 | all_sents += sections[:num_sections] 315 | self.sents = all_sents 316 | 317 | def _process_article(self, text): 318 | def beautify_context(context: str) -> str: 319 | context = context.replace("", '').replace('', '') 320 | context = re.sub(r"\s+", " ", context) 321 | context = re.sub(r"\n+", " ", context) 322 | return context 323 | 324 | text = re.sub(r"^.*?(§)", r"\1", text, flags=re.DOTALL) 325 | sections = re.split(r"(? max_len * 10: 119 | text = text[:max_len * 10] 120 | 121 | tokenized_text = tokenizer(text, add_special_tokens=False) 122 | chunked = tokenized_text['input_ids'][:max_len] 123 | 124 | return tokenizer.decode(chunked) 125 | 126 | def _make_prompt( 127 | self, 128 | task, # choose from ['question_generation', 'post_processing', 'answer_generation'] 129 | article = None, 130 | question = None, 131 | ref_answer = None, 132 | out = None, 133 | model_name = None, 134 | ): 135 | if task == 'question_generation': 136 | assert article is not None 137 | 138 | prompt = ( 139 | "Title: {title}\n\nDate: {published_date}\nTLDR\n{description}\nMain Context:\n{content}\n\n========\n\nWhat's the key information in this article? I am making some quiz questions with short and quick answers. " 140 | "Help me find the key information, including but not limited to who, what, where, and when. For each key piece of information, ask a corresponding question next to the answer. Returns up to three QA pairs as a json dictionary list." 141 | ) 142 | prompt = prompt.format( 143 | title=article['title'], 144 | published_date=article['published_date'], 145 | description=article['description'], 146 | content=article['content'] 147 | ) 148 | 149 | demonstration = """[ 150 | { 151 | "question": "Who has Stoke City signed?", 152 | "answer": "Saido Berahino" 153 | }, 154 | { 155 | "question": "From which club has Stoke City signed Saido Berahino?", 156 | "answer": "West Brom" 157 | }, 158 | { 159 | "question": "What is the fee Stoke City paid for Saido Berahino?", 160 | "answer": "£12m" 161 | } 162 | ]""" 163 | prompt += "Here is a demonstration of how your output should look like:\n" + demonstration 164 | 165 | return prompt 166 | 167 | elif task == 'post_processing': 168 | assert question is not None 169 | 170 | prompt = ( 171 | # "Is the following quiz question answerable? or they require more context? Some example of unanswerable question: What was the score of the match?/How many family members are there in this Yorkshire family? Some answerable questions: Why is Martin McGuinness retiring from frontline politics?/Where is Saffron Jackson from? You should answer \"Yes\" or \"No\" as your answer, and nothing else. Question: {question}" 172 | "Is the following quiz question answerable? You should answer \"Yes\" or \"No\" as your answer, and nothing else. Unless significant information is missing in the question, you should reply \"yes\". Question: {question}" 173 | ) 174 | prompt = prompt.format( 175 | question=question 176 | ) 177 | 178 | return prompt 179 | 180 | elif task == 'answer_generation': 181 | templates = { 182 | 'meta-llama/Llama-2-7b-hf': "Question: {question}\nAnswer: ", 183 | 'meta-llama/Llama-2-7b-chat-hf': "[INST] {question} [/INST]", 184 | 'meta-llama/Meta-Llama-3-8B': "Question: {question}\nAnswer: ", 185 | 'meta-llama/Meta-Llama-3-8B-Instruct': "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 186 | } 187 | model_name = model_name if model_name is not None else 'meta-llama/Llama-2-7b-hf' 188 | 189 | prompt = templates[model_name].format( 190 | question=question 191 | ) 192 | 193 | return prompt 194 | 195 | elif task == 'llm_scorer': 196 | templates = { 197 | 'meta-llama/Meta-Llama-3-8B-Instruct': ( 198 | "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" 199 | "Here is a question: \"{question}\". And its reference answer is \"{ref_answer}\". " 200 | "Now, your job is to judge whether the following student answer contains the reference answer? " 201 | "Student answer: {out} " 202 | "You should reply \"Yes\" if the student answer includes reference answer; otherwise reply \"No\". " 203 | "Return only \"Yes\" or \"No\" as your judgement and return nothing else. " 204 | "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 205 | ), 206 | 'gpt-4': ( 207 | "The question is: {question}. And its reference answer is {ref_answer}. " 208 | "Now, your job is to judge whether the following answer is correct? Student answer: {out} " 209 | "You should only reply \"Yes\" or \"No\" as your answer and return nothing else." 210 | ) 211 | } 212 | 213 | prompt = templates[model_name].format( 214 | question=question, 215 | ref_answer=ref_answer, 216 | out=out 217 | ) 218 | 219 | return prompt 220 | 221 | def _parse_response( 222 | self, 223 | task, # choose from ['question_generation', 'post_processing', 'answer_generation'] 224 | response 225 | ): 226 | if task == 'question_generation': 227 | try: 228 | response = eval(response) 229 | except: 230 | response = None 231 | 232 | return response 233 | 234 | elif task == 'post_processing': 235 | 236 | if 'yes' in response.lower(): 237 | return True 238 | elif 'no' in response.lower(): 239 | return False 240 | else: 241 | return None 242 | 243 | elif task == 'answer_generation': 244 | raise NotImplementedError 245 | 246 | def make_test( 247 | self, 248 | ): 249 | for time, dataset in tqdm(self.datasets.items(), desc='loop over months'): 250 | # skip? 251 | if self.do_skip: 252 | f = os.path.join(self.question_output_dir, f'{time}.json') 253 | if os.path.exists(f): 254 | with open(f, 'r') as f: 255 | self.test_set[time] = json.load(f) 256 | continue 257 | 258 | count = 0 259 | if time not in self.test_set: 260 | self.test_set[time] = [] 261 | for article in tqdm(dataset): 262 | prompt = self._make_prompt(task='question_generation', article=article) 263 | response = self._make_llm_request(prompt) 264 | response = self._parse_response( 265 | task='question_generation', 266 | response=response 267 | ) 268 | if response is not None: 269 | count += 1 270 | 271 | self.test_set[time].append({ 272 | 'prompt': prompt, 273 | 'article': article, 274 | 'response': response 275 | }) 276 | 277 | if count >= self.num_examples_per_time: 278 | break 279 | 280 | if count % 10 == 0: 281 | self.save_to_file( 282 | task='question_generation', 283 | time=time 284 | ) 285 | 286 | self.save_to_file( 287 | task='question_generation', 288 | time=time 289 | ) 290 | 291 | print('Finished making test set.') 292 | 293 | def postprocess( 294 | self, 295 | ): 296 | for time, test_set in tqdm(self.test_set.items(), desc='loop over months'): 297 | # skip? 298 | if self.do_skip: 299 | f = os.path.join(self.postprocessed_output_dir, f'{time}.json') 300 | if os.path.exists(f): 301 | with open(f, 'r') as f: 302 | self.test_set[time] = json.load(f) 303 | continue 304 | 305 | for index, example in tqdm(enumerate(test_set), desc='loop over examples'): 306 | if example['response'] is None: 307 | continue 308 | 309 | finalized_test_set = [] 310 | date = datetime.strptime(example['article']['published_date'], "%Y-%m-%d") 311 | date_str = date.strftime("%b %Y") 312 | for question in example['response']: 313 | try: 314 | question['question'] = f'{date_str}: {question["question"]}' 315 | except: 316 | continue 317 | prompt = self._make_prompt(task='post_processing', question=question['question']) 318 | response = self._make_llm_request(prompt) 319 | answerable = self._parse_response( 320 | task='post_processing', 321 | response=response 322 | ) 323 | 324 | if answerable: 325 | finalized_test_set.append(question) 326 | 327 | example['filtered_qa'] = finalized_test_set 328 | 329 | if index % 10 == 0: 330 | self.save_to_file( 331 | task='post_processing', 332 | time=time 333 | ) 334 | 335 | self.save_to_file( 336 | task='post_processing', 337 | time=time 338 | ) 339 | 340 | print('Finished postprocessing test set.') 341 | 342 | def get_answer( 343 | self, 344 | model_name: str 345 | ): 346 | self.answer_set = {} 347 | 348 | for time, test_set in tqdm(self.test_set.items(), desc='loop over months'): 349 | # skip? 350 | if self.do_skip: 351 | f = os.path.join(self.answer_output_dir, f'{time}.json') 352 | if os.path.exists(f): 353 | with open(f, 'r') as f: 354 | self.answer_set[time] = json.load(f) 355 | continue 356 | 357 | if time not in self.answer_set: 358 | self.answer_set[time] = [] 359 | 360 | questions = [] 361 | ref_answers = [] 362 | article_links = [] 363 | prompts = [] 364 | for example in test_set: 365 | if 'filtered_qa' not in example or example['filtered_qa'] is None: 366 | continue 367 | 368 | for qa in example['filtered_qa']: 369 | try: 370 | time_s, query = qa['question'].split(': ') 371 | except: 372 | print('Wrong format') 373 | print(qa['question']) 374 | continue 375 | query = f"In {time_s}, {query}" 376 | prompts.append(self._make_prompt( 377 | task='answer_generation', 378 | question=query, 379 | model_name=model_name 380 | )) 381 | questions.append(query) 382 | ref_answers.append(qa['answer']) 383 | article_links.append(example['article']['link']) 384 | 385 | outs = self._inference_with_vllm( 386 | task='answer_generation', 387 | prompts=prompts, 388 | # model_name='meta-llama/Llama-2-7b-hf' 389 | model_name=model_name, 390 | ) 391 | 392 | for question, ref_answer, out, link in zip(questions, ref_answers, outs, article_links): 393 | gen_answers = [o.text for o in out.outputs] 394 | self.answer_set[time].append({ 395 | 'question': question, 396 | 'ref_answer': ref_answer, 397 | 'prompt': out.prompt, 398 | 'out': gen_answers, 399 | 'link': link, 400 | 'accuracy': self.auto_scorer(gen_answers, ref_answer), 401 | }) 402 | 403 | self.save_to_file( 404 | task='answer_generation', 405 | time=time 406 | ) 407 | 408 | def auto_scorer( 409 | self, 410 | outs, 411 | ref_answer 412 | ): 413 | acc = np.mean([1 if ref_answer.lower() in o.lower() else 0 for o in outs]) 414 | return acc 415 | 416 | def llm_scorer( 417 | self, 418 | model_name, 419 | ): 420 | for time, answer_set in tqdm(self.answer_set.items(), desc='loop over months'): 421 | # skip? 422 | if self.do_skip: 423 | f = os.path.join(self.scored_output_dir, f'{time}.json') 424 | if os.path.exists(f): 425 | with open(f, 'r') as f: 426 | self.answer_set[time] = json.load(f) 427 | continue 428 | 429 | for example in answer_set: 430 | ref_answer = example['ref_answer'] 431 | query = example['question'] 432 | 433 | prompts = [ 434 | self._make_prompt( 435 | task='llm_scorer', 436 | question=query, 437 | ref_answer=ref_answer, 438 | out=o if '\nQuestion' not in o else o.split('\nQuestion')[0], 439 | model_name=model_name 440 | ) 441 | for o in example['out'] 442 | ] 443 | 444 | responses = self._inference_with_vllm( 445 | task='llm_scorer', 446 | prompts=prompts, 447 | model_name=model_name 448 | ) 449 | 450 | num_correct = 0 451 | num_wrong = 0 452 | judges = [] 453 | for r in responses: 454 | judges.append(r.outputs[0].text) 455 | if 'yes' in r.outputs[0].text.lower(): 456 | num_correct += 1 457 | elif 'no' in r.outputs[0].text.lower(): 458 | num_wrong += 1 459 | 460 | if num_correct + num_wrong == 0: 461 | example['llm_acc'] = None 462 | else: 463 | example['llm_acc'] = num_correct / (num_correct + num_wrong) 464 | example['judges'] = judges 465 | 466 | self.save_to_file( 467 | task='llm_scorer', 468 | time=time 469 | ) 470 | 471 | print('Finished getting answers.') 472 | 473 | def create_batch_request_to_OpenAI( 474 | path, 475 | output_path, 476 | tasks, 477 | ): 478 | os.makedirs(output_path, exist_ok=True) 479 | 480 | tempates = { 481 | 'llm_scorer': ( 482 | "Here is a question: \"{question}\". And its reference answer is \"{ref_answer}\". " 483 | "Now, your job is to judge whether the following answer is correct? Student answer: \"{out}\". " 484 | "You should reply \"Yes\" or \"No\" as your judgement and return nothing else." 485 | ) 486 | } 487 | 488 | files = glob(os.path.join(path, '*.json')) 489 | requests = [] 490 | for f in files: 491 | time = f.split('/')[-1].split('.')[0] 492 | output_file = os.path.join(output_path, f'{time}.json') 493 | if os.path.exists(output_file): 494 | continue 495 | 496 | with open(f, 'r') as f: 497 | data = json.load(f) 498 | 499 | if tasks == 'llm_scorer': 500 | for e_idx, example in enumerate(data): 501 | for o_idx, o in enumerate(example['out']): 502 | # time, query = example['question'].split(':', 1) 503 | # query = f"In {time}, {query}" 504 | 505 | if '\nQuestion' in o: 506 | o = o.split('\nQuestion')[0] 507 | # if len(o) > 2 * len(example['ref_answer']): 508 | # fo = '' 509 | # for i in o.split('\n'): 510 | # if not i.strip(): 511 | # continue 512 | # if len(fo) + len(i) < 2 * len(example['ref_answer']): 513 | # fo += i + '\n' 514 | # o = fo 515 | prompt = tempates[tasks].format( 516 | question=example['question'], 517 | ref_answer=example['ref_answer'], 518 | out=o 519 | ) 520 | 521 | requests.append( 522 | { 523 | "custom_id": f"{time}-{e_idx}-{o_idx}", 524 | "method": "POST", 525 | "url": "/v1/chat/completions", 526 | "body": { 527 | "model": "gpt-3.5-turbo-0125", 528 | "messages": [{"role": "user", "content": prompt}], 529 | "max_tokens": 8 530 | } 531 | } 532 | ) 533 | 534 | # write requests to jsonl 535 | with open(os.path.join(output_path, 'requests.jsonl'), 'w') as f: 536 | for r in requests: 537 | f.write(json.dumps(r) + '\n') 538 | 539 | client = OpenAI( 540 | api_key=os.environ.get('OPENAI_API_KEY'), 541 | ) 542 | batch_input_file = client.files.create( 543 | file=open("real_tasks/data/factqa_scored_openai/requests.jsonl", "rb"), 544 | purpose="batch" 545 | ) 546 | 547 | batch_input_file_id = batch_input_file.id 548 | 549 | b = client.batches.create( 550 | input_file_id=batch_input_file_id, 551 | endpoint="/v1/chat/completions", 552 | completion_window="24h", 553 | metadata={ 554 | "description": "factqa eval jobs" 555 | } 556 | ) 557 | 558 | print('file_id') 559 | print(batch_input_file_id) 560 | print('batch_meta_info') 561 | print(b) 562 | 563 | if __name__ == '__main__': 564 | # every months from 2017-01 to 2024-02 565 | times = [f'{year}-{month:02}' for year in range(2017, 2025) for month in range(1, 13)] 566 | times = times[:86] 567 | 568 | factqa = FactQA( 569 | task_name='factqa', 570 | dataset_name='RealTimeData/bbc_news_alltime', 571 | times=times, 572 | data_dir='real_tasks/data', 573 | output_dir='real_tasks/Llama-2-chat', 574 | num_examples_per_time=100 575 | ) 576 | 577 | factqa.make_test() 578 | factqa.postprocess() 579 | # factqa.get_answer(model_name='meta-llama/Llama-2-7b-hf') 580 | factqa.get_answer(model_name='meta-llama/Llama-2-7b-chat-hf') 581 | # factqa.get_answer(model_name='meta-llama/Meta-Llama-3-8B') 582 | factqa.llm_scorer(model_name='meta-llama/Meta-Llama-3-8B-Instruct') 583 | 584 | # create_batch_request_to_OpenAI( 585 | # path='real_tasks/data/factqa_scored', 586 | # output_path='real_tasks/data/factqa_scored_openai', 587 | # tasks='llm_scorer' 588 | # ) -------------------------------------------------------------------------------- /page/math.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
5 |
6 | 7 | --------------------------------------------------------------------------------