├── tests ├── __init__.py └── test_dp_few_shot_generation │ ├── __init__.py │ └── test_import.py ├── lmapi ├── tests │ ├── __init__.py │ └── test_lmapi │ │ ├── __init__.py │ │ ├── test_import.py │ │ └── async_tools │ │ └── test_asyncitertools.py ├── src │ └── lmapi │ │ ├── __init__.py │ │ ├── py.typed │ │ ├── async_tools │ │ ├── __init__.py │ │ ├── server_sent_events.py │ │ ├── asyncitertools.py │ │ └── limits.py │ │ ├── auth.py │ │ ├── lm.py │ │ └── openai.py ├── poetry.toml ├── README.md └── pyproject.toml ├── src └── dp_few_shot_generation │ ├── __init__.py │ ├── prob_utils.py │ ├── lm.py │ ├── run_exp_movie.py │ ├── run_exp_agnews.py │ ├── run_exp_trec.py │ └── run_exp_dbpedia.py ├── poetry.toml ├── CODE_OF_CONDUCT.md ├── SUPPORT.md ├── NOTICE.txt ├── Makefile ├── LICENSE ├── privacy_analysis ├── exponential.py └── subgaussian.py ├── pyproject.toml ├── run.sh ├── SECURITY.md ├── data └── process_movie.py ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lmapi/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lmapi/src/lmapi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lmapi/src/lmapi/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lmapi/tests/test_lmapi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dp_few_shot_generation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lmapi/src/lmapi/async_tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_dp_few_shot_generation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | in-project = "true" 3 | -------------------------------------------------------------------------------- /lmapi/poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | in-project = "true" 3 | -------------------------------------------------------------------------------- /lmapi/README.md: -------------------------------------------------------------------------------- 1 | This is a small library for accessing language models. 2 | -------------------------------------------------------------------------------- /lmapi/tests/test_lmapi/test_import.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # ruff: noqa: F401 5 | def test_package_import(): 6 | "Tests that lmapi is importable." 7 | import lmapi 8 | -------------------------------------------------------------------------------- /tests/test_dp_few_shot_generation/test_import.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # ruff: noqa: F401 5 | def test_package_import(): 6 | "Tests that dp_few_shot_generation is importable." 7 | import dp_few_shot_generation 8 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For help and questions about using this project, please use GitHub Issues or Discussions. 10 | 11 | ## Microsoft Support Policy 12 | 13 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 14 | -------------------------------------------------------------------------------- /NOTICE.txt: -------------------------------------------------------------------------------- 1 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION 2 | Do Not Translate or Localize 3 | 4 | This software incorporates components from the projects listed below. The original copyright notices 5 | and the licenses under which Microsoft received such components are set forth below and are provided for 6 | informational purposes only. Microsoft reserves all rights not expressly granted herein, whether by 7 | implication, estoppel or otherwise. 8 | 9 | This software includes parts of the following repository: https://github.com/tonyzhaozh/few-shot-learning. 10 | This repository is licensed under Apache License 2.0, you can find a copy of this license at https://github.com/tonyzhaozh/few-shot-learning/blob/main/LICENSE 11 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SOURCE_DIRS = src 2 | TEST_DIRS = tests 3 | SOURCE_AND_TEST_DIRS = $(SOURCE_DIRS) $(TEST_DIRS) 4 | PREFIX = poetry run 5 | 6 | .PHONY: format lint-fix fix format-check lint pyright test 7 | 8 | all: format-check lint pyright test 9 | 10 | format: 11 | $(PREFIX) ruff -e --fix-only --select I001 $(SOURCE_AND_TEST_DIRS) 12 | $(PREFIX) black $(SOURCE_AND_TEST_DIRS) 13 | 14 | lint-fix: 15 | $(PREFIX) ruff -e --fix-only $(SOURCE_AND_TEST_DIRS) 16 | 17 | fix: lint-fix 18 | $(PREFIX) black $(SOURCE_AND_TEST_DIRS) 19 | 20 | format-check: 21 | @($(PREFIX) ruff --select I001 $(SOURCE_AND_TEST_DIRS)) && ($(PREFIX) black --check $(SOURCE_AND_TEST_DIRS)) || (echo "run \"make format\" to format the code"; exit 1) 22 | 23 | lint: 24 | @($(PREFIX) ruff $(SOURCE_AND_TEST_DIRS)) || (echo "run \"make lint-fix\" to fix some lint errors automatically"; exit 1) 25 | 26 | pyright: 27 | $(PREFIX) pyright 28 | 29 | test: 30 | $(PREFIX) python -m pytest $(TEST_DIRS) 31 | -------------------------------------------------------------------------------- /lmapi/src/lmapi/auth.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import dataclasses 5 | from dataclasses import InitVar, dataclass 6 | from typing import Protocol 7 | 8 | 9 | class AuthorizationProvider(Protocol): 10 | def headers(self) -> dict[str, str]: 11 | """Returns the headers to be used for authorization.""" 12 | ... 13 | 14 | 15 | @dataclass 16 | class OpenAiApiKey(AuthorizationProvider): 17 | key: InitVar[str] 18 | _headers: dict[str, str] = dataclasses.field(init=False) 19 | 20 | def __post_init__(self, key: str) -> None: 21 | self._headers = {"Authorization": f"Bearer {key}"} 22 | 23 | def headers(self) -> dict[str, str]: 24 | return self._headers 25 | 26 | 27 | @dataclass 28 | class AoaiApiKey(AuthorizationProvider): 29 | key: InitVar[str] 30 | _headers: dict[str, str] = dataclasses.field(init=False) 31 | 32 | def __post_init__(self, key: str) -> None: 33 | self._headers = {"api-key": key} 34 | 35 | def headers(self) -> dict[str, str]: 36 | return self._headers 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /privacy_analysis/exponential.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from prv_accountant.dpsgd import DPSGDAccountant 5 | import numpy as np 6 | from prv_accountant.privacy_random_variables import PureDPMechanism 7 | from prv_accountant import PRVAccountant 8 | import math 9 | 10 | ###assumming sensitivity as 1. 11 | dataset = "DBPedia" 12 | assert dataset in ["DBPedia"] 13 | sigma_list = [] 14 | 15 | if dataset == "DBPedia": 16 | full_train_num = 40000 17 | n = 80 18 | max_len_token = 100 19 | sigma_list = [2.73, 3.34, 3.95, 4.57] 20 | 21 | 22 | sample_rate = n / full_train_num 23 | print(n, full_train_num) 24 | print("sample rate", sample_rate) 25 | print("steps", max_len_token) 26 | 27 | for sigma in sigma_list: 28 | sigma_bar = math.log(1+sample_rate*(math.exp(sigma)-1)) 29 | print(sigma_bar) 30 | prv_0 = PureDPMechanism(sigma_bar) 31 | 32 | accountant = PRVAccountant( 33 | prvs=[prv_0, ], 34 | max_self_compositions=[max_len_token+1], 35 | eps_error=0.01, 36 | delta_error=1e-10 37 | ) 38 | eps_low, eps_est, eps_up = accountant.compute_epsilon(delta=1/full_train_num, num_self_compositions=[max_len_token]) 39 | 40 | print(sigma, eps_low, eps_est, eps_up) 41 | -------------------------------------------------------------------------------- /lmapi/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "lmapi" 3 | version = "0.0.1dev14" 4 | description = "" 5 | authors = ["Richard Shin "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | aiohttp = "^3.0" 10 | python = "^3.10" 11 | tiktoken = "^0.3.0" 12 | 13 | [tool.poetry.group.dev.dependencies] 14 | black = "^23.3.0" 15 | pyright = "^1.1.303" 16 | pytest = "^7.3.1" 17 | pytest-asyncio = "^0.21.0" 18 | ruff = "^0.0.261" 19 | 20 | [tool.black] 21 | skip-magic-trailing-comma = true 22 | target-version = ["py310"] 23 | 24 | [tool.pyright] 25 | include = [ 26 | "src", 27 | "tests", 28 | ] 29 | reportUnnecessaryCast = "error" 30 | reportUnnecessaryTypeIgnoreComment = "error" 31 | 32 | [tool.ruff] 33 | # See hhttps://beta.ruff.rs/docs/rules/ for a list of rules. 34 | # This list is kept in the same order as the documentation. 35 | select = [ 36 | "E", 37 | "F", 38 | "W", 39 | "I", 40 | "UP", 41 | "B", 42 | "C4", 43 | "RUF", 44 | ] 45 | ignore = [ 46 | # Do not perform function call in argument defaults 47 | "B008", 48 | # Line too long 49 | "E501", 50 | ] 51 | target-version = "py310" 52 | src = [ 53 | "src", 54 | "tests", 55 | ] 56 | 57 | [build-system] 58 | requires = ["poetry-core>=1.0.0"] 59 | build-backend = "poetry.core.masonry.api" 60 | -------------------------------------------------------------------------------- /privacy_analysis/subgaussian.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from prv_accountant.dpsgd import DPSGDAccountant 5 | import numpy as np 6 | from prv_accountant.privacy_random_variables import PoissonSubsampledGaussianMechanism, GaussianMechanism, LaplaceMechanism 7 | from prv_accountant import PRVAccountant 8 | 9 | 10 | dataset = "MIT-G" 11 | assert dataset in ["AGNEWS", "DBPedia", "MIT-D", "MIT-G"] 12 | sigma_list = [] 13 | 14 | if dataset == "AGNEWS": 15 | full_train_num = 30000 16 | n = 20 17 | max_token_cnt = 100 18 | sigma_list = [0.51, 0.46, 0.39, 0.31] 19 | elif dataset == "DBPedia": 20 | full_train_num = 40000 21 | n = 80 22 | max_token_cnt = 100 23 | sigma_list = [0.63, 0.54, 0.45, 0.36] 24 | elif dataset == "MIT-G": 25 | full_train_num = 2953 26 | n = 80 27 | max_token_cnt = 80 28 | sigma_list = [1.08, 0.81, 0.64, 0.5] 29 | elif dataset == "MIT-D": 30 | full_train_num = 1561 31 | n = 80 32 | max_token_cnt = 80 33 | sigma_list = [1.52, 1.04, 0.77, 0.58] 34 | 35 | sample_rate = n / full_train_num 36 | print(dataset) 37 | print(n, full_train_num) 38 | print("sample rate", sample_rate) 39 | print("steps", max_token_cnt) 40 | 41 | for sigma in sigma_list: 42 | prv_0 = PoissonSubsampledGaussianMechanism(noise_multiplier=sigma, sampling_probability=sample_rate) 43 | print("sample rate", sample_rate) 44 | 45 | accountant = PRVAccountant( 46 | prvs=[prv_0, ], 47 | max_self_compositions=[1000], 48 | eps_error=0.01, 49 | delta_error=1e-10 50 | ) 51 | eps_low, eps_est, eps_up = accountant.compute_epsilon(delta=1/full_train_num, num_self_compositions=[max_token_cnt]) 52 | 53 | print(sigma, eps_low, eps_est, eps_up) 54 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "dp-few-shot-generation" 3 | version = "0.1.0" 4 | description = "Differentially private few-shot generation using in-context learning with LLMs." 5 | authors = [ 6 | # See paper for full list of authors 7 | "Xinyu Tang ", 8 | "Richard Shin ", 9 | "Huseyin A. Inan " 10 | ] 11 | readme = "README.md" 12 | packages = [ 13 | { include = "dp_few_shot_generation", from = "src" }, 14 | ] 15 | 16 | [tool.poetry.dependencies] 17 | python = ">=3.10,<3.12" 18 | datasets = "^2.11.0" 19 | autodp = "^0.2" 20 | lmapi = {path = "./lmapi", develop = true} 21 | more-itertools = "^9.1.0" 22 | scipy = "^1.10.1" 23 | typer = "^0.9.0" 24 | numpy = "^1.25.2" 25 | openai = "^0.28.0" 26 | pandas = "^2.1.0" 27 | prv-accountant = "^0.2.0" 28 | aiohttp = "^3.8.5" 29 | tqdm = "^4.66.1" 30 | 31 | [tool.poetry.dev-dependencies] 32 | black = "^23.3.0" 33 | pyright = "^1.1.303" 34 | pytest = "^7.3.1" 35 | ruff = "^0.0.261" 36 | 37 | [tool.poetry.group.dev.dependencies] 38 | pudb = "^2022.1.3" 39 | ipykernel = "^6.23.1" 40 | 41 | [tool.black] 42 | skip-magic-trailing-comma = true 43 | target-version = ["py310"] 44 | 45 | [tool.pyright] 46 | include = [ 47 | "src", 48 | "tests", 49 | ] 50 | reportUnnecessaryCast = "error" 51 | reportUnnecessaryTypeIgnoreComment = "error" 52 | 53 | [tool.ruff] 54 | # See hhttps://beta.ruff.rs/docs/rules/ for a list of rules. 55 | # This list is kept in the same order as the documentation. 56 | select = [ 57 | "E", 58 | "F", 59 | "W", 60 | "I", 61 | "UP", 62 | "B", 63 | "C4", 64 | "RUF", 65 | ] 66 | ignore = [ 67 | # Do not perform function call in argument defaults 68 | "B008", 69 | # Line too long 70 | "E501", 71 | ] 72 | target-version = "py310" 73 | src = [ 74 | "src", 75 | "tests", 76 | ] 77 | 78 | [build-system] 79 | requires = ["poetry-core>=1.0.0"] 80 | build-backend = "poetry.core.masonry.api" 81 | -------------------------------------------------------------------------------- /lmapi/src/lmapi/async_tools/server_sent_events.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from collections.abc import AsyncIterable, AsyncIterator 5 | from dataclasses import dataclass 6 | 7 | 8 | @dataclass 9 | class ServerSentEvent: 10 | event: str | None = None 11 | data: str | None = None 12 | id: str | None = None 13 | retry: int | None = None 14 | 15 | 16 | async def parse_event_stream( 17 | lines: AsyncIterable[str], 18 | ) -> AsyncIterator[ServerSentEvent]: 19 | """Implements part of https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation 20 | This function doesn't implement the logic specified there for dispatching the event; it only performs parsing. 21 | """ 22 | next_event = ServerSentEvent() 23 | 24 | async for line in lines: 25 | line = line.rstrip("\n") 26 | if len(line) == 0: 27 | yield next_event 28 | next_event = ServerSentEvent() 29 | continue 30 | 31 | if line[0] == ":": 32 | continue 33 | 34 | field, _, value = line.partition(":") 35 | if value[0] == " ": 36 | value = value[1:] 37 | if field == "event": 38 | next_event.event = value 39 | elif field == "data": 40 | if next_event.data is None: 41 | next_event.data = value + "\n" 42 | else: 43 | next_event.data += value + "\n" 44 | elif field == "id": 45 | next_event.id = value 46 | elif field == "retry": 47 | try: 48 | next_event.retry = int(value) 49 | except ValueError: 50 | pass 51 | 52 | # No need to handle `next_event` now, as the specification states: 53 | # Once the end of the file is reached, any pending data must be discarded. 54 | # (If the file ends in the middle of an event, before the final empty line, 55 | # the incomplete event is not dispatched.) 56 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | ### example scripts for epsilon=4. Hyperparameters are set according to Table 9. 2 | ### For 0-shot, set --num-valid 0 3 | ### For 4-shot (epsilon=0), set --num-private-train 0 and --num-private-train-splits 0 4 | ### For 4-shot (epsilon=infinity), set --sigma 0 and other parameters as Table 9. 5 | 6 | # AGNEWS 7 | for seed in 0 1 2 3 4 8 | do 9 | python -m dp_few_shot_generation.run_exp_agnews \ 10 | --sigma 0.39 \ 11 | --openai-model "babbage" \ 12 | --num-private-train 20 \ 13 | --set-num-public-train 0 \ 14 | --num-valid 4 \ 15 | --num-private-train-splits 10 \ 16 | --num-test 1000 \ 17 | --use-dp-prompts \ 18 | --sample-same-label-prompts \ 19 | --subsample-per-token \ 20 | --synth-seed $seed \ 21 | --eval-seed $seed 22 | done 23 | 24 | # DBPedia 25 | for seed in 0 1 2 3 4 26 | do 27 | python -m dp_few_shot_generation.run_exp_dbpedia \ 28 | --sigma 0.45 \ 29 | --openai-model "babbage" \ 30 | --num-private-train 80 \ 31 | --set-num-public-train 0 \ 32 | --num-valid 4 \ 33 | --num-private-train-splits 40 \ 34 | --num-test 1000 \ 35 | --use-dp-prompts \ 36 | --sample-same-label-prompts \ 37 | --subsample-per-token \ 38 | --synth-seed $seed \ 39 | --eval-seed $seed 40 | done 41 | 42 | # TREC 43 | for seed in 0 1 2 3 4 44 | do 45 | python -m dp_few_shot_generation.run_exp_trec \ 46 | --sigma 0.69 \ 47 | --openai-model "babbage" \ 48 | --num-private-train 80 \ 49 | --set-num-public-train 0 \ 50 | --num-valid 4 \ 51 | --num-private-train-splits 80 \ 52 | --num-test 1000 \ 53 | --no-public-token \ 54 | --use-dp-prompts \ 55 | --sample-same-label-prompts \ 56 | --subsample-per-token \ 57 | --synth-seed $seed \ 58 | --eval-seed $seed 59 | done 60 | 61 | # MIT-G 62 | for seed in 0 1 2 3 4 63 | do 64 | python -m dp_few_shot_generation.run_exp_movie \ 65 | --sigma 0.64 \ 66 | --openai-model "babbage" \ 67 | --num-private-train 80 \ 68 | --set-num-public-train 0 \ 69 | --num-valid 4 \ 70 | --num-private-train-splits 20 \ 71 | --num-test 1000 \ 72 | --use-dp-prompts \ 73 | --field-name Genre \ 74 | --subsample-per-token \ 75 | --synth-seed $seed \ 76 | --eval-seed $seed 77 | done 78 | 79 | # MIT-D 80 | for seed in 0 1 2 3 4 81 | do 82 | python -m dp_few_shot_generation.run_exp_movie \ 83 | --sigma 0.77 \ 84 | --openai-model "babbage" \ 85 | --num-private-train 80 \ 86 | --set-num-public-train 0 \ 87 | --num-valid 4 \ 88 | --num-private-train-splits 20 \ 89 | --num-test 1000 \ 90 | --use-dp-prompts \ 91 | --field-name Director \ 92 | --subsample-per-token \ 93 | --synth-seed $seed \ 94 | --eval-seed $seed 95 | done 96 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /src/dp_few_shot_generation/prob_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from typing import TypeVar, cast 5 | 6 | import numpy as np 7 | import scipy.special 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | def log_normalize(logprobs: dict[T, float]) -> dict[T, float]: 13 | """Normalize a log probability distribution so that all probabilities sum to 1. 14 | 15 | The input is a sparse distribution represented as a dictionary mapping from token IDs to log probabilities. 16 | """ 17 | 18 | normalizer = cast(float, scipy.special.logsumexp(list(logprobs.values()))) 19 | return {k: v - normalizer for k, v in logprobs.items()} 20 | 21 | 22 | def log_max_normalize(logprobs: dict[T, float]) -> dict[T, float]: 23 | """Normalize a log probability distribution so that maxium probabilities is to 1. 24 | 25 | The input is a sparse distribution represented as a dictionary mapping from token IDs to log probabilities. 26 | """ 27 | 28 | normalizer = cast(float, max(list(logprobs.values()))) 29 | return {k: v - normalizer for k, v in logprobs.items()} 30 | 31 | 32 | def densify(vocab_size: int, logprobs: dict[int, float]) -> np.ndarray: 33 | """Convert a sparse log-probability distribution into a dense one. 34 | 35 | The dense distribution is represented as a 1D tensor of size `vocab_size`. 36 | """ 37 | 38 | assert len(logprobs) > 0 39 | result = np.full((vocab_size,), -np.inf) 40 | for k, v in logprobs.items(): 41 | result[k] = v 42 | return result 43 | 44 | 45 | MINIMUM_MISSING_PROB = 1e-6 46 | 47 | 48 | def remove_logit_bias( 49 | biased_logprobs: dict[T, float], logit_bias: dict[T, float] | dict[T, int] 50 | ) -> dict[T, float]: 51 | """Undo the effects of logit_bias. 52 | 53 | This function is useful to get log probabilities of arbitrary tokens, 54 | as a workaround for how OpenAI's API only returns the top K most likely tokens. 55 | Give those arbitary tokens a large logit_bias (about 70 works well in practice). 56 | 57 | For this function to work properly: 58 | - The logit_bias should be set so that logprobs for all tokens in the logit_bias 59 | are provided in biased_logprobs. For example, they should all be large positive numbers. 60 | - The logit_bias should not be too big, because of limited floating point precision. 61 | OpenAI's API will not return probabilities smaller than np.log(np.finfo(np.float32).smallest_normal. 62 | - biased_logprobs should have more elements than logit_bias. Otherwise it's 63 | not possible to recover the original log probabilities. 64 | """ 65 | if not (logit_bias.keys() <= biased_logprobs.keys()): 66 | raise ValueError("logit_bias must be a subset of biased_logprobs") 67 | 68 | missing_key = object() 69 | missing_prob = 1 - np.exp(scipy.special.logsumexp(list(biased_logprobs.values()))) 70 | missing_logprob = float( 71 | np.log(missing_prob) if missing_prob > MINIMUM_MISSING_PROB else -np.inf 72 | ) 73 | 74 | debiased_logits: dict[T | object, float] = { 75 | k: v - logit_bias.get(k, 0) for k, v in biased_logprobs.items() 76 | } 77 | debiased_logits[missing_key] = missing_logprob 78 | 79 | normalized = log_normalize(debiased_logits) 80 | del normalized[missing_key] 81 | return cast(dict[T, float], normalized) 82 | -------------------------------------------------------------------------------- /data/process_movie.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # Assume saving data from https://github.com/tonyzhaozh/few-shot-learning/tree/main/data/slot-movies in the current folder 5 | 6 | import pandas as pd 7 | import json 8 | import pickle 9 | import numpy as np 10 | import os 11 | 12 | 13 | # Imported from https://github.com/tonyzhaozh/few-shot-learning/blob/main/data_utils.py#L122-#L173 14 | def generate_dataset(field_name, data_path="."): 15 | 16 | all_fields = ["Actor", "Award", "Character_Name", "Director", "Genre", "Opinion", "Origin", "Plot", "Quote", "Relationship", "Soundtrack", "Year"] 17 | assert field_name in all_fields 18 | all_fields.remove(field_name) 19 | filter_tags = [f"B-{field}" for field in all_fields] + [f"I-{field}" for field in all_fields] + ["O"] 20 | target_tags = [f"B-{field_name}", f"I-{field_name}"] 21 | 22 | with open(f'{data_path}/slot-movies/train', 'r') as f: 23 | lines = f.readlines() 24 | lines = [line.replace(' <=> ','').strip() for line in lines] 25 | train_answers = [] 26 | train_sentences = [] 27 | for line in lines: 28 | answer = '' 29 | untagged_line = '' 30 | for word in line.split(' '): 31 | contains_target = [tag in word for tag in target_tags] 32 | if np.any(contains_target): 33 | for tag in target_tags: 34 | word = word.replace(':' + tag, '') 35 | answer += word + ' ' 36 | for tag in filter_tags: 37 | word = word.replace(':' + tag, '') 38 | untagged_line += word + ' ' 39 | 40 | if answer != '': 41 | train_answers.append(answer.strip()) 42 | train_sentences.append(untagged_line.strip()) 43 | 44 | 45 | with open(f'{data_path}/slot-movies/test', 'r') as f: 46 | lines = f.readlines() 47 | lines = [line.replace(' <=> ','').strip() for line in lines] 48 | test_answers = [] 49 | test_sentences = [] 50 | for line in lines: 51 | answer = '' 52 | untagged_line = '' 53 | for word in line.split(' '): 54 | contains_target = [tag in word for tag in target_tags] 55 | if np.any(contains_target): 56 | for tag in target_tags: 57 | word = word.replace(':' + tag, '') 58 | answer += word + ' ' 59 | for tag in filter_tags: 60 | word = word.replace(':' + tag, '') 61 | untagged_line += word + ' ' 62 | 63 | if answer != '': 64 | test_answers.append(answer.strip()) 65 | test_sentences.append(untagged_line.strip()) 66 | if not os.path.isdir((f"{data_path}/movie/{field_name}")): 67 | os.makedirs(f"{data_path}/movie/{field_name}", exist_ok=True) 68 | train_data = {} 69 | train_data['content'] = train_sentences 70 | train_data['label'] = train_answers 71 | df = pd.DataFrame(train_data) 72 | df.to_csv(f"{data_path}/movie/{field_name}/train.csv") 73 | 74 | test_data = {} 75 | test_data['content'] = test_sentences 76 | test_data['label'] = test_answers 77 | df = pd.DataFrame(test_data) 78 | df.to_csv(f"{data_path}/movie/{field_name}/test.csv") 79 | 80 | for field_name in ['Director', 'Genre']: 81 | # by default save to ./movie/field_name, this is consistent with src/run_exp_movie.py#L362 82 | generate_dataset(field_name, data_path='./') -------------------------------------------------------------------------------- /lmapi/src/lmapi/lm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from abc import abstractmethod 5 | from collections.abc import AsyncGenerator, Mapping, Sequence 6 | from dataclasses import dataclass 7 | from typing import Protocol, TypeAlias 8 | 9 | import tiktoken 10 | 11 | 12 | class TokenWithLogprob(Protocol): 13 | @property 14 | def text(self) -> str: 15 | """The textual representation of the token. 16 | 17 | For tokens that don't constitute valid UTF-8, this will look like "bytes:\\x??\\x??". 18 | """ 19 | ... 20 | 21 | @property 22 | def bytes(self) -> bytes: 23 | "The bytes that make up the token." 24 | ... 25 | 26 | @property 27 | def logprob(self) -> float: 28 | ... 29 | 30 | @property 31 | def token_id(self) -> int: 32 | ... 33 | 34 | 35 | @dataclass 36 | class SampledToken: 37 | token: TokenWithLogprob 38 | top_choices: tuple[TokenWithLogprob, ...] 39 | 40 | 41 | # Types: 42 | # - rules for how we can interrupt the completion: stop, etc. 43 | # - whether we can get logprobs 44 | 45 | Completion: TypeAlias = Sequence[SampledToken] 46 | 47 | 48 | @dataclass 49 | class CompletionsSettings: 50 | """The meanings are the same as in https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#completions 51 | 52 | If any field is set to None, the defaults from that documentation page apply.""" 53 | 54 | temperature: float | None = None 55 | max_tokens: int | None = None 56 | n: int | None = None 57 | stop: str | Sequence[str] | None = None 58 | logprobs: int | None = None 59 | echo: bool | None = None 60 | logit_bias: Mapping[int, float] | None = None 61 | top_p: float | None=None 62 | 63 | 64 | class LM(Protocol): 65 | @abstractmethod 66 | async def completions( 67 | self, 68 | prompt: str | Sequence[int] | Sequence[str] | Sequence[Sequence[int]], 69 | settings: CompletionsSettings | None = None, 70 | ) -> Sequence[Completion]: 71 | """Returns completions given the prompt.""" 72 | ... 73 | 74 | @abstractmethod 75 | async def streaming_completions( 76 | self, prompt: str | Sequence[int], settings: CompletionsSettings | None = None 77 | ) -> Sequence[AsyncGenerator[SampledToken, None]]: 78 | """Returns completions in a streaming fashion as async generators of SampledText. 79 | 80 | To ensure that the HTTP connection to GPT is closed quickly, use `contextlib.aclosing`: 81 | ``` 82 | # streams is a sequence with length equal to `n` in the settings, 83 | # the number of completions to generate given the prompt. 84 | # In this example, we assume n = 1. 85 | streams = gpt.streaming_completions(prompt) 86 | [stream] = streams 87 | async with contextlib.aclosing(stream): 88 | async for sampled_text in stream: 89 | ... 90 | if some_condition: 91 | break 92 | ``` 93 | 94 | This pattern ensures that when the loop exits with the `break`, `it.aclose()` is called immediately, 95 | rather than only when `it` is garbage collected. 96 | """ 97 | ... 98 | 99 | @property 100 | def encoding(self) -> tiktoken.Encoding: 101 | """Returns the encoding scheme the model uses.""" 102 | ... 103 | -------------------------------------------------------------------------------- /lmapi/src/lmapi/async_tools/asyncitertools.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import asyncio 5 | import contextlib 6 | from collections.abc import AsyncGenerator, Callable, Iterable 7 | from typing import Any, TypeVar 8 | 9 | T = TypeVar("T") 10 | U = TypeVar("U") 11 | 12 | 13 | class FinishedMarker: 14 | pass 15 | 16 | 17 | async def bucket( 18 | num_buckets: int, 19 | it: AsyncGenerator[T, None], 20 | process: Callable[[T], Iterable[tuple[int, U]]], 21 | ) -> list[AsyncGenerator[U, None]]: 22 | """ 23 | Wraps `it` and distributes its items into `num_buckets` buckets based on the 24 | result of `process`. 25 | 26 | Similar to https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.bucket 27 | but with the following differences: 28 | - The buckets are integers in [0, num_buckets). 29 | - The `process` callable returns both the bucket and a transformed version 30 | of the element from `it`. 31 | """ 32 | queues: list[asyncio.Queue[U | FinishedMarker]] = [ 33 | asyncio.Queue() for _ in range(num_buckets) 34 | ] 35 | # Keep track of how many consumers are still open, so if they're all closed, 36 | # `fill_queues` can finish and close `it` early. 37 | num_open_consumers = num_buckets 38 | 39 | async def fill_queues() -> None: 40 | try: 41 | async with contextlib.aclosing(it): 42 | async for item in it: 43 | if num_open_consumers == 0: 44 | break 45 | for bucket_index, processed_item in process(item): 46 | queues[bucket_index].put_nowait(processed_item) 47 | finally: 48 | for queue in queues: 49 | queue.put_nowait(FinishedMarker()) 50 | 51 | fill_queues_task = asyncio.create_task(fill_queues()) 52 | 53 | async def gen(queue: asyncio.Queue[U | FinishedMarker]) -> AsyncGenerator[U, None]: 54 | nonlocal num_open_consumers 55 | try: 56 | # Below, we will run the generator for one iteration for returning it, 57 | # to ensure that it starts executing. We would like this so that when 58 | # the generator is closed, it runs the finally block below to decrement 59 | # `num_open_consumers`. If the generator is closed before it has run any 60 | # iterations, then the finalizer below will not run, and 61 | # `num_open_consumers` will not be decremented. 62 | # Therefore, we yield nothing once here. 63 | yield # type: ignore 64 | while True: 65 | next_item = await queue.get() 66 | if isinstance(next_item, FinishedMarker): 67 | break 68 | yield next_item 69 | 70 | # Check that there were no exceptions in the fill_queues_task 71 | # TODO(richard): If there was an exception, the printed traceback isn't very useful; improve that. 72 | await fill_queues_task 73 | finally: 74 | num_open_consumers -= 1 75 | 76 | # Run each generator for one iteration. 77 | result = [gen(queue) for queue in queues] 78 | for result_item in result: 79 | await anext(result_item) 80 | return result 81 | 82 | 83 | async def consume(it: AsyncGenerator[Any, None]) -> None: 84 | """Fully executes the async generator `it` while discarding all output.""" 85 | async with contextlib.aclosing(it): 86 | async for _ in it: 87 | pass 88 | -------------------------------------------------------------------------------- /lmapi/tests/test_lmapi/async_tools/test_asyncitertools.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import asyncio 5 | import itertools 6 | from collections.abc import AsyncGenerator 7 | 8 | import pytest 9 | 10 | from lmapi.async_tools.asyncitertools import bucket 11 | 12 | 13 | @pytest.mark.asyncio 14 | async def test_bucket_normal() -> None: 15 | async def gen() -> AsyncGenerator[int, None]: 16 | for i in range(4): 17 | if i % 2 == 0: 18 | await asyncio.sleep(0) 19 | yield i 20 | 21 | it_indices_permutations = set(itertools.permutations(i % 3 for i in range(4))) 22 | for it_indices in it_indices_permutations: 23 | bucketed_its = await bucket(3, gen(), lambda x: [(x % 3, x)]) 24 | regular_its = [iter([0, 3]), iter([1]), iter([2])] 25 | for it_index in it_indices: 26 | assert await anext(bucketed_its[it_index]) == next(regular_its[it_index]) 27 | 28 | for it in bucketed_its: 29 | with pytest.raises(StopAsyncIteration): 30 | await anext(it) 31 | 32 | for it in regular_its: 33 | with pytest.raises(StopIteration): 34 | next(it) 35 | 36 | 37 | class ExceptionForTest(Exception): 38 | pass 39 | 40 | 41 | @pytest.mark.asyncio 42 | async def test_bucket_exception() -> None: 43 | async def gen() -> AsyncGenerator[int, None]: 44 | for i in range(6): 45 | if i >= 5: 46 | raise ExceptionForTest 47 | 48 | if i % 2 == 0: 49 | await asyncio.sleep(0) 50 | yield i 51 | 52 | it0, it1, it2 = await bucket(3, gen(), lambda x: [(x % 3, x)]) 53 | 54 | assert await anext(it0) == 0 55 | assert await anext(it0) == 3 56 | with pytest.raises(ExceptionForTest): 57 | await anext(it0) 58 | 59 | assert await anext(it1) == 1 60 | assert await anext(it1) == 4 61 | with pytest.raises(ExceptionForTest): 62 | await anext(it1) 63 | 64 | assert await anext(it2) == 2 65 | with pytest.raises(ExceptionForTest): 66 | await anext(it2) 67 | 68 | 69 | @pytest.mark.asyncio 70 | async def test_bucket_closing() -> None: 71 | """Tests that `bucket` closes the generator when all consumers are closed.""" 72 | 73 | lock = asyncio.Lock() 74 | max_i = None 75 | gen_was_closed = False 76 | 77 | async def gen() -> AsyncGenerator[int, None]: 78 | nonlocal max_i, gen_was_closed 79 | try: 80 | for i in range(9): 81 | async with lock: 82 | max_i = i 83 | yield i 84 | await asyncio.sleep(0) 85 | except GeneratorExit: 86 | gen_was_closed = True 87 | 88 | async with lock: 89 | it0, it1, it2 = await bucket(3, gen(), lambda x: [(x % 3, x)]) 90 | await it0.aclose() 91 | await it2.aclose() 92 | 93 | # Give up control so that `gen` executes one iteration, yielding 0 94 | await asyncio.sleep(0) 95 | # Give up control so that `gen` executes one iteration, yielding 1 96 | await asyncio.sleep(0) 97 | 98 | async with lock: 99 | assert await anext(it1) == 1 100 | await it1.aclose() 101 | 102 | # `fill_queues` will run more iteration of `gen` and then realize that 103 | # `num_open_consumers == 0`, thus closing `gen` as well 104 | await asyncio.sleep(0) 105 | assert max_i == 2 106 | assert gen_was_closed 107 | 108 | # There should be no changes due to this additional sleep 109 | await asyncio.sleep(0) 110 | assert max_i == 2 111 | assert gen_was_closed 112 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### From https://github.com/github/gitignore/blob/main/Python.gitignore 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | #MACOS 9 | *.DS_Store 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /src/dp_few_shot_generation/lm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | from collections.abc import Sequence, Set 6 | 7 | from lmapi.async_tools import limits 8 | from lmapi.auth import OpenAiApiKey 9 | from lmapi.lm import LM, CompletionsSettings 10 | from lmapi.openai import OpenAI 11 | 12 | from dp_few_shot_generation.prob_utils import log_normalize 13 | 14 | 15 | def api_openai_com(model_name: str) -> OpenAI: 16 | return OpenAI.create( 17 | "https://api.openai.com/v1/completions", 18 | # This list is taken from 19 | # https://github.com/openai/tiktoken/blob/095924e02c85617df6889698d94515f91666c7ea/tiktoken/model.py#L13-L53 20 | # and modified, currently to accommodate how text-davinci-003 can actually produce <|fim_...|> tokens. 21 | { 22 | # chat 23 | "gpt-4": "cl100k_base", 24 | "gpt-3.5-turbo": "cl100k_base", 25 | # text 26 | "text-davinci-003": "p50k_edit", 27 | "text-davinci-002": "p50k_base", 28 | "text-davinci-001": "r50k_base", 29 | "text-curie-001": "r50k_base", 30 | "text-babbage-001": "r50k_base", 31 | "text-ada-001": "r50k_base", 32 | "davinci": "r50k_base", 33 | "curie": "r50k_base", 34 | "babbage": "r50k_base", 35 | "ada": "r50k_base", 36 | # code 37 | "code-davinci-002": "p50k_base", 38 | "code-davinci-001": "p50k_base", 39 | "code-cushman-002": "p50k_base", 40 | "code-cushman-001": "p50k_base", 41 | "davinci-codex": "p50k_base", 42 | "cushman-codex": "p50k_base", 43 | # edit 44 | "text-davinci-edit-001": "p50k_edit", 45 | "code-davinci-edit-001": "p50k_edit", 46 | # embeddings 47 | "text-embedding-ada-002": "cl100k_base", 48 | # old embeddings 49 | "text-similarity-davinci-001": "r50k_base", 50 | "text-similarity-curie-001": "r50k_base", 51 | "text-similarity-babbage-001": "r50k_base", 52 | "text-similarity-ada-001": "r50k_base", 53 | "text-search-davinci-doc-001": "r50k_base", 54 | "text-search-curie-doc-001": "r50k_base", 55 | "text-search-babbage-doc-001": "r50k_base", 56 | "text-search-ada-doc-001": "r50k_base", 57 | "code-search-babbage-code-001": "r50k_base", 58 | "code-search-ada-code-001": "r50k_base", 59 | # open source 60 | "gpt2": "gpt2", 61 | }[model_name], 62 | OpenAiApiKey(os.environ["OPENAI_API_KEY"]), 63 | limits.AdaptiveLimiter(), 64 | {"model": model_name}, 65 | ) 66 | 67 | 68 | MAX_TOP_LOGPROBS = 100 69 | MAX_LOGIT_BIAS = 100 70 | MIN_LOGIT_BIAS = -100 71 | 72 | 73 | async def next_logprobs( 74 | self: LM, prompt: str | Sequence[int], top_p=1 75 | ) -> dict[int, float]: 76 | # TODO: Don't hardcode "100" here 77 | [sampled_tokens] = await self.completions( 78 | prompt, 79 | CompletionsSettings(n=1, max_tokens=1, logprobs=100, stop=[""]), 80 | ) 81 | if len(sampled_tokens) == 0: 82 | if isinstance(prompt, str): 83 | prompt += "<|endoftext|>" 84 | else: 85 | prompt = [*prompt, self.encoding.encode_single_token("<|endoftext|>")] 86 | [[*_prev_tokens, sampled_token]] = await self.completions( 87 | prompt, 88 | CompletionsSettings( 89 | n=1, max_tokens=0, logprobs=100, echo=True, top_p=top_p 90 | ), 91 | ) 92 | else: 93 | [sampled_token] = sampled_tokens 94 | 95 | return {tlp.token_id: tlp.logprob for tlp in sampled_token.top_choices} 96 | 97 | 98 | async def normalized_logprobs_for_chosen_tokens( 99 | self: LM, prompt: Sequence[int], chosen_tokens: Set[int], top_p: float 100 | ) -> dict[int, float]: 101 | """Compute the probability that the prompt will be continued with each of the chosen tokens. 102 | 103 | The returned probability distribution is normalized over just the chosen tokens.""" 104 | 105 | assert ( 106 | len(chosen_tokens) <= MAX_TOP_LOGPROBS 107 | ), f"chosen_tokens must be <= {MAX_TOP_LOGPROBS} in length" 108 | 109 | logit_bias = {token_id: MAX_LOGIT_BIAS for token_id in chosen_tokens} 110 | [sampled_tokens] = await self.completions( 111 | prompt, 112 | CompletionsSettings( 113 | n=1, 114 | max_tokens=1, 115 | logprobs=MAX_TOP_LOGPROBS, 116 | logit_bias=logit_bias, 117 | top_p=top_p, 118 | ), 119 | ) 120 | if len(sampled_tokens) == 0: 121 | # Fall back to querying over the set 122 | chosen_tokens_list = list(chosen_tokens) 123 | 124 | result = await self.completions( 125 | [[*prompt, token_id] for token_id in chosen_tokens_list], 126 | CompletionsSettings(n=1, max_tokens=0, logprobs=0, echo=True), 127 | ) 128 | unnormalized_logprobs = { 129 | sampled_token.token.token_id: sampled_token.token.logprob 130 | for [*_prev_tokens, sampled_token] in result 131 | } 132 | return log_normalize(unnormalized_logprobs) 133 | else: 134 | [sampled_token] = sampled_tokens 135 | biased_logprobs = { 136 | tlp.token_id: tlp.logprob for tlp in sampled_token.top_choices 137 | } 138 | biased_logprobs_for_tokens = { 139 | token_id: biased_logprobs.get(token_id, float("-inf")) 140 | for token_id in chosen_tokens 141 | } 142 | return log_normalize(biased_logprobs_for_tokens) 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Privacy-preserving in-context learning with differentially private few-shot generation 2 | This is a codebase to perform privacy-preserving in-context learning with differentially private few-shot generation. 3 | 4 | ## Experiments 5 | See `run.sh` for example commands for AGNEWS/DBPedia/TREC/MIT-G/MIT-D. Here are a few explanations for the parameters in `run.sh`. 6 | 7 | ``` 8 | --sigma 0.39 # noise parameter 9 | --openai-model "babbage" # openai model 10 | --num-private-train 20 # num_private_train=MN. MN=0 and M=0 with num_valid=4 will get epsilon=0 (4-shot) results. 11 | --set-num-public-train 0 # by default set to 0. set_num_public_train >0 indicates additonal public data available. 12 | --num-valid 4 # num-valid=n. n samples to be generated for n-shot ICL 13 | --num-private-train-splits 10 # num_private_train_splits=M 14 | --num-test 1000 # if len(test_set)<1000, use the exact test set. Otherwise we sample 1000 samples from test set for evaluation 15 | --use-dp-prompts # generate prompts from private dataset 16 | --sample-same-label-prompts # sample_same_label_prompts=True, sample subsets from the sets with same targeted labels. 17 | --subsample-per-token # subsample_per_token=True, at each token generation, subsample a fresh new subset. 18 | --no-public-token # no_public_token=True, RVP=False 19 | --synth-seed 0 # random seed for subsampling in generation 20 | --eval-seed 0 # random seed for n-shot demonstrations sampling in evaluation 21 | ``` 22 | 23 | Note: Due to the randomness in generations caused by DP noise, the results may be slightly different from the reported values in the paper. 24 | 25 | ### About reproducing experiments with OpenAI models 26 | Our code uses the `logprobs` parameter of OpenAI's API (https://platform.openai.com/docs/api-reference/completions/create#logprobs) with a value of 100. 27 | By default, OpenAI currently allows up to 5 as the value for `logprobs`. Unless you obtain permission from OpenAI to use a larger value, the code will not work as-is. 28 | The existing code uses models which have been [deprecated by OpenAI](https://platform.openai.com/docs/deprecations/base-gpt-models) and may no longer be available in the future. 29 | 30 | As an alternative, you can consider using alternative LMs through software like https://github.com/vllm-project/vllm which provides an OpenAI-compatible API. 31 | It's also possible to use the `logit_bias` parameter (https://platform.openai.com/docs/api-reference/completions/create#logit_bias) to get top-k log probs for larger values of k 32 | by repeatedly querying the API with the same prefix while banning the most likely tokens obtained so far. 33 | 34 | 35 | ## Setup 36 | 1. Install Python 3.10. 37 | 38 | One way is to use [pyenv](https://github.com/pyenv/pyenv). 39 | Run `pyenv install --list | grep '^ *3.10' | tail -n1` to discover the most recent minor version of Python 3.10. 40 | Run `pyenv install 3.10.X` where `X` is the latest minor version available. 41 | 42 | 1. Install [Poetry](https://python-poetry.org/) following the [instructions](https://python-poetry.org/docs/#installation). 43 | 1. Configure Poetry to use your Python 3.10 installation. 44 | - If using `pyenv` setup above: run `poetry env use $(pyenv prefix 3.10.X)/bin/python` 45 | - Otherwise: run `poetry env use ` 46 | 1. Run `poetry install` to install the dependencies. 47 | 48 | ## IDE Setup 49 | ### IntelliJ/PyCharm 50 | - IntelliJ only: Install the Python plug-in. 51 | - Setup the Python Interpreter 52 | - PyCharm: open Settings then go to Python Interpreter. 53 | - IntelliJ: go to `File -> Project Structure -> Project -> Project SDK`. 54 | - Ensure that the Python environment in the `.venv` directory is selected. If needed, you can add a new interpreter. 55 | Choose "Poetry Environment" as the interpreter type. Select to use "Existing environment" with the interpreter in `.venv/bin/python`. 56 | - Setup source folders 57 | - Right click the `src` folder and choose `Mark Directory As -> Sources Root`. 58 | - Right click the `tests` folder and choose `Mark Directory As -> Test Sources Root`. 59 | - Configure `pytest`: In `Preferences -> Tools -> Python Integrated Tools`, set the default test runner to `pytest`. 60 | - ruff: You can try a [plugin](https://plugins.jetbrains.com/plugin/20574-ruff) as a replacement for running `make lint` manually. 61 | 62 | ### Visual Studio Code 63 | - Install the Python extension. 64 | - Open the root directory of the project. 65 | - Open the Command Palette and choose "Python: Select Interpreter". Ensure the one in `.venv` is selected. 66 | If not, you can choose "Enter interpreter path..." and enter `./.venv/bin/python`. 67 | - Configure `pytest`: open the Command Palette and choose "Python: Configure Tests". Choose pytest. Chooses `tests` as the root directory for tests. 68 | - ruff: You can try an [extension](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) if you want. 69 | 70 | ## Development 71 | We have automated code style checks, linting for common errors, type checking, and unit testing. 72 | - Run `make` to run formatting, linting, type checking, and unit testing, in that order. 73 | You can run each of the four checks separately with `make format-check`, `make lint`, `make pyright`, and `make pytest`. 74 | - Run `make format` and `make lint-fix` to automatically fix formatting errors and (some) linting errors. 75 | 76 | ## Project structure 77 | - `src/`: Python code for the project. 78 | - `tests/`: Unit tests for code in `src/`. 79 | - `data/`: Python code for data processing of MIT dataset. 80 | - `privacy_analysis/`: Python code for calculating the noise parameter. 81 | - `lmapi/`: a custom wrapper for OpenAI's API. 82 | 83 | ## Acknowledgments 84 | This project is built upon the foundation of [Calibrate Before Use: Improving Few-Shot Performance of Language Models](https://github.com/tonyzhaozh/few-shot-learning). 85 | We would like to thank the contributors and maintainers of the original repository for their valuable work. 86 | 87 | ## Contributing 88 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 89 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 90 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 91 | 92 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 93 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 94 | provided by the bot. You will only need to do this once across all repos using our CLA. 95 | 96 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 97 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 98 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 99 | 100 | ## Trademarks 101 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 102 | trademarks or logos is subject to and must follow 103 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 104 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 105 | Any use of third-party trademarks or logos are subject to those third-party's policies. 106 | -------------------------------------------------------------------------------- /lmapi/src/lmapi/async_tools/limits.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import asyncio 5 | import collections 6 | import dataclasses 7 | import functools 8 | import time 9 | from collections.abc import Awaitable, Callable 10 | from contextlib import AbstractAsyncContextManager, asynccontextmanager 11 | from dataclasses import dataclass 12 | from typing import Any, TypeVar, cast 13 | 14 | _MICROS_RATIO = 1_000_000 15 | _NANOS_RATIO = 1_000_000_000 16 | 17 | 18 | @dataclass 19 | class TokenBucket: 20 | """Used to prevent too much resource consumption during any time period. 21 | 22 | These tokens are not words (like in NLP) and the bucket is not a map. 23 | Instead the tokens are like coins and the bucket is like a physical bucket. 24 | 25 | See https://en.wikipedia.org/wiki/Token_bucket#Algorithm for more details. 26 | However, unlike typical implementations of token buckets, this one allows 27 | the refill rate to change at any time. 28 | """ 29 | 30 | # Maximum number of tokens that can be in the bucket, multiplied by 1_000_000. 31 | capacity_micros: int 32 | # Amount of tokens added per second, multipled by 1_000_000. 33 | refill_rate_micros: int 34 | 35 | # The current amount of tokens in the bucket. 36 | level_micros: int = 0 37 | 38 | # Time when `level` was last updated, in nanoseconds. 39 | last_updated_ns: int = dataclasses.field(default_factory=time.monotonic_ns) 40 | # One entry per call to `claim` that's waiting for the bucket to refill. 41 | _waiting: collections.deque[tuple[asyncio.Event, int]] = dataclasses.field( 42 | default_factory=collections.deque 43 | ) 44 | # Set by the next `claim` in line so that it can be notified if the refill rate changes. 45 | _next_token_refill: asyncio.Event | None = None 46 | 47 | @staticmethod 48 | def create(capacity: float, refill_rate: float, initial_level: float = 0): 49 | assert capacity >= 0 50 | assert refill_rate >= 0 51 | assert initial_level >= 0 52 | 53 | return TokenBucket( 54 | int(capacity * _MICROS_RATIO), 55 | int(refill_rate * _MICROS_RATIO), 56 | int(initial_level * _MICROS_RATIO), 57 | ) 58 | 59 | async def claim(self, tokens: float) -> None: 60 | tokens_micros = int(tokens * _MICROS_RATIO) 61 | assert tokens_micros <= self.capacity_micros 62 | del tokens 63 | 64 | event: asyncio.Event | None = None 65 | try: 66 | # Check if we have sufficient tokens already 67 | self._update_level() 68 | # async with self._lock: 69 | if self.level_micros >= tokens_micros: 70 | self.level_micros -= tokens_micros 71 | return 72 | 73 | event = asyncio.Event() 74 | self._waiting.append((event, tokens_micros)) 75 | while True: 76 | # Check whether we're first in line 77 | first_event, _ = self._waiting[0] 78 | if event is first_event: 79 | # We might already have enough tokens. If so, don't wait for more tokens to fill. 80 | if self.level_micros >= tokens_micros: 81 | break 82 | try: 83 | # Wait for the bucket to fill up. 84 | self._next_token_refill = event 85 | 86 | # Compute number of microseconds to wait until bucket will be full, 87 | # if the rate doesn't change in the meantime. 88 | # Round up to the next microsecond. 89 | timeout_micros, remainder = divmod( 90 | (tokens_micros - self.level_micros) * _MICROS_RATIO, 91 | self.refill_rate_micros, 92 | ) 93 | if remainder: 94 | timeout_micros += 1 95 | 96 | # Sleep for the computed number of microseconds. 97 | # We assume that 1) the floating point value for timeout 98 | # has sufficient precision to represent the desired 99 | # number of microseconds, and 2) wait_for's timer has 100 | # sufficient resolution. 101 | # If the rate changes before the timeout, the event will fire. 102 | await asyncio.wait_for( 103 | event.wait(), timeout=timeout_micros / _MICROS_RATIO 104 | ) 105 | 106 | # If we reach here, before we refilled all the tokens, 107 | # the refill rate changed. Recompute how long we should 108 | # wait until the tokens are refilled, and try again. 109 | # No matter how the rate changed, we shouldn't have 110 | # enough tokens at this point. 111 | event.clear() 112 | self._next_token_refill = None 113 | continue 114 | except asyncio.TimeoutError: 115 | # We should have collected enough tokens by now to take our turn 116 | self._update_level() 117 | break 118 | else: 119 | # Wait until we get to the start of the line 120 | await event.wait() 121 | event.clear() 122 | 123 | # async with self._lock: 124 | assert self.level_micros >= tokens_micros, ( 125 | self.level_micros, 126 | tokens_micros, 127 | ) 128 | self.level_micros -= tokens_micros 129 | 130 | self._waiting.popleft() 131 | if self._waiting: 132 | # Tell the next in line that they're first now. 133 | event, _ = self._waiting[0] 134 | event.set() 135 | return 136 | 137 | except asyncio.CancelledError: 138 | pass 139 | 140 | def add_to_rate( 141 | self, delta: float, min_rate: float | None = None, max_rate: float | None = None 142 | ) -> None: 143 | """Add to the refill rate. 144 | 145 | Supply a negative number to subtract from the rate. However, the 146 | negative number cannot cause the rate to become negative.""" 147 | new_rate = self.refill_rate_micros + round(delta * _MICROS_RATIO) 148 | assert new_rate >= 0, "Cannot have a negative refill rate" 149 | if min_rate is not None: 150 | new_rate = max(new_rate, round(min_rate * _MICROS_RATIO)) 151 | if max_rate is not None: 152 | new_rate = min(new_rate, round(max_rate * _MICROS_RATIO)) 153 | self.refill_rate_micros = new_rate 154 | self.reset_rate(new_rate) 155 | 156 | def multiply_rate( 157 | self, ratio: float, min_rate: float | None = None, max_rate: float | None = None 158 | ) -> None: 159 | """Apply a multiplicative factor to the refill rate.""" 160 | 161 | assert ratio >= 0, "Cannot have a negative ratio" 162 | new_rate = round(self.refill_rate_micros * ratio) 163 | if min_rate is not None: 164 | new_rate = max(new_rate, round(min_rate * _MICROS_RATIO)) 165 | if max_rate is not None: 166 | new_rate = min(new_rate, round(max_rate * _MICROS_RATIO)) 167 | self.reset_rate(new_rate) 168 | 169 | def reset_rate(self, new_rate: float) -> None: 170 | """Reset the refill rate.""" 171 | # Refill tokens with the previous rate, before updating the rate. 172 | self._update_level() 173 | 174 | self.refill_rate_micros = round(new_rate * _MICROS_RATIO) 175 | # Tell the task waiting for tokens to refill to wake up so that we can 176 | # recompute how much longer it should sleep, with the current rate. 177 | if self._next_token_refill: 178 | self._next_token_refill.set() 179 | 180 | def _update_level(self) -> None: 181 | now = time.monotonic_ns() 182 | elapsed_nanos = now - self.last_updated_ns 183 | self.level_micros = min( 184 | self.capacity_micros, 185 | self.level_micros 186 | + (elapsed_nanos * self.refill_rate_micros) // _NANOS_RATIO, 187 | ) 188 | self.last_updated_ns = now 189 | 190 | 191 | class RateLimitExceededError(Exception): 192 | pass 193 | 194 | 195 | AsyncCallableT = TypeVar("AsyncCallableT", bound=Callable[..., Awaitable[Any]]) 196 | AsyncContextManagerProducerT = TypeVar( 197 | "AsyncContextManagerProducerT", 198 | bound=Callable[..., AbstractAsyncContextManager[Any]], 199 | ) 200 | 201 | 202 | @dataclass 203 | class AdaptiveLimiter: 204 | """Adaptively rate-limit the wrapped function. 205 | 206 | When a call to the function is successful, we increase the rate additively; 207 | if it complains that the rate was exceeded, we decrease the rate multiplicatively. 208 | """ 209 | 210 | initial_qps: float = 10 211 | max_qps: float = 500 212 | min_qps: float = 1 213 | bucket: TokenBucket = dataclasses.field(init=False) 214 | 215 | def __post_init__(self): 216 | self.bucket = TokenBucket.create(self.max_qps, self.initial_qps) 217 | 218 | def wrap_async_callable(self, func: AsyncCallableT) -> AsyncCallableT: 219 | @functools.wraps(func) 220 | async def wrapped(*args, **kwargs): 221 | while True: 222 | # Wait our turn 223 | await self.bucket.claim(1) 224 | # Try calling the function 225 | try: 226 | result = await func(*args, **kwargs) 227 | self.bucket.add_to_rate(1, max_rate=self.max_qps) 228 | return result 229 | except RateLimitExceededError: 230 | self.bucket.multiply_rate(0.9, min_rate=self.min_qps) 231 | 232 | return cast(AsyncCallableT, wrapped) 233 | 234 | def wrap_async_context_manager_producer( 235 | self, func: AsyncContextManagerProducerT 236 | ) -> AsyncContextManagerProducerT: 237 | @functools.wraps(func) 238 | @asynccontextmanager 239 | async def wrapped(*args, **kwargs): 240 | while True: 241 | # Wait our turn 242 | await self.bucket.claim(1) 243 | # Try calling the function 244 | try: 245 | async with func(*args, **kwargs) as cm_result: 246 | yield cm_result 247 | self.bucket.add_to_rate(1, max_rate=self.max_qps) 248 | break 249 | except RateLimitExceededError: 250 | self.bucket.multiply_rate(0.9, min_rate=self.min_qps) 251 | 252 | return cast(AsyncContextManagerProducerT, wrapped) 253 | -------------------------------------------------------------------------------- /lmapi/src/lmapi/openai.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import ast 5 | import dataclasses 6 | import json 7 | from collections.abc import ( 8 | AsyncGenerator, 9 | AsyncIterable, 10 | AsyncIterator, 11 | Callable, 12 | Iterator, 13 | Sequence, 14 | ) 15 | from contextlib import AbstractAsyncContextManager, asynccontextmanager 16 | from contextvars import ContextVar 17 | from dataclasses import dataclass 18 | from functools import cached_property 19 | from typing import TYPE_CHECKING, Any, Protocol, TypeVar 20 | 21 | import aiohttp 22 | import tiktoken 23 | 24 | from lmapi.async_tools import asyncitertools, limits 25 | from lmapi.async_tools.server_sent_events import ServerSentEvent, parse_event_stream 26 | from lmapi.auth import AuthorizationProvider 27 | from lmapi.lm import LM, Completion, CompletionsSettings, SampledToken, TokenWithLogprob 28 | 29 | client_session: ContextVar[aiohttp.ClientSession] = ContextVar("client_session") 30 | 31 | 32 | @dataclass(frozen=True) 33 | class OpenAIAPIError(Exception): 34 | """Indicates a general OpenAI call error.""" 35 | 36 | status_code: int 37 | text: str 38 | 39 | @property 40 | def user_message(self) -> str: 41 | return "Model communication failure" 42 | 43 | @cached_property 44 | def debug_message(self) -> str: 45 | return f"Unexpected status code: {self.status_code}. {self.text}" 46 | 47 | 48 | @dataclass(slots=True) 49 | class OpenAITokenWithLogprob: 50 | if TYPE_CHECKING: 51 | 52 | def _check_protocol(self) -> TokenWithLogprob: 53 | return self 54 | 55 | text: str 56 | logprob: float 57 | _encoding: tiktoken.Encoding 58 | 59 | _bytes: bytes | None = dataclasses.field(init=False, default=None) 60 | 61 | @property 62 | def bytes(self) -> bytes: 63 | if self._bytes is None: 64 | self._bytes = openai_token_to_bytes(self.text) 65 | return self._bytes 66 | 67 | _token_id: int | None = dataclasses.field(init=False, default=None) 68 | 69 | @property 70 | def token_id(self) -> int: 71 | if self._token_id is None: 72 | self._token_id = self._encoding.encode_single_token(self.bytes) 73 | return self._token_id 74 | 75 | 76 | class _NextLogprobsFunction(Protocol): 77 | async def __call__(self, prompt: str | Sequence[int]) -> dict[int, float]: 78 | ... 79 | 80 | 81 | class _CompletionsFunction(Protocol): 82 | async def __call__( 83 | self, 84 | prompt: str | Sequence[int] | Sequence[str] | Sequence[Sequence[int]], 85 | settings: CompletionsSettings | None = None, 86 | ) -> Sequence[Completion]: 87 | ... 88 | 89 | 90 | @dataclass(frozen=True) 91 | class OpenAI: 92 | """Implementation of the LM protocol for OpenAI models. 93 | 94 | Arguments: 95 | url: The URL of the model. 96 | auth_provider: The authorization provider to use. 97 | encoding: The encoding of the model. 98 | default_completion_settings: The default completion settings to use. 99 | """ 100 | 101 | if TYPE_CHECKING: 102 | 103 | def _check_protocol(self) -> LM: 104 | return self 105 | 106 | url: str 107 | auth_provider: AuthorizationProvider 108 | encoding: tiktoken.Encoding 109 | default_completion_settings: dict[str, Any] 110 | additional_headers: dict[str, str] 111 | request_limiter: limits.AdaptiveLimiter | None = None 112 | 113 | @staticmethod 114 | def create( 115 | url: str, 116 | encoding_or_name: str | tiktoken.Encoding, 117 | auth_provider: AuthorizationProvider, 118 | request_limiter: limits.AdaptiveLimiter | None = None, 119 | default_completion_settings: dict[str, Any] | None = None, 120 | additional_headers: dict[str, Any] | None = None, 121 | ) -> "OpenAI": 122 | encoding = ( 123 | tiktoken.get_encoding(encoding_or_name) 124 | if isinstance(encoding_or_name, str) 125 | else encoding_or_name 126 | ) 127 | 128 | return OpenAI( 129 | url, 130 | auth_provider, 131 | encoding, 132 | default_completion_settings or {}, 133 | additional_headers or {}, 134 | request_limiter, 135 | ) 136 | 137 | async def completions( 138 | self, 139 | prompt: str | Sequence[int] | Sequence[str] | Sequence[Sequence[int]], 140 | settings: CompletionsSettings | None = None, 141 | ) -> Sequence[Completion]: 142 | return await self._completions_maybe_limited(prompt, settings) 143 | 144 | @cached_property 145 | def _completions_maybe_limited(self) -> _CompletionsFunction: 146 | if self.request_limiter is None: 147 | return self._completions 148 | return self.request_limiter.wrap_async_callable(self._completions) 149 | 150 | async def _completions( 151 | self, 152 | prompt: str | Sequence[int] | Sequence[str] | Sequence[Sequence[int]], 153 | settings: CompletionsSettings | None = None, 154 | ) -> Sequence[Completion]: 155 | """ 156 | Inner implementation of `completions` before `self.request_limiter` is applied. 157 | 158 | Must be called in an `async with client_session:` block where 159 | `client_session` is the same one used to construct this object. 160 | """ 161 | params = self._make_params(prompt, settings) 162 | async with client_session.get().post( 163 | self.url, 164 | headers={**self.auth_provider.headers(), **self.additional_headers}, 165 | json=params, 166 | ) as response: 167 | if response.status != 200: 168 | if response.status in (408, 429, 500): 169 | raise limits.RateLimitExceededError() 170 | else: 171 | raise OpenAIAPIError(response.status, await response.text()) 172 | resp = await response.json(content_type=None) 173 | if resp is None: 174 | raise limits.RateLimitExceededError() 175 | result = [ 176 | extract_sampled_tokens(choice["logprobs"], self.encoding) 177 | for choice in resp["choices"] 178 | ] 179 | return result 180 | 181 | async def streaming_completions( 182 | self, prompt: str | Sequence[int], settings: CompletionsSettings | None = None 183 | ) -> Sequence[AsyncGenerator[SampledToken, None]]: 184 | """Please see docstring for Gpt.streaming_completions.""" 185 | 186 | params = self._make_params(prompt, settings) 187 | n = params.get("n", 1) 188 | 189 | async def drop_unneeded( 190 | events: AsyncIterable[ServerSentEvent], 191 | ) -> AsyncIterator[dict[str, Any]]: 192 | """Process the stream of ServerSentEvents to drop the ones that we don't need. 193 | 194 | We drop: 195 | - The [DONE] event, which normally occurs one at the end 196 | - Unexpected events where there are no tokens sampled 197 | """ 198 | async for event in events: 199 | assert event.data is not None 200 | if event.data == "[DONE]\n": 201 | continue 202 | 203 | data = json.loads(event.data) 204 | assert len(data["choices"]) == 1 205 | choice = data["choices"][0] 206 | 207 | if len(choice["logprobs"].get("tokens", [])) == 0: 208 | # Sometimes the API sends us an event even though no tokens were sampled. 209 | # Ignore those cases. 210 | continue 211 | 212 | yield choice 213 | 214 | def process_choice( 215 | choice: dict[str, Any] 216 | ) -> Iterator[tuple[int, SampledToken]]: 217 | """Extracts the data returned by OpenAI's API into the SampledText object.""" 218 | # choice["text"] is always a valid Unicode string, and may correspond to multiple tokens. 219 | # Sometimes, choice["text"] is empty while choice["logprobs"]["tokens"] is not. 220 | # This seems to happen when choice["logprobs"]["tokens"] doesn't concatenate into valid UTF-8. 221 | # sampled_text = SampledText(choice["text"], choice["finish_reason"]) 222 | for sampled_tokens in extract_sampled_tokens( 223 | choice["logprobs"], self.encoding 224 | ): 225 | yield choice["index"], sampled_tokens 226 | 227 | if n == 1: 228 | 229 | async def gen_1() -> AsyncGenerator[SampledToken, None]: 230 | async with self.streaming_completions_client_response( 231 | params 232 | ) as response: 233 | # We use "utf-8" because 234 | # https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation says: 235 | # > Streams must be decoded using the UTF-8 decode algorithm. 236 | events = parse_event_stream( 237 | _bytes_to_str(response.content, "utf-8") 238 | ) 239 | # `response.content` is an async iterable which returns sequences of 240 | # bytes ending in b'\n'. That is slightly inappropriate because the SSE 241 | # specification allows lines to end in b'\r\n' or b'\r' as well. 242 | # But in practice, the GPT server never seems to use \r\n or \r. 243 | async for c in drop_unneeded(events): 244 | for _, item in process_choice(c): 245 | yield item 246 | 247 | return [gen_1()] 248 | else: 249 | 250 | async def gen_n() -> AsyncGenerator[dict[str, Any], None]: 251 | async with self.streaming_completions_client_response( 252 | params 253 | ) as response: 254 | events = parse_event_stream( 255 | _bytes_to_str(response.content, "utf-8") 256 | ) 257 | async for event in drop_unneeded(events): 258 | yield event 259 | 260 | return await asyncitertools.bucket(n, gen_n(), process_choice) 261 | 262 | @cached_property 263 | def streaming_completions_client_response( 264 | self, 265 | ) -> Callable[ 266 | [dict[str, Any]], AbstractAsyncContextManager[aiohttp.ClientResponse] 267 | ]: 268 | """Helper function for `streaming_completions`. 269 | 270 | This method returns `aiohttp.ClientResponse`; the data retrieved from it is parsed by `streaming_completions`. 271 | """ 272 | if self.request_limiter is None: 273 | return self._streaming_completions_client_response 274 | return self.request_limiter.wrap_async_context_manager_producer( 275 | self._streaming_completions_client_response 276 | ) 277 | 278 | @asynccontextmanager 279 | async def _streaming_completions_client_response( 280 | self, params: dict[str, Any] 281 | ) -> AsyncIterator[aiohttp.ClientResponse]: 282 | """Call as: `with gpt_impl._streaming_completions_client_response(params) as response: ...""" 283 | 284 | async with client_session.get().post( 285 | self.url, 286 | headers={**self.auth_provider.headers(), **self.additional_headers}, 287 | json={**params, "stream": True}, 288 | ) as response: 289 | if response.status != 200: 290 | if response.status in (408, 429, 500): 291 | raise limits.RateLimitExceededError( 292 | response.status, await response.text() 293 | ) 294 | else: 295 | raise OpenAIAPIError(response.status, await response.text()) 296 | yield response 297 | 298 | def _make_params( 299 | self, 300 | prompt: str | Sequence[int] | Sequence[str] | Sequence[Sequence[int]], 301 | settings: CompletionsSettings | None, 302 | ) -> dict[str, Any]: 303 | params = {"prompt": prompt, "logprobs": 0, **self.default_completion_settings} 304 | if settings is not None: 305 | params.update(_filter_none_values(dataclasses.asdict(settings))) 306 | return params 307 | 308 | 309 | K = TypeVar("K") 310 | V = TypeVar("V") 311 | 312 | 313 | def _filter_none_values(d: dict[K, V | None]) -> dict[K, V]: 314 | return {k: v for k, v in d.items() if v is not None} 315 | 316 | 317 | async def _bytes_to_str( 318 | bs: AsyncIterable[bytes], encoding: str 319 | ) -> AsyncGenerator[str, None]: 320 | async for b in bs: 321 | yield b.decode(encoding) 322 | 323 | 324 | def openai_token_to_bytes(token: str) -> bytes: 325 | if token.startswith("bytes:"): 326 | return ast.literal_eval(f"b'{token[6:]}'") 327 | else: 328 | return token.encode("utf-8") 329 | 330 | 331 | def extract_sampled_tokens( 332 | logprobs_info: dict, encoding: tiktoken.Encoding 333 | ) -> list[SampledToken]: 334 | tokens = logprobs_info["tokens"] 335 | token_logprobs = logprobs_info["token_logprobs"] 336 | top_logprobs = logprobs_info.get("top_logprobs") 337 | if top_logprobs is None: 338 | top_logprobs = [{}] * len(tokens) 339 | 340 | sampled_tokens: list[SampledToken] = [] 341 | for token, token_logprob, top_logprobs_for_token in zip( 342 | tokens, token_logprobs, top_logprobs, strict=True 343 | ): 344 | if top_logprobs_for_token is None: 345 | top_logprobs_for_token = {} 346 | sampled_tokens.append( 347 | SampledToken( 348 | OpenAITokenWithLogprob(token, token_logprob, encoding), 349 | tuple( 350 | OpenAITokenWithLogprob(t, lp, encoding) 351 | for t, lp in top_logprobs_for_token.items() 352 | ), 353 | ) 354 | ) 355 | 356 | return sampled_tokens 357 | -------------------------------------------------------------------------------- /src/dp_few_shot_generation/run_exp_movie.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import asyncio 5 | import math 6 | import re 7 | import sys 8 | import traceback 9 | from collections.abc import Iterable, Set 10 | from typing import Annotated, cast 11 | 12 | import aiohttp 13 | import more_itertools 14 | import numpy as np 15 | import openai 16 | import scipy.special 17 | import tqdm 18 | import typer 19 | from datasets import DatasetDict, load_dataset 20 | from lmapi.lm import LM, CompletionsSettings 21 | from lmapi.openai import client_session 22 | 23 | from dp_few_shot_generation.lm import ( 24 | api_openai_com, 25 | next_logprobs, 26 | normalized_logprobs_for_chosen_tokens, 27 | ) 28 | from dp_few_shot_generation.prob_utils import densify, log_max_normalize, log_normalize 29 | 30 | DEFAULT_NUM_PRIVATE_TRAIN = 80 31 | DEFAULT_NUM_PUBLIC_TRAIN = 0 32 | DEFAULT_NUM_VALID = 4 33 | DEFAULT_NUM_PRIVATE_TRAIN_SPLITS = 20 34 | DEFAULT_NUM_TEST = -1 35 | 36 | 37 | def format_full_datum_for_prompt(field_name, datum: dict[str, str]): 38 | return ( 39 | f'{field_name}: "{datum["label"]}"\nSentence: "{datum["content"] + " END"}"\n' 40 | ) 41 | 42 | 43 | def format_test_input_for_prompt(field_name, test_input: str): 44 | return f'{field_name}: "{test_input}"\nSentence: "' 45 | 46 | 47 | def construct_prompt_same(train_examples, test_example, field_name): 48 | prompt = f"" # prompt strucrture follows: https://github.com/tonyzhaozh/few-shot-learning/blob/main/data_utils.py#L427-L429 49 | for train_example in train_examples: 50 | prompt += "Sentence: " + train_example["content"] + "\n" 51 | prompt += f"{field_name}: " + train_example["label"] + "\n\n" 52 | prompt += "Sentence: " + test_example["content"] + "\n" 53 | prompt += f"{field_name}:" 54 | return prompt 55 | 56 | 57 | def complete(prompt, l, model_name, temp=0, num_log_probs=None, echo=False, n=None): 58 | # call GPT-3 API until result is provided and then return it 59 | response = None 60 | received = False 61 | while not received: 62 | try: 63 | response = openai.Completion.create( 64 | engine=model_name, 65 | prompt=prompt, 66 | max_tokens=l, 67 | temperature=temp, 68 | logprobs=num_log_probs, 69 | echo=echo, 70 | stop="\n", 71 | n=n, 72 | ) 73 | received = True 74 | except: 75 | error = sys.exc_info()[0] 76 | if ( 77 | error == openai.error.InvalidRequestError 78 | ): # something is wrong: e.g. prompt too long 79 | print(f"InvalidRequestError\nPrompt passed in:\n\n{prompt}\n\n") 80 | assert False 81 | 82 | print("API error:", error) 83 | time.sleep(1) 84 | return response 85 | 86 | 87 | def chunks(lst, n): 88 | """Yield successive n-sized chunks from lst.""" 89 | for i in range(0, len(lst), n): 90 | yield lst[i : i + n] 91 | 92 | 93 | def get_model_response( 94 | data, 95 | test_examples, 96 | openai_model, 97 | field_name, 98 | max_token_to_fill=5, 99 | additional_tokens=None, 100 | ): 101 | all_raw_answers = [] 102 | 103 | prompts = [] 104 | train_examples = data 105 | 106 | for test_example in test_examples: 107 | prompts.append(construct_prompt_same(train_examples, test_example, field_name)) 108 | 109 | if additional_tokens is not None: 110 | assert len(additional_tokens) == len(prompts) 111 | for i in range(len(prompts)): 112 | prompts[i] += additional_tokens[i] 113 | 114 | chunked_prompts = list(chunks(prompts, 20)) 115 | for test_chunk in chunked_prompts: 116 | response = complete( 117 | test_chunk, l=max_token_to_fill, model_name=openai_model, num_log_probs=100 118 | ) 119 | 120 | for answer_id, answer in enumerate(response["choices"]): 121 | all_raw_answers.append(answer) 122 | 123 | return all_raw_answers 124 | 125 | 126 | def em_accuracy_helper(prediction, label): 127 | correctness_list = [] 128 | for pred, l in zip(prediction, label): 129 | pred = pred.split("\n")[0] 130 | if pred == l: 131 | correctness_list.append(1) 132 | else: 133 | correctness_list.append(0) 134 | return np.mean(correctness_list) 135 | 136 | 137 | def merge_logprobs_topk_mean( 138 | private_next_logprobs: list[dict[int, float]], 139 | public_next_logprobs: dict[int, float], 140 | n_vocab: int, 141 | no_public_token: bool, 142 | normalize_max: bool, 143 | ) -> np.ndarray: 144 | # Compute merged distribution 145 | # logsumexp - np.log(...): compute mean probability of distribution 146 | if normalize_max: 147 | normalize_func = ( 148 | log_max_normalize # normalize max probability to 1, Exponential mechanism 149 | ) 150 | else: 151 | normalize_func = ( 152 | log_normalize # normalize sum probability to 1, Gaussian mechanism 153 | ) 154 | if no_public_token: 155 | merged_next_logprobs = scipy.special.logsumexp( 156 | np.stack( 157 | [ 158 | # Turn into a 1D tensor of size n_vocab 159 | densify( 160 | n_vocab, 161 | # Normalize distribution 162 | normalize_func( 163 | # Filter to the top 100 most likely next tokens according to the public prompt 164 | {k: v for k, v in lps.items()} 165 | ), 166 | ) 167 | for lps in private_next_logprobs 168 | ] 169 | ), 170 | axis=0, 171 | ) - np.log(len(private_next_logprobs)) 172 | 173 | else: 174 | merged_next_logprobs = scipy.special.logsumexp( 175 | np.stack( 176 | [ 177 | # Turn into a 1D tensor of size n_vocab 178 | densify( 179 | n_vocab, 180 | # Normalize distribution 181 | normalize_func( 182 | # Filter to the top 100 most likely next tokens according to the public prompt 183 | {k: v for k, v in lps.items() if k in public_next_logprobs} 184 | ), 185 | ) 186 | for lps in private_next_logprobs 187 | ] 188 | ), 189 | axis=0, 190 | ) - np.log(len(private_next_logprobs)) 191 | merged_next_probs = np.exp(merged_next_logprobs) 192 | return merged_next_probs 193 | 194 | 195 | async def generate_with_private_prompts( 196 | trainset, 197 | num_private_train, 198 | num_private_train_splits, 199 | instruction, 200 | public_train_prompt: str, 201 | stop_tokens: Set[int], 202 | test_input: str, 203 | lm: LM, 204 | noise_rng: np.random.RandomState, 205 | sigma: float, 206 | field_name: str, 207 | top_p, 208 | no_public_token: bool, 209 | subsample_per_token: bool, 210 | gen_seed: int, 211 | max_tokens: int, 212 | normalize_max: bool = False, 213 | ) -> list[int]: 214 | generated_token_ids: list[int] = [] 215 | 216 | stringified_test_datum = format_test_input_for_prompt(field_name, test_input) 217 | public_prompt = public_train_prompt + stringified_test_datum 218 | public_prompt_tokens = lm.encoding.encode(public_prompt) 219 | 220 | assert num_private_train_splits > 0 221 | 222 | train_subset = trainset.select(range(len(trainset)), keep_in_memory=True) 223 | 224 | if not subsample_per_token: 225 | private_train_subset = cast( 226 | Iterable[dict[str, str]], 227 | train_subset.shuffle(gen_seed, keep_in_memory=True).select( 228 | range(num_private_train), keep_in_memory=True 229 | ), 230 | ) 231 | private_train_splits = [ 232 | list(it) 233 | for it in more_itertools.distribute( 234 | num_private_train_splits, private_train_subset 235 | ) 236 | ] 237 | private_train_prompts = [ 238 | instruction 239 | + "\n".join( 240 | format_full_datum_for_prompt(field_name, datum) for datum in split 241 | ) 242 | for split in private_train_splits 243 | ] 244 | private_prompts = [ 245 | train_prompt + "\n" + stringified_test_datum 246 | for train_prompt in private_train_prompts 247 | ] 248 | private_prompts_tokens = [ 249 | lm.encoding.encode(prompt) for prompt in private_prompts 250 | ] 251 | 252 | cnt = 0 253 | for _ in tqdm.tqdm(range(max_tokens), total=float("inf"), unit=" tokens generated"): 254 | private_next_logprobs: list[dict[int, float]] 255 | public_next_logprobs: dict[int, float] 256 | # Split training dataset 257 | if subsample_per_token: 258 | private_train_subset = cast( 259 | Iterable[dict[str, str]], 260 | train_subset.shuffle(gen_seed + cnt, keep_in_memory=True).select( 261 | range(num_private_train), keep_in_memory=True 262 | ), 263 | ) 264 | cnt += 1 265 | private_train_splits = [ 266 | list(it) 267 | for it in more_itertools.distribute( 268 | num_private_train_splits, private_train_subset 269 | ) 270 | ] 271 | # Turn the data into prompts 272 | private_train_prompts = [ 273 | instruction 274 | + "\n".join( 275 | format_full_datum_for_prompt(field_name, datum) for datum in split 276 | ) 277 | for split in private_train_splits 278 | ] 279 | private_prompts = [ 280 | train_prompt + "\n" + stringified_test_datum 281 | for train_prompt in private_train_prompts 282 | ] 283 | private_prompts_tokens = [ 284 | lm.encoding.encode(prompt) for prompt in private_prompts 285 | ] 286 | if no_public_token: 287 | private_next_logprobs = await asyncio.gather( 288 | *( 289 | next_logprobs(lm, prompt + generated_token_ids, top_p=top_p) 290 | for prompt in private_prompts_tokens 291 | ) 292 | ) 293 | merged_next_probs = merge_logprobs_topk_mean( 294 | private_next_logprobs, 295 | None, 296 | lm.encoding.n_vocab, 297 | no_public_token, 298 | normalize_max, 299 | ) 300 | 301 | if normalize_max: 302 | # scale = 1/lambda 303 | noise = noise_rng.exponential(scale=sigma, size=lm.encoding.n_vocab) 304 | else: 305 | noise = noise_rng.normal(0, sigma, size=lm.encoding.n_vocab) 306 | merged_next_probs += noise 307 | else: 308 | public_next_logprobs = await next_logprobs( 309 | lm, public_prompt_tokens + generated_token_ids, top_p=top_p 310 | ) 311 | private_next_logprobs = await asyncio.gather( 312 | *( 313 | normalized_logprobs_for_chosen_tokens( 314 | lm, 315 | prompt + generated_token_ids, 316 | public_next_logprobs.keys(), 317 | top_p=top_p, 318 | ) 319 | for prompt in private_prompts_tokens 320 | ) 321 | ) 322 | merged_next_probs = merge_logprobs_topk_mean( 323 | private_next_logprobs, 324 | public_next_logprobs, 325 | lm.encoding.n_vocab, 326 | no_public_token, 327 | normalize_max, 328 | ) 329 | if normalize_max: 330 | # scale = 1/lambda 331 | noise = noise_rng.exponential( 332 | scale=sigma, size=len(public_next_logprobs) 333 | ) 334 | else: 335 | noise = noise_rng.normal(0, sigma, size=len(public_next_logprobs)) 336 | merged_next_probs[list(public_next_logprobs.keys())] += noise 337 | 338 | next_token_id = int(np.argmax(merged_next_probs)) 339 | 340 | if next_token_id in stop_tokens: 341 | break 342 | 343 | generated_token_ids.append(next_token_id) 344 | 345 | del next_token_id 346 | return generated_token_ids 347 | 348 | 349 | async def generate_with_public_prompt( 350 | public_train_prompt: str, 351 | stop_tokens: Set[str], 352 | test_input: str, 353 | lm: LM, 354 | field_name, 355 | max_tokens: int = 500, 356 | ) -> list[int]: 357 | public_prompt = public_train_prompt + format_test_input_for_prompt( 358 | field_name, test_input 359 | ) 360 | public_prompt_tokens = lm.encoding.encode(public_prompt) 361 | public_prompt_tokens = public_prompt 362 | 363 | [completion] = await lm.completions( 364 | public_prompt_tokens, 365 | CompletionsSettings( 366 | temperature=0.0, max_tokens=max_tokens, n=1, stop=list(stop_tokens) 367 | ), 368 | ) 369 | generated_tokens = [st.token.token_id for st in completion] 370 | return generated_tokens 371 | 372 | 373 | def _main( 374 | sigma: Annotated[float, typer.Option()], # noise parameters 375 | openai_model: Annotated[str, typer.Option()] = "babbage", 376 | print_prompts: Annotated[bool, typer.Option()] = False, 377 | # num_private_train=MN. MN=0 with num_valid=4 will get epsilon=0 (4-shot) results. 378 | num_private_train: Annotated[int, typer.Option()] = DEFAULT_NUM_PRIVATE_TRAIN, 379 | # by default set to 0. set_num_public_train >0 indicates additional public data available. 380 | set_num_public_train: Annotated[int, typer.Option()] = DEFAULT_NUM_PUBLIC_TRAIN, 381 | # num_valid=n. n samples to be generated for n-shot ICL 382 | num_valid: Annotated[int, typer.Option()] = DEFAULT_NUM_VALID, 383 | # num_private_train_splits=M 384 | num_private_train_splits: Annotated[ 385 | int, typer.Option() 386 | ] = DEFAULT_NUM_PRIVATE_TRAIN_SPLITS, 387 | num_test: Annotated[int, typer.Option()] = DEFAULT_NUM_TEST, 388 | # no_public_token=True, RVP=False; no_public_token=False, RVP=True 389 | no_public_token: Annotated[bool, typer.Option()] = False, 390 | # subsample_per_token=True: at each token generation, subsample a new test set 391 | subsample_per_token: Annotated[bool, typer.Option()] = False, 392 | use_dp_prompts: Annotated[bool, typer.Option()] = False, 393 | # normalize_max=True, Exponential mechanism; normalize_max=False, Gaussian mechanism 394 | normalize_max: Annotated[bool, typer.Option()] = False, 395 | # max_token_per_text=T_max 396 | max_token_per_text: Annotated[int, typer.Option()] = 20, 397 | # consistent with default parameters in the documentation https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#completions 398 | top_p: Annotated[float, typer.Option()] = 1, 399 | # random seed for subsampling in generation 400 | synth_seed: Annotated[int, typer.Option()] = 0, 401 | # random seed for n-shot demonstrations sampling in evaluation 402 | eval_seed: Annotated[int, typer.Option()] = 0, 403 | # choice bewteen ["Genre", "Director"] 404 | field_name: Annotated[str, typer.Option()] = "Genre", 405 | data_path: Annotated[str, typer.Option()] = "./../../data/movie", 406 | ): 407 | async def main(): 408 | if (num_private_train == 0) != (num_private_train_splits == 0): 409 | raise ValueError( 410 | "Either both or neither of --num-private-train and --num-private-train-splits can be 0" 411 | ) 412 | assert field_name in [ 413 | "Director", 414 | "Genre", 415 | ] # field_name from movie dataset include "Actor", "Award", "Character_Name", "Director", "Genre", "Opinion", "Origin", "Plot", "Quote", "Relationship", "Soundtrack", "Year"] 416 | command = ["python", sys.argv[0]] 417 | for x in sys.argv[1:]: 418 | if x.startswith("--"): 419 | assert '"' not in x and "'" not in x 420 | command.append(x) 421 | else: 422 | assert "'" not in x 423 | if re.match("^[a-zA-Z0-9_]+$", x): 424 | command.append("%s" % x) 425 | else: 426 | command.append("'%s'" % x) 427 | command = " ".join(command) 428 | print(command) 429 | 430 | if no_public_token: 431 | num_public_train = 0 432 | else: 433 | num_public_train = set_num_public_train 434 | 435 | lm = api_openai_com(openai_model) 436 | noise_rng = np.random.RandomState() 437 | 438 | data_files = {"train": "train.csv", "test": "test.csv"} 439 | data = cast( 440 | DatasetDict, 441 | load_dataset(f"{data_path}/{field_name}", data_files=data_files), 442 | ) 443 | 444 | trainset = data["train"].shuffle(seed=synth_seed, keep_in_memory=True) 445 | print("trainset length", len(trainset)) 446 | if num_public_train > 0: 447 | public_train_subset = cast( 448 | Iterable[dict[str, str]], 449 | trainset.select( 450 | range( 451 | len(trainset) - num_public_train, 452 | len(trainset), 453 | keep_in_memory=True, 454 | ) 455 | ), 456 | ) 457 | else: 458 | public_train_subset = [] 459 | 460 | trainset = trainset.select( 461 | range(len(trainset) - num_public_train), keep_in_memory=True 462 | ) 463 | query_subset = ( 464 | data["train"] 465 | .shuffle(seed=eval_seed, keep_in_memory=True) 466 | .select(range(num_valid), keep_in_memory=True) 467 | ) 468 | 469 | if use_dp_prompts: 470 | synthetic_examples = [] 471 | 472 | # Turn the data into prompts 473 | instruction = f"Given a propety of {field_name} for the film, generate a description accordingly and make sure to include the given {field_name} in the description.\n\n" 474 | print(instruction) 475 | 476 | public_train_prompt = instruction + "\n".join( 477 | format_full_datum_for_prompt(field_name, datum) 478 | for datum in public_train_subset 479 | ) 480 | 481 | if print_prompts: 482 | print(public_train_prompt) 483 | print("=========") 484 | 485 | if normalize_max: 486 | print("Exponential Mechanism") 487 | assert num_private_train == 0 or sigma > 0 488 | if num_private_train > 0: 489 | # scale == sigma_calib == 1/lambda. lambda for exponential distribution. 490 | sigma_calib = (2 / num_private_train_splits) * (1 / sigma) 491 | else: 492 | print("Gaussian Mechanism") 493 | if num_private_train_splits > 0: 494 | sigma_calib = math.sqrt(2) / num_private_train_splits * sigma 495 | else: 496 | sigma_calib = 0 497 | print( 498 | f"sigma in command {sigma}. sigma added according to sensitivity {sigma_calib}" 499 | ) 500 | 501 | stop_tokens = {"\n", "<|endoftext|>", " END"} 502 | stop_tokens_ids = {lm.encoding.encode_single_token(t) for t in stop_tokens} 503 | 504 | client_session.set(aiohttp.ClientSession()) 505 | 506 | async with client_session.get(): 507 | for i, test_datum in enumerate(query_subset, 1): 508 | print(f"# Example {i}") 509 | print(f'{field_name}: "{test_datum["label"]}"') 510 | 511 | np.random.seed(synth_seed + i) 512 | gen_seed = np.random.randint(100000) 513 | print(f"gen-seed: {gen_seed}") 514 | 515 | if num_private_train_splits > 0: 516 | generated_token_ids = await generate_with_private_prompts( 517 | trainset, 518 | num_private_train, 519 | num_private_train_splits, 520 | instruction, 521 | public_train_prompt, 522 | stop_tokens_ids, 523 | test_datum["label"], 524 | lm, 525 | noise_rng, 526 | sigma_calib, 527 | field_name, 528 | top_p, 529 | no_public_token, 530 | subsample_per_token, 531 | gen_seed, 532 | max_tokens=max_token_per_text 533 | - 1, # need one token length for EOS. 534 | normalize_max=normalize_max, 535 | ) 536 | else: 537 | generated_token_ids = await generate_with_public_prompt( 538 | public_train_prompt, 539 | stop_tokens, 540 | test_datum["label"], 541 | lm, 542 | field_name, 543 | max_tokens=max_token_per_text, 544 | ) 545 | 546 | generated = lm.encoding.decode(generated_token_ids).rstrip('"') 547 | 548 | print(f"Generated: {generated}\n") 549 | output_datum = {} 550 | output_datum["content"] = generated.strip() 551 | output_datum["label"] = test_datum["label"] 552 | synthetic_examples.append(output_datum) 553 | 554 | if num_test > 0 and num_test <= len(data["test"]): 555 | test_subset = ( 556 | data["test"] 557 | .shuffle(seed=12345, keep_in_memory=True) 558 | .select(range(num_test), keep_in_memory=True) 559 | ) 560 | else: 561 | test_subset = data["test"] 562 | 563 | all_raw_answers_wout_DP = get_model_response( 564 | query_subset, test_subset, openai_model, field_name 565 | ) 566 | all_orig_ans = [] 567 | for resp in all_raw_answers_wout_DP: 568 | all_orig_ans.append(resp["text"]) 569 | all_orig_ans = [ans.strip() for ans in all_orig_ans] 570 | test_labels = test_subset["label"] 571 | orig_accuracy = em_accuracy_helper(all_orig_ans, test_labels) 572 | print(f"Accuracy (original) without DP: {orig_accuracy}") 573 | 574 | if use_dp_prompts: 575 | all_raw_answers_w_DP = get_model_response( 576 | synthetic_examples, test_subset, openai_model, field_name 577 | ) 578 | all_orig_ans = [] 579 | for resp in all_raw_answers_w_DP: 580 | all_orig_ans.append(resp["text"]) 581 | all_orig_ans = [ans.strip() for ans in all_orig_ans] 582 | test_labels = test_subset["label"] 583 | orig_accuracy = em_accuracy_helper(all_orig_ans, test_labels) 584 | print(f"Accuracy (original) with DP: {orig_accuracy}") 585 | 586 | try: 587 | asyncio.run(main()) 588 | except KeyboardInterrupt: 589 | traceback.print_exc() 590 | raise 591 | 592 | 593 | if __name__ == "__main__": 594 | typer.run(_main) 595 | -------------------------------------------------------------------------------- /src/dp_few_shot_generation/run_exp_agnews.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import asyncio 5 | import math 6 | import re 7 | import sys 8 | import time 9 | import traceback 10 | from collections.abc import Iterable, Set 11 | from typing import Annotated, cast 12 | 13 | import aiohttp 14 | import more_itertools 15 | import numpy as np 16 | import openai 17 | import scipy.special 18 | import tqdm 19 | import typer 20 | from datasets import DatasetDict, load_dataset 21 | from lmapi.lm import LM, CompletionsSettings 22 | from lmapi.openai import client_session 23 | 24 | from dp_few_shot_generation.lm import ( 25 | api_openai_com, 26 | next_logprobs, 27 | normalized_logprobs_for_chosen_tokens, 28 | ) 29 | from dp_few_shot_generation.prob_utils import densify, log_max_normalize, log_normalize 30 | 31 | DEFAULT_NUM_PRIVATE_TRAIN = 20 32 | DEFAULT_NUM_PUBLIC_TRAIN = 0 33 | DEFAULT_NUM_VALID = 4 34 | DEFAULT_NUM_PRIVATE_TRAIN_SPLITS = 10 35 | DEFAULT_NUM_TEST = 1000 36 | 37 | labels = ["World", "Sport", "Business", "Technology"] 38 | label_dict = {0: ["World"], 1: ["Sports"], 2: ["Business"], 3: ["Technology"]} 39 | 40 | 41 | def format_full_datum_for_prompt(labels, datum: dict[str, str]): 42 | return f'News Type: "{labels[datum["label"]]}"\nText: "{datum["text"] + " END"}"\n' 43 | 44 | 45 | def format_test_input_for_prompt(labels, test_input: int): 46 | return f'News Type: "{labels[test_input]}"\nText: "' 47 | 48 | 49 | def construct_prompt_same(train_examples, test_example): 50 | prompt = "Classify the news articles into the categories of World, Sports, Business, and Technology.\n\n" 51 | for train_example in train_examples: 52 | prompt += "Article: " + train_example["text"] + "\n" 53 | prompt += "Answer: " + label_dict[train_example["label"]][0] + "\n\n" 54 | prompt += "Article: " + test_example["text"] + "\n" 55 | prompt += "Answer:" 56 | return prompt 57 | 58 | 59 | def complete(prompt, l, model_name, temp=0, num_log_probs=None, echo=False, n=None): 60 | # call GPT-3 API until result is provided and then return it 61 | response = None 62 | received = False 63 | while not received: 64 | try: 65 | response = openai.Completion.create( 66 | engine=model_name, 67 | prompt=prompt, 68 | max_tokens=l, 69 | temperature=temp, 70 | logprobs=num_log_probs, 71 | echo=echo, 72 | stop="\n", 73 | n=n, 74 | ) 75 | received = True 76 | except: 77 | error = sys.exc_info()[0] 78 | if ( 79 | error == openai.error.InvalidRequestError 80 | ): # something is wrong: e.g. prompt too long 81 | print(f"InvalidRequestError\nPrompt passed in:\n\n{prompt}\n\n") 82 | assert False 83 | 84 | print("API error:", error) 85 | time.sleep(1) 86 | return response 87 | 88 | 89 | def chunks(lst, n): 90 | """Yield successive n-sized chunks from lst.""" 91 | for i in range(0, len(lst), n): 92 | yield lst[i : i + n] 93 | 94 | 95 | def get_model_response(data, test_examples, openai_model): 96 | all_raw_answers = [] 97 | 98 | prompts = [] 99 | train_examples = data 100 | 101 | for test_example in test_examples: 102 | prompts.append(construct_prompt_same(train_examples, test_example)) 103 | 104 | chunked_prompts = list(chunks(prompts, 20)) 105 | for test_chunk in chunked_prompts: 106 | response = complete(test_chunk, l=1, model_name=openai_model, num_log_probs=100) 107 | 108 | for answer_id, answer in enumerate(response["choices"]): 109 | all_raw_answers.append(answer) 110 | 111 | return all_raw_answers 112 | 113 | 114 | def get_label_probs(all_raw_answers, test_subset): 115 | """Obtain model's label probability for each of the test examples. The returned prob is NOT normalized""" 116 | num_classes = len(label_dict) 117 | approx = False 118 | assert len(all_raw_answers) == len(test_subset) 119 | 120 | # Fill in the labels that is in the top k prob 121 | all_label_probs = [] 122 | all_missing_positions = [] 123 | cnt = 0 124 | for i, ans in enumerate(all_raw_answers): 125 | try: 126 | top_logprobs = ans["logprobs"]["top_logprobs"][ 127 | 0 128 | ] # [0] since we only ask for complete one more token 129 | except: 130 | cnt += 1 # cnt for corner case 131 | label_probs = [0] * len(label_dict.keys()) 132 | for j, label_list in label_dict.items(): 133 | all_found = True 134 | for label in label_list: # each possible label correspond to the same class 135 | label = " " + label # notice prompt does not have space after 'A:' 136 | if label in top_logprobs: 137 | label_probs[j] += np.exp(top_logprobs[label]) 138 | else: 139 | all_found = False 140 | if not all_found: 141 | position = (i, j) # (which test example, which label) 142 | all_missing_positions.append(position) 143 | all_label_probs.append(label_probs) 144 | all_label_probs = np.array(all_label_probs) # prob not normalized 145 | 146 | return all_label_probs # NOT NORMALIZED 147 | 148 | 149 | def eval_accuracy(all_label_probs, test_labels, mode=None, p_cf=None): 150 | # evaluate the accuracy with and without contextual calibration 151 | num_classes = all_label_probs.shape[1] 152 | if p_cf is None: 153 | # do not calibrate 154 | W = np.identity(num_classes) 155 | b = np.zeros([num_classes, 1]) 156 | else: 157 | # calibrate 158 | if mode == "diagonal_W": 159 | W = np.linalg.inv(np.identity(num_classes) * p_cf) 160 | b = np.zeros([num_classes, 1]) 161 | elif mode == "identity_W": 162 | W = np.identity(num_classes) 163 | b = -1 * np.expand_dims(p_cf, axis=-1) 164 | else: 165 | assert False 166 | 167 | correctness_list = [] 168 | assert len(all_label_probs) == len(test_labels) 169 | for label_probs, true_label in zip(all_label_probs, test_labels): 170 | if np.sum(label_probs) > 0: # corner case np.sum(label_probs)=0. 171 | label_probs = label_probs / np.sum(label_probs) # normalize to 1 172 | 173 | calibrate_label_probs = np.matmul(W, np.expand_dims(label_probs, axis=-1)) + b 174 | 175 | ans_label = np.argmax(calibrate_label_probs) 176 | if ans_label == true_label: 177 | correctness_list.append(1) 178 | else: 179 | correctness_list.append(0) 180 | return np.mean(correctness_list) 181 | 182 | 183 | def get_p_content_free(train_subset, openai_model, content_free_inputs=("N/A",)): 184 | """Query model with content free input, return its prediction probability for each label""" 185 | all_p_y = [] 186 | for content_free_input in content_free_inputs: 187 | prompt = construct_prompt_same(train_subset, content_free_input) 188 | p_y = [0] * len(label_dict) 189 | for i, answers in label_dict.items(): 190 | prob = 0 191 | for a in answers: 192 | prob += np.exp( 193 | complete( 194 | prompt + " " + a, 0, openai_model, echo=True, num_log_probs=1 195 | )["choices"][0]["logprobs"]["token_logprobs"][-1] 196 | ) 197 | p_y[i] = prob 198 | all_p_y.append(p_y) 199 | p_y = np.mean(np.array(all_p_y), axis=0) 200 | p_y = p_y / np.sum(p_y) # normalize 201 | return p_y 202 | 203 | 204 | def merge_logprobs_topk_mean( 205 | private_next_logprobs: list[dict[int, float]], 206 | public_next_logprobs: dict[int, float], 207 | n_vocab: int, 208 | no_public_token: bool, 209 | normalize_max: bool, 210 | ) -> np.ndarray: 211 | # Compute merged distribution 212 | # logsumexp - np.log(...): compute mean probability of distribution 213 | if normalize_max: 214 | normalize_func = ( 215 | log_max_normalize # normalize max probability to 1, Exponential mechanism 216 | ) 217 | else: 218 | normalize_func = ( 219 | log_normalize # normalize sum probability to 1, Gaussian mechanism 220 | ) 221 | if no_public_token: 222 | merged_next_logprobs = scipy.special.logsumexp( 223 | np.stack( 224 | [ 225 | # Turn into a 1D tensor of size n_vocab 226 | densify( 227 | n_vocab, 228 | # Normalize distribution 229 | normalize_func( 230 | # Filter to the top 100 most likely next tokens according to the public prompt 231 | {k: v for k, v in lps.items()} 232 | ), 233 | ) 234 | for lps in private_next_logprobs 235 | ] 236 | ), 237 | axis=0, 238 | ) - np.log(len(private_next_logprobs)) 239 | 240 | else: 241 | merged_next_logprobs = scipy.special.logsumexp( 242 | np.stack( 243 | [ 244 | # Turn into a 1D tensor of size n_vocab 245 | densify( 246 | n_vocab, 247 | # Normalize distribution 248 | normalize_func( 249 | # Filter to the top 100 most likely next tokens according to the public prompt 250 | {k: v for k, v in lps.items() if k in public_next_logprobs} 251 | ), 252 | ) 253 | for lps in private_next_logprobs 254 | ] 255 | ), 256 | axis=0, 257 | ) - np.log(len(private_next_logprobs)) 258 | merged_next_probs = np.exp(merged_next_logprobs) 259 | return merged_next_probs 260 | 261 | 262 | async def generate_with_private_prompts( 263 | trainset, 264 | num_private_train, 265 | num_private_train_splits, 266 | instruction, 267 | public_train_prompt: str, 268 | stop_tokens: Set[int], 269 | test_input: int, 270 | lm: LM, 271 | noise_rng: np.random.RandomState, 272 | sigma: float, 273 | labels, 274 | top_p, 275 | no_public_token: bool, 276 | subsample_per_token: bool, 277 | sample_same_label_prompts: bool, 278 | gen_seed: int, 279 | max_tokens: int = 100 - 1, 280 | normalize_max: bool = False, 281 | ) -> list[int]: 282 | generated_token_ids: list[int] = [] 283 | 284 | stringified_test_datum = format_test_input_for_prompt(labels, test_input) 285 | public_prompt = public_train_prompt + stringified_test_datum 286 | public_prompt_tokens = lm.encoding.encode(public_prompt) 287 | 288 | assert num_private_train_splits > 0 289 | if sample_same_label_prompts: 290 | select_list = [] 291 | for i in range(len(trainset)): 292 | if trainset[i]["label"] == test_input: 293 | select_list.append(i) 294 | train_subset = trainset.select(select_list, keep_in_memory=True) 295 | else: 296 | train_subset = trainset.select(range(len(trainset)), keep_in_memory=True) 297 | 298 | if not subsample_per_token: 299 | private_train_subset = cast( 300 | Iterable[dict[str, str]], 301 | train_subset.shuffle(gen_seed, keep_in_memory=True).select( 302 | range(num_private_train), keep_in_memory=True 303 | ), 304 | ) 305 | private_train_splits = [ 306 | list(it) 307 | for it in more_itertools.distribute( 308 | num_private_train_splits, private_train_subset 309 | ) 310 | ] 311 | private_train_prompts = [ 312 | instruction 313 | + "\n".join(format_full_datum_for_prompt(labels, datum) for datum in split) 314 | for split in private_train_splits 315 | ] 316 | private_prompts = [ 317 | train_prompt + "\n" + stringified_test_datum 318 | for train_prompt in private_train_prompts 319 | ] 320 | private_prompts_tokens = [ 321 | lm.encoding.encode(prompt) for prompt in private_prompts 322 | ] 323 | 324 | cnt = 0 325 | for _ in tqdm.tqdm(range(max_tokens), total=float("inf"), unit=" tokens generated"): 326 | private_next_logprobs: list[dict[int, float]] 327 | public_next_logprobs: dict[int, float] 328 | # Split training dataset 329 | if subsample_per_token: 330 | private_train_subset = cast( 331 | Iterable[dict[str, str]], 332 | train_subset.shuffle(gen_seed + cnt, keep_in_memory=True).select( 333 | range(num_private_train), keep_in_memory=True 334 | ), 335 | ) 336 | cnt += 1 337 | private_train_splits = [ 338 | list(it) 339 | for it in more_itertools.distribute( 340 | num_private_train_splits, private_train_subset 341 | ) 342 | ] 343 | # Turn the data into prompts 344 | private_train_prompts = [ 345 | instruction 346 | + "\n".join( 347 | format_full_datum_for_prompt(labels, datum) for datum in split 348 | ) 349 | for split in private_train_splits 350 | ] 351 | private_prompts = [ 352 | train_prompt + "\n" + stringified_test_datum 353 | for train_prompt in private_train_prompts 354 | ] 355 | private_prompts_tokens = [ 356 | lm.encoding.encode(prompt) for prompt in private_prompts 357 | ] 358 | if no_public_token: 359 | private_next_logprobs = await asyncio.gather( 360 | *( 361 | next_logprobs(lm, prompt + generated_token_ids, top_p=top_p) 362 | for prompt in private_prompts_tokens 363 | ) 364 | ) 365 | merged_next_probs = merge_logprobs_topk_mean( 366 | private_next_logprobs, 367 | None, 368 | lm.encoding.n_vocab, 369 | no_public_token, 370 | normalize_max, 371 | ) 372 | if normalize_max: 373 | # scale = 1/lambda 374 | noise = noise_rng.exponential(scale=sigma, size=lm.encoding.n_vocab) 375 | else: 376 | noise = noise_rng.normal(0, sigma, size=lm.encoding.n_vocab) 377 | merged_next_probs += noise 378 | else: 379 | public_next_logprobs = await next_logprobs( 380 | lm, public_prompt_tokens + generated_token_ids, top_p=top_p 381 | ) 382 | private_next_logprobs = await asyncio.gather( 383 | *( 384 | normalized_logprobs_for_chosen_tokens( 385 | lm, 386 | prompt + generated_token_ids, 387 | public_next_logprobs.keys(), 388 | top_p=top_p, 389 | ) 390 | for prompt in private_prompts_tokens 391 | ) 392 | ) 393 | merged_next_probs = merge_logprobs_topk_mean( 394 | private_next_logprobs, 395 | public_next_logprobs, 396 | lm.encoding.n_vocab, 397 | no_public_token, 398 | normalize_max, 399 | ) 400 | if normalize_max: 401 | # scale = 1/lambda 402 | noise = noise_rng.exponential( 403 | scale=sigma, size=len(public_next_logprobs) 404 | ) 405 | else: 406 | noise = noise_rng.normal(0, sigma, size=len(public_next_logprobs)) 407 | merged_next_probs[list(public_next_logprobs.keys())] += noise 408 | 409 | next_token_id = int(np.argmax(merged_next_probs)) 410 | 411 | if next_token_id in stop_tokens: 412 | break 413 | 414 | generated_token_ids.append(next_token_id) 415 | 416 | del next_token_id 417 | return generated_token_ids 418 | 419 | 420 | async def generate_with_public_prompt( 421 | public_train_prompt: str, 422 | stop_tokens: Set[str], 423 | test_input: str, 424 | lm: LM, 425 | labels, 426 | max_tokens: int = 500, 427 | ) -> list[int]: 428 | public_prompt = public_train_prompt + format_test_input_for_prompt( 429 | labels, test_input 430 | ) 431 | public_prompt_tokens = lm.encoding.encode(public_prompt) 432 | public_prompt_tokens = public_prompt 433 | 434 | [completion] = await lm.completions( 435 | public_prompt_tokens, 436 | CompletionsSettings( 437 | temperature=0.0, max_tokens=max_tokens, n=1, stop=list(stop_tokens) 438 | ), 439 | ) 440 | generated_tokens = [st.token.token_id for st in completion] 441 | return generated_tokens 442 | 443 | 444 | def select_uniform_n_shots_over_labels(data, n_shots): 445 | select_list = [] 446 | n_shots_per_label = math.ceil(n_shots / len(labels)) 447 | labels_counter = {label[1][0]: n_shots_per_label for label in label_dict.items()} 448 | n_shots_selected = 0 449 | for i in range(len(data)): 450 | label = label_dict[data[i]["label"]][0] 451 | if labels_counter[label] == 0: 452 | continue 453 | else: 454 | labels_counter[label] -= 1 455 | select_list.append(i) 456 | n_shots_selected += 1 457 | if n_shots_selected == n_shots: 458 | break 459 | query_subset = data.select(select_list, keep_in_memory=True) 460 | return query_subset 461 | 462 | 463 | def _main( 464 | sigma: Annotated[float, typer.Option()], # noise parameters 465 | openai_model: Annotated[str, typer.Option()] = "babbage", 466 | print_prompts: Annotated[bool, typer.Option()] = False, 467 | # num_private_train=MN. MN=0 with num_valid=4 will get epsilon=0 (4-shot) results. 468 | num_private_train: Annotated[int, typer.Option()] = DEFAULT_NUM_PRIVATE_TRAIN, 469 | # by default set to 0. set_num_public_train >0 indicates additional public data available. 470 | set_num_public_train: Annotated[int, typer.Option()] = DEFAULT_NUM_PUBLIC_TRAIN, 471 | # num_valid=n. n samples to be generated for n-shot ICL 472 | num_valid: Annotated[int, typer.Option()] = DEFAULT_NUM_VALID, 473 | # num_private_train_splits=M 474 | num_private_train_splits: Annotated[ 475 | int, typer.Option() 476 | ] = DEFAULT_NUM_PRIVATE_TRAIN_SPLITS, 477 | num_test: Annotated[int, typer.Option()] = DEFAULT_NUM_TEST, 478 | # no_public_token=True, RVP=False; no_public_token=False, RVP=True 479 | no_public_token: Annotated[bool, typer.Option()] = False, 480 | # subsample_per_token=True: at each token generation, subsample a new test set 481 | subsample_per_token: Annotated[bool, typer.Option()] = False, 482 | use_dp_prompts: Annotated[bool, typer.Option()] = False, 483 | # sample_same_label_prompts=True: sample subsets from the sets with same targeted labels. 484 | sample_same_label_prompts: Annotated[bool, typer.Option()] = False, 485 | # normalize_max=True, Exponential mechanism; normalize_max=False, Gaussian mechanism 486 | normalize_max: Annotated[bool, typer.Option()] = False, 487 | # max_token_per_text=T_max 488 | max_token_per_text: Annotated[int, typer.Option()] = 100, 489 | # consistent with default parameters in the documentation https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#completions 490 | top_p: Annotated[float, typer.Option()] = 1, 491 | # random seed for subsampling in generation 492 | synth_seed: Annotated[int, typer.Option()] = 0, 493 | # random seed for n-shot demonstrations sampling in evaluation 494 | eval_seed: Annotated[int, typer.Option()] = 0, 495 | ): 496 | async def main(): 497 | if (num_private_train == 0) != (num_private_train_splits == 0): 498 | raise ValueError( 499 | "Either both or neither of --num-private-train and --num-private-train-splits can be 0" 500 | ) 501 | command = ["python", sys.argv[0]] 502 | for x in sys.argv[1:]: 503 | if x.startswith("--"): 504 | assert '"' not in x and "'" not in x 505 | command.append(x) 506 | else: 507 | assert "'" not in x 508 | if re.match("^[a-zA-Z0-9_]+$", x): 509 | command.append("%s" % x) 510 | else: 511 | command.append("'%s'" % x) 512 | command = " ".join(command) 513 | print(command) 514 | 515 | if no_public_token: 516 | num_public_train = 0 517 | else: 518 | num_public_train = set_num_public_train 519 | 520 | lm = api_openai_com(openai_model) 521 | noise_rng = np.random.RandomState() 522 | 523 | data = cast(DatasetDict, load_dataset("ag_news")) 524 | print(labels) 525 | 526 | trainset = data["train"].shuffle(seed=synth_seed, keep_in_memory=True) 527 | print("trainset length", len(trainset)) 528 | if num_public_train > 0: 529 | public_train_subset = cast( 530 | Iterable[dict[str, str]], 531 | trainset.select( 532 | range( 533 | len(trainset) - num_public_train, 534 | len(trainset), 535 | keep_in_memory=True, 536 | ) 537 | ), 538 | ) 539 | else: 540 | public_train_subset = [] 541 | 542 | trainset = trainset.select( 543 | range(len(trainset) - num_public_train), keep_in_memory=True 544 | ) 545 | queryset = data["train"].shuffle(seed=eval_seed, keep_in_memory=True) 546 | query_subset = select_uniform_n_shots_over_labels(queryset, num_valid) 547 | 548 | if use_dp_prompts: 549 | synthetic_examples = [] 550 | 551 | # Turn the data into prompts 552 | instruction = "Given a label of news type, generate the chosen type of news accordingly.\n\n" 553 | 554 | public_train_prompt = instruction + "\n".join( 555 | format_full_datum_for_prompt(labels, datum) 556 | for datum in public_train_subset 557 | ) 558 | 559 | if print_prompts: 560 | print(public_train_prompt) 561 | print("=========") 562 | 563 | if normalize_max: 564 | print("Exponential Mechanism") 565 | assert num_private_train == 0 or sigma > 0 566 | if num_private_train > 0: 567 | # scale == sigma_calib == 1/lambda. lambda for exponential distribution. 568 | sigma_calib = (2 / num_private_train_splits) * (1 / sigma) 569 | else: 570 | print("Gaussian Mechanism") 571 | if num_private_train_splits > 0: 572 | sigma_calib = math.sqrt(2) / num_private_train_splits * sigma 573 | else: 574 | sigma_calib = 0 575 | print( 576 | f"sigma in command {sigma}. sigma added according to sensitivity {sigma_calib}" 577 | ) 578 | 579 | stop_tokens = {"\n", "<|endoftext|>", " END"} 580 | stop_tokens_ids = {lm.encoding.encode_single_token(t) for t in stop_tokens} 581 | 582 | client_session.set(aiohttp.ClientSession()) 583 | 584 | async with client_session.get(): 585 | for i, test_datum in enumerate(query_subset, 1): 586 | print(f"# Example {i}") 587 | print(f'News Type: "{labels[test_datum["label"]]}"') 588 | print(f'References:\n "{test_datum["text"]}"') 589 | 590 | np.random.seed(synth_seed + i) 591 | gen_seed = np.random.randint(100000) 592 | print(f"gen-seed: {gen_seed}") 593 | 594 | if num_private_train_splits > 0: 595 | generated_token_ids = await generate_with_private_prompts( 596 | trainset, 597 | num_private_train, 598 | num_private_train_splits, 599 | instruction, 600 | public_train_prompt, 601 | stop_tokens_ids, 602 | test_datum["label"], 603 | lm, 604 | noise_rng, 605 | sigma_calib, 606 | labels, 607 | top_p, 608 | no_public_token, 609 | subsample_per_token, 610 | sample_same_label_prompts, 611 | gen_seed, 612 | max_tokens=max_token_per_text 613 | - 1, # need one token length for EOS. 614 | normalize_max=normalize_max, 615 | ) 616 | else: 617 | generated_token_ids = await generate_with_public_prompt( 618 | public_train_prompt, 619 | stop_tokens, 620 | test_datum["label"], 621 | lm, 622 | labels, 623 | max_tokens=max_token_per_text, 624 | ) 625 | 626 | generated = lm.encoding.decode(generated_token_ids).rstrip('"') 627 | 628 | print(f"Generated: {generated}\n") 629 | output_datum = {} 630 | output_datum["text"] = generated.strip() 631 | output_datum["label"] = test_datum["label"] 632 | synthetic_examples.append(output_datum) 633 | 634 | if num_test > 0: 635 | test_subset = ( 636 | data["test"] 637 | .shuffle(seed=12345, keep_in_memory=True) 638 | .select(range(num_test), keep_in_memory=True) 639 | ) 640 | test_labels = [test_example["label"] for test_example in test_subset] 641 | 642 | content_free_inputs = [{"text": "N/A"}, {"text": ""}, {"text": "[MASK]"}] 643 | p_cf_wout_DP = get_p_content_free( 644 | query_subset, openai_model, content_free_inputs=content_free_inputs 645 | ) 646 | 647 | all_raw_answers_wout_DP = get_model_response( 648 | query_subset, test_subset, openai_model 649 | ) 650 | all_label_probs_wout_DP = get_label_probs( 651 | all_raw_answers_wout_DP, test_subset 652 | ) 653 | 654 | acc_original_wout_DP = eval_accuracy(all_label_probs_wout_DP, test_labels) 655 | acc_calibrated_wout_DP = eval_accuracy( 656 | all_label_probs_wout_DP, 657 | test_labels, 658 | mode="diagonal_W", 659 | p_cf=p_cf_wout_DP, 660 | ) 661 | 662 | print(f"Accuracy (original) without DP: {acc_original_wout_DP}") 663 | print(f"Accuracy (calibrated) without DP: {acc_calibrated_wout_DP}") 664 | 665 | if use_dp_prompts: 666 | p_cf_w_DP = get_p_content_free( 667 | synthetic_examples, 668 | openai_model, 669 | content_free_inputs=content_free_inputs, 670 | ) 671 | 672 | all_raw_answers_w_DP = get_model_response( 673 | synthetic_examples, test_subset, openai_model 674 | ) 675 | all_label_probs_w_DP = get_label_probs( 676 | all_raw_answers_w_DP, test_subset 677 | ) 678 | 679 | acc_original_w_DP = eval_accuracy(all_label_probs_w_DP, test_labels) 680 | acc_calibrated_w_DP = eval_accuracy( 681 | all_label_probs_w_DP, test_labels, mode="diagonal_W", p_cf=p_cf_w_DP 682 | ) 683 | 684 | print(f"Accuracy (original) with DP: {acc_original_w_DP}") 685 | print(f"Accuracy (calibrated) with DP: {acc_calibrated_w_DP}") 686 | 687 | try: 688 | asyncio.run(main()) 689 | except KeyboardInterrupt: 690 | traceback.print_exc() 691 | raise 692 | 693 | 694 | if __name__ == "__main__": 695 | typer.run(_main) 696 | -------------------------------------------------------------------------------- /src/dp_few_shot_generation/run_exp_trec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import asyncio 5 | import math 6 | import re 7 | import sys 8 | import time 9 | import traceback 10 | from collections.abc import Iterable, Set 11 | from typing import Annotated, cast 12 | 13 | import aiohttp 14 | import more_itertools 15 | import numpy as np 16 | import openai 17 | import scipy.special 18 | import tqdm 19 | import typer 20 | from datasets import DatasetDict, load_dataset 21 | from lmapi.lm import LM, CompletionsSettings 22 | from lmapi.openai import client_session 23 | 24 | from dp_few_shot_generation.lm import ( 25 | api_openai_com, 26 | next_logprobs, 27 | normalized_logprobs_for_chosen_tokens, 28 | ) 29 | from dp_few_shot_generation.prob_utils import densify, log_max_normalize, log_normalize 30 | 31 | DEFAULT_NUM_PRIVATE_TRAIN = 80 32 | DEFAULT_NUM_PUBLIC_TRAIN = 0 33 | DEFAULT_NUM_VALID = 4 34 | DEFAULT_NUM_PRIVATE_TRAIN_SPLITS = 80 35 | DEFAULT_NUM_TEST = 1000 36 | 37 | labels = ["Ab", "Entity", "Description", "Person", "Location", "Number"] 38 | label_dict = { 39 | 0: ["Ab"], 40 | 1: ["Entity"], 41 | 2: ["Description"], 42 | 3: ["Person"], 43 | 4: ["Location"], 44 | 5: ["Number"], 45 | } 46 | 47 | 48 | def format_full_datum_for_prompt(labels, datum: dict[str, str]): 49 | return f'Answer Type: "{labels[datum["coarse_label"]]}"\nText: "{datum["text"] + " END"}"\n' 50 | 51 | 52 | def format_test_input_for_prompt(labels, test_input: int): 53 | return f'Answer Type: "{labels[test_input]}"\nText: "' 54 | 55 | 56 | def construct_prompt_same(train_examples, test_example): 57 | prompt = "Classify the questions based on whether their answer type is a Number, Location, Person, Description, Entity, or Abbreviation.\n\n" 58 | for train_example in train_examples: 59 | prompt += "Question: " + train_example["text"] + "\n" 60 | prompt += ( 61 | "Answer Type: " + label_dict[train_example["coarse_label"]][0] + "\n\n" 62 | ) 63 | prompt += "Question: " + test_example["text"] + "\n" 64 | prompt += "Answer Type:" 65 | return prompt 66 | 67 | 68 | def complete(prompt, l, model_name, temp=0, num_log_probs=None, echo=False, n=None): 69 | # call GPT-3 API until result is provided and then return it 70 | response = None 71 | received = False 72 | while not received: 73 | try: 74 | response = openai.Completion.create( 75 | engine=model_name, 76 | prompt=prompt, 77 | max_tokens=l, 78 | temperature=temp, 79 | logprobs=num_log_probs, 80 | echo=echo, 81 | stop="\n", 82 | n=n, 83 | ) 84 | received = True 85 | except: 86 | error = sys.exc_info()[0] 87 | if ( 88 | error == openai.error.InvalidRequestError 89 | ): # something is wrong: e.g. prompt too long 90 | print(f"InvalidRequestError\nPrompt passed in:\n\n{prompt}\n\n") 91 | assert False 92 | 93 | print("API error:", error) 94 | time.sleep(1) 95 | return response 96 | 97 | 98 | def chunks(lst, n): 99 | """Yield successive n-sized chunks from lst.""" 100 | for i in range(0, len(lst), n): 101 | yield lst[i : i + n] 102 | 103 | 104 | def get_model_response(data, test_examples, openai_model): 105 | all_raw_answers = [] 106 | 107 | prompts = [] 108 | train_examples = data 109 | 110 | for test_example in test_examples: 111 | prompts.append(construct_prompt_same(train_examples, test_example)) 112 | 113 | chunked_prompts = list(chunks(prompts, 20)) 114 | for test_chunk in chunked_prompts: 115 | response = complete(test_chunk, l=1, model_name=openai_model, num_log_probs=100) 116 | 117 | for answer_id, answer in enumerate(response["choices"]): 118 | all_raw_answers.append(answer) 119 | 120 | return all_raw_answers 121 | 122 | 123 | def get_label_probs(all_raw_answers, test_subset): 124 | """Obtain model's label probability for each of the test examples. The returned prob is NOT normalized""" 125 | num_classes = len(label_dict) 126 | approx = False 127 | assert len(all_raw_answers) == len(test_subset) 128 | 129 | # Fill in the labels that is in the top k prob 130 | all_label_probs = [] 131 | all_missing_positions = [] 132 | cnt = 0 133 | for i, ans in enumerate(all_raw_answers): 134 | try: 135 | top_logprobs = ans["logprobs"]["top_logprobs"][ 136 | 0 137 | ] # [0] since we only ask for complete one more token 138 | except: 139 | cnt += 1 # cnt for corner case 140 | label_probs = [0] * len(label_dict.keys()) 141 | for j, label_list in label_dict.items(): 142 | all_found = True 143 | for label in label_list: # each possible label correspond to the same class 144 | label = " " + label # notice prompt does not have space after 'A:' 145 | if label in top_logprobs: 146 | label_probs[j] += np.exp(top_logprobs[label]) 147 | else: 148 | all_found = False 149 | if not all_found: 150 | position = (i, j) # (which test example, which label) 151 | all_missing_positions.append(position) 152 | all_label_probs.append(label_probs) 153 | all_label_probs = np.array(all_label_probs) # prob not normalized 154 | 155 | return all_label_probs # NOT NORMALIZED 156 | 157 | 158 | def eval_accuracy(all_label_probs, test_labels, mode=None, p_cf=None): 159 | # evaluate the accuracy with and without contextual calibration 160 | num_classes = all_label_probs.shape[1] 161 | if p_cf is None: 162 | # do not calibrate 163 | W = np.identity(num_classes) 164 | b = np.zeros([num_classes, 1]) 165 | else: 166 | # calibrate 167 | if mode == "diagonal_W": 168 | W = np.linalg.inv(np.identity(num_classes) * p_cf) 169 | b = np.zeros([num_classes, 1]) 170 | elif mode == "identity_W": 171 | W = np.identity(num_classes) 172 | b = -1 * np.expand_dims(p_cf, axis=-1) 173 | else: 174 | assert False 175 | 176 | correctness_list = [] 177 | assert len(all_label_probs) == len(test_labels) 178 | for label_probs, true_label in zip(all_label_probs, test_labels): 179 | if np.sum(label_probs) > 0: # corner case np.sum(label_probs)=0. 180 | label_probs = label_probs / np.sum(label_probs) # normalize to 1 181 | 182 | calibrate_label_probs = np.matmul(W, np.expand_dims(label_probs, axis=-1)) + b 183 | 184 | ans_label = np.argmax(calibrate_label_probs) 185 | if ans_label == true_label: 186 | correctness_list.append(1) 187 | else: 188 | correctness_list.append(0) 189 | return np.mean(correctness_list) 190 | 191 | 192 | def get_p_content_free(train_subset, openai_model, content_free_inputs=("N/A",)): 193 | """Query model with content free input, return its prediction probability for each label""" 194 | all_p_y = [] 195 | for content_free_input in content_free_inputs: 196 | prompt = construct_prompt_same(train_subset, content_free_input) 197 | p_y = [0] * len(label_dict) 198 | for i, answers in label_dict.items(): 199 | prob = 0 200 | for a in answers: 201 | prob += np.exp( 202 | complete( 203 | prompt + " " + a, 0, openai_model, echo=True, num_log_probs=1 204 | )["choices"][0]["logprobs"]["token_logprobs"][-1] 205 | ) 206 | p_y[i] = prob 207 | all_p_y.append(p_y) 208 | p_y = np.mean(np.array(all_p_y), axis=0) 209 | p_y = p_y / np.sum(p_y) # normalize 210 | return p_y 211 | 212 | 213 | def merge_logprobs_topk_mean( 214 | private_next_logprobs: list[dict[int, float]], 215 | public_next_logprobs: dict[int, float], 216 | n_vocab: int, 217 | no_public_token: bool, 218 | normalize_max: bool, 219 | ) -> np.ndarray: 220 | # Compute merged distribution 221 | # logsumexp - np.log(...): compute mean probability of distribution 222 | if normalize_max: 223 | normalize_func = ( 224 | log_max_normalize # normalize max probability to 1, Exponential mechanism 225 | ) 226 | else: 227 | normalize_func = ( 228 | log_normalize # normalize sum probability to 1, Gaussian mechanism 229 | ) 230 | if no_public_token: 231 | merged_next_logprobs = scipy.special.logsumexp( 232 | np.stack( 233 | [ 234 | # Turn into a 1D tensor of size n_vocab 235 | densify( 236 | n_vocab, 237 | # Normalize distribution 238 | normalize_func( 239 | # Filter to the top 100 most likely next tokens according to the public prompt 240 | {k: v for k, v in lps.items()} 241 | ), 242 | ) 243 | for lps in private_next_logprobs 244 | ] 245 | ), 246 | axis=0, 247 | ) - np.log(len(private_next_logprobs)) 248 | 249 | else: 250 | merged_next_logprobs = scipy.special.logsumexp( 251 | np.stack( 252 | [ 253 | # Turn into a 1D tensor of size n_vocab 254 | densify( 255 | n_vocab, 256 | # Normalize distribution 257 | normalize_func( 258 | # Filter to the top 100 most likely next tokens according to the public prompt 259 | {k: v for k, v in lps.items() if k in public_next_logprobs} 260 | ), 261 | ) 262 | for lps in private_next_logprobs 263 | ] 264 | ), 265 | axis=0, 266 | ) - np.log(len(private_next_logprobs)) 267 | merged_next_probs = np.exp(merged_next_logprobs) 268 | return merged_next_probs 269 | 270 | 271 | async def generate_with_private_prompts( 272 | trainset, 273 | num_private_train, 274 | num_private_train_splits, 275 | instruction, 276 | public_train_prompt: str, 277 | stop_tokens: Set[int], 278 | test_input: int, 279 | lm: LM, 280 | noise_rng: np.random.RandomState, 281 | sigma: float, 282 | labels, 283 | top_p, 284 | no_public_token: bool, 285 | subsample_per_token: bool, 286 | sample_same_label_prompts: bool, 287 | gen_seed: int, 288 | max_tokens: int = 100 - 1, 289 | normalize_max: bool = False, 290 | ) -> list[int]: 291 | generated_token_ids: list[int] = [] 292 | 293 | stringified_test_datum = format_test_input_for_prompt(labels, test_input) 294 | public_prompt = public_train_prompt + stringified_test_datum 295 | public_prompt_tokens = lm.encoding.encode(public_prompt) 296 | 297 | assert num_private_train_splits > 0 298 | if sample_same_label_prompts: 299 | select_list = [] 300 | for i in range(len(trainset)): 301 | if trainset[i]["coarse_label"] == test_input: 302 | select_list.append(i) 303 | train_subset = trainset.select(select_list, keep_in_memory=True) 304 | else: 305 | train_subset = trainset.select(range(len(trainset)), keep_in_memory=True) 306 | 307 | if not subsample_per_token: 308 | private_train_subset = cast( 309 | Iterable[dict[str, str]], 310 | train_subset.shuffle(gen_seed, keep_in_memory=True).select( 311 | range(num_private_train), keep_in_memory=True 312 | ), 313 | ) 314 | private_train_splits = [ 315 | list(it) 316 | for it in more_itertools.distribute( 317 | num_private_train_splits, private_train_subset 318 | ) 319 | ] 320 | private_train_prompts = [ 321 | instruction 322 | + "\n".join(format_full_datum_for_prompt(labels, datum) for datum in split) 323 | for split in private_train_splits 324 | ] 325 | private_prompts = [ 326 | train_prompt + "\n" + stringified_test_datum 327 | for train_prompt in private_train_prompts 328 | ] 329 | private_prompts_tokens = [ 330 | lm.encoding.encode(prompt) for prompt in private_prompts 331 | ] 332 | 333 | cnt = 0 334 | for _ in tqdm.tqdm(range(max_tokens), total=float("inf"), unit=" tokens generated"): 335 | private_next_logprobs: list[dict[int, float]] 336 | public_next_logprobs: dict[int, float] 337 | # Split training dataset 338 | if subsample_per_token: 339 | private_train_subset = cast( 340 | Iterable[dict[str, str]], 341 | train_subset.shuffle(gen_seed + cnt, keep_in_memory=True).select( 342 | range(num_private_train), keep_in_memory=True 343 | ), 344 | ) 345 | cnt += 1 346 | private_train_splits = [ 347 | list(it) 348 | for it in more_itertools.distribute( 349 | num_private_train_splits, private_train_subset 350 | ) 351 | ] 352 | # Turn the data into prompts 353 | private_train_prompts = [ 354 | instruction 355 | + "\n".join( 356 | format_full_datum_for_prompt(labels, datum) for datum in split 357 | ) 358 | for split in private_train_splits 359 | ] 360 | private_prompts = [ 361 | train_prompt + "\n" + stringified_test_datum 362 | for train_prompt in private_train_prompts 363 | ] 364 | private_prompts_tokens = [ 365 | lm.encoding.encode(prompt) for prompt in private_prompts 366 | ] 367 | if no_public_token: 368 | private_next_logprobs = await asyncio.gather( 369 | *( 370 | next_logprobs(lm, prompt + generated_token_ids, top_p=top_p) 371 | for prompt in private_prompts_tokens 372 | ) 373 | ) 374 | merged_next_probs = merge_logprobs_topk_mean( 375 | private_next_logprobs, 376 | None, 377 | lm.encoding.n_vocab, 378 | no_public_token, 379 | normalize_max, 380 | ) 381 | if normalize_max: 382 | # scale = 1/lambda 383 | noise = noise_rng.exponential(scale=sigma, size=lm.encoding.n_vocab) 384 | else: 385 | noise = noise_rng.normal(0, sigma, size=lm.encoding.n_vocab) 386 | merged_next_probs += noise 387 | else: 388 | public_next_logprobs = await next_logprobs( 389 | lm, public_prompt_tokens + generated_token_ids, top_p=top_p 390 | ) 391 | private_next_logprobs = await asyncio.gather( 392 | *( 393 | normalized_logprobs_for_chosen_tokens( 394 | lm, 395 | prompt + generated_token_ids, 396 | public_next_logprobs.keys(), 397 | top_p=top_p, 398 | ) 399 | for prompt in private_prompts_tokens 400 | ) 401 | ) 402 | merged_next_probs = merge_logprobs_topk_mean( 403 | private_next_logprobs, 404 | public_next_logprobs, 405 | lm.encoding.n_vocab, 406 | no_public_token, 407 | normalize_max, 408 | ) 409 | if normalize_max: 410 | # scale = 1/lambda 411 | noise = noise_rng.exponential( 412 | scale=sigma, size=len(public_next_logprobs) 413 | ) 414 | else: 415 | noise = noise_rng.normal(0, sigma, size=len(public_next_logprobs)) 416 | merged_next_probs[list(public_next_logprobs.keys())] += noise 417 | 418 | next_token_id = int(np.argmax(merged_next_probs)) 419 | 420 | if next_token_id in stop_tokens: 421 | break 422 | 423 | generated_token_ids.append(next_token_id) 424 | 425 | del next_token_id 426 | return generated_token_ids 427 | 428 | 429 | async def generate_with_public_prompt( 430 | public_train_prompt: str, 431 | stop_tokens: Set[str], 432 | test_input: str, 433 | lm: LM, 434 | labels, 435 | max_tokens: int = 500, 436 | ) -> list[int]: 437 | public_prompt = public_train_prompt + format_test_input_for_prompt( 438 | labels, test_input 439 | ) 440 | public_prompt_tokens = lm.encoding.encode(public_prompt) 441 | public_prompt_tokens = public_prompt 442 | 443 | [completion] = await lm.completions( 444 | public_prompt_tokens, 445 | CompletionsSettings( 446 | temperature=0.0, max_tokens=max_tokens, n=1, stop=list(stop_tokens) 447 | ), 448 | ) 449 | generated_tokens = [st.token.token_id for st in completion] 450 | return generated_tokens 451 | 452 | 453 | def select_uniform_n_shots_over_labels(data, n_shots): 454 | select_list = [] 455 | n_shots_per_label = math.ceil(n_shots / len(labels)) 456 | labels_counter = {label[1][0]: n_shots_per_label for label in label_dict.items()} 457 | n_shots_selected = 0 458 | for i in range(len(data)): 459 | label = label_dict[data[i]["coarse_label"]][0] 460 | if labels_counter[label] == 0 or data[i]["coarse_label"] == 0: 461 | continue 462 | else: 463 | labels_counter[label] -= 1 464 | select_list.append(i) 465 | n_shots_selected += 1 466 | if n_shots_selected == n_shots: 467 | break 468 | query_subset = data.select(select_list, keep_in_memory=True) 469 | return query_subset 470 | 471 | 472 | def _main( 473 | sigma: Annotated[float, typer.Option()], # noise parameters 474 | openai_model: Annotated[str, typer.Option()] = "babbage", 475 | print_prompts: Annotated[bool, typer.Option()] = False, 476 | # num_private_train=MN. MN=0 with num_valid=4 will get epsilon=0 (4-shot) results. 477 | num_private_train: Annotated[int, typer.Option()] = DEFAULT_NUM_PRIVATE_TRAIN, 478 | # by default set to 0. set_num_public_train >0 indicates additional public data available. 479 | set_num_public_train: Annotated[int, typer.Option()] = DEFAULT_NUM_PUBLIC_TRAIN, 480 | # num_valid=n. n samples to be generated for n-shot ICL 481 | num_valid: Annotated[int, typer.Option()] = DEFAULT_NUM_VALID, 482 | # num_private_train_splits=M 483 | num_private_train_splits: Annotated[ 484 | int, typer.Option() 485 | ] = DEFAULT_NUM_PRIVATE_TRAIN_SPLITS, 486 | num_test: Annotated[int, typer.Option()] = DEFAULT_NUM_TEST, 487 | # no_public_token=True, RVP=False; no_public_token=False, RVP=True 488 | no_public_token: Annotated[bool, typer.Option()] = False, 489 | # subsample_per_token=True: at each token generation, subsample a new test set 490 | subsample_per_token: Annotated[bool, typer.Option()] = False, 491 | use_dp_prompts: Annotated[bool, typer.Option()] = False, 492 | # sample_same_label_prompts=True: sample subsets from the sets with same targeted labels. 493 | sample_same_label_prompts: Annotated[bool, typer.Option()] = False, 494 | # normalize_max=True, Exponential mechanism; normalize_max=False, Gaussian mechanism 495 | normalize_max: Annotated[bool, typer.Option()] = False, 496 | # max_token_per_text=T_max 497 | max_token_per_text: Annotated[int, typer.Option()] = 15, 498 | # consistent with default parameters in the documentation https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#completions 499 | top_p: Annotated[float, typer.Option()] = 1, 500 | # random seed for subsampling in generation 501 | synth_seed: Annotated[int, typer.Option()] = 0, 502 | # random seed for n-shot demonstrations sampling in evaluation 503 | eval_seed: Annotated[int, typer.Option()] = 0, 504 | ): 505 | async def main(): 506 | if (num_private_train == 0) != (num_private_train_splits == 0): 507 | raise ValueError( 508 | "Either both or neither of --num-private-train and --num-private-train-splits can be 0" 509 | ) 510 | command = ["python", sys.argv[0]] 511 | for x in sys.argv[1:]: 512 | if x.startswith("--"): 513 | assert '"' not in x and "'" not in x 514 | command.append(x) 515 | else: 516 | assert "'" not in x 517 | if re.match("^[a-zA-Z0-9_]+$", x): 518 | command.append("%s" % x) 519 | else: 520 | command.append("'%s'" % x) 521 | command = " ".join(command) 522 | print(command) 523 | 524 | if no_public_token: 525 | num_public_train = 0 526 | else: 527 | num_public_train = set_num_public_train 528 | 529 | lm = api_openai_com(openai_model) 530 | noise_rng = np.random.RandomState() 531 | 532 | data = cast(DatasetDict, load_dataset("trec")) 533 | print(labels) 534 | 535 | trainset = data["train"].shuffle(seed=synth_seed, keep_in_memory=True) 536 | print("trainset length", len(trainset)) 537 | if num_public_train > 0: 538 | public_train_subset = cast( 539 | Iterable[dict[str, str]], 540 | trainset.select( 541 | range( 542 | len(trainset) - num_public_train, 543 | len(trainset), 544 | keep_in_memory=True, 545 | ) 546 | ), 547 | ) 548 | else: 549 | public_train_subset = [] 550 | 551 | trainset = trainset.select( 552 | range(len(trainset) - num_public_train), keep_in_memory=True 553 | ) 554 | queryset = data["train"].shuffle(seed=eval_seed, keep_in_memory=True) 555 | query_subset = select_uniform_n_shots_over_labels(queryset, num_valid) 556 | 557 | if use_dp_prompts: 558 | synthetic_examples = [] 559 | 560 | # Turn the data into prompts 561 | instruction = "Given a label of answer type, generate a question based on the given answer type accordingly.\n\n" 562 | 563 | public_train_prompt = instruction + "\n".join( 564 | format_full_datum_for_prompt(labels, datum) 565 | for datum in public_train_subset 566 | ) 567 | 568 | if print_prompts: 569 | print(public_train_prompt) 570 | print("=========") 571 | 572 | if normalize_max: 573 | print("Exponential Mechanism") 574 | assert num_private_train == 0 or sigma > 0 575 | if num_private_train > 0: 576 | # scale == sigma_calib == 1/lambda. lambda for exponential distribution. 577 | sigma_calib = (2 / num_private_train_splits) * (1 / sigma) 578 | else: 579 | print("Gaussian Mechanism") 580 | if num_private_train_splits > 0: 581 | sigma_calib = math.sqrt(2) / num_private_train_splits * sigma 582 | else: 583 | sigma_calib = 0 584 | print( 585 | f"sigma in command {sigma}. sigma added according to sensitivity {sigma_calib}" 586 | ) 587 | 588 | stop_tokens = {"\n", "<|endoftext|>", " END"} 589 | stop_tokens_ids = {lm.encoding.encode_single_token(t) for t in stop_tokens} 590 | 591 | client_session.set(aiohttp.ClientSession()) 592 | 593 | len_token = [] 594 | async with client_session.get(): 595 | for i, test_datum in enumerate(query_subset, 1): 596 | print(f"# Example {i}") 597 | print(f'Answer Type: "{labels[test_datum["coarse_label"]]}"') 598 | print(f'References:\n "{test_datum["text"]}"') 599 | 600 | np.random.seed(synth_seed + i) 601 | gen_seed = np.random.randint(100000) 602 | print(f"gen-seed: {gen_seed}") 603 | 604 | if num_private_train_splits > 0: 605 | generated_token_ids = await generate_with_private_prompts( 606 | trainset, 607 | num_private_train, 608 | num_private_train_splits, 609 | instruction, 610 | public_train_prompt, 611 | stop_tokens_ids, 612 | test_datum["coarse_label"], 613 | lm, 614 | noise_rng, 615 | sigma_calib, 616 | labels, 617 | top_p, 618 | no_public_token, 619 | subsample_per_token, 620 | sample_same_label_prompts, 621 | gen_seed, 622 | max_tokens=max_token_per_text 623 | - 1, # need one token length for EOS. 624 | normalize_max=normalize_max, 625 | ) 626 | else: 627 | generated_token_ids = await generate_with_public_prompt( 628 | public_train_prompt, 629 | stop_tokens, 630 | test_datum["coarse_label"], 631 | lm, 632 | labels, 633 | max_tokens=max_token_per_text, 634 | ) 635 | 636 | generated = lm.encoding.decode(generated_token_ids).rstrip('"') 637 | 638 | print(f"Generated: {generated}\n") 639 | output_datum = {} 640 | output_datum["text"] = generated.strip() 641 | output_datum["coarse_label"] = test_datum["coarse_label"] 642 | synthetic_examples.append(output_datum) 643 | 644 | if num_test > 0: 645 | test_subset = data["test"] 646 | test_labels = [test_example["coarse_label"] for test_example in test_subset] 647 | 648 | content_free_inputs = [{"text": "N/A"}, {"text": ""}, {"text": "[MASK]"}] 649 | p_cf_wout_DP = get_p_content_free( 650 | query_subset, openai_model, content_free_inputs=content_free_inputs 651 | ) 652 | 653 | all_raw_answers_wout_DP = get_model_response( 654 | query_subset, test_subset, openai_model 655 | ) 656 | all_label_probs_wout_DP = get_label_probs( 657 | all_raw_answers_wout_DP, test_subset 658 | ) 659 | 660 | acc_original_wout_DP = eval_accuracy(all_label_probs_wout_DP, test_labels) 661 | acc_calibrated_wout_DP = eval_accuracy( 662 | all_label_probs_wout_DP, 663 | test_labels, 664 | mode="diagonal_W", 665 | p_cf=p_cf_wout_DP, 666 | ) 667 | 668 | print(f"Accuracy (original) without DP: {acc_original_wout_DP}") 669 | print(f"Accuracy (calibrated) without DP: {acc_calibrated_wout_DP}") 670 | 671 | if use_dp_prompts: 672 | p_cf_w_DP = get_p_content_free( 673 | synthetic_examples, 674 | openai_model, 675 | content_free_inputs=content_free_inputs, 676 | ) 677 | 678 | all_raw_answers_w_DP = get_model_response( 679 | synthetic_examples, test_subset, openai_model 680 | ) 681 | all_label_probs_w_DP = get_label_probs( 682 | all_raw_answers_w_DP, test_subset 683 | ) 684 | 685 | acc_original_w_DP = eval_accuracy(all_label_probs_w_DP, test_labels) 686 | acc_calibrated_w_DP = eval_accuracy( 687 | all_label_probs_w_DP, test_labels, mode="diagonal_W", p_cf=p_cf_w_DP 688 | ) 689 | 690 | print(f"Accuracy (original) with DP: {acc_original_w_DP}") 691 | print(f"Accuracy (calibrated) with DP: {acc_calibrated_w_DP}") 692 | 693 | try: 694 | asyncio.run(main()) 695 | except KeyboardInterrupt: 696 | traceback.print_exc() 697 | raise 698 | 699 | 700 | if __name__ == "__main__": 701 | typer.run(_main) 702 | -------------------------------------------------------------------------------- /src/dp_few_shot_generation/run_exp_dbpedia.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import asyncio 5 | import math 6 | import re 7 | import sys 8 | import time 9 | import traceback 10 | from collections.abc import Iterable, Set 11 | from typing import Annotated, cast 12 | 13 | import aiohttp 14 | import more_itertools 15 | import numpy as np 16 | import openai 17 | import scipy.special 18 | import tqdm 19 | import typer 20 | from datasets import DatasetDict, load_dataset 21 | from lmapi.lm import LM, CompletionsSettings 22 | from lmapi.openai import client_session 23 | 24 | from dp_few_shot_generation.lm import ( 25 | api_openai_com, 26 | next_logprobs, 27 | normalized_logprobs_for_chosen_tokens, 28 | ) 29 | from dp_few_shot_generation.prob_utils import densify, log_max_normalize, log_normalize 30 | 31 | DEFAULT_NUM_PRIVATE_TRAIN = 80 32 | DEFAULT_NUM_PUBLIC_TRAIN = 0 33 | DEFAULT_NUM_VALID = 4 34 | DEFAULT_NUM_PRIVATE_TRAIN_SPLITS = 40 35 | DEFAULT_NUM_TEST = 1000 36 | 37 | labels = [ 38 | "Company", 39 | "School", 40 | "Artist", 41 | "Ath", 42 | "Polit", 43 | "Transportation", 44 | "Building", 45 | "Nature", 46 | "Village", 47 | "Animal", 48 | "Plant", 49 | "Album", 50 | "Film", 51 | "Book", 52 | ] 53 | label_dict = { 54 | 0: ["Company"], 55 | 1: ["School"], 56 | 2: ["Artist"], 57 | 3: ["Ath"], 58 | 4: ["Polit"], 59 | 5: ["Transportation"], 60 | 6: ["Building"], 61 | 7: ["Nature"], 62 | 8: ["Village"], 63 | 9: ["Animal"], 64 | 10: ["Plant"], 65 | 11: ["Album"], 66 | 12: ["Film"], 67 | 13: ["Book"], 68 | } 69 | 70 | 71 | def format_full_datum_for_prompt(labels, datum: dict[str, str]): 72 | return f'Document Type: "{labels[datum["label"]]}"\nText: "{datum["content"] + " END"}"\n' 73 | 74 | 75 | def format_test_input_for_prompt(labels, test_input: int): 76 | return f'Document Type: "{labels[test_input]}"\nText: "' 77 | 78 | 79 | def construct_prompt_same(train_examples, test_example): 80 | labels_str = ", ".join(labels) 81 | prompt = ( 82 | f"Classify the documents based on whether they are about a {labels_str}.\n\n" 83 | ) 84 | for train_example in train_examples: 85 | prompt += "Article: " + train_example["content"] + "\n" 86 | prompt += "Answer: " + label_dict[train_example["label"]][0] + "\n\n" 87 | prompt += "Article: " + test_example["content"] + "\n" 88 | prompt += "Answer:" 89 | return prompt 90 | 91 | 92 | def complete(prompt, l, model_name, temp=0, num_log_probs=None, echo=False, n=None): 93 | # call GPT-3 API until result is provided and then return it 94 | response = None 95 | received = False 96 | while not received: 97 | try: 98 | response = openai.Completion.create( 99 | engine=model_name, 100 | prompt=prompt, 101 | max_tokens=l, 102 | temperature=temp, 103 | logprobs=num_log_probs, 104 | echo=echo, 105 | stop="\n", 106 | n=n, 107 | ) 108 | received = True 109 | except: 110 | error = sys.exc_info()[0] 111 | if ( 112 | error == openai.error.InvalidRequestError 113 | ): # something is wrong: e.g. prompt too long 114 | print(f"InvalidRequestError\nPrompt passed in:\n\n{prompt}\n\n") 115 | assert False 116 | 117 | print("API error:", error) 118 | time.sleep(1) 119 | return response 120 | 121 | 122 | def chunks(lst, n): 123 | """Yield successive n-sized chunks from lst.""" 124 | for i in range(0, len(lst), n): 125 | yield lst[i : i + n] 126 | 127 | 128 | def get_model_response(data, test_examples, openai_model): 129 | all_raw_answers = [] 130 | 131 | prompts = [] 132 | train_examples = data 133 | 134 | for test_example in test_examples: 135 | prompts.append(construct_prompt_same(train_examples, test_example)) 136 | 137 | chunked_prompts = list(chunks(prompts, 20)) 138 | for test_chunk in chunked_prompts: 139 | response = complete(test_chunk, l=1, model_name=openai_model, num_log_probs=100) 140 | 141 | for answer_id, answer in enumerate(response["choices"]): 142 | all_raw_answers.append(answer) 143 | 144 | return all_raw_answers 145 | 146 | 147 | def get_label_probs(all_raw_answers, test_subset): 148 | """Obtain model's label probability for each of the test examples. The returned prob is NOT normalized""" 149 | num_classes = len(label_dict) 150 | approx = False 151 | assert len(all_raw_answers) == len(test_subset) 152 | 153 | # Fill in the labels that is in the top k prob 154 | all_label_probs = [] 155 | all_missing_positions = [] 156 | cnt = 0 157 | for i, ans in enumerate(all_raw_answers): 158 | try: 159 | top_logprobs = ans["logprobs"]["top_logprobs"][ 160 | 0 161 | ] # [0] since we only ask for complete one more token 162 | except: 163 | cnt += 1 # cnt for corner case 164 | label_probs = [0] * len(label_dict.keys()) 165 | for j, label_list in label_dict.items(): 166 | all_found = True 167 | for label in label_list: # each possible label correspond to the same class 168 | label = " " + label # notice prompt does not have space after 'A:' 169 | if label in top_logprobs: 170 | label_probs[j] += np.exp(top_logprobs[label]) 171 | else: 172 | all_found = False 173 | if not all_found: 174 | position = (i, j) # (which test example, which label) 175 | all_missing_positions.append(position) 176 | all_label_probs.append(label_probs) 177 | all_label_probs = np.array(all_label_probs) # prob not normalized 178 | 179 | return all_label_probs # NOT NORMALIZED 180 | 181 | 182 | def eval_accuracy(all_label_probs, test_labels, mode=None, p_cf=None): 183 | # evaluate the accuracy with and without contextual calibration 184 | num_classes = all_label_probs.shape[1] 185 | if p_cf is None: 186 | # do not calibrate 187 | W = np.identity(num_classes) 188 | b = np.zeros([num_classes, 1]) 189 | else: 190 | # calibrate 191 | if mode == "diagonal_W": 192 | W = np.linalg.inv(np.identity(num_classes) * p_cf) 193 | b = np.zeros([num_classes, 1]) 194 | elif mode == "identity_W": 195 | W = np.identity(num_classes) 196 | b = -1 * np.expand_dims(p_cf, axis=-1) 197 | else: 198 | assert False 199 | 200 | correctness_list = [] 201 | assert len(all_label_probs) == len(test_labels) 202 | for label_probs, true_label in zip(all_label_probs, test_labels): 203 | if np.sum(label_probs) > 0: # corner case np.sum(label_probs)=0. 204 | label_probs = label_probs / np.sum(label_probs) # normalize to 1 205 | 206 | calibrate_label_probs = np.matmul(W, np.expand_dims(label_probs, axis=-1)) + b 207 | 208 | ans_label = np.argmax(calibrate_label_probs) 209 | if ans_label == true_label: 210 | correctness_list.append(1) 211 | else: 212 | correctness_list.append(0) 213 | return np.mean(correctness_list) 214 | 215 | 216 | def get_p_content_free(train_subset, openai_model, content_free_inputs=("N/A",)): 217 | """Query model with content free input, return its prediction probability for each label""" 218 | all_p_y = [] 219 | for content_free_input in content_free_inputs: 220 | prompt = construct_prompt_same(train_subset, content_free_input) 221 | p_y = [0] * len(label_dict) 222 | for i, answers in label_dict.items(): 223 | prob = 0 224 | for a in answers: 225 | prob += np.exp( 226 | complete( 227 | prompt + " " + a, 0, openai_model, echo=True, num_log_probs=1 228 | )["choices"][0]["logprobs"]["token_logprobs"][-1] 229 | ) 230 | p_y[i] = prob 231 | all_p_y.append(p_y) 232 | p_y = np.mean(np.array(all_p_y), axis=0) 233 | p_y = p_y / np.sum(p_y) # normalize 234 | return p_y 235 | 236 | 237 | def merge_logprobs_topk_mean( 238 | private_next_logprobs: list[dict[int, float]], 239 | public_next_logprobs: dict[int, float], 240 | n_vocab: int, 241 | no_public_token: bool, 242 | normalize_max: bool, 243 | ) -> np.ndarray: 244 | # Compute merged distribution 245 | # logsumexp - np.log(...): compute mean probability of distribution 246 | if normalize_max: 247 | normalize_func = ( 248 | log_max_normalize # normalize max probability to 1, Exponential mechanism 249 | ) 250 | else: 251 | normalize_func = ( 252 | log_normalize # normalize sum probability to 1, Gaussian mechanism 253 | ) 254 | if no_public_token: 255 | merged_next_logprobs = scipy.special.logsumexp( 256 | np.stack( 257 | [ 258 | # Turn into a 1D tensor of size n_vocab 259 | densify( 260 | n_vocab, 261 | # Normalize distribution 262 | normalize_func( 263 | # Filter to the top 100 most likely next tokens according to the public prompt 264 | {k: v for k, v in lps.items()} 265 | ), 266 | ) 267 | for lps in private_next_logprobs 268 | ] 269 | ), 270 | axis=0, 271 | ) - np.log(len(private_next_logprobs)) 272 | 273 | else: 274 | merged_next_logprobs = scipy.special.logsumexp( 275 | np.stack( 276 | [ 277 | # Turn into a 1D tensor of size n_vocab 278 | densify( 279 | n_vocab, 280 | # Normalize distribution 281 | normalize_func( 282 | # Filter to the top 100 most likely next tokens according to the public prompt 283 | {k: v for k, v in lps.items() if k in public_next_logprobs} 284 | ), 285 | ) 286 | for lps in private_next_logprobs 287 | ] 288 | ), 289 | axis=0, 290 | ) - np.log(len(private_next_logprobs)) 291 | merged_next_probs = np.exp(merged_next_logprobs) 292 | return merged_next_probs 293 | 294 | 295 | async def generate_with_private_prompts( 296 | trainset, 297 | num_private_train, 298 | num_private_train_splits, 299 | instruction, 300 | public_train_prompt: str, 301 | stop_tokens: Set[int], 302 | test_input: int, 303 | lm: LM, 304 | noise_rng: np.random.RandomState, 305 | sigma: float, 306 | labels, 307 | top_p, 308 | no_public_token: bool, 309 | subsample_per_token: bool, 310 | sample_same_label_prompts: bool, 311 | gen_seed: int, 312 | max_tokens: int, 313 | normalize_max: bool = False, 314 | ) -> list[int]: 315 | generated_token_ids: list[int] = [] 316 | 317 | stringified_test_datum = format_test_input_for_prompt(labels, test_input) 318 | public_prompt = public_train_prompt + stringified_test_datum 319 | public_prompt_tokens = lm.encoding.encode(public_prompt) 320 | 321 | assert num_private_train_splits > 0 322 | if sample_same_label_prompts: 323 | select_list = [] 324 | for i in range(len(trainset)): 325 | if trainset[i]["label"] == test_input: 326 | select_list.append(i) 327 | train_subset = trainset.select(select_list, keep_in_memory=True) 328 | else: 329 | train_subset = trainset.select(range(len(trainset)), keep_in_memory=True) 330 | 331 | if not subsample_per_token: 332 | private_train_subset = cast( 333 | Iterable[dict[str, str]], 334 | train_subset.shuffle(gen_seed, keep_in_memory=True).select( 335 | range(num_private_train), keep_in_memory=True 336 | ), 337 | ) 338 | private_train_splits = [ 339 | list(it) 340 | for it in more_itertools.distribute( 341 | num_private_train_splits, private_train_subset 342 | ) 343 | ] 344 | private_train_prompts = [ 345 | instruction 346 | + "\n".join(format_full_datum_for_prompt(labels, datum) for datum in split) 347 | for split in private_train_splits 348 | ] 349 | private_prompts = [ 350 | train_prompt + "\n" + stringified_test_datum 351 | for train_prompt in private_train_prompts 352 | ] 353 | private_prompts_tokens = [ 354 | lm.encoding.encode(prompt) for prompt in private_prompts 355 | ] 356 | 357 | cnt = 0 358 | for _ in tqdm.tqdm(range(max_tokens), total=float("inf"), unit=" tokens generated"): 359 | private_next_logprobs: list[dict[int, float]] 360 | public_next_logprobs: dict[int, float] 361 | # Split training dataset 362 | if subsample_per_token: 363 | private_train_subset = cast( 364 | Iterable[dict[str, str]], 365 | train_subset.shuffle(gen_seed + cnt, keep_in_memory=True).select( 366 | range(num_private_train), keep_in_memory=True 367 | ), 368 | ) 369 | cnt += 1 370 | private_train_splits = [ 371 | list(it) 372 | for it in more_itertools.distribute( 373 | num_private_train_splits, private_train_subset 374 | ) 375 | ] 376 | # Turn the data into prompts 377 | private_train_prompts = [ 378 | instruction 379 | + "\n".join( 380 | format_full_datum_for_prompt(labels, datum) for datum in split 381 | ) 382 | for split in private_train_splits 383 | ] 384 | private_prompts = [ 385 | train_prompt + "\n" + stringified_test_datum 386 | for train_prompt in private_train_prompts 387 | ] 388 | private_prompts_tokens = [ 389 | lm.encoding.encode(prompt) for prompt in private_prompts 390 | ] 391 | if no_public_token: 392 | private_next_logprobs = await asyncio.gather( 393 | *( 394 | next_logprobs(lm, prompt + generated_token_ids, top_p=top_p) 395 | for prompt in private_prompts_tokens 396 | ) 397 | ) 398 | merged_next_probs = merge_logprobs_topk_mean( 399 | private_next_logprobs, 400 | None, 401 | lm.encoding.n_vocab, 402 | no_public_token, 403 | normalize_max, 404 | ) 405 | if normalize_max: 406 | # scale = 1/lambda 407 | noise = noise_rng.exponential(scale=sigma, size=lm.encoding.n_vocab) 408 | else: 409 | noise = noise_rng.normal(0, sigma, size=lm.encoding.n_vocab) 410 | merged_next_probs += noise 411 | else: 412 | public_next_logprobs = await next_logprobs( 413 | lm, public_prompt_tokens + generated_token_ids, top_p=top_p 414 | ) 415 | private_next_logprobs = await asyncio.gather( 416 | *( 417 | normalized_logprobs_for_chosen_tokens( 418 | lm, 419 | prompt + generated_token_ids, 420 | public_next_logprobs.keys(), 421 | top_p=top_p, 422 | ) 423 | for prompt in private_prompts_tokens 424 | ) 425 | ) 426 | merged_next_probs = merge_logprobs_topk_mean( 427 | private_next_logprobs, 428 | public_next_logprobs, 429 | lm.encoding.n_vocab, 430 | no_public_token, 431 | normalize_max, 432 | ) 433 | if normalize_max: 434 | # scale = 1/lambda 435 | noise = noise_rng.exponential( 436 | scale=sigma, size=len(public_next_logprobs) 437 | ) 438 | else: 439 | noise = noise_rng.normal(0, sigma, size=len(public_next_logprobs)) 440 | merged_next_probs[list(public_next_logprobs.keys())] += noise 441 | 442 | next_token_id = int(np.argmax(merged_next_probs)) 443 | 444 | if next_token_id in stop_tokens: 445 | break 446 | 447 | generated_token_ids.append(next_token_id) 448 | 449 | del next_token_id 450 | return generated_token_ids 451 | 452 | 453 | async def generate_with_public_prompt( 454 | public_train_prompt: str, 455 | stop_tokens: Set[str], 456 | test_input: str, 457 | lm: LM, 458 | labels, 459 | max_tokens: int = 500, 460 | ) -> list[int]: 461 | public_prompt = public_train_prompt + format_test_input_for_prompt( 462 | labels, test_input 463 | ) 464 | public_prompt_tokens = lm.encoding.encode(public_prompt) 465 | public_prompt_tokens = public_prompt 466 | 467 | [completion] = await lm.completions( 468 | public_prompt_tokens, 469 | CompletionsSettings( 470 | temperature=0.0, max_tokens=max_tokens, n=1, stop=list(stop_tokens) 471 | ), 472 | ) 473 | generated_tokens = [st.token.token_id for st in completion] 474 | return generated_tokens 475 | 476 | 477 | def select_uniform_n_shots_over_labels(data, n_shots): 478 | select_list = [] 479 | n_shots_per_label = math.ceil(n_shots / len(labels)) 480 | labels_counter = {label[1][0]: n_shots_per_label for label in label_dict.items()} 481 | n_shots_selected = 0 482 | for i in range(len(data)): 483 | label = label_dict[data[i]["label"]][0] 484 | if labels_counter[label] == 0: 485 | continue 486 | else: 487 | labels_counter[label] -= 1 488 | select_list.append(i) 489 | n_shots_selected += 1 490 | if n_shots_selected == n_shots: 491 | break 492 | query_subset = data.select(select_list, keep_in_memory=True) 493 | return query_subset 494 | 495 | 496 | def _main( 497 | sigma: Annotated[float, typer.Option()], # noise parameters 498 | openai_model: Annotated[str, typer.Option()] = "babbage", 499 | print_prompts: Annotated[bool, typer.Option()] = False, 500 | # num_private_train=MN. MN=0 with num_valid=4 will get epsilon=0 (4-shot) results. 501 | num_private_train: Annotated[int, typer.Option()] = DEFAULT_NUM_PRIVATE_TRAIN, 502 | # by default set to 0. set_num_public_train >0 indicates additional public data available. 503 | set_num_public_train: Annotated[int, typer.Option()] = DEFAULT_NUM_PUBLIC_TRAIN, 504 | # num_valid=n. n samples to be generated for n-shot ICL 505 | num_valid: Annotated[int, typer.Option()] = DEFAULT_NUM_VALID, 506 | # num_private_train_splits=M 507 | num_private_train_splits: Annotated[ 508 | int, typer.Option() 509 | ] = DEFAULT_NUM_PRIVATE_TRAIN_SPLITS, 510 | num_test: Annotated[int, typer.Option()] = DEFAULT_NUM_TEST, 511 | # no_public_token=True, RVP=False; no_public_token=False, RVP=True 512 | no_public_token: Annotated[bool, typer.Option()] = False, 513 | # subsample_per_token=True: at each token generation, subsample a new test set 514 | subsample_per_token: Annotated[bool, typer.Option()] = False, 515 | use_dp_prompts: Annotated[bool, typer.Option()] = False, 516 | # sample_same_label_prompts=True: sample subsets from the sets with same targeted labels. 517 | sample_same_label_prompts: Annotated[bool, typer.Option()] = False, 518 | # normalize_max=True, Exponential mechanism; normalize_max=False, Gaussian mechanism 519 | normalize_max: Annotated[bool, typer.Option()] = False, 520 | # max_token_per_text=T_max 521 | max_token_per_text: Annotated[int, typer.Option()] = 100, 522 | # consistent with default parameters in the documentation https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#completions 523 | top_p: Annotated[float, typer.Option()] = 1, 524 | # random seed for subsampling in generation 525 | synth_seed: Annotated[int, typer.Option()] = 0, 526 | # random seed for n-shot demonstrations sampling in evaluation 527 | eval_seed: Annotated[int, typer.Option()] = 0, 528 | ): 529 | async def main(): 530 | if (num_private_train == 0) != (num_private_train_splits == 0): 531 | raise ValueError( 532 | "Either both or neither of --num-private-train and --num-private-train-splits can be 0" 533 | ) 534 | command = ["python", sys.argv[0]] 535 | for x in sys.argv[1:]: 536 | if x.startswith("--"): 537 | assert '"' not in x and "'" not in x 538 | command.append(x) 539 | else: 540 | assert "'" not in x 541 | if re.match("^[a-zA-Z0-9_]+$", x): 542 | command.append("%s" % x) 543 | else: 544 | command.append("'%s'" % x) 545 | command = " ".join(command) 546 | print(command) 547 | 548 | if no_public_token: 549 | num_public_train = 0 550 | else: 551 | num_public_train = set_num_public_train 552 | 553 | lm = api_openai_com(openai_model) 554 | noise_rng = np.random.RandomState() 555 | 556 | data = cast(DatasetDict, load_dataset("dbpedia_14")) 557 | print(labels) 558 | 559 | trainset = data["train"].shuffle(seed=synth_seed, keep_in_memory=True) 560 | print("trainset length", len(trainset)) 561 | if num_public_train > 0: 562 | public_train_subset = cast( 563 | Iterable[dict[str, str]], 564 | trainset.select( 565 | range( 566 | len(trainset) - num_public_train, 567 | len(trainset), 568 | keep_in_memory=True, 569 | ) 570 | ), 571 | ) 572 | else: 573 | public_train_subset = [] 574 | 575 | trainset = trainset.select( 576 | range(len(trainset) - num_public_train), keep_in_memory=True 577 | ) 578 | queryset = data["train"].shuffle(seed=eval_seed, keep_in_memory=True) 579 | query_subset = select_uniform_n_shots_over_labels(queryset, num_valid) 580 | 581 | if use_dp_prompts: 582 | synthetic_examples = [] 583 | 584 | # Turn the data into prompts 585 | instruction = "Given a label of document type, generate the chosen type of document accordingly.\n\n" 586 | 587 | public_train_prompt = instruction + "\n".join( 588 | format_full_datum_for_prompt(labels, datum) 589 | for datum in public_train_subset 590 | ) 591 | 592 | if print_prompts: 593 | print(public_train_prompt) 594 | print("=========") 595 | 596 | if normalize_max: 597 | print("Exponential Mechanism") 598 | assert num_private_train == 0 or sigma > 0 599 | if num_private_train > 0: 600 | # scale == sigma_calib == 1/lambda. lambda for exponential distribution. 601 | sigma_calib = (2 / num_private_train_splits) * (1 / sigma) 602 | else: 603 | print("Gaussian Mechanism") 604 | if num_private_train_splits > 0: 605 | sigma_calib = math.sqrt(2) / num_private_train_splits * sigma 606 | else: 607 | sigma_calib = 0 608 | print( 609 | f"sigma in command {sigma}. sigma added according to sensitivity {sigma_calib}" 610 | ) 611 | 612 | stop_tokens = {"\n", "<|endoftext|>", " END"} 613 | stop_tokens_ids = {lm.encoding.encode_single_token(t) for t in stop_tokens} 614 | 615 | client_session.set(aiohttp.ClientSession()) 616 | 617 | async with client_session.get(): 618 | for i, test_datum in enumerate(query_subset, 1): 619 | print(f"# Example {i}") 620 | print(f'Document Type: "{labels[test_datum["label"]]}"') 621 | print(f'References:\n "{test_datum["content"]}"') 622 | 623 | np.random.seed(synth_seed + i) 624 | gen_seed = np.random.randint(100000) 625 | print(f"gen-seed: {gen_seed}") 626 | 627 | if num_private_train_splits > 0: 628 | generated_token_ids = await generate_with_private_prompts( 629 | trainset, 630 | num_private_train, 631 | num_private_train_splits, 632 | instruction, 633 | public_train_prompt, 634 | stop_tokens_ids, 635 | test_datum["label"], 636 | lm, 637 | noise_rng, 638 | sigma_calib, 639 | labels, 640 | top_p, 641 | no_public_token, 642 | subsample_per_token, 643 | sample_same_label_prompts, 644 | gen_seed, 645 | max_tokens=max_token_per_text 646 | - 1, # need one token length for EOS. 647 | normalize_max=normalize_max, 648 | ) 649 | else: 650 | generated_token_ids = await generate_with_public_prompt( 651 | public_train_prompt, 652 | stop_tokens, 653 | test_datum["label"], 654 | lm, 655 | labels, 656 | max_tokens=max_token_per_text, 657 | ) 658 | 659 | generated = lm.encoding.decode(generated_token_ids).rstrip('"') 660 | 661 | print(f"Generated: {generated}\n") 662 | output_datum = {} 663 | output_datum["content"] = generated.strip() 664 | output_datum["label"] = test_datum["label"] 665 | synthetic_examples.append(output_datum) 666 | 667 | if num_test > 0: 668 | test_subset = ( 669 | data["test"] 670 | .shuffle(seed=12345, keep_in_memory=True) 671 | .select(range(num_test), keep_in_memory=True) 672 | ) 673 | test_labels = [test_example["label"] for test_example in test_subset] 674 | 675 | content_free_inputs = [ 676 | {"content": "N/A"}, 677 | {"content": ""}, 678 | {"content": "[MASK]"}, 679 | ] 680 | p_cf_wout_DP = get_p_content_free( 681 | query_subset, openai_model, content_free_inputs=content_free_inputs 682 | ) 683 | 684 | all_raw_answers_wout_DP = get_model_response( 685 | query_subset, test_subset, openai_model 686 | ) 687 | all_label_probs_wout_DP = get_label_probs( 688 | all_raw_answers_wout_DP, test_subset 689 | ) 690 | 691 | acc_original_wout_DP = eval_accuracy(all_label_probs_wout_DP, test_labels) 692 | acc_calibrated_wout_DP = eval_accuracy( 693 | all_label_probs_wout_DP, 694 | test_labels, 695 | mode="diagonal_W", 696 | p_cf=p_cf_wout_DP, 697 | ) 698 | 699 | print(f"Accuracy (original) without DP: {acc_original_wout_DP}") 700 | print(f"Accuracy (calibrated) without DP: {acc_calibrated_wout_DP}") 701 | 702 | if use_dp_prompts: 703 | p_cf_w_DP = get_p_content_free( 704 | synthetic_examples, 705 | openai_model, 706 | content_free_inputs=content_free_inputs, 707 | ) 708 | all_raw_answers_w_DP = get_model_response( 709 | synthetic_examples, test_subset, openai_model 710 | ) 711 | 712 | all_label_probs_w_DP = get_label_probs( 713 | all_raw_answers_w_DP, test_subset 714 | ) 715 | 716 | acc_original_w_DP = eval_accuracy(all_label_probs_w_DP, test_labels) 717 | acc_calibrated_w_DP = eval_accuracy( 718 | all_label_probs_w_DP, test_labels, mode="diagonal_W", p_cf=p_cf_w_DP 719 | ) 720 | 721 | print(f"Accuracy (original) with DP: {acc_original_w_DP}") 722 | print(f"Accuracy (calibrated) with DP: {acc_calibrated_w_DP}") 723 | 724 | try: 725 | asyncio.run(main()) 726 | except KeyboardInterrupt: 727 | traceback.print_exc() 728 | raise 729 | 730 | 731 | if __name__ == "__main__": 732 | typer.run(_main) 733 | --------------------------------------------------------------------------------