├── experiments ├── __init__.py ├── sampling │ ├── __init__.py │ ├── test_api_tokenization.py │ ├── cache_local_samples.py │ ├── cache_logprobs.py │ ├── cache_local_samples_vllm.py │ └── cache_api_samples.py ├── testing │ ├── __init__.py │ ├── bootstrap_manager.py │ ├── cache_one_sample_bootstrap.py │ ├── cache_two_sample_bootstrap.py │ ├── simulation.py │ ├── simulate_one_sample_power.py │ ├── simulate_two_sample_power.py │ └── simulate_two_sample_power_composite.py ├── constants │ ├── wikipedia_prompt_indices_test │ │ ├── 1 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ ├── 9.json │ │ │ └── 3.json │ │ ├── 5 │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 9.json │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ └── 8.json │ │ ├── 10 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 9.json │ │ │ └── 8.json │ │ ├── 25 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ └── 9.json │ │ ├── 50 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ └── 9.json │ │ ├── 75 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ └── 9.json │ │ └── 100 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ └── 9.json │ ├── wikipedia_prompt_indices_val │ │ ├── 1 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ ├── 9.json │ │ │ └── 5.json │ │ ├── 5 │ │ │ ├── 4.json │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 9.json │ │ │ ├── 7.json │ │ │ └── 8.json │ │ ├── 25 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ ├── 6.json │ │ │ └── 9.json │ │ ├── 50 │ │ │ ├── 6.json │ │ │ ├── 1.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 7.json │ │ │ ├── 9.json │ │ │ ├── 2.json │ │ │ ├── 8.json │ │ │ └── 0.json │ │ ├── 75 │ │ │ ├── 6.json │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 5.json │ │ │ ├── 7.json │ │ │ ├── 9.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ └── 8.json │ │ └── 100 │ │ │ └── 0.json │ ├── vanilla_api.json │ ├── vanilla_local.json │ ├── llama3-chat-template.txt │ └── minimum_prompt_indices.json ├── utils.py └── prompts.py ├── hero.png ├── setup.png ├── model_equality_testing ├── src │ ├── __init__.py │ ├── algorithm.py │ ├── utils.py │ ├── dataset.py │ ├── pvalue.py │ ├── distribution.py │ └── tests.py ├── pyproject.toml └── LICENSE ├── .gitignore └── README.md /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/testing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/i-gao/model-equality-testing/HEAD/hero.png -------------------------------------------------------------------------------- /setup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/i-gao/model-equality-testing/HEAD/setup.png -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [75]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [70], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [69], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [38], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [43], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [67]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [34], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [44], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [45], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [18], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [2], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [14], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [9]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [13]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [16]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [1]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [20], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": []} 2 | -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [6], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": []} 2 | -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [9], "wikipedia_de": [], "wikipedia_fr": [4], "wikipedia_ru": [4], "wikipedia_es": [11, 13]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [87], "wikipedia_de": [43], "wikipedia_fr": [44, 63], "wikipedia_ru": [], "wikipedia_es": [90]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [98], "wikipedia_de": [46], "wikipedia_fr": [], "wikipedia_ru": [99], "wikipedia_es": [56, 58]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [95], "wikipedia_de": [], "wikipedia_fr": [48], "wikipedia_ru": [25, 96], "wikipedia_es": [28]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [30, 54], "wikipedia_de": [], "wikipedia_fr": [88], "wikipedia_ru": [41], "wikipedia_es": [64]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [36, 90], "wikipedia_de": [74], "wikipedia_fr": [95], "wikipedia_ru": [], "wikipedia_es": [28]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [19], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [2, 9], "wikipedia_es": [4, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [8, 15, 18], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [2, 13], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [11], "wikipedia_fr": [19], "wikipedia_ru": [3, 17, 19], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [0, 3, 6], "wikipedia_ru": [12], "wikipedia_es": [18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0], "wikipedia_de": [8, 17], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [3, 12]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [9, 15], "wikipedia_de": [10], "wikipedia_fr": [], "wikipedia_ru": [14, 15], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [14], "wikipedia_de": [], "wikipedia_fr": [11], "wikipedia_ru": [6], "wikipedia_es": [10, 16]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [79], "wikipedia_fr": [], "wikipedia_ru": [21, 48], "wikipedia_es": [74, 84]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [55, 56], "wikipedia_fr": [42, 74], "wikipedia_ru": [78], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [48, 81, 92], "wikipedia_fr": [68], "wikipedia_ru": [], "wikipedia_es": [33]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [60, 90, 95], "wikipedia_ru": [], "wikipedia_es": [69, 77]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 79], "wikipedia_de": [87], "wikipedia_fr": [38, 72], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [6, 12, 13, 18, 19], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [6], "wikipedia_ru": [], "wikipedia_es": [10, 12, 13, 16]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [28], "wikipedia_de": [45, 58], "wikipedia_fr": [69, 89], "wikipedia_ru": [39, 48, 81, 85], "wikipedia_es": [95]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [31], "wikipedia_de": [], "wikipedia_fr": [30, 88], "wikipedia_ru": [27, 40, 41, 47], "wikipedia_es": [46, 67, 78]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [29, 30, 44], "wikipedia_de": [26, 31, 40, 46, 61], "wikipedia_fr": [], "wikipedia_ru": [22], "wikipedia_es": [64]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [58, 79], "wikipedia_fr": [35, 45], "wikipedia_ru": [30, 31, 55, 63, 91], "wikipedia_es": [34]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [39, 63, 85], "wikipedia_de": [20, 31, 39, 68], "wikipedia_fr": [49], "wikipedia_ru": [], "wikipedia_es": [52, 57]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [42, 61], "wikipedia_de": [83], "wikipedia_fr": [46, 63], "wikipedia_ru": [36, 52, 90, 96], "wikipedia_es": [70]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [55, 62, 99], "wikipedia_de": [72], "wikipedia_fr": [26, 43, 54, 80], "wikipedia_ru": [93], "wikipedia_es": [88]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [33, 45, 79, 89], "wikipedia_de": [56, 89], "wikipedia_fr": [], "wikipedia_ru": [25, 71], "wikipedia_es": [38, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [42], "wikipedia_de": [51], "wikipedia_fr": [57, 69], "wikipedia_ru": [48, 65, 87], "wikipedia_es": [60, 67, 86]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [26, 30, 76, 85, 88], "wikipedia_fr": [29, 60, 72, 75], "wikipedia_ru": [], "wikipedia_es": [72]} -------------------------------------------------------------------------------- /model_equality_testing/src/__init__.py: -------------------------------------------------------------------------------- 1 | from . import algorithm 2 | from . import dataset 3 | from . import distribution 4 | from . import utils 5 | from . import pvalue 6 | from . import tests 7 | 8 | -------------------------------------------------------------------------------- /experiments/constants/vanilla_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "top_k": null, 3 | "temperature": 1.0, 4 | "top_p": 1.0, 5 | "frequency_penalty": 0.0, 6 | "presence_penalty": 0.0, 7 | "stop": null 8 | } 9 | 10 | -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [13, 19], "wikipedia_de": [1, 4, 5, 14, 19], "wikipedia_fr": [2, 3, 5, 11, 13, 14, 17, 18], "wikipedia_ru": [0, 5, 7, 9, 11, 16, 19], "wikipedia_es": [3, 4, 16]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2, 3, 14], "wikipedia_de": [1, 2, 4, 9, 12], "wikipedia_fr": [1, 5, 11, 12, 14], "wikipedia_ru": [2, 6, 16, 17, 19], "wikipedia_es": [0, 5, 10, 11, 17, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 7, 8, 9, 15], "wikipedia_de": [4, 5, 6, 7, 8, 15], "wikipedia_fr": [3, 10, 14], "wikipedia_ru": [3, 5, 6, 13, 15, 18], "wikipedia_es": [6, 9, 10, 14, 15]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [5, 7, 12], "wikipedia_de": [4, 11, 12, 19], "wikipedia_fr": [4, 5, 7, 14, 15, 17], "wikipedia_ru": [4, 6, 14, 16], "wikipedia_es": [0, 3, 5, 7, 8, 10, 12, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 2, 4, 5, 8, 12, 19], "wikipedia_de": [0, 7, 8, 11, 13, 19], "wikipedia_fr": [0, 9, 15], "wikipedia_ru": [3, 10], "wikipedia_es": [2, 3, 4, 7, 11, 12, 15]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2, 19], "wikipedia_de": [1, 3, 6, 8, 11, 15, 19], "wikipedia_fr": [3, 8, 15], "wikipedia_ru": [1, 3, 4, 10, 12, 17, 19], "wikipedia_es": [0, 1, 4, 6, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [3, 7, 8, 18], "wikipedia_de": [16], "wikipedia_fr": [0, 1, 2, 4, 5, 10, 12, 13, 14], "wikipedia_ru": [5, 6, 7, 9, 10, 14, 19], "wikipedia_es": [8, 14, 15, 17]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [9, 11, 14, 19], "wikipedia_de": [8, 17, 19], "wikipedia_fr": [0, 1, 6, 16, 18, 19], "wikipedia_ru": [1, 3, 4, 7, 8, 12, 14, 17], "wikipedia_es": [7, 8, 13, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [4, 10, 11, 16], "wikipedia_de": [5, 10, 18], "wikipedia_fr": [2, 4, 8, 9, 14, 16], "wikipedia_ru": [3, 6, 17, 19], "wikipedia_es": [1, 5, 8, 10, 13, 14, 17, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2, 4, 12, 16], "wikipedia_de": [0, 3, 4, 9, 10, 12, 14, 18], "wikipedia_fr": [2, 8, 12, 14, 17, 18], "wikipedia_ru": [0, 10, 12, 13], "wikipedia_es": [0, 10, 17]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [39, 77, 89], "wikipedia_de": [21, 46, 67, 85, 88, 96], "wikipedia_fr": [47, 51, 76, 83, 95], "wikipedia_ru": [24, 28, 37, 57, 61, 71, 97], "wikipedia_es": [36, 60, 67, 80]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [28, 37, 42, 55, 63, 73], "wikipedia_de": [21, 66, 88, 90, 92, 94], "wikipedia_fr": [50, 70, 76, 93], "wikipedia_ru": [31, 86, 89], "wikipedia_es": [27, 28, 32, 50, 63, 99]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [22, 41, 48, 52, 65, 89], "wikipedia_de": [25, 29, 31, 44, 80, 85, 94, 99], "wikipedia_fr": [99], "wikipedia_ru": [20, 21, 29, 60, 74, 99], "wikipedia_es": [77, 81, 84, 93]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [37, 50, 52, 53, 77, 90], "wikipedia_de": [35, 51, 61, 72, 81], "wikipedia_fr": [24, 36, 37, 61, 95], "wikipedia_ru": [24, 37, 39, 46, 66, 90], "wikipedia_es": [27, 38, 78]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [22, 32, 57, 68, 70], "wikipedia_de": [33, 57, 72, 74, 78, 88, 95, 97], "wikipedia_fr": [52, 64, 89], "wikipedia_ru": [54, 71, 73, 82, 87], "wikipedia_es": [42, 69, 74, 85]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 24, 37, 51, 60, 62, 81, 87], "wikipedia_de": [49, 67, 83], "wikipedia_fr": [24, 34, 35, 37, 61, 63], "wikipedia_ru": [29, 56, 75, 92], "wikipedia_es": [61, 75, 81, 90]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [24, 60, 79, 83, 89], "wikipedia_de": [20, 23, 27, 36, 55, 66, 92], "wikipedia_fr": [27, 28, 57, 65, 78], "wikipedia_ru": [51, 56, 69, 97], "wikipedia_es": [23, 31, 62, 81]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [50, 59, 63, 67, 91, 95, 99], "wikipedia_de": [59, 80], "wikipedia_fr": [20, 29, 45, 46, 73, 92], "wikipedia_ru": [57, 92], "wikipedia_es": [33, 37, 39, 40, 66, 72, 82, 88]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 41, 44, 46, 48, 49, 64, 66, 80, 82, 90], "wikipedia_de": [24, 59, 64], "wikipedia_fr": [64, 83, 93, 94, 98], "wikipedia_ru": [28, 30, 44, 70, 79, 89], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [28, 31, 52, 74, 75, 77], "wikipedia_de": [48, 80], "wikipedia_fr": [20, 24, 35, 67, 81, 85, 88], "wikipedia_ru": [25, 27, 35, 62, 65, 95], "wikipedia_es": [42, 80, 84, 92]} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg* 2 | model_equality_testing/dist/ 3 | sparse_logs/ 4 | *.zip 5 | *.bin 6 | __pycache__/ 7 | .ipynb_checkpoints/ 8 | wandb 9 | *.out 10 | scr/ 11 | *.pdf 12 | *.pkl 13 | cache/ 14 | old-cache/ 15 | power_*.json 16 | *.c 17 | *.so 18 | .nfs* 19 | *.o 20 | finetuned_models/ 21 | data/ -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 2, 3, 4, 8, 9, 10, 11, 14, 16, 18, 19], "wikipedia_de": [0, 1, 6, 7, 9, 12, 17], "wikipedia_fr": [0, 4, 5, 6, 7, 9, 11, 13, 15, 18], "wikipedia_ru": [0, 2, 3, 4, 9, 10, 12, 14, 15, 16, 17, 18], "wikipedia_es": [0, 2, 3, 4, 5, 10, 11, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 3, 5, 6, 8, 9, 11, 12, 13, 17, 19], "wikipedia_de": [1, 4, 9, 10, 11, 12, 13, 14, 15, 17, 19], "wikipedia_fr": [1, 2, 3, 6, 11, 14, 15, 17, 18, 19], "wikipedia_ru": [6, 8, 9, 12, 13, 15, 17, 19], "wikipedia_es": [1, 3, 4, 5, 6, 9, 13, 17, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 2, 3, 4, 6, 8, 9, 10, 11, 12, 14, 15, 16, 18], "wikipedia_de": [3, 4, 6, 7, 8, 9, 13, 15, 16], "wikipedia_fr": [0, 1, 2, 4, 9, 10, 11, 13, 17, 18], "wikipedia_ru": [2, 8, 9, 13, 14, 15, 16], "wikipedia_es": [1, 2, 4, 5, 9, 11, 13, 14, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 2, 5, 7, 9, 10, 11, 12, 13, 15, 17, 18], "wikipedia_de": [2, 3, 4, 8, 9, 10, 11, 12, 16, 17], "wikipedia_fr": [1, 4, 5, 6, 7, 9, 11, 13, 14, 17, 18], "wikipedia_ru": [3, 6, 8, 11, 12, 13, 16], "wikipedia_es": [2, 4, 6, 7, 8, 11, 15, 16, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2, 3, 4, 5, 6, 12, 13, 14, 16, 17, 18], "wikipedia_de": [0, 2, 3, 5, 7, 8, 11, 15, 16, 17, 18], "wikipedia_fr": [0, 3, 6, 10, 11, 15, 17], "wikipedia_ru": [1, 4, 5, 7, 9, 10, 14, 15, 16, 17, 19], "wikipedia_es": [1, 3, 7, 8, 12, 14, 15, 16, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 3, 4, 5, 6, 9, 10, 12, 13, 14, 16, 17], "wikipedia_de": [0, 3, 5, 6, 7, 8, 11, 12, 15, 16, 18], "wikipedia_fr": [2, 8, 9, 10, 12, 14, 18], "wikipedia_ru": [0, 3, 4, 5, 6, 16, 18, 19], "wikipedia_es": [0, 2, 5, 6, 7, 10, 11, 13, 15, 16, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2, 3, 6, 7, 9, 10, 14, 15, 17, 18, 19], "wikipedia_de": [0, 1, 6, 8, 9, 10, 12, 13, 15, 16, 18], "wikipedia_fr": [0, 1, 4, 5, 6, 7, 9, 10, 13, 14, 15, 16, 18, 19], "wikipedia_ru": [0, 1, 6, 12, 15, 17, 19], "wikipedia_es": [1, 2, 8, 13, 16, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 2, 4, 8, 9, 11, 12, 13, 14, 15, 16, 17], "wikipedia_de": [0, 1, 2, 10, 11, 13, 14, 17, 18, 19], "wikipedia_fr": [2, 4, 6, 8, 10, 11, 14, 17], "wikipedia_ru": [1, 2, 4, 8, 9, 12, 15, 18, 19], "wikipedia_es": [0, 2, 5, 10, 11, 12, 13, 14, 16, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2, 3, 4, 11, 13, 15, 16, 18], "wikipedia_de": [0, 1, 3, 4, 7, 10, 11, 12, 14, 15, 16, 17, 19], "wikipedia_fr": [2, 4, 8, 11, 12, 15, 18, 19], "wikipedia_ru": [0, 1, 4, 10, 12, 13, 14, 15, 16], "wikipedia_es": [1, 3, 4, 6, 7, 9, 10, 11, 12, 15, 17, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 1, 10, 12, 13, 14, 15, 16, 18, 19], "wikipedia_de": [0, 1, 3, 6, 10, 12, 13, 14, 16, 17, 18], "wikipedia_fr": [2, 8, 9, 10, 12, 13, 17, 18, 19], "wikipedia_ru": [1, 3, 5, 8, 10, 11, 14, 17, 18], "wikipedia_es": [0, 1, 2, 10, 11, 12, 13, 16, 17, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [20, 29, 38, 58, 59, 82, 85, 88, 95], "wikipedia_de": [21, 29, 36, 49, 50, 52, 61, 74, 80, 86, 87, 92, 96], "wikipedia_fr": [23, 25, 34, 45, 48, 49, 55, 79, 83, 94], "wikipedia_ru": [23, 28, 37, 38, 43, 69, 77, 88, 95], "wikipedia_es": [24, 40, 41, 46, 50, 60, 81, 86, 91]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [45, 50, 62, 64, 65, 74, 99], "wikipedia_de": [30, 32, 35, 48, 55, 60, 63, 65, 69, 71, 79, 80, 92, 94], "wikipedia_fr": [35, 39, 51, 65, 70, 72, 85, 87, 88, 90], "wikipedia_ru": [29, 34, 63, 65, 75, 82, 83, 90], "wikipedia_es": [22, 24, 30, 43, 50, 56, 75, 77, 85, 90, 97]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [36, 46, 54, 65, 70, 78, 87, 92], "wikipedia_de": [20, 28, 29, 40, 46, 58, 61, 63, 66, 68, 71, 74, 82, 85, 97], "wikipedia_fr": [21, 37, 43, 54, 57, 70, 89], "wikipedia_ru": [24, 25, 29, 32, 35, 40, 61, 65, 68], "wikipedia_es": [23, 26, 33, 40, 45, 57, 65, 71, 77, 97, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [35, 45, 48, 54, 58, 73, 78, 93, 96, 97], "wikipedia_de": [29, 35, 45, 56, 69, 80, 90, 91], "wikipedia_fr": [29, 38, 47, 49, 51, 55, 57, 66, 73, 83, 84], "wikipedia_ru": [22, 25, 29, 38, 56, 63, 80, 86], "wikipedia_es": [20, 26, 27, 29, 32, 35, 39, 50, 60, 77, 81, 88, 94]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [37, 39, 53, 54, 56, 74, 75, 78, 90, 92, 99], "wikipedia_de": [21, 23, 24, 26, 37, 73, 84], "wikipedia_fr": [26, 40, 41, 47, 57, 58, 77, 86, 91], "wikipedia_ru": [33, 46, 48, 53, 66, 77, 78, 80, 88, 93], "wikipedia_es": [21, 31, 36, 42, 44, 46, 50, 59, 64, 73, 80, 85, 86]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [20, 26, 33, 50, 51, 60, 61, 69, 70, 82, 83, 97, 99], "wikipedia_de": [23, 35, 36, 53, 56, 60, 88, 94], "wikipedia_fr": [23, 30, 44, 45, 50, 58, 65, 70, 74, 80, 85, 98], "wikipedia_ru": [23, 32, 39, 46, 50, 66, 79], "wikipedia_es": [32, 36, 38, 40, 45, 53, 67, 76, 78, 96]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [62, 66, 84, 91, 92], "wikipedia_de": [26, 28, 34, 38, 52, 53, 64, 66, 74, 78, 79, 80, 87, 89, 91, 94], "wikipedia_fr": [21, 27, 35, 37, 70, 71, 74, 78, 85, 89, 94, 96], "wikipedia_ru": [20, 30, 37, 38, 39, 49, 59, 84, 90], "wikipedia_es": [47, 59, 60, 64, 69, 80, 81, 86]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 23, 31, 34, 44, 56, 59, 62, 69, 77, 79, 81, 83, 90], "wikipedia_de": [24, 31, 65, 84, 87, 94], "wikipedia_fr": [24, 30, 31, 36, 41, 43, 45, 49, 50], "wikipedia_ru": [24, 25, 39, 51, 56, 61, 62, 67, 79, 90, 96], "wikipedia_es": [27, 31, 40, 41, 51, 53, 67, 68, 83, 95]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [28, 29, 34, 50, 52, 53, 57, 59, 83, 85, 88, 94, 95, 98], "wikipedia_de": [33, 42, 47, 56, 68, 72, 75, 78], "wikipedia_fr": [22, 28, 43, 48, 57, 65, 69, 72, 77, 86, 87, 98], "wikipedia_ru": [22, 44, 47, 54, 59, 68, 84, 87, 99], "wikipedia_es": [32, 39, 41, 67, 73, 85, 86]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [27, 31, 35, 42, 48, 50, 56, 67, 68, 71, 85, 88, 91, 97, 99], "wikipedia_de": [30, 47, 54, 67, 70, 71, 74, 80, 96], "wikipedia_fr": [23, 34, 41, 43, 51, 54, 96], "wikipedia_ru": [20, 28, 38, 45, 46, 52, 61, 64, 66, 67, 71, 74, 77, 78, 82], "wikipedia_es": [20, 37, 54, 98]} -------------------------------------------------------------------------------- /experiments/constants/vanilla_local.json: -------------------------------------------------------------------------------- 1 | { 2 | "top_k": 0, 3 | "temperature": 1.0, 4 | "top_p": 1.0, 5 | "repetition_penalty": 1.0, 6 | "num_beams": 1, 7 | "do_sample": true, 8 | "early_stopping": false, 9 | "no_repeat_ngram_size": 0, 10 | "bad_words_ids": null, 11 | "force_words_ids": null, 12 | "use_cache": true 13 | } 14 | 15 | -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 1, 2, 3, 5, 6, 7, 9, 10, 11, 12, 14, 15], "wikipedia_de": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 18, 19], "wikipedia_fr": [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14, 16, 17, 19], "wikipedia_ru": [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16, 17, 18], "wikipedia_es": [0, 1, 2, 3, 4, 5, 7, 8, 9, 12, 15, 16, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 1, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19], "wikipedia_de": [0, 1, 2, 4, 7, 8, 9, 12, 13, 14, 15, 16], "wikipedia_fr": [1, 2, 3, 4, 5, 6, 8, 11, 12, 14, 15, 16, 17, 18, 19], "wikipedia_ru": [1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 13, 14, 15, 16, 18, 19], "wikipedia_es": [0, 1, 2, 3, 5, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 3, 4, 5, 6, 8, 9, 11, 12, 13, 14, 15, 16, 18], "wikipedia_de": [1, 2, 3, 4, 6, 7, 10, 11, 12, 14, 15, 16, 17, 18], "wikipedia_fr": [0, 2, 3, 4, 5, 6, 8, 9, 10, 12, 14, 15, 16, 17, 18, 19], "wikipedia_ru": [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 13, 15, 17, 18, 19], "wikipedia_es": [0, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 2, 4, 5, 6, 7, 8, 9, 11, 13, 14, 17, 18, 19], "wikipedia_de": [1, 2, 5, 7, 8, 9, 10, 11, 12, 14, 15, 17, 19], "wikipedia_fr": [0, 1, 2, 3, 4, 6, 7, 8, 9, 12, 13, 14, 19], "wikipedia_ru": [1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19], "wikipedia_es": [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19], "wikipedia_de": [0, 1, 2, 3, 4, 5, 8, 9, 10, 14, 15, 16, 17, 18], "wikipedia_fr": [0, 2, 3, 4, 5, 7, 8, 10, 13, 14, 15, 17, 19], "wikipedia_ru": [1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 19], "wikipedia_es": [0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 14, 15, 16, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 18, 19], "wikipedia_de": [0, 1, 2, 3, 4, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19], "wikipedia_fr": [0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13, 15, 16, 17, 18, 19], "wikipedia_ru": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 13, 14, 17, 18, 19], "wikipedia_es": [2, 5, 6, 8, 9, 10, 11, 12, 13, 16, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 19], "wikipedia_de": [1, 4, 5, 6, 7, 8, 9, 12, 14, 16, 18, 19], "wikipedia_fr": [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19], "wikipedia_ru": [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 15, 16, 17, 18, 19], "wikipedia_es": [0, 1, 3, 4, 6, 8, 9, 12, 13, 14, 15, 16]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 2, 4, 7, 8, 10, 11, 12, 13, 14, 15, 16, 19], "wikipedia_de": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], "wikipedia_fr": [0, 2, 3, 6, 7, 8, 10, 11, 12, 13, 15, 16, 17, 18, 19], "wikipedia_ru": [1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18], "wikipedia_es": [0, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 2, 3, 4, 5, 7, 8, 11, 12, 13, 14, 15, 17, 18, 19], "wikipedia_de": [0, 1, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], "wikipedia_fr": [0, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18], "wikipedia_ru": [2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 15, 17, 19], "wikipedia_es": [0, 1, 2, 4, 6, 7, 8, 9, 10, 11, 13, 14, 16, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 3, 4, 5, 6, 7, 9, 12, 14, 15, 16, 17, 18, 19], "wikipedia_de": [1, 2, 5, 6, 7, 9, 11, 12, 14, 15, 16, 17, 19], "wikipedia_fr": [0, 1, 2, 4, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19], "wikipedia_ru": [0, 1, 2, 3, 5, 6, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19], "wikipedia_es": [0, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/llama3-chat-template.txt: -------------------------------------------------------------------------------- 1 | {% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|> 2 | 3 | '+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|> 4 | 5 | ' }}{% endif %} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [23, 24, 33, 35, 37, 43, 50, 53, 54, 55, 58, 64, 73, 77, 88, 90, 94], "wikipedia_de": [22, 26, 31, 32, 37, 39, 42, 46, 57, 68, 75, 77, 82, 83, 87, 90, 91], "wikipedia_fr": [21, 25, 26, 28, 36, 40, 45, 47, 55, 63, 64, 78, 79, 80, 81, 99], "wikipedia_ru": [24, 27, 29, 32, 37, 39, 41, 55, 57, 61, 66, 68, 70, 71, 75, 97], "wikipedia_es": [22, 26, 31, 37, 41, 54, 66, 74, 93]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 27, 30, 36, 38, 43, 44, 49, 56, 59, 63, 65, 67, 79, 83, 87, 90, 98], "wikipedia_de": [22, 24, 26, 35, 52, 54, 57, 60, 62, 71, 73, 75, 77, 81, 82, 95, 98], "wikipedia_fr": [23, 27, 30, 42, 49, 53, 55, 72, 76, 77, 81, 82, 83, 97], "wikipedia_ru": [27, 33, 38, 40, 47, 54, 61, 62, 63, 72, 78, 80, 81, 88, 91], "wikipedia_es": [20, 30, 31, 42, 44, 51, 61, 68, 75, 76, 90]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 28, 46, 60, 62, 63, 68, 70, 74, 75, 85, 91, 96], "wikipedia_de": [26, 29, 37, 45, 46, 51, 77, 84, 88, 90, 94, 98], "wikipedia_fr": [20, 24, 28, 34, 36, 37, 45, 47, 48, 52, 53, 59, 60, 65, 67, 68, 80, 81], "wikipedia_ru": [24, 25, 27, 33, 35, 38, 40, 53, 58, 59, 60, 63, 77, 82, 84, 87, 95], "wikipedia_es": [20, 25, 38, 39, 40, 45, 48, 51, 65, 74, 76, 77, 78, 85, 86]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [30, 34, 35, 51, 54, 58, 63, 72, 73, 77, 79, 82, 83, 89, 90, 91, 92], "wikipedia_de": [34, 48, 49, 54, 67, 77, 81, 83, 86, 88, 89, 96], "wikipedia_fr": [32, 33, 38, 46, 48, 61, 65, 67, 69, 83, 87, 89, 92, 95], "wikipedia_ru": [23, 32, 33, 34, 40, 44, 47, 64, 70, 81, 82, 84, 85, 89, 91, 95, 99], "wikipedia_es": [21, 25, 26, 31, 40, 42, 44, 53, 63, 67, 69, 72, 82, 88, 93]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [23, 26, 27, 29, 31, 32, 36, 44, 46, 51, 59, 66, 74, 78, 83, 86, 87, 92, 97, 99], "wikipedia_de": [20, 21, 22, 27, 33, 34, 35, 45, 47, 48, 64, 71, 78, 79, 82, 83, 85, 86, 99], "wikipedia_fr": [28, 30, 37, 48, 65, 66, 88, 89, 95], "wikipedia_ru": [20, 31, 38, 41, 44, 53, 55, 58, 69, 70, 78, 90, 94, 99], "wikipedia_es": [22, 25, 28, 29, 58, 59, 60, 62, 63, 66, 78, 90, 96]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [20, 30, 36, 47, 48, 53, 58, 59, 60, 61, 63, 67, 72, 75, 87, 89, 95, 99], "wikipedia_de": [23, 29, 38, 45, 54, 60, 62, 72, 75, 89], "wikipedia_fr": [30, 34, 37, 40, 50, 55, 59, 60, 61, 63, 66, 79, 80, 90, 91, 92, 94], "wikipedia_ru": [20, 21, 30, 36, 46, 55, 58, 61, 68, 72, 82, 88, 90, 96], "wikipedia_es": [27, 28, 30, 31, 32, 35, 47, 49, 53, 63, 68, 74, 75, 77, 81, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [23, 25, 28, 30, 35, 44, 46, 48, 49, 59, 60, 72, 77, 86, 87, 88], "wikipedia_de": [22, 29, 34, 38, 42, 43, 50, 52, 53, 72, 82, 83, 89, 90, 93, 99], "wikipedia_fr": [32, 48, 52, 58, 63, 64, 66, 69, 77, 78, 80, 82, 83, 87, 92, 99], "wikipedia_ru": [22, 31, 44, 46, 47, 48, 55, 65, 84, 89, 99], "wikipedia_es": [20, 21, 23, 32, 34, 35, 36, 38, 45, 49, 53, 54, 72, 75, 76, 81]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [23, 27, 40, 47, 51, 53, 55, 56, 68, 94, 95, 99], "wikipedia_de": [21, 23, 24, 25, 27, 28, 31, 46, 50, 51, 62, 63, 66, 80, 82, 88], "wikipedia_fr": [24, 32, 36, 38, 39, 49, 50, 51, 59, 69, 71, 75, 87, 91], "wikipedia_ru": [21, 22, 33, 40, 47, 48, 51, 56, 62, 67, 81, 84, 90, 96], "wikipedia_es": [21, 28, 35, 40, 41, 43, 45, 46, 47, 54, 55, 57, 59, 61, 65, 71, 81, 95, 96]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [36, 50, 51, 59, 63, 64, 71, 78, 79, 81, 95, 98], "wikipedia_de": [25, 28, 31, 37, 42, 45, 48, 59, 60, 63, 65, 66, 72, 77, 88], "wikipedia_fr": [29, 34, 35, 43, 49, 52, 53, 58, 59, 61, 67, 69, 72, 74, 76, 84, 93, 94, 97], "wikipedia_ru": [21, 26, 27, 34, 39, 43, 51, 52, 65, 72, 88, 92, 98], "wikipedia_es": [26, 28, 32, 39, 47, 51, 62, 67, 68, 72, 84, 87, 88, 92, 95, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [30, 39, 40, 48, 51, 54, 55, 71, 76, 77, 80, 83, 86, 92, 98, 99], "wikipedia_de": [24, 36, 37, 38, 42, 47, 48, 52, 54, 59, 60, 69, 76, 79, 84, 85, 87, 88, 90, 98], "wikipedia_fr": [25, 26, 38, 52, 59, 67, 74, 78, 86, 97], "wikipedia_ru": [21, 39, 43, 47, 52, 53, 57, 66, 72, 73, 79, 84, 85, 92, 96, 99], "wikipedia_es": [24, 29, 30, 35, 42, 43, 51, 59, 67, 74, 87, 89, 90]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/100/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], "wikipedia_de": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], "wikipedia_fr": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], "wikipedia_ru": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], "wikipedia_es": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [24, 26, 27, 29, 38, 42, 54, 56, 57, 60, 66, 76, 77, 78, 86, 92, 94, 95, 99], "wikipedia_de": [28, 29, 32, 37, 38, 45, 51, 56, 60, 61, 65, 70, 72, 73, 74, 78, 79, 81, 83, 89, 91, 94], "wikipedia_fr": [24, 26, 28, 36, 39, 47, 60, 66, 70, 71, 77, 78, 82, 83, 86, 94, 98], "wikipedia_ru": [29, 39, 43, 46, 53, 55, 56, 58, 59, 72, 75, 85, 86, 88, 90, 94, 97], "wikipedia_es": [21, 22, 26, 27, 30, 31, 32, 42, 44, 53, 54, 61, 63, 65, 68, 69, 73, 75, 78, 80, 85, 86, 92, 93, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [22, 23, 26, 27, 28, 38, 45, 47, 48, 51, 57, 59, 62, 67, 78, 79, 81, 83, 91, 99], "wikipedia_de": [21, 35, 36, 43, 44, 46, 48, 50, 56, 71, 77, 79, 91, 92, 95], "wikipedia_fr": [20, 30, 31, 36, 42, 44, 46, 47, 48, 51, 54, 60, 62, 65, 72, 73, 75, 82, 83, 84, 88, 95, 99], "wikipedia_ru": [21, 29, 34, 41, 42, 43, 44, 48, 49, 69, 70, 78, 79, 80, 82, 85, 92, 95], "wikipedia_es": [21, 25, 33, 39, 43, 49, 53, 55, 56, 57, 67, 73, 75, 76, 78, 79, 81, 85, 86, 89, 90, 95, 97, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [26, 27, 34, 40, 41, 44, 51, 59, 60, 62, 65, 72, 73, 78, 86, 87, 90, 95, 99], "wikipedia_de": [23, 24, 28, 31, 35, 43, 47, 54, 55, 56, 61, 63, 73, 81, 85, 88, 89, 91, 94, 96, 98], "wikipedia_fr": [21, 23, 27, 29, 37, 43, 49, 52, 53, 54, 55, 59, 60, 71, 72, 83, 85, 97], "wikipedia_ru": [20, 24, 30, 33, 34, 39, 40, 45, 50, 54, 63, 64, 67, 70, 72, 78, 85, 86, 87, 90, 94, 95, 97], "wikipedia_es": [20, 22, 24, 26, 37, 44, 49, 56, 64, 72, 73, 76, 78, 81, 82, 84, 89, 95, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [20, 24, 26, 27, 28, 32, 33, 36, 39, 45, 48, 54, 59, 61, 68, 76, 79, 82, 84, 85, 86, 88, 92, 96], "wikipedia_de": [20, 21, 23, 28, 32, 37, 41, 46, 50, 55, 59, 60, 64, 68, 73, 78, 80, 82, 87, 89, 91], "wikipedia_fr": [20, 21, 23, 25, 33, 34, 36, 40, 44, 46, 55, 58, 63, 66, 72, 73, 83, 85, 94, 98], "wikipedia_ru": [27, 29, 31, 39, 45, 48, 53, 56, 57, 66, 76, 77, 80, 82, 83, 84, 89, 90], "wikipedia_es": [24, 26, 36, 43, 47, 56, 61, 66, 70, 74, 77, 78, 79, 82, 83, 85, 94]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [29, 30, 31, 34, 40, 42, 44, 45, 47, 52, 53, 54, 58, 68, 69, 71, 74, 79, 81, 85, 99], "wikipedia_de": [23, 26, 29, 31, 32, 37, 48, 53, 56, 58, 62, 63, 75, 80, 82, 86, 91, 96], "wikipedia_fr": [21, 22, 31, 33, 41, 42, 44, 49, 54, 55, 56, 57, 64, 72, 75, 76, 84, 87, 93, 94], "wikipedia_ru": [20, 22, 23, 25, 32, 34, 41, 44, 45, 47, 53, 59, 62, 66, 69, 70, 74, 80, 88, 91, 99], "wikipedia_es": [20, 22, 25, 29, 33, 43, 45, 47, 54, 57, 66, 73, 75, 78, 81, 82, 84, 88, 90, 97]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [23, 26, 30, 33, 41, 45, 48, 55, 56, 58, 59, 60, 70, 73, 75, 76, 80, 85, 88, 93], "wikipedia_de": [22, 24, 25, 27, 31, 35, 36, 40, 41, 50, 51, 54, 57, 58, 68, 70, 75, 76, 77, 83, 87, 91], "wikipedia_fr": [20, 22, 31, 32, 33, 37, 41, 42, 45, 47, 48, 52, 54, 63, 75, 77, 79, 81, 83, 89, 91, 94, 96, 97], "wikipedia_ru": [20, 22, 23, 24, 27, 34, 35, 36, 37, 40, 41, 43, 58, 61, 64, 68, 69, 83, 84, 95, 99], "wikipedia_es": [20, 22, 37, 41, 73, 75, 77, 79, 86, 87, 90, 97, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [23, 26, 28, 30, 31, 34, 40, 41, 43, 46, 57, 59, 69, 71, 74, 75, 77, 79, 80, 91, 92, 93, 94, 96, 97, 98], "wikipedia_de": [20, 32, 34, 37, 47, 50, 57, 62, 64, 65, 86, 91, 92, 95], "wikipedia_fr": [20, 22, 23, 25, 41, 46, 49, 57, 59, 61, 63, 64, 70, 73, 75, 76, 80, 86, 89, 90, 96, 97, 98], "wikipedia_ru": [20, 34, 37, 38, 45, 51, 60, 65, 71, 74, 78, 80, 86, 93, 96, 98, 99], "wikipedia_es": [27, 31, 32, 34, 36, 37, 38, 41, 43, 45, 46, 56, 61, 66, 67, 69, 72, 84, 89, 90]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [24, 25, 34, 38, 42, 43, 46, 51, 52, 53, 54, 57, 58, 59, 62, 69, 77, 78, 79, 80, 82, 92, 94, 95, 97], "wikipedia_de": [24, 26, 27, 34, 35, 36, 37, 41, 42, 48, 55, 56, 61, 63, 64, 65, 74, 77, 78, 80, 81, 89, 94, 95], "wikipedia_fr": [23, 29, 35, 36, 40, 48, 54, 55, 58, 69, 70, 76, 79, 80, 85, 90, 98], "wikipedia_ru": [20, 29, 34, 35, 40, 44, 45, 53, 54, 57, 63, 64, 65, 66, 67, 70, 76, 85, 86, 88, 92], "wikipedia_es": [21, 34, 40, 55, 62, 82, 85, 86, 88, 91, 92, 95, 96]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [25, 26, 29, 32, 33, 37, 38, 41, 43, 44, 49, 54, 61, 62, 63, 64, 68, 92, 97], "wikipedia_de": [23, 26, 36, 37, 41, 45, 46, 50, 54, 58, 61, 62, 68, 77, 80, 81, 83, 87, 91, 93], "wikipedia_fr": [20, 26, 32, 36, 38, 44, 47, 48, 51, 57, 60, 62, 66, 68, 76, 82, 86, 88, 90, 94], "wikipedia_ru": [27, 29, 35, 52, 53, 59, 62, 63, 66, 67, 68, 73, 76, 78, 81, 92, 94, 96, 98], "wikipedia_es": [22, 29, 30, 37, 41, 42, 43, 44, 48, 49, 51, 53, 58, 63, 73, 75, 76, 77, 82, 87, 89, 96]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 22, 23, 24, 25, 32, 35, 36, 40, 43, 44, 46, 47, 52, 54, 56, 61, 63, 66, 68, 73, 77, 78, 79, 86, 99], "wikipedia_de": [24, 27, 31, 33, 40, 46, 53, 54, 55, 58, 59, 60, 61, 62, 63, 76, 91, 93, 97], "wikipedia_fr": [21, 22, 24, 32, 39, 40, 57, 61, 62, 69, 81, 93, 94, 95, 97, 98], "wikipedia_ru": [23, 25, 41, 45, 47, 49, 53, 57, 61, 62, 71, 76, 77, 83, 88, 89, 93, 97], "wikipedia_es": [20, 23, 25, 26, 38, 40, 44, 45, 46, 48, 53, 54, 58, 63, 65, 68, 72, 74, 81, 85, 99]} -------------------------------------------------------------------------------- /model_equality_testing/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [tool.setuptools] 7 | packages = ["model_equality_testing"] 8 | package-dir = { "model_equality_testing" = "src" } 9 | 10 | [project] 11 | name = "model_equality_testing" 12 | version = "0.0.1" 13 | authors = [ 14 | { name="Irena Gao", email="irena@cs.stanford.edu" }, 15 | ] 16 | description = "Package to conduct model equality testing for black-box language model APIs" 17 | readme = "README.md" 18 | requires-python = ">=3.8" 19 | classifiers = [ 20 | "Programming Language :: Python :: 3", 21 | "License :: OSI Approved :: MIT License", 22 | "Operating System :: OS Independent", 23 | ] 24 | dependencies = [ 25 | "numpy", 26 | "torch", 27 | "matplotlib" 28 | ] 29 | 30 | [project.urls] 31 | Homepage = "https://github.com/i-gao/model-equality-testing" 32 | Issues = "https://github.com/i-gao/model-equality-testing/issues" 33 | -------------------------------------------------------------------------------- /model_equality_testing/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /experiments/sampling/test_api_tokenization.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../") 3 | from accelerate import Accelerator 4 | 5 | APIS = [ 6 | "anyscale", 7 | "together", 8 | "fireworks", 9 | "perplexity", 10 | "replicate", 11 | "groq", 12 | "deepinfra", 13 | "amazon", 14 | "azure", 15 | ] 16 | MODELS = [ 17 | "meta-llama/Meta-Llama-3-8B-Instruct", 18 | "meta-llama/Meta-Llama-3-70B-Instruct", 19 | "meta-llama/Meta-Llama-3.1-8B-Instruct", 20 | "meta-llama/Meta-Llama-3.1-70B-Instruct", 21 | "meta-llama/Meta-Llama-3.1-405B-Instruct", 22 | ] 23 | 24 | ############ 25 | 26 | from experiments.sampling.model import TransformersModel 27 | import experiments.prompts as prompts 28 | from cache_api_samples import policy 29 | import experiments.sampling.api as api_library 30 | 31 | def get_expected_prompt_len(api, model, ds): 32 | _policy = policy(api, model.model_name) 33 | p = ds[0][_policy["prompt_key"]] 34 | p_for_len = ds[0][_policy["expected_prompt_key"]] 35 | ids = model.tokenizer( 36 | p_for_len, add_special_tokens=('special' not in _policy["expected_prompt_key"]), 37 | )['input_ids'] 38 | return len(ids), p 39 | 40 | def test_sampling(p, model, api, expected_prompt_len, n): 41 | print(f">> Testing {model} with {api} and repeatedly requesting through n={n}") 42 | i = 0 43 | try: 44 | raw = "" 45 | out = [] 46 | get_fn, kwargs = getattr(api_library, f"setup_{api}")( 47 | model=model, 48 | N=1, 49 | L=5, 50 | use_chat_endpoint=policy(api, model)["use_chat_endpoint"], 51 | do_sample=True, 52 | temperature=1, 53 | top_p=None, 54 | ) 55 | for i in range(n): 56 | o = get_fn(prompt=p, **kwargs)[0] 57 | if not expected_prompt_len == o.num_prompt_tokens: 58 | if o.prompt is None or o.prompt != p: 59 | raise Exception(f"Expected prompt len was {expected_prompt_len}, actual was {o.num_prompt_tokens}\n{o}") 60 | out.append(o.full_completion) 61 | if len(set(out)) == 1: 62 | raise Exception("Always got the same output: " + out[0]) 63 | except Exception as e: 64 | print("\tResult: FAILED") 65 | print("\tIteration " + str(i) + ": " + str(e) + str(raw)) 66 | return False 67 | print("\tResult: PASSED") 68 | return True 69 | 70 | ############# 71 | 72 | accelerator = Accelerator() 73 | 74 | for m in MODELS: 75 | model = TransformersModel(m, accelerator, skip_loading_weights=True) 76 | ds = prompts.get_bit_prompts(model) 77 | for api in APIS: 78 | expected_prompt_len, prompt = get_expected_prompt_len(api, model, ds) 79 | # n=1, repeated sampling for 3 times 80 | test_sampling(prompt, m, api, expected_prompt_len, 2) 81 | -------------------------------------------------------------------------------- /experiments/testing/bootstrap_manager.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Union, Tuple 2 | import numpy as np 3 | import torch 4 | import pickle 5 | from model_equality_testing.pvalue import EmpiricalPvalueCalculator 6 | 7 | 8 | def load_parametric_bootstrap(path, min_b=None, max_b=None) -> np.ndarray: 9 | """ 10 | Loads pre-saved test stats from bootstrapping and enforces the correct shapes. 11 | Output shape is (b, 1, n_stats) where b is the number of bootstraps and n_stats is the number of statistics. 12 | """ 13 | with open(path, "rb") as f: 14 | stats = pickle.load(f) 15 | if stats.ndim == 1: 16 | stats = np.expand_dims(stats, 0) 17 | if stats.ndim == 2: 18 | stats = np.expand_dims(stats, 2) 19 | if max_b is not None: 20 | stats = stats[:max_b] 21 | if min_b is not None: 22 | assert min_b <= len( 23 | stats 24 | ), f"min_b={min_b} is greater than the number of bootstraps {len(stats)}" 25 | return stats 26 | 27 | 28 | class BootstrapManager: 29 | """ 30 | Helper class to manage loading of bootstrapped statistics with different sample sizes. 31 | """ 32 | 33 | def __init__(self, bootstrap_path_template: str, min_b=None, max_b=None): 34 | """ 35 | Args: 36 | bootstrap_path_template: str 37 | A string template that can be filled in with the sample size n 38 | Example: "cache/parametric_bootstrap_stats/meta-llama-Meta-Llama-3-8B-wikipedia-{n}.pkl" 39 | min_b: int 40 | Minimum number of bootstraps to load 41 | max_b: int 42 | Maximum number of bootstraps to load 43 | """ 44 | self._bootstrap_path_template = bootstrap_path_template 45 | self.min_b = min_b 46 | self.max_b = max_b 47 | self._stats = None 48 | 49 | def load( 50 | self, 51 | return_stats=False, 52 | **kwargs, 53 | ) -> Union[EmpiricalPvalueCalculator, Tuple[np.ndarray, EmpiricalPvalueCalculator]]: 54 | """ 55 | Loads the bootstrapped statistics for a given sample size n. 56 | Args: 57 | n: int or List[int] 58 | The sample size to load statistics for 59 | return_stats: bool 60 | If True, returns the raw statistics as well as the p-value calculator 61 | Returns: 62 | EmpiricalPvalueCalculator or Tuple[np.ndarray, EmpiricalPvalueCalculator] 63 | """ 64 | for k, v in kwargs.items(): 65 | try: 66 | v = v.item() 67 | except: 68 | pass 69 | self._stats = load_parametric_bootstrap( 70 | self._bootstrap_path_template.format(**kwargs), 71 | min_b=self.min_b, 72 | max_b=self.max_b, 73 | ) 74 | 75 | if return_stats: 76 | return self._stats, EmpiricalPvalueCalculator(self._stats) 77 | return EmpiricalPvalueCalculator(self._stats) 78 | -------------------------------------------------------------------------------- /experiments/sampling/cache_local_samples.py: -------------------------------------------------------------------------------- 1 | """ 2 | Caches n samples / prompt by locally inferencing a model on a dataset of prompts. 3 | """ 4 | 5 | import torch 6 | import tqdm 7 | import argparse 8 | import os 9 | import experiments.prompts as prompts_module 10 | from experiments.sampling.model import TransformersModel 11 | from experiments.utils import ( 12 | str_to_bool, 13 | ParseKwargs, 14 | build_cache_filename, 15 | ) 16 | from accelerate import Accelerator 17 | import pickle 18 | import glob 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--model", type=str, required=True) 24 | parser.add_argument("--model_kwargs", nargs="*", action=ParseKwargs, default={}) 25 | parser.add_argument( 26 | "--source", 27 | type=str, 28 | default="fp32", 29 | ) 30 | parser.add_argument("--prompts", default="dummy", type=str) 31 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 32 | parser.add_argument("--temperature", type=float, default=None) 33 | parser.add_argument("--top_p", type=float, default=None) 34 | parser.add_argument("--L", type=int, default=50) 35 | parser.add_argument("--n", type=int, default=10000) 36 | parser.add_argument("--batch_size", type=int, default=90) 37 | parser.add_argument("--save_dir", type=str, default="../cache/samples") 38 | args = parser.parse_args() 39 | accelerator = Accelerator() 40 | print(args) 41 | 42 | # setup model 43 | fixed_decoding_params = { 44 | "temperature": args.temperature if args.temperature is not None else 1, 45 | "top_p": args.top_p if args.top_p is not None else 1, 46 | } 47 | kwargs = {} 48 | if args.source == "fp16": 49 | kwargs = {"cast_dtype": torch.float16} 50 | elif args.source == "int8": 51 | kwargs = {"quantize": 8} 52 | elif args.source == "nf4": 53 | kwargs = {"quantize": 4} 54 | elif args.source == "watermark": 55 | kwargs = {"watermark_bias": 2.5} 56 | model = TransformersModel( 57 | args.model, 58 | accelerator=accelerator, 59 | fixed_decoding_params=fixed_decoding_params, 60 | batch_size=args.batch_size, 61 | **args.model_kwargs, 62 | **kwargs 63 | ) 64 | 65 | # load dataset 66 | ds = getattr(prompts_module, f"get_{args.prompts}_prompts")(model) 67 | try: 68 | # causes issues if kept 69 | ds = ds.remove_columns(["chat_tokens", "chat_with_ellipses_tokens"]) 70 | except: 71 | pass 72 | 73 | # get save string 74 | filename = build_cache_filename( 75 | model=args.model, 76 | prompts=args.prompts, 77 | alternative=args.source, 78 | temperature=args.temperature, 79 | top_p=args.top_p, 80 | do_sample=args.do_sample, 81 | L=args.L, 82 | ) 83 | 84 | # run through batches and dump the results 85 | for i in range(len(ds["chat"])): 86 | if os.path.exists(f"{args.save_dir}/{filename}-{i}.pkl"): 87 | print("Skipping", i) 88 | continue 89 | print(f"Collecting samples for prompt {i}") 90 | 91 | out = model.sample( 92 | [ds["chat"][i]], 93 | n=args.n, 94 | L=args.L, 95 | ) # (1, n, L) 96 | out = out.squeeze().tolist() # (n, L) 97 | with open(f"{args.save_dir}/{filename}-{i}.pkl", "wb") as f: 98 | pickle.dump(out, f) 99 | -------------------------------------------------------------------------------- /experiments/constants/minimum_prompt_indices.json: -------------------------------------------------------------------------------- 1 | { 2 | "wikipedia_en": [ 3 | 21, 4 | 22, 5 | 24, 6 | 28, 7 | 31, 8 | 32, 9 | 37, 10 | 39, 11 | 41, 12 | 42, 13 | 44, 14 | 46, 15 | 48, 16 | 49, 17 | 50, 18 | 51, 19 | 52, 20 | 53, 21 | 55, 22 | 57, 23 | 59, 24 | 60, 25 | 62, 26 | 63, 27 | 64, 28 | 65, 29 | 66, 30 | 67, 31 | 68, 32 | 70, 33 | 73, 34 | 74, 35 | 75, 36 | 77, 37 | 79, 38 | 80, 39 | 81, 40 | 82, 41 | 83, 42 | 87, 43 | 89, 44 | 90, 45 | 91, 46 | 95, 47 | 99 48 | ], 49 | "wikipedia_de": [ 50 | 20, 51 | 21, 52 | 23, 53 | 24, 54 | 25, 55 | 27, 56 | 29, 57 | 31, 58 | 33, 59 | 35, 60 | 36, 61 | 44, 62 | 46, 63 | 48, 64 | 49, 65 | 51, 66 | 55, 67 | 57, 68 | 59, 69 | 61, 70 | 64, 71 | 66, 72 | 67, 73 | 72, 74 | 74, 75 | 78, 76 | 80, 77 | 81, 78 | 83, 79 | 85, 80 | 88, 81 | 90, 82 | 92, 83 | 94, 84 | 95, 85 | 96, 86 | 97, 87 | 99 88 | ], 89 | "wikipedia_fr": [ 90 | 20, 91 | 24, 92 | 27, 93 | 28, 94 | 29, 95 | 34, 96 | 35, 97 | 36, 98 | 37, 99 | 45, 100 | 46, 101 | 47, 102 | 50, 103 | 51, 104 | 52, 105 | 57, 106 | 61, 107 | 63, 108 | 64, 109 | 65, 110 | 67, 111 | 70, 112 | 73, 113 | 76, 114 | 78, 115 | 81, 116 | 83, 117 | 85, 118 | 88, 119 | 89, 120 | 92, 121 | 93, 122 | 94, 123 | 95, 124 | 98, 125 | 99 126 | ], 127 | "wikipedia_ru": [ 128 | 20, 129 | 21, 130 | 24, 131 | 25, 132 | 27, 133 | 28, 134 | 29, 135 | 30, 136 | 31, 137 | 35, 138 | 37, 139 | 39, 140 | 44, 141 | 46, 142 | 51, 143 | 54, 144 | 56, 145 | 57, 146 | 60, 147 | 62, 148 | 65, 149 | 66, 150 | 69, 151 | 70, 152 | 71, 153 | 73, 154 | 74, 155 | 75, 156 | 79, 157 | 82, 158 | 86, 159 | 87, 160 | 89, 161 | 90, 162 | 92, 163 | 95, 164 | 97, 165 | 99 166 | ], 167 | "wikipedia_es": [ 168 | 23, 169 | 27, 170 | 28, 171 | 31, 172 | 32, 173 | 33, 174 | 36, 175 | 37, 176 | 38, 177 | 39, 178 | 40, 179 | 42, 180 | 50, 181 | 61, 182 | 62, 183 | 63, 184 | 66, 185 | 67, 186 | 69, 187 | 72, 188 | 74, 189 | 75, 190 | 77, 191 | 78, 192 | 80, 193 | 81, 194 | 82, 195 | 84, 196 | 85, 197 | 88, 198 | 90, 199 | 92, 200 | 93, 201 | 99 202 | ] 203 | } -------------------------------------------------------------------------------- /experiments/sampling/cache_logprobs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Given a set of samples, computes the logprobs of the individual completion tokens (one number per token) under the fp32 null. 3 | Used in dataset construction; later applied for goodness-of-fit testing. 4 | Assumes that samples are saved as *-{i}.pkl files, where i is the index of the prompt in the dataset, 5 | and that the dataset is implemented in experiments.prompts. 6 | """ 7 | 8 | import torch 9 | import tqdm 10 | import glob 11 | import pickle 12 | import argparse 13 | import os 14 | import experiments.prompts as prompts_module 15 | from experiments.sampling.model import TransformersModel 16 | from experiments.utils import ( 17 | str_to_bool, 18 | ParseKwargs, 19 | stack_with_padding, 20 | ) 21 | from accelerate import Accelerator 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | "--model", type=str, required=True, help="Model to get logprobs with" 27 | ) 28 | parser.add_argument("--model_kwargs", nargs="*", action=ParseKwargs, default={}) 29 | parser.add_argument( 30 | "--prompts", type=str, required=True, help="Name of prompts dataset to get logprobs for" 31 | ) 32 | parser.add_argument("--samples_path_template", type=str, required=True, help="String template for pkl file of samples, saved as list of lists of tokens ids. The first part of {template}-{i}.pkl") 33 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 34 | parser.add_argument("--temperature", type=float, default=None) 35 | parser.add_argument("--top_p", type=float, default=None) 36 | parser.add_argument("--batch_size", type=int, default=90) 37 | parser.add_argument("--save_dir", type=str, default="../cache/logprobs") 38 | args = parser.parse_args() 39 | accelerator = Accelerator() 40 | print(args) 41 | 42 | # setup model 43 | fixed_decoding_params = { 44 | "temperature": args.temperature, 45 | "top_p": args.top_p, 46 | } 47 | model = TransformersModel( 48 | args.model, 49 | accelerator=accelerator, 50 | fixed_decoding_params=fixed_decoding_params, 51 | batch_size=args.batch_size, 52 | **args.model_kwargs 53 | ) 54 | 55 | # load dataset 56 | ds = getattr(prompts_module, f"get_{args.prompts}_prompts")(model) 57 | try: 58 | # causes issues if kept 59 | ds = ds.remove_columns(["chat_tokens", "chat_with_ellipses_tokens"]) 60 | except: 61 | pass 62 | 63 | # run through batches and dump the results 64 | for path in tqdm.tqdm(glob.glob(f"{args.samples_path_template}-*")): 65 | i = int(os.path.basename(path).split("-")[-1].split(".")[0]) 66 | if os.path.exists(f"{args.save_dir}/{os.path.basename(args.samples_path_template)}-{i}.pkl"): 67 | print("Skipping", path) 68 | continue 69 | 70 | print(f"Collecting logprobs for prompt {i} based on path {path}") 71 | 72 | with open(path, "rb") as f: 73 | # note: this expects to pkl file to contain a list of lists of integers (token IDs) 74 | completions = [torch.tensor(x) for x in pickle.load(f)] 75 | 76 | completions, attention_mask = stack_with_padding(completions) 77 | 78 | logprobs = model.get_logprobs( 79 | prompts=[ds[i]["chat"]] * len(completions), 80 | completion_input_ids=completions, 81 | completion_attention_mask=attention_mask, 82 | ) 83 | with open(f"{args.save_dir}/{os.path.basename(args.samples_path_template)}-{i}.pkl", "wb") as f: 84 | out = {tuple(seq.tolist()): lp.tolist() for seq, lp in zip(completions, logprobs)} 85 | pickle.dump(out, f) -------------------------------------------------------------------------------- /experiments/sampling/cache_local_samples_vllm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Caches n samples / prompt by locally inferencing a model on a dataset of prompts. 3 | """ 4 | 5 | from vllm import LLM, SamplingParams 6 | import torch 7 | import tqdm 8 | import argparse 9 | import os 10 | import experiments.prompts as prompts_module 11 | from experiments.utils import ( 12 | str_to_bool, 13 | ParseKwargs, 14 | build_cache_filename, 15 | ) 16 | import pickle 17 | from accelerate import Accelerator 18 | from experiments.sampling.model import TransformersModel 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--model", type=str, required=True) 24 | parser.add_argument("--model_kwargs", nargs="*", action=ParseKwargs, default={}) 25 | parser.add_argument( 26 | "--source", 27 | type=str, 28 | default="fp32", 29 | ) 30 | parser.add_argument("--prompts", default="dummy", type=str) 31 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 32 | parser.add_argument("--temperature", type=float, default=None) 33 | parser.add_argument("--top_p", type=float, default=None) 34 | parser.add_argument("--L", type=int, default=50) 35 | parser.add_argument("--n", type=int, default=10000) 36 | parser.add_argument("--batch_size", type=int, default=90) 37 | parser.add_argument("--save_dir", type=str, default="../cache/samples") 38 | args = parser.parse_args() 39 | accelerator = Accelerator() 40 | print(args) 41 | 42 | # get dataset 43 | model_to_load_prompts = TransformersModel( 44 | args.model, 45 | accelerator=accelerator, 46 | skip_loading_weights=True, 47 | ) 48 | ds = getattr(prompts_module, f"get_{args.prompts}_prompts")(model_to_load_prompts) 49 | try: 50 | # causes issues if kept 51 | ds = ds.remove_columns(["chat_tokens", "chat_with_ellipses_tokens"]) 52 | except: 53 | pass 54 | 55 | # initialize vllm model 56 | if args.source == "None": 57 | kwargs = {"dtype": "float32"} 58 | elif args.source == "fp16": 59 | kwargs = {"dtype": "float16"} 60 | elif args.source == "nf4": 61 | kwargs = {"quantization": "bitsandbytes", "load_format": "bitsandbytes"} 62 | else: 63 | raise ValueError("Cannot use that with vllm") 64 | 65 | model = LLM( 66 | model=args.model, 67 | tensor_parallel_size=len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")), 68 | **kwargs, 69 | **args.model_kwargs, 70 | ) 71 | 72 | # get save string 73 | filename = build_cache_filename( 74 | model=args.model, 75 | prompts=args.prompts, 76 | alternative=args.source, 77 | temperature=args.temperature, 78 | top_p=args.top_p, 79 | do_sample=args.do_sample, 80 | L=args.L, 81 | ) 82 | 83 | sampling_params = SamplingParams( 84 | n=args.batch_size, 85 | temperature=args.temperature if args.temperature is not None else 1, 86 | top_p=args.top_p if args.top_p is not None else 1, 87 | top_k=1 if not args.do_sample else -1, 88 | max_tokens=args.L, 89 | stop_token_ids=[], 90 | skip_special_tokens=False, 91 | ignore_eos=True, 92 | logprobs=None, 93 | ) 94 | 95 | # run through batches and dump the results 96 | for i in range(len(ds)): 97 | if os.path.exists(f"{args.save_dir}/{filename}-{i}.pkl"): 98 | print("Skipping", i) 99 | continue 100 | print(f"Collecting samples for prompt {i}") 101 | 102 | out = [] 103 | for batch_size in tqdm.tqdm( 104 | [ 105 | min(args.batch_size, args.n - j * args.batch_size) 106 | for j in range( 107 | (args.n + args.batch_size - 1) // args.batch_size 108 | ) 109 | ] 110 | ): 111 | out.extend( 112 | model.generate( 113 | [ds["chat"][i]], 114 | sampling_params, 115 | ) 116 | ) 117 | sample = [oi.token_ids for o in out for oi in o.outputs] 118 | 119 | with open(f"{args.save_dir}/{filename}-{i}.pkl", "wb") as f: 120 | pickle.dump(sample, f) -------------------------------------------------------------------------------- /experiments/testing/cache_one_sample_bootstrap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calls model_equality_testing.pvalue.one_sample_parametric_bootstrap_pvalue repeatedly to cache simulated test statistics (parametric bootstrap) for a given null x test statistic. 3 | """ 4 | 5 | from model_equality_testing.dataset import load_distribution 6 | from model_equality_testing.pvalue import one_sample_parametric_bootstrap_pvalue 7 | import pickle 8 | from experiments.utils import build_cache_filename, str_to_bool 9 | import argparse 10 | import os 11 | from typing import List 12 | import json 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--model", 18 | type=str, 19 | help="Name of the model", 20 | default="meta-llama/Meta-Llama-3-8B-Instruct", 21 | ) 22 | parser.add_argument( 23 | "--null_distribution_source", 24 | type=str, 25 | default="fp32", 26 | help="Source to use as the null distribution", 27 | ) 28 | parser.add_argument( 29 | "--prompt_indices_json", 30 | type=str, 31 | help="JSON of prompt dataset name: [list of indices]. Prompt distribution will be uniform over these prompts", 32 | required=True, 33 | ) 34 | parser.add_argument( 35 | "--n", 36 | type=int, 37 | help="Number of samples to draw each time we compute the test statistic", 38 | default=None, 39 | ) 40 | parser.add_argument( 41 | "--n_per_prompt", 42 | type=int, 43 | help="Alternative to --n, number of samples per prompt", 44 | default=None, 45 | ) 46 | parser.add_argument( 47 | "--stat", type=str, default="g_squared", help="One-sample test statistic" 48 | ) 49 | parser.add_argument( 50 | "--test_in_unicode", 51 | type=str_to_bool, 52 | default=True, 53 | help="Test in unicode space instead of token space", 54 | ) 55 | parser.add_argument( 56 | "--b", type=int, default=1000, help="Number of bootstrap samples" 57 | ) 58 | parser.add_argument("--L", type=int, default=1000, help="Maximum completion length") 59 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 60 | parser.add_argument("--temperature", type=float, default=None) 61 | parser.add_argument("--top_p", type=float, default=None) 62 | parser.add_argument( 63 | "--save_dir", type=str, default="../cache/parametric_bootstrap_stats" 64 | ) 65 | args = parser.parse_args() 66 | print(args) 67 | 68 | # Load the prompt indices 69 | js = json.load(open(args.prompt_indices_json)) 70 | prompts = list(js.keys()) 71 | # get the number of prompts m 72 | m = sum([len(js[k]) for k in js]) 73 | print("Number of prompts: ", m) 74 | 75 | assert (args.n is not None) ^ ( 76 | args.n_per_prompt is not None 77 | ), "Exactly one of --n or --n_per_prompt must be provided" 78 | if args.n_per_prompt is not None: 79 | args.n = m * args.n_per_prompt 80 | print("Number of samples: ", args.n) 81 | 82 | # Load dataset 83 | p = load_distribution( 84 | model=args.model, 85 | prompt_ids=js, 86 | L=args.L, 87 | source=args.null_distribution_source, 88 | load_in_unicode=args.test_in_unicode, 89 | ) 90 | print("Null shape: ", p.shape) 91 | 92 | # Construct the filename 93 | filename = build_cache_filename( 94 | model=args.model, 95 | prompts=prompts, 96 | prompt_indices_json=args.prompt_indices_json, 97 | alternative=args.null_distribution_source, 98 | temperature=args.temperature, 99 | top_p=args.top_p, 100 | do_sample=args.do_sample, 101 | L=args.L, 102 | stat=args.stat, 103 | N=args.n, 104 | use_char_space=args.test_in_unicode, 105 | ) 106 | out_path = f"{args.save_dir}/{filename}.pkl" 107 | print(f"Will write to {out_path}") 108 | 109 | # Skip if already exists 110 | if os.path.exists(out_path): 111 | print("Skipping...") 112 | exit() 113 | 114 | # Cache the test statistics 115 | _, s = one_sample_parametric_bootstrap_pvalue( 116 | null_dist=p, 117 | n=args.n, 118 | b=args.b, 119 | return_stats=True, 120 | stat_type=args.stat, 121 | ) 122 | 123 | with open(out_path, "wb") as f: 124 | pickle.dump(s, f) 125 | -------------------------------------------------------------------------------- /model_equality_testing/src/algorithm.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, List, Dict 2 | from model_equality_testing.pvalue import ( 3 | EmpiricalPvalueCalculator, 4 | one_sample_parametric_bootstrap_pvalue, 5 | two_sample_parametric_bootstrap_pvalue, 6 | two_sample_permutation_pvalue, 7 | ) 8 | from model_equality_testing.tests import IMPLEMENTED_TESTS 9 | from model_equality_testing.distribution import ( 10 | CompletionSample, 11 | DistributionFromDataset, 12 | ) 13 | 14 | 15 | def _noop_pvalue(*args, **kwargs): 16 | return 1.0 17 | 18 | 19 | def run_goodness_of_fit_test( 20 | sample: CompletionSample, 21 | null_dist: DistributionFromDataset, 22 | get_pvalue: Union[callable, EmpiricalPvalueCalculator] = None, 23 | pvalue_type: str = "parametric_bootstrap", 24 | stat_type: str = "g_squared", 25 | b=1000, 26 | **kwargs, 27 | ) -> Tuple[float, float]: 28 | """ 29 | Tests whether the sample is drawn from the null distribution 30 | Args: 31 | sample: CompletionSample 32 | null_dist: DistributionFromDataset 33 | get_pvalue: callable or EmpiricalPvalueCalculator 34 | Given a test statistic, returns the p-value 35 | The function should take in one argument (a float, np.ndarray, or torch.Tensor) 36 | representing the observed statistic, and it should return a float (the pvalue). 37 | pvalue_type: str 38 | If get_pvalue is None, how to compute the p-value 39 | stat_type: str 40 | Which test statistic to compute 41 | b: int 42 | Number of bootstrap samples if pvalue_type is "parametric_bootstrap" 43 | kwargs 44 | Additional arguments to pass to the test statistic function 45 | Returns: 46 | pvalue: float 47 | statistic: float 48 | """ 49 | if get_pvalue is None: 50 | if pvalue_type == "parametric_bootstrap": 51 | get_pvalue = one_sample_parametric_bootstrap_pvalue( 52 | null_dist=null_dist, 53 | n=sample.N, 54 | b=b, 55 | return_stats=False, 56 | stat_type=stat_type, 57 | ) 58 | elif pvalue_type == "dummy": 59 | get_pvalue = _noop_pvalue 60 | else: 61 | raise ValueError("Unrecognized p-value type") 62 | 63 | statistic = IMPLEMENTED_TESTS[stat_type](sample, null_dist, **kwargs) 64 | pvalue = get_pvalue(statistic) 65 | if not isinstance(pvalue, float): 66 | pvalue = pvalue.item() 67 | return (pvalue, statistic) 68 | 69 | 70 | def run_two_sample_test( 71 | sample: CompletionSample, 72 | other_sample: CompletionSample, 73 | null_dist: DistributionFromDataset = None, 74 | get_pvalue: Union[callable, EmpiricalPvalueCalculator] = None, 75 | pvalue_type: str = "permutation_pvalue", 76 | stat_type: str = "two_sample_L2", 77 | b=1000, 78 | **kwargs, 79 | ) -> Tuple[float, float]: 80 | """ 81 | Tests whether the samples are drawn from the same distribution 82 | Args: 83 | sample: CompletionSample 84 | other_sample: CompletionSample 85 | null_dist: DistributionFromDataset 86 | get_pvalue: callable or EmpiricalPvalueCalculator 87 | Given a test statistic, returns the p-value 88 | The function should take in one argument (a float, np.ndarray, or torch.Tensor) 89 | representing the observed statistic, and it should return a float (the pvalue). 90 | pvalue_type: str 91 | If get_pvalue is None, how to compute the p-value 92 | stat_type: str 93 | Which test statistic to compute 94 | b: int 95 | Number of bootstrap samples if pvalue_type is "parametric_bootstrap" 96 | kwargs 97 | Additional arguments to pass to the test statistic function 98 | Returns: 99 | pvalue: float 100 | statistic: float 101 | """ 102 | if get_pvalue is None: 103 | if pvalue_type == "permutation_pvalue": 104 | get_pvalue = two_sample_permutation_pvalue( 105 | sample, other_sample, b=b, stat_type=stat_type, **kwargs 106 | ) 107 | elif pvalue_type == "parametric_bootstrap": 108 | assert ( 109 | null_dist is not None 110 | ), "Must provide null distribution for parametric bootstrap" 111 | get_pvalue = two_sample_parametric_bootstrap_pvalue( 112 | null_dist=null_dist, 113 | n1=sample.N, 114 | n2=other_sample.N, 115 | b=b, 116 | stat_type=stat_type, 117 | **kwargs, 118 | ) 119 | elif pvalue_type == "dummy": 120 | get_pvalue = _noop_pvalue 121 | else: 122 | raise ValueError("Unrecognized p-value type") 123 | 124 | statistic = IMPLEMENTED_TESTS[stat_type](sample, other_sample, **kwargs) 125 | pvalue = get_pvalue(statistic) 126 | if not isinstance(pvalue, float): 127 | pvalue = pvalue.item() 128 | return (pvalue, statistic) 129 | -------------------------------------------------------------------------------- /experiments/testing/cache_two_sample_bootstrap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calls model_equality_testing.pvalue.two_sample_parametric_bootstrap_pvalue repeatedly to cache simulated test statistics (parametric bootstrap) for a given null x test statistic. 3 | """ 4 | 5 | from model_equality_testing.dataset import load_distribution 6 | from model_equality_testing.pvalue import two_sample_parametric_bootstrap_pvalue 7 | import pickle 8 | from experiments.utils import build_cache_filename, str_to_bool 9 | import argparse 10 | import os 11 | from typing import List 12 | import json 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--model", 18 | type=str, 19 | help="Name of the model", 20 | default="meta-llama/Meta-Llama-3-8B-Instruct", 21 | ) 22 | parser.add_argument( 23 | "--null_distribution_source", 24 | type=str, 25 | default="fp32", 26 | help="Source to use as the null distribution", 27 | ) 28 | parser.add_argument( 29 | "--prompt_indices_json", 30 | type=str, 31 | help="JSON of prompt dataset name: [list of indices]. Prompt distribution will be uniform over these prompts", 32 | required=True, 33 | ) 34 | parser.add_argument( 35 | "--n1", 36 | type=int, 37 | help="Number of samples to compute statistics with, for the first sample (null in cache_power)", 38 | default=None, 39 | ) 40 | parser.add_argument( 41 | "--n1_per_prompt", 42 | type=int, 43 | help="Alternative to --n1, number of samples per prompt, for the first sample (null in cache_power)", 44 | default=None, 45 | ) 46 | parser.add_argument( 47 | "--n2", 48 | type=int, 49 | help="Number of samples to compute statistics with, for the second sample (alt in cache_power)", 50 | default=None, 51 | ) 52 | parser.add_argument( 53 | "--n2_per_prompt", 54 | type=int, 55 | help="Alternative to --n1, number of samples per prompt, for the second sample (alt in cache_power)", 56 | default=None, 57 | ) 58 | parser.add_argument( 59 | "--stat", type=str, default="two_sample_L2", help="One-sample test statistic" 60 | ) 61 | parser.add_argument( 62 | "--test_in_unicode", 63 | type=str_to_bool, 64 | default=True, 65 | help="Test in unicode space instead of token space", 66 | ) 67 | parser.add_argument( 68 | "--b", type=int, default=1000, help="Number of bootstrap samples" 69 | ) 70 | parser.add_argument("--L", type=int, default=1000, help="Maximum completion length") 71 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 72 | parser.add_argument("--temperature", type=float, default=None) 73 | parser.add_argument("--top_p", type=float, default=None) 74 | parser.add_argument( 75 | "--save_dir", type=str, default="../cache/parametric_bootstrap_stats" 76 | ) 77 | args = parser.parse_args() 78 | print(args) 79 | 80 | # Load the prompt indices 81 | js = json.load(open(args.prompt_indices_json)) 82 | prompts = list(js.keys()) 83 | # get the number of prompts m 84 | m = sum([len(js[k]) for k in js]) 85 | print("Number of prompts: ", m) 86 | 87 | assert (args.n1 is not None) ^ ( 88 | args.n1_per_prompt is not None 89 | ), "Exactly one of --n1 or --n1_per_prompt must be provided" 90 | if args.n1_per_prompt is not None: 91 | args.n1 = m * args.n1_per_prompt 92 | assert (args.n2 is not None) ^ ( 93 | args.n2_per_prompt is not None 94 | ), "Exactly one of --n2 or --n2_per_prompt must be provided" 95 | if args.n2_per_prompt is not None: 96 | args.n2 = m * args.n2_per_prompt 97 | print("Number of samples: ", args.n1, args.n2) 98 | 99 | # Load dataset 100 | p = load_distribution( 101 | model=args.model, 102 | prompt_ids=js, 103 | L=args.L, 104 | source=args.null_distribution_source, 105 | load_in_unicode=args.test_in_unicode, 106 | ) 107 | print("Null shape: ", p.shape) 108 | 109 | # Construct the filename 110 | filename = build_cache_filename( 111 | model=args.model, 112 | prompts=prompts, 113 | prompt_indices_json=args.prompt_indices_json, 114 | alternative=args.null_distribution_source, 115 | temperature=args.temperature, 116 | top_p=args.top_p, 117 | do_sample=args.do_sample, 118 | L=args.L, 119 | stat=args.stat, 120 | N=f"{args.n1}_{args.n2}", 121 | use_char_space=args.test_in_unicode, 122 | ) 123 | out_path = f"{args.save_dir}/{filename}.pkl" 124 | print(f"Will write to {out_path}") 125 | 126 | # Skip if already exists 127 | if os.path.exists(out_path): 128 | print("Skipping...") 129 | exit() 130 | 131 | # Cache the test statistics 132 | _, s = two_sample_parametric_bootstrap_pvalue( 133 | null_dist=p, 134 | n1=args.n1, 135 | n2=args.n2, 136 | b=args.b, 137 | return_stats=True, 138 | stat_type=args.stat, 139 | ) 140 | 141 | with open(out_path, "wb") as f: 142 | pickle.dump(s, f) 143 | -------------------------------------------------------------------------------- /experiments/sampling/cache_api_samples.py: -------------------------------------------------------------------------------- 1 | """ 2 | Caches n samples / prompt from an API on a dataset of prompts. 3 | """ 4 | 5 | import tqdm 6 | import os 7 | import argparse 8 | import experiments.prompts as prompts_module 9 | from experiments.sampling.model import TransformersModel 10 | from experiments.utils import ( 11 | str_to_bool, 12 | build_cache_filename, 13 | wait_if_error, 14 | ) 15 | import experiments.sampling.api as api_module 16 | import time 17 | from dataclasses import dataclass, asdict 18 | import time 19 | from accelerate import Accelerator 20 | import json 21 | import pickle 22 | 23 | """ 24 | Logic to choose how to query APIs; this was selected to try to account for different tokenization policies by APIs. Set by manually testing whether the number of tokens in the prompt is the same as the number of prompt tokens mentioned in the returned message. 25 | """ 26 | 27 | 28 | # case-by-case policies to handle uniform tokenization. use test_api_tokenization.py to set these 29 | def policy(api, model): 30 | if api in ["replicate"]: 31 | default = { 32 | "prompt_key": "chat", 33 | "use_chat_endpoint": False, 34 | "expected_prompt_key": "chat", 35 | } 36 | elif api == "amazon": 37 | default = { 38 | "prompt_key": "chat_with_special", 39 | "use_chat_endpoint": False, 40 | "expected_prompt_key": "chat", 41 | } 42 | else: 43 | default = { 44 | "prompt_key": "plain", 45 | "use_chat_endpoint": True, 46 | "expected_prompt_key": "chat", 47 | } 48 | return default 49 | 50 | 51 | FILE_DIR = os.path.dirname(os.path.abspath(__file__)) 52 | with open(f"{FILE_DIR}/../constants/minimum_prompt_indices.json") as f: 53 | BARE_MINIMUM = json.load(f) 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument( 58 | "--model", type=str, required=True, help="HF transformers model name" 59 | ) 60 | parser.add_argument("--backend", type=str, required=True, help="API name") 61 | parser.add_argument("--prompts", default="dummy") 62 | parser.add_argument("--L", type=int, default=3) 63 | parser.add_argument( 64 | "--n", type=int, default=1, help="Number of samples to generate per prompt" 65 | ) 66 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 67 | parser.add_argument("--temperature", type=float, default=None) 68 | parser.add_argument("--top_p", type=float, default=None) 69 | parser.add_argument("--val_cutoff", type=int, default=20) 70 | parser.add_argument( 71 | "--sample_bare_minimum", 72 | type=str_to_bool, 73 | default=False, 74 | help="Whether to only sample the bare minimum prompt indices used for the prompt distributions in the paper, listed in constants/minimum_prompt_indices.json", 75 | ) 76 | parser.add_argument("--save_dir", type=str, default="../cache/api") 77 | args = parser.parse_args() 78 | accelerator = Accelerator() 79 | print(args) 80 | 81 | # get dataset 82 | model_to_load_prompts = TransformersModel( 83 | args.model, 84 | accelerator=accelerator, 85 | skip_loading_weights=True, 86 | ) 87 | ds = getattr(prompts_module, f"get_{args.prompts}_prompts")(model_to_load_prompts) 88 | try: 89 | # causes issues if kept 90 | ds = ds.remove_columns(["chat_tokens", "chat_with_ellipses_tokens"]) 91 | except: 92 | pass 93 | 94 | setup_kwargs = policy(args.backend, args.model) 95 | 96 | # backend 97 | try: 98 | get_fn, kwargs = getattr(api_module, f"setup_{args.backend}")( 99 | model=args.model, 100 | N=1, 101 | L=args.L, 102 | use_chat_endpoint=setup_kwargs["use_chat_endpoint"], 103 | do_sample=args.do_sample, 104 | temperature=args.temperature, 105 | top_p=args.top_p, 106 | ) 107 | except AttributeError: 108 | raise ValueError("Unrecognized backend") 109 | 110 | filename = build_cache_filename( 111 | model=args.model, 112 | prompts=args.prompts, 113 | alternative=args.backend, 114 | temperature=args.temperature, 115 | top_p=args.top_p, 116 | do_sample=args.do_sample, 117 | L=args.L, 118 | ) 119 | filename = f"{args.save_dir}/{filename}" 120 | 121 | # get samples 122 | for it, x in tqdm.tqdm(enumerate(ds)): 123 | if ( 124 | args.sample_bare_minimum 125 | and args.prompts in BARE_MINIMUM 126 | and it not in BARE_MINIMUM[args.prompts] 127 | ): 128 | continue 129 | if os.path.exists(f"{filename}-{it}.pkl"): 130 | print("Skipping, exists...") 131 | continue 132 | 133 | print("Sampling", it) 134 | time.sleep(30) # avoid rate limiting 135 | 136 | out = [] 137 | for i in range(args.n): 138 | print("> sample #", i) 139 | o = wait_if_error( 140 | get_fn, timeout=10, prompt=x[setup_kwargs["prompt_key"]], **kwargs 141 | ) 142 | if o is not None: 143 | out.extend(o) 144 | out = [asdict(o) for o in out] 145 | d = { 146 | "samples": out, 147 | "id": x["id"], 148 | "prompt": x[setup_kwargs["prompt_key"]], 149 | "y": x.get("y", None), 150 | } 151 | 152 | with open(f"{filename}-{it}.pkl", "wb") as f: 153 | pickle.dump(d, f) 154 | -------------------------------------------------------------------------------- /experiments/testing/simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from model_equality_testing.distribution import DistributionFromDataset 4 | from model_equality_testing.algorithm import ( 5 | run_goodness_of_fit_test, 6 | run_two_sample_test, 7 | ) 8 | from typing import Union, List, Tuple 9 | import tqdm 10 | import gc 11 | 12 | 13 | def get_power_one_sample( 14 | null_dist: DistributionFromDataset, 15 | data_dist: DistributionFromDataset, 16 | n: int, 17 | n_simulations=100, 18 | alpha: float = 0.05, 19 | pvalue_type: str = "parametric_bootstrap", 20 | stat_type: str = "g_squared", 21 | get_pvalue_fn=None, 22 | return_pvalue: bool = False, 23 | return_alpha: bool = False, 24 | return_stat: bool = False, 25 | **kwargs, 26 | ): 27 | """ 28 | Power analysis for joint (prompt, completion) distribution 29 | """ 30 | rejections = [] 31 | pvalues = [] 32 | stat = [] 33 | for _ in tqdm.tqdm(range(n_simulations), desc="Power simulation"): 34 | sample = data_dist.sample(n=n) 35 | pv, s = run_goodness_of_fit_test( 36 | sample=sample, 37 | null_dist=null_dist, 38 | get_pvalue=get_pvalue_fn, 39 | pvalue_type=pvalue_type, 40 | stat_type=stat_type, 41 | **kwargs, 42 | ) 43 | stat.append(s) 44 | pvalues.append(pv) 45 | rejections.append(int(pv <= alpha)) 46 | power = sum(rejections) / n_simulations 47 | return_tuple = (power, np.array(rejections, dtype=bool)) 48 | if return_alpha: 49 | return_tuple += (alpha * np.ones((n_simulations, null_dist.m)),) 50 | if return_pvalue: 51 | return_tuple += (torch.from_numpy(np.stack(pvalues)),) 52 | if return_stat: 53 | return_tuple += (np.stack(stat),) 54 | return return_tuple 55 | 56 | 57 | def get_power_two_sample( 58 | null_dist: DistributionFromDataset, 59 | data_dist: DistributionFromDataset, 60 | n_null: int, 61 | n_data: int, 62 | n_simulations=100, 63 | alpha: float = 0.05, 64 | pvalue_type: str = "permutation_pvalue", 65 | stat_type: str = "two_sample_L2", 66 | get_pvalue_fn=None, 67 | return_pvalue: bool = False, 68 | return_alpha: bool = False, 69 | return_stat: bool = False, 70 | **kwargs, 71 | ): 72 | """ 73 | Power analysis for joint (prompt, completion) distribution 74 | """ 75 | rejections = [] 76 | pvalues = [] 77 | stat = [] 78 | for _ in tqdm.tqdm(range(n_simulations), desc="Power simulation"): 79 | sample_1 = null_dist.sample(n=n_null) 80 | sample_2 = data_dist.sample(n=n_data) 81 | pv, s = run_two_sample_test( 82 | sample=sample_1, 83 | other_sample=sample_2, 84 | null_dist=null_dist, 85 | get_pvalue=get_pvalue_fn, 86 | pvalue_type=pvalue_type, 87 | stat_type=stat_type, 88 | **kwargs, 89 | ) 90 | pvalues.append(pv) 91 | stat.append(s) 92 | rejections.append(int(pv <= alpha)) 93 | del sample_1, sample_2 94 | gc.collect() 95 | power = sum(rejections) / n_simulations 96 | return_tuple = (power, np.array(rejections, dtype=bool)) 97 | if return_alpha: 98 | return_tuple += (alpha * np.ones((n_simulations, null_dist.m)),) 99 | if return_pvalue: 100 | return_tuple += (torch.from_numpy(np.stack(pvalues)),) 101 | if return_stat: 102 | return_tuple += (np.stack(stat),) 103 | return return_tuple 104 | 105 | 106 | def get_power_two_sample_composite_null( 107 | null_dist_1: DistributionFromDataset, 108 | null_dist_2: DistributionFromDataset, 109 | data_dist: DistributionFromDataset, 110 | n_null: int, 111 | n_data: int, 112 | n_simulations=100, 113 | alpha: float = 0.05, 114 | pvalue_type: str = "permutation_pvalue", 115 | stat_type: str = "two_sample_L2", 116 | get_pvalue_fn_1=None, 117 | get_pvalue_fn_2=None, 118 | return_pvalue: bool = False, 119 | return_alpha: bool = False, 120 | return_stat: bool = False, 121 | **kwargs, 122 | ): 123 | """ 124 | Power analysis for joint (prompt, completion) distribution 125 | """ 126 | rejections = [] 127 | pvalues = [] 128 | stat = [] 129 | for _ in tqdm.tqdm(range(n_simulations), desc="Power simulation"): 130 | sample_1 = null_dist_1.sample(n=n_null) 131 | sample_2 = null_dist_2.sample(n=n_null) 132 | sample_3 = data_dist.sample(n=n_data) 133 | pv1, s1 = run_two_sample_test( 134 | sample=sample_1, 135 | other_sample=sample_3, 136 | get_pvalue=get_pvalue_fn_1, 137 | pvalue_type=pvalue_type, 138 | stat_type=stat_type, 139 | **kwargs, 140 | ) 141 | pv2, s2 = run_two_sample_test( 142 | sample=sample_2, 143 | other_sample=sample_3, 144 | get_pvalue=get_pvalue_fn_2, 145 | pvalue_type=pvalue_type, 146 | stat_type=stat_type, 147 | **kwargs, 148 | ) 149 | rejections.append(int((pv1 <= alpha) and (pv2 <= alpha))) 150 | pvalues.append((pv1, pv2)) 151 | stat.append((s1, s2)) 152 | del sample_1, sample_2 153 | gc.collect() 154 | power = sum(rejections) / n_simulations 155 | return_tuple = (power, np.array(rejections, dtype=bool)) 156 | if return_alpha: 157 | return_tuple += (alpha * np.ones((n_simulations, null_dist_1.m)),) 158 | if return_pvalue: 159 | return_tuple += (torch.from_numpy(np.stack(pvalues)),) 160 | if return_stat: 161 | return_tuple += (np.stack(stat),) 162 | return return_tuple 163 | -------------------------------------------------------------------------------- /model_equality_testing/src/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | import torch 3 | import torch.nn.functional as F 4 | from time import perf_counter 5 | import numpy as np 6 | from functools import lru_cache 7 | 8 | def tokenize_unicode(strings: List[str], pad_token_id: int = -1) -> np.ndarray: 9 | """ 10 | Tokenize a list of strings into a list of lists of unicode codepoints, 11 | and then stack them into a 2D numpy array, using -1 as padding. 12 | 13 | Args: 14 | strings (List[str]): list of strings to tokenize 15 | 16 | Returns: 17 | List[List[int]]: list of lists of unicode codepoints 18 | """ 19 | strings = [torch.tensor([ord(c) for c in s]) for s in strings] 20 | chr_array = stack_with_padding(strings, padding_token=-1)[0].numpy() 21 | return chr_array 22 | 23 | def pad_to_length(samples: np.ndarray, L: int, pad_token_id: int = -1) -> np.ndarray: 24 | """ 25 | Pad a 2D numpy array of samples to a fixed length L using a pad token. 26 | 27 | Args: 28 | samples (np.ndarray): 2D numpy array of samples 29 | L (int): length to pad to 30 | pad_token_id (int): padding token 31 | 32 | If the current length is longer than L, throws an error. 33 | 34 | Returns: 35 | np.ndarray: padded 2D numpy array 36 | """ 37 | if samples.shape[1] > L: 38 | raise ValueError(f"Cannot pad to length {L} because the current length is {samples.shape[1]}") 39 | padded_samples = np.full((samples.shape[0], L), pad_token_id) 40 | padded_samples[:, :samples.shape[1]] = samples 41 | return padded_samples 42 | 43 | def sanitize(s): 44 | """Sanitize a string for use as a filename.""" 45 | s = str(s) 46 | s = s.replace(" ", "-") 47 | s = s.replace("[", "") 48 | s = s.replace("]", "") 49 | s = s.replace(",", "_") 50 | s = s.replace("/", "-") 51 | s = s.replace("(", "") 52 | s = s.replace(")", "") 53 | return s 54 | 55 | 56 | def stack_with_padding( 57 | tensors: List[torch.Tensor], 58 | dim: int = 0, 59 | padding_side: str = "right", 60 | padding_mode: str = "constant", 61 | padding_token: Union[int, float] = 0, 62 | ) -> Tuple[torch.Tensor, torch.Tensor]: 63 | """ 64 | Stack tensors along specified dimension and pad them to ensure their size is equal in all dimensions. 65 | Returns the stacked tensor and a boolean mask indicating valid (non-padded) elements. 66 | 67 | Args: 68 | tensors (List[torch.Tensor]): list of tensors to stack 69 | dim (int): dimension along which to stack tensors. Defaults to 0. 70 | padding_side (str): side on which to pad - "left" or "right". Defaults to "right". 71 | padding_mode (str): 'constant', 'reflect', 'replicate' or 'circular'. Default: 'constant' 72 | padding_value (Union[int, float]): value to use for constant padding 73 | 74 | Returns: 75 | Tuple[torch.Tensor, torch.Tensor]: stacked tensor and boolean mask 76 | """ 77 | # Ensure all tensors have the same number of dimensions 78 | max_dims = max(t.dim() for t in tensors) 79 | tensors = [t.view(*t.shape, *([1] * (max_dims - t.dim()))) for t in tensors] 80 | 81 | # Find the maximum size for each dimension 82 | max_sizes = [max(t.shape[i] for t in tensors) for i in range(max_dims)] 83 | 84 | def make_padding(tensor_shape): 85 | padding = [] 86 | for i in reversed(range(max_dims)): # Reverse for F.pad expectations 87 | pad_size = max_sizes[i] - tensor_shape[i] 88 | if padding_side == "left": 89 | padding.extend([pad_size, 0]) 90 | elif padding_side == "right": 91 | padding.extend([0, pad_size]) 92 | else: 93 | raise ValueError(f"padding_side '{padding_side}' is unknown") 94 | return tuple(padding) 95 | 96 | padded_tensors = [] 97 | masks = [] 98 | 99 | for t in tensors: 100 | padding = make_padding(t.shape) 101 | padded_t = F.pad(t, padding, mode=padding_mode, value=padding_token) 102 | 103 | mask = torch.zeros_like(padded_t, dtype=torch.bool) 104 | slices = [] 105 | for i in range(max_dims): 106 | if padding_side == "left": 107 | slices.append(slice(max_sizes[i] - t.shape[i], None)) 108 | else: 109 | slices.append(slice(None, t.shape[i])) 110 | mask[tuple(slices)] = True 111 | 112 | padded_tensors.append(padded_t) 113 | masks.append(mask) 114 | 115 | stacked_tensor = torch.stack(padded_tensors, dim=dim) 116 | stacked_mask = torch.stack(masks, dim=dim) 117 | 118 | return stacked_tensor, stacked_mask 119 | 120 | 121 | class Stopwatch: 122 | """ 123 | Context manager for timing a block of code 124 | Source: https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time 125 | """ 126 | 127 | def __enter__(self): 128 | if torch.cuda.is_available(): 129 | torch.cuda.synchronize() 130 | self.time = perf_counter() 131 | return self 132 | 133 | def __exit__(self, type, value, traceback): 134 | if torch.cuda.is_available(): 135 | torch.cuda.synchronize() 136 | self.time = perf_counter() - self.time 137 | 138 | 139 | def ndim(p): 140 | """ 141 | Args: 142 | p: either a tensor or a list of tensors or a list of lists of tensors 143 | """ 144 | if isinstance(p, torch.Tensor): 145 | return p.ndim 146 | if not isinstance(p, (list, np.ndarray)): 147 | return 0 148 | elif len(p) > 0: 149 | return ndim(p[0]) + 1 150 | else: 151 | return 1 152 | 153 | 154 | @lru_cache(maxsize=100) 155 | def get_inv(lst): 156 | """ 157 | Given a list of items, returns a dict of {item: ix} 158 | """ 159 | return {x: idx for idx, x in enumerate(lst)} 160 | -------------------------------------------------------------------------------- /experiments/testing/simulate_one_sample_power.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simulate power of a one-sample test. 3 | """ 4 | 5 | from model_equality_testing.dataset import load_distribution 6 | from experiments.testing.simulation import get_power_one_sample 7 | from experiments.utils import build_cache_filename, str_to_bool 8 | from experiments.testing.bootstrap_manager import BootstrapManager 9 | import argparse 10 | import os 11 | from typing import List 12 | import json 13 | import argparse 14 | import os 15 | import torch 16 | import numpy as np 17 | from accelerate import Accelerator 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--model", 23 | type=str, 24 | help="Name of the model", 25 | default="meta-llama/Meta-Llama-3-8B-Instruct", 26 | ) 27 | parser.add_argument( 28 | "--null_distribution_source", 29 | type=str, 30 | default="fp32", 31 | help="Source to use as the null distribution P", 32 | ) 33 | parser.add_argument( 34 | "--alternative_distribution_source", 35 | type=str, 36 | default="fp32", 37 | help="Source to use as the alternative distribution Q", 38 | ) 39 | parser.add_argument( 40 | "--n_simulations", 41 | type=int, 42 | default=100, 43 | help="Number of simulations to run to estimate power", 44 | ) 45 | parser.add_argument("--alpha", type=float, default=0.05, help="Significance level") 46 | parser.add_argument( 47 | "--prompt_indices_json", 48 | type=str, 49 | help="JSON of prompt dataset name: [list of indices]. Prompt distribution will be uniform over these prompts", 50 | required=True, 51 | ) 52 | parser.add_argument( 53 | "--n", 54 | type=int, 55 | help="Number of samples to draw each time we compute the test statistic", 56 | default=None, 57 | ) 58 | parser.add_argument( 59 | "--n_per_prompt", 60 | type=int, 61 | help="Alternative to --n, number of samples per prompt", 62 | default=None, 63 | ) 64 | parser.add_argument( 65 | "--stat", type=str, default="g_squared", help="One-sample test statistic" 66 | ) 67 | parser.add_argument( 68 | "--pvalue_type", 69 | type=str, 70 | default="parametric_bootstrap", 71 | help="Type of p-value", 72 | ) 73 | parser.add_argument( 74 | "--test_in_unicode", 75 | type=str_to_bool, 76 | default=True, 77 | help="Test in unicode space instead of token space", 78 | ) 79 | parser.add_argument("--L", type=int, default=1000, help="Maximum completion length") 80 | parser.add_argument("--max_b", type=int, default=None) 81 | parser.add_argument("--min_b", type=int, default=None) 82 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 83 | parser.add_argument("--temperature", type=float, default=None) 84 | parser.add_argument("--top_p", type=float, default=None) 85 | parser.add_argument("--save_dir", type=str, default="../cache/power") 86 | parser.add_argument("--bootstrap_dir", type=str, default="../cache/parametric_bootstrap_stats") 87 | accelerator = Accelerator() 88 | args = parser.parse_args() 89 | print(args) 90 | 91 | # Load the prompt indices 92 | js = json.load(open(args.prompt_indices_json)) 93 | prompts = list(js.keys()) 94 | # get the number of prompts m 95 | m = sum([len(js[k]) for k in js]) 96 | print("Number of prompts: ", m) 97 | 98 | assert (args.n is not None) ^ ( 99 | args.n_per_prompt is not None 100 | ), "Exactly one of --n or --n_per_prompt must be provided" 101 | if args.n_per_prompt is not None: 102 | args.n = m * args.n_per_prompt 103 | print("Number of samples: ", args.n) 104 | 105 | # Load null distribution 106 | p = load_distribution( 107 | model=args.model, 108 | prompt_ids=js, 109 | L=args.L, 110 | source=args.null_distribution_source, 111 | load_in_unicode=args.test_in_unicode, 112 | ) 113 | print("Null shape: ", p.shape) 114 | 115 | # Start building the output filename 116 | filename_stem = build_cache_filename( 117 | model=args.model, 118 | prompts=prompts, 119 | prompt_indices_json=args.prompt_indices_json, 120 | use_char_space=args.test_in_unicode, 121 | alternative=args.null_distribution_source, 122 | temperature=args.temperature, 123 | top_p=args.top_p, 124 | do_sample=args.do_sample, 125 | L=args.L, 126 | stat=args.stat, 127 | N="{n}", 128 | ) 129 | filename = filename_stem.format(n=args.n) 130 | filename += f"-pvalue={args.pvalue_type}-alt={args.alternative_distribution_source}" 131 | 132 | if os.path.exists(f"{args.save_dir}/{filename}.pt"): 133 | print("Skipping b/c already exists...") 134 | exit() 135 | 136 | # Set up bootstrap manager 137 | get_pvalue = None 138 | if args.pvalue_type == "parametric_bootstrap": 139 | try: 140 | bootstrap_manager = BootstrapManager( 141 | bootstrap_path_template=f"{args.bootstrap_dir}/{filename_stem}.pkl", 142 | min_b=args.min_b, 143 | max_b=args.max_b, 144 | ) 145 | get_pvalue = bootstrap_manager.load(n=args.n) 146 | except: 147 | get_pvalue = None 148 | 149 | # Load alternative distribution 150 | q = load_distribution( 151 | model=args.model, 152 | prompt_ids=js, 153 | L=args.L, 154 | source=args.alternative_distribution_source, 155 | load_in_unicode=args.test_in_unicode, 156 | ) 157 | print("Alternative shape: ", q.shape) 158 | 159 | # Simulate! 160 | ( 161 | pwr, 162 | reject_history, 163 | alpha_history, 164 | pvalue_history, 165 | test_stats, 166 | ) = get_power_one_sample( 167 | p, 168 | q, 169 | n=args.n, 170 | n_simulations=args.n_simulations, 171 | pvalue_type=args.pvalue_type, 172 | stat_type=args.stat, 173 | get_pvalue_fn=get_pvalue, 174 | return_pvalue=True, 175 | return_alpha=True, 176 | return_stat=True, 177 | alpha=args.alpha, 178 | ) 179 | 180 | # Save results 181 | print("Power: ", pwr) 182 | torch.save( 183 | { 184 | "power": pwr, 185 | "reject_history": reject_history, 186 | "alpha_history": alpha_history, 187 | "pvalue_history": pvalue_history, 188 | "test_stats": test_stats, 189 | }, 190 | f"{args.save_dir}/{filename}.pt", 191 | ) 192 | -------------------------------------------------------------------------------- /experiments/testing/simulate_two_sample_power.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simulate power of a two-sample test. 3 | """ 4 | 5 | from model_equality_testing.dataset import load_distribution 6 | from experiments.testing.simulation import get_power_two_sample 7 | from experiments.utils import build_cache_filename, str_to_bool 8 | from experiments.testing.bootstrap_manager import BootstrapManager 9 | import argparse 10 | import os 11 | from typing import List 12 | import json 13 | import argparse 14 | import os 15 | import torch 16 | import numpy as np 17 | from accelerate import Accelerator 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--model", 23 | type=str, 24 | help="Name of the model", 25 | default="meta-llama/Meta-Llama-3-8B-Instruct", 26 | ) 27 | parser.add_argument( 28 | "--null_distribution_source", 29 | type=str, 30 | default="fp32", 31 | help="Source to use as the null distribution P", 32 | ) 33 | parser.add_argument( 34 | "--alternative_distribution_source", 35 | type=str, 36 | default="fp32", 37 | help="Source to use as the alternative distribution Q", 38 | ) 39 | parser.add_argument( 40 | "--n_simulations", 41 | type=int, 42 | default=100, 43 | help="Number of simulations to run to estimate power", 44 | ) 45 | parser.add_argument("--alpha", type=float, default=0.05, help="Significance level") 46 | parser.add_argument( 47 | "--prompt_indices_json", 48 | type=str, 49 | help="JSON of prompt dataset name: [list of indices]. Prompt distribution will be uniform over these prompts", 50 | required=True, 51 | ) 52 | parser.add_argument( 53 | "--n1", 54 | type=int, 55 | help="Number of samples to compute statistics with, for the first sample (null in cache_power)", 56 | default=None, 57 | ) 58 | parser.add_argument( 59 | "--n1_per_prompt", 60 | type=int, 61 | help="Alternative to --n1, number of samples per prompt, for the first sample (null in cache_power)", 62 | default=None, 63 | ) 64 | parser.add_argument( 65 | "--n2", 66 | type=int, 67 | help="Number of samples to compute statistics with, for the second sample (alt in cache_power)", 68 | default=None, 69 | ) 70 | parser.add_argument( 71 | "--n2_per_prompt", 72 | type=int, 73 | help="Alternative to --n1, number of samples per prompt, for the second sample (alt in cache_power)", 74 | default=None, 75 | ) 76 | parser.add_argument( 77 | "--stat", type=str, default="two_sample_L2", help="Two-sample test statistic" 78 | ) 79 | parser.add_argument( 80 | "--pvalue_type", 81 | type=str, 82 | default="parametric_bootstrap", 83 | help="Type of p-value", 84 | ) 85 | parser.add_argument( 86 | "--test_in_unicode", 87 | type=str_to_bool, 88 | default=True, 89 | help="Test in unicode space instead of token space", 90 | ) 91 | parser.add_argument("--L", type=int, default=1000, help="Maximum completion length") 92 | parser.add_argument("--max_b", type=int, default=None) 93 | parser.add_argument("--min_b", type=int, default=None) 94 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 95 | parser.add_argument("--temperature", type=float, default=None) 96 | parser.add_argument("--top_p", type=float, default=None) 97 | parser.add_argument("--save_dir", type=str, default="../cache/power") 98 | parser.add_argument( 99 | "--bootstrap_dir", type=str, default="../cache/parametric_bootstrap_stats" 100 | ) 101 | accelerator = Accelerator() 102 | args = parser.parse_args() 103 | print(args) 104 | 105 | # Load the prompt indices 106 | js = json.load(open(args.prompt_indices_json)) 107 | prompts = list(js.keys()) 108 | # get the number of prompts m 109 | m = sum([len(js[k]) for k in js]) 110 | print("Number of prompts: ", m) 111 | 112 | assert (args.n1 is not None) ^ ( 113 | args.n1_per_prompt is not None 114 | ), "Exactly one of --n1 or --n1_per_prompt must be provided" 115 | if args.n1_per_prompt is not None: 116 | args.n1 = m * args.n1_per_prompt 117 | assert (args.n2 is not None) ^ ( 118 | args.n2_per_prompt is not None 119 | ), "Exactly one of --n2 or --n2_per_prompt must be provided" 120 | if args.n2_per_prompt is not None: 121 | args.n2 = m * args.n2_per_prompt 122 | print("Number of samples: ", args.n1, args.n2) 123 | 124 | # Load null distribution 125 | p = load_distribution( 126 | model=args.model, 127 | prompt_ids=js, 128 | L=args.L, 129 | source=args.null_distribution_source, 130 | load_in_unicode=args.test_in_unicode, 131 | ) 132 | print("Null shape: ", p.shape) 133 | 134 | # Start building the output filename 135 | filename_stem = build_cache_filename( 136 | model=args.model, 137 | prompts=prompts, 138 | prompt_indices_json=args.prompt_indices_json, 139 | use_char_space=args.test_in_unicode, 140 | alternative=args.null_distribution_source, 141 | temperature=args.temperature, 142 | top_p=args.top_p, 143 | do_sample=args.do_sample, 144 | L=args.L, 145 | stat=args.stat, 146 | N="{n1}_{n2}", 147 | ) 148 | filename = filename_stem.format(n1=args.n1, n2=args.n2) 149 | filename += f"-pvalue={args.pvalue_type}-alt={args.alternative_distribution_source}" 150 | 151 | if os.path.exists(f"{args.save_dir}/{filename}.pt"): 152 | print("Skipping b/c already exists...") 153 | exit() 154 | 155 | # Set up bootstrap manager 156 | get_pvalue = None 157 | if args.pvalue_type == "parametric_bootstrap": 158 | try: 159 | bootstrap_manager = BootstrapManager( 160 | bootstrap_path_template=f"{args.bootstrap_dir}/{filename_stem}.pkl", 161 | min_b=args.min_b, 162 | max_b=args.max_b, 163 | ) 164 | get_pvalue = bootstrap_manager.load(n1=args.n1, n2=args.n2) 165 | except: 166 | get_pvalue = None 167 | 168 | # Load alternative distribution 169 | q = load_distribution( 170 | model=args.model, 171 | prompt_ids=js, 172 | L=args.L, 173 | source=args.alternative_distribution_source, 174 | load_in_unicode=args.test_in_unicode, 175 | ) 176 | print("Alternative shape: ", q.shape) 177 | 178 | # Simulate! 179 | ( 180 | pwr, 181 | reject_history, 182 | alpha_history, 183 | pvalue_history, 184 | test_stats, 185 | ) = get_power_two_sample( 186 | p, 187 | q, 188 | n_null=args.n1, 189 | n_data=args.n2, 190 | n_simulations=args.n_simulations, 191 | pvalue_type=args.pvalue_type, 192 | stat_type=args.stat, 193 | get_pvalue_fn=get_pvalue, 194 | return_pvalue=True, 195 | return_alpha=True, 196 | return_stat=True, 197 | alpha=args.alpha, 198 | ) 199 | 200 | # Save results 201 | print("Power: ", pwr) 202 | torch.save( 203 | { 204 | "power": pwr, 205 | "reject_history": reject_history, 206 | "alpha_history": alpha_history, 207 | "pvalue_history": pvalue_history, 208 | "test_stats": test_stats, 209 | }, 210 | f"{args.save_dir}/{filename}.pt", 211 | ) 212 | -------------------------------------------------------------------------------- /model_equality_testing/src/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions for downloading and loading the dataset we release. 3 | """ 4 | 5 | from model_equality_testing.distribution import DistributionFromDataset 6 | from typing import Literal, List, Dict, Union, Tuple 7 | from model_equality_testing.utils import sanitize, stack_with_padding, tokenize_unicode 8 | import glob 9 | import numpy as np 10 | import torch 11 | from transformers import AutoTokenizer 12 | import pickle 13 | from collections import defaultdict 14 | import os 15 | import zipfile 16 | 17 | def download_dataset(root_dir="./data"): 18 | """ 19 | Download the dataset from Google Drive and extract it into the specified directory. 20 | 21 | Arguments: 22 | - root_dir (str): The root directory where the dataset will be saved. 23 | - split (str or None): 'test', 'val', or None. If None, both splits will be downloaded. 24 | """ 25 | try: 26 | import gdown 27 | except: 28 | raise ImportError("Please install gdown to download the dataset: pip install gdown") 29 | 30 | os.makedirs(root_dir, exist_ok=True) 31 | file_id = "1csgp83tx04kVA9ejp6MNSC7Ni0vh0J9U" 32 | download_url = f"https://drive.google.com/uc?id={file_id}" 33 | 34 | zip_file_path = os.path.join(root_dir, "model_equality_testing_ds.zip") 35 | extract_dir = root_dir 36 | 37 | print(f"Downloading dataset...") 38 | if not os.path.exists(zip_file_path): 39 | gdown.download(download_url, zip_file_path, quiet=False) 40 | 41 | print(f"Extracting dataset...") 42 | with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: 43 | zip_ref.extractall(extract_dir) 44 | 45 | os.remove(zip_file_path) 46 | print(f"Dataset downloaded and extracted successfully.") 47 | 48 | 49 | SOURCES = [ 50 | "fp32", 51 | "fp16", 52 | "nf4", 53 | "int8", 54 | "watermark", 55 | "anyscale", 56 | "amazon", 57 | "fireworks", 58 | "replicate", 59 | "deepinfra", 60 | "groq", 61 | "perplexity", 62 | "together", 63 | "azure", 64 | ] 65 | 66 | MODELS = [ 67 | "meta-llama/Meta-Llama-3-8B-Instruct", 68 | "meta-llama/Meta-Llama-3-70B-Instruct", 69 | "meta-llama/Meta-Llama-3.1-8B-Instruct", 70 | "meta-llama/Meta-Llama-3.1-70B-Instruct", 71 | "mistralai/Mistral-7B-Instruct-v0.3", 72 | ] 73 | 74 | 75 | def load_distribution( 76 | model: str, 77 | prompt_ids: dict, 78 | L: int, 79 | source: str, 80 | prompt_distribution: np.ndarray = None, 81 | load_in_unicode: bool = True, 82 | root_dir="./data", 83 | ) -> DistributionFromDataset: 84 | """ 85 | Given a dictionary mapping {dataset_name: [prompt_ids]}, load the dataset and return a DistributionFromDataset object. 86 | Args: 87 | model: a string representing the model name 88 | prompt_ids: a dictionary mapping {dataset_name: [prompt_ids]} 89 | where dataset_name is a string and prompt_ids is a list of strings. 90 | Example: 91 | { 92 | "wikipedia_en": [0, 1, 4], 93 | "wikipedia_es": [5], 94 | } 95 | => DistributionFromDataset(m=4, L=L) 96 | L: completion length 97 | source: a string representing 98 | Returns: 99 | a DistributionFromDataset object, which allows for sampling from the dataset. 100 | """ 101 | assert source in SOURCES, f"source must be one of {SOURCES}" 102 | print(f"Loading dataset from source: {source}") 103 | print(f"Prompt IDs: {prompt_ids}") 104 | 105 | tokenizer = AutoTokenizer.from_pretrained(model) 106 | tokenizer.padding_side = "left" 107 | tokenizer.truncation_side = "left" 108 | if tokenizer.pad_token_id is None: 109 | tokenizer.pad_token = tokenizer.eos_token 110 | tokenizer.pad_token_id = tokenizer.eos_token_id 111 | 112 | if source in ["fp32", "fp16", "nf4", "int8"]: 113 | if load_in_unicode: 114 | load_fn = lambda x: _load_local_samples_unicode(x, tokenizer) 115 | else: 116 | load_fn = lambda x: _load_local_samples_tokens(x, tokenizer) 117 | else: 118 | assert load_in_unicode, "Only unicode supported for API samples" 119 | load_fn = lambda x: _load_api_samples_unicode(x, tokenizer) 120 | 121 | filenames_and_callables = [] 122 | logprob_filenames_and_callables = [] 123 | for ds, ids in prompt_ids.items(): 124 | for id in ids: 125 | name = f"{sanitize(model)}-{ds}-{source}-L=*-{id}" 126 | lst = glob.glob(f"{root_dir}/samples/{name}.pkl") 127 | if len(lst) == 0: 128 | print( 129 | f"Warning: no pkl file matching the pattern {name} was found in the specified dataset directory {root_dir}.", 130 | "Make sure the `root_dir` variable is set correctly and that `model` and `prompt_ids` exist in the dataset." 131 | ) 132 | fn = lst[0] 133 | filenames_and_callables.append((fn, load_fn)) 134 | 135 | # logprobs: include if not in unicode and the file exists 136 | if not load_in_unicode: 137 | try: 138 | fn = glob.glob(f"{root_dir}/logprobs/{name}.pkl")[0] 139 | logprob_filenames_and_callables.append((fn, lambda x: _load_logprobs(x, tokenizer))) 140 | except IndexError: 141 | pass 142 | 143 | return DistributionFromDataset( 144 | sample_paths=filenames_and_callables, 145 | L=L, 146 | prompt_distribution=prompt_distribution, 147 | logprob_paths=logprob_filenames_and_callables, 148 | pad_token_id=tokenizer.pad_token_id if not load_in_unicode else -1, 149 | ) 150 | 151 | 152 | def _pad_after_first_eos(array, eos_token_id, pad_token_id): 153 | """ 154 | In each sequence, converts all tokens to the right of the first occurrence 155 | of the eos_token to a pad_token. 156 | Edge case: does not pad if the eos token is the first token in the sequence 157 | """ 158 | assert array.ndim == 2 159 | eos_locs = np.expand_dims((array == eos_token_id).astype(float).argmax(axis=1), 1) 160 | eos_locs[eos_locs == 0] = array.shape[1] 161 | col_indices = np.repeat( 162 | np.expand_dims(np.arange(array.shape[1]), 0), array.shape[0], axis=0 163 | ) 164 | array[col_indices > eos_locs] = pad_token_id 165 | return array 166 | 167 | 168 | def _load_local_samples_tokens(path, tok) -> np.ndarray: 169 | """ 170 | Loads local samples into a (k, max_L) array of token ids. 171 | Assumes samples are saved as lists of lists of token ids (.pkl files). 172 | """ 173 | with open(path, "rb") as f: 174 | array = pickle.load(f) 175 | array = np.array(array) 176 | array = _pad_after_first_eos(array, tok.eos_token_id, tok.pad_token_id) 177 | return array 178 | 179 | 180 | def _load_local_samples_unicode(path, tok) -> np.ndarray: 181 | """ 182 | Loads local samples into a (k, max_L) array of character ids (using Python's ord function). 183 | Assumes samples are saved as lists of token ids (.pkl files). 184 | Uses -1 as a padding token in the output. 185 | """ 186 | with open(path, "rb") as f: 187 | array = pickle.load(f) 188 | array = np.array(array) 189 | array = _pad_after_first_eos(array, tok.eos_token_id, tok.pad_token_id) 190 | strings = tok.batch_decode(array, skip_special_tokens=True) 191 | return tokenize_unicode(strings) 192 | 193 | 194 | def _load_api_samples_unicode(path, tok) -> np.ndarray: 195 | """ 196 | Loads API samples from disk where each character is assigned an id, rather than in token space. 197 | """ 198 | with open(path, "rb") as f: 199 | js = pickle.load(f) 200 | samples = js["samples"] 201 | samples = [torch.tensor([ord(c) for c in s["full_completion"]]) for s in samples] 202 | chr_array = stack_with_padding(samples, padding_token=-1)[0].numpy() 203 | return chr_array 204 | 205 | 206 | def _load_logprobs(path, tok) -> dict: 207 | """ 208 | Returns a dictionary of {prompt: {sequence: logprob}} as saved 209 | by cache_logprobs.py or cache_distribution_by_sampling.py using utils.dump. 210 | """ 211 | with open(path, "rb") as f: 212 | logprobs = pickle.load(f) 213 | 214 | # merge dictionaries with the same keys 215 | new_dict = {} 216 | for k, v in logprobs.items(): 217 | k = list(k) 218 | v = torch.tensor(v) 219 | try: 220 | ix = k.index(tok.eos_token_id) 221 | k[ix + 1 :] = [tok.pad_token_id] * (len(k) - ix - 1) 222 | v[ix + 1 :] = 0 223 | except: 224 | pass 225 | new_dict[tuple(k)] = v 226 | return new_dict 227 | -------------------------------------------------------------------------------- /experiments/testing/simulate_two_sample_power_composite.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simulate power of a two-sample test. 3 | """ 4 | 5 | from model_equality_testing.dataset import load_distribution 6 | from experiments.testing.simulation import get_power_two_sample_composite_null 7 | from experiments.utils import build_cache_filename, str_to_bool 8 | from experiments.testing.bootstrap_manager import BootstrapManager 9 | import argparse 10 | import os 11 | from typing import List 12 | import json 13 | import argparse 14 | import os 15 | import torch 16 | import numpy as np 17 | from accelerate import Accelerator 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--model", 23 | type=str, 24 | help="Name of the model", 25 | default="meta-llama/Meta-Llama-3-8B-Instruct", 26 | ) 27 | parser.add_argument( 28 | "--null_distribution_1_source", 29 | type=str, 30 | default="fp32", 31 | help="Source to use as the null distribution P1", 32 | ) 33 | parser.add_argument( 34 | "--null_distribution_2_source", 35 | type=str, 36 | default="fp32", 37 | help="Source to use as the null distribution P2", 38 | ) 39 | parser.add_argument( 40 | "--alternative_distribution_source", 41 | type=str, 42 | default="fp32", 43 | help="Source to use as the alternative distribution Q", 44 | ) 45 | parser.add_argument( 46 | "--n_simulations", 47 | type=int, 48 | default=100, 49 | help="Number of simulations to run to estimate power", 50 | ) 51 | parser.add_argument("--alpha", type=float, default=0.05, help="Significance level") 52 | parser.add_argument( 53 | "--prompt_indices_json", 54 | type=str, 55 | help="JSON of prompt dataset name: [list of indices]. Prompt distribution will be uniform over these prompts", 56 | required=True, 57 | ) 58 | parser.add_argument( 59 | "--n1", 60 | type=int, 61 | help="Number of samples to compute statistics with, for the first sample (null in cache_power)", 62 | default=None, 63 | ) 64 | parser.add_argument( 65 | "--n1_per_prompt", 66 | type=int, 67 | help="Alternative to --n1, number of samples per prompt, for the first sample (null in cache_power)", 68 | default=None, 69 | ) 70 | parser.add_argument( 71 | "--n2", 72 | type=int, 73 | help="Number of samples to compute statistics with, for the second sample (alt in cache_power)", 74 | default=None, 75 | ) 76 | parser.add_argument( 77 | "--n2_per_prompt", 78 | type=int, 79 | help="Alternative to --n1, number of samples per prompt, for the second sample (alt in cache_power)", 80 | default=None, 81 | ) 82 | parser.add_argument( 83 | "--stat", type=str, default="two_sample_L2", help="Two-sample test statistic" 84 | ) 85 | parser.add_argument( 86 | "--pvalue_type", 87 | type=str, 88 | default="parametric_bootstrap", 89 | help="Type of p-value", 90 | ) 91 | parser.add_argument( 92 | "--test_in_unicode", 93 | type=str_to_bool, 94 | default=True, 95 | help="Test in unicode space instead of token space", 96 | ) 97 | parser.add_argument("--L", type=int, default=1000, help="Maximum completion length") 98 | parser.add_argument("--max_b", type=int, default=None) 99 | parser.add_argument("--min_b", type=int, default=None) 100 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 101 | parser.add_argument("--temperature", type=float, default=None) 102 | parser.add_argument("--top_p", type=float, default=None) 103 | parser.add_argument("--save_dir", type=str, default="../cache/power") 104 | parser.add_argument( 105 | "--bootstrap_dir", type=str, default="../cache/parametric_bootstrap_stats" 106 | ) 107 | accelerator = Accelerator() 108 | args = parser.parse_args() 109 | print(args) 110 | 111 | # Load the prompt indices 112 | js = json.load(open(args.prompt_indices_json)) 113 | prompts = list(js.keys()) 114 | # get the number of prompts m 115 | m = sum([len(js[k]) for k in js]) 116 | print("Number of prompts: ", m) 117 | 118 | assert (args.n1 is not None) ^ ( 119 | args.n1_per_prompt is not None 120 | ), "Exactly one of --n1 or --n1_per_prompt must be provided" 121 | if args.n1_per_prompt is not None: 122 | args.n1 = m * args.n1_per_prompt 123 | assert (args.n2 is not None) ^ ( 124 | args.n2_per_prompt is not None 125 | ), "Exactly one of --n2 or --n2_per_prompt must be provided" 126 | if args.n2_per_prompt is not None: 127 | args.n2 = m * args.n2_per_prompt 128 | print("Number of samples: ", args.n1, args.n2) 129 | 130 | # Load null distribution 131 | p1 = load_distribution( 132 | model=args.model, 133 | prompt_ids=js, 134 | L=args.L, 135 | source=args.null_distribution_1_source, 136 | load_in_unicode=args.test_in_unicode, 137 | ) 138 | p2 = load_distribution( 139 | model=args.model, 140 | prompt_ids=js, 141 | L=args.L, 142 | source=args.null_distribution_2_source, 143 | load_in_unicode=args.test_in_unicode, 144 | ) 145 | print("Null shape: ", p1.shape, p2.shape) 146 | 147 | # Start building the output filename 148 | filename_stem = build_cache_filename( 149 | model=args.model, 150 | prompts=prompts, 151 | use_char_space=args.test_in_unicode, 152 | alternative=( 153 | args.null_distribution_1_source, 154 | args.null_distribution_2_source, 155 | ), 156 | temperature=args.temperature, 157 | top_p=args.top_p, 158 | do_sample=args.do_sample, 159 | L=args.L, 160 | stat=args.stat, 161 | N="{n1}_{n2}", 162 | prompt_indices_json=args.prompt_indices_json, 163 | ) 164 | filename = filename_stem.format(n1=args.n1, n2=args.n2) 165 | filename += f"-pvalue={args.pvalue_type}-alt={args.alternative_distribution_source}" 166 | 167 | if os.path.exists(f"{args.save_dir}/{filename}.pt"): 168 | print("Skipping b/c already exists...") 169 | exit() 170 | 171 | # Set up bootstrap manager 172 | get_pvalue_1 = None 173 | get_pvalue_2 = None 174 | if args.pvalue_type == "parametric_bootstrap": 175 | fn1 = filename_stem.replace( 176 | f"{args.null_distribution_1_source}_{args.null_distribution_2_source}", 177 | args.null_distribution_1_source, 178 | ) 179 | bootstrap_manager_1 = BootstrapManager( 180 | bootstrap_path_template=f"{args.bootstrap_dir}/{fn1}.pkl", 181 | min_b=args.min_b, 182 | max_b=args.max_b, 183 | ) 184 | get_pvalue_1 = bootstrap_manager_1.load(n1=args.n1, n2=args.n2) 185 | fn2 = filename_stem.replace( 186 | f"{args.null_distribution_1_source}_{args.null_distribution_2_source}", 187 | args.null_distribution_2_source, 188 | ) 189 | bootstrap_manager_2 = BootstrapManager( 190 | bootstrap_path_template=f"{args.bootstrap_dir}/{fn2}.pkl", 191 | min_b=args.min_b, 192 | max_b=args.max_b, 193 | ) 194 | get_pvalue_2 = bootstrap_manager_2.load(n1=args.n1, n2=args.n2) 195 | 196 | # Load alternative distribution 197 | q = load_distribution( 198 | model=args.model, 199 | prompt_ids=js, 200 | L=args.L, 201 | source=args.alternative_distribution_source, 202 | load_in_unicode=args.test_in_unicode, 203 | ) 204 | print("Alternative shape: ", q.shape) 205 | 206 | # Simulate! 207 | ( 208 | pwr, 209 | reject_history, 210 | alpha_history, 211 | pvalue_history, 212 | test_stats, 213 | ) = get_power_two_sample_composite_null( 214 | p1, 215 | p2, 216 | q, 217 | n_null=args.n1, 218 | n_data=args.n2, 219 | n_simulations=args.n_simulations, 220 | pvalue_type=args.pvalue_type, 221 | stat_type=args.stat, 222 | get_pvalue_fn_1=get_pvalue_1, 223 | get_pvalue_fn_2=get_pvalue_2, 224 | return_pvalue=True, 225 | return_alpha=True, 226 | return_stat=True, 227 | alpha=args.alpha, 228 | ) 229 | 230 | # Save results 231 | print("Power: ", pwr) 232 | torch.save( 233 | { 234 | "power": pwr, 235 | "reject_history": reject_history, 236 | "alpha_history": alpha_history, 237 | "pvalue_history": pvalue_history, 238 | "test_stats": test_stats, 239 | }, 240 | f"{args.save_dir}/{filename}.pt", 241 | ) 242 | -------------------------------------------------------------------------------- /model_equality_testing/src/pvalue.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import Union, Tuple, List, Dict 4 | from model_equality_testing.tests import IMPLEMENTED_TESTS 5 | from model_equality_testing.distribution import ( 6 | CompletionSample, 7 | DistributionFromDataset, 8 | ) 9 | import tqdm 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def _plot_empirical_distribution(stats, ax=None, label="", **kwargs): 14 | """ 15 | Given a numpy array of test stats, plots the empirical distribution using a histogram 16 | """ 17 | if ax is None: 18 | plt.figure() 19 | ax = plt.gca() 20 | ax.hist(stats, bins="auto", **kwargs) # Adjust bins as needed 21 | ax.set_xlabel("Test Statistic") 22 | ax.set_ylabel("Frequency") 23 | ax.set_title( 24 | f'Distribution of Test Statistic {"(" + label + ")" if len(label) else ""}' 25 | ) 26 | return ax 27 | 28 | 29 | ############################################## 30 | # Functions to simulate and compute p-values 31 | ############################################## 32 | 33 | 34 | class EmpiricalPvalueCalculator: 35 | """ 36 | Given an empirical sample of test statitics, provides a callable that returns the p-value of an observed test statistic. 37 | """ 38 | 39 | def __init__(self, observed_stats: np.ndarray): 40 | """ 41 | Args: 42 | observed_stats: a numpy array of shape (b,) 43 | where b is the number of bootstrap samples 44 | """ 45 | self.stats = observed_stats 46 | 47 | def __call__(self, obs_stat: Union[float, np.ndarray, torch.Tensor]) -> float: 48 | # handle obs_stat: make sure it's a float 49 | if isinstance(obs_stat, (torch.Tensor, np.ndarray)): 50 | obs_stat = obs_stat.item() 51 | 52 | # compare to self.stats and average across the batch dimension (b) 53 | return np.mean((self.stats >= obs_stat), axis=0).item() 54 | 55 | 56 | def one_sample_parametric_bootstrap_pvalue( 57 | null_dist: DistributionFromDataset, 58 | n: int, 59 | b=1000, 60 | plot=False, 61 | return_stats=False, 62 | stat_type="g_squared", 63 | **kwargs, 64 | ) -> Union[EmpiricalPvalueCalculator, Tuple[EmpiricalPvalueCalculator, np.ndarray]]: 65 | """ 66 | Simulates the empirical distribution of the test statistic by repeatedly drawing samples 67 | from the null distribution and computing the test statistic. 68 | Args: 69 | null_dist: a distribution object from which to draw samples 70 | n: the size of the sample to take 71 | b: the number of times to draw samples and compute the test statistic 72 | plot: whether to plot the empirical distribution of the test statistics 73 | return_stats: whether to return the raw test statistics, in addition to 74 | the p-value calculator 75 | stat_type: the type of test statistic to compute as a string. 76 | Must be a key in IMPLEMENTED_TESTS 77 | **kwargs: additional arguments to pass to the test computation function 78 | """ 79 | stats = [] 80 | for _ in tqdm.tqdm(range(b), desc="Parametric bootstrap"): 81 | bootstrap_sample = null_dist.sample(n=n) 82 | stat = IMPLEMENTED_TESTS[stat_type](bootstrap_sample, null_dist, **kwargs) 83 | stats.append(stat) 84 | stats = np.array(stats) 85 | if stats.ndim == 1: 86 | stats = np.expand_dims(stats, 1) 87 | if stats.ndim == 2: 88 | stats = np.expand_dims(stats, 2) 89 | 90 | # plot the empirical distribution 91 | if plot: 92 | b, m, nstats = stats.shape 93 | assert m == 1, "Incorrect shape for plotting" 94 | for i in range(nstats): 95 | _plot_empirical_distribution(stats[:, :, i], label=f"{stat_type} dim {i}") 96 | 97 | get_pvalue = EmpiricalPvalueCalculator(stats) 98 | if return_stats: 99 | return get_pvalue, stats 100 | return get_pvalue 101 | 102 | 103 | def two_sample_parametric_bootstrap_pvalue( 104 | null_dist: DistributionFromDataset, 105 | n1: int, 106 | n2: int, 107 | b=1000, 108 | plot=False, 109 | return_stats=False, 110 | stat_type="two_sample_L2", 111 | **kwargs, 112 | ) -> Union[EmpiricalPvalueCalculator, Tuple[EmpiricalPvalueCalculator, np.ndarray]]: 113 | """ 114 | Simulates the empirical distribution of the test statistic by repeatedly drawing samples 115 | from the null distribution and computing the test statistic. 116 | Args: 117 | null_dist: a distribution object from which to draw samples 118 | n1: the size of the first sample to take 119 | n2: the size of the second sample to take 120 | b: the number of times to draw samples and compute the test statistic 121 | plot: whether to plot the empirical distribution of the test statistics 122 | return_stats: whether to return the raw test statistics, in addition to 123 | the p-value calculator 124 | stat_type: the type of test statistic to compute as a string. 125 | Must be a key in IMPLEMENTED_TESTS 126 | **kwargs: additional arguments to pass to the test computation function 127 | """ 128 | stats = [] 129 | for _ in tqdm.tqdm(range(b), desc="Parametric bootstrap"): 130 | sample1 = null_dist.sample(n=n1) 131 | sample2 = null_dist.sample(n=n2) 132 | stat = IMPLEMENTED_TESTS[stat_type](sample1, sample2, **kwargs) 133 | stats.append(stat) 134 | stats = np.array(stats) 135 | if stats.ndim == 1: 136 | stats = np.expand_dims(stats, 1) 137 | if stats.ndim == 2: 138 | stats = np.expand_dims(stats, 2) 139 | 140 | # plot the empirical distribution 141 | if plot: 142 | b, m, nstats = stats.shape 143 | assert m == 1, "Incorrect shape for plotting" 144 | for i in range(nstats): 145 | _plot_empirical_distribution(stats[:, :, i], label=f"{stat_type} dim {i}") 146 | 147 | get_pvalue = EmpiricalPvalueCalculator(stats) 148 | if return_stats: 149 | return get_pvalue, stats 150 | return get_pvalue 151 | 152 | 153 | def two_sample_permutation_pvalue( 154 | sample1: CompletionSample, 155 | sample2: CompletionSample, 156 | b=1000, 157 | plot=False, 158 | return_stats=False, 159 | stat_type="two_sample_L2", 160 | **kwargs, 161 | ) -> Union[EmpiricalPvalueCalculator, Tuple[EmpiricalPvalueCalculator, np.ndarray]]: 162 | """ 163 | Simulates the empirical distribution of the test statistic by repeatedly permuting the labels 164 | of the samples and computing the test statistic. 165 | Args: 166 | sample1: the first sample 167 | sample2: the second sample 168 | b: the number of times to draw samples and compute the test statistic 169 | plot: whether to plot the empirical distribution of the test statistics 170 | return_stats: whether to return the raw test statistics, in addition to 171 | the p-value calculator 172 | stat_type: the type of test statistic to compute as a string. 173 | Must be a key in IMPLEMENTED_TESTS 174 | **kwargs: additional arguments to pass to the test computation function 175 | """ 176 | stats = [] 177 | all_samples = torch.cat( 178 | [ 179 | sample1.sequences, 180 | sample2.sequences, 181 | ], 182 | dim=0, 183 | ) 184 | for _ in tqdm.tqdm(range(b), desc="Permutation bootstrap"): 185 | ix = torch.randperm(len(all_samples)) 186 | permuted_sample1 = CompletionSample( 187 | prompts=all_samples[ix][: sample1.N, 0], 188 | completions=all_samples[ix][: sample1.N, 1:], 189 | m=sample1.m, 190 | ) 191 | permuted_sample2 = CompletionSample( 192 | prompts=all_samples[ix][sample1.N :, 0], 193 | completions=all_samples[ix][sample1.N :, 1:], 194 | m=sample1.m, 195 | ) 196 | 197 | stat = IMPLEMENTED_TESTS[stat_type]( 198 | permuted_sample1, permuted_sample2, **kwargs 199 | ) 200 | stats.append(stat) 201 | 202 | stats = np.array(stats) 203 | if stats.ndim == 1: 204 | stats = np.expand_dims(stats, 1) 205 | if stats.ndim == 2: 206 | stats = np.expand_dims(stats, 2) 207 | 208 | # plot the empirical distribution 209 | if plot: 210 | b, m, nstats = stats.shape 211 | assert m == 1, "Incorrect shape for plotting" 212 | for i in range(nstats): 213 | _plot_empirical_distribution(stats[:, :, i], label=f"{stat_type} dim {i}") 214 | 215 | get_pvalue = EmpiricalPvalueCalculator(stats) 216 | if return_stats: 217 | return get_pvalue, stats 218 | del stats 219 | return get_pvalue 220 | 221 | 222 | ###### map from name to function ###### 223 | 224 | IMPLEMENTED_PVALUES = { 225 | "one_sample_parametric_bootstrap": one_sample_parametric_bootstrap_pvalue, 226 | "two_sample_parametric_bootstrap": two_sample_parametric_bootstrap_pvalue, 227 | "two_sample_permutation": two_sample_permutation_pvalue, 228 | } 229 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Model Equality Testing: Which Model Is This API Serving? 2 | 3 | [Paper link](https://arxiv.org/abs/2410.20247) | [Dataset link](https://drive.google.com/drive/folders/1TgxlUp3n-BFh-A6ARA42jkvxkv7Leccv?usp=drive_link) | [Twitter announcement](https://x.com/irena_gao/status/1851273269690908777) 4 | 5 | 6 | Users often interact with large language models through black-box inference APIs, both for closed- and open-weight models (e.g., Llama models are popularly accessed via Amazon Bedrock and Azure AI Studios). 7 | In order to cut costs or add functionality, API providers may quantize, watermark, or finetune the underlying model, changing the output distribution — often without notifying users. How can we detect if an API has changed for our particular task using only sample access? 8 | 9 |

