├── experiments ├── __init__.py ├── sampling │ ├── __init__.py │ ├── test_api_tokenization.py │ ├── cache_local_samples.py │ ├── cache_logprobs.py │ ├── cache_local_samples_vllm.py │ └── cache_api_samples.py ├── testing │ ├── __init__.py │ ├── bootstrap_manager.py │ ├── cache_one_sample_bootstrap.py │ ├── cache_two_sample_bootstrap.py │ ├── simulation.py │ ├── simulate_one_sample_power.py │ ├── simulate_two_sample_power.py │ └── simulate_two_sample_power_composite.py ├── constants │ ├── wikipedia_prompt_indices_test │ │ ├── 1 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ ├── 9.json │ │ │ └── 3.json │ │ ├── 5 │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 9.json │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ └── 8.json │ │ ├── 10 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 9.json │ │ │ └── 8.json │ │ ├── 25 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ └── 9.json │ │ ├── 50 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ └── 9.json │ │ ├── 75 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ └── 9.json │ │ └── 100 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ └── 9.json │ ├── wikipedia_prompt_indices_val │ │ ├── 1 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 6.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ ├── 9.json │ │ │ └── 5.json │ │ ├── 5 │ │ │ ├── 4.json │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 5.json │ │ │ ├── 6.json │ │ │ ├── 9.json │ │ │ ├── 7.json │ │ │ └── 8.json │ │ ├── 25 │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 7.json │ │ │ ├── 8.json │ │ │ ├── 6.json │ │ │ └── 9.json │ │ ├── 50 │ │ │ ├── 6.json │ │ │ ├── 1.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ ├── 5.json │ │ │ ├── 7.json │ │ │ ├── 9.json │ │ │ ├── 2.json │ │ │ ├── 8.json │ │ │ └── 0.json │ │ ├── 75 │ │ │ ├── 6.json │ │ │ ├── 0.json │ │ │ ├── 1.json │ │ │ ├── 2.json │ │ │ ├── 5.json │ │ │ ├── 7.json │ │ │ ├── 9.json │ │ │ ├── 3.json │ │ │ ├── 4.json │ │ │ └── 8.json │ │ └── 100 │ │ │ └── 0.json │ ├── vanilla_api.json │ ├── vanilla_local.json │ ├── llama3-chat-template.txt │ └── minimum_prompt_indices.json ├── utils.py └── prompts.py ├── hero.png ├── setup.png ├── model_equality_testing ├── src │ ├── __init__.py │ ├── algorithm.py │ ├── utils.py │ ├── dataset.py │ ├── pvalue.py │ ├── distribution.py │ └── tests.py ├── pyproject.toml └── LICENSE ├── .gitignore └── README.md /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/testing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/i-gao/model-equality-testing/HEAD/hero.png -------------------------------------------------------------------------------- /setup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/i-gao/model-equality-testing/HEAD/setup.png -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [75]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [70], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [69], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [38], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [43], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [67]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [34], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [44], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [45], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [18], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [2], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [14], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [9]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [13]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [16]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [1]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/1/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [20], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": []} 2 | -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/1/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [6], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": []} 2 | -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [9], "wikipedia_de": [], "wikipedia_fr": [4], "wikipedia_ru": [4], "wikipedia_es": [11, 13]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [87], "wikipedia_de": [43], "wikipedia_fr": [44, 63], "wikipedia_ru": [], "wikipedia_es": [90]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [98], "wikipedia_de": [46], "wikipedia_fr": [], "wikipedia_ru": [99], "wikipedia_es": [56, 58]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [95], "wikipedia_de": [], "wikipedia_fr": [48], "wikipedia_ru": [25, 96], "wikipedia_es": [28]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [30, 54], "wikipedia_de": [], "wikipedia_fr": [88], "wikipedia_ru": [41], "wikipedia_es": [64]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [36, 90], "wikipedia_de": [74], "wikipedia_fr": [95], "wikipedia_ru": [], "wikipedia_es": [28]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [19], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [2, 9], "wikipedia_es": [4, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [8, 15, 18], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [2, 13], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [11], "wikipedia_fr": [19], "wikipedia_ru": [3, 17, 19], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [0, 3, 6], "wikipedia_ru": [12], "wikipedia_es": [18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0], "wikipedia_de": [8, 17], "wikipedia_fr": [], "wikipedia_ru": [], "wikipedia_es": [3, 12]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [9, 15], "wikipedia_de": [10], "wikipedia_fr": [], "wikipedia_ru": [14, 15], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [14], "wikipedia_de": [], "wikipedia_fr": [11], "wikipedia_ru": [6], "wikipedia_es": [10, 16]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [79], "wikipedia_fr": [], "wikipedia_ru": [21, 48], "wikipedia_es": [74, 84]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [55, 56], "wikipedia_fr": [42, 74], "wikipedia_ru": [78], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [48, 81, 92], "wikipedia_fr": [68], "wikipedia_ru": [], "wikipedia_es": [33]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [60, 90, 95], "wikipedia_ru": [], "wikipedia_es": [69, 77]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/5/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 79], "wikipedia_de": [87], "wikipedia_fr": [38, 72], "wikipedia_ru": [], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [], "wikipedia_ru": [6, 12, 13, 18, 19], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/5/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [], "wikipedia_fr": [6], "wikipedia_ru": [], "wikipedia_es": [10, 12, 13, 16]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [28], "wikipedia_de": [45, 58], "wikipedia_fr": [69, 89], "wikipedia_ru": [39, 48, 81, 85], "wikipedia_es": [95]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [31], "wikipedia_de": [], "wikipedia_fr": [30, 88], "wikipedia_ru": [27, 40, 41, 47], "wikipedia_es": [46, 67, 78]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [29, 30, 44], "wikipedia_de": [26, 31, 40, 46, 61], "wikipedia_fr": [], "wikipedia_ru": [22], "wikipedia_es": [64]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [58, 79], "wikipedia_fr": [35, 45], "wikipedia_ru": [30, 31, 55, 63, 91], "wikipedia_es": [34]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [39, 63, 85], "wikipedia_de": [20, 31, 39, 68], "wikipedia_fr": [49], "wikipedia_ru": [], "wikipedia_es": [52, 57]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [42, 61], "wikipedia_de": [83], "wikipedia_fr": [46, 63], "wikipedia_ru": [36, 52, 90, 96], "wikipedia_es": [70]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [55, 62, 99], "wikipedia_de": [72], "wikipedia_fr": [26, 43, 54, 80], "wikipedia_ru": [93], "wikipedia_es": [88]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [33, 45, 79, 89], "wikipedia_de": [56, 89], "wikipedia_fr": [], "wikipedia_ru": [25, 71], "wikipedia_es": [38, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [42], "wikipedia_de": [51], "wikipedia_fr": [57, 69], "wikipedia_ru": [48, 65, 87], "wikipedia_es": [60, 67, 86]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/10/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [], "wikipedia_de": [26, 30, 76, 85, 88], "wikipedia_fr": [29, 60, 72, 75], "wikipedia_ru": [], "wikipedia_es": [72]} -------------------------------------------------------------------------------- /model_equality_testing/src/__init__.py: -------------------------------------------------------------------------------- 1 | from . import algorithm 2 | from . import dataset 3 | from . import distribution 4 | from . import utils 5 | from . import pvalue 6 | from . import tests 7 | 8 | -------------------------------------------------------------------------------- /experiments/constants/vanilla_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "top_k": null, 3 | "temperature": 1.0, 4 | "top_p": 1.0, 5 | "frequency_penalty": 0.0, 6 | "presence_penalty": 0.0, 7 | "stop": null 8 | } 9 | 10 | -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [13, 19], "wikipedia_de": [1, 4, 5, 14, 19], "wikipedia_fr": [2, 3, 5, 11, 13, 14, 17, 18], "wikipedia_ru": [0, 5, 7, 9, 11, 16, 19], "wikipedia_es": [3, 4, 16]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2, 3, 14], "wikipedia_de": [1, 2, 4, 9, 12], "wikipedia_fr": [1, 5, 11, 12, 14], "wikipedia_ru": [2, 6, 16, 17, 19], "wikipedia_es": [0, 5, 10, 11, 17, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 7, 8, 9, 15], "wikipedia_de": [4, 5, 6, 7, 8, 15], "wikipedia_fr": [3, 10, 14], "wikipedia_ru": [3, 5, 6, 13, 15, 18], "wikipedia_es": [6, 9, 10, 14, 15]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [5, 7, 12], "wikipedia_de": [4, 11, 12, 19], "wikipedia_fr": [4, 5, 7, 14, 15, 17], "wikipedia_ru": [4, 6, 14, 16], "wikipedia_es": [0, 3, 5, 7, 8, 10, 12, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 2, 4, 5, 8, 12, 19], "wikipedia_de": [0, 7, 8, 11, 13, 19], "wikipedia_fr": [0, 9, 15], "wikipedia_ru": [3, 10], "wikipedia_es": [2, 3, 4, 7, 11, 12, 15]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2, 19], "wikipedia_de": [1, 3, 6, 8, 11, 15, 19], "wikipedia_fr": [3, 8, 15], "wikipedia_ru": [1, 3, 4, 10, 12, 17, 19], "wikipedia_es": [0, 1, 4, 6, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [3, 7, 8, 18], "wikipedia_de": [16], "wikipedia_fr": [0, 1, 2, 4, 5, 10, 12, 13, 14], "wikipedia_ru": [5, 6, 7, 9, 10, 14, 19], "wikipedia_es": [8, 14, 15, 17]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [9, 11, 14, 19], "wikipedia_de": [8, 17, 19], "wikipedia_fr": [0, 1, 6, 16, 18, 19], "wikipedia_ru": [1, 3, 4, 7, 8, 12, 14, 17], "wikipedia_es": [7, 8, 13, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [4, 10, 11, 16], "wikipedia_de": [5, 10, 18], "wikipedia_fr": [2, 4, 8, 9, 14, 16], "wikipedia_ru": [3, 6, 17, 19], "wikipedia_es": [1, 5, 8, 10, 13, 14, 17, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/25/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2, 4, 12, 16], "wikipedia_de": [0, 3, 4, 9, 10, 12, 14, 18], "wikipedia_fr": [2, 8, 12, 14, 17, 18], "wikipedia_ru": [0, 10, 12, 13], "wikipedia_es": [0, 10, 17]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [39, 77, 89], "wikipedia_de": [21, 46, 67, 85, 88, 96], "wikipedia_fr": [47, 51, 76, 83, 95], "wikipedia_ru": [24, 28, 37, 57, 61, 71, 97], "wikipedia_es": [36, 60, 67, 80]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [28, 37, 42, 55, 63, 73], "wikipedia_de": [21, 66, 88, 90, 92, 94], "wikipedia_fr": [50, 70, 76, 93], "wikipedia_ru": [31, 86, 89], "wikipedia_es": [27, 28, 32, 50, 63, 99]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [22, 41, 48, 52, 65, 89], "wikipedia_de": [25, 29, 31, 44, 80, 85, 94, 99], "wikipedia_fr": [99], "wikipedia_ru": [20, 21, 29, 60, 74, 99], "wikipedia_es": [77, 81, 84, 93]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [37, 50, 52, 53, 77, 90], "wikipedia_de": [35, 51, 61, 72, 81], "wikipedia_fr": [24, 36, 37, 61, 95], "wikipedia_ru": [24, 37, 39, 46, 66, 90], "wikipedia_es": [27, 38, 78]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [22, 32, 57, 68, 70], "wikipedia_de": [33, 57, 72, 74, 78, 88, 95, 97], "wikipedia_fr": [52, 64, 89], "wikipedia_ru": [54, 71, 73, 82, 87], "wikipedia_es": [42, 69, 74, 85]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 24, 37, 51, 60, 62, 81, 87], "wikipedia_de": [49, 67, 83], "wikipedia_fr": [24, 34, 35, 37, 61, 63], "wikipedia_ru": [29, 56, 75, 92], "wikipedia_es": [61, 75, 81, 90]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [24, 60, 79, 83, 89], "wikipedia_de": [20, 23, 27, 36, 55, 66, 92], "wikipedia_fr": [27, 28, 57, 65, 78], "wikipedia_ru": [51, 56, 69, 97], "wikipedia_es": [23, 31, 62, 81]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [50, 59, 63, 67, 91, 95, 99], "wikipedia_de": [59, 80], "wikipedia_fr": [20, 29, 45, 46, 73, 92], "wikipedia_ru": [57, 92], "wikipedia_es": [33, 37, 39, 40, 66, 72, 82, 88]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 41, 44, 46, 48, 49, 64, 66, 80, 82, 90], "wikipedia_de": [24, 59, 64], "wikipedia_fr": [64, 83, 93, 94, 98], "wikipedia_ru": [28, 30, 44, 70, 79, 89], "wikipedia_es": []} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/25/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [28, 31, 52, 74, 75, 77], "wikipedia_de": [48, 80], "wikipedia_fr": [20, 24, 35, 67, 81, 85, 88], "wikipedia_ru": [25, 27, 35, 62, 65, 95], "wikipedia_es": [42, 80, 84, 92]} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg* 2 | model_equality_testing/dist/ 3 | sparse_logs/ 4 | *.zip 5 | *.bin 6 | __pycache__/ 7 | .ipynb_checkpoints/ 8 | wandb 9 | *.out 10 | scr/ 11 | *.pdf 12 | *.pkl 13 | cache/ 14 | old-cache/ 15 | power_*.json 16 | *.c 17 | *.so 18 | .nfs* 19 | *.o 20 | finetuned_models/ 21 | data/ -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 2, 3, 4, 8, 9, 10, 11, 14, 16, 18, 19], "wikipedia_de": [0, 1, 6, 7, 9, 12, 17], "wikipedia_fr": [0, 4, 5, 6, 7, 9, 11, 13, 15, 18], "wikipedia_ru": [0, 2, 3, 4, 9, 10, 12, 14, 15, 16, 17, 18], "wikipedia_es": [0, 2, 3, 4, 5, 10, 11, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 3, 5, 6, 8, 9, 11, 12, 13, 17, 19], "wikipedia_de": [1, 4, 9, 10, 11, 12, 13, 14, 15, 17, 19], "wikipedia_fr": [1, 2, 3, 6, 11, 14, 15, 17, 18, 19], "wikipedia_ru": [6, 8, 9, 12, 13, 15, 17, 19], "wikipedia_es": [1, 3, 4, 5, 6, 9, 13, 17, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 2, 3, 4, 6, 8, 9, 10, 11, 12, 14, 15, 16, 18], "wikipedia_de": [3, 4, 6, 7, 8, 9, 13, 15, 16], "wikipedia_fr": [0, 1, 2, 4, 9, 10, 11, 13, 17, 18], "wikipedia_ru": [2, 8, 9, 13, 14, 15, 16], "wikipedia_es": [1, 2, 4, 5, 9, 11, 13, 14, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 2, 5, 7, 9, 10, 11, 12, 13, 15, 17, 18], "wikipedia_de": [2, 3, 4, 8, 9, 10, 11, 12, 16, 17], "wikipedia_fr": [1, 4, 5, 6, 7, 9, 11, 13, 14, 17, 18], "wikipedia_ru": [3, 6, 8, 11, 12, 13, 16], "wikipedia_es": [2, 4, 6, 7, 8, 11, 15, 16, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2, 3, 4, 5, 6, 12, 13, 14, 16, 17, 18], "wikipedia_de": [0, 2, 3, 5, 7, 8, 11, 15, 16, 17, 18], "wikipedia_fr": [0, 3, 6, 10, 11, 15, 17], "wikipedia_ru": [1, 4, 5, 7, 9, 10, 14, 15, 16, 17, 19], "wikipedia_es": [1, 3, 7, 8, 12, 14, 15, 16, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 3, 4, 5, 6, 9, 10, 12, 13, 14, 16, 17], "wikipedia_de": [0, 3, 5, 6, 7, 8, 11, 12, 15, 16, 18], "wikipedia_fr": [2, 8, 9, 10, 12, 14, 18], "wikipedia_ru": [0, 3, 4, 5, 6, 16, 18, 19], "wikipedia_es": [0, 2, 5, 6, 7, 10, 11, 13, 15, 16, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2, 3, 6, 7, 9, 10, 14, 15, 17, 18, 19], "wikipedia_de": [0, 1, 6, 8, 9, 10, 12, 13, 15, 16, 18], "wikipedia_fr": [0, 1, 4, 5, 6, 7, 9, 10, 13, 14, 15, 16, 18, 19], "wikipedia_ru": [0, 1, 6, 12, 15, 17, 19], "wikipedia_es": [1, 2, 8, 13, 16, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 2, 4, 8, 9, 11, 12, 13, 14, 15, 16, 17], "wikipedia_de": [0, 1, 2, 10, 11, 13, 14, 17, 18, 19], "wikipedia_fr": [2, 4, 6, 8, 10, 11, 14, 17], "wikipedia_ru": [1, 2, 4, 8, 9, 12, 15, 18, 19], "wikipedia_es": [0, 2, 5, 10, 11, 12, 13, 14, 16, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [2, 3, 4, 11, 13, 15, 16, 18], "wikipedia_de": [0, 1, 3, 4, 7, 10, 11, 12, 14, 15, 16, 17, 19], "wikipedia_fr": [2, 4, 8, 11, 12, 15, 18, 19], "wikipedia_ru": [0, 1, 4, 10, 12, 13, 14, 15, 16], "wikipedia_es": [1, 3, 4, 6, 7, 9, 10, 11, 12, 15, 17, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/50/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 1, 10, 12, 13, 14, 15, 16, 18, 19], "wikipedia_de": [0, 1, 3, 6, 10, 12, 13, 14, 16, 17, 18], "wikipedia_fr": [2, 8, 9, 10, 12, 13, 17, 18, 19], "wikipedia_ru": [1, 3, 5, 8, 10, 11, 14, 17, 18], "wikipedia_es": [0, 1, 2, 10, 11, 12, 13, 16, 17, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [20, 29, 38, 58, 59, 82, 85, 88, 95], "wikipedia_de": [21, 29, 36, 49, 50, 52, 61, 74, 80, 86, 87, 92, 96], "wikipedia_fr": [23, 25, 34, 45, 48, 49, 55, 79, 83, 94], "wikipedia_ru": [23, 28, 37, 38, 43, 69, 77, 88, 95], "wikipedia_es": [24, 40, 41, 46, 50, 60, 81, 86, 91]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [45, 50, 62, 64, 65, 74, 99], "wikipedia_de": [30, 32, 35, 48, 55, 60, 63, 65, 69, 71, 79, 80, 92, 94], "wikipedia_fr": [35, 39, 51, 65, 70, 72, 85, 87, 88, 90], "wikipedia_ru": [29, 34, 63, 65, 75, 82, 83, 90], "wikipedia_es": [22, 24, 30, 43, 50, 56, 75, 77, 85, 90, 97]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [36, 46, 54, 65, 70, 78, 87, 92], "wikipedia_de": [20, 28, 29, 40, 46, 58, 61, 63, 66, 68, 71, 74, 82, 85, 97], "wikipedia_fr": [21, 37, 43, 54, 57, 70, 89], "wikipedia_ru": [24, 25, 29, 32, 35, 40, 61, 65, 68], "wikipedia_es": [23, 26, 33, 40, 45, 57, 65, 71, 77, 97, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [35, 45, 48, 54, 58, 73, 78, 93, 96, 97], "wikipedia_de": [29, 35, 45, 56, 69, 80, 90, 91], "wikipedia_fr": [29, 38, 47, 49, 51, 55, 57, 66, 73, 83, 84], "wikipedia_ru": [22, 25, 29, 38, 56, 63, 80, 86], "wikipedia_es": [20, 26, 27, 29, 32, 35, 39, 50, 60, 77, 81, 88, 94]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [37, 39, 53, 54, 56, 74, 75, 78, 90, 92, 99], "wikipedia_de": [21, 23, 24, 26, 37, 73, 84], "wikipedia_fr": [26, 40, 41, 47, 57, 58, 77, 86, 91], "wikipedia_ru": [33, 46, 48, 53, 66, 77, 78, 80, 88, 93], "wikipedia_es": [21, 31, 36, 42, 44, 46, 50, 59, 64, 73, 80, 85, 86]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [20, 26, 33, 50, 51, 60, 61, 69, 70, 82, 83, 97, 99], "wikipedia_de": [23, 35, 36, 53, 56, 60, 88, 94], "wikipedia_fr": [23, 30, 44, 45, 50, 58, 65, 70, 74, 80, 85, 98], "wikipedia_ru": [23, 32, 39, 46, 50, 66, 79], "wikipedia_es": [32, 36, 38, 40, 45, 53, 67, 76, 78, 96]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [62, 66, 84, 91, 92], "wikipedia_de": [26, 28, 34, 38, 52, 53, 64, 66, 74, 78, 79, 80, 87, 89, 91, 94], "wikipedia_fr": [21, 27, 35, 37, 70, 71, 74, 78, 85, 89, 94, 96], "wikipedia_ru": [20, 30, 37, 38, 39, 49, 59, 84, 90], "wikipedia_es": [47, 59, 60, 64, 69, 80, 81, 86]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 23, 31, 34, 44, 56, 59, 62, 69, 77, 79, 81, 83, 90], "wikipedia_de": [24, 31, 65, 84, 87, 94], "wikipedia_fr": [24, 30, 31, 36, 41, 43, 45, 49, 50], "wikipedia_ru": [24, 25, 39, 51, 56, 61, 62, 67, 79, 90, 96], "wikipedia_es": [27, 31, 40, 41, 51, 53, 67, 68, 83, 95]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [28, 29, 34, 50, 52, 53, 57, 59, 83, 85, 88, 94, 95, 98], "wikipedia_de": [33, 42, 47, 56, 68, 72, 75, 78], "wikipedia_fr": [22, 28, 43, 48, 57, 65, 69, 72, 77, 86, 87, 98], "wikipedia_ru": [22, 44, 47, 54, 59, 68, 84, 87, 99], "wikipedia_es": [32, 39, 41, 67, 73, 85, 86]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/50/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [27, 31, 35, 42, 48, 50, 56, 67, 68, 71, 85, 88, 91, 97, 99], "wikipedia_de": [30, 47, 54, 67, 70, 71, 74, 80, 96], "wikipedia_fr": [23, 34, 41, 43, 51, 54, 96], "wikipedia_ru": [20, 28, 38, 45, 46, 52, 61, 64, 66, 67, 71, 74, 77, 78, 82], "wikipedia_es": [20, 37, 54, 98]} -------------------------------------------------------------------------------- /experiments/constants/vanilla_local.json: -------------------------------------------------------------------------------- 1 | { 2 | "top_k": 0, 3 | "temperature": 1.0, 4 | "top_p": 1.0, 5 | "repetition_penalty": 1.0, 6 | "num_beams": 1, 7 | "do_sample": true, 8 | "early_stopping": false, 9 | "no_repeat_ngram_size": 0, 10 | "bad_words_ids": null, 11 | "force_words_ids": null, 12 | "use_cache": true 13 | } 14 | 15 | -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 1, 2, 3, 5, 6, 7, 9, 10, 11, 12, 14, 15], "wikipedia_de": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 18, 19], "wikipedia_fr": [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14, 16, 17, 19], "wikipedia_ru": [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16, 17, 18], "wikipedia_es": [0, 1, 2, 3, 4, 5, 7, 8, 9, 12, 15, 16, 17, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 1, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19], "wikipedia_de": [0, 1, 2, 4, 7, 8, 9, 12, 13, 14, 15, 16], "wikipedia_fr": [1, 2, 3, 4, 5, 6, 8, 11, 12, 14, 15, 16, 17, 18, 19], "wikipedia_ru": [1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 13, 14, 15, 16, 18, 19], "wikipedia_es": [0, 1, 2, 3, 5, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 3, 4, 5, 6, 8, 9, 11, 12, 13, 14, 15, 16, 18], "wikipedia_de": [1, 2, 3, 4, 6, 7, 10, 11, 12, 14, 15, 16, 17, 18], "wikipedia_fr": [0, 2, 3, 4, 5, 6, 8, 9, 10, 12, 14, 15, 16, 17, 18, 19], "wikipedia_ru": [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 13, 15, 17, 18, 19], "wikipedia_es": [0, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 2, 4, 5, 6, 7, 8, 9, 11, 13, 14, 17, 18, 19], "wikipedia_de": [1, 2, 5, 7, 8, 9, 10, 11, 12, 14, 15, 17, 19], "wikipedia_fr": [0, 1, 2, 3, 4, 6, 7, 8, 9, 12, 13, 14, 19], "wikipedia_ru": [1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19], "wikipedia_es": [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19], "wikipedia_de": [0, 1, 2, 3, 4, 5, 8, 9, 10, 14, 15, 16, 17, 18], "wikipedia_fr": [0, 2, 3, 4, 5, 7, 8, 10, 13, 14, 15, 17, 19], "wikipedia_ru": [1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 19], "wikipedia_es": [0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 14, 15, 16, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 18, 19], "wikipedia_de": [0, 1, 2, 3, 4, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19], "wikipedia_fr": [0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13, 15, 16, 17, 18, 19], "wikipedia_ru": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 13, 14, 17, 18, 19], "wikipedia_es": [2, 5, 6, 8, 9, 10, 11, 12, 13, 16, 18]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 19], "wikipedia_de": [1, 4, 5, 6, 7, 8, 9, 12, 14, 16, 18, 19], "wikipedia_fr": [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19], "wikipedia_ru": [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 15, 16, 17, 18, 19], "wikipedia_es": [0, 1, 3, 4, 6, 8, 9, 12, 13, 14, 15, 16]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 2, 4, 7, 8, 10, 11, 12, 13, 14, 15, 16, 19], "wikipedia_de": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], "wikipedia_fr": [0, 2, 3, 6, 7, 8, 10, 11, 12, 13, 15, 16, 17, 18, 19], "wikipedia_ru": [1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18], "wikipedia_es": [0, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 2, 3, 4, 5, 7, 8, 11, 12, 13, 14, 15, 17, 18, 19], "wikipedia_de": [0, 1, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], "wikipedia_fr": [0, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18], "wikipedia_ru": [2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 15, 17, 19], "wikipedia_es": [0, 1, 2, 4, 6, 7, 8, 9, 10, 11, 13, 14, 16, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/75/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [1, 3, 4, 5, 6, 7, 9, 12, 14, 15, 16, 17, 18, 19], "wikipedia_de": [1, 2, 5, 6, 7, 9, 11, 12, 14, 15, 16, 17, 19], "wikipedia_fr": [0, 1, 2, 4, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19], "wikipedia_ru": [0, 1, 2, 3, 5, 6, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19], "wikipedia_es": [0, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/llama3-chat-template.txt: -------------------------------------------------------------------------------- 1 | {% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|> 2 | 3 | '+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|> 4 | 5 | ' }}{% endif %} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [23, 24, 33, 35, 37, 43, 50, 53, 54, 55, 58, 64, 73, 77, 88, 90, 94], "wikipedia_de": [22, 26, 31, 32, 37, 39, 42, 46, 57, 68, 75, 77, 82, 83, 87, 90, 91], "wikipedia_fr": [21, 25, 26, 28, 36, 40, 45, 47, 55, 63, 64, 78, 79, 80, 81, 99], "wikipedia_ru": [24, 27, 29, 32, 37, 39, 41, 55, 57, 61, 66, 68, 70, 71, 75, 97], "wikipedia_es": [22, 26, 31, 37, 41, 54, 66, 74, 93]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 27, 30, 36, 38, 43, 44, 49, 56, 59, 63, 65, 67, 79, 83, 87, 90, 98], "wikipedia_de": [22, 24, 26, 35, 52, 54, 57, 60, 62, 71, 73, 75, 77, 81, 82, 95, 98], "wikipedia_fr": [23, 27, 30, 42, 49, 53, 55, 72, 76, 77, 81, 82, 83, 97], "wikipedia_ru": [27, 33, 38, 40, 47, 54, 61, 62, 63, 72, 78, 80, 81, 88, 91], "wikipedia_es": [20, 30, 31, 42, 44, 51, 61, 68, 75, 76, 90]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 28, 46, 60, 62, 63, 68, 70, 74, 75, 85, 91, 96], "wikipedia_de": [26, 29, 37, 45, 46, 51, 77, 84, 88, 90, 94, 98], "wikipedia_fr": [20, 24, 28, 34, 36, 37, 45, 47, 48, 52, 53, 59, 60, 65, 67, 68, 80, 81], "wikipedia_ru": [24, 25, 27, 33, 35, 38, 40, 53, 58, 59, 60, 63, 77, 82, 84, 87, 95], "wikipedia_es": [20, 25, 38, 39, 40, 45, 48, 51, 65, 74, 76, 77, 78, 85, 86]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [30, 34, 35, 51, 54, 58, 63, 72, 73, 77, 79, 82, 83, 89, 90, 91, 92], "wikipedia_de": [34, 48, 49, 54, 67, 77, 81, 83, 86, 88, 89, 96], "wikipedia_fr": [32, 33, 38, 46, 48, 61, 65, 67, 69, 83, 87, 89, 92, 95], "wikipedia_ru": [23, 32, 33, 34, 40, 44, 47, 64, 70, 81, 82, 84, 85, 89, 91, 95, 99], "wikipedia_es": [21, 25, 26, 31, 40, 42, 44, 53, 63, 67, 69, 72, 82, 88, 93]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [23, 26, 27, 29, 31, 32, 36, 44, 46, 51, 59, 66, 74, 78, 83, 86, 87, 92, 97, 99], "wikipedia_de": [20, 21, 22, 27, 33, 34, 35, 45, 47, 48, 64, 71, 78, 79, 82, 83, 85, 86, 99], "wikipedia_fr": [28, 30, 37, 48, 65, 66, 88, 89, 95], "wikipedia_ru": [20, 31, 38, 41, 44, 53, 55, 58, 69, 70, 78, 90, 94, 99], "wikipedia_es": [22, 25, 28, 29, 58, 59, 60, 62, 63, 66, 78, 90, 96]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [20, 30, 36, 47, 48, 53, 58, 59, 60, 61, 63, 67, 72, 75, 87, 89, 95, 99], "wikipedia_de": [23, 29, 38, 45, 54, 60, 62, 72, 75, 89], "wikipedia_fr": [30, 34, 37, 40, 50, 55, 59, 60, 61, 63, 66, 79, 80, 90, 91, 92, 94], "wikipedia_ru": [20, 21, 30, 36, 46, 55, 58, 61, 68, 72, 82, 88, 90, 96], "wikipedia_es": [27, 28, 30, 31, 32, 35, 47, 49, 53, 63, 68, 74, 75, 77, 81, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [23, 25, 28, 30, 35, 44, 46, 48, 49, 59, 60, 72, 77, 86, 87, 88], "wikipedia_de": [22, 29, 34, 38, 42, 43, 50, 52, 53, 72, 82, 83, 89, 90, 93, 99], "wikipedia_fr": [32, 48, 52, 58, 63, 64, 66, 69, 77, 78, 80, 82, 83, 87, 92, 99], "wikipedia_ru": [22, 31, 44, 46, 47, 48, 55, 65, 84, 89, 99], "wikipedia_es": [20, 21, 23, 32, 34, 35, 36, 38, 45, 49, 53, 54, 72, 75, 76, 81]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [23, 27, 40, 47, 51, 53, 55, 56, 68, 94, 95, 99], "wikipedia_de": [21, 23, 24, 25, 27, 28, 31, 46, 50, 51, 62, 63, 66, 80, 82, 88], "wikipedia_fr": [24, 32, 36, 38, 39, 49, 50, 51, 59, 69, 71, 75, 87, 91], "wikipedia_ru": [21, 22, 33, 40, 47, 48, 51, 56, 62, 67, 81, 84, 90, 96], "wikipedia_es": [21, 28, 35, 40, 41, 43, 45, 46, 47, 54, 55, 57, 59, 61, 65, 71, 81, 95, 96]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [36, 50, 51, 59, 63, 64, 71, 78, 79, 81, 95, 98], "wikipedia_de": [25, 28, 31, 37, 42, 45, 48, 59, 60, 63, 65, 66, 72, 77, 88], "wikipedia_fr": [29, 34, 35, 43, 49, 52, 53, 58, 59, 61, 67, 69, 72, 74, 76, 84, 93, 94, 97], "wikipedia_ru": [21, 26, 27, 34, 39, 43, 51, 52, 65, 72, 88, 92, 98], "wikipedia_es": [26, 28, 32, 39, 47, 51, 62, 67, 68, 72, 84, 87, 88, 92, 95, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/75/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [30, 39, 40, 48, 51, 54, 55, 71, 76, 77, 80, 83, 86, 92, 98, 99], "wikipedia_de": [24, 36, 37, 38, 42, 47, 48, 52, 54, 59, 60, 69, 76, 79, 84, 85, 87, 88, 90, 98], "wikipedia_fr": [25, 26, 38, 52, 59, 67, 74, 78, 86, 97], "wikipedia_ru": [21, 39, 43, 47, 52, 53, 57, 66, 72, 73, 79, 84, 85, 92, 96, 99], "wikipedia_es": [24, 29, 30, 35, 42, 43, 51, 59, 67, 74, 87, 89, 90]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_val/100/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], "wikipedia_de": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], "wikipedia_fr": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], "wikipedia_ru": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], "wikipedia_es": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/0.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [24, 26, 27, 29, 38, 42, 54, 56, 57, 60, 66, 76, 77, 78, 86, 92, 94, 95, 99], "wikipedia_de": [28, 29, 32, 37, 38, 45, 51, 56, 60, 61, 65, 70, 72, 73, 74, 78, 79, 81, 83, 89, 91, 94], "wikipedia_fr": [24, 26, 28, 36, 39, 47, 60, 66, 70, 71, 77, 78, 82, 83, 86, 94, 98], "wikipedia_ru": [29, 39, 43, 46, 53, 55, 56, 58, 59, 72, 75, 85, 86, 88, 90, 94, 97], "wikipedia_es": [21, 22, 26, 27, 30, 31, 32, 42, 44, 53, 54, 61, 63, 65, 68, 69, 73, 75, 78, 80, 85, 86, 92, 93, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/1.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [22, 23, 26, 27, 28, 38, 45, 47, 48, 51, 57, 59, 62, 67, 78, 79, 81, 83, 91, 99], "wikipedia_de": [21, 35, 36, 43, 44, 46, 48, 50, 56, 71, 77, 79, 91, 92, 95], "wikipedia_fr": [20, 30, 31, 36, 42, 44, 46, 47, 48, 51, 54, 60, 62, 65, 72, 73, 75, 82, 83, 84, 88, 95, 99], "wikipedia_ru": [21, 29, 34, 41, 42, 43, 44, 48, 49, 69, 70, 78, 79, 80, 82, 85, 92, 95], "wikipedia_es": [21, 25, 33, 39, 43, 49, 53, 55, 56, 57, 67, 73, 75, 76, 78, 79, 81, 85, 86, 89, 90, 95, 97, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/2.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [26, 27, 34, 40, 41, 44, 51, 59, 60, 62, 65, 72, 73, 78, 86, 87, 90, 95, 99], "wikipedia_de": [23, 24, 28, 31, 35, 43, 47, 54, 55, 56, 61, 63, 73, 81, 85, 88, 89, 91, 94, 96, 98], "wikipedia_fr": [21, 23, 27, 29, 37, 43, 49, 52, 53, 54, 55, 59, 60, 71, 72, 83, 85, 97], "wikipedia_ru": [20, 24, 30, 33, 34, 39, 40, 45, 50, 54, 63, 64, 67, 70, 72, 78, 85, 86, 87, 90, 94, 95, 97], "wikipedia_es": [20, 22, 24, 26, 37, 44, 49, 56, 64, 72, 73, 76, 78, 81, 82, 84, 89, 95, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/3.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [20, 24, 26, 27, 28, 32, 33, 36, 39, 45, 48, 54, 59, 61, 68, 76, 79, 82, 84, 85, 86, 88, 92, 96], "wikipedia_de": [20, 21, 23, 28, 32, 37, 41, 46, 50, 55, 59, 60, 64, 68, 73, 78, 80, 82, 87, 89, 91], "wikipedia_fr": [20, 21, 23, 25, 33, 34, 36, 40, 44, 46, 55, 58, 63, 66, 72, 73, 83, 85, 94, 98], "wikipedia_ru": [27, 29, 31, 39, 45, 48, 53, 56, 57, 66, 76, 77, 80, 82, 83, 84, 89, 90], "wikipedia_es": [24, 26, 36, 43, 47, 56, 61, 66, 70, 74, 77, 78, 79, 82, 83, 85, 94]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/4.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [29, 30, 31, 34, 40, 42, 44, 45, 47, 52, 53, 54, 58, 68, 69, 71, 74, 79, 81, 85, 99], "wikipedia_de": [23, 26, 29, 31, 32, 37, 48, 53, 56, 58, 62, 63, 75, 80, 82, 86, 91, 96], "wikipedia_fr": [21, 22, 31, 33, 41, 42, 44, 49, 54, 55, 56, 57, 64, 72, 75, 76, 84, 87, 93, 94], "wikipedia_ru": [20, 22, 23, 25, 32, 34, 41, 44, 45, 47, 53, 59, 62, 66, 69, 70, 74, 80, 88, 91, 99], "wikipedia_es": [20, 22, 25, 29, 33, 43, 45, 47, 54, 57, 66, 73, 75, 78, 81, 82, 84, 88, 90, 97]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/5.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [23, 26, 30, 33, 41, 45, 48, 55, 56, 58, 59, 60, 70, 73, 75, 76, 80, 85, 88, 93], "wikipedia_de": [22, 24, 25, 27, 31, 35, 36, 40, 41, 50, 51, 54, 57, 58, 68, 70, 75, 76, 77, 83, 87, 91], "wikipedia_fr": [20, 22, 31, 32, 33, 37, 41, 42, 45, 47, 48, 52, 54, 63, 75, 77, 79, 81, 83, 89, 91, 94, 96, 97], "wikipedia_ru": [20, 22, 23, 24, 27, 34, 35, 36, 37, 40, 41, 43, 58, 61, 64, 68, 69, 83, 84, 95, 99], "wikipedia_es": [20, 22, 37, 41, 73, 75, 77, 79, 86, 87, 90, 97, 98]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/6.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [23, 26, 28, 30, 31, 34, 40, 41, 43, 46, 57, 59, 69, 71, 74, 75, 77, 79, 80, 91, 92, 93, 94, 96, 97, 98], "wikipedia_de": [20, 32, 34, 37, 47, 50, 57, 62, 64, 65, 86, 91, 92, 95], "wikipedia_fr": [20, 22, 23, 25, 41, 46, 49, 57, 59, 61, 63, 64, 70, 73, 75, 76, 80, 86, 89, 90, 96, 97, 98], "wikipedia_ru": [20, 34, 37, 38, 45, 51, 60, 65, 71, 74, 78, 80, 86, 93, 96, 98, 99], "wikipedia_es": [27, 31, 32, 34, 36, 37, 38, 41, 43, 45, 46, 56, 61, 66, 67, 69, 72, 84, 89, 90]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/7.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [24, 25, 34, 38, 42, 43, 46, 51, 52, 53, 54, 57, 58, 59, 62, 69, 77, 78, 79, 80, 82, 92, 94, 95, 97], "wikipedia_de": [24, 26, 27, 34, 35, 36, 37, 41, 42, 48, 55, 56, 61, 63, 64, 65, 74, 77, 78, 80, 81, 89, 94, 95], "wikipedia_fr": [23, 29, 35, 36, 40, 48, 54, 55, 58, 69, 70, 76, 79, 80, 85, 90, 98], "wikipedia_ru": [20, 29, 34, 35, 40, 44, 45, 53, 54, 57, 63, 64, 65, 66, 67, 70, 76, 85, 86, 88, 92], "wikipedia_es": [21, 34, 40, 55, 62, 82, 85, 86, 88, 91, 92, 95, 96]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/8.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [25, 26, 29, 32, 33, 37, 38, 41, 43, 44, 49, 54, 61, 62, 63, 64, 68, 92, 97], "wikipedia_de": [23, 26, 36, 37, 41, 45, 46, 50, 54, 58, 61, 62, 68, 77, 80, 81, 83, 87, 91, 93], "wikipedia_fr": [20, 26, 32, 36, 38, 44, 47, 48, 51, 57, 60, 62, 66, 68, 76, 82, 86, 88, 90, 94], "wikipedia_ru": [27, 29, 35, 52, 53, 59, 62, 63, 66, 67, 68, 73, 76, 78, 81, 92, 94, 96, 98], "wikipedia_es": [22, 29, 30, 37, 41, 42, 43, 44, 48, 49, 51, 53, 58, 63, 73, 75, 76, 77, 82, 87, 89, 96]} -------------------------------------------------------------------------------- /experiments/constants/wikipedia_prompt_indices_test/100/9.json: -------------------------------------------------------------------------------- 1 | {"wikipedia_en": [21, 22, 23, 24, 25, 32, 35, 36, 40, 43, 44, 46, 47, 52, 54, 56, 61, 63, 66, 68, 73, 77, 78, 79, 86, 99], "wikipedia_de": [24, 27, 31, 33, 40, 46, 53, 54, 55, 58, 59, 60, 61, 62, 63, 76, 91, 93, 97], "wikipedia_fr": [21, 22, 24, 32, 39, 40, 57, 61, 62, 69, 81, 93, 94, 95, 97, 98], "wikipedia_ru": [23, 25, 41, 45, 47, 49, 53, 57, 61, 62, 71, 76, 77, 83, 88, 89, 93, 97], "wikipedia_es": [20, 23, 25, 26, 38, 40, 44, 45, 46, 48, 53, 54, 58, 63, 65, 68, 72, 74, 81, 85, 99]} -------------------------------------------------------------------------------- /model_equality_testing/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [tool.setuptools] 7 | packages = ["model_equality_testing"] 8 | package-dir = { "model_equality_testing" = "src" } 9 | 10 | [project] 11 | name = "model_equality_testing" 12 | version = "0.0.1" 13 | authors = [ 14 | { name="Irena Gao", email="irena@cs.stanford.edu" }, 15 | ] 16 | description = "Package to conduct model equality testing for black-box language model APIs" 17 | readme = "README.md" 18 | requires-python = ">=3.8" 19 | classifiers = [ 20 | "Programming Language :: Python :: 3", 21 | "License :: OSI Approved :: MIT License", 22 | "Operating System :: OS Independent", 23 | ] 24 | dependencies = [ 25 | "numpy", 26 | "torch", 27 | "matplotlib" 28 | ] 29 | 30 | [project.urls] 31 | Homepage = "https://github.com/i-gao/model-equality-testing" 32 | Issues = "https://github.com/i-gao/model-equality-testing/issues" 33 | -------------------------------------------------------------------------------- /model_equality_testing/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /experiments/sampling/test_api_tokenization.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../") 3 | from accelerate import Accelerator 4 | 5 | APIS = [ 6 | "anyscale", 7 | "together", 8 | "fireworks", 9 | "perplexity", 10 | "replicate", 11 | "groq", 12 | "deepinfra", 13 | "amazon", 14 | "azure", 15 | ] 16 | MODELS = [ 17 | "meta-llama/Meta-Llama-3-8B-Instruct", 18 | "meta-llama/Meta-Llama-3-70B-Instruct", 19 | "meta-llama/Meta-Llama-3.1-8B-Instruct", 20 | "meta-llama/Meta-Llama-3.1-70B-Instruct", 21 | "meta-llama/Meta-Llama-3.1-405B-Instruct", 22 | ] 23 | 24 | ############ 25 | 26 | from experiments.sampling.model import TransformersModel 27 | import experiments.prompts as prompts 28 | from cache_api_samples import policy 29 | import experiments.sampling.api as api_library 30 | 31 | def get_expected_prompt_len(api, model, ds): 32 | _policy = policy(api, model.model_name) 33 | p = ds[0][_policy["prompt_key"]] 34 | p_for_len = ds[0][_policy["expected_prompt_key"]] 35 | ids = model.tokenizer( 36 | p_for_len, add_special_tokens=('special' not in _policy["expected_prompt_key"]), 37 | )['input_ids'] 38 | return len(ids), p 39 | 40 | def test_sampling(p, model, api, expected_prompt_len, n): 41 | print(f">> Testing {model} with {api} and repeatedly requesting through n={n}") 42 | i = 0 43 | try: 44 | raw = "" 45 | out = [] 46 | get_fn, kwargs = getattr(api_library, f"setup_{api}")( 47 | model=model, 48 | N=1, 49 | L=5, 50 | use_chat_endpoint=policy(api, model)["use_chat_endpoint"], 51 | do_sample=True, 52 | temperature=1, 53 | top_p=None, 54 | ) 55 | for i in range(n): 56 | o = get_fn(prompt=p, **kwargs)[0] 57 | if not expected_prompt_len == o.num_prompt_tokens: 58 | if o.prompt is None or o.prompt != p: 59 | raise Exception(f"Expected prompt len was {expected_prompt_len}, actual was {o.num_prompt_tokens}\n{o}") 60 | out.append(o.full_completion) 61 | if len(set(out)) == 1: 62 | raise Exception("Always got the same output: " + out[0]) 63 | except Exception as e: 64 | print("\tResult: FAILED") 65 | print("\tIteration " + str(i) + ": " + str(e) + str(raw)) 66 | return False 67 | print("\tResult: PASSED") 68 | return True 69 | 70 | ############# 71 | 72 | accelerator = Accelerator() 73 | 74 | for m in MODELS: 75 | model = TransformersModel(m, accelerator, skip_loading_weights=True) 76 | ds = prompts.get_bit_prompts(model) 77 | for api in APIS: 78 | expected_prompt_len, prompt = get_expected_prompt_len(api, model, ds) 79 | # n=1, repeated sampling for 3 times 80 | test_sampling(prompt, m, api, expected_prompt_len, 2) 81 | -------------------------------------------------------------------------------- /experiments/testing/bootstrap_manager.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Union, Tuple 2 | import numpy as np 3 | import torch 4 | import pickle 5 | from model_equality_testing.pvalue import EmpiricalPvalueCalculator 6 | 7 | 8 | def load_parametric_bootstrap(path, min_b=None, max_b=None) -> np.ndarray: 9 | """ 10 | Loads pre-saved test stats from bootstrapping and enforces the correct shapes. 11 | Output shape is (b, 1, n_stats) where b is the number of bootstraps and n_stats is the number of statistics. 12 | """ 13 | with open(path, "rb") as f: 14 | stats = pickle.load(f) 15 | if stats.ndim == 1: 16 | stats = np.expand_dims(stats, 0) 17 | if stats.ndim == 2: 18 | stats = np.expand_dims(stats, 2) 19 | if max_b is not None: 20 | stats = stats[:max_b] 21 | if min_b is not None: 22 | assert min_b <= len( 23 | stats 24 | ), f"min_b={min_b} is greater than the number of bootstraps {len(stats)}" 25 | return stats 26 | 27 | 28 | class BootstrapManager: 29 | """ 30 | Helper class to manage loading of bootstrapped statistics with different sample sizes. 31 | """ 32 | 33 | def __init__(self, bootstrap_path_template: str, min_b=None, max_b=None): 34 | """ 35 | Args: 36 | bootstrap_path_template: str 37 | A string template that can be filled in with the sample size n 38 | Example: "cache/parametric_bootstrap_stats/meta-llama-Meta-Llama-3-8B-wikipedia-{n}.pkl" 39 | min_b: int 40 | Minimum number of bootstraps to load 41 | max_b: int 42 | Maximum number of bootstraps to load 43 | """ 44 | self._bootstrap_path_template = bootstrap_path_template 45 | self.min_b = min_b 46 | self.max_b = max_b 47 | self._stats = None 48 | 49 | def load( 50 | self, 51 | return_stats=False, 52 | **kwargs, 53 | ) -> Union[EmpiricalPvalueCalculator, Tuple[np.ndarray, EmpiricalPvalueCalculator]]: 54 | """ 55 | Loads the bootstrapped statistics for a given sample size n. 56 | Args: 57 | n: int or List[int] 58 | The sample size to load statistics for 59 | return_stats: bool 60 | If True, returns the raw statistics as well as the p-value calculator 61 | Returns: 62 | EmpiricalPvalueCalculator or Tuple[np.ndarray, EmpiricalPvalueCalculator] 63 | """ 64 | for k, v in kwargs.items(): 65 | try: 66 | v = v.item() 67 | except: 68 | pass 69 | self._stats = load_parametric_bootstrap( 70 | self._bootstrap_path_template.format(**kwargs), 71 | min_b=self.min_b, 72 | max_b=self.max_b, 73 | ) 74 | 75 | if return_stats: 76 | return self._stats, EmpiricalPvalueCalculator(self._stats) 77 | return EmpiricalPvalueCalculator(self._stats) 78 | -------------------------------------------------------------------------------- /experiments/sampling/cache_local_samples.py: -------------------------------------------------------------------------------- 1 | """ 2 | Caches n samples / prompt by locally inferencing a model on a dataset of prompts. 3 | """ 4 | 5 | import torch 6 | import tqdm 7 | import argparse 8 | import os 9 | import experiments.prompts as prompts_module 10 | from experiments.sampling.model import TransformersModel 11 | from experiments.utils import ( 12 | str_to_bool, 13 | ParseKwargs, 14 | build_cache_filename, 15 | ) 16 | from accelerate import Accelerator 17 | import pickle 18 | import glob 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--model", type=str, required=True) 24 | parser.add_argument("--model_kwargs", nargs="*", action=ParseKwargs, default={}) 25 | parser.add_argument( 26 | "--source", 27 | type=str, 28 | default="fp32", 29 | ) 30 | parser.add_argument("--prompts", default="dummy", type=str) 31 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 32 | parser.add_argument("--temperature", type=float, default=None) 33 | parser.add_argument("--top_p", type=float, default=None) 34 | parser.add_argument("--L", type=int, default=50) 35 | parser.add_argument("--n", type=int, default=10000) 36 | parser.add_argument("--batch_size", type=int, default=90) 37 | parser.add_argument("--save_dir", type=str, default="../cache/samples") 38 | args = parser.parse_args() 39 | accelerator = Accelerator() 40 | print(args) 41 | 42 | # setup model 43 | fixed_decoding_params = { 44 | "temperature": args.temperature if args.temperature is not None else 1, 45 | "top_p": args.top_p if args.top_p is not None else 1, 46 | } 47 | kwargs = {} 48 | if args.source == "fp16": 49 | kwargs = {"cast_dtype": torch.float16} 50 | elif args.source == "int8": 51 | kwargs = {"quantize": 8} 52 | elif args.source == "nf4": 53 | kwargs = {"quantize": 4} 54 | elif args.source == "watermark": 55 | kwargs = {"watermark_bias": 2.5} 56 | model = TransformersModel( 57 | args.model, 58 | accelerator=accelerator, 59 | fixed_decoding_params=fixed_decoding_params, 60 | batch_size=args.batch_size, 61 | **args.model_kwargs, 62 | **kwargs 63 | ) 64 | 65 | # load dataset 66 | ds = getattr(prompts_module, f"get_{args.prompts}_prompts")(model) 67 | try: 68 | # causes issues if kept 69 | ds = ds.remove_columns(["chat_tokens", "chat_with_ellipses_tokens"]) 70 | except: 71 | pass 72 | 73 | # get save string 74 | filename = build_cache_filename( 75 | model=args.model, 76 | prompts=args.prompts, 77 | alternative=args.source, 78 | temperature=args.temperature, 79 | top_p=args.top_p, 80 | do_sample=args.do_sample, 81 | L=args.L, 82 | ) 83 | 84 | # run through batches and dump the results 85 | for i in range(len(ds["chat"])): 86 | if os.path.exists(f"{args.save_dir}/{filename}-{i}.pkl"): 87 | print("Skipping", i) 88 | continue 89 | print(f"Collecting samples for prompt {i}") 90 | 91 | out = model.sample( 92 | [ds["chat"][i]], 93 | n=args.n, 94 | L=args.L, 95 | ) # (1, n, L) 96 | out = out.squeeze().tolist() # (n, L) 97 | with open(f"{args.save_dir}/{filename}-{i}.pkl", "wb") as f: 98 | pickle.dump(out, f) 99 | -------------------------------------------------------------------------------- /experiments/constants/minimum_prompt_indices.json: -------------------------------------------------------------------------------- 1 | { 2 | "wikipedia_en": [ 3 | 21, 4 | 22, 5 | 24, 6 | 28, 7 | 31, 8 | 32, 9 | 37, 10 | 39, 11 | 41, 12 | 42, 13 | 44, 14 | 46, 15 | 48, 16 | 49, 17 | 50, 18 | 51, 19 | 52, 20 | 53, 21 | 55, 22 | 57, 23 | 59, 24 | 60, 25 | 62, 26 | 63, 27 | 64, 28 | 65, 29 | 66, 30 | 67, 31 | 68, 32 | 70, 33 | 73, 34 | 74, 35 | 75, 36 | 77, 37 | 79, 38 | 80, 39 | 81, 40 | 82, 41 | 83, 42 | 87, 43 | 89, 44 | 90, 45 | 91, 46 | 95, 47 | 99 48 | ], 49 | "wikipedia_de": [ 50 | 20, 51 | 21, 52 | 23, 53 | 24, 54 | 25, 55 | 27, 56 | 29, 57 | 31, 58 | 33, 59 | 35, 60 | 36, 61 | 44, 62 | 46, 63 | 48, 64 | 49, 65 | 51, 66 | 55, 67 | 57, 68 | 59, 69 | 61, 70 | 64, 71 | 66, 72 | 67, 73 | 72, 74 | 74, 75 | 78, 76 | 80, 77 | 81, 78 | 83, 79 | 85, 80 | 88, 81 | 90, 82 | 92, 83 | 94, 84 | 95, 85 | 96, 86 | 97, 87 | 99 88 | ], 89 | "wikipedia_fr": [ 90 | 20, 91 | 24, 92 | 27, 93 | 28, 94 | 29, 95 | 34, 96 | 35, 97 | 36, 98 | 37, 99 | 45, 100 | 46, 101 | 47, 102 | 50, 103 | 51, 104 | 52, 105 | 57, 106 | 61, 107 | 63, 108 | 64, 109 | 65, 110 | 67, 111 | 70, 112 | 73, 113 | 76, 114 | 78, 115 | 81, 116 | 83, 117 | 85, 118 | 88, 119 | 89, 120 | 92, 121 | 93, 122 | 94, 123 | 95, 124 | 98, 125 | 99 126 | ], 127 | "wikipedia_ru": [ 128 | 20, 129 | 21, 130 | 24, 131 | 25, 132 | 27, 133 | 28, 134 | 29, 135 | 30, 136 | 31, 137 | 35, 138 | 37, 139 | 39, 140 | 44, 141 | 46, 142 | 51, 143 | 54, 144 | 56, 145 | 57, 146 | 60, 147 | 62, 148 | 65, 149 | 66, 150 | 69, 151 | 70, 152 | 71, 153 | 73, 154 | 74, 155 | 75, 156 | 79, 157 | 82, 158 | 86, 159 | 87, 160 | 89, 161 | 90, 162 | 92, 163 | 95, 164 | 97, 165 | 99 166 | ], 167 | "wikipedia_es": [ 168 | 23, 169 | 27, 170 | 28, 171 | 31, 172 | 32, 173 | 33, 174 | 36, 175 | 37, 176 | 38, 177 | 39, 178 | 40, 179 | 42, 180 | 50, 181 | 61, 182 | 62, 183 | 63, 184 | 66, 185 | 67, 186 | 69, 187 | 72, 188 | 74, 189 | 75, 190 | 77, 191 | 78, 192 | 80, 193 | 81, 194 | 82, 195 | 84, 196 | 85, 197 | 88, 198 | 90, 199 | 92, 200 | 93, 201 | 99 202 | ] 203 | } -------------------------------------------------------------------------------- /experiments/sampling/cache_logprobs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Given a set of samples, computes the logprobs of the individual completion tokens (one number per token) under the fp32 null. 3 | Used in dataset construction; later applied for goodness-of-fit testing. 4 | Assumes that samples are saved as *-{i}.pkl files, where i is the index of the prompt in the dataset, 5 | and that the dataset is implemented in experiments.prompts. 6 | """ 7 | 8 | import torch 9 | import tqdm 10 | import glob 11 | import pickle 12 | import argparse 13 | import os 14 | import experiments.prompts as prompts_module 15 | from experiments.sampling.model import TransformersModel 16 | from experiments.utils import ( 17 | str_to_bool, 18 | ParseKwargs, 19 | stack_with_padding, 20 | ) 21 | from accelerate import Accelerator 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | "--model", type=str, required=True, help="Model to get logprobs with" 27 | ) 28 | parser.add_argument("--model_kwargs", nargs="*", action=ParseKwargs, default={}) 29 | parser.add_argument( 30 | "--prompts", type=str, required=True, help="Name of prompts dataset to get logprobs for" 31 | ) 32 | parser.add_argument("--samples_path_template", type=str, required=True, help="String template for pkl file of samples, saved as list of lists of tokens ids. The first part of {template}-{i}.pkl") 33 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 34 | parser.add_argument("--temperature", type=float, default=None) 35 | parser.add_argument("--top_p", type=float, default=None) 36 | parser.add_argument("--batch_size", type=int, default=90) 37 | parser.add_argument("--save_dir", type=str, default="../cache/logprobs") 38 | args = parser.parse_args() 39 | accelerator = Accelerator() 40 | print(args) 41 | 42 | # setup model 43 | fixed_decoding_params = { 44 | "temperature": args.temperature, 45 | "top_p": args.top_p, 46 | } 47 | model = TransformersModel( 48 | args.model, 49 | accelerator=accelerator, 50 | fixed_decoding_params=fixed_decoding_params, 51 | batch_size=args.batch_size, 52 | **args.model_kwargs 53 | ) 54 | 55 | # load dataset 56 | ds = getattr(prompts_module, f"get_{args.prompts}_prompts")(model) 57 | try: 58 | # causes issues if kept 59 | ds = ds.remove_columns(["chat_tokens", "chat_with_ellipses_tokens"]) 60 | except: 61 | pass 62 | 63 | # run through batches and dump the results 64 | for path in tqdm.tqdm(glob.glob(f"{args.samples_path_template}-*")): 65 | i = int(os.path.basename(path).split("-")[-1].split(".")[0]) 66 | if os.path.exists(f"{args.save_dir}/{os.path.basename(args.samples_path_template)}-{i}.pkl"): 67 | print("Skipping", path) 68 | continue 69 | 70 | print(f"Collecting logprobs for prompt {i} based on path {path}") 71 | 72 | with open(path, "rb") as f: 73 | # note: this expects to pkl file to contain a list of lists of integers (token IDs) 74 | completions = [torch.tensor(x) for x in pickle.load(f)] 75 | 76 | completions, attention_mask = stack_with_padding(completions) 77 | 78 | logprobs = model.get_logprobs( 79 | prompts=[ds[i]["chat"]] * len(completions), 80 | completion_input_ids=completions, 81 | completion_attention_mask=attention_mask, 82 | ) 83 | with open(f"{args.save_dir}/{os.path.basename(args.samples_path_template)}-{i}.pkl", "wb") as f: 84 | out = {tuple(seq.tolist()): lp.tolist() for seq, lp in zip(completions, logprobs)} 85 | pickle.dump(out, f) -------------------------------------------------------------------------------- /experiments/sampling/cache_local_samples_vllm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Caches n samples / prompt by locally inferencing a model on a dataset of prompts. 3 | """ 4 | 5 | from vllm import LLM, SamplingParams 6 | import torch 7 | import tqdm 8 | import argparse 9 | import os 10 | import experiments.prompts as prompts_module 11 | from experiments.utils import ( 12 | str_to_bool, 13 | ParseKwargs, 14 | build_cache_filename, 15 | ) 16 | import pickle 17 | from accelerate import Accelerator 18 | from experiments.sampling.model import TransformersModel 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--model", type=str, required=True) 24 | parser.add_argument("--model_kwargs", nargs="*", action=ParseKwargs, default={}) 25 | parser.add_argument( 26 | "--source", 27 | type=str, 28 | default="fp32", 29 | ) 30 | parser.add_argument("--prompts", default="dummy", type=str) 31 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 32 | parser.add_argument("--temperature", type=float, default=None) 33 | parser.add_argument("--top_p", type=float, default=None) 34 | parser.add_argument("--L", type=int, default=50) 35 | parser.add_argument("--n", type=int, default=10000) 36 | parser.add_argument("--batch_size", type=int, default=90) 37 | parser.add_argument("--save_dir", type=str, default="../cache/samples") 38 | args = parser.parse_args() 39 | accelerator = Accelerator() 40 | print(args) 41 | 42 | # get dataset 43 | model_to_load_prompts = TransformersModel( 44 | args.model, 45 | accelerator=accelerator, 46 | skip_loading_weights=True, 47 | ) 48 | ds = getattr(prompts_module, f"get_{args.prompts}_prompts")(model_to_load_prompts) 49 | try: 50 | # causes issues if kept 51 | ds = ds.remove_columns(["chat_tokens", "chat_with_ellipses_tokens"]) 52 | except: 53 | pass 54 | 55 | # initialize vllm model 56 | if args.source == "None": 57 | kwargs = {"dtype": "float32"} 58 | elif args.source == "fp16": 59 | kwargs = {"dtype": "float16"} 60 | elif args.source == "nf4": 61 | kwargs = {"quantization": "bitsandbytes", "load_format": "bitsandbytes"} 62 | else: 63 | raise ValueError("Cannot use that with vllm") 64 | 65 | model = LLM( 66 | model=args.model, 67 | tensor_parallel_size=len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")), 68 | **kwargs, 69 | **args.model_kwargs, 70 | ) 71 | 72 | # get save string 73 | filename = build_cache_filename( 74 | model=args.model, 75 | prompts=args.prompts, 76 | alternative=args.source, 77 | temperature=args.temperature, 78 | top_p=args.top_p, 79 | do_sample=args.do_sample, 80 | L=args.L, 81 | ) 82 | 83 | sampling_params = SamplingParams( 84 | n=args.batch_size, 85 | temperature=args.temperature if args.temperature is not None else 1, 86 | top_p=args.top_p if args.top_p is not None else 1, 87 | top_k=1 if not args.do_sample else -1, 88 | max_tokens=args.L, 89 | stop_token_ids=[], 90 | skip_special_tokens=False, 91 | ignore_eos=True, 92 | logprobs=None, 93 | ) 94 | 95 | # run through batches and dump the results 96 | for i in range(len(ds)): 97 | if os.path.exists(f"{args.save_dir}/{filename}-{i}.pkl"): 98 | print("Skipping", i) 99 | continue 100 | print(f"Collecting samples for prompt {i}") 101 | 102 | out = [] 103 | for batch_size in tqdm.tqdm( 104 | [ 105 | min(args.batch_size, args.n - j * args.batch_size) 106 | for j in range( 107 | (args.n + args.batch_size - 1) // args.batch_size 108 | ) 109 | ] 110 | ): 111 | out.extend( 112 | model.generate( 113 | [ds["chat"][i]], 114 | sampling_params, 115 | ) 116 | ) 117 | sample = [oi.token_ids for o in out for oi in o.outputs] 118 | 119 | with open(f"{args.save_dir}/{filename}-{i}.pkl", "wb") as f: 120 | pickle.dump(sample, f) -------------------------------------------------------------------------------- /experiments/testing/cache_one_sample_bootstrap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calls model_equality_testing.pvalue.one_sample_parametric_bootstrap_pvalue repeatedly to cache simulated test statistics (parametric bootstrap) for a given null x test statistic. 3 | """ 4 | 5 | from model_equality_testing.dataset import load_distribution 6 | from model_equality_testing.pvalue import one_sample_parametric_bootstrap_pvalue 7 | import pickle 8 | from experiments.utils import build_cache_filename, str_to_bool 9 | import argparse 10 | import os 11 | from typing import List 12 | import json 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--model", 18 | type=str, 19 | help="Name of the model", 20 | default="meta-llama/Meta-Llama-3-8B-Instruct", 21 | ) 22 | parser.add_argument( 23 | "--null_distribution_source", 24 | type=str, 25 | default="fp32", 26 | help="Source to use as the null distribution", 27 | ) 28 | parser.add_argument( 29 | "--prompt_indices_json", 30 | type=str, 31 | help="JSON of prompt dataset name: [list of indices]. Prompt distribution will be uniform over these prompts", 32 | required=True, 33 | ) 34 | parser.add_argument( 35 | "--n", 36 | type=int, 37 | help="Number of samples to draw each time we compute the test statistic", 38 | default=None, 39 | ) 40 | parser.add_argument( 41 | "--n_per_prompt", 42 | type=int, 43 | help="Alternative to --n, number of samples per prompt", 44 | default=None, 45 | ) 46 | parser.add_argument( 47 | "--stat", type=str, default="g_squared", help="One-sample test statistic" 48 | ) 49 | parser.add_argument( 50 | "--test_in_unicode", 51 | type=str_to_bool, 52 | default=True, 53 | help="Test in unicode space instead of token space", 54 | ) 55 | parser.add_argument( 56 | "--b", type=int, default=1000, help="Number of bootstrap samples" 57 | ) 58 | parser.add_argument("--L", type=int, default=1000, help="Maximum completion length") 59 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 60 | parser.add_argument("--temperature", type=float, default=None) 61 | parser.add_argument("--top_p", type=float, default=None) 62 | parser.add_argument( 63 | "--save_dir", type=str, default="../cache/parametric_bootstrap_stats" 64 | ) 65 | args = parser.parse_args() 66 | print(args) 67 | 68 | # Load the prompt indices 69 | js = json.load(open(args.prompt_indices_json)) 70 | prompts = list(js.keys()) 71 | # get the number of prompts m 72 | m = sum([len(js[k]) for k in js]) 73 | print("Number of prompts: ", m) 74 | 75 | assert (args.n is not None) ^ ( 76 | args.n_per_prompt is not None 77 | ), "Exactly one of --n or --n_per_prompt must be provided" 78 | if args.n_per_prompt is not None: 79 | args.n = m * args.n_per_prompt 80 | print("Number of samples: ", args.n) 81 | 82 | # Load dataset 83 | p = load_distribution( 84 | model=args.model, 85 | prompt_ids=js, 86 | L=args.L, 87 | source=args.null_distribution_source, 88 | load_in_unicode=args.test_in_unicode, 89 | ) 90 | print("Null shape: ", p.shape) 91 | 92 | # Construct the filename 93 | filename = build_cache_filename( 94 | model=args.model, 95 | prompts=prompts, 96 | prompt_indices_json=args.prompt_indices_json, 97 | alternative=args.null_distribution_source, 98 | temperature=args.temperature, 99 | top_p=args.top_p, 100 | do_sample=args.do_sample, 101 | L=args.L, 102 | stat=args.stat, 103 | N=args.n, 104 | use_char_space=args.test_in_unicode, 105 | ) 106 | out_path = f"{args.save_dir}/{filename}.pkl" 107 | print(f"Will write to {out_path}") 108 | 109 | # Skip if already exists 110 | if os.path.exists(out_path): 111 | print("Skipping...") 112 | exit() 113 | 114 | # Cache the test statistics 115 | _, s = one_sample_parametric_bootstrap_pvalue( 116 | null_dist=p, 117 | n=args.n, 118 | b=args.b, 119 | return_stats=True, 120 | stat_type=args.stat, 121 | ) 122 | 123 | with open(out_path, "wb") as f: 124 | pickle.dump(s, f) 125 | -------------------------------------------------------------------------------- /model_equality_testing/src/algorithm.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, List, Dict 2 | from model_equality_testing.pvalue import ( 3 | EmpiricalPvalueCalculator, 4 | one_sample_parametric_bootstrap_pvalue, 5 | two_sample_parametric_bootstrap_pvalue, 6 | two_sample_permutation_pvalue, 7 | ) 8 | from model_equality_testing.tests import IMPLEMENTED_TESTS 9 | from model_equality_testing.distribution import ( 10 | CompletionSample, 11 | DistributionFromDataset, 12 | ) 13 | 14 | 15 | def _noop_pvalue(*args, **kwargs): 16 | return 1.0 17 | 18 | 19 | def run_goodness_of_fit_test( 20 | sample: CompletionSample, 21 | null_dist: DistributionFromDataset, 22 | get_pvalue: Union[callable, EmpiricalPvalueCalculator] = None, 23 | pvalue_type: str = "parametric_bootstrap", 24 | stat_type: str = "g_squared", 25 | b=1000, 26 | **kwargs, 27 | ) -> Tuple[float, float]: 28 | """ 29 | Tests whether the sample is drawn from the null distribution 30 | Args: 31 | sample: CompletionSample 32 | null_dist: DistributionFromDataset 33 | get_pvalue: callable or EmpiricalPvalueCalculator 34 | Given a test statistic, returns the p-value 35 | The function should take in one argument (a float, np.ndarray, or torch.Tensor) 36 | representing the observed statistic, and it should return a float (the pvalue). 37 | pvalue_type: str 38 | If get_pvalue is None, how to compute the p-value 39 | stat_type: str 40 | Which test statistic to compute 41 | b: int 42 | Number of bootstrap samples if pvalue_type is "parametric_bootstrap" 43 | kwargs 44 | Additional arguments to pass to the test statistic function 45 | Returns: 46 | pvalue: float 47 | statistic: float 48 | """ 49 | if get_pvalue is None: 50 | if pvalue_type == "parametric_bootstrap": 51 | get_pvalue = one_sample_parametric_bootstrap_pvalue( 52 | null_dist=null_dist, 53 | n=sample.N, 54 | b=b, 55 | return_stats=False, 56 | stat_type=stat_type, 57 | ) 58 | elif pvalue_type == "dummy": 59 | get_pvalue = _noop_pvalue 60 | else: 61 | raise ValueError("Unrecognized p-value type") 62 | 63 | statistic = IMPLEMENTED_TESTS[stat_type](sample, null_dist, **kwargs) 64 | pvalue = get_pvalue(statistic) 65 | if not isinstance(pvalue, float): 66 | pvalue = pvalue.item() 67 | return (pvalue, statistic) 68 | 69 | 70 | def run_two_sample_test( 71 | sample: CompletionSample, 72 | other_sample: CompletionSample, 73 | null_dist: DistributionFromDataset = None, 74 | get_pvalue: Union[callable, EmpiricalPvalueCalculator] = None, 75 | pvalue_type: str = "permutation_pvalue", 76 | stat_type: str = "two_sample_L2", 77 | b=1000, 78 | **kwargs, 79 | ) -> Tuple[float, float]: 80 | """ 81 | Tests whether the samples are drawn from the same distribution 82 | Args: 83 | sample: CompletionSample 84 | other_sample: CompletionSample 85 | null_dist: DistributionFromDataset 86 | get_pvalue: callable or EmpiricalPvalueCalculator 87 | Given a test statistic, returns the p-value 88 | The function should take in one argument (a float, np.ndarray, or torch.Tensor) 89 | representing the observed statistic, and it should return a float (the pvalue). 90 | pvalue_type: str 91 | If get_pvalue is None, how to compute the p-value 92 | stat_type: str 93 | Which test statistic to compute 94 | b: int 95 | Number of bootstrap samples if pvalue_type is "parametric_bootstrap" 96 | kwargs 97 | Additional arguments to pass to the test statistic function 98 | Returns: 99 | pvalue: float 100 | statistic: float 101 | """ 102 | if get_pvalue is None: 103 | if pvalue_type == "permutation_pvalue": 104 | get_pvalue = two_sample_permutation_pvalue( 105 | sample, other_sample, b=b, stat_type=stat_type, **kwargs 106 | ) 107 | elif pvalue_type == "parametric_bootstrap": 108 | assert ( 109 | null_dist is not None 110 | ), "Must provide null distribution for parametric bootstrap" 111 | get_pvalue = two_sample_parametric_bootstrap_pvalue( 112 | null_dist=null_dist, 113 | n1=sample.N, 114 | n2=other_sample.N, 115 | b=b, 116 | stat_type=stat_type, 117 | **kwargs, 118 | ) 119 | elif pvalue_type == "dummy": 120 | get_pvalue = _noop_pvalue 121 | else: 122 | raise ValueError("Unrecognized p-value type") 123 | 124 | statistic = IMPLEMENTED_TESTS[stat_type](sample, other_sample, **kwargs) 125 | pvalue = get_pvalue(statistic) 126 | if not isinstance(pvalue, float): 127 | pvalue = pvalue.item() 128 | return (pvalue, statistic) 129 | -------------------------------------------------------------------------------- /experiments/testing/cache_two_sample_bootstrap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calls model_equality_testing.pvalue.two_sample_parametric_bootstrap_pvalue repeatedly to cache simulated test statistics (parametric bootstrap) for a given null x test statistic. 3 | """ 4 | 5 | from model_equality_testing.dataset import load_distribution 6 | from model_equality_testing.pvalue import two_sample_parametric_bootstrap_pvalue 7 | import pickle 8 | from experiments.utils import build_cache_filename, str_to_bool 9 | import argparse 10 | import os 11 | from typing import List 12 | import json 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--model", 18 | type=str, 19 | help="Name of the model", 20 | default="meta-llama/Meta-Llama-3-8B-Instruct", 21 | ) 22 | parser.add_argument( 23 | "--null_distribution_source", 24 | type=str, 25 | default="fp32", 26 | help="Source to use as the null distribution", 27 | ) 28 | parser.add_argument( 29 | "--prompt_indices_json", 30 | type=str, 31 | help="JSON of prompt dataset name: [list of indices]. Prompt distribution will be uniform over these prompts", 32 | required=True, 33 | ) 34 | parser.add_argument( 35 | "--n1", 36 | type=int, 37 | help="Number of samples to compute statistics with, for the first sample (null in cache_power)", 38 | default=None, 39 | ) 40 | parser.add_argument( 41 | "--n1_per_prompt", 42 | type=int, 43 | help="Alternative to --n1, number of samples per prompt, for the first sample (null in cache_power)", 44 | default=None, 45 | ) 46 | parser.add_argument( 47 | "--n2", 48 | type=int, 49 | help="Number of samples to compute statistics with, for the second sample (alt in cache_power)", 50 | default=None, 51 | ) 52 | parser.add_argument( 53 | "--n2_per_prompt", 54 | type=int, 55 | help="Alternative to --n1, number of samples per prompt, for the second sample (alt in cache_power)", 56 | default=None, 57 | ) 58 | parser.add_argument( 59 | "--stat", type=str, default="two_sample_L2", help="One-sample test statistic" 60 | ) 61 | parser.add_argument( 62 | "--test_in_unicode", 63 | type=str_to_bool, 64 | default=True, 65 | help="Test in unicode space instead of token space", 66 | ) 67 | parser.add_argument( 68 | "--b", type=int, default=1000, help="Number of bootstrap samples" 69 | ) 70 | parser.add_argument("--L", type=int, default=1000, help="Maximum completion length") 71 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 72 | parser.add_argument("--temperature", type=float, default=None) 73 | parser.add_argument("--top_p", type=float, default=None) 74 | parser.add_argument( 75 | "--save_dir", type=str, default="../cache/parametric_bootstrap_stats" 76 | ) 77 | args = parser.parse_args() 78 | print(args) 79 | 80 | # Load the prompt indices 81 | js = json.load(open(args.prompt_indices_json)) 82 | prompts = list(js.keys()) 83 | # get the number of prompts m 84 | m = sum([len(js[k]) for k in js]) 85 | print("Number of prompts: ", m) 86 | 87 | assert (args.n1 is not None) ^ ( 88 | args.n1_per_prompt is not None 89 | ), "Exactly one of --n1 or --n1_per_prompt must be provided" 90 | if args.n1_per_prompt is not None: 91 | args.n1 = m * args.n1_per_prompt 92 | assert (args.n2 is not None) ^ ( 93 | args.n2_per_prompt is not None 94 | ), "Exactly one of --n2 or --n2_per_prompt must be provided" 95 | if args.n2_per_prompt is not None: 96 | args.n2 = m * args.n2_per_prompt 97 | print("Number of samples: ", args.n1, args.n2) 98 | 99 | # Load dataset 100 | p = load_distribution( 101 | model=args.model, 102 | prompt_ids=js, 103 | L=args.L, 104 | source=args.null_distribution_source, 105 | load_in_unicode=args.test_in_unicode, 106 | ) 107 | print("Null shape: ", p.shape) 108 | 109 | # Construct the filename 110 | filename = build_cache_filename( 111 | model=args.model, 112 | prompts=prompts, 113 | prompt_indices_json=args.prompt_indices_json, 114 | alternative=args.null_distribution_source, 115 | temperature=args.temperature, 116 | top_p=args.top_p, 117 | do_sample=args.do_sample, 118 | L=args.L, 119 | stat=args.stat, 120 | N=f"{args.n1}_{args.n2}", 121 | use_char_space=args.test_in_unicode, 122 | ) 123 | out_path = f"{args.save_dir}/{filename}.pkl" 124 | print(f"Will write to {out_path}") 125 | 126 | # Skip if already exists 127 | if os.path.exists(out_path): 128 | print("Skipping...") 129 | exit() 130 | 131 | # Cache the test statistics 132 | _, s = two_sample_parametric_bootstrap_pvalue( 133 | null_dist=p, 134 | n1=args.n1, 135 | n2=args.n2, 136 | b=args.b, 137 | return_stats=True, 138 | stat_type=args.stat, 139 | ) 140 | 141 | with open(out_path, "wb") as f: 142 | pickle.dump(s, f) 143 | -------------------------------------------------------------------------------- /experiments/sampling/cache_api_samples.py: -------------------------------------------------------------------------------- 1 | """ 2 | Caches n samples / prompt from an API on a dataset of prompts. 3 | """ 4 | 5 | import tqdm 6 | import os 7 | import argparse 8 | import experiments.prompts as prompts_module 9 | from experiments.sampling.model import TransformersModel 10 | from experiments.utils import ( 11 | str_to_bool, 12 | build_cache_filename, 13 | wait_if_error, 14 | ) 15 | import experiments.sampling.api as api_module 16 | import time 17 | from dataclasses import dataclass, asdict 18 | import time 19 | from accelerate import Accelerator 20 | import json 21 | import pickle 22 | 23 | """ 24 | Logic to choose how to query APIs; this was selected to try to account for different tokenization policies by APIs. Set by manually testing whether the number of tokens in the prompt is the same as the number of prompt tokens mentioned in the returned message. 25 | """ 26 | 27 | 28 | # case-by-case policies to handle uniform tokenization. use test_api_tokenization.py to set these 29 | def policy(api, model): 30 | if api in ["replicate"]: 31 | default = { 32 | "prompt_key": "chat", 33 | "use_chat_endpoint": False, 34 | "expected_prompt_key": "chat", 35 | } 36 | elif api == "amazon": 37 | default = { 38 | "prompt_key": "chat_with_special", 39 | "use_chat_endpoint": False, 40 | "expected_prompt_key": "chat", 41 | } 42 | else: 43 | default = { 44 | "prompt_key": "plain", 45 | "use_chat_endpoint": True, 46 | "expected_prompt_key": "chat", 47 | } 48 | return default 49 | 50 | 51 | FILE_DIR = os.path.dirname(os.path.abspath(__file__)) 52 | with open(f"{FILE_DIR}/../constants/minimum_prompt_indices.json") as f: 53 | BARE_MINIMUM = json.load(f) 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument( 58 | "--model", type=str, required=True, help="HF transformers model name" 59 | ) 60 | parser.add_argument("--backend", type=str, required=True, help="API name") 61 | parser.add_argument("--prompts", default="dummy") 62 | parser.add_argument("--L", type=int, default=3) 63 | parser.add_argument( 64 | "--n", type=int, default=1, help="Number of samples to generate per prompt" 65 | ) 66 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 67 | parser.add_argument("--temperature", type=float, default=None) 68 | parser.add_argument("--top_p", type=float, default=None) 69 | parser.add_argument("--val_cutoff", type=int, default=20) 70 | parser.add_argument( 71 | "--sample_bare_minimum", 72 | type=str_to_bool, 73 | default=False, 74 | help="Whether to only sample the bare minimum prompt indices used for the prompt distributions in the paper, listed in constants/minimum_prompt_indices.json", 75 | ) 76 | parser.add_argument("--save_dir", type=str, default="../cache/api") 77 | args = parser.parse_args() 78 | accelerator = Accelerator() 79 | print(args) 80 | 81 | # get dataset 82 | model_to_load_prompts = TransformersModel( 83 | args.model, 84 | accelerator=accelerator, 85 | skip_loading_weights=True, 86 | ) 87 | ds = getattr(prompts_module, f"get_{args.prompts}_prompts")(model_to_load_prompts) 88 | try: 89 | # causes issues if kept 90 | ds = ds.remove_columns(["chat_tokens", "chat_with_ellipses_tokens"]) 91 | except: 92 | pass 93 | 94 | setup_kwargs = policy(args.backend, args.model) 95 | 96 | # backend 97 | try: 98 | get_fn, kwargs = getattr(api_module, f"setup_{args.backend}")( 99 | model=args.model, 100 | N=1, 101 | L=args.L, 102 | use_chat_endpoint=setup_kwargs["use_chat_endpoint"], 103 | do_sample=args.do_sample, 104 | temperature=args.temperature, 105 | top_p=args.top_p, 106 | ) 107 | except AttributeError: 108 | raise ValueError("Unrecognized backend") 109 | 110 | filename = build_cache_filename( 111 | model=args.model, 112 | prompts=args.prompts, 113 | alternative=args.backend, 114 | temperature=args.temperature, 115 | top_p=args.top_p, 116 | do_sample=args.do_sample, 117 | L=args.L, 118 | ) 119 | filename = f"{args.save_dir}/{filename}" 120 | 121 | # get samples 122 | for it, x in tqdm.tqdm(enumerate(ds)): 123 | if ( 124 | args.sample_bare_minimum 125 | and args.prompts in BARE_MINIMUM 126 | and it not in BARE_MINIMUM[args.prompts] 127 | ): 128 | continue 129 | if os.path.exists(f"{filename}-{it}.pkl"): 130 | print("Skipping, exists...") 131 | continue 132 | 133 | print("Sampling", it) 134 | time.sleep(30) # avoid rate limiting 135 | 136 | out = [] 137 | for i in range(args.n): 138 | print("> sample #", i) 139 | o = wait_if_error( 140 | get_fn, timeout=10, prompt=x[setup_kwargs["prompt_key"]], **kwargs 141 | ) 142 | if o is not None: 143 | out.extend(o) 144 | out = [asdict(o) for o in out] 145 | d = { 146 | "samples": out, 147 | "id": x["id"], 148 | "prompt": x[setup_kwargs["prompt_key"]], 149 | "y": x.get("y", None), 150 | } 151 | 152 | with open(f"{filename}-{it}.pkl", "wb") as f: 153 | pickle.dump(d, f) 154 | -------------------------------------------------------------------------------- /experiments/testing/simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from model_equality_testing.distribution import DistributionFromDataset 4 | from model_equality_testing.algorithm import ( 5 | run_goodness_of_fit_test, 6 | run_two_sample_test, 7 | ) 8 | from typing import Union, List, Tuple 9 | import tqdm 10 | import gc 11 | 12 | 13 | def get_power_one_sample( 14 | null_dist: DistributionFromDataset, 15 | data_dist: DistributionFromDataset, 16 | n: int, 17 | n_simulations=100, 18 | alpha: float = 0.05, 19 | pvalue_type: str = "parametric_bootstrap", 20 | stat_type: str = "g_squared", 21 | get_pvalue_fn=None, 22 | return_pvalue: bool = False, 23 | return_alpha: bool = False, 24 | return_stat: bool = False, 25 | **kwargs, 26 | ): 27 | """ 28 | Power analysis for joint (prompt, completion) distribution 29 | """ 30 | rejections = [] 31 | pvalues = [] 32 | stat = [] 33 | for _ in tqdm.tqdm(range(n_simulations), desc="Power simulation"): 34 | sample = data_dist.sample(n=n) 35 | pv, s = run_goodness_of_fit_test( 36 | sample=sample, 37 | null_dist=null_dist, 38 | get_pvalue=get_pvalue_fn, 39 | pvalue_type=pvalue_type, 40 | stat_type=stat_type, 41 | **kwargs, 42 | ) 43 | stat.append(s) 44 | pvalues.append(pv) 45 | rejections.append(int(pv <= alpha)) 46 | power = sum(rejections) / n_simulations 47 | return_tuple = (power, np.array(rejections, dtype=bool)) 48 | if return_alpha: 49 | return_tuple += (alpha * np.ones((n_simulations, null_dist.m)),) 50 | if return_pvalue: 51 | return_tuple += (torch.from_numpy(np.stack(pvalues)),) 52 | if return_stat: 53 | return_tuple += (np.stack(stat),) 54 | return return_tuple 55 | 56 | 57 | def get_power_two_sample( 58 | null_dist: DistributionFromDataset, 59 | data_dist: DistributionFromDataset, 60 | n_null: int, 61 | n_data: int, 62 | n_simulations=100, 63 | alpha: float = 0.05, 64 | pvalue_type: str = "permutation_pvalue", 65 | stat_type: str = "two_sample_L2", 66 | get_pvalue_fn=None, 67 | return_pvalue: bool = False, 68 | return_alpha: bool = False, 69 | return_stat: bool = False, 70 | **kwargs, 71 | ): 72 | """ 73 | Power analysis for joint (prompt, completion) distribution 74 | """ 75 | rejections = [] 76 | pvalues = [] 77 | stat = [] 78 | for _ in tqdm.tqdm(range(n_simulations), desc="Power simulation"): 79 | sample_1 = null_dist.sample(n=n_null) 80 | sample_2 = data_dist.sample(n=n_data) 81 | pv, s = run_two_sample_test( 82 | sample=sample_1, 83 | other_sample=sample_2, 84 | null_dist=null_dist, 85 | get_pvalue=get_pvalue_fn, 86 | pvalue_type=pvalue_type, 87 | stat_type=stat_type, 88 | **kwargs, 89 | ) 90 | pvalues.append(pv) 91 | stat.append(s) 92 | rejections.append(int(pv <= alpha)) 93 | del sample_1, sample_2 94 | gc.collect() 95 | power = sum(rejections) / n_simulations 96 | return_tuple = (power, np.array(rejections, dtype=bool)) 97 | if return_alpha: 98 | return_tuple += (alpha * np.ones((n_simulations, null_dist.m)),) 99 | if return_pvalue: 100 | return_tuple += (torch.from_numpy(np.stack(pvalues)),) 101 | if return_stat: 102 | return_tuple += (np.stack(stat),) 103 | return return_tuple 104 | 105 | 106 | def get_power_two_sample_composite_null( 107 | null_dist_1: DistributionFromDataset, 108 | null_dist_2: DistributionFromDataset, 109 | data_dist: DistributionFromDataset, 110 | n_null: int, 111 | n_data: int, 112 | n_simulations=100, 113 | alpha: float = 0.05, 114 | pvalue_type: str = "permutation_pvalue", 115 | stat_type: str = "two_sample_L2", 116 | get_pvalue_fn_1=None, 117 | get_pvalue_fn_2=None, 118 | return_pvalue: bool = False, 119 | return_alpha: bool = False, 120 | return_stat: bool = False, 121 | **kwargs, 122 | ): 123 | """ 124 | Power analysis for joint (prompt, completion) distribution 125 | """ 126 | rejections = [] 127 | pvalues = [] 128 | stat = [] 129 | for _ in tqdm.tqdm(range(n_simulations), desc="Power simulation"): 130 | sample_1 = null_dist_1.sample(n=n_null) 131 | sample_2 = null_dist_2.sample(n=n_null) 132 | sample_3 = data_dist.sample(n=n_data) 133 | pv1, s1 = run_two_sample_test( 134 | sample=sample_1, 135 | other_sample=sample_3, 136 | get_pvalue=get_pvalue_fn_1, 137 | pvalue_type=pvalue_type, 138 | stat_type=stat_type, 139 | **kwargs, 140 | ) 141 | pv2, s2 = run_two_sample_test( 142 | sample=sample_2, 143 | other_sample=sample_3, 144 | get_pvalue=get_pvalue_fn_2, 145 | pvalue_type=pvalue_type, 146 | stat_type=stat_type, 147 | **kwargs, 148 | ) 149 | rejections.append(int((pv1 <= alpha) and (pv2 <= alpha))) 150 | pvalues.append((pv1, pv2)) 151 | stat.append((s1, s2)) 152 | del sample_1, sample_2 153 | gc.collect() 154 | power = sum(rejections) / n_simulations 155 | return_tuple = (power, np.array(rejections, dtype=bool)) 156 | if return_alpha: 157 | return_tuple += (alpha * np.ones((n_simulations, null_dist_1.m)),) 158 | if return_pvalue: 159 | return_tuple += (torch.from_numpy(np.stack(pvalues)),) 160 | if return_stat: 161 | return_tuple += (np.stack(stat),) 162 | return return_tuple 163 | -------------------------------------------------------------------------------- /model_equality_testing/src/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | import torch 3 | import torch.nn.functional as F 4 | from time import perf_counter 5 | import numpy as np 6 | from functools import lru_cache 7 | 8 | def tokenize_unicode(strings: List[str], pad_token_id: int = -1) -> np.ndarray: 9 | """ 10 | Tokenize a list of strings into a list of lists of unicode codepoints, 11 | and then stack them into a 2D numpy array, using -1 as padding. 12 | 13 | Args: 14 | strings (List[str]): list of strings to tokenize 15 | 16 | Returns: 17 | List[List[int]]: list of lists of unicode codepoints 18 | """ 19 | strings = [torch.tensor([ord(c) for c in s]) for s in strings] 20 | chr_array = stack_with_padding(strings, padding_token=-1)[0].numpy() 21 | return chr_array 22 | 23 | def pad_to_length(samples: np.ndarray, L: int, pad_token_id: int = -1) -> np.ndarray: 24 | """ 25 | Pad a 2D numpy array of samples to a fixed length L using a pad token. 26 | 27 | Args: 28 | samples (np.ndarray): 2D numpy array of samples 29 | L (int): length to pad to 30 | pad_token_id (int): padding token 31 | 32 | If the current length is longer than L, throws an error. 33 | 34 | Returns: 35 | np.ndarray: padded 2D numpy array 36 | """ 37 | if samples.shape[1] > L: 38 | raise ValueError(f"Cannot pad to length {L} because the current length is {samples.shape[1]}") 39 | padded_samples = np.full((samples.shape[0], L), pad_token_id) 40 | padded_samples[:, :samples.shape[1]] = samples 41 | return padded_samples 42 | 43 | def sanitize(s): 44 | """Sanitize a string for use as a filename.""" 45 | s = str(s) 46 | s = s.replace(" ", "-") 47 | s = s.replace("[", "") 48 | s = s.replace("]", "") 49 | s = s.replace(",", "_") 50 | s = s.replace("/", "-") 51 | s = s.replace("(", "") 52 | s = s.replace(")", "") 53 | return s 54 | 55 | 56 | def stack_with_padding( 57 | tensors: List[torch.Tensor], 58 | dim: int = 0, 59 | padding_side: str = "right", 60 | padding_mode: str = "constant", 61 | padding_token: Union[int, float] = 0, 62 | ) -> Tuple[torch.Tensor, torch.Tensor]: 63 | """ 64 | Stack tensors along specified dimension and pad them to ensure their size is equal in all dimensions. 65 | Returns the stacked tensor and a boolean mask indicating valid (non-padded) elements. 66 | 67 | Args: 68 | tensors (List[torch.Tensor]): list of tensors to stack 69 | dim (int): dimension along which to stack tensors. Defaults to 0. 70 | padding_side (str): side on which to pad - "left" or "right". Defaults to "right". 71 | padding_mode (str): 'constant', 'reflect', 'replicate' or 'circular'. Default: 'constant' 72 | padding_value (Union[int, float]): value to use for constant padding 73 | 74 | Returns: 75 | Tuple[torch.Tensor, torch.Tensor]: stacked tensor and boolean mask 76 | """ 77 | # Ensure all tensors have the same number of dimensions 78 | max_dims = max(t.dim() for t in tensors) 79 | tensors = [t.view(*t.shape, *([1] * (max_dims - t.dim()))) for t in tensors] 80 | 81 | # Find the maximum size for each dimension 82 | max_sizes = [max(t.shape[i] for t in tensors) for i in range(max_dims)] 83 | 84 | def make_padding(tensor_shape): 85 | padding = [] 86 | for i in reversed(range(max_dims)): # Reverse for F.pad expectations 87 | pad_size = max_sizes[i] - tensor_shape[i] 88 | if padding_side == "left": 89 | padding.extend([pad_size, 0]) 90 | elif padding_side == "right": 91 | padding.extend([0, pad_size]) 92 | else: 93 | raise ValueError(f"padding_side '{padding_side}' is unknown") 94 | return tuple(padding) 95 | 96 | padded_tensors = [] 97 | masks = [] 98 | 99 | for t in tensors: 100 | padding = make_padding(t.shape) 101 | padded_t = F.pad(t, padding, mode=padding_mode, value=padding_token) 102 | 103 | mask = torch.zeros_like(padded_t, dtype=torch.bool) 104 | slices = [] 105 | for i in range(max_dims): 106 | if padding_side == "left": 107 | slices.append(slice(max_sizes[i] - t.shape[i], None)) 108 | else: 109 | slices.append(slice(None, t.shape[i])) 110 | mask[tuple(slices)] = True 111 | 112 | padded_tensors.append(padded_t) 113 | masks.append(mask) 114 | 115 | stacked_tensor = torch.stack(padded_tensors, dim=dim) 116 | stacked_mask = torch.stack(masks, dim=dim) 117 | 118 | return stacked_tensor, stacked_mask 119 | 120 | 121 | class Stopwatch: 122 | """ 123 | Context manager for timing a block of code 124 | Source: https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time 125 | """ 126 | 127 | def __enter__(self): 128 | if torch.cuda.is_available(): 129 | torch.cuda.synchronize() 130 | self.time = perf_counter() 131 | return self 132 | 133 | def __exit__(self, type, value, traceback): 134 | if torch.cuda.is_available(): 135 | torch.cuda.synchronize() 136 | self.time = perf_counter() - self.time 137 | 138 | 139 | def ndim(p): 140 | """ 141 | Args: 142 | p: either a tensor or a list of tensors or a list of lists of tensors 143 | """ 144 | if isinstance(p, torch.Tensor): 145 | return p.ndim 146 | if not isinstance(p, (list, np.ndarray)): 147 | return 0 148 | elif len(p) > 0: 149 | return ndim(p[0]) + 1 150 | else: 151 | return 1 152 | 153 | 154 | @lru_cache(maxsize=100) 155 | def get_inv(lst): 156 | """ 157 | Given a list of items, returns a dict of {item: ix} 158 | """ 159 | return {x: idx for idx, x in enumerate(lst)} 160 | -------------------------------------------------------------------------------- /experiments/testing/simulate_one_sample_power.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simulate power of a one-sample test. 3 | """ 4 | 5 | from model_equality_testing.dataset import load_distribution 6 | from experiments.testing.simulation import get_power_one_sample 7 | from experiments.utils import build_cache_filename, str_to_bool 8 | from experiments.testing.bootstrap_manager import BootstrapManager 9 | import argparse 10 | import os 11 | from typing import List 12 | import json 13 | import argparse 14 | import os 15 | import torch 16 | import numpy as np 17 | from accelerate import Accelerator 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--model", 23 | type=str, 24 | help="Name of the model", 25 | default="meta-llama/Meta-Llama-3-8B-Instruct", 26 | ) 27 | parser.add_argument( 28 | "--null_distribution_source", 29 | type=str, 30 | default="fp32", 31 | help="Source to use as the null distribution P", 32 | ) 33 | parser.add_argument( 34 | "--alternative_distribution_source", 35 | type=str, 36 | default="fp32", 37 | help="Source to use as the alternative distribution Q", 38 | ) 39 | parser.add_argument( 40 | "--n_simulations", 41 | type=int, 42 | default=100, 43 | help="Number of simulations to run to estimate power", 44 | ) 45 | parser.add_argument("--alpha", type=float, default=0.05, help="Significance level") 46 | parser.add_argument( 47 | "--prompt_indices_json", 48 | type=str, 49 | help="JSON of prompt dataset name: [list of indices]. Prompt distribution will be uniform over these prompts", 50 | required=True, 51 | ) 52 | parser.add_argument( 53 | "--n", 54 | type=int, 55 | help="Number of samples to draw each time we compute the test statistic", 56 | default=None, 57 | ) 58 | parser.add_argument( 59 | "--n_per_prompt", 60 | type=int, 61 | help="Alternative to --n, number of samples per prompt", 62 | default=None, 63 | ) 64 | parser.add_argument( 65 | "--stat", type=str, default="g_squared", help="One-sample test statistic" 66 | ) 67 | parser.add_argument( 68 | "--pvalue_type", 69 | type=str, 70 | default="parametric_bootstrap", 71 | help="Type of p-value", 72 | ) 73 | parser.add_argument( 74 | "--test_in_unicode", 75 | type=str_to_bool, 76 | default=True, 77 | help="Test in unicode space instead of token space", 78 | ) 79 | parser.add_argument("--L", type=int, default=1000, help="Maximum completion length") 80 | parser.add_argument("--max_b", type=int, default=None) 81 | parser.add_argument("--min_b", type=int, default=None) 82 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 83 | parser.add_argument("--temperature", type=float, default=None) 84 | parser.add_argument("--top_p", type=float, default=None) 85 | parser.add_argument("--save_dir", type=str, default="../cache/power") 86 | parser.add_argument("--bootstrap_dir", type=str, default="../cache/parametric_bootstrap_stats") 87 | accelerator = Accelerator() 88 | args = parser.parse_args() 89 | print(args) 90 | 91 | # Load the prompt indices 92 | js = json.load(open(args.prompt_indices_json)) 93 | prompts = list(js.keys()) 94 | # get the number of prompts m 95 | m = sum([len(js[k]) for k in js]) 96 | print("Number of prompts: ", m) 97 | 98 | assert (args.n is not None) ^ ( 99 | args.n_per_prompt is not None 100 | ), "Exactly one of --n or --n_per_prompt must be provided" 101 | if args.n_per_prompt is not None: 102 | args.n = m * args.n_per_prompt 103 | print("Number of samples: ", args.n) 104 | 105 | # Load null distribution 106 | p = load_distribution( 107 | model=args.model, 108 | prompt_ids=js, 109 | L=args.L, 110 | source=args.null_distribution_source, 111 | load_in_unicode=args.test_in_unicode, 112 | ) 113 | print("Null shape: ", p.shape) 114 | 115 | # Start building the output filename 116 | filename_stem = build_cache_filename( 117 | model=args.model, 118 | prompts=prompts, 119 | prompt_indices_json=args.prompt_indices_json, 120 | use_char_space=args.test_in_unicode, 121 | alternative=args.null_distribution_source, 122 | temperature=args.temperature, 123 | top_p=args.top_p, 124 | do_sample=args.do_sample, 125 | L=args.L, 126 | stat=args.stat, 127 | N="{n}", 128 | ) 129 | filename = filename_stem.format(n=args.n) 130 | filename += f"-pvalue={args.pvalue_type}-alt={args.alternative_distribution_source}" 131 | 132 | if os.path.exists(f"{args.save_dir}/{filename}.pt"): 133 | print("Skipping b/c already exists...") 134 | exit() 135 | 136 | # Set up bootstrap manager 137 | get_pvalue = None 138 | if args.pvalue_type == "parametric_bootstrap": 139 | try: 140 | bootstrap_manager = BootstrapManager( 141 | bootstrap_path_template=f"{args.bootstrap_dir}/{filename_stem}.pkl", 142 | min_b=args.min_b, 143 | max_b=args.max_b, 144 | ) 145 | get_pvalue = bootstrap_manager.load(n=args.n) 146 | except: 147 | get_pvalue = None 148 | 149 | # Load alternative distribution 150 | q = load_distribution( 151 | model=args.model, 152 | prompt_ids=js, 153 | L=args.L, 154 | source=args.alternative_distribution_source, 155 | load_in_unicode=args.test_in_unicode, 156 | ) 157 | print("Alternative shape: ", q.shape) 158 | 159 | # Simulate! 160 | ( 161 | pwr, 162 | reject_history, 163 | alpha_history, 164 | pvalue_history, 165 | test_stats, 166 | ) = get_power_one_sample( 167 | p, 168 | q, 169 | n=args.n, 170 | n_simulations=args.n_simulations, 171 | pvalue_type=args.pvalue_type, 172 | stat_type=args.stat, 173 | get_pvalue_fn=get_pvalue, 174 | return_pvalue=True, 175 | return_alpha=True, 176 | return_stat=True, 177 | alpha=args.alpha, 178 | ) 179 | 180 | # Save results 181 | print("Power: ", pwr) 182 | torch.save( 183 | { 184 | "power": pwr, 185 | "reject_history": reject_history, 186 | "alpha_history": alpha_history, 187 | "pvalue_history": pvalue_history, 188 | "test_stats": test_stats, 189 | }, 190 | f"{args.save_dir}/{filename}.pt", 191 | ) 192 | -------------------------------------------------------------------------------- /experiments/testing/simulate_two_sample_power.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simulate power of a two-sample test. 3 | """ 4 | 5 | from model_equality_testing.dataset import load_distribution 6 | from experiments.testing.simulation import get_power_two_sample 7 | from experiments.utils import build_cache_filename, str_to_bool 8 | from experiments.testing.bootstrap_manager import BootstrapManager 9 | import argparse 10 | import os 11 | from typing import List 12 | import json 13 | import argparse 14 | import os 15 | import torch 16 | import numpy as np 17 | from accelerate import Accelerator 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--model", 23 | type=str, 24 | help="Name of the model", 25 | default="meta-llama/Meta-Llama-3-8B-Instruct", 26 | ) 27 | parser.add_argument( 28 | "--null_distribution_source", 29 | type=str, 30 | default="fp32", 31 | help="Source to use as the null distribution P", 32 | ) 33 | parser.add_argument( 34 | "--alternative_distribution_source", 35 | type=str, 36 | default="fp32", 37 | help="Source to use as the alternative distribution Q", 38 | ) 39 | parser.add_argument( 40 | "--n_simulations", 41 | type=int, 42 | default=100, 43 | help="Number of simulations to run to estimate power", 44 | ) 45 | parser.add_argument("--alpha", type=float, default=0.05, help="Significance level") 46 | parser.add_argument( 47 | "--prompt_indices_json", 48 | type=str, 49 | help="JSON of prompt dataset name: [list of indices]. Prompt distribution will be uniform over these prompts", 50 | required=True, 51 | ) 52 | parser.add_argument( 53 | "--n1", 54 | type=int, 55 | help="Number of samples to compute statistics with, for the first sample (null in cache_power)", 56 | default=None, 57 | ) 58 | parser.add_argument( 59 | "--n1_per_prompt", 60 | type=int, 61 | help="Alternative to --n1, number of samples per prompt, for the first sample (null in cache_power)", 62 | default=None, 63 | ) 64 | parser.add_argument( 65 | "--n2", 66 | type=int, 67 | help="Number of samples to compute statistics with, for the second sample (alt in cache_power)", 68 | default=None, 69 | ) 70 | parser.add_argument( 71 | "--n2_per_prompt", 72 | type=int, 73 | help="Alternative to --n1, number of samples per prompt, for the second sample (alt in cache_power)", 74 | default=None, 75 | ) 76 | parser.add_argument( 77 | "--stat", type=str, default="two_sample_L2", help="Two-sample test statistic" 78 | ) 79 | parser.add_argument( 80 | "--pvalue_type", 81 | type=str, 82 | default="parametric_bootstrap", 83 | help="Type of p-value", 84 | ) 85 | parser.add_argument( 86 | "--test_in_unicode", 87 | type=str_to_bool, 88 | default=True, 89 | help="Test in unicode space instead of token space", 90 | ) 91 | parser.add_argument("--L", type=int, default=1000, help="Maximum completion length") 92 | parser.add_argument("--max_b", type=int, default=None) 93 | parser.add_argument("--min_b", type=int, default=None) 94 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 95 | parser.add_argument("--temperature", type=float, default=None) 96 | parser.add_argument("--top_p", type=float, default=None) 97 | parser.add_argument("--save_dir", type=str, default="../cache/power") 98 | parser.add_argument( 99 | "--bootstrap_dir", type=str, default="../cache/parametric_bootstrap_stats" 100 | ) 101 | accelerator = Accelerator() 102 | args = parser.parse_args() 103 | print(args) 104 | 105 | # Load the prompt indices 106 | js = json.load(open(args.prompt_indices_json)) 107 | prompts = list(js.keys()) 108 | # get the number of prompts m 109 | m = sum([len(js[k]) for k in js]) 110 | print("Number of prompts: ", m) 111 | 112 | assert (args.n1 is not None) ^ ( 113 | args.n1_per_prompt is not None 114 | ), "Exactly one of --n1 or --n1_per_prompt must be provided" 115 | if args.n1_per_prompt is not None: 116 | args.n1 = m * args.n1_per_prompt 117 | assert (args.n2 is not None) ^ ( 118 | args.n2_per_prompt is not None 119 | ), "Exactly one of --n2 or --n2_per_prompt must be provided" 120 | if args.n2_per_prompt is not None: 121 | args.n2 = m * args.n2_per_prompt 122 | print("Number of samples: ", args.n1, args.n2) 123 | 124 | # Load null distribution 125 | p = load_distribution( 126 | model=args.model, 127 | prompt_ids=js, 128 | L=args.L, 129 | source=args.null_distribution_source, 130 | load_in_unicode=args.test_in_unicode, 131 | ) 132 | print("Null shape: ", p.shape) 133 | 134 | # Start building the output filename 135 | filename_stem = build_cache_filename( 136 | model=args.model, 137 | prompts=prompts, 138 | prompt_indices_json=args.prompt_indices_json, 139 | use_char_space=args.test_in_unicode, 140 | alternative=args.null_distribution_source, 141 | temperature=args.temperature, 142 | top_p=args.top_p, 143 | do_sample=args.do_sample, 144 | L=args.L, 145 | stat=args.stat, 146 | N="{n1}_{n2}", 147 | ) 148 | filename = filename_stem.format(n1=args.n1, n2=args.n2) 149 | filename += f"-pvalue={args.pvalue_type}-alt={args.alternative_distribution_source}" 150 | 151 | if os.path.exists(f"{args.save_dir}/{filename}.pt"): 152 | print("Skipping b/c already exists...") 153 | exit() 154 | 155 | # Set up bootstrap manager 156 | get_pvalue = None 157 | if args.pvalue_type == "parametric_bootstrap": 158 | try: 159 | bootstrap_manager = BootstrapManager( 160 | bootstrap_path_template=f"{args.bootstrap_dir}/{filename_stem}.pkl", 161 | min_b=args.min_b, 162 | max_b=args.max_b, 163 | ) 164 | get_pvalue = bootstrap_manager.load(n1=args.n1, n2=args.n2) 165 | except: 166 | get_pvalue = None 167 | 168 | # Load alternative distribution 169 | q = load_distribution( 170 | model=args.model, 171 | prompt_ids=js, 172 | L=args.L, 173 | source=args.alternative_distribution_source, 174 | load_in_unicode=args.test_in_unicode, 175 | ) 176 | print("Alternative shape: ", q.shape) 177 | 178 | # Simulate! 179 | ( 180 | pwr, 181 | reject_history, 182 | alpha_history, 183 | pvalue_history, 184 | test_stats, 185 | ) = get_power_two_sample( 186 | p, 187 | q, 188 | n_null=args.n1, 189 | n_data=args.n2, 190 | n_simulations=args.n_simulations, 191 | pvalue_type=args.pvalue_type, 192 | stat_type=args.stat, 193 | get_pvalue_fn=get_pvalue, 194 | return_pvalue=True, 195 | return_alpha=True, 196 | return_stat=True, 197 | alpha=args.alpha, 198 | ) 199 | 200 | # Save results 201 | print("Power: ", pwr) 202 | torch.save( 203 | { 204 | "power": pwr, 205 | "reject_history": reject_history, 206 | "alpha_history": alpha_history, 207 | "pvalue_history": pvalue_history, 208 | "test_stats": test_stats, 209 | }, 210 | f"{args.save_dir}/{filename}.pt", 211 | ) 212 | -------------------------------------------------------------------------------- /model_equality_testing/src/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions for downloading and loading the dataset we release. 3 | """ 4 | 5 | from model_equality_testing.distribution import DistributionFromDataset 6 | from typing import Literal, List, Dict, Union, Tuple 7 | from model_equality_testing.utils import sanitize, stack_with_padding, tokenize_unicode 8 | import glob 9 | import numpy as np 10 | import torch 11 | from transformers import AutoTokenizer 12 | import pickle 13 | from collections import defaultdict 14 | import os 15 | import zipfile 16 | 17 | def download_dataset(root_dir="./data"): 18 | """ 19 | Download the dataset from Google Drive and extract it into the specified directory. 20 | 21 | Arguments: 22 | - root_dir (str): The root directory where the dataset will be saved. 23 | - split (str or None): 'test', 'val', or None. If None, both splits will be downloaded. 24 | """ 25 | try: 26 | import gdown 27 | except: 28 | raise ImportError("Please install gdown to download the dataset: pip install gdown") 29 | 30 | os.makedirs(root_dir, exist_ok=True) 31 | file_id = "1csgp83tx04kVA9ejp6MNSC7Ni0vh0J9U" 32 | download_url = f"https://drive.google.com/uc?id={file_id}" 33 | 34 | zip_file_path = os.path.join(root_dir, "model_equality_testing_ds.zip") 35 | extract_dir = root_dir 36 | 37 | print(f"Downloading dataset...") 38 | if not os.path.exists(zip_file_path): 39 | gdown.download(download_url, zip_file_path, quiet=False) 40 | 41 | print(f"Extracting dataset...") 42 | with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: 43 | zip_ref.extractall(extract_dir) 44 | 45 | os.remove(zip_file_path) 46 | print(f"Dataset downloaded and extracted successfully.") 47 | 48 | 49 | SOURCES = [ 50 | "fp32", 51 | "fp16", 52 | "nf4", 53 | "int8", 54 | "watermark", 55 | "anyscale", 56 | "amazon", 57 | "fireworks", 58 | "replicate", 59 | "deepinfra", 60 | "groq", 61 | "perplexity", 62 | "together", 63 | "azure", 64 | ] 65 | 66 | MODELS = [ 67 | "meta-llama/Meta-Llama-3-8B-Instruct", 68 | "meta-llama/Meta-Llama-3-70B-Instruct", 69 | "meta-llama/Meta-Llama-3.1-8B-Instruct", 70 | "meta-llama/Meta-Llama-3.1-70B-Instruct", 71 | "mistralai/Mistral-7B-Instruct-v0.3", 72 | ] 73 | 74 | 75 | def load_distribution( 76 | model: str, 77 | prompt_ids: dict, 78 | L: int, 79 | source: str, 80 | prompt_distribution: np.ndarray = None, 81 | load_in_unicode: bool = True, 82 | root_dir="./data", 83 | ) -> DistributionFromDataset: 84 | """ 85 | Given a dictionary mapping {dataset_name: [prompt_ids]}, load the dataset and return a DistributionFromDataset object. 86 | Args: 87 | model: a string representing the model name 88 | prompt_ids: a dictionary mapping {dataset_name: [prompt_ids]} 89 | where dataset_name is a string and prompt_ids is a list of strings. 90 | Example: 91 | { 92 | "wikipedia_en": [0, 1, 4], 93 | "wikipedia_es": [5], 94 | } 95 | => DistributionFromDataset(m=4, L=L) 96 | L: completion length 97 | source: a string representing 98 | Returns: 99 | a DistributionFromDataset object, which allows for sampling from the dataset. 100 | """ 101 | assert source in SOURCES, f"source must be one of {SOURCES}" 102 | print(f"Loading dataset from source: {source}") 103 | print(f"Prompt IDs: {prompt_ids}") 104 | 105 | tokenizer = AutoTokenizer.from_pretrained(model) 106 | tokenizer.padding_side = "left" 107 | tokenizer.truncation_side = "left" 108 | if tokenizer.pad_token_id is None: 109 | tokenizer.pad_token = tokenizer.eos_token 110 | tokenizer.pad_token_id = tokenizer.eos_token_id 111 | 112 | if source in ["fp32", "fp16", "nf4", "int8"]: 113 | if load_in_unicode: 114 | load_fn = lambda x: _load_local_samples_unicode(x, tokenizer) 115 | else: 116 | load_fn = lambda x: _load_local_samples_tokens(x, tokenizer) 117 | else: 118 | assert load_in_unicode, "Only unicode supported for API samples" 119 | load_fn = lambda x: _load_api_samples_unicode(x, tokenizer) 120 | 121 | filenames_and_callables = [] 122 | logprob_filenames_and_callables = [] 123 | for ds, ids in prompt_ids.items(): 124 | for id in ids: 125 | name = f"{sanitize(model)}-{ds}-{source}-L=*-{id}" 126 | lst = glob.glob(f"{root_dir}/samples/{name}.pkl") 127 | if len(lst) == 0: 128 | print( 129 | f"Warning: no pkl file matching the pattern {name} was found in the specified dataset directory {root_dir}.", 130 | "Make sure the `root_dir` variable is set correctly and that `model` and `prompt_ids` exist in the dataset." 131 | ) 132 | fn = lst[0] 133 | filenames_and_callables.append((fn, load_fn)) 134 | 135 | # logprobs: include if not in unicode and the file exists 136 | if not load_in_unicode: 137 | try: 138 | fn = glob.glob(f"{root_dir}/logprobs/{name}.pkl")[0] 139 | logprob_filenames_and_callables.append((fn, lambda x: _load_logprobs(x, tokenizer))) 140 | except IndexError: 141 | pass 142 | 143 | return DistributionFromDataset( 144 | sample_paths=filenames_and_callables, 145 | L=L, 146 | prompt_distribution=prompt_distribution, 147 | logprob_paths=logprob_filenames_and_callables, 148 | pad_token_id=tokenizer.pad_token_id if not load_in_unicode else -1, 149 | ) 150 | 151 | 152 | def _pad_after_first_eos(array, eos_token_id, pad_token_id): 153 | """ 154 | In each sequence, converts all tokens to the right of the first occurrence 155 | of the eos_token to a pad_token. 156 | Edge case: does not pad if the eos token is the first token in the sequence 157 | """ 158 | assert array.ndim == 2 159 | eos_locs = np.expand_dims((array == eos_token_id).astype(float).argmax(axis=1), 1) 160 | eos_locs[eos_locs == 0] = array.shape[1] 161 | col_indices = np.repeat( 162 | np.expand_dims(np.arange(array.shape[1]), 0), array.shape[0], axis=0 163 | ) 164 | array[col_indices > eos_locs] = pad_token_id 165 | return array 166 | 167 | 168 | def _load_local_samples_tokens(path, tok) -> np.ndarray: 169 | """ 170 | Loads local samples into a (k, max_L) array of token ids. 171 | Assumes samples are saved as lists of lists of token ids (.pkl files). 172 | """ 173 | with open(path, "rb") as f: 174 | array = pickle.load(f) 175 | array = np.array(array) 176 | array = _pad_after_first_eos(array, tok.eos_token_id, tok.pad_token_id) 177 | return array 178 | 179 | 180 | def _load_local_samples_unicode(path, tok) -> np.ndarray: 181 | """ 182 | Loads local samples into a (k, max_L) array of character ids (using Python's ord function). 183 | Assumes samples are saved as lists of token ids (.pkl files). 184 | Uses -1 as a padding token in the output. 185 | """ 186 | with open(path, "rb") as f: 187 | array = pickle.load(f) 188 | array = np.array(array) 189 | array = _pad_after_first_eos(array, tok.eos_token_id, tok.pad_token_id) 190 | strings = tok.batch_decode(array, skip_special_tokens=True) 191 | return tokenize_unicode(strings) 192 | 193 | 194 | def _load_api_samples_unicode(path, tok) -> np.ndarray: 195 | """ 196 | Loads API samples from disk where each character is assigned an id, rather than in token space. 197 | """ 198 | with open(path, "rb") as f: 199 | js = pickle.load(f) 200 | samples = js["samples"] 201 | samples = [torch.tensor([ord(c) for c in s["full_completion"]]) for s in samples] 202 | chr_array = stack_with_padding(samples, padding_token=-1)[0].numpy() 203 | return chr_array 204 | 205 | 206 | def _load_logprobs(path, tok) -> dict: 207 | """ 208 | Returns a dictionary of {prompt: {sequence: logprob}} as saved 209 | by cache_logprobs.py or cache_distribution_by_sampling.py using utils.dump. 210 | """ 211 | with open(path, "rb") as f: 212 | logprobs = pickle.load(f) 213 | 214 | # merge dictionaries with the same keys 215 | new_dict = {} 216 | for k, v in logprobs.items(): 217 | k = list(k) 218 | v = torch.tensor(v) 219 | try: 220 | ix = k.index(tok.eos_token_id) 221 | k[ix + 1 :] = [tok.pad_token_id] * (len(k) - ix - 1) 222 | v[ix + 1 :] = 0 223 | except: 224 | pass 225 | new_dict[tuple(k)] = v 226 | return new_dict 227 | -------------------------------------------------------------------------------- /experiments/testing/simulate_two_sample_power_composite.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simulate power of a two-sample test. 3 | """ 4 | 5 | from model_equality_testing.dataset import load_distribution 6 | from experiments.testing.simulation import get_power_two_sample_composite_null 7 | from experiments.utils import build_cache_filename, str_to_bool 8 | from experiments.testing.bootstrap_manager import BootstrapManager 9 | import argparse 10 | import os 11 | from typing import List 12 | import json 13 | import argparse 14 | import os 15 | import torch 16 | import numpy as np 17 | from accelerate import Accelerator 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--model", 23 | type=str, 24 | help="Name of the model", 25 | default="meta-llama/Meta-Llama-3-8B-Instruct", 26 | ) 27 | parser.add_argument( 28 | "--null_distribution_1_source", 29 | type=str, 30 | default="fp32", 31 | help="Source to use as the null distribution P1", 32 | ) 33 | parser.add_argument( 34 | "--null_distribution_2_source", 35 | type=str, 36 | default="fp32", 37 | help="Source to use as the null distribution P2", 38 | ) 39 | parser.add_argument( 40 | "--alternative_distribution_source", 41 | type=str, 42 | default="fp32", 43 | help="Source to use as the alternative distribution Q", 44 | ) 45 | parser.add_argument( 46 | "--n_simulations", 47 | type=int, 48 | default=100, 49 | help="Number of simulations to run to estimate power", 50 | ) 51 | parser.add_argument("--alpha", type=float, default=0.05, help="Significance level") 52 | parser.add_argument( 53 | "--prompt_indices_json", 54 | type=str, 55 | help="JSON of prompt dataset name: [list of indices]. Prompt distribution will be uniform over these prompts", 56 | required=True, 57 | ) 58 | parser.add_argument( 59 | "--n1", 60 | type=int, 61 | help="Number of samples to compute statistics with, for the first sample (null in cache_power)", 62 | default=None, 63 | ) 64 | parser.add_argument( 65 | "--n1_per_prompt", 66 | type=int, 67 | help="Alternative to --n1, number of samples per prompt, for the first sample (null in cache_power)", 68 | default=None, 69 | ) 70 | parser.add_argument( 71 | "--n2", 72 | type=int, 73 | help="Number of samples to compute statistics with, for the second sample (alt in cache_power)", 74 | default=None, 75 | ) 76 | parser.add_argument( 77 | "--n2_per_prompt", 78 | type=int, 79 | help="Alternative to --n1, number of samples per prompt, for the second sample (alt in cache_power)", 80 | default=None, 81 | ) 82 | parser.add_argument( 83 | "--stat", type=str, default="two_sample_L2", help="Two-sample test statistic" 84 | ) 85 | parser.add_argument( 86 | "--pvalue_type", 87 | type=str, 88 | default="parametric_bootstrap", 89 | help="Type of p-value", 90 | ) 91 | parser.add_argument( 92 | "--test_in_unicode", 93 | type=str_to_bool, 94 | default=True, 95 | help="Test in unicode space instead of token space", 96 | ) 97 | parser.add_argument("--L", type=int, default=1000, help="Maximum completion length") 98 | parser.add_argument("--max_b", type=int, default=None) 99 | parser.add_argument("--min_b", type=int, default=None) 100 | parser.add_argument("--do_sample", type=str_to_bool, default=True) 101 | parser.add_argument("--temperature", type=float, default=None) 102 | parser.add_argument("--top_p", type=float, default=None) 103 | parser.add_argument("--save_dir", type=str, default="../cache/power") 104 | parser.add_argument( 105 | "--bootstrap_dir", type=str, default="../cache/parametric_bootstrap_stats" 106 | ) 107 | accelerator = Accelerator() 108 | args = parser.parse_args() 109 | print(args) 110 | 111 | # Load the prompt indices 112 | js = json.load(open(args.prompt_indices_json)) 113 | prompts = list(js.keys()) 114 | # get the number of prompts m 115 | m = sum([len(js[k]) for k in js]) 116 | print("Number of prompts: ", m) 117 | 118 | assert (args.n1 is not None) ^ ( 119 | args.n1_per_prompt is not None 120 | ), "Exactly one of --n1 or --n1_per_prompt must be provided" 121 | if args.n1_per_prompt is not None: 122 | args.n1 = m * args.n1_per_prompt 123 | assert (args.n2 is not None) ^ ( 124 | args.n2_per_prompt is not None 125 | ), "Exactly one of --n2 or --n2_per_prompt must be provided" 126 | if args.n2_per_prompt is not None: 127 | args.n2 = m * args.n2_per_prompt 128 | print("Number of samples: ", args.n1, args.n2) 129 | 130 | # Load null distribution 131 | p1 = load_distribution( 132 | model=args.model, 133 | prompt_ids=js, 134 | L=args.L, 135 | source=args.null_distribution_1_source, 136 | load_in_unicode=args.test_in_unicode, 137 | ) 138 | p2 = load_distribution( 139 | model=args.model, 140 | prompt_ids=js, 141 | L=args.L, 142 | source=args.null_distribution_2_source, 143 | load_in_unicode=args.test_in_unicode, 144 | ) 145 | print("Null shape: ", p1.shape, p2.shape) 146 | 147 | # Start building the output filename 148 | filename_stem = build_cache_filename( 149 | model=args.model, 150 | prompts=prompts, 151 | use_char_space=args.test_in_unicode, 152 | alternative=( 153 | args.null_distribution_1_source, 154 | args.null_distribution_2_source, 155 | ), 156 | temperature=args.temperature, 157 | top_p=args.top_p, 158 | do_sample=args.do_sample, 159 | L=args.L, 160 | stat=args.stat, 161 | N="{n1}_{n2}", 162 | prompt_indices_json=args.prompt_indices_json, 163 | ) 164 | filename = filename_stem.format(n1=args.n1, n2=args.n2) 165 | filename += f"-pvalue={args.pvalue_type}-alt={args.alternative_distribution_source}" 166 | 167 | if os.path.exists(f"{args.save_dir}/{filename}.pt"): 168 | print("Skipping b/c already exists...") 169 | exit() 170 | 171 | # Set up bootstrap manager 172 | get_pvalue_1 = None 173 | get_pvalue_2 = None 174 | if args.pvalue_type == "parametric_bootstrap": 175 | fn1 = filename_stem.replace( 176 | f"{args.null_distribution_1_source}_{args.null_distribution_2_source}", 177 | args.null_distribution_1_source, 178 | ) 179 | bootstrap_manager_1 = BootstrapManager( 180 | bootstrap_path_template=f"{args.bootstrap_dir}/{fn1}.pkl", 181 | min_b=args.min_b, 182 | max_b=args.max_b, 183 | ) 184 | get_pvalue_1 = bootstrap_manager_1.load(n1=args.n1, n2=args.n2) 185 | fn2 = filename_stem.replace( 186 | f"{args.null_distribution_1_source}_{args.null_distribution_2_source}", 187 | args.null_distribution_2_source, 188 | ) 189 | bootstrap_manager_2 = BootstrapManager( 190 | bootstrap_path_template=f"{args.bootstrap_dir}/{fn2}.pkl", 191 | min_b=args.min_b, 192 | max_b=args.max_b, 193 | ) 194 | get_pvalue_2 = bootstrap_manager_2.load(n1=args.n1, n2=args.n2) 195 | 196 | # Load alternative distribution 197 | q = load_distribution( 198 | model=args.model, 199 | prompt_ids=js, 200 | L=args.L, 201 | source=args.alternative_distribution_source, 202 | load_in_unicode=args.test_in_unicode, 203 | ) 204 | print("Alternative shape: ", q.shape) 205 | 206 | # Simulate! 207 | ( 208 | pwr, 209 | reject_history, 210 | alpha_history, 211 | pvalue_history, 212 | test_stats, 213 | ) = get_power_two_sample_composite_null( 214 | p1, 215 | p2, 216 | q, 217 | n_null=args.n1, 218 | n_data=args.n2, 219 | n_simulations=args.n_simulations, 220 | pvalue_type=args.pvalue_type, 221 | stat_type=args.stat, 222 | get_pvalue_fn_1=get_pvalue_1, 223 | get_pvalue_fn_2=get_pvalue_2, 224 | return_pvalue=True, 225 | return_alpha=True, 226 | return_stat=True, 227 | alpha=args.alpha, 228 | ) 229 | 230 | # Save results 231 | print("Power: ", pwr) 232 | torch.save( 233 | { 234 | "power": pwr, 235 | "reject_history": reject_history, 236 | "alpha_history": alpha_history, 237 | "pvalue_history": pvalue_history, 238 | "test_stats": test_stats, 239 | }, 240 | f"{args.save_dir}/{filename}.pt", 241 | ) 242 | -------------------------------------------------------------------------------- /model_equality_testing/src/pvalue.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import Union, Tuple, List, Dict 4 | from model_equality_testing.tests import IMPLEMENTED_TESTS 5 | from model_equality_testing.distribution import ( 6 | CompletionSample, 7 | DistributionFromDataset, 8 | ) 9 | import tqdm 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def _plot_empirical_distribution(stats, ax=None, label="", **kwargs): 14 | """ 15 | Given a numpy array of test stats, plots the empirical distribution using a histogram 16 | """ 17 | if ax is None: 18 | plt.figure() 19 | ax = plt.gca() 20 | ax.hist(stats, bins="auto", **kwargs) # Adjust bins as needed 21 | ax.set_xlabel("Test Statistic") 22 | ax.set_ylabel("Frequency") 23 | ax.set_title( 24 | f'Distribution of Test Statistic {"(" + label + ")" if len(label) else ""}' 25 | ) 26 | return ax 27 | 28 | 29 | ############################################## 30 | # Functions to simulate and compute p-values 31 | ############################################## 32 | 33 | 34 | class EmpiricalPvalueCalculator: 35 | """ 36 | Given an empirical sample of test statitics, provides a callable that returns the p-value of an observed test statistic. 37 | """ 38 | 39 | def __init__(self, observed_stats: np.ndarray): 40 | """ 41 | Args: 42 | observed_stats: a numpy array of shape (b,) 43 | where b is the number of bootstrap samples 44 | """ 45 | self.stats = observed_stats 46 | 47 | def __call__(self, obs_stat: Union[float, np.ndarray, torch.Tensor]) -> float: 48 | # handle obs_stat: make sure it's a float 49 | if isinstance(obs_stat, (torch.Tensor, np.ndarray)): 50 | obs_stat = obs_stat.item() 51 | 52 | # compare to self.stats and average across the batch dimension (b) 53 | return np.mean((self.stats >= obs_stat), axis=0).item() 54 | 55 | 56 | def one_sample_parametric_bootstrap_pvalue( 57 | null_dist: DistributionFromDataset, 58 | n: int, 59 | b=1000, 60 | plot=False, 61 | return_stats=False, 62 | stat_type="g_squared", 63 | **kwargs, 64 | ) -> Union[EmpiricalPvalueCalculator, Tuple[EmpiricalPvalueCalculator, np.ndarray]]: 65 | """ 66 | Simulates the empirical distribution of the test statistic by repeatedly drawing samples 67 | from the null distribution and computing the test statistic. 68 | Args: 69 | null_dist: a distribution object from which to draw samples 70 | n: the size of the sample to take 71 | b: the number of times to draw samples and compute the test statistic 72 | plot: whether to plot the empirical distribution of the test statistics 73 | return_stats: whether to return the raw test statistics, in addition to 74 | the p-value calculator 75 | stat_type: the type of test statistic to compute as a string. 76 | Must be a key in IMPLEMENTED_TESTS 77 | **kwargs: additional arguments to pass to the test computation function 78 | """ 79 | stats = [] 80 | for _ in tqdm.tqdm(range(b), desc="Parametric bootstrap"): 81 | bootstrap_sample = null_dist.sample(n=n) 82 | stat = IMPLEMENTED_TESTS[stat_type](bootstrap_sample, null_dist, **kwargs) 83 | stats.append(stat) 84 | stats = np.array(stats) 85 | if stats.ndim == 1: 86 | stats = np.expand_dims(stats, 1) 87 | if stats.ndim == 2: 88 | stats = np.expand_dims(stats, 2) 89 | 90 | # plot the empirical distribution 91 | if plot: 92 | b, m, nstats = stats.shape 93 | assert m == 1, "Incorrect shape for plotting" 94 | for i in range(nstats): 95 | _plot_empirical_distribution(stats[:, :, i], label=f"{stat_type} dim {i}") 96 | 97 | get_pvalue = EmpiricalPvalueCalculator(stats) 98 | if return_stats: 99 | return get_pvalue, stats 100 | return get_pvalue 101 | 102 | 103 | def two_sample_parametric_bootstrap_pvalue( 104 | null_dist: DistributionFromDataset, 105 | n1: int, 106 | n2: int, 107 | b=1000, 108 | plot=False, 109 | return_stats=False, 110 | stat_type="two_sample_L2", 111 | **kwargs, 112 | ) -> Union[EmpiricalPvalueCalculator, Tuple[EmpiricalPvalueCalculator, np.ndarray]]: 113 | """ 114 | Simulates the empirical distribution of the test statistic by repeatedly drawing samples 115 | from the null distribution and computing the test statistic. 116 | Args: 117 | null_dist: a distribution object from which to draw samples 118 | n1: the size of the first sample to take 119 | n2: the size of the second sample to take 120 | b: the number of times to draw samples and compute the test statistic 121 | plot: whether to plot the empirical distribution of the test statistics 122 | return_stats: whether to return the raw test statistics, in addition to 123 | the p-value calculator 124 | stat_type: the type of test statistic to compute as a string. 125 | Must be a key in IMPLEMENTED_TESTS 126 | **kwargs: additional arguments to pass to the test computation function 127 | """ 128 | stats = [] 129 | for _ in tqdm.tqdm(range(b), desc="Parametric bootstrap"): 130 | sample1 = null_dist.sample(n=n1) 131 | sample2 = null_dist.sample(n=n2) 132 | stat = IMPLEMENTED_TESTS[stat_type](sample1, sample2, **kwargs) 133 | stats.append(stat) 134 | stats = np.array(stats) 135 | if stats.ndim == 1: 136 | stats = np.expand_dims(stats, 1) 137 | if stats.ndim == 2: 138 | stats = np.expand_dims(stats, 2) 139 | 140 | # plot the empirical distribution 141 | if plot: 142 | b, m, nstats = stats.shape 143 | assert m == 1, "Incorrect shape for plotting" 144 | for i in range(nstats): 145 | _plot_empirical_distribution(stats[:, :, i], label=f"{stat_type} dim {i}") 146 | 147 | get_pvalue = EmpiricalPvalueCalculator(stats) 148 | if return_stats: 149 | return get_pvalue, stats 150 | return get_pvalue 151 | 152 | 153 | def two_sample_permutation_pvalue( 154 | sample1: CompletionSample, 155 | sample2: CompletionSample, 156 | b=1000, 157 | plot=False, 158 | return_stats=False, 159 | stat_type="two_sample_L2", 160 | **kwargs, 161 | ) -> Union[EmpiricalPvalueCalculator, Tuple[EmpiricalPvalueCalculator, np.ndarray]]: 162 | """ 163 | Simulates the empirical distribution of the test statistic by repeatedly permuting the labels 164 | of the samples and computing the test statistic. 165 | Args: 166 | sample1: the first sample 167 | sample2: the second sample 168 | b: the number of times to draw samples and compute the test statistic 169 | plot: whether to plot the empirical distribution of the test statistics 170 | return_stats: whether to return the raw test statistics, in addition to 171 | the p-value calculator 172 | stat_type: the type of test statistic to compute as a string. 173 | Must be a key in IMPLEMENTED_TESTS 174 | **kwargs: additional arguments to pass to the test computation function 175 | """ 176 | stats = [] 177 | all_samples = torch.cat( 178 | [ 179 | sample1.sequences, 180 | sample2.sequences, 181 | ], 182 | dim=0, 183 | ) 184 | for _ in tqdm.tqdm(range(b), desc="Permutation bootstrap"): 185 | ix = torch.randperm(len(all_samples)) 186 | permuted_sample1 = CompletionSample( 187 | prompts=all_samples[ix][: sample1.N, 0], 188 | completions=all_samples[ix][: sample1.N, 1:], 189 | m=sample1.m, 190 | ) 191 | permuted_sample2 = CompletionSample( 192 | prompts=all_samples[ix][sample1.N :, 0], 193 | completions=all_samples[ix][sample1.N :, 1:], 194 | m=sample1.m, 195 | ) 196 | 197 | stat = IMPLEMENTED_TESTS[stat_type]( 198 | permuted_sample1, permuted_sample2, **kwargs 199 | ) 200 | stats.append(stat) 201 | 202 | stats = np.array(stats) 203 | if stats.ndim == 1: 204 | stats = np.expand_dims(stats, 1) 205 | if stats.ndim == 2: 206 | stats = np.expand_dims(stats, 2) 207 | 208 | # plot the empirical distribution 209 | if plot: 210 | b, m, nstats = stats.shape 211 | assert m == 1, "Incorrect shape for plotting" 212 | for i in range(nstats): 213 | _plot_empirical_distribution(stats[:, :, i], label=f"{stat_type} dim {i}") 214 | 215 | get_pvalue = EmpiricalPvalueCalculator(stats) 216 | if return_stats: 217 | return get_pvalue, stats 218 | del stats 219 | return get_pvalue 220 | 221 | 222 | ###### map from name to function ###### 223 | 224 | IMPLEMENTED_PVALUES = { 225 | "one_sample_parametric_bootstrap": one_sample_parametric_bootstrap_pvalue, 226 | "two_sample_parametric_bootstrap": two_sample_parametric_bootstrap_pvalue, 227 | "two_sample_permutation": two_sample_permutation_pvalue, 228 | } 229 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Model Equality Testing: Which Model Is This API Serving? 2 | 3 | [Paper link](https://arxiv.org/abs/2410.20247) | [Dataset link](https://drive.google.com/drive/folders/1TgxlUp3n-BFh-A6ARA42jkvxkv7Leccv?usp=drive_link) | [Twitter announcement](https://x.com/irena_gao/status/1851273269690908777) 4 | 5 | 6 | Users often interact with large language models through black-box inference APIs, both for closed- and open-weight models (e.g., Llama models are popularly accessed via Amazon Bedrock and Azure AI Studios). 7 | In order to cut costs or add functionality, API providers may quantize, watermark, or finetune the underlying model, changing the output distribution — often without notifying users. How can we detect if an API has changed for our particular task using only sample access? 8 | 9 |
10 | 11 | 12 | We formalize this problem as Model Equality Testing, a two-sample testing problem where the user collects samples from the API and a reference distribution, and conducts a statistical test to see if the two distributions are the same. Unlike current approaches that simply compare numbers on standard benchmarks, this approach is specific to a user’s distribution of task prompts, is applicable to tasks without automated evaluation metrics, and can be more powerful at distinguishing distributions. 13 | 14 | 15 |
16 | 17 | To enable users to test APIs on their own tasks, we open-source a Python package here. Additionally, to encourage future research into this problem, we also release a dataset of 1 million LLM completions that can be used to learn / evaluate more powerful tests. 18 | 19 | ## Installation 20 | To run Model Equality Testing on your own samples, we recommend using pip to install the package: 21 | 22 | ``` 23 | pip install model-equality-testing 24 | ``` 25 | 26 | The package provides functions to run the tests discussed in the paper on your samples. This includes functions to compute test statistics and simulate p-values. 27 | 28 | An example of how to use the package to test two samples is below, and additional examples will be continually added to `demo.ipynb`. 29 | 30 | ```python 31 | import numpy as np 32 | 33 | ########## Example data ############### 34 | sampled_prompts_1 = np.array([0, 1, 0]) # integers representing which prompt was selected 35 | corresponding_completions_1 = [ 36 | "...a time to be born and a time to die", 37 | "'Laughter,' I said, 'is madness.'", 38 | "...a time to weep and a time to laugh", 39 | ] # corresponding completions 40 | sampled_prompts_2 = np.array([0, 0, 1]) # integers representing which prompt was selected 41 | corresponding_completions_2 = [ 42 | "...a time to mourn and a time to dance", 43 | "...a time to embrace and a time to refrain from embracing", 44 | "I said to myself, 'Come now, I will test you'", 45 | ] # corresponding completions 46 | 47 | ######### Testing code ################ 48 | # Tokenize the string completions as unicode codepoints 49 | # and pad both completion arrays to a shared maximum length of 200 chars 50 | from model_equality_testing.utils import tokenize_unicode 51 | corresponding_completions_1 = tokenize_unicode(corresponding_completions_1) 52 | corresponding_completions_1 = pad_to_length(corresponding_completions_1, L=200) 53 | corresponding_completions_2 = tokenize_unicode(corresponding_completions_2) 54 | corresponding_completions_2 = pad_to_length(corresponding_completions_2, L=200) 55 | 56 | 57 | # Wrap these as CompletionSample objects 58 | # m is the total number of prompts supported by the distribution 59 | from model_equality_testing.distribution import CompletionSample 60 | 61 | sample1 = CompletionSample(prompts=sampled_prompts_1, completions=corresponding_completions_1, m=2) 62 | sample2 = CompletionSample(prompts=sampled_prompts_2, completions=corresponding_completions_2, m=2) 63 | 64 | from model_equality_testing.algorithm import run_two_sample_test 65 | 66 | # Run the two-sample test 67 | pvalue, test_statistic = run_two_sample_test( 68 | sample1, 69 | sample2, 70 | pvalue_type="permutation_pvalue", # use the permutation procedure to compute the p-value 71 | stat_type="mmd_hamming", # use the MMD with Hamming kernel as the test statistic 72 | b=100, # number of permutations 73 | ) 74 | print(f"p-value: {pvalue}, test statistic: {test_statistic}") 75 | print("Should we reject P = Q?", pvalue < 0.05) 76 | ``` 77 | 78 | ## Dataset 79 | To enable future research on better tests for Model Equality Testing, we release a dataset of LLM completions, including samples used in the paper experiments. At a high level, this dataset includes 1.6M completion samples collected across 5 language models, each served by various sources (e.g. in `fp32` and `int8` precisions, as well as by various inference API providers, e.g. `amazon` and `azure`). These completions are collected for a fixed set of 540 prompts. For 100 of these prompts (the "dev set"), we additionally collect logprobs for each completion under the fp32 model. 80 | 81 | The data (and a spreadsheet documenting its contents) are hosted as a 37.1GB `dataset.zip` file [here](https://drive.google.com/drive/folders/1TgxlUp3n-BFh-A6ARA42jkvxkv7Leccv?usp=drive_link). For convenience, we provide a function in the `model-equality-testing` package to automatically download and unzip the dataset. 82 | 83 | ```python 84 | # make sure to first install gdown 85 | # ! pip install gdown 86 | from model_equality_testing.dataset import download_dataset 87 | download_dataset(root_dir="./data") # will download to ./data 88 | ``` 89 | 90 | You can also download just the samples (dev/test set) or just the logprobs (dev set); make sure to set these inside `{root_dir}/samples/` and `{root_dir}/logprobs` paths respectively. These can be found as separate zip files in the Google Drive link above. 91 | 92 | Once downloaded, you can load the dataset using the function `load_distribution`, which returns a `DistributionFromDataset` object. 93 | 94 | ```python 95 | # load a distribution object representing the joint distribution 96 | # where prompts come from Wikipedia (Ru) with prompt ids 0, 3, 10 97 | # and Wikipedia (De) with prompt id 5 98 | # and completions come from meta-llama/Meta-Llama-3-8B-Instruct 99 | from model_equality_testing.dataset import load_distribution 100 | p = load_distribution( 101 | model="meta-llama/Meta-Llama-3-8B-Instruct", # model 102 | prompt_ids={"wikipedia_ru": [0, 3, 10], "wikipedia_de": [5]}, # prompts 103 | L=1000, # number of characters to pad / truncate to 104 | source="fp32", # or replace with 'nf4', 'int8', 'amazon', etc. 105 | load_in_unicode=True, # instead of tokens 106 | root_dir="./data", 107 | ) 108 | ``` 109 | 110 | [This spreadsheet in the Google Drive](https://docs.google.com/spreadsheets/d/1T9aPZHK1xxfxogrHYaHqvW0Blqi-XJx2rgOPktWcN0w/edit?usp=sharing) catalogs all samples present in the dataset, including for which prompt they were collected, under which language model, and served by which source. 111 | 112 | At a high level, samples are split into those from local sources (`fp32`, `fp16`, `int8`, `nf4`, `watermark`) vs. APIs (`anyscale`, `amazon`, `fireworks`, `replicate`, `deepinfra`, `groq`, `perplexity`, `together`, `azure`). 113 | Local samples are saved as `.pkl` files containing numpy arrays of token IDs (integers). 114 | API samples are saved as `.pkl` files containing dictionaries; the samples themselves are strings returned directly by the API. 115 | 116 | * Note that the data loading code in our package does additional postprocessing on local samples, e.g. replacing tokens to the right of the first `` token with pad tokens, and padding to the max observed length. This is because our desired behavior is to only test strings up to the first `` token. When the model tokenizer does not specify a pad token, we set it to the `` token. For convenience, we've preprocessed local samples before saving in the zip files above, but we include the processing code in the `model_equality_testing.dataset` module. 117 | * To the best of our knowledge, API samples were all returned without special tokens. 118 | * Local samples may be loaded in unicode or token space; API samples can only be loaded in unicode. When loading samples in unicode, samples are batch decoded skipping special tokens, and each character is represented by its integer Unicode codepoint. Padding is represented as `-1`. 119 | 120 | Additional details about dataset collection can be found in Appendix B.1 in [the paper](https://arxiv.org/abs/2410.20247). 121 | 122 | ## Reproducing paper experiments 123 | 124 | In `experiments/`, we include the code used to produce the experiments shown in the paper, including the code to generate the dataset (`experiments/sampling`) and the code to simulate power (`experiments/testing`). 125 | 126 | Note that APIs are actively evolving: many APIs have changed behavior since when we used these scripts to collect samples between July and August 2024. For full details documenting the dates we queried each API for the samples in our dataset, see Appendix B.1 in [the paper](https://arxiv.org/abs/2410.20247). 127 | 128 | ## Citation 129 | 130 | If you use our dataset or code, please cite this work as 131 | ```bibtex 132 | @misc{gao2024model, 133 | title={Model Equality Testing: Which model is this API serving?}, 134 | author={Gao, Irena and Liang, Percy and Guestrin, Carlos}, 135 | journal={arXiv preprint}, 136 | year={2024} 137 | } 138 | ``` -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | from argparse import Action 2 | from typing import List, Tuple, Union 3 | import os 4 | import torch 5 | import torch.nn.functional as F 6 | from ast import literal_eval 7 | import time 8 | import numpy as np 9 | import hashlib 10 | 11 | FILE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | 13 | 14 | def sanitize(s): 15 | """Sanitize a string for use as a filename.""" 16 | s = str(s) 17 | s = s.replace(" ", "-") 18 | s = s.replace("[", "") 19 | s = s.replace("]", "") 20 | s = s.replace(",", "_") 21 | s = s.replace("/", "-") 22 | s = s.replace("(", "") 23 | s = s.replace(")", "") 24 | return s 25 | 26 | 27 | def build_cache_filename( 28 | *, 29 | filename: str = None, 30 | model: str = None, 31 | prompts: Union[str, List[str]] = None, 32 | alternative: Union[str, Tuple[str, str]] = None, 33 | use_char_space: bool = None, 34 | temperature: float = None, 35 | top_p: float = None, 36 | do_sample: bool = True, 37 | L: int = None, 38 | stat: str = None, 39 | N: Union[int, Tuple[int]] = None, 40 | prompt_indices_json: str = None, 41 | ): 42 | """ 43 | Build the filename used for logits, parametric bootstrap, power caches, etc. 44 | Appends to the given filename if provided. 45 | Format: 46 | {model}-{prompts}-{alternative}-{temperature}-{top_p}-{do_sample}-{L}-{stat}-{N}-{prompt_indices} 47 | """ 48 | if filename is None: 49 | filename = "" 50 | if model is not None: 51 | filename += f"{sanitize(model)}" 52 | if prompts is not None: 53 | charspace = "-char" if use_char_space == True else "" 54 | if isinstance(prompts, list): 55 | prompts = "-".join(sorted(prompts)) 56 | if ( 57 | prompts 58 | == "wikipedia_de-wikipedia_en-wikipedia_es-wikipedia_fr-wikipedia_ru" 59 | ): 60 | prompts = "wikipedia" 61 | filename += f"-{prompts}{charspace}" 62 | if alternative is not None: 63 | if type(alternative) == tuple: 64 | assert ( 65 | len(alternative) == 2 66 | ), "Only support a tuple in the case of a composite null with two simple nulls" 67 | filename += f"-{sanitize(alternative[0])}_{sanitize(alternative[1])}" # note: None will be shown in the composite case 68 | else: 69 | filename += ( 70 | f"-{sanitize(alternative)}" if alternative != "None" else "" 71 | ) # None will not be shown in the simple case 72 | if temperature is not None: 73 | filename += f"-temp={temperature}" 74 | if top_p is not None: 75 | filename += f"-top_p={top_p}" 76 | if do_sample is False: 77 | filename += "-greedy" 78 | if L is not None: 79 | filename += f"-L={L}" 80 | if stat is not None: 81 | filename += f"-{stat}" 82 | if N is not None: 83 | if type(N) == tuple: 84 | assert len(N) == 2, "N should only be a tuple for two-sample testing" 85 | N = f"{N[0]}_{N[1]}" 86 | filename += f"-n={N}" 87 | if prompt_indices_json is not None: 88 | # convert prompt_indices_json to relative filepath from root 89 | to_hash = os.path.abspath(prompt_indices_json).replace(FILE_DIR, "") 90 | setting = hash_fn(to_hash) 91 | filename += f"-prompt_indices={setting}" 92 | return filename 93 | 94 | 95 | ################################### 96 | # Argparse utils 97 | ################################### 98 | 99 | 100 | def str_to_bool(value): 101 | """ 102 | Function to parse boolean values from argparse. 103 | """ 104 | if value.lower() in ("true", "t", "yes", "y"): 105 | return True 106 | elif value.lower() in ("false", "f", "no", "n"): 107 | return False 108 | else: 109 | raise Exception(f"Invalid boolean value: {value}") 110 | 111 | 112 | class ParseKwargs(Action): 113 | """ 114 | Helper function s.t. argparse can parse kwargs of the form --kwarg key1=value1 key2=value2 115 | """ 116 | 117 | def __call__(self, parser, namespace, values, option_string=None): 118 | setattr(namespace, self.dest, dict()) 119 | for pair in values: 120 | key, value = pair.split("=") 121 | processed_value = infer_type(value) 122 | getattr(namespace, self.dest)[key] = processed_value 123 | 124 | 125 | def infer_type(s): 126 | """ 127 | If the str can be interpreted as a float or an int, convert it to that type. 128 | """ 129 | try: 130 | return str_to_bool(s) 131 | except: 132 | pass 133 | try: 134 | return literal_eval(s) 135 | except: 136 | pass 137 | try: 138 | return str_to_torchdtype(s) 139 | except: 140 | pass 141 | try: 142 | return str_to_list(s) 143 | except: 144 | return s 145 | 146 | 147 | def str_to_torchdtype(value): 148 | if not value.startswith("torch."): 149 | raise Exception(f"Invalid torch dtype: {value}") 150 | return getattr(torch, value.split(".")[1]) 151 | 152 | 153 | def str_to_list(value): 154 | """ 155 | Helper function to parse a string of the form "[x,y,z]" into a list [x, y, z]. 156 | Catches some cases where ast.literal_eval fails because the elements in the list 157 | contain non-standard characters. 158 | """ 159 | if value.startswith("[") and value.endswith("]"): 160 | value = value[1:-1] 161 | else: 162 | raise Exception(f"Invalid list value: {value}") 163 | 164 | return value.split(",") 165 | 166 | 167 | def str_to_bool(value): 168 | """ 169 | Function to parse boolean values from argparse. 170 | """ 171 | if value.lower() in ("true", "t", "yes", "y"): 172 | return True 173 | elif value.lower() in ("false", "f", "no", "n"): 174 | return False 175 | else: 176 | raise Exception(f"Invalid boolean value: {value}") 177 | 178 | 179 | ################################### 180 | # Misc utils 181 | ################################### 182 | 183 | 184 | def collate(list_of_dicts): 185 | """ 186 | Collate a list of dictionaries into a single dictionary. 187 | """ 188 | collated_dict = defaultdict(list) 189 | for d in list_of_dicts: 190 | for k, v in d.items(): 191 | collated_dict[k].append(v) 192 | return collated_dict 193 | 194 | 195 | def uncollate(dict_of_lists, step=1): 196 | """ 197 | Uncollate a dictionary of lists into a list of dictionaries. 198 | If step > 1, the output will be a list of dicts where each dict key 199 | has values that are themselves lists of size step. 200 | If step = 1, the outputs will be a list of dicts where each dict value 201 | are not lists. 202 | """ 203 | list_of_dicts = [] 204 | num_elements = len(dict_of_lists[list(dict_of_lists.keys())[0]]) 205 | 206 | for i in range(0, num_elements, step): 207 | dict_entry = {} 208 | for key, value_list in dict_of_lists.items(): 209 | _n = min(i + step, num_elements) - i 210 | if value_list is None: 211 | dict_entry[key] = ([None] * _n) if _n > 1 else None 212 | else: 213 | dict_entry[key] = value_list[i : i + step] if _n > 1 else value_list[i] 214 | list_of_dicts.append(dict_entry) 215 | 216 | return list_of_dicts 217 | 218 | 219 | def stack_with_padding( 220 | tensors: List[torch.Tensor], 221 | dim: int = 0, 222 | padding_side: str = "right", 223 | padding_mode: str = "constant", 224 | padding_token: Union[int, float] = 0, 225 | ) -> Tuple[torch.Tensor, torch.Tensor]: 226 | """ 227 | Stack tensors along specified dimension and pad them to ensure their size is equal in all dimensions. 228 | Returns the stacked tensor and a boolean mask indicating valid (non-padded) elements. 229 | 230 | Args: 231 | tensors (List[torch.Tensor]): list of tensors to stack 232 | dim (int): dimension along which to stack tensors. Defaults to 0. 233 | padding_side (str): side on which to pad - "left" or "right". Defaults to "right". 234 | padding_mode (str): 'constant', 'reflect', 'replicate' or 'circular'. Default: 'constant' 235 | padding_value (Union[int, float]): value to use for constant padding 236 | 237 | Returns: 238 | Tuple[torch.Tensor, torch.Tensor]: stacked tensor and boolean mask 239 | """ 240 | # Ensure all tensors have the same number of dimensions 241 | max_dims = max(t.dim() for t in tensors) 242 | tensors = [t.view(*t.shape, *([1] * (max_dims - t.dim()))) for t in tensors] 243 | 244 | # Find the maximum size for each dimension 245 | max_sizes = [max(t.shape[i] for t in tensors) for i in range(max_dims)] 246 | 247 | def make_padding(tensor_shape): 248 | padding = [] 249 | for i in reversed(range(max_dims)): # Reverse for F.pad expectations 250 | pad_size = max_sizes[i] - tensor_shape[i] 251 | if padding_side == "left": 252 | padding.extend([pad_size, 0]) 253 | elif padding_side == "right": 254 | padding.extend([0, pad_size]) 255 | else: 256 | raise ValueError(f"padding_side '{padding_side}' is unknown") 257 | return tuple(padding) 258 | 259 | padded_tensors = [] 260 | masks = [] 261 | 262 | for t in tensors: 263 | padding = make_padding(t.shape) 264 | padded_t = F.pad(t, padding, mode=padding_mode, value=padding_token) 265 | 266 | mask = torch.zeros_like(padded_t, dtype=torch.bool) 267 | slices = [] 268 | for i in range(max_dims): 269 | if padding_side == "left": 270 | slices.append(slice(max_sizes[i] - t.shape[i], None)) 271 | else: 272 | slices.append(slice(None, t.shape[i])) 273 | mask[tuple(slices)] = True 274 | 275 | padded_tensors.append(padded_t) 276 | masks.append(mask) 277 | 278 | stacked_tensor = torch.stack(padded_tensors, dim=dim) 279 | stacked_mask = torch.stack(masks, dim=dim) 280 | 281 | return stacked_tensor, stacked_mask 282 | 283 | 284 | def wait_if_error( 285 | callable, 286 | *args, 287 | timeout=1, 288 | max_retries=5, 289 | exception_if_fail=False, 290 | special_exceptions=[], 291 | **kwargs, 292 | ): 293 | """ 294 | Call a function, and if it raises an exception, wait for timeout seconds and try again, 295 | up to max_retries times. 296 | """ 297 | for try_number in range(max_retries): 298 | try: 299 | return callable(*args, **kwargs) 300 | except Exception as e: 301 | if type(e) in special_exceptions: 302 | print(f"Error: {e}; finishing on a special exception") 303 | return None 304 | print(f"Error: {e}; waiting {timeout} seconds and trying again") 305 | time.sleep(timeout ** (try_number + 1)) 306 | 307 | if exception_if_fail: 308 | raise Exception(f"Failed after {max_retries} tries") 309 | else: 310 | print(f"Failed after {max_retries} tries") 311 | return None 312 | 313 | 314 | def seed_everything(seed: int): 315 | """ 316 | Helper function to seed everything. 317 | """ 318 | import random 319 | 320 | random.seed(seed) 321 | np.random.seed(seed) 322 | torch.manual_seed(seed) 323 | torch.cuda.manual_seed_all(seed) 324 | 325 | 326 | class StreamingDataset(torch.utils.data.Dataset): 327 | """ 328 | Truncates a stream dataset at num_samples. 329 | """ 330 | 331 | def __init__(self, stream_dataset, num_samples, char_limit=None): 332 | self.stream_dataset = stream_dataset 333 | self.num_samples = num_samples 334 | self.samples = list(self._truncated_dataset(char_limit=char_limit)) 335 | 336 | def _truncated_dataset(self, char_limit=None): 337 | i = 0 338 | seen = set() 339 | for sample in self.stream_dataset: 340 | if i == self.num_samples: 341 | break 342 | if sample["id"] in seen: 343 | continue 344 | if char_limit is None or len(sample["plain"]) <= char_limit: 345 | i += 1 346 | seen.add(sample["id"]) 347 | yield sample 348 | 349 | def __len__(self): 350 | return self.num_samples 351 | 352 | def __getitem__(self, idx): 353 | # if integer, get the row 354 | # if string, get the column 355 | if isinstance(idx, (int, slice)): 356 | return self.samples[idx] 357 | elif isinstance(idx, str): 358 | return [sample[idx] for sample in self.samples] 359 | else: 360 | raise ValueError("idx must be an integer or a string") 361 | 362 | def remove_columns(self, columns): 363 | for sample in self.samples: 364 | for col in columns: 365 | sample.pop(col, None) 366 | return self 367 | 368 | 369 | def hash_fn(x: object, type="md5"): 370 | """ 371 | Hash an object determinisitically. 372 | """ 373 | # encode the object 374 | if isinstance(x, torch.Tensor): 375 | encoded = x.numpy().tobytes() 376 | elif isinstance(x, np.ndarray): 377 | encoded = x.tobytes() 378 | elif isinstance(x, str): 379 | encoded = x.encode("utf-8") 380 | else: 381 | encoded = pickle.dumps(x) 382 | # hash the encoded object 383 | if type == "md5": 384 | return hashlib.md5(encoded).hexdigest() 385 | elif type == "sha256": 386 | return hashlib.sha256(encoded).hexdigest() 387 | else: 388 | return zlib.adler32(encoded) 389 | -------------------------------------------------------------------------------- /experiments/prompts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Library of functions to generate datasets of prompts. 3 | The exposed functions are of the form get_{prompt}_prompts. 4 | They return Datasets with these columns: 5 | 6 | - `plain` (str): the prompt without any model-specific modifications, and without the '"...' at the start of the model response. 7 | This is useful for APIs that apply chat templates automatically. 8 | Example: 9 | 'Continue the sequence. Do not output anything else except the continuation of the sequence. Start the continuation immediately.\nSequence: 123' 10 | 11 | - `chat` (str): `plain` but formatted in the model's chat template, but WITHOUT the initial token or the ending token. 12 | The goal is that later calls to the tokenizer with add_special_tokens=True will produce the same output as `chat_tokens`. 13 | Example for llama-2-7b-chat: 14 | '[INST] Continue the sequence. Do not output anything else except the continuation of the sequence. Start the continuation immediately.\nSequence: 123 [/INST]' 15 | After tokenizer: ' [INST] Continue the sequence. Do not output anything else except the continuation of the sequence. Start the continuation immediately.\nSequence: 123 [/INST]' 16 | 17 | - `chat_with_special` (str): `chat` but keeping the special tokens 18 | 19 | - `chat_tokens`: the prompt with model-specific chat templates applied, including all or tokens, tokenized. 20 | This is for testing only, and should not be used as input to the model. 21 | 22 | - `chat_with_ellipses` (str): `chat` but with the ellipses at the end of the model response. 23 | The ellipses at the start of the model response, after the chat tempalte, helps chat models to start immediately instead of hedging. 24 | Whenever possible, we try to use this column instead of `chat`. 25 | This requires us to call the Completions endpoints for APIs. That's actually nice because it removes uncertainty about whether 26 | chat templates are being applied correctly by the API. 27 | Example for llama-2-7b-chat: 28 | '[INST] Continue the sequence. Do not output anything else except the continuation of the sequence. Start the continuation immediately.\nSequence: 123 [/INST] "...' 29 | 30 | - `chat_with_ellipses_special` (str): `chat_with_ellipses` but keeping the special tokens e.g. 31 | 32 | - `chat_with_ellipses_tokens`: `chat_with_ellipses` tokenized. 33 | This is for testing only, and should not be used as input to the model. 34 | 35 | - `id`: a unique identifier for the prompt, generated by hashing the `plain` prompt. 36 | 37 | - `y`: for datasets where there is a gold response, this is the gold response. 38 | """ 39 | 40 | import os 41 | import random 42 | import datasets 43 | from experiments.utils import seed_everything, StreamingDataset, hash_fn 44 | 45 | # seed everything with seed 0 for consistent hashing 46 | seed_everything(0) 47 | 48 | ####################################### 49 | 50 | def _map_to_chat_format(ds, model, ellipses='"...'): 51 | return ds.map( 52 | lambda x: { 53 | "chat": model.format_as_chat(x["plain"], add_ellipses=False), 54 | "chat_with_special": model.format_as_chat(x["plain"], add_ellipses=False, remove_special=False), 55 | "chat_tokens": model.format_as_chat(x["plain"], tokenize=True, add_ellipses=False), 56 | "chat_with_ellipses": model.format_as_chat(x["plain"], add_ellipses=True, ellipses=ellipses), 57 | "chat_with_ellipses_special": model.format_as_chat(x["plain"], add_ellipses=True, remove_special=False, ellipses=ellipses), 58 | "chat_with_ellipses_tokens": model.format_as_chat( 59 | x["plain"], tokenize=True, add_ellipses=True, ellipses=ellipses 60 | ), 61 | "id": hash_fn(x["plain"]), 62 | } 63 | ) 64 | 65 | def _load_local_prompts( 66 | filename: str, num_prompts: int, num_digits: int, random_digit_fn: callable, model 67 | ): 68 | """ 69 | Load prompts from a local file, or generate new ones if the file does not exist. 70 | """ 71 | if os.path.exists(filename): 72 | with open(filename, "r") as f: 73 | prompts = f.readlines() 74 | prompts = [p.strip().replace("\\n", "\n") for p in prompts] 75 | else: 76 | prompts = [] 77 | 78 | def random_x(): 79 | number = str(random_digit_fn()) 80 | for _ in range(num_digits - 1): 81 | number += str(random_digit_fn()) 82 | return number 83 | 84 | if len(prompts) < num_prompts: 85 | prompts += [ 86 | "Continue the sequence. Do not output anything else except the continuation of the sequence. Start the continuation immediately.\nSequence: " 87 | + random_x() 88 | for _ in range(num_prompts - len(prompts)) 89 | ] 90 | 91 | with open(filename, "w") as f: 92 | for p in prompts: 93 | f.write(p.replace("\n", "\\n") + "\n") 94 | ds = datasets.Dataset.from_dict({"plain": prompts}) 95 | ds = _map_to_chat_format(ds, model) 96 | print(f"{filename} prompts loaded: {ds['id']}") 97 | return ds 98 | 99 | ########################### 100 | 101 | def get_dummy_prompts(model) -> datasets.Dataset: 102 | """Dummy prompts""" 103 | ds = datasets.Dataset.from_dict({"plain": ["hello", "world", "the sun", "the moon", "the stars"]}) 104 | ds = _map_to_chat_format(ds, model) 105 | print(f"Dummy prompts loaded: {ds['id']}") 106 | return ds 107 | 108 | 109 | def get_number_prompts(model): 110 | return _load_local_prompts( 111 | "number_prompts.txt", 100, 100, lambda: random.randint(0, 9), model 112 | ) 113 | 114 | 115 | def get_bit_prompts(model): 116 | return _load_local_prompts( 117 | "bit_prompts.txt", 100, 100, lambda: random.randint(0, 1), model 118 | ) 119 | 120 | 121 | def get_alphanumericpunct_prompts(model): 122 | ALPHANUMERIC_PUNCTUATION = ( 123 | list(range(33, 48)) # !"#$%&'()*+,-./ 124 | + list(range(58, 65)) # :;<=>?@ 125 | + list(range(91, 97)) # [\]^_` 126 | + list(range(123, 127)) # {|}~ 127 | + list(range(48, 58)) # 0123456789 128 | + list(range(65, 91)) # ABCDEFGHIJKLMNOPQRSTUVWXYZ 129 | + list(range(97, 123)) # abcdefghijklmnopqrstuvwxyz 130 | ) 131 | ALPHANUMERIC_PUNCTUATION = [chr(i) for i in ALPHANUMERIC_PUNCTUATION] 132 | return _load_local_prompts( 133 | "alphanumericpunct_prompts.txt", 134 | 100, 135 | 100, 136 | lambda: random.choice(ALPHANUMERIC_PUNCTUATION), 137 | model, 138 | ) 139 | 140 | def _get_wikipedia_prompts(model, language: str) -> datasets.Dataset: 141 | """100-char Wikipedia snippets as prompts""" 142 | LANGUAGES = [ 143 | "en", "de", "fr", "ru", "es", "it", "ceb", "uk", "ja", "nl" 144 | ] 145 | assert language in LANGUAGES, f"Language {language} not in {LANGUAGES}" 146 | 147 | ds = datasets.load_dataset("Cohere/wikipedia-2023-11-embed-multilingual-v3", language, split="train", streaming=True) 148 | 149 | def get_prompt(row): 150 | out = "Continue the paragraph. Do not output anything except the continuation to the paragraph. Start the continuation immediately.\n" 151 | out += '"' + row["text"][:100] + '..."' 152 | return out 153 | 154 | ds = ds.remove_columns(["url", "_id", "title","emb"]).map( 155 | lambda x: { 156 | "plain": get_prompt(x), 157 | } 158 | ) 159 | ds = _map_to_chat_format(ds, model) 160 | print(f"Wikipedia prompts loaded in streaming mode") 161 | return StreamingDataset(ds, 100) 162 | 163 | get_wikipedia_en_prompts = lambda model: _get_wikipedia_prompts(model, "en") 164 | get_wikipedia_de_prompts = lambda model: _get_wikipedia_prompts(model, "de") 165 | get_wikipedia_fr_prompts = lambda model: _get_wikipedia_prompts(model, "fr") 166 | get_wikipedia_ru_prompts = lambda model: _get_wikipedia_prompts(model, "ru") 167 | get_wikipedia_es_prompts = lambda model: _get_wikipedia_prompts(model, "es") 168 | get_wikipedia_it_prompts = lambda model: _get_wikipedia_prompts(model, "it") 169 | get_wikipedia_ceb_prompts = lambda model: _get_wikipedia_prompts(model, "ceb") 170 | get_wikipedia_uk_prompts = lambda model: _get_wikipedia_prompts(model, "uk") 171 | get_wikipedia_ja_prompts = lambda model: _get_wikipedia_prompts(model, "ja") 172 | get_wikipedia_nl_prompts = lambda model: _get_wikipedia_prompts(model, "nl") 173 | 174 | def get_humaneval_prompts(model) -> datasets.Dataset: 175 | """HumanEval prompts""" 176 | ds = datasets.load_dataset("openai/openai_humaneval", split="test", streaming=True) 177 | 178 | def get_prompt(row): 179 | out = "Complete the code. Do not output anything except the completion. Start the continuation immediately.\n" 180 | out += '```\n' + row["prompt"] 181 | return out 182 | 183 | ds = ds.map( 184 | lambda x: { 185 | "plain": get_prompt(x), 186 | "y": x["canonical_solution"], 187 | } 188 | ) 189 | ds = _map_to_chat_format(ds, model, ellipses="```\n...\n") 190 | ds = ds.remove_columns(["entry_point", "test", "canonical_solution"]) 191 | print(f"HumanEval prompts loaded in streaming mode") 192 | return StreamingDataset(ds, 100) 193 | 194 | def get_ultrachat_prompts(model) -> datasets.Dataset: 195 | """Ultrachat_200K prompts""" 196 | ds = datasets.load_dataset("HuggingFaceH4/ultrachat_200k", split="test_gen", streaming=True) 197 | 198 | def get_prompt(row): 199 | out = row["prompt"] 200 | return out 201 | 202 | ds = ds.map( 203 | lambda x: { 204 | "plain": get_prompt(x), 205 | } 206 | ) 207 | ds = _map_to_chat_format(ds, model, ellipses="Sure, here's what you asked for:") 208 | ds = ds.remove_columns(["prompt_id", "messages"]) 209 | print(f"Ultrachat_200K prompts loaded in streaming mode") 210 | return StreamingDataset(ds, 100, char_limit=1000) 211 | 212 | ANSWER_SEPARATOR = "|||" 213 | 214 | def get_natural_questions_prompts(model) -> datasets.Dataset: 215 | """Natural Questions prompts""" 216 | ds = datasets.load_dataset("google-research-datasets/nq_open", split="validation", streaming=True) 217 | 218 | def get_prompt(row): 219 | out = "Answer the question. Do not output anything except the answer. Start the answer immediately.\n" 220 | out += row["question"] 221 | return out 222 | 223 | ds = ds.map( 224 | lambda x: { 225 | "plain": get_prompt(x), 226 | "y": ANSWER_SEPARATOR.join(x["answer"]), 227 | } 228 | ) 229 | ds = ds.remove_columns(['answer']) 230 | ds = _map_to_chat_format(ds, model, ellipses="Answer:") 231 | print(f"Natural Questions prompts loaded in streaming mode") 232 | return StreamingDataset(ds, 100, char_limit=1000) 233 | 234 | def get_cnn_prompts(model) -> datasets.Dataset: 235 | """CNN / DailyMail dataset""" 236 | ds = datasets.load_dataset("abisee/cnn_dailymail", "3.0.0", split="validation", streaming=True) 237 | 238 | def get_prompt(row): 239 | out = "Summarize the article in two sentences. Do not output anything except the summary. Start the summary immediately.\n\n" 240 | out += row["article"] 241 | return out 242 | 243 | ds = ds.map( 244 | lambda x: { 245 | "plain": get_prompt(x), 246 | "y": x["highlights"], 247 | } 248 | ) 249 | ds = ds.remove_columns(["id"]) 250 | ds = _map_to_chat_format(ds, model, ellipses="") 251 | print(f"CNN / DailyMail prompts loaded in streaming mode") 252 | return StreamingDataset(ds, 100, char_limit=10000) 253 | 254 | def _get_mmlu_prompts(model, subset) -> datasets.Dataset: 255 | """MMLU prompts""" 256 | ds = datasets.load_dataset("cais/mmlu", subset, split="test", streaming=True) 257 | 258 | def get_prompt(row): 259 | out = "Answer the multiple choice question. Respond with 'A', 'B', 'C', or 'D' for your answer choice.\n\n" 260 | out += f"Question: {row['question']}\n" 261 | for choice, letter in zip(row["choices"], ["A", "B", "C", "D"]): 262 | out += f"{letter}. {choice}\n" 263 | out = out.strip() 264 | return out 265 | 266 | ds = ds.map( 267 | lambda x: { 268 | "plain": get_prompt(x), 269 | } 270 | ) 271 | ds = _map_to_chat_format(ds, model, ellipses="Answer:") 272 | print(f"MMLU prompts loaded in streaming mode") 273 | return StreamingDataset(ds, 100, char_limit=1000) 274 | 275 | get_mmlu_abstract_algebra_prompts = lambda model: _get_mmlu_prompts(model, "abstract_algebra") 276 | get_mmlu_anatomy_prompts = lambda model: _get_mmlu_prompts(model, "anatomy") 277 | get_mmlu_astronomy_prompts = lambda model: _get_mmlu_prompts(model, "astronomy") 278 | get_mmlu_business_ethics_prompts = lambda model: _get_mmlu_prompts(model, "business_ethics") 279 | get_mmlu_clinical_knowledge_prompts = lambda model: _get_mmlu_prompts(model, "clinical_knowledge") 280 | get_mmlu_college_biology_prompts = lambda model: _get_mmlu_prompts(model, "college_biology") 281 | get_mmlu_college_chemistry_prompts = lambda model: _get_mmlu_prompts(model, "college_chemistry") 282 | get_mmlu_college_computer_science_prompts = lambda model: _get_mmlu_prompts(model, "college_computer_science") 283 | get_mmlu_college_mathematics_prompts = lambda model: _get_mmlu_prompts(model, "college_mathematics") 284 | get_mmlu_college_medicine_prompts = lambda model: _get_mmlu_prompts(model, "college_medicine") 285 | get_mmlu_college_physics_prompts = lambda model: _get_mmlu_prompts(model, "college_physics") 286 | get_mmlu_computer_security_prompts = lambda model: _get_mmlu_prompts(model, "computer_security") 287 | get_mmlu_conceptual_physics_prompts = lambda model: _get_mmlu_prompts(model, "conceptual_physics") 288 | get_mmlu_econometrics_prompts = lambda model: _get_mmlu_prompts(model, "econometrics") 289 | get_mmlu_electrical_engineering_prompts = lambda model: _get_mmlu_prompts(model, "electrical_engineering") 290 | get_mmlu_elementary_mathematics_prompts = lambda model: _get_mmlu_prompts(model, "elementary_mathematics") 291 | get_mmlu_formal_logic_prompts = lambda model: _get_mmlu_prompts(model, "formal_logic") 292 | get_mmlu_global_facts_prompts = lambda model: _get_mmlu_prompts(model, "global_facts") 293 | get_mmlu_high_school_biology_prompts = lambda model: _get_mmlu_prompts(model, "high_school_biology") 294 | get_mmlu_high_school_chemistry_prompts = lambda model: _get_mmlu_prompts(model, "high_school_chemistry") 295 | get_mmlu_high_school_computer_science_prompts = lambda model: _get_mmlu_prompts(model, "high_school_computer_science") 296 | get_mmlu_high_school_european_history_prompts = lambda model: _get_mmlu_prompts(model, "high_school_european_history") 297 | get_mmlu_high_school_geography_prompts = lambda model: _get_mmlu_prompts(model, "high_school_geography") 298 | get_mmlu_high_school_government_and_politics_prompts = lambda model: _get_mmlu_prompts(model, "high_school_government_and_politics") 299 | get_mmlu_high_school_macroeconomics_prompts = lambda model: _get_mmlu_prompts(model, "high_school_macroeconomics") 300 | get_mmlu_high_school_mathematics_prompts = lambda model: _get_mmlu_prompts(model, "high_school_mathematics") 301 | get_mmlu_high_school_microeconomics_prompts = lambda model: _get_mmlu_prompts(model, "high_school_microeconomics") 302 | get_mmlu_high_school_physics_prompts = lambda model: _get_mmlu_prompts(model, "high_school_physics") 303 | get_mmlu_high_school_psychology_prompts = lambda model: _get_mmlu_prompts(model, "high_school_psychology") 304 | get_mmlu_high_school_statistics_prompts = lambda model: _get_mmlu_prompts(model, "high_school_statistics") 305 | get_mmlu_high_school_us_history_prompts = lambda model: _get_mmlu_prompts(model, "high_school_us_history") 306 | get_mmlu_high_school_world_history_prompts = lambda model: _get_mmlu_prompts(model, "high_school_world_history") 307 | get_mmlu_human_aging_prompts = lambda model: _get_mmlu_prompts(model, "human_aging") 308 | get_mmlu_human_sexuality_prompts = lambda model: _get_mmlu_prompts(model, "human_sexuality") 309 | get_mmlu_international_law_prompts = lambda model: _get_mmlu_prompts(model, "international_law") 310 | get_mmlu_jurisprudence_prompts = lambda model: _get_mmlu_prompts(model, "jurisprudence") 311 | get_mmlu_logical_fallacies_prompts = lambda model: _get_mmlu_prompts(model, "logical_fallacies") 312 | get_mmlu_machine_learning_prompts = lambda model: _get_mmlu_prompts(model, "machine_learning") 313 | get_mmlu_management_prompts = lambda model: _get_mmlu_prompts(model, "management") 314 | get_mmlu_marketing_prompts = lambda model: _get_mmlu_prompts(model, "marketing") 315 | get_mmlu_medical_genetics_prompts = lambda model: _get_mmlu_prompts(model, "medical_genetics") 316 | get_mmlu_miscellaneous_prompts = lambda model: _get_mmlu_prompts(model, "miscellaneous") 317 | get_mmlu_moral_disputes_prompts = lambda model: _get_mmlu_prompts(model, "moral_disputes") 318 | get_mmlu_moral_scenarios_prompts = lambda model: _get_mmlu_prompts(model, "moral_scenarios") 319 | get_mmlu_nutrition_prompts = lambda model: _get_mmlu_prompts(model, "nutrition") 320 | get_mmlu_philosophy_prompts = lambda model: _get_mmlu_prompts(model, "philosophy") 321 | get_mmlu_prehistory_prompts = lambda model: _get_mmlu_prompts(model, "prehistory") 322 | get_mmlu_professional_accounting_prompts = lambda model: _get_mmlu_prompts(model, "professional_accounting") 323 | get_mmlu_professional_law_prompts = lambda model: _get_mmlu_prompts(model, "professional_law") 324 | get_mmlu_professional_medicine_prompts = lambda model: _get_mmlu_prompts(model, "professional_medicine") 325 | get_mmlu_professional_psychology_prompts = lambda model: _get_mmlu_prompts(model, "professional_psychology") 326 | get_mmlu_public_relations_prompts = lambda model: _get_mmlu_prompts(model, "public_relations") 327 | get_mmlu_security_studies_prompts = lambda model: _get_mmlu_prompts(model, "security_studies") 328 | get_mmlu_sociology_prompts = lambda model: _get_mmlu_prompts(model, "sociology") 329 | get_mmlu_us_foreign_policy_prompts = lambda model: _get_mmlu_prompts(model, "us_foreign_policy") 330 | get_mmlu_virology_prompts = lambda model: _get_mmlu_prompts(model, "virology") 331 | get_mmlu_world_religions_prompts = lambda model: _get_mmlu_prompts(model, "world_religions") 332 | -------------------------------------------------------------------------------- /model_equality_testing/src/distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import Union, Tuple, List, Dict 4 | from model_equality_testing.utils import ( 5 | ndim, 6 | stack_with_padding, 7 | Stopwatch, 8 | ) 9 | import os 10 | from functools import lru_cache, cache 11 | 12 | FILE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 13 | 14 | 15 | def _sample_from_categorical( 16 | p: Union[np.ndarray, torch.Tensor, List[torch.Tensor]], 17 | n: Union[int, List[int], np.ndarray, torch.Tensor] = 1, 18 | ) -> Union[torch.Tensor, List[torch.Tensor]]: 19 | """ 20 | Samples n indices according to the distribution given by p 21 | Args: 22 | p: (m, k) tensor or m-list of (ki,) tensors giving the probabilities of each 23 | of the k items for each of the m prompts 24 | n: number of samples to draw per prompt (along the first dimension) 25 | can be a scalar or a list of ints 26 | Returns: 27 | sample: 28 | if n is a scalar, an (m, n) tensor 29 | otherwise, a m-list of tensors, where each tensor is (ni,) 30 | """ 31 | if isinstance(n, (list, torch.Tensor, np.ndarray)): 32 | assert len(n) == len(p), "n must be a scalar or have the same length as p" 33 | if isinstance(p, torch.Tensor): 34 | # strategy: draw the max n samples so that we can parallelize, and then truncate 35 | # since p is a tensor, directly use torch.multinomial 36 | max_n = max(n) if isinstance(n, (list, torch.Tensor, np.ndarray)) else n 37 | sample = torch.multinomial(p, max_n, replacement=True) 38 | sample = sample.unsqueeze(-1) 39 | if isinstance(n, (list, torch.Tensor, np.ndarray)) and not all( 40 | ni == max_n for ni in n 41 | ): 42 | sample = [si[:ni] for si, ni in zip(sample, n)] 43 | else: 44 | # since p is a (potentially jagged) list, we need to iterate over each prompt 45 | if isinstance(n, (list, torch.Tensor, np.ndarray)): 46 | sample = [ 47 | torch.multinomial(pi, ni, replacement=True).unsqueeze(-1) 48 | for pi, ni in zip(p, n) 49 | ] 50 | else: 51 | sample = [ 52 | torch.multinomial(pi, n, replacement=True).unsqueeze(-1) for pi in p 53 | ] 54 | 55 | if isinstance(sample, list): 56 | try: 57 | return torch.stack(sample) 58 | except: 59 | return sample 60 | else: 61 | return sample 62 | 63 | 64 | ######################## 65 | # Distribution objects # 66 | ######################## 67 | 68 | 69 | class CompletionSample: 70 | def __init__(self, prompts: Union[np.ndarray, torch.Tensor], completions: Union[np.ndarray, torch.Tensor], m: int): 71 | """ 72 | Represents a sample from a DistributionFromDataset object. 73 | Args: 74 | prompts: a (N,) tensor or array of prompt indices 75 | completions: a (N, L) tensor or array of completions, where L is the maximum completion length 76 | i.e. this is pre-padded. prompts[i] should correspond to the prompt for completions[i] 77 | m: total number of prompts; used to enforce that the prompt indices are in [0, m) 78 | """ 79 | self.m = m 80 | 81 | if isinstance(prompts, np.ndarray): 82 | prompts = torch.from_numpy(prompts) 83 | if isinstance(completions, np.ndarray): 84 | completions = torch.from_numpy(completions) 85 | 86 | self.prompt_sample = prompts.clone().detach() 87 | self.completion_sample = completions.clone().detach() 88 | if self.completion_sample.ndim == 1: 89 | self.completion_sample = self.completion_sample.unsqueeze(-1) 90 | 91 | self.L = self.completion_sample.shape[-1] 92 | self.N = len(self.prompt_sample) 93 | self.ns = torch.tensor( 94 | [(self.prompt_sample == i).sum().item() for i in range(m)] 95 | ) # effective number of completions for each prompt 96 | assert self.completion_sample.ndim == 2 97 | assert self.prompt_sample.ndim == 1 98 | 99 | @property 100 | def shape(self): 101 | return (self.m, self.ns.tolist(), self.L) 102 | 103 | @property 104 | @cache 105 | def sample(self): 106 | """(m, n, L) tensor or len-m list of (ni, L) tensors""" 107 | return self._prompt_completion_to_sample( 108 | self.prompt_sample, self.completion_sample 109 | ) 110 | 111 | @property 112 | @cache 113 | def sequences(self): 114 | return torch.cat( 115 | [self.prompt_sample.unsqueeze(1), self.completion_sample], dim=1 116 | ) 117 | 118 | def __str__(self): 119 | return f"CompletionSample with n={self.ns}" 120 | 121 | def __repr__(self): 122 | return str(self.sequences) 123 | 124 | @cache 125 | def _prompt_completion_to_sample(self, prompt_sample, completion_sample): 126 | """ 127 | Converts view 1 to view 2 128 | """ 129 | assert len(prompt_sample) == len(completion_sample) == self.N 130 | max_n = max(self.ns) 131 | if all([ni == max_n for ni in self.ns]): 132 | sample = torch.zeros(self.m, max_n, self.shape[-1], dtype=int) 133 | else: 134 | sample = [torch.zeros(ni, self.shape[-1], dtype=int) for ni in self.ns] 135 | count_up = torch.zeros(self.m, dtype=int) 136 | for i, (prompt, completion) in enumerate(zip(prompt_sample, completion_sample)): 137 | sample[prompt][count_up[prompt].item()] = completion 138 | count_up[prompt] += 1 139 | return sample 140 | 141 | def _sample_to_prompt_completion(self, sample): 142 | """ 143 | Converts view 2 to view 1 144 | """ 145 | assert ndim(sample) == 3 146 | assert len(sample) == self.m 147 | if isinstance(sample, torch.Tensor): 148 | indices = [] 149 | for i, tensor in enumerate(sample): 150 | indices.extend([i] * len(tensor)) 151 | indices = torch.tensor(indices) 152 | completion_sample = sample.view(-1, sample.shape[-1]) 153 | else: 154 | indices = [] 155 | for i, tensor in enumerate(sample): 156 | indices.extend([i] * len(tensor)) 157 | indices = torch.tensor(indices) 158 | completion_sample = torch.cat(sample) 159 | return indices, completion_sample 160 | 161 | 162 | class DistributionFromDataset: 163 | """ 164 | Helper class to create a prompt-completion distribution from a dataset of text samples. 165 | Given a dataset of completions for each prompt, we can create a distribution over 166 | prompt-completion pairs by sampling with replacement from the completions for each prompt. 167 | """ 168 | 169 | def __init__( 170 | self, 171 | sample_paths: List[Tuple[str, callable]], 172 | L: int, 173 | prompt_distribution: Union[np.ndarray, torch.Tensor] = None, 174 | logprob_paths: List[Tuple[str, callable]] = None, 175 | pad_token_id: int = -1, 176 | ): 177 | """ 178 | Args: 179 | sample_paths: list of tuples of paths to completion files and a callable to load them 180 | The length of the list is the number of prompts m. 181 | Example: sample_paths = [ 182 | ("completions-to-prompt-1-file-500-samples", load_fn), 183 | ("completions-to-prompt-2-file-200-samples", load_fn), 184 | ("completions-to-prompt-2-file-500-samples", load_fn), 185 | ] 186 | => DistributionFromDataset(m=3, k=[500, 200, 500], L=self.L) 187 | The callable should return an integer numpy array of shape (N, L') and pad with 188 | the same pad_token_id passed into this class if any completions are shorter than L' 189 | Example: 190 | np.array([[1, 2, 3], [4, 5, 6], [7, 8, self.pad_token_id]]) 191 | L: maximum completion length 192 | If the length of completions in files is L' > L, we truncate to L 193 | prompt_distribution: (m,) tensor 194 | The distribution over prompts 195 | If None, defaults to uniform 196 | logprob_paths: list of tuples of paths to logprob files and a callable to load them 197 | The length of the list is the number of prompts m. 198 | All completions that we load from the sample_paths should be in the logprob files 199 | The callable should return a dictionary of completions (tuples) to numpy arrays of 200 | logprobs, where the length of the key is the same length as the value. 201 | Example: 202 | { 203 | (0, 1, 2): [logprob_0, logprob_1, logprob_2], 204 | } 205 | pad_token_id: the id of the padding token 206 | If the length of completions in files is L' < L, we pad with this token 207 | """ 208 | self.sample_paths = sample_paths 209 | self.logprob_paths = logprob_paths 210 | self.m = len(sample_paths) 211 | self.L = L 212 | if prompt_distribution is None: 213 | prompt_distribution = torch.ones(self.m) / self.m 214 | self.prompt_distribution = prompt_distribution 215 | self.pad_token_id = pad_token_id 216 | 217 | # first call to _load_samples: establish ki, the number of possible completions for each prompt 218 | print("Initializing DistributionFromDataset; loading k...") 219 | self.k = torch.tensor( 220 | [len(self._load_samples(i, print_info=True)) for i in range(self.m)] 221 | ) 222 | print("Done initializing") 223 | 224 | def __len__(self): 225 | return self.m 226 | 227 | @property 228 | def shape(self): 229 | return (self.m, self.k.tolist(), self.L) 230 | 231 | @lru_cache(maxsize=100) 232 | def _load_samples(self, i, print_info=False) -> torch.Tensor: 233 | """ 234 | Loads the cached samples for the given prompt i based on the sample_paths passed 235 | in at initialization 236 | Recall that sample_paths is of type List[Tuple[str, callable]] 237 | The callable should return an integer numpy array of shape (k, L') and pad with 238 | the same pad_token_id passed into this class if any completions are shorter than L' 239 | Example: 240 | np.array([[1, 2, 3], [4, 5, 6], [7, 8, self.pad_token_id]]) 241 | Args: 242 | i: the index of the prompt 243 | Returns: 244 | samples: (k, L) tensor 245 | """ 246 | path, load_fn = self.sample_paths[i] 247 | with Stopwatch() as sw: 248 | try: 249 | sample = load_fn(path) # should be a numpy array 250 | sample = torch.from_numpy(sample) 251 | except Exception as e: 252 | raise ValueError(f"Error loading {path} for prompt {i}: {e}") 253 | if print_info: 254 | print( 255 | f"\tTime to load distribution {i} w/ {len(sample)} entries: {sw.time}" 256 | ) 257 | if len(sample) == 0: 258 | print(f"Warning: sample {i} is empty") 259 | else: 260 | # truncate to L 261 | sample = stack_with_padding( 262 | sample[:, : self.L], 263 | padding_token=self.pad_token_id, 264 | )[0] 265 | # pad to L 266 | sample = torch.cat( 267 | [ 268 | sample, 269 | self.pad_token_id 270 | * torch.ones((len(sample), (self.L - sample.shape[1])), dtype=int), 271 | ], 272 | dim=-1, 273 | ) 274 | return sample 275 | 276 | @lru_cache(maxsize=100) 277 | def _load_probs(self, i, print_info=False) -> dict: 278 | """ 279 | Loads the cached probs for the given prompt i based on the logprob_paths passed 280 | in at initialization 281 | Recall that logprob_paths is of type List[Tuple[str, callable]] 282 | The callable should return a dictionary of completions (tuples) to numpy arrays of 283 | logprobs, where the length of the key is the same length as the value. 284 | Example: 285 | { 286 | (0, 1, 2): [logprob_0, logprob_1, logprob_2], 287 | } 288 | Args: 289 | i: the index of the prompt 290 | Returns: 291 | d: dictionary of completions (tuples) to probs (floats) 292 | """ 293 | path, load_fn = self.logprob_paths[i] 294 | with Stopwatch() as sw: 295 | d = load_fn( 296 | path 297 | ) # map from completions (as tuples of ints) to probs (floats) 298 | # truncate to L 299 | new_d = {} 300 | for k, v in d.items(): 301 | key = tuple( 302 | k[:self.L] + (self.pad_token_id, ) * (self.L - min(self.L, len(k))) 303 | ) 304 | new_d[key] = v[: self.L].sum().exp().item() 305 | d = new_d 306 | 307 | if print_info: 308 | print(f"\tTime to load distribution {i} w/ {len(d)} entries: {sw.time}") 309 | print(f"Sum of cached probs is {np.sum(list(d.values()))}") 310 | return d 311 | 312 | def sample(self, n: int = 1, prompt_indices=None) -> CompletionSample: 313 | """ 314 | Samples n prompt-completion pairs according to the joint distribution 315 | P(x, y) = P(x) * P(y | x) 316 | by using the following two-step process: 317 | 1. Sample a prompt x from the prompt distribution 318 | 2. Sample a completion y from the completion distribution given x 319 | The completion distribution is assumed to be uniform over the completions 320 | loaded from the dataset; we sample with replacement. 321 | Args: 322 | n: number of samples to draw overall 323 | """ 324 | if prompt_indices is not None: 325 | assert 0 <= min(prompt_indices) and max(prompt_indices) < self.m 326 | else: 327 | prompt_indices = list(range(self.m)) 328 | 329 | # First, sample from the prompt distribution 330 | prompt_samples = _sample_from_categorical( 331 | self.prompt_distribution, n=n 332 | ).squeeze() 333 | 334 | # Then, count how many completions we need for each prompt 335 | prompt_indices, prompt_counts = torch.unique(prompt_samples, return_counts=True) 336 | prompt_indices = prompt_indices.tolist() 337 | prompt_counts = prompt_counts.tolist() 338 | 339 | # Finally, sample from the completion distribution for each prompt 340 | samples = [] # note: will be ordered by prompt index 341 | for i, ni in zip(prompt_indices, prompt_counts): 342 | s = self._load_samples(i) # (ki, L) tensor 343 | samples.append(s[np.random.choice(len(s), ni, replace=True)]) 344 | assert len(samples) == len(prompt_indices) 345 | 346 | # Create a CompletionSample object 347 | sample = CompletionSample( 348 | prompts=torch.cat( 349 | [torch.tensor([i] * ni) for i, ni in zip(prompt_indices, prompt_counts)] 350 | ), # prompt samples (N,) 351 | completions=torch.cat(samples), # completion samples (N,) 352 | m=self.m, 353 | ) 354 | return sample 355 | 356 | def get_completion_probabilities(self, sequences: torch.Tensor) -> torch.Tensor: 357 | """ 358 | Given a list of n sequences [(x, y_1, ...., y_L)] 359 | return an (n,) array of [P(y_1, ..., y_L | x)] 360 | by looking up the completion in the output of self._load_probs(x) 361 | Args: 362 | sequences: (n, L+1) tensor of sequences 363 | where the first column is the prompt index 364 | Returns: 365 | probabilities: (n,) tensor of probabilities 366 | """ 367 | assert sequences.ndim == 2 368 | final_probs = torch.zeros(len(sequences), dtype=self.prompt_distribution.dtype) 369 | prompts = sequences[:, 0] 370 | 371 | for prompt_index in range(self.m): 372 | mask = prompts == prompt_index 373 | if mask.sum() == 0: 374 | continue 375 | try: 376 | d = self._load_probs(prompt_index) 377 | except: 378 | raise ValueError( 379 | "Cannot get completion probabilities because could not find logprob files." 380 | ) 381 | vals = torch.tensor([d[tuple(row[1:].tolist())] for row in sequences[mask]]) 382 | final_probs[mask] = vals 383 | return final_probs 384 | 385 | def get_all_completion_probabilities(self, i): 386 | """ 387 | Enumerate all completions (y_1, ..., y_L) given x=i and their probabilities 388 | P(y_1, ..., y_L | x=i) 389 | Args: 390 | i (int): the prompt index 391 | Returns: 392 | completions: (n, L) tensor 393 | p: (n,) tensor 394 | """ 395 | try: 396 | d = self._load_probs(i) 397 | except: 398 | raise ValueError( 399 | "Cannot get completion probabilities because could not find logprob files." 400 | ) 401 | sequences_i = torch.tensor(list(d.keys())) 402 | p_i = torch.tensor(list(d.values())) 403 | return sequences_i, p_i 404 | 405 | @cache 406 | def get_all_joint_probabilities(self): 407 | """ 408 | Enumerate all joint sequences (x, y_1, ..., y_L) and their joint probabilities 409 | P(x, y_1, ..., y_L) = P(x) * P(y_1, ..., y_L | x) 410 | Returns: 411 | unique: (ntilde, L+1) tensor of joint sequences 412 | probs: (ntilde,) tensor of joint probabilities 413 | """ 414 | sequences, p = [], [] 415 | for i in range(self.m): 416 | sequences_i, p_i = self.get_all_completion_probabilities(i) 417 | sequences.append(sequences_i) 418 | p.append(p_i) 419 | probs = torch.cat( 420 | [p[i] * self.prompt_distribution[i] for i in range(self.m)] 421 | ).view(-1) 422 | unique = torch.cat( 423 | [ 424 | torch.tensor([[i] + list(s) for s in sequences_i]) 425 | for i, sequences_i in enumerate(sequences) 426 | ] 427 | ) 428 | return unique, probs 429 | 430 | def __del__(self): 431 | self._load_samples.cache_clear() 432 | self._load_probs.cache_clear() 433 | self.get_all_joint_probabilities.cache_clear() 434 | -------------------------------------------------------------------------------- /model_equality_testing/src/tests.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions to compute test statistics given sample(s) and null distribution(s). 3 | """ 4 | 5 | import numpy as np 6 | from model_equality_testing.utils import ( 7 | Stopwatch, 8 | get_inv, 9 | ) 10 | import torch 11 | from typing import Union, Tuple, List, Dict 12 | from model_equality_testing.distribution import ( 13 | CompletionSample, 14 | DistributionFromDataset, 15 | ) 16 | from functools import lru_cache 17 | from collections import Counter 18 | 19 | ####################### 20 | # Two sample tests 21 | ####################### 22 | 23 | #### MMD tests #### 24 | 25 | 26 | def _mmd( 27 | X: np.ndarray, 28 | Y: np.ndarray, 29 | get_kernel: callable, 30 | normalize=True, 31 | print_info=False, 32 | ) -> float: 33 | """ 34 | Helper function to compute MMD test statistic. 35 | Handles normalization. 36 | Args: 37 | X: (n, L+1) numpy array of n sequences of length L. 38 | The first column X[:, 0] is an integer indicating the prompt. 39 | Y: (m, L+1) numpy array of m sequences of length L. 40 | The first column Y[:, 0] is an integer indicating the prompt. 41 | get_kernel: function that computes the kernel matrices K_XX, K_XY, K_YY given X, Y 42 | normalize: whether to normalize the kernel matrices 43 | print_info: whether to print time to compute kernels 44 | Returns: 45 | MMD test statistic 46 | """ 47 | # Create a mask that is True if the prompts are different 48 | # When computing the kernel, we will zero out the entries where the mask is True 49 | # since we define the kernel to be 0 when the prompts are different 50 | prompts_x, prompts_y = X[:, 0], Y[:, 0] 51 | mask_XY = prompts_x[:, None] != prompts_y[None, :] 52 | mask_XX = prompts_x[:, None] != prompts_x[None, :] 53 | mask_YY = prompts_y[:, None] != prompts_y[None, :] 54 | 55 | # Call get_kernel to compute the kernel matrices 56 | with Stopwatch() as sw: 57 | K_XX, K_XY, K_YY = get_kernel( 58 | X[:, 1:], Y[:, 1:], mask_XX, mask_XY, mask_YY 59 | ) # remove prompt from seq 60 | if print_info: 61 | print("Time to compute kernels", sw.time) 62 | n_XX, n_XY, n_YY = K_XX.size, K_XY.size, K_YY.size 63 | 64 | # Zero out sequences from different prompts according to the mask 65 | K_XY[mask_XY] = 0 66 | n_XY -= mask_XY.sum() 67 | K_XX[mask_XX] = 0 68 | n_XX -= mask_XX.sum() 69 | K_YY[mask_YY] = 0 70 | n_YY -= mask_YY.sum() 71 | 72 | # Normalize the kernel matrices s.t. diagonal is 1 73 | if normalize: 74 | # kernel'[x, y] = kernel[x, y] / sqrt(kernel[x, x] * kernel[y, y]) 75 | diagX = np.sqrt(np.diag(K_XX)) 76 | diagY = np.sqrt(np.diag(K_YY)) 77 | diagX[diagX == 0] = 1 78 | diagY[diagY == 0] = 1 79 | K_XX /= np.outer(diagX, diagX) 80 | K_YY /= np.outer(diagY, diagY) 81 | K_XY /= np.outer(diagX, diagY) 82 | 83 | # Zero out samples with themselves 84 | np.fill_diagonal(K_XX, 0) 85 | n_XX -= len(K_XX) 86 | np.fill_diagonal(K_YY, 0) 87 | n_YY -= len(K_YY) 88 | 89 | # Compute empirical MMD estimate 90 | return np.sum(K_XX) / n_XX - 2 * np.sum(K_XY) / n_XY + np.sum(K_YY) / n_YY 91 | 92 | 93 | def _reflect_upper_triangular(K: np.ndarray) -> np.ndarray: 94 | """ 95 | Helper function to reflect the upper triangular part of a matrix to the lower triangular part. 96 | Args: 97 | K: (n, n) numpy array with the diagonal + upper right part filled in 98 | Returns: 99 | (n, n) numpy array with the diagonal + upper right part filled in, and the lower left part 100 | filled in by reflecting the upper right part 101 | """ 102 | np.fill_diagonal(K, K.diagonal() / 2) 103 | return K + K.T 104 | 105 | 106 | def mmd_hamming( 107 | sample1: CompletionSample, 108 | sample2: CompletionSample, 109 | ) -> float: 110 | """ 111 | MMD test statistic using K(x, y) = sum_i^L 1[x_i == y_i], 112 | i.e. whether the marginal densities match 113 | """ 114 | 115 | def get_hamming_kernel( 116 | X: np.ndarray, Y: np.ndarray, *args, memory_threshold: int = 10000, **kwargs 117 | ): 118 | """ 119 | Args: 120 | X: (n, L) numpy array of n sequences of length L 121 | Y: (m, L) numpy array of m sequences of length L 122 | Returns: 123 | K(X, X), K(X, Y), K(Y, Y) as a tuple 124 | """ 125 | n, L = X.shape 126 | m, _ = Y.shape 127 | max_size_XX = n * n 128 | max_size_XY = n * m 129 | max_size_YY = m * m 130 | if max(max_size_XX, max_size_XY, max_size_YY) <= memory_threshold**2: 131 | K_XX = np.sum(X[:, None, :] == X[None, :, :], axis=-1).astype(float) 132 | K_XY = np.sum(X[:, None, :] == Y[None, :, :], axis=-1).astype(float) 133 | K_YY = np.sum(Y[:, None, :] == Y[None, :, :], axis=-1).astype(float) 134 | else: 135 | print("To save memory, computing Hamming using for loops") 136 | K_XX = np.zeros((n, n), dtype=float) 137 | K_XY = np.zeros((n, m), dtype=float) 138 | K_YY = np.zeros((m, m), dtype=float) 139 | for i in range(n): 140 | for j in range(i, n): 141 | K_XX[i, j] = K_XX[j, i] = np.sum(X[i] == X[j]) 142 | for i in range(n): 143 | for j in range(m): 144 | K_XY[i, j] = K_XY[j, i] = np.sum(X[i] == Y[j]) 145 | for i in range(m): 146 | for j in range(i, m): 147 | K_YY[i, j] = K_YY[j, i] = np.sum(Y[i] == Y[j]) 148 | return K_XX, K_XY, K_YY 149 | 150 | return _mmd( 151 | X=sample1.sequences.numpy(), 152 | Y=sample2.sequences.numpy(), 153 | get_kernel=get_hamming_kernel, 154 | ) 155 | 156 | 157 | @lru_cache(maxsize=10000) 158 | def _get_kgrams(input_list: List[tuple], k: int, cumulative: bool = False) -> Counter: 159 | """ 160 | Given a list of sequences, returns a Counter of all k-grams in the sequences. 161 | Args: 162 | input_list: tuple of sequences 163 | k: length of k-grams 164 | cum: whether to return all k-grams up to length k or just k-grams of length k 165 | """ 166 | out = Counter() 167 | for ki in range(1, k + 1) if cumulative else [k]: 168 | out.update(zip(*(input_list[i:] for i in range(ki)))) 169 | return out 170 | 171 | 172 | @lru_cache(maxsize=100000) 173 | def _compute_dot_product_counts( 174 | a: List[tuple], b: List[tuple], k: int, cumulative: bool 175 | ) -> int: 176 | r""" 177 | Given two sequences, computes the dot product of the counts of k-grams in the two sequences. 178 | $$ 179 | \sum_{s \in a} \#(s \in a) \#(s \in b) 180 | $$ 181 | Args: 182 | a: sequence 183 | """ 184 | d1 = _get_kgrams(a, k, cumulative) 185 | d2 = _get_kgrams(b, k, cumulative) 186 | return sum(d1[key] * d2.get(key, 0) for key in d1) 187 | 188 | 189 | def mmd_kspectrum( 190 | sample1: CompletionSample, 191 | sample2: CompletionSample, 192 | k: int = 5, 193 | cumulative: bool = True, 194 | ): 195 | r""" 196 | MMD test statistic using K(x, y) = \sum_{len-k substrings of len L} #(s in x) #(s in y) 197 | Args: 198 | sample1: CompletionSample 199 | sample2: CompletionSample 200 | k: length of k-grams 201 | cumulative: whether to use all k-grams up to length k or just k-grams of length k 202 | """ 203 | 204 | def get_kspectrum_kernel( 205 | X: np.ndarray, 206 | Y: np.ndarray, 207 | mask_XX: np.ndarray = None, 208 | mask_XY: np.ndarray = None, 209 | mask_YY: np.ndarray = None, 210 | ): 211 | """ 212 | Args: 213 | X: (n, L) numpy array of n sequences of length L 214 | Y: (m, L) numpy array of m sequences of length L 215 | Returns: 216 | K(X, X), K(X, Y), K(Y, Y) as a tuple 217 | """ 218 | # functools caches require hashable inputs, so convert the 2D numpy arrays to lists of tuples 219 | Xp = list(map(tuple, X)) 220 | Yp = list(map(tuple, Y)) 221 | 222 | def _get_kernel(A, B, mask, diag=False): 223 | out = np.zeros((len(A), len(B))) 224 | for i in range(len(A)): 225 | for j in range(i if diag else 0, len(B)): 226 | if mask is not None and mask[i, j]: 227 | continue 228 | ordered = (A[i], B[j]) if A[i] > B[j] else (B[j], A[i]) 229 | out[i, j] = _compute_dot_product_counts(*ordered, k, cumulative) 230 | return out 231 | 232 | K_XX = _get_kernel(Xp, Xp, mask_XX, diag=True) 233 | K_YY = _get_kernel(Yp, Yp, mask_YY, diag=True) 234 | K_XY = _get_kernel(Xp, Yp, mask_XY, diag=False) 235 | K_XX = _reflect_upper_triangular(K_XX) 236 | K_YY = _reflect_upper_triangular(K_YY) 237 | return K_XX, K_XY, K_YY 238 | 239 | return _mmd( 240 | X=sample1.sequences.numpy(), 241 | Y=sample2.sequences.numpy(), 242 | get_kernel=get_kspectrum_kernel, 243 | ) 244 | 245 | 246 | def mmd_all_subsequences( 247 | sample1: CompletionSample, 248 | sample2: CompletionSample, 249 | ): 250 | r""" 251 | Computes the all-subsequences MMD test statistic, which is the MMD test statistic using 252 | K(x, y) = \sum_{s \in x} \sum_{s \in y} 1[s in x and s in y] 253 | for all subsequences s of x and y. 254 | Args: 255 | sample1: CompletionSample 256 | sample2: CompletionSample 257 | """ 258 | L = sample1.shape[-1] 259 | return mmd_kspectrum(sample1, sample2, k=L, cumulative=True) 260 | 261 | 262 | #### Other two-sample tests #### 263 | 264 | 265 | def _get_counts(sample1: torch.Tensor, sample2: torch.Tensor): 266 | """ 267 | Given two 2D samples, finds the unique rows in the union of the two 268 | samples, and returns counts for each unique row in sample1, sample2. 269 | Args: 270 | sample1: (n, L) tensor where each row is a sequence 271 | sample2: (m, L) tensor where each row is a sequence 272 | Returns: 273 | unique_sequences: (k, L) tensor of unique sequences 274 | counts1_full: (k,) tensor of counts for each unique sequence in sample1 275 | counts2_full: (k,) tensor of counts for each unique sequence in sample2 276 | """ 277 | unique1, counts1 = torch.unique(sample1, dim=0, return_counts=True) 278 | unique2, counts2 = torch.unique(sample2, dim=0, return_counts=True) 279 | all_sequences = torch.cat((unique1, unique2)) 280 | unique_sequences, inverse = torch.unique(all_sequences, dim=0, return_inverse=True) 281 | 282 | counts1_full = torch.zeros(len(unique_sequences), dtype=torch.int64) 283 | counts2_full = torch.zeros(len(unique_sequences), dtype=torch.int64) 284 | counts1_full[inverse[: len(unique1)]] = counts1 285 | counts2_full[inverse[len(unique1) :]] = counts2 286 | return unique_sequences, counts1_full, counts2_full 287 | 288 | 289 | def two_sample_chi_squared( 290 | sample1: CompletionSample, 291 | sample2: CompletionSample, 292 | ): 293 | r""" 294 | Computes the two-sample (centered) chi-squared test statistic, which has been modified for the imbalanced sample size case by 295 | Bhattacharya and Valiant (2015) and as cited in Balakrishnan & Wasserman (2017). 296 | $$ 297 | \sum_{i=1}^k \frac{(N_2 c^1_i - N_1 c^2_i)^2}{c^1_i + c^2_i} - N_2^2 c^1_i - N_1^2 c^2_i 298 | $$ 299 | where $c^1_i$ is the count of the $i$th unique sequence in sample1, $c^2_i$ is the count of the $i$th unique sequence in sample2, 300 | $N_1$ is the total number of sequences in sample1, and $N_2$ is the total number of sequences in sample2. 301 | 302 | References: 303 | - Bhattacharya and Valiant (2015) "Testing Closeness with Unequal Sized Samples" [step 2 in Alg 1] 304 | https://arxiv.org/abs/1504.04599 305 | - Balakrishnan & Wasserman (2017) "Hypothesis Testing for High-Dimensional Multinomials: A Selective Review" 306 | https://arxiv.org/abs/1712.06120 307 | 308 | Args: 309 | sample1: CompletionSample 310 | sample2: CompletionSample 311 | """ 312 | _, c1, c2 = _get_counts(sample1.sequences, sample2.sequences) 313 | return np.nansum( 314 | ( 315 | np.square(sample2.N * c1 - sample1.N * c2) 316 | - sample2.N**2 * c1 317 | - sample1.N**2 * c2 318 | ) 319 | / (c1 + c2) 320 | ) 321 | 322 | 323 | def two_sample_L1( 324 | sample1: CompletionSample, 325 | sample2: CompletionSample, 326 | ): 327 | r""" 328 | Computes the two-sample L1 test statistic 329 | $$ 330 | \sum_{i=1}^k |c^1_i / N_1 - c^2_i / N_2| 331 | $$ 332 | where $c^1_i$ is the count of the $i$th unique sequence in sample1, $c^2_i$ is the count of the $i$th unique sequence in sample2, 333 | $N_1$ is the total number of sequences in sample1, and $N_2$ is the total number of sequences in sample2. 334 | 335 | References: 336 | - Balakrishnan & Wasserman (2017) "Hypothesis Testing for High-Dimensional Multinomials: A Selective Review" 337 | https://arxiv.org/abs/1712.06120 338 | """ 339 | _, c1, c2 = _get_counts(sample1.sequences, sample2.sequences) 340 | return torch.sum(torch.abs(c1 / sample1.N - c2 / sample2.N)).item() 341 | 342 | 343 | def two_sample_L2( 344 | sample1: CompletionSample, 345 | sample2: CompletionSample, 346 | ): 347 | r""" 348 | Computes the two-sample L2 test statistic 349 | $$ 350 | \sum_{i=1}^k (c^1_i / N_1 - c^2_i / N_2)^2 351 | $$ 352 | where $c^1_i$ is the count of the $i$th unique sequence in sample1, $c^2_i$ is the count of the $i$th unique sequence in sample2, 353 | $N_1$ is the total number of sequences in sample1, and $N_2$ is the total number of sequences in sample2. 354 | 355 | References: 356 | - Balakrishnan & Wasserman (2017) "Hypothesis Testing for High-Dimensional Multinomials: A Selective Review" 357 | https://arxiv.org/abs/1712.06120 358 | """ 359 | _, c1, c2 = _get_counts(sample1.sequences, sample2.sequences) 360 | return torch.sum(torch.square(c1 / sample1.N - c2 / sample2.N)).item() 361 | 362 | 363 | ####################### 364 | # Goodness of fit tests 365 | ####################### 366 | 367 | 368 | def g_squared( 369 | sample: CompletionSample, 370 | null_dist: DistributionFromDataset, 371 | ): 372 | r""" 373 | Computes the G^2 / LRT test statistic 374 | $$ 375 | -2 \sum_{i=1}^k o_i \log(p_i / m_i) 376 | $$ 377 | where $o_i$ is the observed count of the $i$th unique sequence in the sample, 378 | $p_i$ is the completion probability of the $i$th unique sequence under the null distribution, 379 | and $m_i$ is the MLE of the completion probability of the $i$th unique sequence in the sample. 380 | 381 | Args: 382 | sample: CompletionSample 383 | null_dist: DistributionFromDataset 384 | """ 385 | 386 | def _stat(p, m, o): 387 | # p = probabilities, m = mles, o = observed counts 388 | test_stat = np.log(p) - np.log(m) 389 | test_stat *= o 390 | test_stat = np.nansum(test_stat, axis=-1) 391 | return -2 * test_stat 392 | 393 | sequences, counts = torch.unique( 394 | sample.sequences, 395 | return_counts=True, 396 | dim=0, 397 | ) 398 | probs = ( 399 | null_dist.get_completion_probabilities(sequences) 400 | * null_dist.prompt_distribution[sequences[:, 0]] 401 | ) 402 | return _stat(probs, counts / sample.N, counts) 403 | 404 | 405 | def chi_squared( 406 | sample: CompletionSample, 407 | null_dist: DistributionFromDataset, 408 | ): 409 | r""" 410 | Computes the Pearson chi_squared test statistic 411 | $$ 412 | \sum_{i=1}^k \frac{(o_i - n p_i)^2}{n p_i} 413 | $$ 414 | where $o_i$ is the observed count of the $i$th unique sequence in the sample, 415 | $p_i$ is the completion probability of the $i$th unique sequence under the null distribution, 416 | and $n$ is the total number of sequences in the sample. 417 | 418 | Args: 419 | sample: CompletionSample 420 | null_dist: DistributionFromDataset 421 | """ 422 | 423 | def _stat(seqs, probs, obs): 424 | hashmap = get_inv(tuple(map(tuple, seqs.numpy()))) 425 | o, counts = torch.unique(obs, return_counts=True, dim=0) 426 | ixs = [hashmap[tuple(row.tolist())] for row in o] 427 | n = len(obs) 428 | obs_probs = probs[ixs] 429 | unobs_probs = probs[np.setdiff1d(np.arange(len(seqs)), ixs)] 430 | return ( 431 | (np.square(counts - n * obs_probs) / (n * obs_probs)).nansum() 432 | + (np.square(0 - n * unobs_probs) / (n * unobs_probs)).nansum() 433 | ).item() 434 | 435 | sequences, probs = null_dist.get_all_joint_probabilities() 436 | return _stat(sequences, probs, sample.sequences) 437 | 438 | 439 | def truncated_chi_squared( 440 | sample: CompletionSample, 441 | null_dist: DistributionFromDataset, 442 | ): 443 | r""" 444 | Computes the truncated chi_squared test statistic 445 | $$ 446 | \sum_{i=1}^k \frac{(o_i - n p_i)^2 - o_i}{\max(p_i, 1/k)} 447 | $$ 448 | where $o_i$ is the observed count of the $i$th unique sequence in the sample, 449 | $p_i$ is the completion probability of the $i$th unique sequence under the null distribution, 450 | and $n$ is the total number of sequences in the sample. 451 | 452 | References: 453 | - Balakrishnan & Wasserman (2017) "Hypothesis Testing for High-Dimensional Multinomials: A Selective Review" 454 | https://arxiv.org/abs/1712.06120 455 | 456 | Args: 457 | sample: CompletionSample 458 | null_dist: DistributionFromDataset 459 | """ 460 | 461 | def _stat(seqs, probs, obs): 462 | hashmap = get_inv(tuple(map(tuple, seqs.numpy()))) 463 | o, counts = torch.unique(obs, return_counts=True, dim=0) 464 | n = len(obs) 465 | ixs = [hashmap[tuple(row.tolist())] for row in o] 466 | obs_probs = probs[ixs] 467 | unobs_probs = probs[np.setdiff1d(np.arange(len(seqs)), ixs)] 468 | k = len(seqs) 469 | return ( 470 | ( 471 | (np.square(counts - n * obs_probs) - counts) 472 | / np.maximum(obs_probs, 1 / k * torch.ones_like(obs_probs)) 473 | ).nansum() 474 | + ( 475 | np.square(0 - n * unobs_probs) 476 | / np.maximum(unobs_probs, 1 / k * torch.ones_like(unobs_probs)) 477 | ).nansum() 478 | ).item() 479 | 480 | sequences, probs = null_dist.get_all_joint_probabilities() 481 | return _stat(sequences, probs, sample.sequences) 482 | 483 | 484 | def L1( 485 | sample: CompletionSample, 486 | null_dist: DistributionFromDataset, 487 | ): 488 | r""" 489 | Computes the L1 test statistic 490 | $$ 491 | \sum_{i=1}^k |o_i - n p_i| 492 | $$ 493 | 494 | Args: 495 | sample: CompletionSample 496 | null_dist: DistributionFromDataset 497 | """ 498 | 499 | def _stat(seqs, probs, obs): 500 | hashmap = get_inv(tuple(map(tuple, seqs.numpy()))) 501 | o, counts = torch.unique(obs, return_counts=True, dim=0) 502 | n = len(obs) 503 | ixs = [hashmap[tuple(row.tolist())] for row in o] 504 | obs_probs = probs[ixs] 505 | unobs_probs = probs[np.setdiff1d(np.arange(len(seqs)), ixs)] 506 | return ( 507 | np.abs(counts - n * obs_probs).nansum() 508 | + np.abs(0 - n * unobs_probs).nansum() 509 | ).item() 510 | 511 | sequences, probs = null_dist.get_all_joint_probabilities() 512 | return _stat(sequences, probs, sample.sequences) 513 | 514 | 515 | def L2( 516 | sample: CompletionSample, 517 | null_dist: DistributionFromDataset, 518 | ): 519 | r""" 520 | Computes the L2 test statistic 521 | $$ 522 | \sum_{i=1}^k (o_i - n p_i)^2 523 | $$ 524 | where $o_i$ is the observed count of the $i$th unique sequence in the sample, 525 | $p_i$ is the completion probability of the $i$th unique sequence under the null distribution, 526 | and $n$ is the total number of sequences in the sample. 527 | 528 | Args: 529 | sample: CompletionSample 530 | null_dist: DistributionFromDataset 531 | """ 532 | 533 | def _stat(seqs, probs, obs): 534 | hashmap = get_inv(tuple(map(tuple, seqs.numpy()))) 535 | o, counts = torch.unique(obs, return_counts=True, dim=0) 536 | n = len(obs) 537 | ixs = [hashmap[tuple(row.tolist())] for row in o] 538 | obs_probs = probs[ixs] 539 | unobs_probs = probs[np.setdiff1d(np.arange(len(seqs)), ixs)] 540 | return ( 541 | np.square(counts - n * obs_probs).nansum() 542 | + np.square(0 - n * unobs_probs).nansum() 543 | ).item() 544 | 545 | sequences, probs = null_dist.get_all_joint_probabilities() 546 | return _stat(sequences, probs, sample.sequences) 547 | 548 | 549 | ###### map from name to function ###### 550 | 551 | IMPLEMENTED_TESTS = { 552 | "g_squared": g_squared, 553 | "chi_squared": chi_squared, 554 | "truncated_chi_squared": truncated_chi_squared, 555 | "L1": L1, 556 | "L2": L2, 557 | "two_sample_chi_squared": two_sample_chi_squared, 558 | "two_sample_L1": two_sample_L1, 559 | "two_sample_L2": two_sample_L2, 560 | "mmd_hamming": mmd_hamming, 561 | "mmd_kspectrum": mmd_kspectrum, 562 | "mmd_all_subsequences": mmd_all_subsequences, 563 | } 564 | --------------------------------------------------------------------------------