├── dln ├── __init__.py ├── templates │ ├── thoughts_forward.yaml │ ├── analysis_forward.yaml │ ├── classify_residual.yaml │ ├── suffix_forward.yaml │ ├── suffix_forward_tbs.yaml │ ├── classify_forward.yaml │ ├── suffix_forward_y.yaml │ ├── suffix_forward_tbs_y.yaml │ ├── thoughts_forward_y.yaml │ ├── analysis_forward_y.yaml │ ├── thoughts_backward.yaml │ ├── analysis_backward.yaml │ ├── suffix_backward_h_np_y.yaml │ ├── q_action_prompt_seq.yaml │ └── q_action_prompt.yaml ├── postprocessing.py ├── template.py ├── loss.py ├── vi │ ├── utils.py │ ├── layers.py │ └── sampler.py ├── dataset_info.yaml └── score.py ├── scripts ├── __init__.py ├── utils.py ├── setup_data.sh ├── split_bigbench_navigate.py ├── split_bigbench_date_understanding.py ├── split_bigbench_logical_deduction_seven_objects.py └── split_bigbench_hyperbaton.py ├── projects ├── demo │ ├── requirements.txt │ ├── dln_demo_sample.png │ ├── Dockerfile │ ├── README.md │ └── demo.py ├── guide │ ├── requirements.txt │ ├── templates │ │ ├── classify_forward.yaml │ │ └── guided_backward.yaml │ ├── Dockerfile │ ├── README.md │ └── app.py └── vi_dln │ ├── connections.yaml │ ├── scripts │ ├── one_layer │ │ ├── gpt4_date.sh │ │ ├── gpt4_DLN-1.sh │ │ ├── any_DLN-1.sh │ │ ├── one_layer.py │ │ ├── any_Nshot.sh │ │ ├── gpt4_Nshot.sh │ │ ├── phi2_DLN-1.sh │ │ ├── phi2_Nshot.sh │ │ └── sweep_bbh.sh │ ├── two_layers_e2e │ │ ├── navigate.sh │ │ ├── date_understanding.sh │ │ ├── logical_deduction_seven_objects.sh │ │ ├── navigate_vllm.sh │ │ ├── phi2_gpt35_DLN-2.sh │ │ ├── phi2_phi2_DLN-2.sh │ │ ├── navigate_sweep.sh │ │ ├── gsm8k.sh │ │ ├── logical_deduction_seven_objects_sweep.sh │ │ ├── sweep_nlu.sh │ │ └── sweep_bbh.sh │ ├── two_layers_fix_2nd │ │ ├── navigate.sh │ │ ├── data_understanding.sh │ │ └── logical_deduction_seven_objects.sh │ ├── two_layers_ft_2nd │ │ ├── navigate.sh │ │ ├── date_understanding.sh │ │ └── logical_deduction_seven_objects.sh │ ├── sweep.sh │ └── sweep_phi2_gpt35.sh │ ├── README.md │ └── read_results.py ├── .gitattributes ├── .vscode └── settings.json ├── requirements.txt ├── CODE_OF_CONDUCT.md ├── SUPPORT.md ├── .github └── workflows │ ├── tests.yml │ └── codeql.yml ├── pyproject.toml ├── LICENSE ├── tests ├── test_dln_score.py ├── test_dln_postprocessing.py ├── test_dln_templates.py ├── test_vi_sampler.py ├── test_dln_losses.py ├── test_vi_layers.py └── conftest.py ├── SECURITY.md ├── .gitignore └── README.md /dln/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/demo/requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | altair 3 | streamlit 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /projects/guide/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | git+https://github.com/microsoft/deep-language-networks.git#egg=dln 3 | -------------------------------------------------------------------------------- /projects/guide/templates/classify_forward.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | template: |- 3 | {{ prompt }} 4 | 5 | {{ input }} 6 | -------------------------------------------------------------------------------- /dln/templates/thoughts_forward.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | template: |- 3 | {{ prompt }} 4 | 5 | {{ input }} 6 | 7 | Thoughts: -------------------------------------------------------------------------------- /dln/templates/analysis_forward.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | template: |- 3 | {{ prompt }} 4 | 5 | {{ input }} 6 | 7 | Brief analysis: -------------------------------------------------------------------------------- /dln/templates/classify_residual.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | template: |- 3 | {{ input }} 4 | Your thoughts were: 5 | {{ output }} 6 | -------------------------------------------------------------------------------- /dln/templates/suffix_forward.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | stop_tokens: ["##", "#"] 3 | template: |- 4 | {{ input }} 5 | 6 | {{ prompt }} -------------------------------------------------------------------------------- /dln/templates/suffix_forward_tbs.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | template: |- 3 | {{ input }} 4 | 5 | {{ prompt }} Let's think step by step. 6 | -------------------------------------------------------------------------------- /projects/demo/dln_demo_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/deep-language-networks/HEAD/projects/demo/dln_demo_sample.png -------------------------------------------------------------------------------- /dln/templates/classify_forward.yaml: -------------------------------------------------------------------------------- 1 | v3.0: 2 | template: |- 3 | {{ prompt }} 4 | 5 | {{ input }} 6 | 7 | Answer: 8 | {{ answer }} -------------------------------------------------------------------------------- /dln/templates/suffix_forward_y.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | template: |- 3 | {{ input }} 4 | 5 | Given that the answer is: 6 | {{ y }} 7 | 8 | {{ prompt }} 9 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [ 3 | "tests" 4 | ], 5 | "python.testing.unittestEnabled": false, 6 | "python.testing.pytestEnabled": true 7 | } -------------------------------------------------------------------------------- /dln/templates/suffix_forward_tbs_y.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | template: |- 3 | {{ input }} 4 | 5 | Given that the answer is: 6 | {{ y }} 7 | 8 | {{ prompt }} Let's think step by step. 9 | -------------------------------------------------------------------------------- /dln/templates/thoughts_forward_y.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | template: |- 3 | {{ prompt }} 4 | 5 | {{ input }} 6 | 7 | Given that the answer is: 8 | {{ y }} 9 | 10 | Thoughts: -------------------------------------------------------------------------------- /dln/templates/analysis_forward_y.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | template: |- 3 | {{ prompt }} 4 | 5 | {{ input }} 6 | 7 | Given that the answer is: 8 | {{ y }} 9 | 10 | Brief analysis: -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | def download_json_from_url(url): 5 | response = requests.get(url) 6 | response.raise_for_status() 7 | json_data = response.json() 8 | return json_data 9 | -------------------------------------------------------------------------------- /projects/demo/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim-bookworm 2 | 3 | WORKDIR /app 4 | COPY requirements.txt . 5 | RUN pip install --no-cache-dir -r requirements.txt 6 | 7 | COPY demo.py . 8 | COPY data.json . 9 | 10 | EXPOSE 8501 11 | 12 | CMD streamlit run demo.py 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | packaging==23.1 2 | torch==2.0.0 3 | openai==1.14.1 4 | jinja2==3.1.4 5 | tenacity==8.2.2 6 | tensorboard==2.13.0 7 | click==8.1.3 8 | termcolor==2.3.0 9 | tiktoken==0.4.0 10 | scipy==1.10.1 11 | numpy 12 | requests==2.31.0 13 | pyyaml 14 | transformers 15 | datasets 16 | pytest 17 | pytest-asyncio -------------------------------------------------------------------------------- /projects/guide/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim-bookworm 2 | 3 | WORKDIR /app 4 | COPY requirements.txt . 5 | 6 | RUN apt update && apt install -y git 7 | RUN pip install --upgrade pip 8 | RUN pip install --no-cache-dir -r requirements.txt 9 | 10 | COPY app.py . 11 | COPY guided_search.py . 12 | COPY templates ./templates 13 | 14 | CMD streamlit run app.py 15 | -------------------------------------------------------------------------------- /projects/vi_dln/connections.yaml: -------------------------------------------------------------------------------- 1 | - name: gpt-35-bwd 2 | model: gpt-35-turbo-instruct 3 | api_key: ${GPT_API_KEY} 4 | base_url: ${GPT_API_BASE} 5 | api_type: ${GPT_API_TYPE} 6 | api_version: ${GPT_API_VERSION} 7 | temperature: 0.7 8 | max_tokens: 512 9 | 10 | - name: phi-2-fwd 11 | model: microsoft/phi-2 12 | api_key: EMPTY 13 | base_url: ${PHI2_API_BASE} 14 | api_type: null 15 | api_version: null 16 | max_tokens: 256 17 | temperature: 0 18 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. 6 | Please search the existing issues before filing new issues to avoid duplicates. 7 | For new issues, file your bug, feature request, or help and questions about using this project, 8 | please fill an Issue on GitHub. 9 | 10 | ## Microsoft Support Policy 11 | 12 | Support for this project is limited to the resources listed above. 13 | -------------------------------------------------------------------------------- /dln/templates/thoughts_backward.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | template: |- 3 | This is useful context to solve the problem: 4 | {{ next_prompt }} 5 | 6 | This is the input: 7 | {{ input }} 8 | 9 | These were your thoughts about the input: 10 | {{ h }} 11 | 12 | Given that this is the correct answer: 13 | {{ y }} 14 | 15 | {{ message }} 16 | Thoughts: 17 | message_alternatives: 18 | - "Refine your thoughts by thinking about what would be useful to add given the context and the correct answer. Be concise." 19 | - "Improve your thoughts about the input to reflect the correct answer and the context." 20 | -------------------------------------------------------------------------------- /dln/templates/analysis_backward.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | template: |- 3 | This is useful context to solve the problem: 4 | {{ next_prompt }} 5 | 6 | This is the input: 7 | {{ input }} 8 | 9 | This was your analysis about the input: 10 | {{ h }} 11 | 12 | Given that this is the correct answer: 13 | {{ y }} 14 | 15 | {{ message }} 16 | Brief analysis: 17 | message_alternatives: 18 | - "Reflect and refine your analysis. What would be useful to add given the context and the correct answer? Be concise." 19 | - "Improve your analysis about the input to make it clearer and such that it reflects the correct answer and the context." 20 | -------------------------------------------------------------------------------- /dln/templates/suffix_backward_h_np_y.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | template: |- 3 | This is useful context to solve the problem: 4 | {{ next_prompt }} 5 | 6 | This is the input: 7 | {{ input }} 8 | 9 | These were your thoughts about the input: 10 | {{ h }} 11 | 12 | Given that this is the correct answer: 13 | {{ y }} 14 | 15 | {{ message }} 16 | Thoughts: 17 | message_alternatives: 18 | - "Reflect and refine your thoughts for this problem by adding detailed explanations." 19 | - "Fix the errors in your reasoning. Add examples to illustrate your thoughts. Be concise." 20 | - "Improve your thoughts to make them clearer and more concise." 21 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: "Tests" 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | # The branches below must be a subset of the branches above 8 | branches: [ "main" ] 9 | schedule: 10 | - cron: '20 17 * * 5' 11 | 12 | jobs: 13 | tests: 14 | name: Test 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | language: [ "python" ] 20 | 21 | steps: 22 | - name: Checkout repository 23 | uses: actions/checkout@v3 24 | - name: Set up python 25 | uses: actions/setup-python@v4 26 | with: 27 | python-version: '3.10' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install -e '.[dev]' 32 | - name: Test with pytest 33 | run: | 34 | pytest -vv -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "dln" 7 | version = "0.0.1" 8 | description = "Deep Language Networks" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | dynamic = ["dependencies"] 12 | 13 | classifiers = [ 14 | "Programming Language :: Python :: 3", 15 | "License :: OSI Approved :: MIT License", 16 | "Operating System :: OS Independent", 17 | ] 18 | 19 | [tool.setuptools.dynamic] 20 | dependencies = {file = ["requirements.txt"]} 21 | 22 | [tool.setuptools] 23 | packages = ["dln"] 24 | 25 | [project.optional-dependencies] 26 | dev = [ 27 | "pytest", 28 | "pytest-asyncio", 29 | ] 30 | 31 | [project.urls] 32 | "Homepage" = "https://github.com/microsoft/deep-language-networks" 33 | "Bug Tracker" = "https://github.com/microsoft/deep-language-networks/issues" 34 | -------------------------------------------------------------------------------- /scripts/setup_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create data folder 4 | mkdir -p data 5 | 6 | # Get Ordered Prompt data 7 | wget https://github.com/yaolu/Ordered-Prompt/archive/refs/heads/main.zip 8 | unzip main.zip 9 | mv ordered-prompt-main/data data/ordered_prompt 10 | rm -rf ordered-prompt-main 11 | rm -f main.zip 12 | 13 | # Get Leopard data 14 | wget https://github.com/iesl/leopard/archive/refs/heads/master.zip 15 | unzip master.zip 16 | mv leopard-master/data/json data/leopard 17 | rm -rf leopard-master 18 | rm -f master.zip 19 | 20 | # Get BBH data 21 | wget https://github.com/suzgunmirac/BIG-Bench-Hard/archive/refs/heads/main.zip 22 | unzip main.zip 23 | mv BIG-Bench-Hard-main/bbh data/ 24 | rm -rf BIG-Bench-Hard-main 25 | rm -f main.zip 26 | 27 | # Preprocess BBH data removing points from BigBench to avoid data contamination 28 | python scripts/split_bigbench_date_understanding.py 29 | python scripts/split_bigbench_hyperbaton.py 30 | python scripts/split_bigbench_logical_deduction_seven_objects.py 31 | python scripts/split_bigbench_navigate.py 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/one_layer/gpt4_date.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=date_understanding 3 | p_class_tpl="classify_forward:3.0" 4 | iters=20 5 | batch_size=10 6 | num_p_samples=15 7 | bwd_temp=0.7 8 | held_out_prompt_ranking=True 9 | use_memory=2 10 | tolerance=2 11 | num_h_samples=5 12 | q_prompt_tpl="q_action_prompt:v3.5" 13 | logp_penalty=2. 14 | posterior_temp=1. 15 | fwd_model_type="gpt-4" 16 | one_layer=True 17 | 18 | dir=log/one_layer_gpt_4_temp_0.7_new/${dataset} 19 | /bin/rm -rf ${dir} 20 | 21 | for seed in 13 42 25; do 22 | python vi_main.py \ 23 | --balance_batch \ 24 | --one_layer ${one_layer} \ 25 | --num_p_samples ${num_p_samples} \ 26 | --bwd_temp ${bwd_temp} \ 27 | --iters ${iters} \ 28 | --q_prompt ${q_prompt_tpl} \ 29 | --p_class ${p_class_tpl} \ 30 | --out_dir ${dir} \ 31 | --batch_size ${batch_size} \ 32 | --seed ${seed} \ 33 | --dataset ${dataset} \ 34 | --use_memory ${use_memory} \ 35 | --tolerance ${tolerance} \ 36 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 37 | --output_scoring_function accuracy \ 38 | --fwd_model_type ${fwd_model_type} 39 | done 40 | -------------------------------------------------------------------------------- /tests/test_dln_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from dln.score import LogProbsScore, OutputClasses 4 | 5 | 6 | def test_logprobs_score_with_output_classes(score_requests, top_logprobs, mock_llm_func): 7 | mock_llm = mock_llm_func("text-davinci-003") 8 | mock_llm._generate = top_logprobs 9 | logprobs_score = LogProbsScore(mock_llm) 10 | 11 | logprobs = logprobs_score.score_requests( 12 | score_requests, output_classes=OutputClasses(protos=["a|A", "b|B"]) 13 | ) 14 | 15 | np.testing.assert_almost_equal(logprobs.logp_targets, [-8.6746863, -0.4428973]) 16 | np.testing.assert_almost_equal( 17 | logprobs.distribution, 18 | [ 19 | [9.99829143e-01, 1.70856546e-04], 20 | [6.42173164e-01, 3.57826836e-01], 21 | ], 22 | ) 23 | 24 | 25 | def test_logprobs_score_without_output_classes(score_requests, raw_logprobs, mock_llm_func): 26 | mock_llm = mock_llm_func("text-davinci-003") 27 | mock_llm._generate = raw_logprobs 28 | logprobs_score = LogProbsScore(mock_llm) 29 | 30 | logprobs = logprobs_score.score_requests(score_requests) 31 | 32 | np.testing.assert_almost_equal(logprobs.logp_targets, [-0.7682657, -0.7632834]) 33 | np.testing.assert_almost_equal(logprobs.distribution, [-2.8217665, -2.73069]) 34 | -------------------------------------------------------------------------------- /dln/postprocessing.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def remove_extra_spaces(input, remove_new_line=False): 5 | assert isinstance(input, str) 6 | output = input 7 | if remove_new_line: 8 | output = output.replace("\n", " ") 9 | # remove extra spaces 10 | while True: 11 | if len(output) == 0 or " " not in output: 12 | break 13 | output = output.replace(" ", " ") 14 | # remove extra new lines 15 | while True: 16 | if len(output) == 0 or "\n\n" not in output: 17 | break 18 | output = output.replace("\n\n", "\n") 19 | return output 20 | 21 | 22 | def postprocess_prediction(input): 23 | assert isinstance(input, str) 24 | output = input 25 | output = re.sub(r"\W+", " ", output) # remove non word 26 | output = re.sub(r"\d+", " ", output) # remove digits 27 | output = remove_extra_spaces(output) 28 | output = output.lower() 29 | 30 | output = output.split() 31 | if len(output) == 0: 32 | return "" 33 | 34 | if len(output) == 1: 35 | return output[0] 36 | 37 | # More than one word 38 | 39 | # Useful when the model predicts "Option (A)" instead of (A). 40 | if "option" == output[0]: 41 | return output[1] 42 | 43 | # Return the first word 44 | return output[0] 45 | -------------------------------------------------------------------------------- /tests/test_dln_postprocessing.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from dln.postprocessing import postprocess_prediction, remove_extra_spaces 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "input,expected", 8 | [ 9 | ("foo bar", "foo bar"), 10 | (" foo bar ", " foo bar "), 11 | ("foo\n\nbar", "foo\nbar"), 12 | ("\nfoo\n\nbar\n", "\nfoo\nbar\n"), 13 | ], 14 | ) 15 | def test_remove_extra_spaces(input, expected): 16 | assert remove_extra_spaces(input, False) == expected 17 | 18 | 19 | @pytest.mark.parametrize( 20 | "input,expected", 21 | [ 22 | ("foo bar", "foo bar"), 23 | (" foo bar ", " foo bar "), 24 | ("foo\n\nbar", "foo bar"), 25 | ("\nfoo\n\nbar\n", " foo bar "), 26 | ], 27 | ) 28 | def test_remove_extra_spaces_and_replace_new_lines(input, expected): 29 | assert remove_extra_spaces(input, True) == expected 30 | 31 | 32 | @pytest.mark.parametrize( 33 | "input,expected", 34 | [ 35 | ("foo@bar", "foo"), 36 | ("foo123bar", "foo"), 37 | (" Foo Bar ", "foo"), 38 | ("", ""), 39 | ("Foo", "foo"), 40 | ("Option (A)", "a"), 41 | ("Option (A, B)", "a"), 42 | ("Foo Bar", "foo"), 43 | ], 44 | ) 45 | def test_postprocess_prediction(input, expected): 46 | assert postprocess_prediction(input) == expected 47 | -------------------------------------------------------------------------------- /projects/demo/README.md: -------------------------------------------------------------------------------- 1 | # Deep Language Networks Demo 2 | 3 | ![Image showing DLN Demo](./dln_demo_sample.png) 4 | 5 | 6 | ## Install dependencies 7 | From the demo directory, run the following command to install the additional dependencies: 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Generate plot for the given data.json file 14 | 15 | ``` 16 | streamlit run demo.py 17 | ``` 18 | 19 | ## Visualizing your own results 20 | 1. Please see the [Variational Inference README](projects/vi_dln/README.md) for information on how to run experiments. 21 | 22 | Your results will be stored in the log directory by default (projects/vi_dln/log/result_data.json). 23 | 24 | Alternatively, you can specify the output directory using the --result_data_path argument. 25 | 26 | 2. Run the demo.py file to generate the plot for your results. 27 | 28 | ``` 29 | streamlit run demo.py /result_data.json 30 | ``` 31 | 32 | You can load multiple result_data.json files by specifying multiple paths. 33 | 34 | ``` 35 | streamlit run demo.py /result_data.json /result_data.json 36 | ``` 37 | 38 | ## Serving from Docker 39 | Run the following commands to serve the demo from a Docker container: 40 | ``` 41 | docker build -t dln-demo . 42 | docker run --name dln-demo --restart unless-stopped -d -p 8001:8501 dln-demo 43 | ``` 44 | Open your browser and navigate to http://localhost:8001. 45 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/one_layer/gpt4_DLN-1.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | p_class_tpl="classify_forward:3.0" 3 | q_prompt_tpl="q_action_prompt:v3.5" 4 | iters=20 5 | batch_size=20 6 | num_p_samples=20 7 | bwd_temp=0.7 8 | held_out_prompt_ranking=True 9 | use_memory=2 10 | tolerance=2 11 | num_h_samples=5 12 | logp_penalty=2. 13 | posterior_temp=1. 14 | fwd_model_type="gpt-4" 15 | one_layer=True 16 | 17 | # dataset 18 | for dataset in subj; do 19 | 20 | dir=log/gpt4/${dataset}/dln-1-20bsz/ 21 | 22 | if [ ! -f ${dir}/done.txt ]; then 23 | /bin/rm -rf ${dir} 24 | 25 | # seed 26 | for seed in 13 42 25; do 27 | python vi_main.py \ 28 | --do_first_eval \ 29 | --balance_batch \ 30 | --one_layer ${one_layer} \ 31 | --train_p1 False \ 32 | --train_p2 True \ 33 | --num_p_samples ${num_p_samples} \ 34 | --bwd_temp ${bwd_temp} \ 35 | --iters ${iters} \ 36 | --q_prompt ${q_prompt_tpl} \ 37 | --p_class ${p_class_tpl} \ 38 | --out_dir ${dir} \ 39 | --batch_size ${batch_size} \ 40 | --seed ${seed} \ 41 | --dataset ${dataset} \ 42 | --use_memory ${use_memory} \ 43 | --tolerance ${tolerance} \ 44 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 45 | --output_scoring_function accuracy \ 46 | --fwd_model_type ${fwd_model_type} 47 | # seed 48 | done 49 | 50 | touch ${dir}/done.txt 51 | fi 52 | # dataset 53 | done -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_e2e/navigate.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=navigate 3 | p_class_tpl="classify_forward:3.0" 4 | iters=20 5 | batch_size=10 6 | num_p_samples=10 7 | bwd_temp=0.7 8 | held_out_prompt_ranking=True 9 | use_memory=5 10 | tolerance=2 11 | num_h_samples=10 12 | q_prompt_tpl="q_action_prompt:v3.5" 13 | logp_penalty=2. 14 | posterior_temp=1. 15 | trust_factor=5. 16 | p_hidden_tpl="suffix_forward_tbs" 17 | q_hidden_tpl="suffix_forward_tbs_y|suffix_forward_tbs" 18 | 19 | dir=log/two_layers_e2e/${dataset} 20 | /bin/rm -rf ${dir} 21 | 22 | for seed in 13 42 25; do 23 | python vi_main.py \ 24 | --balance_batch \ 25 | --num_p_samples ${num_p_samples} \ 26 | --num_h_samples ${num_h_samples} \ 27 | --bwd_temp ${bwd_temp} \ 28 | --iters ${iters} \ 29 | --p_hidden ${p_hidden_tpl} \ 30 | --q_hidden ${q_hidden_tpl} \ 31 | --q_prompt ${q_prompt_tpl} \ 32 | --p_class ${p_class_tpl} \ 33 | --out_dir ${dir} \ 34 | --batch_size ${batch_size} \ 35 | --seed ${seed} \ 36 | --dataset ${dataset} \ 37 | --use_memory ${use_memory} \ 38 | --tolerance ${tolerance} \ 39 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 40 | --trust_factor ${trust_factor} \ 41 | --train_p1 True \ 42 | --train_p2 True \ 43 | --forward_use_classes True \ 44 | --logp_penalty ${logp_penalty} \ 45 | --posterior_temp ${posterior_temp} \ 46 | --strip_options_for_hidden True 47 | done 48 | -------------------------------------------------------------------------------- /dln/templates/q_action_prompt_seq.yaml: -------------------------------------------------------------------------------- 1 | v1.0: 2 | stop_tokens: ['4.', '#'] 3 | template: |- 4 | A student is completing a task that requires producing a text output from a text input. The student receives an instruction that describes how to produce the output given each input. 5 | The student has made some errors. Your task is to improve the instruction such that the student can fix the errors. 6 | 7 | This was the instruction. 8 | {{ prompt }} 9 | 10 | # Student successes 11 | {% for backward_info in backward_infos %} {% if backward_info.loss == 0.0 %} 12 | ## Input: 13 | > {{ backward_info.input }} 14 | ## Correct Output: 15 | > {{ backward_info.target }} 16 | {% endif %} {% endfor %} 17 | 18 | # Student errors 19 | {% for backward_info in backward_infos %} {% if backward_info.loss > 0.0 %} 20 | ## Input: 21 | > {{ backward_info.input }} 22 | ## Student Output: 23 | > {{ backward_info.output }} 24 | ## Correct Output: 25 | > {{ backward_info.target }} 26 | {% endif %} {% endfor %} 27 | 28 | Improve the instruction to fix the student errors. {{ message }}. 29 | Propose 3 new instructions: 30 | 1. 31 | message_alternatives: 32 | - Clarify the instruction by adding few words or a short sentence. Be concise. 33 | - Improve the instruction by providing examples on how to solve the task. Be concise. 34 | - Shorten the instruction by removing superflous words or sentences. 35 | - Rewrite the instruction by providing detailed information to avoid ambiguity. Be concise. 36 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_fix_2nd/navigate.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | best_weights_file_path="one-layer-dln-hp-search-result.json" 3 | dataset=navigate 4 | p_class_tpl="classify_forward:3.0" 5 | iters=20 6 | batch_size=20 7 | num_p_samples=20 8 | num_h_samples=5 9 | bwd_temp=0.7 10 | use_h_argmax=False 11 | trust_factor=0. 12 | held_out_prompt_ranking=True 13 | 14 | q_prompt_tpl="q_action_prompt:v3.0" 15 | logp_penalty=0.5 16 | use_memory=2 17 | tolerance=2 18 | p_hidden_tpl="suffix_forward_tbs" 19 | q_hidden_tpl="suffix_backward_h_np_y" 20 | 21 | dir=log/two_layers_fix_2nd/${dataset} 22 | /bin/rm -rf ${dir} 23 | 24 | for seed in 13 42 25; do 25 | python vi_main.py \ 26 | --balance_batch \ 27 | --num_p_samples ${num_p_samples} \ 28 | --num_h_samples ${num_h_samples} \ 29 | --bwd_temp ${bwd_temp} \ 30 | --iters ${iters} \ 31 | --p_hidden ${p_hidden_tpl} \ 32 | --q_hidden ${q_hidden_tpl} \ 33 | --q_prompt ${q_prompt_tpl} \ 34 | --p_class ${p_class_tpl} \ 35 | --out_dir ${dir} \ 36 | --batch_size ${batch_size} \ 37 | --seed ${seed} \ 38 | --dataset ${dataset} \ 39 | --use_h_argmax ${use_h_argmax} \ 40 | --use_memory ${use_memory} \ 41 | --tolerance ${tolerance} \ 42 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 43 | --trust_factor ${trust_factor} \ 44 | --train_p1 True \ 45 | --train_p2 False \ 46 | --init_p2 ${best_weights_file_path} \ 47 | --logp_penalty ${logp_penalty} \ 48 | --do_first_eval 49 | done 50 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/one_layer/any_DLN-1.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | p_class_tpl="classify_forward:3.0" 3 | iters=20 4 | batch_size=20 5 | num_p_samples=10 6 | bwd_temp=0.7 7 | held_out_prompt_ranking=True 8 | use_memory=2 9 | tolerance=2 10 | num_h_samples=5 11 | q_prompt_tpl="q_action_prompt:v3.5" 12 | logp_penalty=2. 13 | posterior_temp=1. 14 | fwd_model_type="text-davinci-003" 15 | bwd_model_type="text-davinci-003" 16 | one_layer=True 17 | 18 | # dataset 19 | for dataset in navigate subj logical_deduction_seven_objects; do 20 | 21 | dir=log/any/${dataset}/DLN-1/ 22 | 23 | if [ ! -f ${dir}/done.txt ]; then 24 | /bin/rm -rf ${dir} 25 | 26 | # seed 27 | for seed in 13 42 25; do 28 | python vi_main.py \ 29 | --balance_batch \ 30 | --one_layer ${one_layer} \ 31 | --do_first_eval \ 32 | --num_p_samples ${num_p_samples} \ 33 | --bwd_temp ${bwd_temp} \ 34 | --iters ${iters} \ 35 | --q_prompt ${q_prompt_tpl} \ 36 | --p_class ${p_class_tpl} \ 37 | --out_dir ${dir} \ 38 | --batch_size ${batch_size} \ 39 | --seed ${seed} \ 40 | --dataset ${dataset} \ 41 | --use_memory ${use_memory} \ 42 | --tolerance ${tolerance} \ 43 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 44 | --output_scoring_function logprobs \ 45 | --fwd_model_type ${fwd_model_type} \ 46 | --bwd_model_type ${bwd_model_type} 47 | # seed 48 | done 49 | 50 | touch ${dir}/done.txt 51 | fi 52 | # dataset 53 | done 54 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/one_layer/one_layer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | 4 | import click 5 | 6 | 7 | @click.command() 8 | @click.option( 9 | "--dataset", 10 | type=str, 11 | help="Dataset name", 12 | required=True, 13 | ) 14 | def main(dataset): 15 | config_file="one-layer-dln-hp-search-result.json" 16 | with open(config_file) as f: 17 | config_data = json.load(f) 18 | 19 | config = config_data[dataset] 20 | q_prompt_tpl = config["hyperparam"]["q_prompt_tpl"] 21 | tolerance = config["hyperparam"]["tolerance"] 22 | use_memory = config["hyperparam"]["use_memory"] 23 | held_out_prompt_ranking = config["hyperparam"]["held_out_prompt_ranking"] 24 | 25 | output_dir = f"log/one_layer/{dataset}" 26 | for seed in [13, 42, 25]: 27 | command = list(map(str, [ 28 | "python", 29 | "vi_main.py", 30 | "--balance_batch", 31 | "--num_p_samples", 20, 32 | "--bwd_temp", 0.7, 33 | "--iters", 20, 34 | "--p_class", "classify_forward:3.0", 35 | "--q_prompt", q_prompt_tpl, 36 | "--out_dir", output_dir, 37 | "--batch_size", 20, 38 | "--seed", seed, 39 | "--dataset", dataset, 40 | "--tolerance", tolerance, 41 | "--use_memory", use_memory, 42 | "--held_out_prompt_ranking", held_out_prompt_ranking, 43 | "--one_layer", True, 44 | "--do_first_eval", 45 | ])) 46 | print(' '.join(command)) 47 | subprocess.run(command) 48 | 49 | if __name__ == "__main__": 50 | main() -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_fix_2nd/data_understanding.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | best_weights_file_path="one-layer-dln-hp-search-result.json" 3 | dataset=date_understanding 4 | p_class_tpl="classify_forward:3.0" 5 | iters=20 6 | batch_size=20 7 | num_p_samples=20 8 | num_h_samples=5 9 | bwd_temp=0.7 10 | use_h_argmax=False 11 | trust_factor=0. 12 | held_out_prompt_ranking=True 13 | 14 | q_prompt_tpl="q_action_prompt:v3.0" 15 | logp_penalty=0.5 16 | use_memory=2 17 | tolerance=2 18 | p_hidden_tpl="suffix_forward_tbs" 19 | q_hidden_tpl="suffix_backward_h_np_y" 20 | 21 | dir=log/two_layers_fix_2nd/${dataset} 22 | /bin/rm -rf ${dir} 23 | 24 | for seed in 13 42 25; do 25 | python vi_main.py \ 26 | --balance_batch \ 27 | --num_p_samples ${num_p_samples} \ 28 | --num_h_samples ${num_h_samples} \ 29 | --bwd_temp ${bwd_temp} \ 30 | --iters ${iters} \ 31 | --p_hidden ${p_hidden_tpl} \ 32 | --q_hidden ${q_hidden_tpl} \ 33 | --q_prompt ${q_prompt_tpl} \ 34 | --p_class ${p_class_tpl} \ 35 | --out_dir ${dir} \ 36 | --batch_size ${batch_size} \ 37 | --seed ${seed} \ 38 | --dataset ${dataset} \ 39 | --use_h_argmax ${use_h_argmax} \ 40 | --use_memory ${use_memory} \ 41 | --tolerance ${tolerance} \ 42 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 43 | --trust_factor ${trust_factor} \ 44 | --train_p1 True \ 45 | --train_p2 False \ 46 | --init_p2 ${best_weights_file_path} \ 47 | --logp_penalty ${logp_penalty} \ 48 | --do_first_eval 49 | done 50 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_e2e/date_understanding.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=date_understanding 3 | p_class_tpl="classify_forward:3.0" 4 | iters=20 5 | batch_size=10 6 | num_p_samples=10 7 | bwd_temp=0.7 8 | held_out_prompt_ranking=True 9 | use_memory=2 10 | tolerance=2 11 | num_h_samples=5 12 | trust_factor=1. 13 | q_prompt_tpl="q_action_prompt:v3.5" 14 | logp_penalty=1. 15 | posterior_temp=1. 16 | p_hidden_tpl="suffix_forward_tbs" 17 | q_hidden_tpl="suffix_forward_tbs_y|suffix_forward_tbs" 18 | 19 | dir=log/two_layers_e2e/${dataset} 20 | /bin/rm -rf ${dir} 21 | 22 | for seed in 13 42 25; do 23 | python vi_main.py \ 24 | --do_first_eval \ 25 | --balance_batch \ 26 | --num_p_samples ${num_p_samples} \ 27 | --num_h_samples ${num_h_samples} \ 28 | --bwd_temp ${bwd_temp} \ 29 | --iters ${iters} \ 30 | --p_hidden ${p_hidden_tpl} \ 31 | --q_hidden ${q_hidden_tpl} \ 32 | --q_prompt ${q_prompt_tpl} \ 33 | --p_class ${p_class_tpl} \ 34 | --out_dir ${dir} \ 35 | --batch_size ${batch_size} \ 36 | --seed ${seed} \ 37 | --dataset ${dataset} \ 38 | --use_memory ${use_memory} \ 39 | --tolerance ${tolerance} \ 40 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 41 | --trust_factor ${trust_factor} \ 42 | --train_p1 True \ 43 | --train_p2 True \ 44 | --init_p2 "" \ 45 | --forward_use_classes True \ 46 | --logp_penalty ${logp_penalty} \ 47 | --posterior_temp ${posterior_temp} \ 48 | --strip_options_for_hidden True 49 | done 50 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/one_layer/any_Nshot.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=logical_deduction_seven_objects 3 | p_class_tpl="classify_forward:3.0" 4 | iters=20 5 | batch_size=10 6 | num_p_samples=15 7 | bwd_temp=0.7 8 | held_out_prompt_ranking=True 9 | use_memory=2 10 | tolerance=2 11 | num_h_samples=5 12 | q_prompt_tpl="q_action_prompt:v3.5" 13 | logp_penalty=2. 14 | posterior_temp=1. 15 | fwd_model_type="local-2" 16 | one_layer=True 17 | 18 | # dataset 19 | for dataset in navigate subj logical_deduction_seven_objects; do 20 | # n_shot 21 | for n_shot in 0 5; do 22 | 23 | dir=log/any/${dataset}/${n_shot}shot/ 24 | 25 | if [ ! -f ${dir}/done.txt ]; then 26 | /bin/rm -rf ${dir} 27 | 28 | # seed 29 | for seed in 13 42; do 30 | python vi_main.py \ 31 | --balance_batch \ 32 | --n_shots ${n_shot} \ 33 | --one_layer ${one_layer} \ 34 | --num_p_samples ${num_p_samples} \ 35 | --bwd_temp ${bwd_temp} \ 36 | --iters ${iters} \ 37 | --q_prompt ${q_prompt_tpl} \ 38 | --p_class ${p_class_tpl} \ 39 | --out_dir ${dir} \ 40 | --batch_size ${batch_size} \ 41 | --seed ${seed} \ 42 | --dataset ${dataset} \ 43 | --use_memory ${use_memory} \ 44 | --tolerance ${tolerance} \ 45 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 46 | --output_scoring_function accuracy \ 47 | --fwd_model_type ${fwd_model_type} 48 | # seed 49 | done 50 | 51 | touch ${dir}/done.txt 52 | fi 53 | # n shot 54 | done 55 | # dataset 56 | done 57 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/one_layer/gpt4_Nshot.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | p_class_tpl="classify_forward:3.0" 3 | iters=20 4 | batch_size=10 5 | num_p_samples=15 6 | bwd_temp=0.7 7 | held_out_prompt_ranking=True 8 | use_memory=2 9 | tolerance=2 10 | num_h_samples=5 11 | q_prompt_tpl="q_action_prompt:v3.5" 12 | logp_penalty=2. 13 | posterior_temp=1. 14 | fwd_model_type="gpt-4" 15 | one_layer=True 16 | 17 | # dataset 18 | for dataset in date_understanding logical_deduction_seven_objects subj navigate; do 19 | # n_shot 20 | for n_shot in 0 5 16; do 21 | 22 | dir=log/gpt4/cost/${dataset}/${n_shot}shot/ 23 | 24 | if [ ! -f ${dir}/done.txt ]; then 25 | /bin/rm -rf ${dir} 26 | 27 | # seed 28 | for seed in 13 42 25; do 29 | python vi_main.py \ 30 | --cost_only \ 31 | --balance_batch \ 32 | --n_shots ${n_shot} \ 33 | --one_layer ${one_layer} \ 34 | --num_p_samples ${num_p_samples} \ 35 | --bwd_temp ${bwd_temp} \ 36 | --iters ${iters} \ 37 | --q_prompt ${q_prompt_tpl} \ 38 | --p_class ${p_class_tpl} \ 39 | --out_dir ${dir} \ 40 | --batch_size ${batch_size} \ 41 | --seed ${seed} \ 42 | --dataset ${dataset} \ 43 | --use_memory ${use_memory} \ 44 | --tolerance ${tolerance} \ 45 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 46 | --output_scoring_function accuracy \ 47 | --fwd_model_type ${fwd_model_type} 48 | # seed 49 | done 50 | 51 | touch ${dir}/done.txt 52 | fi 53 | # n shot 54 | done 55 | # dataset 56 | done -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_fix_2nd/logical_deduction_seven_objects.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | best_weights_file_path="one-layer-dln-hp-search-result.json" 3 | dataset=logical_deduction_seven_objects 4 | p_class_tpl="classify_forward:3.0" 5 | iters=20 6 | batch_size=20 7 | num_p_samples=20 8 | num_h_samples=5 9 | bwd_temp=0.7 10 | use_h_argmax=False 11 | trust_factor=0. 12 | held_out_prompt_ranking=True 13 | 14 | q_prompt_tpl="q_action_prompt:v3.0" 15 | logp_penalty=0.5 16 | use_memory=2 17 | tolerance=2 18 | p_hidden_tpl="suffix_forward_tbs" 19 | q_hidden_tpl="suffix_backward_h_np_y" 20 | 21 | dir=log/two_layers_fix_2nd/${dataset} 22 | /bin/rm -rf ${dir} 23 | 24 | for seed in 13 42 25; do 25 | python vi_main.py \ 26 | --balance_batch \ 27 | --num_p_samples ${num_p_samples} \ 28 | --num_h_samples ${num_h_samples} \ 29 | --bwd_temp ${bwd_temp} \ 30 | --iters ${iters} \ 31 | --p_hidden ${p_hidden_tpl} \ 32 | --q_hidden ${q_hidden_tpl} \ 33 | --q_prompt ${q_prompt_tpl} \ 34 | --p_class ${p_class_tpl} \ 35 | --out_dir ${dir} \ 36 | --batch_size ${batch_size} \ 37 | --seed ${seed} \ 38 | --dataset ${dataset} \ 39 | --use_h_argmax ${use_h_argmax} \ 40 | --use_memory ${use_memory} \ 41 | --tolerance ${tolerance} \ 42 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 43 | --trust_factor ${trust_factor} \ 44 | --train_p1 True \ 45 | --train_p2 False \ 46 | --init_p2 ${best_weights_file_path} \ 47 | --logp_penalty ${logp_penalty} \ 48 | --do_first_eval 49 | done 50 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_ft_2nd/navigate.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | best_weights_file_path="one-layer-dln-hp-search-result.json" 3 | dataset=navigate 4 | p_class_tpl="classify_forward:3.0" 5 | iters=20 6 | batch_size=20 7 | num_p_samples=20 8 | num_h_samples=5 9 | bwd_temp=0.7 10 | use_h_argmax=False 11 | trust_factor=0.25 12 | deactivate_vi=False 13 | held_out_prompt_ranking=True 14 | 15 | q_prompt_tpl="q_action_prompt:v3.5" 16 | logp_penalty=2. 17 | use_memory=2 18 | tolerance=2 19 | p_hidden_tpl="suffix_forward_tbs" 20 | q_hidden_tpl="suffix_backward_h_np_y" 21 | 22 | dir=log/two_layers_ft_2nd/${dataset} 23 | /bin/rm -rf ${dir} 24 | 25 | for seed in 13 42 25; do 26 | python vi_main.py \ 27 | --balance_batch \ 28 | --num_p_samples ${num_p_samples} \ 29 | --num_h_samples ${num_h_samples} \ 30 | --bwd_temp ${bwd_temp} \ 31 | --iters ${iters} \ 32 | --p_hidden ${p_hidden_tpl} \ 33 | --q_hidden ${q_hidden_tpl} \ 34 | --q_prompt ${q_prompt_tpl} \ 35 | --p_class ${p_class_tpl} \ 36 | --out_dir ${dir} \ 37 | --batch_size ${batch_size} \ 38 | --seed ${seed} \ 39 | --dataset ${dataset} \ 40 | --use_h_argmax ${use_h_argmax} \ 41 | --use_memory ${use_memory} \ 42 | --tolerance ${tolerance} \ 43 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 44 | --trust_factor ${trust_factor} \ 45 | --train_p1 True \ 46 | --train_p2 True \ 47 | --init_p2 ${best_weights_file_path} \ 48 | --logp_penalty ${logp_penalty} \ 49 | --forward_use_classes True \ 50 | --do_first_eval 51 | done 52 | -------------------------------------------------------------------------------- /tests/test_dln_templates.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import yaml 3 | 4 | from dln.template import DLNTemplate, Templates, load_template 5 | 6 | 7 | def test_DLNTemplate_render(): 8 | template = DLNTemplate(template="{{ message }}") 9 | rendered = template.render(message="Foo bar!") 10 | assert rendered == "Foo bar!" 11 | 12 | 13 | def test_DLNTemplate_render_default_message(): 14 | template = DLNTemplate(template="{{ message }}", message="Default foo bar") 15 | rendered = template.render() 16 | assert rendered == "Default foo bar" 17 | 18 | 19 | def test_template_get_template(): 20 | suffix_forward = Templates.get("suffix_forward") 21 | assert suffix_forward.template == "{{ input }}\n\n{{ prompt }}" 22 | 23 | 24 | def test_template_template_not_found(): 25 | with pytest.raises(KeyError): 26 | Templates.get("foo") 27 | 28 | 29 | def test_load_template(): 30 | template = load_template("suffix_forward") 31 | rendered = template.render(input="input test", prompt="prompt test") 32 | assert rendered == ("""input test\n\nprompt test""") 33 | 34 | 35 | def test_custom_template_directory(tmp_path): 36 | custom_template_dir = tmp_path / "templates" 37 | custom_template_dir.mkdir() 38 | template_file = custom_template_dir / "custom_template.yaml" 39 | template_content = {"v1.0": {"template": "Custom template: {{ message }}"}} 40 | with open(template_file, "w") as f: 41 | f.write(yaml.dump(template_content)) 42 | template = load_template("custom_template", template_directory=custom_template_dir) 43 | rendered = template.render(message="my message!") 44 | assert rendered == "Custom template: my message!" 45 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_ft_2nd/date_understanding.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | best_weights_file_path="one-layer-dln-hp-search-result.json" 3 | dataset=date_understanding 4 | p_class_tpl="classify_forward:3.0" 5 | iters=20 6 | batch_size=20 7 | num_p_samples=20 8 | num_h_samples=5 9 | bwd_temp=0.7 10 | use_h_argmax=False 11 | trust_factor=0.25 12 | deactivate_vi=False 13 | held_out_prompt_ranking=True 14 | 15 | q_prompt_tpl="q_action_prompt:v3.5" 16 | logp_penalty=0.5 17 | use_memory=2 18 | tolerance=2 19 | p_hidden_tpl="suffix_forward_tbs" 20 | q_hidden_tpl="suffix_backward_h_np_y" 21 | 22 | dir=log/two_layers_ft_2nd/${dataset} 23 | /bin/rm -rf ${dir} 24 | 25 | for seed in 13 42 25; do 26 | python vi_main.py \ 27 | --balance_batch \ 28 | --num_p_samples ${num_p_samples} \ 29 | --num_h_samples ${num_h_samples} \ 30 | --bwd_temp ${bwd_temp} \ 31 | --iters ${iters} \ 32 | --p_hidden ${p_hidden_tpl} \ 33 | --q_hidden ${q_hidden_tpl} \ 34 | --q_prompt ${q_prompt_tpl} \ 35 | --p_class ${p_class_tpl} \ 36 | --out_dir ${dir} \ 37 | --batch_size ${batch_size} \ 38 | --seed ${seed} \ 39 | --dataset ${dataset} \ 40 | --use_h_argmax ${use_h_argmax} \ 41 | --use_memory ${use_memory} \ 42 | --tolerance ${tolerance} \ 43 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 44 | --trust_factor ${trust_factor} \ 45 | --train_p1 True \ 46 | --train_p2 True \ 47 | --init_p2 ${best_weights_file_path} \ 48 | --logp_penalty ${logp_penalty} \ 49 | --forward_use_classes True \ 50 | --do_first_eval 51 | done 52 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_ft_2nd/logical_deduction_seven_objects.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | best_weights_file_path="one-layer-dln-hp-search-result.json" 3 | dataset=logical_deduction_seven_objects 4 | p_class_tpl="classify_forward:3.0" 5 | iters=20 6 | batch_size=20 7 | num_p_samples=20 8 | num_h_samples=5 9 | bwd_temp=0.7 10 | use_h_argmax=False 11 | trust_factor=0.25 12 | deactivate_vi=False 13 | held_out_prompt_ranking=True 14 | 15 | q_prompt_tpl="q_action_prompt:v3.5" 16 | logp_penalty=0.5 17 | use_memory=2 18 | tolerance=2 19 | p_hidden_tpl="suffix_forward_tbs" 20 | q_hidden_tpl="suffix_backward_h_np_y" 21 | 22 | dir=log/two_layers_ft_2nd/${dataset} 23 | /bin/rm -rf ${dir} 24 | 25 | for seed in 13 42 25; do 26 | python vi_main.py \ 27 | --balance_batch \ 28 | --num_p_samples ${num_p_samples} \ 29 | --num_h_samples ${num_h_samples} \ 30 | --bwd_temp ${bwd_temp} \ 31 | --iters ${iters} \ 32 | --p_hidden ${p_hidden_tpl} \ 33 | --q_hidden ${q_hidden_tpl} \ 34 | --q_prompt ${q_prompt_tpl} \ 35 | --p_class ${p_class_tpl} \ 36 | --out_dir ${dir} \ 37 | --batch_size ${batch_size} \ 38 | --seed ${seed} \ 39 | --dataset ${dataset} \ 40 | --use_h_argmax ${use_h_argmax} \ 41 | --use_memory ${use_memory} \ 42 | --tolerance ${tolerance} \ 43 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 44 | --trust_factor ${trust_factor} \ 45 | --train_p1 True \ 46 | --train_p2 True \ 47 | --init_p2 ${best_weights_file_path} \ 48 | --logp_penalty ${logp_penalty} \ 49 | --forward_use_classes True \ 50 | --do_first_eval 51 | done 52 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_e2e/logical_deduction_seven_objects.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=logical_deduction_seven_objects 3 | p_class_tpl="classify_forward:3.0" 4 | iters=20 5 | batch_size=10 6 | num_p_samples=10 7 | bwd_temp=0.7 8 | held_out_prompt_ranking=True 9 | use_memory=2 10 | tolerance=2 11 | num_h_samples=5 12 | fwd_max_tokens=512 13 | bwd_max_tokens=1024 14 | trust_factor=0. 15 | q_prompt_tpl="q_action_prompt:v3.5" 16 | logp_penalty=1. 17 | posterior_temp=0.75 18 | p_hidden_tpl="suffix_forward_tbs" 19 | q_hidden_tpl="suffix_forward_tbs_y|suffix_forward_tbs" 20 | 21 | dir=log/two_layers_e2e/${dataset} 22 | /bin/rm -rf ${dir} 23 | 24 | for seed in 13 42 25; do 25 | python vi_main.py \ 26 | --do_first_eval \ 27 | --balance_batch \ 28 | --num_p_samples ${num_p_samples} \ 29 | --num_h_samples ${num_h_samples} \ 30 | --bwd_temp ${bwd_temp} \ 31 | --iters ${iters} \ 32 | --p_hidden ${p_hidden_tpl} \ 33 | --q_hidden ${q_hidden_tpl} \ 34 | --q_prompt ${q_prompt_tpl} \ 35 | --p_class ${p_class_tpl} \ 36 | --out_dir ${dir} \ 37 | --batch_size ${batch_size} \ 38 | --seed ${seed} \ 39 | --dataset ${dataset} \ 40 | --use_memory ${use_memory} \ 41 | --tolerance ${tolerance} \ 42 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 43 | --trust_factor ${trust_factor} \ 44 | --train_p1 True \ 45 | --train_p2 True \ 46 | --init_p2 "" \ 47 | --forward_use_classes True \ 48 | --logp_penalty ${logp_penalty} \ 49 | --posterior_temp ${posterior_temp} \ 50 | --strip_options_for_hidden True \ 51 | --fwd_max_tokens ${fwd_max_tokens} \ 52 | --bwd_max_tokens ${bwd_max_tokens} 53 | done 54 | -------------------------------------------------------------------------------- /projects/vi_dln/README.md: -------------------------------------------------------------------------------- 1 | # Variational Inference 2 | 3 | ## Setup 4 | 5 | Please follow the instructions from the [main README](../../README.md). 6 | 7 | 8 | ## Reproducing results 9 | 10 | 11 | ### One-Layer DLN 12 | 13 | See [one-layer-dln-hp-search-result.json](./one-layer-dln-hp-search-result.json) for available datasets and the hyperparameter search results. Then run: 14 | 15 | python scripts/one_layer/one_layer.py --dataset 16 | 17 | 18 | ### Two-Layer DLN 19 | 20 | > :warning: **Warning:** Setting `echo` and `logprobs` simultaneously is no longer supported for certain OpenAI models. 21 | However, optimizing prompts jointly for 2-DLN using variational inference requires both settings. 22 | To run 2-DLN experiments, consider hosting your own model (see [self-hosted models](../../README.md#setup-self-hosted-models-vllm)). 23 | Alternatively, you can run 1-DNL by setting output_scoring_function="accuracy" and --one_layer=True. 24 | 25 | 26 | For two-layer DLN, you can select one of the following training strategies: 27 | 28 | - `two_layers_fix_2nd`: Load pretrained prompts from [one-layer-dln-hp-search-result.json](./one-layer-dln-hp-search-result.json) to the second layer and train only the first layer. 29 | 30 | - `two_layers_ft_2nd`: Load pretrained prompts from [one-layer-dln-hp-search-result.json](./one-layer-dln-hp-search-result.json) to the second layer and train both the first and second layers. 31 | 32 | - `two_layers_e2e`: Train the two layers from scratch. 33 | 34 | Then, run the following command: 35 | 36 | bash scripts//.sh 37 | 38 | Results will be saved in `log//`. 39 | 40 | 41 | ## Running your own experiments 42 | 43 | If you decide to run your own experiments, please check: 44 | 45 | python vi_main.py --help 46 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_e2e/navigate_vllm.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=navigate 3 | p_class_tpl="classify_forward:3.0" 4 | iters=10 5 | batch_size=5 6 | num_p_samples=10 7 | bwd_temp=0.7 8 | held_out_prompt_ranking=True 9 | use_memory=2 10 | tolerance=2 11 | num_h_samples=5 12 | q_prompt_tpl="q_action_prompt:v3.5" 13 | logp_penalty=2. 14 | posterior_temp=1. 15 | trust_factor=5. 16 | p_hidden_tpl="suffix_forward_tbs" 17 | q_hidden_tpl="suffix_forward_tbs_y|suffix_forward_tbs" 18 | fwd_model_type="meta-llama/Llama-2-70b-chat-hf" 19 | fwd_max_tokens=512 20 | bwd_max_tokens=1024 21 | p1_max_tokens=512 22 | p2_max_tokens=512 23 | 24 | dir=log/{fwd_model_type}/two_layers_e2e/${dataset} 25 | 26 | 27 | for seed in 42; do 28 | python vi_main.py \ 29 | --do_first_eval \ 30 | --balance_batch \ 31 | --fwd_model_type ${fwd_model_type} \ 32 | --num_p_samples ${num_p_samples} \ 33 | --num_h_samples ${num_h_samples} \ 34 | --bwd_temp ${bwd_temp} \ 35 | --iters ${iters} \ 36 | --p_hidden ${p_hidden_tpl} \ 37 | --q_hidden ${q_hidden_tpl} \ 38 | --q_prompt ${q_prompt_tpl} \ 39 | --p_class ${p_class_tpl} \ 40 | --out_dir ${dir} \ 41 | --batch_size ${batch_size} \ 42 | --seed ${seed} \ 43 | --dataset ${dataset} \ 44 | --use_memory ${use_memory} \ 45 | --tolerance ${tolerance} \ 46 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 47 | --trust_factor ${trust_factor} \ 48 | --train_p1 True \ 49 | --train_p2 True \ 50 | --forward_use_classes True \ 51 | --logp_penalty ${logp_penalty} \ 52 | --posterior_temp ${posterior_temp} \ 53 | --fwd_max_tokens ${fwd_max_tokens} \ 54 | --bwd_max_tokens ${bwd_max_tokens} \ 55 | --p1_max_tokens ${p1_max_tokens} \ 56 | --p2_max_tokens ${p2_max_tokens} \ 57 | --strip_options_for_hidden True 58 | done 59 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_e2e/phi2_gpt35_DLN-2.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=navigate 3 | p_class_tpl="classify_forward:3.0" 4 | iters=20 5 | batch_size=10 6 | num_p_samples=10 7 | bwd_temp=0.7 8 | held_out_prompt_ranking=True 9 | use_memory=5 10 | tolerance=2 11 | num_h_samples=10 12 | q_prompt_tpl="q_action_prompt:v3.5" 13 | logp_penalty=2. 14 | posterior_temp=1. 15 | trust_factor=5. 16 | p_hidden_tpl="suffix_forward_tbs" 17 | q_hidden_tpl="suffix_forward_tbs_y|suffix_forward_tbs" 18 | fwd_model_type="/vllm/phi-2" 19 | bwd_model_type="gpt-35-turbo" 20 | 21 | # dataset 22 | for dataset in hyperbaton navigate date_understanding logical_deduction_seven_objects mpqa trec subj disaster airline; do 23 | 24 | dir=log/phi2/${dataset}/DLN-2/ 25 | 26 | if [ ! -f ${dir}/done.txt ]; then 27 | /bin/rm -rf ${dir} 28 | 29 | # seed 30 | for seed in 13 42 25; do 31 | python vi_main.py \ 32 | --balance_batch \ 33 | --do_first_eval \ 34 | --num_p_samples ${num_p_samples} \ 35 | --num_h_samples ${num_h_samples} \ 36 | --bwd_temp ${bwd_temp} \ 37 | --iters ${iters} \ 38 | --p_hidden ${p_hidden_tpl} \ 39 | --q_hidden ${q_hidden_tpl} \ 40 | --q_prompt ${q_prompt_tpl} \ 41 | --p_class ${p_class_tpl} \ 42 | --out_dir ${dir} \ 43 | --batch_size ${batch_size} \ 44 | --seed ${seed} \ 45 | --dataset ${dataset} \ 46 | --use_memory ${use_memory} \ 47 | --tolerance ${tolerance} \ 48 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 49 | --trust_factor ${trust_factor} \ 50 | --train_p1 True \ 51 | --train_p2 True \ 52 | --logp_penalty ${logp_penalty} \ 53 | --posterior_temp ${posterior_temp} \ 54 | --output_scoring_function logprobs \ 55 | --fwd_model_type ${fwd_model_type} \ 56 | --bwd_model_type ${bwd_model_type} 57 | done 58 | fi 59 | 60 | # dataset 61 | done -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_e2e/phi2_phi2_DLN-2.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=navigate 3 | p_class_tpl="classify_forward:3.0" 4 | iters=20 5 | batch_size=10 6 | num_p_samples=10 7 | bwd_temp=0.7 8 | held_out_prompt_ranking=True 9 | use_memory=5 10 | tolerance=2 11 | num_h_samples=10 12 | q_prompt_tpl="q_action_prompt:v3.6" 13 | logp_penalty=2. 14 | posterior_temp=1. 15 | trust_factor=5. 16 | p_hidden_tpl="suffix_forward_tbs" 17 | q_hidden_tpl="suffix_forward_tbs_y|suffix_forward_tbs" 18 | fwd_model_type="microsoft/phi-2" 19 | bwd_model_type="microsoft/phi-2" 20 | 21 | # dataset 22 | for dataset in hyperbaton navigate date_understanding logical_deduction_seven_objects mpqa trec subj disaster airline; do 23 | 24 | dir=log/phi2/${dataset}/DLN-2/ 25 | 26 | if [ ! -f ${dir}/done.txt ]; then 27 | /bin/rm -rf ${dir} 28 | 29 | # seed 30 | for seed in 13 42 25; do 31 | python vi_main.py \ 32 | --balance_batch \ 33 | --do_first_eval \ 34 | --num_p_samples ${num_p_samples} \ 35 | --num_h_samples ${num_h_samples} \ 36 | --bwd_temp ${bwd_temp} \ 37 | --iters ${iters} \ 38 | --p_hidden ${p_hidden_tpl} \ 39 | --q_hidden ${q_hidden_tpl} \ 40 | --q_prompt ${q_prompt_tpl} \ 41 | --p_class ${p_class_tpl} \ 42 | --out_dir ${dir} \ 43 | --batch_size ${batch_size} \ 44 | --seed ${seed} \ 45 | --dataset ${dataset} \ 46 | --use_memory ${use_memory} \ 47 | --tolerance ${tolerance} \ 48 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 49 | --trust_factor ${trust_factor} \ 50 | --train_p1 True \ 51 | --train_p2 True \ 52 | --logp_penalty ${logp_penalty} \ 53 | --posterior_temp ${posterior_temp} \ 54 | --output_scoring_function logprobs \ 55 | --fwd_model_type ${fwd_model_type} \ 56 | --bwd_model_type ${bwd_model_type} 57 | done 58 | fi 59 | 60 | # dataset 61 | done -------------------------------------------------------------------------------- /projects/guide/templates/guided_backward.yaml: -------------------------------------------------------------------------------- 1 | v3.0: 2 | stop_tokens: ['\n\n', '[END]', '#'] 3 | template: |- 4 | 5 | A teacher is providing an instruction to guide their student to complete a task that requires producing a text output from a text input. Based on the student's performance, the teacher is trying to improve their instruction so the student can better understand and perform the task. 6 | 7 | This was the instruction given to the student: 8 | ## Instruction 9 | > {{ prompt }} 10 | 11 | This was the students' input and output pairs: 12 | {% for example_output in example_outputs %} 13 | ## Input: 14 | > {{ example_output.example }} 15 | ## Output: 16 | > {{ example_output.output }} 17 | {% endfor %} 18 | 19 | The teacher believes that the outputs can be improved in the following way: 20 | > {{ feedback }} 21 | 22 | Therefore, to generate more proper outputs, an improved version of the instruction can be: 23 | ## Instruction 24 | > 25 | 26 | v5.0: 27 | stop_tokens: ['\n\n', '[END]', '#'] 28 | template: |- 29 | A teacher is providing an instruction to guide their student to complete a task that requires producing a text output from a text input. Based on the student's performance, the teacher is trying to improve their instruction so the student can better understand and perform the task. 30 | 31 | These were the inputs, instructions, generated outputs, ratings, and optionally some per example feedback from the teacher: 32 | {% for example_output in example_outputs %}{% if example_output.rating.value %} 33 | ## Input: 34 | > {{ example_output.example }} 35 | ## Instruction: 36 | > {{ example_output.prompt }} 37 | ## Output: 38 | > {{ example_output.output }} 39 | ## Rating: 40 | > {{ example_output.rating.name.capitalize() }} 41 | {% if example_output.feedback %}## Feedback: 42 | > {{ example_output.feedback }}{% endif %} 43 | {% endif %}{% endfor %} 44 | 45 | Therefore, to encourage the student to generate better outputs, an improved version of the instruction can be: 46 | ## Instruction 47 | > 48 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_e2e/navigate_sweep.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=navigate 3 | iters=20 4 | p_class_tpl="classify_forward:3.0" 5 | q_prompt_tpl="q_action_prompt:v3.5" 6 | batch_size=20 7 | num_p_samples=20 8 | bwd_temp=0.7 9 | held_out_prompt_ranking=True 10 | use_memory=5 11 | tolerance=2 12 | num_h_samples=5 13 | num_p1_steps=1 14 | posterior_temp=1.0 15 | trust_factor=5. 16 | p_hidden_tpl="suffix_forward_tbs" 17 | q_hidden_tpl='"suffix_forward_tbs_y|suffix_forward_tbs"' 18 | 19 | # remove temp jobs file 20 | rm -rf /tmp/jobs.txt 21 | 22 | for logp_penalty in 0. 1. 3. 5.; do 23 | for posterior_temp in 0.15 0.25 1.0; do 24 | for batch_size in 20; do 25 | for num_p_samples in 20; do 26 | for num_h_samples in 5; do 27 | 28 | dir=log/two_layers_e2e_32ex/${dataset}/logp${logp_penalty}_nodecay_bsz${batch_size}_np${num_p_samples}_nh${num_h_samples}_pt${posterior_temp} 29 | /bin/rm -rf ${dir} 30 | 31 | for seed in 13 42 25; do 32 | echo "python vi_main.py \ 33 | --val_freq 2 \ 34 | --num_p1_steps ${num_p1_steps} \ 35 | --train_p1 True \ 36 | --train_p2 True \ 37 | --balance_batch \ 38 | --num_p_samples ${num_p_samples} \ 39 | --num_h_samples ${num_h_samples} \ 40 | --bwd_temp ${bwd_temp} \ 41 | --iters ${iters} \ 42 | --p_hidden ${p_hidden_tpl} \ 43 | --q_hidden ${q_hidden_tpl} \ 44 | --q_prompt ${q_prompt_tpl} \ 45 | --p_class ${p_class_tpl} \ 46 | --out_dir ${dir} \ 47 | --batch_size ${batch_size} \ 48 | --seed ${seed} \ 49 | --dataset ${dataset} \ 50 | --use_memory ${use_memory} \ 51 | --tolerance ${tolerance} \ 52 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 53 | --trust_factor ${trust_factor} \ 54 | --forward_use_classes True \ 55 | --logp_penalty ${logp_penalty} \ 56 | --posterior_temp ${posterior_temp} \ 57 | --strip_options_for_hidden True \ 58 | --strip_prefix_for_hidden False \ 59 | --decay_logp_penalty False" >> /tmp/jobs.txt 60 | #seed 61 | done 62 | done 63 | done 64 | done 65 | done 66 | done 67 | done 68 | 69 | # launch 70 | parallel -j 15 < /tmp/jobs.txt -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_e2e/gsm8k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x # print commands to terminal 3 | 4 | fwd_model_type="gpt-35-turbo-instruct" 5 | bwd_model_type="gpt-35-turbo-instruct" 6 | dataset="gsm8k" 7 | loss_function="number_presence_loss" 8 | output_scoring_function="logprobs" 9 | max_train_size=300 10 | max_dev_size=300 11 | max_test_size=300 12 | p_class_tpl="classify_forward:3.0" 13 | iters=10 14 | batch_size=10 15 | num_p_samples=10 16 | num_h_samples=5 17 | bwd_temp=0.7 18 | forward_use_classes=False 19 | held_out_prompt_ranking=True 20 | use_memory=2 21 | tolerance=2 22 | trust_factor=1. 23 | q_prompt_tpl="q_action_prompt" 24 | logp_penalty=1. 25 | posterior_temp=1. 26 | p_hidden_tpl="suffix_forward_tbs" 27 | q_hidden_tpl="suffix_forward_tbs_y|suffix_forward_tbs" 28 | fwd_max_tokens=512 29 | bwd_max_tokens=1024 30 | 31 | dir=log/longer_context/${fwd_model_type}_${bwd_model_type}/two_layers_e2e/${dataset} 32 | 33 | for seed in 13 42 25; do 34 | python vi_main.py \ 35 | --do_first_eval \ 36 | --init_p2 "Solve the math word problem."\ 37 | --fwd_model_type ${fwd_model_type} \ 38 | --bwd_model_type ${bwd_model_type} \ 39 | --dataset ${dataset} \ 40 | --loss_function ${loss_function} \ 41 | --max_train_size ${max_train_size} \ 42 | --max_dev_size ${max_dev_size} \ 43 | --max_test_size ${max_test_size} \ 44 | --p_class ${p_class_tpl} \ 45 | --iters ${iters} \ 46 | --batch_size ${batch_size} \ 47 | --num_p_samples ${num_p_samples} \ 48 | --num_h_samples ${num_h_samples} \ 49 | --bwd_temp ${bwd_temp} \ 50 | --forward_use_classes ${forward_use_classes} \ 51 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 52 | --out_dir ${dir} \ 53 | --use_memory ${use_memory} \ 54 | --tolerance ${tolerance} \ 55 | --trust_factor ${trust_factor} \ 56 | --q_prompt ${q_prompt_tpl} \ 57 | --logp_penalty ${logp_penalty} \ 58 | --posterior_temp ${posterior_temp} \ 59 | --p_hidden ${p_hidden_tpl} \ 60 | --q_hidden ${q_hidden_tpl} \ 61 | --seed ${seed} \ 62 | --fwd_max_tokens ${fwd_max_tokens} \ 63 | --bwd_max_tokens ${bwd_max_tokens} \ 64 | --output_scoring_function ${output_scoring_function} 65 | done 66 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_e2e/logical_deduction_seven_objects_sweep.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=logical_deduction_seven_objects 3 | iters=20 4 | p_class_tpl="classify_forward:3.0" 5 | q_prompt_tpl="q_action_prompt:v3.5" 6 | batch_size=20 7 | num_p_samples=20 8 | bwd_temp=0.7 9 | held_out_prompt_ranking=True 10 | use_memory=5 11 | tolerance=2 12 | num_h_samples=5 13 | num_p1_steps=1 14 | posterior_temp=1.0 15 | trust_factor=5. 16 | p_hidden_tpl="suffix_forward_tbs" 17 | q_hidden_tpl='"suffix_forward_tbs_y|suffix_forward_tbs"' 18 | 19 | # remove temp jobs file 20 | rm -rf /tmp/jobs_${dataset}.txt 21 | 22 | for logp_penalty in 0. 1. 3. 5.; do 23 | for posterior_temp in 1.0; do 24 | for batch_size in 20; do 25 | for num_p_samples in 20; do 26 | for num_h_samples in 5; do 27 | for strip in False; do 28 | for decay in False; do 29 | 30 | dir=log/two_layers_e2e/${dataset}/stripopt${strip}_decay${decay}_logp${logp_penalty}_bsz${batch_size}_np${num_p_samples}_nh${num_h_samples}_pt${posterior_temp} 31 | /bin/rm -rf ${dir} 32 | 33 | for seed in 13 42 25; do 34 | echo "python vi_main.py \ 35 | --val_freq 1 \ 36 | --do_first_eval \ 37 | --num_p1_steps ${num_p1_steps} \ 38 | --train_p1 True \ 39 | --train_p2 True \ 40 | --balance_batch \ 41 | --p1_max_tokens 768 \ 42 | --num_p_samples ${num_p_samples} \ 43 | --num_h_samples ${num_h_samples} \ 44 | --bwd_temp ${bwd_temp} \ 45 | --iters ${iters} \ 46 | --p_hidden ${p_hidden_tpl} \ 47 | --q_hidden ${q_hidden_tpl} \ 48 | --q_prompt ${q_prompt_tpl} \ 49 | --p_class ${p_class_tpl} \ 50 | --out_dir ${dir} \ 51 | --batch_size ${batch_size} \ 52 | --seed ${seed} \ 53 | --dataset ${dataset} \ 54 | --use_memory ${use_memory} \ 55 | --tolerance ${tolerance} \ 56 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 57 | --trust_factor ${trust_factor} \ 58 | --forward_use_classes True \ 59 | --logp_penalty ${logp_penalty} \ 60 | --posterior_temp ${posterior_temp} \ 61 | --strip_options_for_hidden ${strip} \ 62 | --decay_logp_penalty ${decay}" >> /tmp/jobs_${dataset}.txt 63 | #seed 64 | done 65 | done 66 | done 67 | done 68 | done 69 | done 70 | done 71 | 72 | # launch 73 | parallel -j 15 < /tmp/jobs_${dataset}.txt -------------------------------------------------------------------------------- /scripts/split_bigbench_navigate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import codecs 4 | 5 | from utils import download_json_from_url 6 | 7 | 8 | bb_raw_json_url = "https://raw.githubusercontent.com/google/BIG-bench/main/bigbench/benchmark_tasks/navigate/task.json" 9 | bbh_file_path = "data/bbh/navigate.json" 10 | option_list = ["Yes", "No"] 11 | option_map = {"True": "- Yes", "False": "- No"} 12 | 13 | bbh_sentence, bbh_label = [], [] 14 | bbh_dict = {} 15 | with open(bbh_file_path) as fin: 16 | data = json.load(fin) 17 | data = data["examples"] 18 | 19 | for i in range(len(data)): 20 | input, target = data[i]["input"], data[i]["target"] 21 | bbh_sentence.append(input) 22 | bbh_label.append(target) 23 | assert input not in bbh_dict 24 | bbh_dict[input] = len(bbh_sentence) - 1 25 | 26 | data = download_json_from_url(bb_raw_json_url) 27 | data = data["examples"] 28 | bb_sentence, bb_label = [], [] 29 | bb_len = 0 30 | in_bbh = 0 31 | for i in range(len(data)): 32 | bb_len += 1 33 | input, target = data[i]["input"], data[i]["target_scores"] 34 | res_input = ["If you follow these instructions, do you return to the starting point? " + input, "Options:"] 35 | res_target = "" 36 | for j, key in enumerate(target): 37 | if target[key] == 1: 38 | res_target = option_list[j] 39 | res_input.append(option_map[key]) 40 | assert len(res_target) > 0 41 | res_input = "\n".join(res_input) 42 | if res_input in bbh_dict: 43 | in_bbh += 1 44 | continue 45 | 46 | bb_sentence.append(res_input) 47 | bb_label.append(res_target) 48 | assert in_bbh == len(bbh_dict), "%s, %s" % (str(in_bbh), str(len(bbh_dict))) 49 | 50 | data = [] 51 | for i in range(len(bb_sentence)): 52 | data.append({"input": bb_sentence[i], "target": bb_label[i]}) 53 | print("there are %s data points in big bench" % str(bb_len)) 54 | print("there are %s data points in big bench hard" % str(len(bbh_dict))) 55 | print("collected %s data points from big bench - big bench hard" % str(len(data))) 56 | 57 | bb_minus_bbh_file_path = "data/bb_minus_bbh/" 58 | print("writing data to ", bb_minus_bbh_file_path) 59 | if not os.path.exists(bb_minus_bbh_file_path): 60 | os.makedirs(bb_minus_bbh_file_path) 61 | with codecs.open(bb_minus_bbh_file_path + "/navigate.json", 'w', encoding='utf-8') as json_file: 62 | json.dump({"examples": data}, json_file, ensure_ascii=False) 63 | -------------------------------------------------------------------------------- /scripts/split_bigbench_date_understanding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import codecs 4 | 5 | from utils import download_json_from_url 6 | 7 | 8 | bb_raw_json_url = "https://raw.githubusercontent.com/google/BIG-bench/main/bigbench/benchmark_tasks/date_understanding/task.json" 9 | bbh_file_path = "data/bbh/date_understanding.json" 10 | option_list = ["(A)", "(B)", "(C)", "(D)", "(E)", "(F)", "(G)", "(H)", "(I)", "(J)", "(K)", "(L)"] 11 | 12 | bbh_sentence, bbh_label = [], [] 13 | bbh_dict = {} 14 | with open(bbh_file_path) as fin: 15 | data = json.load(fin) 16 | data = data["examples"] 17 | 18 | for i in range(len(data)): 19 | input, target = data[i]["input"], data[i]["target"] 20 | bbh_sentence.append(input) 21 | bbh_label.append(target) 22 | q = input.split("\n")[0] 23 | assert q not in bbh_dict 24 | bbh_dict[q] = len(bbh_sentence) - 1 25 | 26 | data = download_json_from_url(bb_raw_json_url) 27 | data = data["examples"] 28 | bb_sentence, bb_label = [], [] 29 | bb_len = 0 30 | in_bbh = 0 31 | for i in range(len(data)): 32 | bb_len += 1 33 | input, target = data[i]["input"], data[i]["target_scores"] 34 | if input in bbh_dict: 35 | in_bbh += 1 36 | continue 37 | res_input = [input, "Options:"] 38 | res_target = "" 39 | for j, key in enumerate(target): 40 | if target[key] == 1: 41 | res_target = option_list[j] 42 | res_input.append(" ".join([option_list[j], key])) 43 | assert len(res_target) > 0 44 | res_input = "\n".join(res_input) 45 | bb_sentence.append(res_input) 46 | bb_label.append(res_target) 47 | assert in_bbh == len(bbh_dict), "%s, %s" % (str(in_bbh), str(len(bbh_dict))) 48 | 49 | data = [] 50 | for i in range(len(bb_sentence)): 51 | data.append({"input": bb_sentence[i], "target": bb_label[i]}) 52 | print("there are %s data points in big bench" % str(bb_len)) 53 | print("there are %s data points in big bench hard" % str(len(bbh_dict))) 54 | print("collected %s data points from big bench - big bench hard" % str(len(data))) 55 | 56 | bb_minus_bbh_file_path = "data/bb_minus_bbh/" 57 | print("writing data to ", bb_minus_bbh_file_path) 58 | if not os.path.exists(bb_minus_bbh_file_path): 59 | os.makedirs(bb_minus_bbh_file_path) 60 | with codecs.open(bb_minus_bbh_file_path + "/date_understanding.json", 'w', encoding='utf-8') as json_file: 61 | json.dump({"examples": data}, json_file, ensure_ascii=False) 62 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | name: "CodeQL" 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | # The branches below must be a subset of the branches above 8 | branches: [ "main" ] 9 | schedule: 10 | - cron: '20 17 * * 5' 11 | 12 | jobs: 13 | analyze: 14 | name: Analyze 15 | runs-on: ubuntu-latest 16 | permissions: 17 | actions: read 18 | contents: read 19 | security-events: write 20 | 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | language: [ "python" ] 25 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 26 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 27 | 28 | steps: 29 | - name: Checkout repository 30 | uses: actions/checkout@v4 31 | 32 | # Initializes the CodeQL tools for scanning. 33 | - name: Initialize CodeQL 34 | uses: github/codeql-action/init@v3 35 | with: 36 | languages: ${{ matrix.language }} 37 | # If you wish to specify custom queries, you can do so here or in a config file. 38 | # By default, queries listed here will override any specified in a config file. 39 | # Prefix the list here with "+" to use these queries and those in the config file. 40 | 41 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 42 | # queries: security-extended,security-and-quality 43 | 44 | 45 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). 46 | # If this step fails, then you should remove it and run the build manually (see below) 47 | - name: Autobuild 48 | uses: github/codeql-action/autobuild@v3 49 | 50 | # ℹ️ Command-line programs to run using the OS shell. 51 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 52 | 53 | # If the Autobuild fails above, remove it and uncomment the following three lines. 54 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 55 | 56 | # - run: | 57 | # echo "Run, Build Application using script" 58 | # ./location_of_script_within_repo/buildscript.sh 59 | 60 | - name: Perform CodeQL Analysis 61 | uses: github/codeql-action/analyze@v3 62 | with: 63 | category: "/language:${{matrix.language}}" 64 | -------------------------------------------------------------------------------- /scripts/split_bigbench_logical_deduction_seven_objects.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import codecs 4 | 5 | from utils import download_json_from_url 6 | 7 | 8 | bb_raw_json_url = "https://raw.githubusercontent.com/google/BIG-bench/main/bigbench/benchmark_tasks/logical_deduction/seven_objects/task.json" 9 | bbh_file_path = "data/bbh/logical_deduction_seven_objects.json" 10 | option_list = ["(A)", "(B)", "(C)", "(D)", "(E)", "(F)", "(G)", "(H)", "(I)", "(J)", "(K)", "(L)"] 11 | 12 | bbh_sentence, bbh_label = [], [] 13 | bbh_dict = {} 14 | with open(bbh_file_path) as fin: 15 | data = json.load(fin) 16 | data = data["examples"] 17 | 18 | for i in range(len(data)): 19 | input, target = data[i]["input"], data[i]["target"] 20 | bbh_sentence.append(input) 21 | bbh_label.append(target) 22 | assert input not in bbh_dict 23 | bbh_dict[input] = len(bbh_sentence) - 1 24 | 25 | data = download_json_from_url(bb_raw_json_url) 26 | data = data["examples"] 27 | bb_sentence, bb_label = [], [] 28 | bb_len = 0 29 | in_bbh = 0 30 | for i in range(len(data)): 31 | bb_len += 1 32 | input, target = data[i]["input"], data[i]["target_scores"] 33 | res_input = ["The following paragraphs each describe a set of seven objects arranged in a fixed order. The statements are logically consistent within each paragraph. " + input, "Options:"] 34 | res_target = "" 35 | for j, key in enumerate(target): 36 | if target[key] == 1: 37 | res_target = option_list[j] 38 | res_input.append(" ".join([option_list[j], key[:-1]])) # no periods in bbh 39 | assert len(res_target) > 0 40 | res_input = "\n".join(res_input) 41 | if res_input in bbh_dict: 42 | in_bbh += 1 43 | continue 44 | bb_sentence.append(res_input) 45 | bb_label.append(res_target) 46 | assert in_bbh == len(bbh_dict), "%s, %s" % (str(in_bbh), str(len(bbh_dict))) 47 | 48 | data = [] 49 | for i in range(len(bb_sentence)): 50 | data.append({"input": bb_sentence[i], "target": bb_label[i]}) 51 | print("there are %s data points in big bench" % str(bb_len)) 52 | print("there are %s data points in big bench hard" % str(len(bbh_dict))) 53 | print("collected %s data points from big bench - big bench hard" % str(len(data))) 54 | 55 | bb_minus_bbh_file_path = "data/bb_minus_bbh/" 56 | print("writing data to ", bb_minus_bbh_file_path) 57 | if not os.path.exists(bb_minus_bbh_file_path): 58 | os.makedirs(bb_minus_bbh_file_path) 59 | with codecs.open(bb_minus_bbh_file_path + "/logical_deduction_seven_objects.json", 'w', encoding='utf-8') as json_file: 60 | json.dump({"examples": data}, json_file, ensure_ascii=False) 61 | -------------------------------------------------------------------------------- /scripts/split_bigbench_hyperbaton.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import codecs 4 | 5 | from utils import download_json_from_url 6 | 7 | 8 | bb_raw_json_url = "https://raw.githubusercontent.com/google/BIG-bench/main/bigbench/benchmark_tasks/hyperbaton/task.json" 9 | bbh_file_path = "data/bbh/hyperbaton.json" 10 | option_list = ["(A)", "(B)", "(C)", "(D)", "(E)", "(F)", "(G)", "(H)", "(I)", "(J)", "(K)", "(L)"] 11 | 12 | bbh_sentence, bbh_label = [], [] 13 | bbh_dict = {} 14 | with open(bbh_file_path) as fin: 15 | data = json.load(fin) 16 | data = data["examples"] 17 | 18 | for i in range(len(data)): 19 | input, target = data[i]["input"], data[i]["target"] 20 | bbh_sentence.append(input) 21 | bbh_label.append(target) 22 | assert input not in bbh_dict 23 | bbh_dict[input] = len(bbh_sentence) - 1 24 | 25 | data = download_json_from_url(bb_raw_json_url) 26 | data = data["examples"] 27 | bb_sentence, bb_label = [], [] 28 | bb_len = 0 29 | in_bbh = 0 30 | for i in range(len(data)): 31 | bb_len += 1 32 | input, target = data[i]["input"], data[i]["target_scores"] 33 | candidates = input.split("Which sentence has the correct adjective order:")[-1].strip() 34 | input = "Which sentence has the correct adjective order:" 35 | candidates = [candidates.split("\"")[1].strip(), candidates.split("\"")[3].strip()] 36 | res_input = [input, "Options:"] 37 | res_target = "" 38 | for j, key in enumerate(target): 39 | if target[key] == 1: 40 | res_target = option_list[j] 41 | res_input.append(" ".join([option_list[j], candidates[j]])) 42 | assert len(res_target) > 0 43 | res_input = "\n".join(res_input) 44 | if res_input in bbh_dict: 45 | in_bbh += 1 46 | continue 47 | 48 | bb_sentence.append(res_input) 49 | bb_label.append(res_target) 50 | assert in_bbh == len(bbh_dict), "%s, %s" % (str(in_bbh), str(len(bbh_dict))) 51 | 52 | data = [] 53 | for i in range(len(bb_sentence)): 54 | data.append({"input": bb_sentence[i], "target": bb_label[i]}) 55 | print("there are %s data points in big bench" % str(bb_len)) 56 | print("there are %s data points in big bench hard" % str(len(bbh_dict))) 57 | print("collected %s data points from big bench - big bench hard" % str(len(data))) 58 | 59 | bb_minus_bbh_file_path = "data/bb_minus_bbh/" 60 | print("writing data to ", bb_minus_bbh_file_path) 61 | if not os.path.exists(bb_minus_bbh_file_path): 62 | os.makedirs(bb_minus_bbh_file_path) 63 | with codecs.open(bb_minus_bbh_file_path + "/hyperbaton.json", 'w', encoding='utf-8') as json_file: 64 | json.dump({"examples": data}, json_file, ensure_ascii=False) 65 | -------------------------------------------------------------------------------- /dln/template.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | import os 4 | import glob 5 | import yaml 6 | import logging 7 | from jinja2 import Template 8 | from packaging import version as pkg_version 9 | 10 | 11 | @dataclass 12 | class DLNTemplate: 13 | template: str 14 | stop_tokens: List[str] = None 15 | version: int = "latest" 16 | description: str = None 17 | message: str = None 18 | message_alternatives: List[str] = None 19 | 20 | def render(self, **kwargs): 21 | if kwargs.get("message") is None: 22 | kwargs["message"] = self.message 23 | 24 | return Template(self.template).render(**kwargs).lstrip().rstrip() 25 | 26 | 27 | class Templates: 28 | 29 | def __init__(self, template_directory=None): 30 | if template_directory is None: 31 | template_directory = os.path.join(os.path.dirname(__file__), 'templates/') 32 | self._data = {} 33 | for filename in glob.glob(f"{template_directory}/*.yaml"): 34 | template_name = os.path.basename(filename).split(".")[0] 35 | template = yaml.safe_load(open(filename, "r")) 36 | 37 | self._data[template_name] = [] 38 | for tversion, ttemplate in template.items(): 39 | if "v" not in tversion: 40 | raise ValueError("Version must be in the format v1, v1.2, etc.") 41 | 42 | ttemplate["version"] = pkg_version.parse(tversion.split("v")[-1]) 43 | if "stop_tokens" in ttemplate: 44 | # strip the first \ of \\n from the stop tokens 45 | for i, stop_token in enumerate(ttemplate["stop_tokens"]): 46 | ttemplate["stop_tokens"][i] = ttemplate["stop_tokens"][ 47 | i 48 | ].replace("\\n", "\n") 49 | self._data[template_name].append(DLNTemplate(**ttemplate)) 50 | 51 | @staticmethod 52 | def get(template_name, template_directory=None): 53 | template_name, _, version = template_name.partition(":") 54 | if not version: 55 | version = "latest" 56 | 57 | templates = Templates(template_directory)._data[template_name] 58 | 59 | if version == "latest": 60 | template = max(templates, key=lambda x: x.version) 61 | else: 62 | template = [ 63 | t for t in templates if t.version == pkg_version.parse(version) 64 | ][0] 65 | 66 | logging.info(f"Loaded template {template_name} v{template.version}") 67 | return template 68 | 69 | 70 | def load_template(template_name, template_directory=None): 71 | return Templates.get(template_name, template_directory) 72 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/sweep.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=${1} 3 | iters=20 4 | batch_size=20 5 | num_p_samples=20 6 | held_out_prompt_ranking=True 7 | use_memory=5 8 | tolerance=2 9 | num_h_samples=5 10 | trust_factor=5. 11 | p_class_tpl="classify_forward:3.0" 12 | q_prompt_tpl="q_action_prompt:v3.5" 13 | p_hidden_tpl="suffix_forward_tbs" 14 | q_hidden_tpl='"suffix_forward_tbs_y|suffix_forward_tbs"' 15 | fwd_model_type="text-davinci-003" 16 | bwd_model_type="text-davinci-003" 17 | one_layer=False 18 | 19 | # sweep space 20 | bwd_temps=(0.7) 21 | posterior_temps=(1.0) 22 | log_penalties=(0.0) 23 | batch_sizes=(20) 24 | strip_options=(False) 25 | logp_decays=(True) 26 | 27 | # remove temp jobs file 28 | rm -rf /tmp/jobs_${dataset}.txt 29 | 30 | for logp_penalty in ${log_penalties[@]}; do 31 | for posterior_temp in ${posterior_temps[@]}; do 32 | for bwd_temp in ${bwd_temps[@]}; do 33 | for batch_size in ${batch_sizes[@]}; do 34 | for strip in ${strip_options[@]}; do 35 | for decay in ${logp_decays[@]}; do 36 | 37 | dir=log/one_layer${one_layer}_e2e/${dataset}/fmt${model_type}_bmt${bwd_model_type}_stripopt${strip}_decay${decay}_logp${logp_penalty}_bwdt${bwd_temp}_bsz${batch_size}_np${num_p_samples}_nh${num_h_samples}_pt${posterior_temp} 38 | /bin/rm -rf ${dir} 39 | 40 | for seed in 13 42 25; do 41 | echo "python vi_main.py \ 42 | --val_freq 2 \ 43 | --do_first_eval \ 44 | --one_layer ${one_layer} \ 45 | --train_p1 True \ 46 | --train_p2 True \ 47 | --balance_batch \ 48 | --p1_max_tokens 512 \ 49 | --num_p_samples ${num_p_samples} \ 50 | --num_h_samples ${num_h_samples} \ 51 | --bwd_temp ${bwd_temp} \ 52 | --iters ${iters} \ 53 | --p_hidden ${p_hidden_tpl} \ 54 | --q_hidden ${q_hidden_tpl} \ 55 | --q_prompt ${q_prompt_tpl} \ 56 | --p_class ${p_class_tpl} \ 57 | --out_dir ${dir} \ 58 | --batch_size ${batch_size} \ 59 | --seed ${seed} \ 60 | --dataset ${dataset} \ 61 | --use_memory ${use_memory} \ 62 | --tolerance ${tolerance} \ 63 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 64 | --trust_factor ${trust_factor} \ 65 | --fwd_model_type ${fwd_model_type} \ 66 | --bwd_model_type ${bwd_model_type} \ 67 | --forward_use_classes True \ 68 | --logp_penalty ${logp_penalty} \ 69 | --posterior_temp ${posterior_temp} \ 70 | --strip_options_for_hidden ${strip} \ 71 | --decay_logp_penalty ${decay}" >> /tmp/jobs_${dataset}.txt 72 | #seed 73 | done 74 | done 75 | done 76 | done 77 | done 78 | done 79 | done 80 | 81 | # launch 82 | # parallel -j 15 < /tmp/jobs_${dataset}.txt 83 | head -n 1 /tmp/jobs_${dataset}.txt -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/one_layer/phi2_DLN-1.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | 3 | fwd_model_type="microsoft/phi-2" 4 | bwd_model_type="microsoft/phi-2" 5 | output_scoring_function="logprobs" # "accuracy" 6 | max_train_size=300 7 | max_dev_size=300 8 | max_test_size=300 9 | p_class_tpl="classify_forward:3.0" 10 | iters=10 11 | batch_size=10 12 | num_p_samples=10 13 | num_h_samples=5 14 | bwd_temp=0.7 15 | forward_use_classes=False 16 | held_out_prompt_ranking=True 17 | use_memory=2 18 | tolerance=2 19 | trust_factor=1. 20 | q_prompt_tpl="q_action_prompt" 21 | logp_penalty=1. 22 | posterior_temp=1. 23 | p_hidden_tpl="suffix_forward_tbs" 24 | q_hidden_tpl="suffix_forward_tbs_y|suffix_forward_tbs" 25 | fwd_max_tokens=256 26 | bwd_max_tokens=512 27 | 28 | one_layer=True 29 | 30 | # dataset 31 | for dataset in gsm8k hyperbaton navigate date_understanding logical_deduction_seven_objects mpqa trec subj disaster airline; do 32 | 33 | dir=log/phi2_hg_4/${dataset}/DLN-1/ 34 | 35 | # If dataset is gsm8k, then set forward_use_classes to False 36 | if [ ${dataset} == "gsm8k" ]; then 37 | forward_use_classes=False 38 | loss_function="number_presence_loss" 39 | else 40 | forward_use_classes=True 41 | loss_function="exact_match_loss" 42 | fi 43 | 44 | if [ ! -f ${dir}/done.txt ]; then 45 | /bin/rm -rf ${dir} 46 | 47 | # seed 48 | for seed in 13 42 25; do 49 | python vi_main.py \ 50 | --one_layer ${one_layer} \ 51 | --do_first_eval \ 52 | --fwd_model_type ${fwd_model_type} \ 53 | --bwd_model_type ${bwd_model_type} \ 54 | --dataset ${dataset} \ 55 | --loss_function ${loss_function} \ 56 | --max_train_size ${max_train_size} \ 57 | --max_dev_size ${max_dev_size} \ 58 | --max_test_size ${max_test_size} \ 59 | --p_class ${p_class_tpl} \ 60 | --iters ${iters} \ 61 | --batch_size ${batch_size} \ 62 | --num_p_samples ${num_p_samples} \ 63 | --num_h_samples ${num_h_samples} \ 64 | --bwd_temp ${bwd_temp} \ 65 | --forward_use_classes ${forward_use_classes} \ 66 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 67 | --out_dir ${dir} \ 68 | --use_memory ${use_memory} \ 69 | --tolerance ${tolerance} \ 70 | --trust_factor ${trust_factor} \ 71 | --q_prompt ${q_prompt_tpl} \ 72 | --logp_penalty ${logp_penalty} \ 73 | --posterior_temp ${posterior_temp} \ 74 | --p_hidden ${p_hidden_tpl} \ 75 | --q_hidden ${q_hidden_tpl} \ 76 | --seed ${seed} \ 77 | --fwd_max_tokens ${fwd_max_tokens} \ 78 | --bwd_max_tokens ${bwd_max_tokens} \ 79 | --output_scoring_function ${output_scoring_function} 80 | 81 | # seed 82 | done 83 | 84 | touch ${dir}/done.txt 85 | fi 86 | # dataset 87 | done 88 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/one_layer/phi2_Nshot.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | 3 | fwd_model_type="/vllm/phi-2" 4 | bwd_model_type="/vllm/phi-2" 5 | dataset="gsm8k" 6 | loss_function="number_presence_loss" 7 | output_scoring_function="accuracy" 8 | max_train_size=300 9 | max_dev_size=300 10 | max_test_size=300 11 | p_class_tpl="classify_forward:3.0" 12 | iters=10 13 | batch_size=10 14 | num_p_samples=10 15 | num_h_samples=5 16 | bwd_temp=0.7 17 | forward_use_classes=False 18 | held_out_prompt_ranking=True 19 | use_memory=2 20 | tolerance=2 21 | trust_factor=1. 22 | q_prompt_tpl="q_action_prompt" 23 | logp_penalty=1. 24 | posterior_temp=1. 25 | p_hidden_tpl="suffix_forward_tbs" 26 | q_hidden_tpl="suffix_forward_tbs_y|suffix_forward_tbs" 27 | fwd_max_tokens=256 28 | bwd_max_tokens=512 29 | one_layer=True 30 | 31 | # dataset 32 | for dataset in gsm8k hyperbaton navigate date_understanding logical_deduction_seven_objects mpqa trec subj disaster airline; do 33 | # n_shot 34 | for n_shot in 0 5; do 35 | 36 | # If dataset is gsm8k, then set forward_use_classes to False 37 | if [ ${dataset} == "gsm8k" ]; then 38 | forward_use_classes=False 39 | else 40 | forward_use_classes=True 41 | fi 42 | 43 | dir=log/phi2/${dataset}/${n_shot}shot/ 44 | 45 | if [ ! -f ${dir}/done.txt ]; then 46 | /bin/rm -rf ${dir} 47 | 48 | # seed 49 | for seed in 13 42 25; do 50 | python vi_main.py \ 51 | --one_layer ${one_layer} \ 52 | --do_first_eval \ 53 | --do_zero_shot \ 54 | --n_shots ${n_shot} \ 55 | --fwd_model_type ${fwd_model_type} \ 56 | --bwd_model_type ${bwd_model_type} \ 57 | --dataset ${dataset} \ 58 | --loss_function ${loss_function} \ 59 | --max_train_size ${max_train_size} \ 60 | --max_dev_size ${max_dev_size} \ 61 | --max_test_size ${max_test_size} \ 62 | --p_class ${p_class_tpl} \ 63 | --iters ${iters} \ 64 | --batch_size ${batch_size} \ 65 | --num_p_samples ${num_p_samples} \ 66 | --num_h_samples ${num_h_samples} \ 67 | --bwd_temp ${bwd_temp} \ 68 | --forward_use_classes ${forward_use_classes} \ 69 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 70 | --out_dir ${dir} \ 71 | --use_memory ${use_memory} \ 72 | --tolerance ${tolerance} \ 73 | --trust_factor ${trust_factor} \ 74 | --q_prompt ${q_prompt_tpl} \ 75 | --logp_penalty ${logp_penalty} \ 76 | --posterior_temp ${posterior_temp} \ 77 | --p_hidden ${p_hidden_tpl} \ 78 | --q_hidden ${q_hidden_tpl} \ 79 | --seed ${seed} \ 80 | --fwd_max_tokens ${fwd_max_tokens} \ 81 | --bwd_max_tokens ${bwd_max_tokens} \ 82 | --output_scoring_function ${output_scoring_function} 83 | 84 | # seed 85 | done 86 | 87 | touch ${dir}/done.txt 88 | fi 89 | # n shot 90 | done 91 | # dataset 92 | done 93 | -------------------------------------------------------------------------------- /projects/guide/README.md: -------------------------------------------------------------------------------- 1 | # GUIDE - Guided Meta-Prompt Search with Human Feedback 2 | 3 | GUIDE is a prototype tool developed as an illustrative use-case for DLN and designed to support researchers and practitioners who wish to explore different ways to reformulate their meta-prompts to guide LLMs’ behavior. It does so by providing an intuitive interface that collects and tries to integrate user feedback to help refine and direct the search process for alternative meta-prompts. To search for alternative meta-prompts, GUIDE relies on DLN for prompt optimization. 4 | 5 | 6 | ## Limitations and Risks of Using GUIDE 7 | 8 | While GUIDE is just a research prototype to show-case how DLN can be used, we believe it is important to acknowledge the limitations and potential risks associated with the use of GUIDE in its current form. 9 | 10 | While GUIDE provides an illustration of a process for exploring different meta-prompt designs, we have not evaluated whether this process yields better meta-prompts given some end user goal. While for the current prototype we have explored a few feedback mechanisms, we have also not evaluated their impact on the resulting meta-prompts or on how the meta-prompt impacts the model behavior at inference time. Users of GUIDE will thus need to devise their own meta-prompt evaluations. 11 | 12 | Given that GUIDE relies on LLMs to generate alternative meta-prompt formulations, the many concerns existing literature has raised about LLMs will hold for GUIDE as well. These include concerns related to generating statements making inaccurate or misleading claims, to surfacing harmful biases and stereotypes, or to generating violent speech, among others. 13 | 14 | Finally, given that in practice users are likely to only provide a small number of input examples, GUIDE might end up overfitting on the input examples when generating alternative meta-prompts. We have not evaluated the impact that the number or type of examples might have on optimization process or the resulting meta-prompt suggestions. 15 | 16 | 17 | ## Getting Started 18 | 19 | ### Installation 20 | 21 | 1. Setup a virtual environment using conda or venv 22 | 2. Install the requirements using `pip install -r requirements.txt` 23 | 24 | ### Set your OpenAI API key 25 | 26 | Export your key or put it in your *shrc, e.g.: `export OPENAI_API_KEY='...your...key...'` 27 | 28 | Please refer to the [DLN main page](../../README.md#set-your-openai-api-key) for instructions on configuring your OpenAI API key for Azure endpoints. 29 | 30 | 31 | ### Usage 32 | 33 | Start streamlit app using `streamlit run app.py` 34 | 35 | 36 | ## Serve with Docker 37 | 38 | Build the docker image using: 39 | 40 | ``` 41 | docker build -t guide . 42 | ``` 43 | 44 | Run the docker image making sure to pass in your OpenAI API information: 45 | 46 | ``` 47 | docker run --name guide \ 48 | --restart unless-stopped \ 49 | -d \ 50 | -p 8001:8501 \ 51 | -e OPENAI_API_KEY=$OPENAI_API_KEY \ 52 | -e OPENAI_API_TYPE=$OPENAI_API_TYPE \ 53 | -e OPENAI_API_BASE=$OPENAI_API_BASE \ 54 | -e OPENAI_API_VERSION=$OPENAI_API_VERSION \ 55 | guide 56 | ``` 57 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/one_layer/sweep_bbh.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=${1} 3 | iters=50 4 | batch_size=20 5 | num_p_samples=20 6 | use_memory=5 7 | tolerance=5 8 | num_h_samples=5 9 | p_class_tpl="classify_forward:v3.0" 10 | q_prompt_tpl="q_action_prompt:v3.5" 11 | p_hidden_tpl="suffix_forward_tbs" 12 | q_hidden_tpl='"suffix_forward_tbs_y|suffix_forward_tbs"' 13 | p1_max_tokens=256 14 | 15 | # sweep space 16 | bwd_temps=(0.7 0.9) 17 | posterior_temps=(1.0) 18 | logp_decays=(True) 19 | num_examples=(-1) 20 | batch_sizes=(20) 21 | trust_factors=(0) 22 | held_out_prompt_ranking=(False) 23 | one_layer=True 24 | fwd_model_type="text-davinci-003" 25 | bwd_model_type="text-davinci-003" 26 | log_penalties=(0.0) 27 | train_p2s=(True) 28 | h_max=(True) 29 | 30 | # remove temp jobs file 31 | rm -rf /tmp/jobs_${dataset}.txt 32 | 33 | for logp_penalty in ${log_penalties[@]}; do 34 | for posterior_temp in ${posterior_temps[@]}; do 35 | for bwd_temp in ${bwd_temps[@]}; do 36 | for batch_size in ${batch_sizes[@]}; do 37 | for decay in ${logp_decays[@]}; do 38 | for num_example in ${num_examples[@]}; do 39 | for tf in ${trust_factors[@]}; do 40 | for train_p2 in ${train_p2s[@]}; do 41 | for hmax in ${h_max[@]}; do 42 | for hout in ${held_out_prompt_ranking[@]}; do 43 | 44 | dir=log/one_layer${one_layer}_e2e/${dataset}/${fwd_model_type}_${bwd_model_type}_trp2${train_p2}_hmax${hmax}_tf${tf}_heldoutpromptrank${hout}_nex${num_example}_stripoptFalse_decay${decay}_logp${logp_penalty}_bwdt${bwd_temp}_bsz${batch_size}_np${num_p_samples}_nh${num_h_samples}_pt${posterior_temp} 45 | /bin/rm -rf ${dir} 46 | 47 | for seed in 13 42 25; do 48 | echo "python vi_main.py \ 49 | --rewrite_loss_only True \ 50 | --val_freq 1 \ 51 | --do_first_eval \ 52 | --one_layer ${one_layer} \ 53 | --train_p1 True \ 54 | --train_p2 ${train_p2} \ 55 | --balance_batch \ 56 | --num_train_examples ${num_example} \ 57 | --p1_max_tokens ${p1_max_tokens} \ 58 | --num_p_samples ${num_p_samples} \ 59 | --num_h_samples ${num_h_samples} \ 60 | --bwd_temp ${bwd_temp} \ 61 | --iters ${iters} \ 62 | --p_hidden ${p_hidden_tpl} \ 63 | --q_hidden ${q_hidden_tpl} \ 64 | --q_prompt ${q_prompt_tpl} \ 65 | --p_class ${p_class_tpl} \ 66 | --out_dir ${dir} \ 67 | --batch_size ${batch_size} \ 68 | --seed ${seed} \ 69 | --dataset ${dataset} \ 70 | --use_memory ${use_memory} \ 71 | --tolerance ${tolerance} \ 72 | --held_out_prompt_ranking ${hout} \ 73 | --trust_factor ${tf} \ 74 | --fwd_model_type ${fwd_model_type} \ 75 | --bwd_model_type ${bwd_model_type} \ 76 | --forward_use_classes True \ 77 | --logp_penalty ${logp_penalty} \ 78 | --posterior_temp ${posterior_temp} \ 79 | --use_h_argmax ${hmax} \ 80 | --decay_logp_penalty ${decay}" >> /tmp/jobs_${dataset}.txt 81 | #seed 82 | done 83 | done 84 | done 85 | done 86 | done 87 | done 88 | done 89 | done 90 | done 91 | done 92 | done 93 | 94 | # launch 95 | # head -n 1 /tmp/jobs_${dataset}.txt 96 | parallel -j 15 < /tmp/jobs_${dataset}.txt 97 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_e2e/sweep_nlu.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=${1} 3 | iters=50 4 | batch_size=20 5 | num_p_samples=20 6 | use_memory=5 7 | tolerance=2 8 | num_h_samples=5 9 | p_class_tpl="classify_forward:v3.0" 10 | q_prompt_tpl="q_action_prompt:v3.5" 11 | p_hidden_tpl="analysis_forward" 12 | q_hidden_tpl='"analysis_forward|analysis_forward_y"' 13 | p1_max_tokens=256 14 | 15 | # sweep space 16 | bwd_temps=(0.7) 17 | posterior_temps=(1.0) 18 | logp_decays=(True) 19 | num_examples=(-1) 20 | batch_sizes=(20) 21 | trust_factors=(5) 22 | held_out_prompt_ranking=(False) 23 | one_layer=False 24 | fwd_model_type="text-davinci-003" 25 | bwd_model_type="text-davinci-003" 26 | log_penalties=(0.0 2.0 5.0 10.0) 27 | train_p2s=(True) 28 | h_max=(True False) 29 | 30 | # remove temp jobs file 31 | rm -rf /tmp/jobs_${dataset}.txt 32 | 33 | for logp_penalty in ${log_penalties[@]}; do 34 | for posterior_temp in ${posterior_temps[@]}; do 35 | for bwd_temp in ${bwd_temps[@]}; do 36 | for batch_size in ${batch_sizes[@]}; do 37 | for decay in ${logp_decays[@]}; do 38 | for num_example in ${num_examples[@]}; do 39 | for tf in ${trust_factors[@]}; do 40 | for train_p2 in ${train_p2s[@]}; do 41 | for hmax in ${h_max[@]}; do 42 | for hout in ${held_out_prompt_ranking[@]}; do 43 | 44 | dir=log/one_layer${one_layer}_e2e/${dataset}/${fwd_model_type}_${bwd_model_type}_trp2${train_p2}_hmax${hmax}_tf${tf}_heldoutpromptrank${hout}_nex${num_example}_stripoptFalse_decay${decay}_logp${logp_penalty}_bwdt${bwd_temp}_bsz${batch_size}_np${num_p_samples}_nh${num_h_samples}_pt${posterior_temp} 45 | /bin/rm -rf ${dir} 46 | 47 | for seed in 13 42 25; do 48 | echo "python vi_main.py \ 49 | --rewrite_loss_only True \ 50 | --val_freq 2 \ 51 | --do_first_eval \ 52 | --one_layer ${one_layer} \ 53 | --train_p1 True \ 54 | --train_p2 ${train_p2} \ 55 | --balance_batch \ 56 | --num_train_examples ${num_example} \ 57 | --p1_max_tokens ${p1_max_tokens} \ 58 | --num_p_samples ${num_p_samples} \ 59 | --num_h_samples ${num_h_samples} \ 60 | --bwd_temp ${bwd_temp} \ 61 | --iters ${iters} \ 62 | --p_hidden ${p_hidden_tpl} \ 63 | --q_hidden ${q_hidden_tpl} \ 64 | --q_prompt ${q_prompt_tpl} \ 65 | --p_class ${p_class_tpl} \ 66 | --out_dir ${dir} \ 67 | --batch_size ${batch_size} \ 68 | --seed ${seed} \ 69 | --dataset ${dataset} \ 70 | --use_memory ${use_memory} \ 71 | --tolerance ${tolerance} \ 72 | --held_out_prompt_ranking ${hout} \ 73 | --trust_factor ${tf} \ 74 | --fwd_model_type ${fwd_model_type} \ 75 | --bwd_model_type ${bwd_model_type} \ 76 | --forward_use_classes True \ 77 | --logp_penalty ${logp_penalty} \ 78 | --posterior_temp ${posterior_temp} \ 79 | --use_h_argmax ${hmax} \ 80 | --decay_logp_penalty ${decay}" >> /tmp/jobs_${dataset}.txt 81 | #seed 82 | done 83 | done 84 | done 85 | done 86 | done 87 | done 88 | done 89 | done 90 | done 91 | done 92 | done 93 | 94 | # launch 95 | head -n 1 /tmp/jobs_${dataset}.txt 96 | # parallel -j 15 < /tmp/jobs_${dataset}.txt 97 | -------------------------------------------------------------------------------- /projects/vi_dln/scripts/sweep_phi2_gpt35.sh: -------------------------------------------------------------------------------- 1 | # set -x # print commands to terminal 2 | dataset=${1} 3 | iters=20 4 | held_out_prompt_ranking=True 5 | use_memory=5 6 | tolerance=2 7 | p_class_tpl="classify_forward:3.0" 8 | q_prompt_tpl="q_action_prompt:v3.5" 9 | p_hidden_tpl="suffix_forward_tbs" 10 | q_hidden_tpl="suffix_forward_tbs_y|suffix_forward_tbs" 11 | fwd_model_type="microsoft/phi-2" 12 | bwd_model_type="gpt-35-turbo-instruct" 13 | one_layer=False 14 | 15 | # sweep space 16 | batch_sizes=(20) 17 | trust_factors=(5.) 18 | strip_options=(True) 19 | logp_decays=(False) 20 | bwd_temps=(0.7) 21 | posterior_temps=(1.0) 22 | num_p_samples_sweep=(10 20 50) 23 | num_h_samples_sweep=(05 10 20) 24 | log_penalties=(0.0 2.0 5.0 7.0) 25 | 26 | 27 | for trust_factor in ${trust_factors[@]}; do 28 | for posterior_temp in ${posterior_temps[@]}; do 29 | for bwd_temp in ${bwd_temps[@]}; do 30 | for batch_size in ${batch_sizes[@]}; do 31 | for strip in ${strip_options[@]}; do 32 | for decay in ${logp_decays[@]}; do 33 | for num_p_samples in ${num_p_samples_sweep[@]}; do 34 | for num_h_samples in ${num_h_samples_sweep[@]}; do 35 | for logp_penalty in ${log_penalties[@]}; do 36 | 37 | dir=log/one_layer${one_layer}_e2e/${dataset}/fmt${model_type}_bmt${bwd_model_type}_stripopt${strip}_decay${decay}_logp${logp_penalty}_bwdt${bwd_temp}_bsz${batch_size}_np${num_p_samples}_nh${num_h_samples}_pt${posterior_temp}_tf${trust_factor} 38 | 39 | for seed in 13 42 25; do 40 | python vi_main.py \ 41 | --val_freq 2 \ 42 | --do_first_eval \ 43 | --one_layer ${one_layer} \ 44 | --train_p1 True \ 45 | --train_p2 True \ 46 | --balance_batch \ 47 | --p1_max_tokens 512 \ 48 | --num_p_samples ${num_p_samples} \ 49 | --num_h_samples ${num_h_samples} \ 50 | --bwd_temp ${bwd_temp} \ 51 | --iters ${iters} \ 52 | --p_hidden ${p_hidden_tpl} \ 53 | --q_hidden ${q_hidden_tpl} \ 54 | --q_prompt ${q_prompt_tpl} \ 55 | --p_class ${p_class_tpl} \ 56 | --out_dir ${dir} \ 57 | --batch_size ${batch_size} \ 58 | --seed ${seed} \ 59 | --dataset ${dataset} \ 60 | --use_memory ${use_memory} \ 61 | --tolerance ${tolerance} \ 62 | --held_out_prompt_ranking ${held_out_prompt_ranking} \ 63 | --trust_factor ${trust_factor} \ 64 | --fwd_model_type ${fwd_model_type} \ 65 | --bwd_model_type ${bwd_model_type} \ 66 | --forward_use_classes True \ 67 | --logp_penalty ${logp_penalty} \ 68 | --posterior_temp ${posterior_temp} \ 69 | --strip_options_for_hidden ${strip} \ 70 | --decay_logp_penalty ${decay} 71 | #seed 72 | done 73 | done 74 | done 75 | done 76 | done 77 | done 78 | done 79 | done 80 | done 81 | done 82 | 83 | # launch 84 | # parallel -j 15 < /tmp/jobs_${dataset}.txt 85 | # head -n 1 /tmp/jobs_${dataset}.txt 86 | 87 | 88 | 89 | # hyperbaton navigate date_understanding logical_deduction_seven_objects mpqa trec subj disaster airline 90 | 91 | # nohup bash scripts/sweep.sh hyperbaton & 92 | # nohup bash scripts/sweep.sh navigate & 93 | # nohup bash scripts/sweep.sh date_understanding & 94 | # nohup bash scripts/sweep.sh logical_deduction_seven_objects & 95 | # nohup bash scripts/sweep.sh mpqa & 96 | # nohup bash scripts/sweep.sh trec & 97 | # nohup bash scripts/sweep.sh subj & 98 | # nohup bash scripts/sweep.sh disaster & 99 | # nohup bash scripts/sweep.sh airline & -------------------------------------------------------------------------------- /projects/vi_dln/scripts/two_layers_e2e/sweep_bbh.sh: -------------------------------------------------------------------------------- 1 | set -x # print commands to terminal 2 | dataset=${1} 3 | iters=50 4 | batch_size=20 5 | num_p_samples=20 6 | use_memory=5 7 | tolerance=5 8 | num_h_samples=5 9 | p_class_tpl="classify_forward:v3.0" 10 | q_prompt_tpl="q_action_prompt:v3.5" 11 | p_hidden_tpl="suffix_forward_tbs" 12 | q_hidden_tpl='"suffix_forward_tbs_y|suffix_forward_tbs"' 13 | 14 | # if dataset == logical_deduction_seven_objects, set this to 512 15 | if [ ${dataset} == "logical_deduction_seven_objects" ]; then 16 | p1_max_tokens=512 17 | else 18 | p1_max_tokens=256 19 | fi 20 | 21 | # sweep space 22 | bwd_temps=(0.7) 23 | posterior_temps=(1.0) 24 | logp_decays=(False True) 25 | num_examples=(-1) 26 | batch_sizes=(20) 27 | trust_factors=(0 5) 28 | held_out_prompt_ranking=(False True) 29 | one_layer=False 30 | fwd_model_type="text-davinci-003" 31 | bwd_model_type="text-davinci-003" 32 | log_penalties=(0.0 1.0 3.0 5.0) 33 | train_p2s=(True) 34 | h_max=(False) 35 | 36 | # remove temp jobs file 37 | rm -rf /tmp/jobs_${dataset}.txt 38 | 39 | for logp_penalty in ${log_penalties[@]}; do 40 | for posterior_temp in ${posterior_temps[@]}; do 41 | for bwd_temp in ${bwd_temps[@]}; do 42 | for batch_size in ${batch_sizes[@]}; do 43 | for decay in ${logp_decays[@]}; do 44 | for num_example in ${num_examples[@]}; do 45 | for tf in ${trust_factors[@]}; do 46 | for train_p2 in ${train_p2s[@]}; do 47 | for hmax in ${h_max[@]}; do 48 | for hout in ${held_out_prompt_ranking[@]}; do 49 | 50 | dir=log/one_layer${one_layer}_e2e/${dataset}_new_sweep/${fwd_model_type}_${bwd_model_type}_trp2${train_p2}_hmax${hmax}_tf${tf}_heldoutpromptrank${hout}_nex${num_example}_stripoptFalse_decay${decay}_logp${logp_penalty}_bwdt${bwd_temp}_bsz${batch_size}_np${num_p_samples}_nh${num_h_samples}_pt${posterior_temp} 51 | /bin/rm -rf ${dir} 52 | 53 | for seed in 13 42 25; do 54 | echo "python vi_main.py \ 55 | --rewrite_loss_only False \ 56 | --val_freq 2 \ 57 | --do_first_eval \ 58 | --one_layer ${one_layer} \ 59 | --train_p1 True \ 60 | --train_p2 ${train_p2} \ 61 | --balance_batch \ 62 | --p1_max_tokens ${p1_max_tokens} \ 63 | --num_p_samples ${num_p_samples} \ 64 | --num_h_samples ${num_h_samples} \ 65 | --bwd_temp ${bwd_temp} \ 66 | --iters ${iters} \ 67 | --p_hidden ${p_hidden_tpl} \ 68 | --q_hidden ${q_hidden_tpl} \ 69 | --q_prompt ${q_prompt_tpl} \ 70 | --p_class ${p_class_tpl} \ 71 | --out_dir ${dir} \ 72 | --batch_size ${batch_size} \ 73 | --seed ${seed} \ 74 | --dataset ${dataset} \ 75 | --use_memory ${use_memory} \ 76 | --tolerance ${tolerance} \ 77 | --held_out_prompt_ranking ${hout} \ 78 | --trust_factor ${tf} \ 79 | --fwd_model_type ${fwd_model_type} \ 80 | --bwd_model_type ${bwd_model_type} \ 81 | --forward_use_classes True \ 82 | --logp_penalty ${logp_penalty} \ 83 | --posterior_temp ${posterior_temp} \ 84 | --use_h_argmax ${hmax} \ 85 | --decay_logp_penalty ${decay}" >> /tmp/jobs_${dataset}.txt 86 | done 87 | done 88 | done 89 | done 90 | done 91 | done 92 | done 93 | done 94 | done 95 | done 96 | done 97 | 98 | # launch 99 | # head -n 1 /tmp/jobs_${dataset}.txt 100 | parallel -j 15 < /tmp/jobs_${dataset}.txt 101 | -------------------------------------------------------------------------------- /tests/test_vi_sampler.py: -------------------------------------------------------------------------------- 1 | import re 2 | from unittest.mock import MagicMock 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | from dln.vi.sampler import PosteriorSampler, PromptSampler 8 | 9 | 10 | def test_sample_q_p(backward_info, mock_llm): 11 | inputs, y, y_hat, losses = backward_info 12 | sampler = PromptSampler(mock_llm) 13 | mock_eval_fn = MagicMock(return_value=["new prompt 1", "new prompt 2"]) 14 | sampler.evaluate_func = mock_eval_fn 15 | prompt = "test prompt" 16 | num_samples = 2 17 | held_out_half = False 18 | prompts = sampler.sample_q_p( 19 | inputs, y, y_hat, losses, prompt, num_samples, held_out_half 20 | ) 21 | 22 | q_prompt = mock_eval_fn.call_args[0][0][ 23 | 0 24 | ] # rendered template sent to evaluate_func 25 | 26 | success_block = re.findall(r"# Student successes(.*?)\n\n", q_prompt, re.DOTALL)[0] 27 | assert "test_1" in success_block 28 | assert "test_3" in success_block 29 | assert "test_2" not in success_block 30 | assert "test_4" not in success_block 31 | 32 | error_block = re.findall(r"# Student errors(.*?)\n\n", q_prompt, re.DOTALL)[0] 33 | assert "test_2" in error_block 34 | assert "test_4" in error_block 35 | assert "test_1" not in error_block 36 | assert "test_3" not in error_block 37 | 38 | np.testing.assert_array_equal( 39 | prompts, ["test prompt", "new prompt 1", "new prompt 2"] 40 | ) 41 | 42 | 43 | def test_sample_q_p_hold_out_half(backward_info, mock_llm): 44 | inputs, y, y_hat, losses = backward_info 45 | sampler = PromptSampler(mock_llm) 46 | mock_eval_fn = MagicMock(return_value=["new prompt 1", "new prompt 2"]) 47 | sampler.evaluate_func = mock_eval_fn 48 | prompt = "test prompt" 49 | num_samples = 2 50 | held_out_half = True 51 | prompts = sampler.sample_q_p( 52 | inputs, y, y_hat, losses, prompt, num_samples, held_out_half 53 | ) 54 | 55 | q_prompt = mock_eval_fn.call_args[0][0][ 56 | 0 57 | ] # rendered template sent to evaluate_func 58 | 59 | success_block = re.findall(r"# Student successes(.*?)\n\n", q_prompt, re.DOTALL)[0] 60 | error_block = re.findall(r"# Student errors(.*?)\n\n", q_prompt, re.DOTALL)[0] 61 | 62 | success_examples = [i for i in y if i in success_block] 63 | error_examples = [i for i in y_hat if i in error_block] 64 | 65 | assert len(success_examples + error_examples) == 2 66 | assert "test_2" not in success_block 67 | assert "test_4" not in success_block 68 | assert "test_1" not in error_block 69 | assert "test_3" not in error_block 70 | np.testing.assert_array_equal( 71 | prompts, ["test prompt", "new prompt 1", "new prompt 2"] 72 | ) 73 | 74 | 75 | def test_sample_q_h(backward_info, mock_llm): 76 | inputs, y, _, _ = backward_info 77 | h = ["test 1", "test2", "test 3", "test4"] 78 | num_samples = 2 79 | sampler = PosteriorSampler(mock_llm, "suffix_forward_tbs") 80 | mock_eval_fn = MagicMock( 81 | # h * num_samples 82 | return_value=[ 83 | "test 1.1", 84 | "test 1.2", 85 | "test 2.1", 86 | "test 2.2", 87 | "test 3.1", 88 | "test 3.2", 89 | "test 4.1", 90 | "test 4.2", 91 | ] 92 | ) 93 | sampler.evaluate_func = mock_eval_fn 94 | prompt = "test prompt" 95 | next_prompt = "test next prompt" 96 | h_hat = sampler.sample_q_h( 97 | inputs, 98 | y, 99 | h, 100 | prompt, 101 | next_prompt, 102 | num_samples, 103 | ) 104 | np.testing.assert_equal( 105 | h_hat, 106 | [ 107 | ["test 1.1", "test 1.2"], 108 | ["test 2.1", "test 2.2"], 109 | ["test 3.1", "test 3.2"], 110 | ["test 4.1", "test 4.2"], 111 | ], 112 | ) 113 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .DS_Store 163 | log/ 164 | data/ 165 | wandb/ 166 | # demo data 167 | result_data.json 168 | -------------------------------------------------------------------------------- /tests/test_dln_losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from dln.loss import LLoss, LossRegistry, NumberPresenceLoss, ExactMatchLoss 4 | 5 | 6 | def test_available_losses(): 7 | assert "number_presence_loss" in LossRegistry.available_losses() 8 | for loss in LossRegistry.available_losses(): 9 | assert isinstance(LossRegistry.instantiate(loss), LLoss) 10 | 11 | 12 | def test_instantiate_loss_from_str(): 13 | loss = LossRegistry.instantiate("exact_match_loss") 14 | assert isinstance(loss, ExactMatchLoss) 15 | with pytest.raises(ValueError, match="Unknown loss type: UnknownLoss"): 16 | LossRegistry.instantiate("UnknownLoss") 17 | 18 | 19 | def test_exact_match_loss(): 20 | y = ["a", "b", "c", "a", "b", "c"] 21 | y_hat = ["a", "a", "a", "b", "b", "c"] 22 | exact_match_loss = ExactMatchLoss(lambda x: x) 23 | losses = exact_match_loss(y, y_hat) 24 | np.testing.assert_array_equal(losses, [0.0, 1.0, 1.0, 1.0, 0.0, 0.0]) 25 | 26 | 27 | def test_exact_match_loss_no_postproc(): 28 | y = ["A", "B", "C", "a", "b", "c"] 29 | y_hat = ["a", "a", "a", "b", "b", "c"] 30 | exact_match_loss = ExactMatchLoss() 31 | losses = exact_match_loss(y, y_hat) 32 | np.testing.assert_array_equal(losses, [1.0, 1.0, 1.0, 1.0, 0.0, 0.0]) 33 | 34 | 35 | def test_exact_match_loss_postproc(): 36 | y = ["A", "B", "C", "A", "B", "C"] 37 | y_hat = ["a", "a", "a", "b", "b", "c"] 38 | exact_match_loss = ExactMatchLoss(lambda x: x.lower()) 39 | losses = exact_match_loss(y, y_hat) 40 | np.testing.assert_array_equal(losses, [0.0, 1.0, 1.0, 1.0, 0.0, 0.0]) 41 | assert y == ["A", "B", "C", "A", "B", "C"] # no side effect 42 | 43 | 44 | def test_exact_match_loss_postproc_property(): 45 | exact_match_loss = ExactMatchLoss(lambda x: x.upper()) 46 | assert exact_match_loss.postproc("abc") == "ABC" 47 | 48 | exact_match_loss = ExactMatchLoss() 49 | assert exact_match_loss.postproc("abc") == "abc" 50 | 51 | 52 | def test_number_presence_loss(): 53 | number_presence_loss = NumberPresenceLoss() 54 | inputs = [ 55 | "1234", 56 | "01234", 57 | "Answer 1234 in it.", 58 | "Answer\n1234", 59 | "Answer 01234 in it.", 60 | "12 34", 61 | "12340", 62 | "Answer 12340 not in it.", 63 | "Answer test_1234 not in it.", 64 | "Answer 101234 not in it." 65 | ] 66 | targets = [1234] * 10 67 | loss = number_presence_loss(inputs, targets) 68 | np.testing.assert_array_equal( 69 | loss, [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 70 | ) 71 | 72 | 73 | def test_number_presence_loss_single_value(): 74 | number_presence_loss = NumberPresenceLoss() 75 | assert not number_presence_loss("1234", 1234) 76 | assert not number_presence_loss("01234", "1234") 77 | assert not number_presence_loss("1234.0", 1234) 78 | assert not number_presence_loss("1234", 1234.0) 79 | assert not number_presence_loss("01234,000", 1234000) 80 | assert not number_presence_loss("Answer 1234 in it.", "1234") 81 | assert not number_presence_loss("Answer\n1234", 1234) 82 | assert not number_presence_loss("Answer 01234 in it.", 1234) 83 | assert not number_presence_loss("Answer 01234,000 in it.", 1234000) 84 | assert not number_presence_loss("Answer 01234.0 in it.", 1234) 85 | assert not number_presence_loss("Answer 01234 in it.", 1234.0) 86 | assert not number_presence_loss("Answer 01234 in it.", 1234.0) 87 | assert not number_presence_loss("Answer=01234.", 1234) 88 | assert not number_presence_loss("$1234.00", 1234) 89 | assert not number_presence_loss("$1234.50", 1234.5) 90 | assert number_presence_loss("$1234.50", 1234) 91 | assert number_presence_loss("$12.34", 1234) 92 | assert number_presence_loss("12340", "1234") 93 | assert number_presence_loss("Answer 12340 not in it.", 1234) 94 | assert number_presence_loss("Answer test_1234 not in it.", "1234") 95 | assert number_presence_loss("Answer 101234 not in it.", 1234) 96 | assert number_presence_loss("Answer 0123.4 not in it.", 1234) 97 | -------------------------------------------------------------------------------- /dln/loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import re 3 | from typing import Callable, Iterable, Optional, Union 4 | 5 | import numpy as np 6 | 7 | 8 | class LossRegistry: 9 | 10 | _available_losses = {} 11 | 12 | @classmethod 13 | def register(cls, loss_type: str): 14 | def inner(lloss): 15 | cls._available_losses[loss_type] = lloss 16 | return lloss 17 | return inner 18 | 19 | @classmethod 20 | def available_losses(cls): 21 | return list(cls._available_losses.keys()) 22 | 23 | @classmethod 24 | def instantiate(cls, loss_type: str, postproc: Optional[Callable] = None) -> "LLoss": 25 | try: 26 | return cls._available_losses[loss_type](postproc) 27 | except KeyError: 28 | raise ValueError(f'Unknown loss type: {loss_type}') 29 | 30 | 31 | class LLoss(ABC): 32 | 33 | def __init__(self, postproc: Optional[Callable] = None): 34 | """ 35 | Args: 36 | postproc: a function that takes and returns a string to be apply before calculating the loss 37 | """ 38 | self._postproc = postproc 39 | 40 | @property 41 | def postproc(self): 42 | """ 43 | Returns the post-processing function for the loss function. 44 | If the post-processing function has not been set, returns the identity function 45 | """ 46 | if self._postproc is None: 47 | return lambda x: x 48 | return self._postproc 49 | 50 | @abstractmethod 51 | def loss(self, inputs: Iterable[str], targets: Iterable[str]) -> np.array: 52 | """ 53 | Computes the loss between the input and target 54 | Args: 55 | input: The predicted outputs 56 | target: The true outputs 57 | Returns: 58 | The computed loss 59 | """ 60 | pass 61 | 62 | def __call__( 63 | self, 64 | inputs: Union[str, Iterable[str]], 65 | targets: Union[str, Iterable[str]], 66 | ) -> np.array: 67 | """ 68 | Calls the loss function. If inputs or targets are not iterables, they are converted to lists 69 | Args: 70 | inputs: The predicted outputs 71 | targets: The true outputs 72 | Returns: 73 | The computed loss as an np.array 74 | """ 75 | if isinstance(inputs, str) or not isinstance(inputs, Iterable): 76 | inputs = [inputs] 77 | if isinstance(targets, str) or not isinstance(targets, Iterable): 78 | targets = [targets] 79 | if self._postproc: 80 | inputs = [self.postproc(i) for i in inputs] 81 | targets = [self.postproc(t) for t in targets] 82 | losses = self.loss(inputs, targets) 83 | return losses 84 | 85 | 86 | @LossRegistry.register("exact_match_loss") 87 | class ExactMatchLoss(LLoss): 88 | """ 89 | Calculates the exact match loss between the predicted and target outputs, 90 | where 0 indicates a correct prediction and 1 indicates an incorrect prediction. 91 | """ 92 | def loss(self, inputs: Iterable[str], targets: Iterable[str]) -> np.array: 93 | losses = (np.array(inputs) != np.array(targets)).astype("float32") 94 | return losses 95 | 96 | 97 | @LossRegistry.register("number_presence_loss") 98 | class NumberPresenceLoss(LLoss): 99 | """ 100 | Calculates the loss based on the presence of a number in a string. 101 | 0 if the target number is present in the input, 1 otherwise. 102 | """ 103 | def loss(self, inputs: Iterable[str], targets: Iterable[str]) -> np.array: 104 | losses = [] 105 | for i, t in zip(inputs, targets): 106 | # Convert the target to float 107 | number = float(str(t).replace(",", "")) 108 | # Extract all numbers from the input 109 | numbers_in_text = re.findall(r'\b\d*\.?\d+\b', i.replace(",", "")) 110 | # Try to convert each extracted string into a float and compare it to the number 111 | # we can match just the last number if we consider that the last number is the answer 112 | _loss = 1 113 | for num_str in numbers_in_text: 114 | if float(num_str) == number: 115 | _loss = 0 116 | losses.append(_loss) 117 | return np.array(losses).astype("float32") 118 | -------------------------------------------------------------------------------- /dln/templates/q_action_prompt.yaml: -------------------------------------------------------------------------------- 1 | v3.0: 2 | stop_tokens: ['\n\n', '[END]', '#'] 3 | template: |- 4 | A student is completing a task that requires producing a text output from a text input. The student receives an instruction that describes how to produce the output given each input. 5 | The student has made some errors. Your task is to improve the instruction such that the student can fix the errors. 6 | 7 | {-% if prompt != '' %} 8 | This was the instruction. 9 | ## Instruction 10 | > {{ prompt }} 11 | [END] 12 | {% endif %-} 13 | 14 | # Student successes 15 | {% for backward_info in backward_infos %} {% if backward_info.loss == 0.0 %} 16 | ## Input: 17 | > {{ backward_info.input }} 18 | ## Correct Output: 19 | > {{ backward_info.target }} 20 | {% endif %} {% endfor %} 21 | 22 | # Student errors 23 | {% for backward_info in backward_infos %} {% if backward_info.loss > 0.0 %} 24 | ## Input: 25 | > {{ backward_info.input }} 26 | ## Student Output: 27 | > {{ backward_info.output }} 28 | ## Correct Output: 29 | > {{ backward_info.target }} 30 | {% endif %} {% endfor %} 31 | 32 | Improve the instruction to fix the student errors. {{ message }} 33 | ## Instruction 34 | > 35 | message_alternatives: 36 | - Clarify the instruction by adding few words or a short sentence. 37 | - Improve the instruction by providing examples on how to solve the task. 38 | - Shorten the instruction by removing superflous words or sentences. 39 | - Rewrite the instruction by providing detailed information to avoid ambiguity. 40 | v3.5: 41 | stop_tokens: ['\n\n', '[END]', '#'] 42 | template: |- 43 | A student is completing a task that requires producing a text output from a text input. The student receives an instruction that describes how to produce the output given each input. 44 | The student has made some errors. Your task is to improve the instruction such that the student can fix the errors. 45 | 46 | This was the instruction. 47 | ## Instruction 48 | > {{ prompt }} 49 | [END] 50 | 51 | # Student successes 52 | {% for backward_info in backward_infos %} {% if backward_info.loss == 0.0 %} 53 | ## Input: 54 | > {{ backward_info.input }} 55 | ## Correct Output: 56 | > {{ backward_info.target }} 57 | {% endif %} {% endfor %} 58 | 59 | # Student errors 60 | {% for backward_info in backward_infos %} {% if backward_info.loss > 0.0 %} 61 | ## Input: 62 | > {{ backward_info.input }} 63 | ## Student Output: 64 | > {{ backward_info.output }} 65 | ## Correct Output: 66 | > {{ backward_info.target }} 67 | {% endif %} {% endfor %} 68 | 69 | Improve the instruction to fix the student errors. {{ message }} 70 | ## Instruction 71 | > 72 | message_alternatives: 73 | - Clarify the instruction by adding few words or a short sentence. Be concise. 74 | - Improve the instruction by providing examples on how to solve the task. Be concise. 75 | - Shorten the instruction by removing superflous words or sentences. 76 | - Rewrite the instruction by providing detailed information to avoid ambiguity. Be concise. 77 | v3.6: 78 | stop_tokens: ['[END]'] 79 | template: |- 80 | A student is completing a task that requires producing a text output from a text input. 81 | The student receives an instruction that describes how to produce the output given each input. 82 | Your task is to improve the instruction such that the student can identify and correct any errors. 83 | {%- if prompt %} 84 | 85 | 86 | This was the instruction: 87 | ## Instruction 88 | > {{ prompt }} 89 | [END] 90 | {%- endif %} 91 | {%- set success_list = backward_infos | selectattr('loss', 'equalto', 0.0) | list %} 92 | {%- set error_list = backward_infos | selectattr('loss', 'greaterthan', 0.0) | list %} 93 | {%- if success_list %} 94 | 95 | 96 | # Student successes 97 | {%- for backward_info in success_list %} 98 | 99 | ## Input: 100 | > {{ backward_info.input }} 101 | ## Correct Output: 102 | > {{ backward_info.target }} 103 | {%- endfor %} 104 | {%- endif %} 105 | {%- if error_list %} 106 | 107 | 108 | # Student errors 109 | {%- for backward_info in error_list %} 110 | 111 | ## Input: 112 | > {{ backward_info.input }} 113 | ## Student Output: 114 | > {{ backward_info.output }} 115 | ## Correct Output: 116 | > {{ backward_info.target }} 117 | {%- endfor %} 118 | {%- endif %} 119 | 120 | 121 | Improve the instruction by being concise and avoiding unnecessary code generation. 122 | ## Instruction 123 | > -------------------------------------------------------------------------------- /tests/test_vi_layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from dln.score import OutputClasses 4 | from dln.vi.layers import PriorLayer, ResidualPriorLayer 5 | 6 | 7 | def test_apply_residual_without_template(mock_logprobs_score, mock_llm): 8 | inputs = np.array(["input1", "input2", "input3"]) 9 | outputs = np.array(["output1", "output2", "output3"]) 10 | residual_prior_layer = ResidualPriorLayer( 11 | logprobs_score=mock_logprobs_score, 12 | forward_evaluate=mock_llm, 13 | forward_template="suffix_forward", 14 | init="A task description", 15 | ) 16 | result = residual_prior_layer.apply_residual(outputs, inputs) 17 | expected_outputs = np.array( 18 | [ 19 | "input1\nYour thoughts were:\noutput1", 20 | "input2\nYour thoughts were:\noutput2", 21 | "input3\nYour thoughts were:\noutput3", 22 | ] 23 | ) 24 | np.testing.assert_equal(result, expected_outputs) 25 | 26 | 27 | def test_apply_residual_with_template(mock_logprobs_score, mock_llm): 28 | inputs = np.array(["input1", "input2", "input3"]) 29 | outputs = np.array(["output1", "output2", "output3"]) 30 | residual_prior_layer = ResidualPriorLayer( 31 | logprobs_score=mock_logprobs_score, 32 | forward_evaluate=mock_llm, 33 | forward_template="suffix_forward", 34 | init="A task description", 35 | ) 36 | result = residual_prior_layer.apply_residual(outputs, inputs, use_template=True) 37 | expected_outputs = np.array( 38 | [ 39 | "input1\n\nA task description\nYour thoughts were:\noutput1", 40 | "input2\n\nA task description\nYour thoughts were:\noutput2", 41 | "input3\n\nA task description\nYour thoughts were:\noutput3", 42 | ] 43 | ) 44 | np.testing.assert_equal(result, expected_outputs) 45 | 46 | 47 | def test_log_p_with_output_classes(top_logprobs, mock_logprobs_score, mock_llm): 48 | mock_logprobs_score.forward_evaluate._generate = top_logprobs 49 | inputs = ["1 + 1", "1 * 1"] 50 | outputs = ["B", "A"] 51 | output_classes = OutputClasses(protos=["a|A", "b|B"]) 52 | prior_layer = PriorLayer( 53 | logprobs_score=mock_logprobs_score, 54 | forward_evaluate=mock_llm, 55 | forward_template="suffix_forward", 56 | init="", 57 | ) 58 | logp = prior_layer.log_p( 59 | inputs, outputs, output_classes=output_classes 60 | ) 61 | np.testing.assert_almost_equal(logp.logp_targets, [-8.67468626, -0.44289729]) 62 | np.testing.assert_almost_equal( 63 | logp.distribution, 64 | [ 65 | [9.99829143e-01, 1.70856546e-04], 66 | [6.42173164e-01, 3.57826836e-01], 67 | ], 68 | ) 69 | 70 | 71 | def test_log_p_without_output_classes(raw_logprobs, score_requests, mock_logprobs_score, mock_llm): 72 | mock_logprobs_score.forward_evaluate._generate = raw_logprobs 73 | inputs = [s.context for s in score_requests] 74 | outputs = ["B", "A"] 75 | prior_layer = PriorLayer( 76 | logprobs_score=mock_logprobs_score, 77 | forward_evaluate=mock_llm, 78 | forward_template="suffix_forward", 79 | init="", 80 | ) 81 | logp = prior_layer.log_p(inputs, outputs) 82 | np.testing.assert_almost_equal(logp.logp_targets, [-0.7682657, -0.7632834]) 83 | 84 | 85 | def test_forward_with_output_class(top_logprobs, mock_logprobs_score, mock_llm): 86 | mock_logprobs_score.forward_evaluate._generate = top_logprobs 87 | inputs = ["1 + 1", "1 * 1"] 88 | output_classes = OutputClasses(protos=["A|a", "B|b"]) 89 | prior_layer = PriorLayer( 90 | logprobs_score=mock_logprobs_score, 91 | forward_evaluate=mock_llm, 92 | forward_template="suffix_forward", 93 | init="", 94 | ) 95 | result = prior_layer.forward(inputs, output_classes) 96 | np.testing.assert_equal(result, ["A", "A"]) 97 | 98 | 99 | def test_forward_without_output_class(text_outputs, mock_logprobs_score, mock_llm): 100 | mock_llm._generate = text_outputs 101 | inputs = ["1 + 1", "1 * 1"] 102 | prior_layer = PriorLayer( 103 | logprobs_score=mock_logprobs_score, 104 | forward_evaluate=mock_llm, 105 | forward_template="suffix_forward", 106 | init="", 107 | ) 108 | result = prior_layer.forward(inputs) 109 | np.testing.assert_equal(result, ["A", "A"]) 110 | 111 | 112 | def test_forward_strip_double_newlines(mock_logprobs_score, mock_llm): 113 | text_output = lambda *args, **kwargs: ["A\n\n"] 114 | mock_llm._generate = text_output 115 | inputs = ["1 + 1"] 116 | prior_layer = PriorLayer( 117 | logprobs_score=mock_logprobs_score, 118 | forward_evaluate=mock_llm, 119 | forward_template="suffix_forward", 120 | init="", 121 | ) 122 | result = prior_layer.forward(inputs, strip_double_newlines=True) 123 | np.testing.assert_equal(result, ["A\n"]) 124 | -------------------------------------------------------------------------------- /projects/vi_dln/read_results.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import sys 3 | import re 4 | import json 5 | from itertools import combinations 6 | 7 | import pandas as pd 8 | import numpy as np 9 | import scipy.stats 10 | 11 | 12 | 13 | def escape_ansi(line): 14 | ansi_escape = re.compile(r"(?:\x1B[@-_]|[\x80-\x9F])[\[0-9:;<=>?\]]*[ -/]*[@-~]") 15 | return ansi_escape.sub("", line) 16 | 17 | 18 | root = sys.argv[1] 19 | results = {} 20 | 21 | for method in glob.glob(root + "/**/output.log", recursive=True): 22 | # time should be -2, method -3 23 | name = "/".join(method.split("/")[:-2]) 24 | if name not in results: 25 | results[name] = {"dev": [], "test": [], "cost": [], "args": None, "init_dev": []} 26 | 27 | dev_accuracy = [] 28 | test_accuracy = [] 29 | token_cost = [] 30 | init_dev_accuracy = [] 31 | 32 | with open(method, "r") as f: 33 | lines = f.readlines() 34 | for i, line in enumerate(lines): 35 | if i == 0: 36 | args = line[line.find("{") :] 37 | json_args = json.loads(args) 38 | results[name]["args"] = json_args 39 | elif "INIT DEV ACC:" in line: 40 | line = escape_ansi(line).partition("INIT DEV ACC:")[-1] 41 | init_dev_accuracy.append(float(line.strip())) 42 | elif "BEST DEV ACC:" in line: 43 | line = escape_ansi(line).partition("BEST DEV ACC:")[-1] 44 | dev_accuracy.append(float(line.strip())) 45 | elif "TEST ACC:" in line: 46 | line = escape_ansi(line).partition("TEST ACC:")[-1] 47 | test_accuracy.append(float(line.strip())) 48 | elif "COST:" in line: 49 | line = escape_ansi(line).partition("COST:")[-1] 50 | token_cost.append(float(line.strip())) 51 | 52 | # skip jobs not completed 53 | if dev_accuracy and test_accuracy: 54 | results[name]["dev"].append(np.max(dev_accuracy)) 55 | results[name]["test"].append(test_accuracy[0]) 56 | results[name]["init_dev"].append(init_dev_accuracy[0] if init_dev_accuracy else np.nan) 57 | results[name]["cost"].append(token_cost[0] if token_cost else np.nan) 58 | 59 | 60 | def mean_confidence_interval(data, confidence=0.95): 61 | a = 1.0 * np.array(data) 62 | n = len(a) 63 | m, se = np.mean(a), scipy.stats.sem(a) 64 | h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1) 65 | return h 66 | 67 | 68 | def top_k(data, k): 69 | """ 70 | Generate all combinations of k seeded-runs 71 | Take the argmax dev score for each combination 72 | Report mean and std of the test score over all combinations (given their argmax). 73 | 74 | Conceptually, in the case of top-1, it is simply the average of all the 75 | runs and for top-10 (given 10 seeds total) it is equivalent to taking 76 | the argmax over all the seeds and reporting its test score. 77 | """ 78 | test_scores = np.array(data['test']) 79 | dev_scores = np.array(data['dev']) 80 | 81 | if k <= 0 or k > len(test_scores): 82 | return np.nan, np.nan 83 | 84 | indices = np.arange(len(dev_scores)) 85 | combination_indices = combinations(indices, k) 86 | 87 | max_test_scores = [] 88 | for c in combination_indices: 89 | argmax_dev_index = np.argmax(dev_scores[list(c)]) 90 | max_test_scores.append(test_scores[list(c)[argmax_dev_index]]) 91 | 92 | mean_test_score = np.mean(max_test_scores) 93 | std_test_score = np.std(max_test_scores) 94 | 95 | return mean_test_score, std_test_score 96 | 97 | def find_dataset_name(dataset_path): 98 | dataset_names = ["gsm8k", "navigate", "hyperbaton", "date_understanding", "logical_deduction_seven_objects", "disaster", "airline", "mpqa", "trec", "subj"] 99 | for name in dataset_names: 100 | if name in dataset_path: 101 | return name.split("_")[0] 102 | 103 | data = [] 104 | for k, v in results.items(): 105 | top_k_1_mean_test, top_k_1_std_test = top_k(v, 1) 106 | top_k_3_mean_test, top_k_3_std_test = top_k(v, 3) 107 | top_k_argmax_mean_test, _ = top_k(v, len(v["test"])) 108 | 109 | # TODO: receive min_seeds as parameter 110 | if len(v["dev"]) <= 2: 111 | continue 112 | data.append( 113 | { 114 | "name": find_dataset_name(k), 115 | "seeds": len(v["dev"]), 116 | "init_dev": np.mean(v["init_dev"]) if "init_dev" in v else None, 117 | "dev": np.mean(v["dev"]), 118 | "test": np.mean(v["test"]), 119 | "cost": np.mean(v["cost"]), 120 | "dstd": np.std(v["dev"]), 121 | "tstd": np.std(v["test"]), 122 | "tcf": mean_confidence_interval(v["test"]), 123 | "top_k_1_mean_test": top_k_1_mean_test, 124 | "top_k_1_std_test": top_k_1_std_test, 125 | "top_k_3_mean_test": top_k_3_mean_test, 126 | "top_k_3_std_test": top_k_3_std_test, 127 | "top_k_argmax_mean_test": top_k_argmax_mean_test, 128 | "log_path": k, 129 | } 130 | ) 131 | 132 | print(pd.DataFrame.from_records(data).sort_values(by=["name", "dev"], ascending=[True, False]).to_markdown(index=False)) 133 | -------------------------------------------------------------------------------- /dln/vi/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import copy 4 | import json 5 | from typing import Optional 6 | 7 | import numpy as np 8 | 9 | 10 | # Use this function to log messages to the file 11 | def log_message(*messages): 12 | print(*messages) 13 | logging.info(" ".join(map(str, messages))) 14 | 15 | 16 | def compute_pairwise_kl(lps): 17 | """We make updates constrained by the KL divergence between the function induced by the current prompt 18 | and the function induced by the new prompt. We compute the KL divergence 19 | between the functions induced by the prompts. 20 | 21 | The distribution induced by the current prompt is the first element of the second axis in lps. 22 | """ 23 | # compute pairwise kl, considers reference always as the first prompt 24 | return ( 25 | (lps[:, :1, :] * (np.log(lps[:, :1, :]) - np.log(lps[:, :, :]))).sum(-1).mean(0) 26 | ) 27 | 28 | class ResultLogEntry(): 29 | def __init__(self): 30 | self.hiddens = None 31 | self.candidates = [[], []] 32 | self.metrics = {} 33 | self.outputs = [] 34 | 35 | def log_metric(self, metric: str, value: Optional[float]): 36 | if value is not None: 37 | value = float(value) 38 | 39 | self.metrics[metric] = value 40 | 41 | def log_outputs(self, outputs): 42 | self.outputs = outputs 43 | 44 | def log_hiddens(self, hiddens, size): 45 | self.hiddens = [[]] * size if hiddens is None else [[h] for h in hiddens] 46 | 47 | def log_candidates(self, p_tilde_2, p2_elbo, p_tilde_1=None, p1_elbo=None): 48 | """ 49 | If one_layer, p_tilde_1 and p1_elbo are None, 50 | and we only store the two-layer candidates in the 0th list element. 1st element stays []. 51 | If two_layer, we store the first layer candidates in the 0th list element 52 | and the second layer candidates in the 1st list element. 53 | """ 54 | self.candidates = [[],[]] 55 | if p_tilde_1 is not None: 56 | for i in range(p_tilde_1.shape[0]): 57 | self.candidates[0].append({ 58 | "layer": p_tilde_1[i], 59 | "score": float(p1_elbo[i]), 60 | }) 61 | p2_ind = 1 62 | else: 63 | p2_ind = 0 64 | for i in range(p_tilde_2.shape[0]): 65 | self.candidates[p2_ind].append({ 66 | "layer": p_tilde_2[i], 67 | "score": float(p2_elbo[i]), 68 | }) 69 | 70 | 71 | class ResultLogWriter(object): 72 | def __init__(self, name: str, path: str): 73 | """ 74 | Args: 75 | name: File name 76 | path: File location 77 | Returns: 78 | A ResultLogWriter object 79 | """ 80 | self.name = name 81 | self.path = path 82 | self.result_dict = {} 83 | self.result_dict[self.name] = {'training': [], 'examples': []} 84 | 85 | def write_result(self, step, layers, metrics, candidates): 86 | self.result_dict[self.name]['training'].append({'step': step}) 87 | self.result_dict[self.name]['training'][-1]['layers'] = copy.deepcopy(layers) 88 | self.result_dict[self.name]['training'][-1]['metrics'] = copy.deepcopy(metrics) 89 | self.result_dict[self.name]['training'][-1]['candidates'] = copy.deepcopy(candidates) 90 | 91 | def write_examples(self, step, inputs, labels, outputs, hiddens): 92 | """ 93 | Args: 94 | step: The iteration number 95 | inputs: A list of input strings 96 | labels: A list of label strings 97 | outputs: A list of output strings 98 | hiddens: A list of hidden strings for two-layer-dlns 99 | An element of the "examples" list in the json file looks like: 100 | { 101 | "input": "Do cats sit on mats?", 102 | "label": "Yes", 103 | "trace": [ 104 | { 105 | "step": 0, 106 | "hiddens": ["Cats are picky."], 107 | "output": "No" 108 | }, 109 | { 110 | "step": 1, 111 | "hiddens": ["Cats would sit anywhere."], 112 | "output": "Yes" 113 | } 114 | ] 115 | } 116 | """ 117 | for inp, lab, outp, hidden in zip(inputs, labels, outputs, hiddens): 118 | # Get the element in the list of examples that matches the input 119 | example = next((ex for ex in self.result_dict[self.name]['examples'] if ex['input'] == inp), None) 120 | if example is None: 121 | self.result_dict[self.name]['examples'].append({ 122 | "input": inp, 123 | "label": lab, 124 | "trace": [{"step": step, "hiddens": hidden, "output": outp}], 125 | }) 126 | else: 127 | example['trace'].append({"step": step, "hiddens": hidden, "output": outp}) 128 | 129 | def save_to_json_file(self): 130 | # self.path is a path to a file 131 | os.makedirs(os.path.dirname(self.path), exist_ok=True) 132 | try: 133 | with open(self.path, 'r') as f: 134 | print('Loading existing json file %s' % self.path) 135 | loaded_dict = json.load(f) 136 | except FileNotFoundError: 137 | loaded_dict = {} 138 | if self.name not in loaded_dict: 139 | # Append or add the json dictionary if the result doesn't exist 140 | loaded_dict[self.name] = self.result_dict[self.name] 141 | with open(self.path, 'w') as f: 142 | json.dump(loaded_dict, f, indent=4) 143 | else: 144 | print(f"Result named {self.name} already exists in {self.path}!") -------------------------------------------------------------------------------- /dln/dataset_info.yaml: -------------------------------------------------------------------------------- 1 | ### 2 | # From Fantastically Ordered Prompts 3 | # https://arxiv.org/abs/2104.08786 4 | ### 5 | mpqa: 6 | # output space: ["negative", "positive"] 7 | # raw data example: 8 | # a decade of dramatic economic decline 9 | # 0 10 | # tremendous opportunities 11 | # 1 12 | label_mapping: {'0': 'negative', '1': 'positive'} 13 | instruction: "Read the following review, then choose whether it is negative or positive." 14 | output_type: "single word" 15 | 16 | trec: 17 | # output space: ["description", "entity", "expression", "human", "location", "number"] 18 | # raw data example: 19 | # What is considered the costliest disaster the insurance industry has ever faced ? 20 | # 1 21 | # Who do Herb and Tootsie live next door to ? 22 | # 3 23 | label_mapping: {'0': 'description', '1': 'entity', '2': 'expression', '3': 'human','4': 'location', '5': 'number'} 24 | instruction: "Read the following question, then choose whether it is about a description, entity, expression, human, location or number." 25 | output_type: "single word" 26 | 27 | subj: 28 | # output space: ["subjective", "objective"] 29 | # raw data example: 30 | # \"claude chabrol has here a thriller without thrills , but that's okay .\" 31 | # 0 32 | # a team of scientists is recruited in a crash project to send a ship and bomb into the center of the earth to prevent the catastrophe . 33 | # 1 34 | prefix: "" 35 | label_mapping: {'0': 'subjective', '1': 'objective'} 36 | instruction: "Read the following sentence, then choose whether it is subjective or objective." 37 | output_type: "single word" 38 | 39 | ### 40 | # From Big Bench Hard 41 | # https://arxiv.org/abs/2210.09261 42 | ### 43 | date_understanding: 44 | # output space: ["(A)", "(B)", "(C)", "(D)", "(E)", "(F)"] 45 | # raw data example: 46 | # On May 9th, 2017 Jane bought 40 eggs. She ate one per day. Today she ran out of eggs. What is the date 24 hours later in MM/DD/YYYY?\nOptions:\n(A) 06/19/2017\n(B) 07/17/2017\n(C) 06/20/2017\n(D) 06/18/2017\n(E) 06/15/2017\n(F) 07/10/2017 47 | # (A) 48 | prefix: "Infer the date from context." 49 | instruction: "Infer the date from context." 50 | output_type: "single word" 51 | protos: ["a|A", "b|B", "c|C", "d|D", "e|E", "f|F"] 52 | 53 | hyperbaton: 54 | # output space: ["(A)", "(B)"] 55 | # raw data example: 56 | # Which sentence has the correct adjective order:\nOptions:\n(A) pyramidal American glass exercise surfboard\n(B) glass exercise American pyramidal surfboard 57 | # (A) 58 | prefix: "Which sentence has the correct adjective order:\n" 59 | instruction: "Which sentence has the correct adjective order:" 60 | output_type: "single word" 61 | protos: ["a|A", "b|B"] 62 | 63 | logical_deduction_seven_objects: 64 | # output space: ["(A)", "(B)", "(C)", "(D)", "(E)", "(F)", "(G)"] 65 | # raw data example: 66 | # The following paragraphs each describe a set of seven objects arranged in a fixed order. The statements are logically consistent within each paragraph. On a branch, there are seven birds: a hummingbird, a cardinal, a blue jay, an owl, a raven, a quail, and a robin. The hummingbird is to the left of the quail. The robin is to the left of the cardinal. The blue jay is the leftmost. The cardinal is the fourth from the left. The raven is the third from the right. The owl is the third from the left.\nOptions:\n(A) The hummingbird is the second from the right\n(B) The cardinal is the second from the right\n(C) The blue jay is the second from the right\n(D) The owl is the second from the right\n(E) The raven is the second from the right\n(F) The quail is the second from the right\n(G) The robin is the second from the right 67 | # (A) 68 | prefix: "The following paragraphs each describe a set of seven objects arranged in a fixed order. The statements are logically consistent within each paragraph." 69 | instruction: "A seven-object logical deduction task which requires deducing the order of a sequence of objects." 70 | output_type: "single word" 71 | protos: ["a|A", "b|B", "c|C", "d|D", "e|E", "f|F", "g|G"] 72 | 73 | navigate: 74 | # output space: ["Yes", "No"] 75 | # raw data example: 76 | # If you follow these instructions, do you return to the starting point? Always face forward. Take 1 step right. Take 3 steps left. Take 2 steps right.\nOptions:\n- Yes\n- No 77 | # Yes 78 | prefix: "If you follow these instructions, do you return to the starting point?" 79 | instruction: "Read the following sentence, then determine whether you return to the starting point." 80 | output_type: "single word" 81 | protos: ["yes|Yes", "no|No"] 82 | 83 | ### 84 | # From Leopard 85 | # https://arxiv.org/abs/1911.03863 86 | ### 87 | disaster: 88 | # output space: ["yes", "no"] 89 | # raw data example: 90 | # Worried about how the CA drought might affect you? Extreme Weather: Does it Dampen Our Economy? http://t.co/fDzzuMyW8i 91 | # Relevant 92 | # #golf McIlroy fuels PGA speculation after video: Injured world number one Rory McIlroy fueled speculatio... http://t.co/dCyYJVmXHR #news 93 | # Not Relevant 94 | label_mapping: {'Not Relevant': 'no', 'Relevant': 'yes'} 95 | instruction: "Read the following sentence, then choose whether it is relevant to a disaster." 96 | output_type: "single word" 97 | 98 | airline: 99 | # output space: ["positive", "negative", "neutral"] 100 | # raw data example: 101 | # @SouthwestAir Great, thank you. Best of luck dealing with this horrible winter. 102 | # positive 103 | # @JetBlue Cancelled Flighted _\u00d9\u00f7\u00a2 104 | # negative 105 | # @USAirways I'm flying with you this Summer. Will I be able to leave Miami Airport during my 12 hour stopover there? 106 | # neutral 107 | label_mapping: {'positive': 'positive', 'negative': 'negative', 'neutral': 'neutral'} 108 | instruction: "Read the following sentence, then choose whether it is positive, negative, or neutral." 109 | output_type: "single word" 110 | protos: ["positive|Positive", "negative|Negative", "neutral|Neutral"] 111 | -------------------------------------------------------------------------------- /dln/vi/layers.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | from dln.loss import LLoss 6 | from dln.operator import LLM 7 | from dln.score import LogProbs, LogProbsScore, OutputClasses, ScoreRequest 8 | from dln.template import load_template 9 | from dln.vi.utils import log_message 10 | 11 | 12 | class PriorLayer: 13 | 14 | def __init__( 15 | self, 16 | logprobs_score: LogProbsScore, 17 | forward_evaluate: LLM, 18 | forward_template: str, 19 | init: str = "", 20 | ): 21 | self.forward_template = load_template( 22 | forward_template 23 | ) 24 | log_message("Forward template:\n", f"{repr(self.forward_template.template)}") 25 | self.weight = init 26 | self.logprobs_score = logprobs_score 27 | self.forward_evaluate = forward_evaluate 28 | 29 | def __call__(self, *args, **kwargs): 30 | return self.forward(*args, **kwargs) 31 | 32 | def forward( 33 | self, 34 | inputs, 35 | output_classes: OutputClasses = None, 36 | temperature=0.0, 37 | strip_double_newlines=True, 38 | max_tokens=256, 39 | ) -> np.array: 40 | """Forward pass throught this layer. 41 | 42 | Args: 43 | output_classes: if not None, compute the constrained forward pass on the output classes, pick the highest probability amongst 44 | the prototypes. 45 | temperature: temperature to use for the forward pass 46 | strip_double_newlines: if True, strip any "\n\n" that might have been added 47 | max_tokens: cap the max length for the forward pass 48 | """ 49 | if output_classes is None: 50 | tpl_inputs = [ 51 | self.forward_template.render(input=input, prompt=self.weight) 52 | for input in inputs 53 | ] 54 | log_message(tpl_inputs[0]) 55 | outputs = self.forward_evaluate( 56 | tpl_inputs, 57 | stop=self.forward_template.stop_tokens, 58 | temperature=temperature, 59 | max_tokens=max_tokens, 60 | async_generation=True, 61 | ) 62 | else: 63 | if self.forward_evaluate.has_logprobs: 64 | # compute log p of each output class, second return value is the p(class) 65 | targets = [output_classes.prototype(0) for _ in inputs] 66 | lp = self.log_p( 67 | inputs, targets, output_classes=output_classes, agg="sum" 68 | ).distribution 69 | # best output class index 70 | best_output_class_index = np.argmax(lp, axis=1) 71 | # get the best output class token 72 | outputs = [output_classes.prototype(idx) for idx in best_output_class_index] 73 | else: 74 | tpl_inputs = [ 75 | self.forward_template.render(input=input, prompt=self.weight) 76 | for input in inputs 77 | ] 78 | logit_bias = {} 79 | max_len = 0 80 | 81 | for i in range(len(output_classes)): 82 | for realization in output_classes.verbalizers(i): 83 | token_ids = self.forward_evaluate.encode(realization) 84 | max_len = max(max_len, len(token_ids)) 85 | logit_bias[token_ids[0]] = 100 86 | 87 | outputs = self.forward_evaluate( 88 | tpl_inputs, 89 | stop=self.forward_template.stop_tokens, 90 | temperature=temperature, 91 | max_tokens=max_len, 92 | logit_bias=logit_bias, 93 | ) 94 | 95 | # strip any "\n\n" that might have been added 96 | if strip_double_newlines: 97 | outputs = [o.replace("\n\n", "\n") for o in outputs] 98 | return np.asarray(outputs) 99 | 100 | def log_p_request(self, input: str, target: str, prompt: str) -> ScoreRequest: 101 | # build up a set of score requests 102 | context = self.forward_template.render(input=input, prompt=prompt) 103 | return ScoreRequest(context=context, target=target, payload=target) 104 | 105 | def log_p( 106 | self, 107 | inputs: List[str], 108 | targets: List[str], 109 | prompts=None, 110 | output_classes=None, 111 | agg="max", 112 | ) -> LogProbs: 113 | requests = [] 114 | 115 | if prompts is None: 116 | prompts = [self.weight for _ in inputs] 117 | 118 | for input, target, prompt in zip(inputs, targets, prompts): 119 | requests.append(self.log_p_request(input, target, prompt=prompt)) 120 | 121 | # build up a set of score requests 122 | logprobs = self.logprobs_score.score_requests(requests, output_classes, agg=agg) 123 | return logprobs 124 | 125 | def accuracy( 126 | self, 127 | inputs: List[str], 128 | targets: List[str], 129 | loss: LLoss, 130 | prompts=None, 131 | num_samples=1, 132 | max_tokens=10, 133 | ) -> LogProbs: 134 | requests = [] 135 | 136 | if prompts is None: 137 | prompts = [self.weight for _ in inputs] 138 | 139 | for _ in range(num_samples): 140 | for input, _, prompt in zip(inputs, targets, prompts): 141 | requests.append(self.forward_template.render(input=input, prompt=prompt)) 142 | 143 | # build up a set of score requests 144 | outputs = self.forward_evaluate( 145 | requests, 146 | stop=self.forward_template.stop_tokens, 147 | temperature=1.0 if num_samples > 1 else 0., 148 | max_tokens=max_tokens, 149 | ) 150 | targets = np.array([t for t in targets] * num_samples) 151 | losses = loss(outputs, targets).reshape(-1, num_samples) 152 | accuracy = (1. - losses).mean(1) 153 | return accuracy 154 | 155 | 156 | class ResidualPriorLayer(PriorLayer): 157 | 158 | def __init__( 159 | self, 160 | logprobs_score: LogProbsScore, 161 | forward_evaluate: LLM, 162 | forward_template, 163 | init="", 164 | residual_template="classify_residual" 165 | ): 166 | super().__init__(logprobs_score, forward_evaluate, forward_template, init=init) 167 | self.residual_template = load_template(residual_template) 168 | log_message("Residual template:\n", f"{repr(self.residual_template.template)}") 169 | 170 | def forward(self, inputs, **kwargs) -> np.array: 171 | outputs = super().forward(inputs, **kwargs) 172 | return outputs 173 | 174 | def apply_residual( 175 | self, outputs: np.array, inputs: np.array, use_template=False 176 | ) -> np.array: 177 | outputs_ = [] 178 | if use_template: 179 | for output, input in zip(outputs, inputs): 180 | tpl_input = self.forward_template.render( 181 | input=input, prompt=self.weight 182 | ) 183 | outputs_.append( 184 | self.residual_template.render( 185 | input=tpl_input, output=output 186 | ) 187 | ) 188 | else: 189 | for output, input in zip(outputs, inputs): 190 | outputs_.append( 191 | self.residual_template.render( 192 | input=input, output=output 193 | ) 194 | ) 195 | return np.array(outputs_) 196 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from dln.operator import LLM, instantiate_tokenizer 4 | 5 | from dln.score import LogProbsScore, ScoreRequest 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def unset_env_vars(monkeypatch): 10 | """Unset all environment variables that could be set by the user.""" 11 | monkeypatch.delenv("OPENAI_API_KEY", raising=False) 12 | monkeypatch.delenv("OPENAI_API_TYPE", raising=False) 13 | monkeypatch.delenv("OPENAI_API_BASE", raising=False) # deprecated in favor of OPENAI_BASE_URL 14 | monkeypatch.delenv("OPENAI_BASE_URL", raising=False) 15 | monkeypatch.delenv("OPENAI_API_VERSION", raising=False) 16 | 17 | 18 | @pytest.fixture 19 | def mock_llm_func(): 20 | 21 | def instantiate_llm(model_name=None): 22 | class MockEncoder: 23 | def encode(self, string): 24 | return string 25 | 26 | class MockLLM(LLM): 27 | 28 | def __init__(self, model_name): 29 | if model_name is None: 30 | model_name = "MockLLM" 31 | self.encoder = MockEncoder() 32 | else: 33 | self.encoder = instantiate_tokenizer(model_name) 34 | super().__init__(model_name) 35 | 36 | def _generate(self, inputs, **kwargs): 37 | return inputs 38 | 39 | def encode(self, string): 40 | return self.encoder.encode(string) 41 | 42 | def clean_text(self, string): 43 | return string 44 | 45 | @property 46 | def has_logprobs(self): 47 | return True 48 | 49 | return MockLLM(model_name) 50 | 51 | return instantiate_llm 52 | 53 | 54 | @pytest.fixture 55 | def mock_llm(mock_llm_func): 56 | return mock_llm_func() 57 | 58 | 59 | @pytest.fixture 60 | def mock_logprobs_score(mock_llm_func): 61 | llm = mock_llm_func("text-davinci-003") 62 | logprobs_score = LogProbsScore(llm) 63 | return logprobs_score 64 | 65 | 66 | @pytest.fixture 67 | def backward_info(): 68 | inputs = np.array(["test-1", "test-2", "test-3", "test-4"]) 69 | y = np.array(["test_1", "test_2", "test_3", "test_4"]) 70 | y_hat = np.array(["test_1", "test2", "test_3", "test4"]) 71 | losses = np.array([0.0, 1.0, 0.0, 1.0]) 72 | return inputs, y, y_hat, losses 73 | 74 | 75 | @pytest.fixture 76 | def score_requests(): 77 | return [ 78 | ScoreRequest( 79 | context="1 + 1 is:\n(A) 1\n(B) 2\n\nAnswer:", 80 | target="B", 81 | payload="B", 82 | ), 83 | ScoreRequest( 84 | context="1 * 1 is:\n(A) 1\n(B) 2\n\nAnswer:", 85 | target="A", 86 | payload="A", 87 | ), 88 | ] 89 | 90 | 91 | @pytest.fixture 92 | def text_outputs(): 93 | def logprobs_fn(contexts, *args, **kwargs): 94 | # return logprobs in the same order it was requested (contexts) 95 | logprobs = { 96 | "1 + 1": "A", 97 | "1 * 1": "A", 98 | } 99 | return [logprobs[context[:5]] for context in contexts] 100 | 101 | return logprobs_fn 102 | 103 | 104 | @pytest.fixture 105 | def top_logprobs(): 106 | def logprobs_fn(contexts, *args, **kwargs): 107 | # return logprobs in the same order it was requested (contexts) 108 | logprobs = { 109 | "1 + 1": { 110 | "Option": -8.876863, 111 | "Result": -17.299635, 112 | "choice": -17.710045, 113 | "<": -17.075796, 114 | "=": -15.760291, 115 | "correct": -13.988989, 116 | " A": -10.678262, 117 | "A": -3.663905, 118 | "All": -16.454699, 119 | "B": -12.343077, 120 | }, 121 | "1 * 1": { 122 | "Option": -8.315238, 123 | "=": -16.698154, 124 | "A": -11.863415, 125 | "B": -12.451943, 126 | "Answer": -7.4255853, 127 | "answer": -14.647212, 128 | "Correct": -8.74908, 129 | "Choice": -13.000805, 130 | "Yes": -14.741361, 131 | "b": -17.22967, 132 | }, 133 | } 134 | ordered_log_p = [logprobs[context[:5]] for context in contexts] 135 | return [ 136 | ["0", [ordered_log_p[0]], 2], 137 | ["0", [ordered_log_p[1]], 2], 138 | ] 139 | 140 | return logprobs_fn 141 | 142 | 143 | @pytest.fixture 144 | def raw_logprobs(): 145 | def logprobs_fn(contexts, *args, **kwargs): 146 | # return logprobs in the same order it was requested (contexts) 147 | logprobs = { 148 | "1 + 1": [ 149 | "1 + 1 is:\n(A) 1\n(B) 2\n\nAnswer:\nB", 150 | [ 151 | None, 152 | -5.550775, 153 | -3.194002, 154 | -8.062983, 155 | -1.9706848, 156 | -0.9759903, 157 | -11.239477, 158 | -2.745899, 159 | -0.030587194, 160 | -1.4996661, 161 | -0.068833716, 162 | -0.009404114, 163 | -0.0001532674, 164 | -6.5041706e-05, 165 | -0.056048736, 166 | -0.05334273, 167 | -8.41094, 168 | -6.9211907, 169 | -0.001781753, 170 | -0.053041545, 171 | -1.4834975, 172 | ], 173 | [ 174 | "1", 175 | " +", 176 | " 1", 177 | " is", 178 | ":", 179 | "\\n", 180 | "(", 181 | "A", 182 | ")", 183 | " 1", 184 | "\\n", 185 | "(", 186 | "B", 187 | ")", 188 | " 2", 189 | "\\n", 190 | "\\n", 191 | "Answer", 192 | ":", 193 | "\\n", 194 | "B", 195 | ], 196 | ], 197 | "1 * 1": [ 198 | "1 * 1 is:\n(A) 1\n(B) 2\n\nAnswer: A", 199 | [ 200 | None, 201 | -6.06174, 202 | -4.7931056, 203 | -8.253801, 204 | -2.3915708, 205 | -0.5870681, 206 | -10.741921, 207 | -3.3388677, 208 | -0.011392174, 209 | -0.86958236, 210 | -0.11698982, 211 | -0.48095098, 212 | -0.002377014, 213 | -8.3404535e-05, 214 | -1.417262, 215 | -0.027041545, 216 | -5.510647, 217 | -4.546986, 218 | -0.0010610583, 219 | -0.053041545, 220 | -1.4735329, 221 | ], 222 | [ 223 | "1", 224 | " *", 225 | " 1", 226 | " is", 227 | ":", 228 | "\\n", 229 | "(", 230 | "A", 231 | ")", 232 | " 1", 233 | "\\n", 234 | "(", 235 | "B", 236 | ")", 237 | " 2", 238 | "\\n", 239 | "\\n", 240 | "Answer", 241 | ":", 242 | "\\n", 243 | "A", 244 | ], 245 | ], 246 | } 247 | return [logprobs[context[:5]] for context in contexts] 248 | 249 | return logprobs_fn 250 | -------------------------------------------------------------------------------- /dln/score.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import Any, List 4 | 5 | import numpy as np 6 | 7 | from dln.operator import LLM 8 | 9 | 10 | @dataclass 11 | class ScoreRequest: 12 | context: str 13 | target: str 14 | payload: Any = None 15 | 16 | 17 | @dataclass 18 | class OutputClasses: 19 | protos: List[str] 20 | 21 | def __iter__(self): 22 | return iter(self.protos) 23 | 24 | def __len__(self): 25 | return len(self.protos) 26 | 27 | def verbalizers(self, i): 28 | return self.protos[i].split("|") 29 | 30 | def prototype(self, i): 31 | return self.protos[i].split("|")[0] 32 | 33 | 34 | @dataclass 35 | class LogProbs: 36 | logp_targets: np.ndarray 37 | distribution: np.ndarray 38 | 39 | 40 | class LogProbsScore: 41 | cache = {} 42 | eval_cache = {} 43 | 44 | def __init__(self, forward_evaluate: LLM, burn_in_ratio: float = 0.05): 45 | self.forward_evaluate = forward_evaluate 46 | self.burn_in_ratio = burn_in_ratio 47 | self.cache = {} 48 | 49 | def score_requests(self, requests, output_classes=None, agg="max") -> LogProbs: 50 | # create the batched inputs for the model 51 | if output_classes is not None: 52 | return self._forward_logprobs_score_api_with_classes( 53 | [b.context for b in requests], 54 | [b.target for b in requests], 55 | output_classes, 56 | agg=agg, 57 | ) 58 | return self._forward_logprobs_score_api( 59 | [b.context for b in requests], 60 | [b.target for b in requests], 61 | ) 62 | 63 | def _forward_logprobs_score_api_with_classes( 64 | self, contexts, targets, output_classes, agg="max" 65 | ) -> LogProbs: 66 | eval_kwargs = { 67 | "temperature": 0., 68 | "max_tokens": 1, 69 | "echo": False, 70 | "return_logprobs": True, 71 | "raw_logprobs": False, 72 | "top_logprobs": 100, 73 | } 74 | 75 | output_logprobs = [] 76 | output_distribs = [] 77 | 78 | to_eval = [] 79 | for context in contexts: 80 | if context not in self.cache: 81 | context_ = f"{context}\n" 82 | to_eval.append(context_) 83 | 84 | print("# Scoring requests = {}".format(len(contexts))) 85 | print("# Scoring non cached requests = {}".format(len(to_eval))) 86 | 87 | partial_results = self.forward_evaluate( 88 | to_eval, 89 | async_generation=True, 90 | **eval_kwargs, 91 | ) 92 | for context, result in zip(to_eval, partial_results): 93 | assert context not in self.cache 94 | self.cache[context.strip()] = result 95 | 96 | eval_results = [] 97 | for context in contexts: 98 | eval_results.append(self.cache[context]) 99 | 100 | top_logprobs = [] 101 | for context, result in zip(contexts, eval_results): 102 | context_top_logprobs = result[1][0] 103 | top_logprobs.append(dict(context_top_logprobs)) 104 | 105 | output_logprobs = [] 106 | output_distribs = [] 107 | for context, target, context_top_logprobs in zip(contexts, targets, top_logprobs): 108 | # make this fixed 109 | if context_top_logprobs: 110 | min_prob = np.exp(np.min(list(context_top_logprobs.values()))) 111 | else: 112 | min_prob = 1e-6 113 | 114 | output_classes_scores = np.asarray([min_prob for _ in output_classes]) 115 | # remove tokenization artifacts 116 | clean_text_map = { 117 | self.forward_evaluate.clean_text(i): i 118 | for i in context_top_logprobs.keys() 119 | } 120 | # accumulate probability mass for each class verbalizer 121 | # the class verbalizer can be either " a" or "a" (with or without space) 122 | for i in range(len(output_classes)): 123 | verbalizers = output_classes.verbalizers(i) 124 | verbalizers.extend([f" {v}" for v in verbalizers]) 125 | verbalizers = set(verbalizers) 126 | verbalizers_scores = [0.] 127 | for verbalizer in verbalizers: 128 | if verbalizer in clean_text_map: 129 | prob_orig = np.exp( 130 | context_top_logprobs[clean_text_map[verbalizer]] 131 | ) 132 | else: 133 | prob_orig = min_prob 134 | verbalizers_scores.append(prob_orig) 135 | if agg == "max": 136 | output_classes_scores[i] += np.max(verbalizers_scores) 137 | else: 138 | output_classes_scores[i] += np.sum(verbalizers_scores) 139 | output_class_index = [i for i, output_class in enumerate(output_classes) if target in output_class.split("|")] 140 | assert ( 141 | len(output_class_index) == 1 142 | ), "The target shouldn't appear in two output classes! {}".format(target) 143 | # accuracy here 144 | output_classes_scores = output_classes_scores / output_classes_scores.sum() 145 | output_logprobs.append(np.log(output_classes_scores[output_class_index[0]])) 146 | output_distribs.append(output_classes_scores) 147 | return LogProbs(np.asarray(output_logprobs), np.asarray(output_distribs)) 148 | 149 | def _forward_logprobs_score_api(self, contexts, targets) -> LogProbs: 150 | logging.info("# Scoring requests = {}".format(len(contexts))) 151 | eval_kwargs = { 152 | "temperature": 0, 153 | "max_tokens": 0, 154 | "echo": True, 155 | "return_logprobs": True, 156 | "raw_logprobs": True, 157 | } 158 | 159 | def get_eval_key(context, target): 160 | return f"{context}\n{target}" 161 | 162 | eval_keys = [] 163 | for context, target in zip(contexts, targets): 164 | to_eval = get_eval_key(context, target) 165 | if to_eval not in self.eval_cache: 166 | eval_keys.append(to_eval) 167 | 168 | print("# Scoring requests = {}".format(len(contexts))) 169 | print("# Scoring non cached requests = {}".format(len(eval_keys))) 170 | 171 | # there might be doubles in the eval_batch, so we need to 172 | # only perform unique evals 173 | eval_results = self.forward_evaluate( 174 | eval_keys, 175 | async_generation=True, 176 | **eval_kwargs, 177 | ) 178 | 179 | for eval_key, eval_result in zip(eval_keys, eval_results): 180 | self.eval_cache[eval_key] = eval_result 181 | 182 | # get the results in the same order as the eval_batch 183 | eval_results = [] 184 | for context, target in zip(contexts, targets): 185 | to_eval = get_eval_key(context, target) 186 | eval_results.append(self.eval_cache[to_eval]) 187 | 188 | # get the nll results 189 | log_probs = [eval_result[1] for eval_result in eval_results] 190 | 191 | # get the logprobs results 192 | output_logprobs = [] 193 | context_logprobs = [] 194 | all_output_logprobs = [] 195 | 196 | for context, token_log_probs in zip(contexts, log_probs): 197 | num_tokens_prompt = len(self.forward_evaluate.encode(context)) 198 | burn_in_tokens = int(self.burn_in_ratio * (len(token_log_probs) - num_tokens_prompt)) 199 | target_log_probs = token_log_probs[num_tokens_prompt + burn_in_tokens:] 200 | context_log_probs = token_log_probs[1:num_tokens_prompt] 201 | all_output_logprobs.append(target_log_probs) 202 | output_logprobs.append(sum(target_log_probs) / (len(target_log_probs) + 1e-5)) 203 | context_logprobs.append(sum(context_log_probs) / (len(context_log_probs) + 1e-5)) 204 | 205 | return LogProbs(np.asarray(output_logprobs), np.asarray(context_logprobs)) 206 | -------------------------------------------------------------------------------- /projects/demo/demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import textwrap 4 | 5 | import altair as alt 6 | import streamlit as st 7 | import numpy as np 8 | import pandas as pd 9 | from jinja2 import Template 10 | 11 | 12 | forward_template_L1 = Template( 13 | "{{ input }}\n\n{{ prompt }} Let's think step by step." 14 | ) # Loaded template suffix_forward_tbs v1.0 15 | forward_template_L2 = Template( 16 | "{{ prompt }}\n\n{{ input }}\n\nAnswer:" 17 | ) # Loaded template classify_forward v3.0 18 | 19 | DATASETS = [ 20 | ("subj", "1 Layer - Subj"), 21 | ("hyperbaton", "1 Layer - Hyperbaton"), 22 | ("navigate", "2 Layers - Navigate"), 23 | ] 24 | 25 | 26 | def wrap_text(text, width=100): 27 | text = text.replace("\n\n", "\n") 28 | return "\n".join("\n".join(textwrap.wrap(line, width)) for line in text.split("\n")) 29 | 30 | 31 | def load_data(logs, dataset): 32 | logs = logs[dataset] 33 | 34 | flattened_data = [] 35 | flattened_candidates = [] 36 | for item in logs["training"]: 37 | flat_item = {"step": item["step"]} 38 | flat_item.update( 39 | { 40 | metric: value if value is not None else np.nan 41 | for metric, value in item["metrics"].items() 42 | } 43 | ) 44 | flat_item.update( 45 | {f"layer_{i}": wrap_text(l) for i, l in enumerate(item["layers"], 1)} 46 | ) 47 | flattened_data.append(flat_item) 48 | candidates_data = {} 49 | for layer, candidates in enumerate(item["candidates"], 1): 50 | for idx, candidate in enumerate(candidates): 51 | candidate_data = candidates_data.setdefault(idx, {"step": item["step"]}) 52 | candidate_data[f"Layer {layer} candidate"] = candidate["layer"] 53 | candidate_data[f"Layer {layer} score"] = candidate["score"] 54 | flattened_candidates += list(candidates_data.values()) 55 | 56 | flattened_examples = [] 57 | for i, example in enumerate(logs["examples"]): 58 | for item in example["trace"]: 59 | flat_item = { 60 | "id": i + 1, 61 | "input": wrap_text(example["input"]), 62 | "label": example["label"], 63 | "step": item["step"], 64 | "hidden": wrap_text(item["hiddens"][0]) if item["hiddens"] else "", 65 | "output": wrap_text(item["output"]), 66 | } 67 | flattened_examples.append(flat_item) 68 | 69 | return ( 70 | pd.DataFrame(flattened_data), 71 | pd.DataFrame(flattened_candidates).dropna(), 72 | pd.DataFrame(flattened_examples), 73 | ) 74 | 75 | 76 | def load_logfiles(logfiles): 77 | if not logfiles: 78 | return None 79 | logs = {} 80 | for logfile in logfiles: 81 | with open(logfile, "r") as f: 82 | data = json.load(f) 83 | assert len(data.keys()) == 1 84 | new_key = key = list(data.keys())[0] 85 | if key in logs: 86 | new_key = f"{key} ({len([1 for k in logs if k.startswith(key)]) + 1})" 87 | logs[new_key] = data[key] 88 | 89 | return logs 90 | 91 | 92 | def extract_dataset_names(logs): 93 | return [(x, x) for x in list(logs.keys())] 94 | 95 | 96 | def main(args): 97 | logs = load_logfiles(args.logfiles or ["data.json"]) 98 | datasets = extract_dataset_names(logs) if logs else DATASETS 99 | st.set_page_config(layout="wide") 100 | st.markdown("

Deep Language Networks

", unsafe_allow_html=True) 101 | col1, col2 = st.columns(2) 102 | with col1: 103 | # find navigate dataset index or default to 0 104 | selectbox_index = next((i for i, (dataset_id, _) in enumerate(datasets) if dataset_id == 'navigate'), 0) 105 | dataset_selectbox = st.selectbox("Dataset", datasets, index=selectbox_index, format_func=lambda x: x[1]) 106 | dataset_selectbox = dataset_selectbox[0] 107 | df, candidates, examples = load_data(logs, dataset_selectbox) 108 | 109 | # st.slider does not support non-uniform steps. Using an index slider and then index into steps. 110 | steps = examples['step'].unique() 111 | steps = df['step'].unique() 112 | highlight_example = steps[st.selectbox("Example", [i for i in range(len(steps) - 1)], format_func=lambda x: x + 1)] 113 | highlight_step = steps[st.slider("Step", 0, len(steps) - 1)] 114 | 115 | show_example = any(examples['step'] == highlight_step) 116 | 117 | st.write("") 118 | table_data = [] 119 | table_data.append(f"| **Input:** | {examples[examples['step'] == highlight_step]['input'].iloc[highlight_example] if show_example else 'N/A'}") 120 | table_data.append(f"| **Layer 1 prompt:** | {df[df['step'] == highlight_step]['layer_1'].values[0]}") 121 | if 'layer_2' in df.columns: 122 | table_data.append(f"| **Hidden:** | {examples[examples['step'] == highlight_step]['hidden'].iloc[highlight_example] if show_example else 'N/A'}") 123 | table_data.append(f"| **Layer 2 prompt:** | {df[df['step'] == highlight_step]['layer_2'].values[0]}") 124 | table_data.append(f"| **Output:** | {examples[examples['step'] == highlight_step]['output'].iloc[highlight_example] if show_example else 'N/A'}") 125 | table_data.append(f"| **Label:** | {examples[examples['step'] == highlight_step]['label'].iloc[highlight_example] if show_example else 'N/A'}") 126 | table_data = [x.replace('\n', '
') for x in table_data] 127 | table_data_str = "\n".join(table_data) 128 | 129 | st.markdown(f"\n| | |\n| --- | --- |\n{table_data_str}", unsafe_allow_html=True) 130 | st.write("") 131 | 132 | with col2: 133 | melted_df = df.melt(id_vars=['step'], value_vars=['acc', 'run_acc', 'dev_acc'], var_name='metric', value_name='value') 134 | melted_df['metric'] = melted_df['metric'].replace(['acc', 'run_acc', 'dev_acc'], ['Train Batch', 'Train Run Avg', 'Dev Avg']) 135 | combined_chart = alt.Chart(melted_df).mark_line().encode( 136 | y=alt.Y('value:Q', title="accuracy", scale=alt.Scale( 137 | domain=[melted_df['value'].min(), melted_df['value'].max()] 138 | )), 139 | x='step:Q', 140 | color=alt.Color( 141 | 'metric:N', 142 | scale=alt.Scale(domain=['Train Batch', 'Train Run Avg', 'Dev Avg'], range=['steelblue', 'lightblue', 'orange']), 143 | legend=alt.Legend(title="Train/Dev Accuracy") 144 | ), 145 | ).transform_filter( # Marked NaN values as invalid so they can be ignored. Ref: https://stackoverflow.com/a/72306402 146 | 'isValid(datum.value)' 147 | ) 148 | 149 | # Add a vertical rule at the specific step 150 | highlight_rule = alt.Chart(pd.DataFrame({'step': [highlight_step]})).mark_rule(color='red').encode(x='step:Q') 151 | 152 | # Combine the line chart, vertical rule, and text label 153 | alt_acc = alt.layer( 154 | combined_chart, highlight_rule, data=melted_df 155 | ).properties(height=500) 156 | 157 | st.altair_chart(alt_acc, use_container_width=True) 158 | 159 | activate_elbo = st.toggle("Elbo") 160 | if activate_elbo: 161 | # elbo = df[["step", "elbo", "run_elbo"]] 162 | melted_elbo = df.melt(id_vars=['step'], value_vars=['elbo', 'run_elbo'], var_name='metric', value_name='value') 163 | melted_elbo['metric'] = melted_elbo['metric'].replace(['elbo', 'run_elbo'], ['Batch', 'Run Avg']) 164 | elbo_chart = alt.Chart(melted_elbo).mark_line().encode( 165 | y=alt.Y('value:Q', title="elbo", scale=alt.Scale( 166 | domain=[melted_elbo['value'].min(), melted_elbo['value'].max()] 167 | )), 168 | x='step:Q', 169 | color=alt.Color( 170 | 'metric:N', 171 | scale=alt.Scale(domain=['Batch', 'Run Avg'], range=['steelblue', 'lightblue']), 172 | legend=alt.Legend(title="Train Elbo") 173 | ), 174 | ) 175 | # Combine the elbo line chart and the highlight rule 176 | alt_elbo = alt.layer( 177 | elbo_chart, highlight_rule, data=melted_elbo 178 | ).properties(height=500) 179 | st.altair_chart(alt_elbo, use_container_width=True) 180 | 181 | prompt_candidates = st.toggle("Prompt Candidates") 182 | if prompt_candidates: 183 | # list all columns from candidates dataframe except the 'step' columns 184 | cols = [col for col in candidates.columns if col != 'step'] 185 | st.dataframe( 186 | candidates[candidates["step"] == highlight_step][cols], 187 | hide_index=True, 188 | use_container_width=True, 189 | ) 190 | 191 | 192 | if __name__ == "__main__": 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument("logfiles", nargs="*", help="Log file to use (JSON).") 195 | args = parser.parse_args() 196 | 197 | main(args) 198 | -------------------------------------------------------------------------------- /projects/guide/app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from guided_search import GuidedSearchController, Rating 3 | 4 | 5 | LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 6 | 7 | # --------------- # 8 | # Utils functions # 9 | # --------------- # 10 | 11 | # Keep track of the current screen. 12 | def reset(): 13 | st.session_state.screen = 0 14 | st.session_state.subscreen = 0 15 | st.session_state.initialized = True 16 | st.session_state.gs = GuidedSearchController() 17 | 18 | # Define a function that will be called to process the next screen. 19 | def next_screen(): 20 | st.session_state.screen += 1 21 | st.session_state.subscreen = 0 22 | if st.session_state.screen > 2: 23 | st.session_state.screen = 1 24 | 25 | # Update the configs of the guided search. 26 | def update_bwd_configs(): 27 | st.session_state.gs.update_bwd_configs( 28 | st.session_state.num_candidates, 29 | st.session_state.bwd_temperature, 30 | st.session_state.bwd_max_tokens, 31 | ) 32 | 33 | # Update the configs of the outputs generation. 34 | def update_fwd_configs(): 35 | st.session_state.gs.update_fwd_configs( 36 | st.session_state.fwd_model, 37 | st.session_state.fwd_temperature, 38 | st.session_state.fwd_max_tokens, 39 | ) 40 | 41 | # Initialize the GuidedSearch 42 | def setup_guided_search(): 43 | st.session_state.gs.setup( 44 | st.session_state.meta_prompt, 45 | st.session_state.input_examples 46 | ) 47 | next_screen() 48 | 49 | def provide_feedback(): 50 | st.session_state.gs.feedback = st.session_state.feedback 51 | next_screen() 52 | st.session_state.gs.prompt_proposal_step() 53 | st.session_state.gs.inference_candidates_per_example() 54 | 55 | def submit_ratings(): 56 | selected_outputs = { 57 | (j, i): ( 58 | getattr(st.session_state, f"score_{i}_{j}"), 59 | getattr(st.session_state, f"feedback_{i}_{j}", None) 60 | ) 61 | for i, input in enumerate(st.session_state.gs.examples) 62 | for j, prompt in enumerate(st.session_state.gs.prompt_candidates_outputs) 63 | } 64 | st.session_state.gs.consolidate_prompts(selected_outputs) 65 | next_screen() 66 | 67 | if 'initialized' not in st.session_state: 68 | reset() 69 | 70 | 71 | # ------------------ # 72 | # Main Streamlit app # 73 | # ------------------ # 74 | st.header('GUIDE: A tool for guided meta-prompt search') 75 | 76 | ############ 77 | # Sidebar # 78 | ############ 79 | 80 | with st.sidebar: 81 | with st.form(key="bwd-configs"): 82 | st.markdown("**GUIDE Settings**", help="Good luck") 83 | st.slider("Number of meta-prompt candidates", 1, 5, st.session_state.gs.num_candidates, key="num_candidates") 84 | st.slider("Search diversity", 0.0, 1.0, st.session_state.gs.bwd_temperature, key="bwd_temperature", 85 | help="Intuitively, lower 'Search diversity' values will lead to more deterministic and less creative meta-prompts.") 86 | st.slider("Max meta-prompt length", 0, 500, st.session_state.gs.bwd_max_tokens, key="bwd_max_tokens", help="aka max tokens") 87 | st.toggle("Show meta-prompts", key="show_metaprompts", help="Show the meta-prompts used to generate the outputs.") 88 | st.form_submit_button( 89 | label="Apply", 90 | on_click=update_bwd_configs, 91 | type="primary", 92 | use_container_width=True, 93 | ) 94 | with st.form(key="fwd-configs"): 95 | st.markdown("#### Output Generation Settings", help="Settings used in your application") 96 | st.selectbox("Model", st.session_state.gs.available_models, key="fwd_model", help="LLM used in your application to generate output.") 97 | st.slider("Temperature", 0.0, 1.0, st.session_state.gs.fwd_temperature, key="fwd_temperature", help="Temperature used for the generated output.") 98 | st.slider("Max tokens", 0, 500, st.session_state.gs.fwd_max_tokens, key="fwd_max_tokens", help="Max number of tokens for generated output.") 99 | st.form_submit_button( 100 | label="Apply", 101 | on_click=update_fwd_configs, 102 | type="primary", 103 | use_container_width=True, 104 | ) 105 | st.markdown(f"Search iteration {st.session_state.gs.optimization_step}") 106 | with st.expander("#### Meta-Prompt History", expanded=True): 107 | for idx, (prompt, example_outputs) in enumerate(st.session_state.gs.history.items(), 1): 108 | st.markdown(f"##### Meta-Prompt {idx}: {prompt}") 109 | for example_output in example_outputs: 110 | st.markdown(f"
{example_output.example} | {example_output.output}", unsafe_allow_html=True) 111 | st.markdown("
1: 163 | # After the first iteration of optimization, let the user know we are starting a new round of search/optimization. 164 | st.info("Here is the new meta-prompt based on your preference ratings. You can check out the outputs and choose to do another round of search.") 165 | 166 | with st.form(key="feedback_form"): 167 | # Display the meta-prompt used in a code format. 168 | st.markdown(f"\n| Current meta-prompt |\n|-|\n|{st.session_state.gs.meta_prompt}|") 169 | st.markdown("", unsafe_allow_html=True) 170 | 171 | # Display the results of the inference formatted in a markdown table. Each row is an input-output pair. 172 | # The first column is the input, and the second column is the output. 173 | st.write("") 174 | st.markdown("""\ 175 | Input | Output 176 | --- | --- 177 | """ + "\n".join(f'{out.example} | {out.output}' for out in st.session_state.gs.example_outputs)) 178 | 179 | # Add newlines 180 | st.markdown('') 181 | st.markdown('') 182 | st.write('Based on the results, provide feedback to guide the search for new meta-prompts.') 183 | # Add text area for feedback on the metaprompt. 184 | feedback = st.text_area( 185 | label="Based on the results, provide feedback to guide the meta-prompt search:", 186 | key="feedback", 187 | label_visibility="collapsed", 188 | value=st.session_state.gs.feedback or "", 189 | placeholder="", 190 | ) 191 | 192 | submit_button = st.form_submit_button( 193 | label="Generate prompt candidates", 194 | on_click=provide_feedback, 195 | use_container_width=True, 196 | type="primary", 197 | ) 198 | 199 | 200 | ############ 201 | # Screen 2 # 202 | ############ 203 | if st.session_state.screen == 2: 204 | st.subheader('Rate outputs generated by different meta-prompt candidates') 205 | st.markdown( 206 | 'Guide the Meta-Prompt search by providing feedback on ' 207 | 'the outputs generated with different meta-prompt candidates. ' 208 | 'You are not required to provide feedback on all example outputs. ' 209 | 'The outputs skipped will not inform the generation of the new meta prompt.' 210 | ) 211 | 212 | st.write("---") 213 | 214 | # Go over each input example and generate the output from inference_per_example function in guided_search.py. 215 | # Then, display the input, output, and a checkbox for the user to indicate whether they prefer the output. 216 | st.session_state.selected_output = [] 217 | for i, input in enumerate(st.session_state.gs.examples): 218 | # Display the input as label. 219 | st.markdown(f":red[**Input {i + 1}:** {input}]") 220 | 221 | outputs = st.session_state.gs.find_outputs_by_example(input) 222 | for j, output in enumerate(outputs): 223 | st.caption(f"Meta-prompt {LETTERS[j]}\: " + list(st.session_state.gs.prompt_candidates_outputs)[j] if st.session_state.get("show_metaprompts") else "") 224 | st.markdown(f"**Output {LETTERS[j]}:** {output}") 225 | 226 | radio = st.radio( 227 | "Rating", 228 | Rating.values(), 229 | format_func=lambda e: Rating.icons(e), 230 | horizontal=True, 231 | index=2, 232 | key=f"score_{i}_{j}", 233 | label_visibility="collapsed" 234 | ) 235 | if radio in (1, -1): 236 | st.text_input( 237 | label="Provide feedback on the output:", 238 | key=f"feedback_{i}_{j}", 239 | label_visibility="collapsed", 240 | value="", 241 | placeholder="Provide feedback on the output (optional)", 242 | ) 243 | 244 | st.write("---") 245 | 246 | submit_button = st.button( 247 | label="Submit my preferences", 248 | on_click=submit_ratings, 249 | use_container_width=True, 250 | type="primary", 251 | ) 252 | -------------------------------------------------------------------------------- /dln/vi/sampler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import Union, List 4 | 5 | import numpy as np 6 | 7 | from dln.operator import LLM 8 | from dln.template import load_template 9 | from dln.vi.utils import log_message 10 | 11 | 12 | @dataclass 13 | class Info: 14 | input: str = None 15 | output: str = None 16 | target: str = None 17 | loss: float = 0.0 18 | 19 | 20 | class PromptSampler: 21 | def __init__(self, evaluate_func: LLM, p_template: str = "q_action_prompt:v3.5"): 22 | self.prompt_template = load_template(p_template) 23 | log_message("Prompt template:\n", f"{repr(self.prompt_template.template)}") 24 | log_message( 25 | "Message alternatives:\n", f"{self.prompt_template.message_alternatives}" 26 | ) 27 | 28 | self.evaluate_func = evaluate_func 29 | self.prompt_history = [] 30 | 31 | @staticmethod 32 | def create(template): 33 | if "seq" in template: 34 | return SequentialPromptSampler() 35 | return PromptSampler(template) 36 | 37 | def sample_q_p( 38 | self, 39 | inputs: np.array, 40 | y: np.array, 41 | y_hat: np.array, 42 | losses: np.array, 43 | prompt: Union[str, List[str]], 44 | num_samples=1, 45 | held_out_half=False, 46 | ): 47 | """ 48 | Args: 49 | inputs: input sequences 50 | y: target sequences 51 | y_hat: predicted sequences 52 | losses: losses for each sequence 53 | prompt: prompt to use for sampling 54 | num_samples: number of samples to generate 55 | held_out_half: if True, only use the first half of the data points for sampling prompts 56 | """ 57 | infos = [ 58 | Info(input=input_i, output=y_hat_i, target=y_i, loss=loss) 59 | for input_i, y_i, y_hat_i, loss in zip(inputs, y, y_hat, losses) 60 | ] 61 | log_message("Generating {} ~p proposals...".format(num_samples)) 62 | while True: 63 | try: 64 | tpls = [] 65 | for i in range(num_samples - 1): 66 | template_infos = {} 67 | if self.prompt_template.message_alternatives is None: 68 | message = None 69 | else: 70 | message = self.prompt_template.message_alternatives[ 71 | i % len(self.prompt_template.message_alternatives) 72 | ] 73 | 74 | indices = np.random.permutation(np.arange(len(infos))) 75 | if held_out_half: 76 | infos_ = [infos[i] for i in indices[: len(infos) // 2]] 77 | else: 78 | infos_ = [infos[i] for i in indices] 79 | 80 | template_infos["message"] = message 81 | template_infos["backward_infos"] = infos_ 82 | template_infos["prompt"] = ( 83 | prompt[i % len(prompt)] if type(prompt) == list else prompt 84 | ) 85 | log_message(self.prompt_template.render(**template_infos)) 86 | tpls.append(self.prompt_template.render(**template_infos)) 87 | 88 | new_prompts = self.evaluate_func( 89 | tpls, stop=self.prompt_template.stop_tokens, n=1 90 | ) 91 | 92 | if type(prompt) == list: 93 | prompts = np.array(prompt + list(new_prompts)) 94 | else: 95 | prompts = np.array([prompt] + list(new_prompts)) 96 | return prompts 97 | except KeyboardInterrupt: 98 | break 99 | except: 100 | if len(infos) > 1: 101 | infos = infos[1:] 102 | logging.info("DROPPING A DATA POINT...") 103 | else: 104 | error_message = ( 105 | "Still exeeding context length after shrinking backward_infos." 106 | ) 107 | logging.info(error_message) 108 | raise ValueError(error_message) 109 | 110 | 111 | class SequentialPromptSampler(PromptSampler): 112 | def __init__(self): 113 | super().__init__(p_template="q_action_prompt_seq") 114 | 115 | def sample_q_p( 116 | self, 117 | inputs: np.array, 118 | y: np.array, 119 | y_hat: np.array, 120 | losses: np.array, 121 | prompt: str, 122 | num_samples=1, 123 | held_out_half=False, 124 | ): 125 | """ 126 | Args: 127 | inputs: input sequences 128 | y: target sequences 129 | y_hat: predicted sequences 130 | losses: losses for each sequence 131 | prompt: prompt to use for sampling 132 | num_samples: number of samples to generate 133 | held_out_half: if True, only use the first half of the data points for sampling prompts 134 | """ 135 | self.prompt_history.append(prompt) 136 | 137 | infos = [ 138 | Info(input=input_i, output=y_hat_i, target=y_i, loss=loss) 139 | for input_i, y_i, y_hat_i, loss in zip(inputs, y, y_hat, losses) 140 | ] 141 | while True: 142 | try: 143 | tpls = [] 144 | for i in range((num_samples - 1) // 3): 145 | if self.prompt_template.message_alternatives is None: 146 | message = None 147 | else: 148 | message = self.prompt_template.message_alternatives[ 149 | i % len(self.prompt_template.message_alternatives) 150 | ] 151 | indices = np.random.permutation(np.arange(len(infos))) 152 | if held_out_half: 153 | infos_ = [infos[i] for i in indices[: len(infos) // 2]] 154 | else: 155 | infos_ = [infos[i] for i in indices] 156 | tpls.append( 157 | self.prompt_template.render( 158 | backward_infos=infos_, 159 | prompt=self.prompt_history[-1], 160 | message=message, 161 | ) 162 | ) 163 | 164 | log_message("Generating {} ~p proposals...".format(num_samples)) 165 | 166 | prompts = self.evaluate_func( 167 | tpls, 168 | stop=self.prompt_template.stop_tokens, 169 | n=1, 170 | async_generation=True, 171 | ) 172 | log_message("DONE...") 173 | 174 | # each prompt is prefix by 1., 2. and 3., so flatten the sequentially sampled prompts 175 | prompts_ = [] 176 | for prompt_ in prompts: 177 | sub_prompts_ = prompt_.split("\n") 178 | sub_prompts_ = [sub_prompts_[0].strip()] + [ 179 | p_[2:].strip() for p_ in sub_prompts_[1:] 180 | ] 181 | sub_prompts_ = list(set(sub_prompts_)) 182 | prompts_.extend(sub_prompts_) 183 | 184 | prompts = np.array([prompt] + list(prompts_)) 185 | return prompts 186 | except KeyboardInterrupt: 187 | break 188 | except: 189 | if len(infos) > 1: 190 | infos = infos[1:] 191 | logging.info("DROPPING A DATA POINT...") 192 | else: 193 | error_message = ( 194 | "Still exeeding context length after shrinking backward_infos." 195 | ) 196 | logging.info(error_message) 197 | raise ValueError(error_message) 198 | 199 | 200 | class PosteriorSampler: 201 | def __init__(self, evaluate_func: LLM, q_template: str): 202 | self.q_templates = [] 203 | 204 | for q_template in q_template.split("|"): 205 | self.q_templates.append(load_template(q_template)) 206 | 207 | for q_template in self.q_templates: 208 | log_message("Q template:", f"{repr(q_template.template)}") 209 | 210 | self.stop_tokens = self.q_templates[0].stop_tokens 211 | self.evaluate_func = evaluate_func 212 | self.rng = np.random.RandomState(0) 213 | 214 | def sample_q_h( 215 | self, 216 | x: np.array, 217 | y: np.array, 218 | h: np.array, 219 | prompt: str, 220 | next_prompt: str, 221 | num_samples=1, 222 | strip_double_newlines=True, 223 | return_logprobs=False, 224 | ): 225 | """ 226 | Sample a new hidden state from the posterior distribution. 227 | 228 | Args: 229 | x: inputs 230 | y: labels 231 | y_hat: model predictions for the forward pass 232 | h: hidden states for the forward pass 233 | prompt: prompt for the layer that generated h 234 | next_prompt: prompt for the layer above h 235 | num_samples: number of samples to generate 236 | strip_double_newlines: strip double new lines from the output samples 237 | return_logprobs: return the log probabilities of the samples 238 | Returns 239 | (batch_size, num_samples) array of hidden states 240 | """ 241 | tpls = [] 242 | 243 | for i, (x_i, h_i, y_i) in enumerate(zip(x, h, y)): 244 | for j in range(num_samples): 245 | # pick a template at random 246 | q_template = self.q_templates[ 247 | np.random.choice(np.arange(len(self.q_templates))) 248 | ] 249 | if q_template.message_alternatives is not None: 250 | message = q_template.message_alternatives[ 251 | j % len(q_template.message_alternatives) 252 | ] 253 | else: 254 | message = None 255 | 256 | # pick another example in the set 257 | all_indices = list(np.arange(len(x))) 258 | source_example = self.rng.choice(all_indices) 259 | 260 | tpl = q_template.render( 261 | input=x_i, 262 | h=h_i, 263 | source_x=x[source_example], 264 | source_h=h[source_example], 265 | prompt=prompt, 266 | next_prompt=next_prompt, 267 | y=y_i, 268 | message=message, 269 | ) 270 | 271 | # induce randomness 272 | tpls.append(tpl) 273 | 274 | assert len( 275 | tpls 276 | ), "If we are here, it means that either we resample hidden states, or that there are some errors." 277 | 278 | # this might happen when all memories are correct 279 | log_message("Q proposals: " + str(len(tpls)) + ", Q template:" + "\n" + tpls[0]) 280 | log_message("Generating {} ~h proposals...".format(num_samples)) 281 | 282 | sampled = self.evaluate_func( 283 | tpls, 284 | stop=self.stop_tokens, 285 | n=1, 286 | async_generation=True, 287 | return_logprobs=return_logprobs, 288 | ) 289 | if return_logprobs: 290 | sampled, logprobs, lengths = zip(*sampled) 291 | logprobs = np.asarray(logprobs) / np.asarray(lengths) 292 | 293 | # strip any "\n\n" that might have been added 294 | if strip_double_newlines: 295 | sampled = [s.replace("\n\n", "\n") for s in sampled] 296 | 297 | sampled = np.asarray(sampled).reshape(x.shape[0], num_samples) 298 | assert sampled.shape[0] == x.shape[0] 299 | 300 | if return_logprobs: 301 | return sampled, logprobs.reshape(x.shape[0], num_samples) 302 | return sampled 303 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Language Networks 2 |
3 | 4 | [[ArXiv]](https://arxiv.org/abs/2306.12509) 5 | [[Blog]](https://medium.com/@friederike.niedtner/deep-language-networks-stacking-llms-in-trainable-layers-e7f719bcabde) 6 | 7 |
8 | 9 | In this repository, you will find the code for 10 | "Deep Language Networks: Joint Prompt Training of Stacked LLMs using Variational Inference". 11 | Please refer to our paper for further details. 12 | 13 | ## Abstract 14 | We view large language models (LLMs) as stochastic language layers in a network, where the learnable parameters are the natural language prompts at each layer. We stack two such layers, feeding the output of one layer to the next. We call the stacked architecture a Deep Language Network (DLN). We first show how to effectively perform prompt optimization for a 1-Layer language network (DLN-1). We then show how to train 2-layer DLNs (DLN-2), where two prompts must be learnt. We consider the output of the first layer as a latent variable to marginalize, and devise a variational inference algorithm for joint prompt training. A DLN-2 reaches higher performance than a single layer, sometimes comparable to few-shot GPT-4 even when each LLM in the network is smaller and less powerful. 15 | 16 | ## Setup 17 | 18 | ### Clone repo 19 | git clone https://github.com/microsoft/deep-language-networks.git 20 | cd deep-language-networks 21 | 22 | ### Installl dependencies 23 | conda create -n dln python=3.10 24 | conda activate dln 25 | pip install -e . 26 | 27 | ### Setup data 28 | bash scripts/setup_data.sh 29 | 30 | ### Set your OpenAI API key 31 | 32 | Export your key or put it in your *shrc, e.g., 33 | 34 | export OPENAI_API_KEY='...your...key...' 35 | 36 | In order to use Microsoft Azure endpoints, in addition to the OPENAI_API_KEY, 37 | you need to set the OPENAI_API_TYPE, OPENAI_BASE_URL and OPENAI_API_VERSION. 38 | The OPENAI_API_TYPE must be set to 'azure' and the others correspond to the properties of your endpoint. 39 | 40 | 41 | > :warning: **Warning:** Setting `echo` and `logprobs` simultaneously is no longer supported for certain OpenAI models. 42 | However, optimizing prompts jointly for 2-DLN using variational inference requires both settings. 43 | To run 2-DLN experiments, consider hosting your own model (see [self-hosted models](#setup-self-hosted-models-vllm)). 44 | Alternatively, you can run 1-DNL by setting output_scoring_function="accuracy" and --one_layer=True. 45 | 46 | 47 | ### Setup self-hosted models (vLLM) 48 | 49 | DLN does not directly serve models, instead, we use [vLLM](https://github.com/vllm-project/vllm), an open-source library that provides an OpenAI-compatible server solution for self-hosted models. For instructions on setting up vLLM, please follow this [guide](https://docs.vllm.ai/en/latest/getting_started/quickstart.html#openai-compatible-server). 50 | 51 | Once your vLLM server is up and running, if you have loaded a model from weights on your local machine, please define the path to the tokenizer (e.g., `/path/to/Llama-2-70b-chat-hf`) in the environment variable `TOKENIZER_PATH` so that DLN can also load the same tokenizer. If you are downloading the tokenizer from Hugging Face, this environment variable is not required; DLN will download it automatically (e.g., `meta-llama/Llama-2-70b-chat-hf`). 52 | 53 | Then, set the `OPENAI_BASE_URL` and `OPENAI_API_KEY` environment variables to point to your vLLM server. Finally, remember to unset `OPENAI_API_TYPE` and `OPENAI_API_VERSION` if they were previously set. 54 | 55 | export TOKENIZER_PATH= # /path/to/Llama-2-70b-chat-hf 56 | export OPENAI_API_KEY= # EMPTY 57 | export OPENAI_BASE_URL= # http://127.0.0.1:8000/v1 58 | unset OPENAI_API_TYPE 59 | unset OPENAI_API_VERSION 60 | 61 | 62 | ## Datasets 63 | 64 | We provide an interface to a few datasets from Big-bench Hard, Leopard, Ordered Prompt, and GSM8K that can be used to train and evaluate DLNs. 65 | See [dln/dataset.py](dln/dataset.py) for more details. 66 | 67 | ```python 68 | from dln.dataset import init_dataset 69 | 70 | dataset = "navigate" 71 | seed = 42 72 | data_dir = "data" 73 | 74 | dataset = init_dataset( 75 | dataset_id=dataset, 76 | seed=seed, 77 | data_dir=data_dir, 78 | n_few_shots=5, 79 | # max_train_size=max_train_size, 80 | # max_dev_size=max_dev_size, 81 | # max_test_size=max_test_size, 82 | ) 83 | # Get dataset sizes 84 | dataset.train_size, dataset.dev_size, dataset.test_size 85 | 86 | # Get a batch 87 | sentences, labels, few_shot_examples = dataset.get_batch("train", 10) 88 | 89 | # Reset the split pointer 90 | dataset.reset_pointer("train") 91 | 92 | # Iterate over a dataset split 93 | for sentences, labels, few_shot_examples in dataset.iterate("dev", batch_size=10): 94 | pass 95 | 96 | # Get all data from a split 97 | test_sentences, test_labels = dataset.get_data("test") 98 | ``` 99 | 100 | ## LLMs 101 | 102 | The easiest way to use LLMs in DLN is to register them using `LLMRegistry`. 103 | You can register different models, or the same model with different configurations. 104 | 105 | Any number of keyword arguments can be provided to the `register` method and these will be passed to the model's `generate` method. 106 | Extra keyword arguments can also be provided to the `generate` method, overriding the ones used during the models' instantiation. 107 | 108 | ```python 109 | from dln.operator import LLMRegistry 110 | 111 | llm_registry = LLMRegistry() 112 | 113 | fwd_model = llm_registry.register( 114 | "fwd_model", # how you refer to the model 115 | "gpt-35-turbo-instruct", # model id 116 | temperature=0.0, 117 | max_tokens=256, 118 | stop=None, 119 | ) 120 | 121 | bwd_model = llm_registry.register( 122 | "bwd_model", 123 | "gpt-35-turbo-instruct", 124 | temperature=0.7, 125 | max_tokens=512, 126 | stop=None, 127 | ) 128 | 129 | fwd_model.generate("What is sirop d'érable?") 130 | fwd_model.generate("What is sirop d'érable?", max_tokens=200, stop=[r"\n"]) 131 | ``` 132 | 133 | Alternatively, you can specify the LLMs configuration in a YAML file with the following format: 134 | 135 | ```yaml 136 | - name: fwd_model # how you refer to the model 137 | model: "gpt-35-turbo-instruct" # model id 138 | temperature: 0.0 # any generation kwarg 139 | max_tokens: 256 140 | - name: bwd_model 141 | model: "gpt-35-turbo-instruct" 142 | ... 143 | ``` 144 | 145 | This is particularly useful when you want to use models from different APIs. In this case, 146 | you should unset the default `OPENAI` environment vars, and provide them in the YAML file. 147 | For example: 148 | 149 | ```yaml 150 | - name: phi2-fwd 151 | model: microsoft/phi-2 152 | api_key: ${PHI2_API_KEY} 153 | base_url: ${PHI2_API_BASE} 154 | api_type: null 155 | api_version: null 156 | max_tokens: 256 157 | temperature: 0.0 158 | 159 | - name: gpt-bwd 160 | model: "gpt-35-turbo-instruct" 161 | api_key: ${GPT_API_KEY} 162 | base_url: ${GPT_API_BASE} 163 | api_type: ${GPT_API_TYPE} 164 | api_version: ${GPT_API_VERSION} 165 | temperature: 0.7 166 | max_tokens: 512 167 | 168 | ``` 169 | 170 | Then, you can register the models using the `register_from_yaml` method and get them using the `get` method, as follows: 171 | 172 | ```python 173 | llm_registry = LLMRegistry.from_yaml("connections.yaml") 174 | fwd_model = llm_registry.get("phi2-fwd") 175 | bwd_model = llm_registry.get("gpt-bwd") 176 | 177 | output = bwd_model.generate("Why do programmers prefer dark mode?") 178 | 179 | # You can always provide extra keyword arguments to the `generate` method, 180 | # which will override the ones provided when instantiating the models. 181 | 182 | output = bwd_model.generate( 183 | "Why do programmers prefer dark mode?", 184 | max_tokens=100, 185 | echo=True, 186 | ) 187 | ``` 188 | 189 | ## Losses, Samplers and Scores 190 | 191 | DLN provides a few losses that can be found in [dln/losses.py](dln/loss.py). A simple example of how to use them is as follows: 192 | 193 | ```python 194 | from dln.loss import LossRegistry 195 | from dln.postprocessing import postprocess_prediction 196 | 197 | LossRegistry.available_losses() # list available losses 198 | loss_fn = LossRegistry.instantiate( 199 | "exact_match_loss", postprocess_prediction 200 | ) 201 | y = ["Montreal", "Toronto", "Sao Paulo"] 202 | y_hat = ["Montréal", "Toronto", "SaoPaulo"] 203 | losses = loss_fn(y_hat, y) 204 | # array([1., 0., 1.], dtype=float32) 205 | ``` 206 | For sampling and scoring both prompts and hidden states for the Variational Inference algorithm, samplers are found in [dln/vi/sampler.py](dln/vi/sampler.py), and the LogProbsScore in [dln/score.py](dln/score.py). Samplers use templates that are found in [dln/templates.py](dln/templates.py). 207 | 208 | 209 | ```python 210 | import numpy as np 211 | from dln.operator import LLMRegistry 212 | from dln.vi.sampler import PosteriorSampler, PromptSampler 213 | from dln.score import LogProbsScore, ScoreRequest 214 | 215 | llm_registry = LLMRegistry() 216 | llm = llm_registry.register( 217 | "llm", 218 | "microsoft/phi-2", 219 | ) 220 | 221 | prompt_sampler = PromptSampler(llm, "q_action_prompt") 222 | posterior_sampler = PosteriorSampler(llm, "suffix_forward_tbs") 223 | logprobs_score = LogProbsScore(llm) 224 | 225 | prompt_proposals = prompt_sampler.sample_q_p( 226 | inputs=["France", "Canada", "Brazil"], 227 | y=["Paris", "Ottawa", "Brasilia"], 228 | y_hat=["Paris", "Ottawa", "Sao Paulo"], 229 | losses=[0, 0, 0, 1], 230 | prompt="What is the capital of this country", 231 | num_samples=10, 232 | ) # sample prompts 233 | 234 | hidden_states = posterior_sampler.sample_q_h( 235 | x=np.array(["France", "Canada", "Brazil"]), 236 | y=["Paris", "Ottawa", "Brasilia"], 237 | h=["Paris", "Toronto", "Sao Paulo"], 238 | prompt="What is the largest city in this country", 239 | next_prompt="What is the capital of this country", 240 | num_samples=10, 241 | ) 242 | 243 | score_request = ScoreRequest( 244 | context="What is the capital of this country: Canada", 245 | target="Ottawa", 246 | payload="Ottawa", 247 | ) 248 | score = logprobs_score.score_requests([score_request]) 249 | # LogProbs(logp_targets=array([-7.67090403]), distribution=array([-3.02606859])) 250 | ``` 251 | 252 | 253 | You can refer to [vi_main.py](projects/vi_dln/vi_main.py) for a complete example of how to use the DLN components. 254 | 255 | 256 | ## Variational Inference experiments 257 | 258 | Please see the [Variational Inference README](projects/vi_dln/README.md) for information on how to run VI experiments. 259 | 260 | 261 | ## Limitations 262 | 263 | When it comes to large-scale natural language models, there are particular fairness and responsible AI issues to consider. 264 | People use language to describe the world and to express their beliefs, assumptions, attitudes, and values. 265 | As a result, publicly available text data typically used to train large-scale natural language models contains 266 | societal biases relating to race, gender, religion, age, and other groups of people, as well as other undesirable content. 267 | These societal biases are reflected in the distributions of words, phrases, and syntactic structures. 268 | Large-scale natural language models trained with such data can potentially behave in ways that are unfair, 269 | unreliable, or offensive, in turn causing harms. 270 | 271 | While we are fully aware of the limitations of addressing societal issues through technical work, 272 | we hope that modular approaches like ours will alleviate some of the issues associated with LLMs, 273 | like the concentration of power associated with the difficulty to train them. We also hope that, 274 | by facilitating the reusability and adaptivity of such models, we shall make them more amenable to a wider variety of use cases. 275 | However, while we discuss the performance of these models on artificial benchmarks, 276 | we do not address the question of when and how such models should be deployed, 277 | nor do we offer additional guarantees against their misuse. We also emphasize that performance on artificial tasks, 278 | even if realistic, is neither representative of performance in uncontrolled environments, 279 | nor enough to justify the deployment of these models in high stakes situations. 280 | Please refer to our paper for the specific evaluations we conducted. 281 | 282 | ## Citing Deep Language Networks 283 | If you find DLNs useful, please consider citing this work! 284 | 285 | ```text 286 | @article{sordoni2023deep, 287 | title={Deep Language Networks: Joint Prompt Training of Stacked LLMs using Variational Inference}, 288 | author={Alessandro Sordoni and Xingdi Yuan and Marc-Alexandre Côté and Matheus Pereira and Adam Trischler and Ziang Xiao and Arian Hosseini and Friederike Niedtner and Nicolas Le Roux}, 289 | year={2023}, 290 | eprint={2306.12509}, 291 | archivePrefix={arXiv}, 292 | primaryClass={cs.CL} 293 | } 294 | ``` 295 | 296 | ## Contributing 297 | 298 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 299 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 300 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 301 | 302 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 303 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 304 | provided by the bot. You will only need to do this once across all repos using our CLA. 305 | 306 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 307 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 308 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 309 | 310 | ## Trademarks 311 | 312 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 313 | trademarks or logos is subject to and must follow 314 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 315 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 316 | Any use of third-party trademarks or logos are subject to those third-party's policies. 317 | --------------------------------------------------------------------------------