├── .gitignore ├── LICENSE ├── README.md ├── ZIP.py └── assets └── motivation.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Star-Team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ZIP 2 | 3 | ![Motivation](assets/motivation.png) 4 | 5 | This is the project for our paper [Entropy Law: The Story Behind Data Compression and LLM Performance](https://arxiv.org/abs/2407.06645). 6 | 7 | 8 | ## Quick start 9 | 10 | 11 | ### Data pool 12 | 13 | The data pool used in the paper can be found in [here](https://huggingface.co/datasets/AndrewZeng/deita_sota_pool), which is provided by the [DEITA](https://github.com/hkust-nlp/deita). And we appreciate their contribution. If you want to use ZIP to select your data, we only currently support the following sharegpt format: 14 | 15 | ```json 16 | [ 17 | { 18 | "id": 0, 19 | "conversations":[ 20 | { 21 | "from": "human", 22 | "value": "XXX", 23 | }, 24 | { 25 | "from": "gpt", 26 | "value": "XXX", 27 | } 28 | ], 29 | "source": "ShareGPT" 30 | }, 31 | { 32 | "id": 1, 33 | "conversations":[ 34 | { 35 | "from": "human", 36 | "value": "XXX", 37 | }, 38 | { 39 | "from": "gpt", 40 | "value": "XXX", 41 | } 42 | ], 43 | "source": "ShareGPT" 44 | } 45 | ] 46 | ``` 47 | ### Perform data selection 48 | 49 | ```shell 50 | python ZIP.py --data_path data_pool.json --save_path selected_data.json --budget 10000 51 | ``` 52 | 53 | ### LLM alignment & evaluation 54 | 55 | - We use [Axolotl](https://github.com/axolotl-ai-cloud/axolotl) to align LLMs with the selected data. 56 | - Then we use MT-bench in [FastChat](https://github.com/lm-sys/FastChat) to evaluate the aligned LLMs. 57 | 58 | ## Citation 59 | If you find the content of this project helpful, please cite our paper as follows: 60 | ``` 61 | @ARTICLE{2024arXiv240706645Y, 62 | author = {{Yin}, Mingjia and {Wu}, Chuhan and {Wang}, Yufei and {Wang}, Hao and {Guo}, Wei and {Wang}, Yasheng and {Liu}, Yong and {Tang}, Ruiming and {Lian}, Defu and {Chen}, Enhong}, 63 | title = "{Entropy Law: The Story Behind Data Compression and LLM Performance}", 64 | journal = {arXiv e-prints}, 65 | keywords = {Computer Science - Machine Learning, Computer Science - Computation and Language}, 66 | year = 2024, 67 | month = jul, 68 | doi = {10.48550/arXiv.2407.06645}, 69 | eprint = {2407.06645}, 70 | } 71 | ``` -------------------------------------------------------------------------------- /ZIP.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import json 4 | import zlib 5 | import torch 6 | import time 7 | import argparse 8 | from collections import defaultdict 9 | 10 | def load_original_data_pool(path): 11 | with open(path, 'r') as f: 12 | data = json.load(f) 13 | return data 14 | 15 | def get_compression_ratio(input_data): 16 | data_str = str(input_data).encode('utf-8') 17 | compressed_data = zlib.compress(data_str, level=9) 18 | compressed_ratio = len(data_str) / len(compressed_data) 19 | return compressed_ratio 20 | 21 | def selec_data_from_corpus( 22 | anchor_data, 23 | processed_data_index, 24 | budget, 25 | selection_num, 26 | candidate_budget='all', 27 | pool=None, 28 | turn_print=True, 29 | data_pool=None, 30 | global_information_redundancy_state=None, 31 | ): 32 | data_list = [data_pool[_] for _ in processed_data_index] 33 | 34 | selected_data = [] 35 | selected_index = [] 36 | if not turn_print: 37 | start_time = time.time() 38 | while True: 39 | if turn_print: 40 | start_time = time.time() 41 | # select topk instance to compute compression ratio in a greedy fashion 42 | if candidate_budget == 'all': 43 | group_information_redundancy_state = pool.map(get_compression_ratio, [anchor_data + selected_data + [part] for part in data_list]) 44 | group_information_redundancy_state = torch.tensor(group_information_redundancy_state) 45 | group_information_redundancy_state[selected_index] = 1000000 46 | _, min_index = torch.topk(group_information_redundancy_state, k=selection_num, largest=False) 47 | new_index = min_index.tolist() 48 | selected_instance_list = [] 49 | for _ in new_index: 50 | selected_instance = data_list[_] 51 | selected_instance_list.append(selected_instance) 52 | selected_index.extend(new_index) 53 | selected_data.extend(selected_instance_list) 54 | else: 55 | # global view 56 | _, cur_index = torch.topk(global_information_redundancy_state, k=candidate_budget, largest=False) 57 | group_list = [data_pool[idx] for idx in cur_index] 58 | # compute compression ratio 59 | group_information_redundancy_state = pool.map(get_compression_ratio, [anchor_data + selected_data + [part] for part in group_list]) 60 | group_information_redundancy_state = torch.tensor(group_information_redundancy_state) 61 | global_information_redundancy_state[cur_index] = group_information_redundancy_state 62 | _, min_index = torch.topk(group_information_redundancy_state, k=selection_num, largest=False) 63 | new_index = cur_index[min_index].tolist() 64 | global_information_redundancy_state[new_index] = 1000000 65 | selected_instance_list = [] 66 | for _ in new_index: 67 | selected_instance = data_pool[_] 68 | selected_instance_list.append(selected_instance) 69 | selected_index.extend(new_index) 70 | selected_data.extend(selected_instance_list) 71 | if turn_print: 72 | end_time = time.time() 73 | execution_time = end_time - start_time 74 | print(f"Code execution time: {execution_time} seconds") 75 | 76 | cur_len = len(selected_data) 77 | if cur_len >= budget: 78 | if candidate_budget == 'all': 79 | selected_global_index = [processed_data_index[_] for _ in selected_index] 80 | else: 81 | selected_global_index = selected_index 82 | if not turn_print: 83 | end_time = time.time() 84 | execution_time = end_time - start_time 85 | print(f"Code execution time: {execution_time} seconds") 86 | return selected_global_index, selected_data 87 | 88 | def ZIP_select(data_pool, save_path, budget, k1=10000, k2=200, k3=100, n_jobs=1): 89 | pool = multiprocessing.Pool(processes=min(multiprocessing.cpu_count(), n_jobs)) 90 | global_information_redundancy_state = pool.map(get_compression_ratio, [[part] for part in data_pool]) 91 | global_information_redundancy_state = torch.tensor(global_information_redundancy_state) 92 | 93 | final_selected_data = [] 94 | cur_data_index = list(range(len(data_pool))) 95 | while len(final_selected_data) < budget: 96 | print('stage 1 & stage 2') 97 | second_stage_index, _ = selec_data_from_corpus( 98 | final_selected_data, cur_data_index, k2, k2, k1, pool, turn_print=True, 99 | data_pool=data_pool, 100 | global_information_redundancy_state=global_information_redundancy_state 101 | ) 102 | print('stage 3') 103 | third_stage_index, third_stage_data = selec_data_from_corpus( 104 | [], second_stage_index, k3, 1, 'all', pool, turn_print=False, 105 | data_pool=data_pool, 106 | global_information_redundancy_state=global_information_redundancy_state 107 | ) 108 | cur_data_index = [_ for _ in cur_data_index if _ not in third_stage_index] 109 | final_selected_data.extend(third_stage_data) 110 | source_list = defaultdict(int) 111 | for _ in final_selected_data: 112 | source_list[_['source']] += 1 113 | print(f'selected {len(final_selected_data)}, including {source_list}') 114 | with open(save_path, 'w+', encoding='utf-8') as f: 115 | json.dump(final_selected_data, f) 116 | pool.close() 117 | 118 | 119 | if __name__ == '__main__': 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument('--data_path', type=str, default='./pool.json', help='The path to the original data pool with sharegpt format.') 122 | parser.add_argument('--save_path', type=str, default='./zip_selected_data.json', help='The path to the selected dataset.') 123 | parser.add_argument('--budget', type=int, default=1000, help='The number of selected instances') 124 | parser.add_argument('--k1', type=int, default=10000, help='') 125 | parser.add_argument('--k2', type=int, default=200, help='The gpu index, -1 for cpu') 126 | parser.add_argument('--k3', type=int, default=100, help='The gpu index, -1 for cpu') 127 | parser.add_argument('--n_jobs', type=int, default=64, help='The number of jobs to use for calculating compression ratios.') 128 | args = parser.parse_args() 129 | original_data_pool = load_original_data_pool(args.data_path) 130 | ZIP_select(original_data_pool, args.save_path, args.budget, args.k1, args.k2, args.k3, args.n_jobs) 131 | -------------------------------------------------------------------------------- /assets/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/USTC-StarTeam/ZIP/f77b0b5e2cc28414fc9d51f5d08bf34ad0353527/assets/motivation.png --------------------------------------------------------------------------------