├── .gitignore ├── LICENSE ├── README.md ├── config ├── dataset2level.json ├── dataset2maxlen.json ├── dataset2prompt.json ├── model2maxlen.json └── model2path.json ├── convert_llama_weights_to_hf.py ├── data └── WaterBench │ ├── 1-1_konwledge_memorization.jsonl │ ├── 1-2_konwledge_understanding.jsonl │ ├── 2-1_longform_qa.jsonl │ ├── 2-2_finance_qa.jsonl │ ├── 3-1_hotpotqa.jsonl │ ├── 3-2_lcc.jsonl │ ├── 4-1_multi_news.jsonl │ ├── 4-2_qmsum.jsonl │ └── 5-1_alpacafarm.jsonl ├── detect.py ├── detect_human.py ├── eval.py ├── generate.py ├── llama_flash_attn_monkey_patch.py ├── metrics.py ├── pictures ├── heat_datasets.pdf ├── llama2-7b-chat-4k_mission_scatters1.pdf ├── llama2-7b-chat-4k_mission_scatters2.pdf ├── llama2-7b-chat-4k_radar.pdf ├── llama2-7b-chat-4k_search.pdf ├── roc_auc.pdf └── scatters_all.pdf ├── plot_image_and_tables.ipynb ├── pred.py ├── process ├── __init__.py ├── download.py ├── hyper-search.py ├── load.py └── process_z.py ├── requirements.txt ├── shell ├── detect.sh ├── detect_human.sh ├── eval.sh └── pred.sh └── watermark ├── __init__.py ├── gptwm.py ├── homoglyphs.py ├── normalizers.py ├── old_watermark.py ├── our_watermark.py ├── run_generate.py └── watermark_v2.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | data/download/ 6 | data/models/ 7 | data/openai_api_key.txt 8 | # C extensions 9 | *.so 10 | *.csv 11 | pred/ 12 | mutual_detect/ 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Shangqing Tu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WaterBench 2 | Data and Code for the paper, WaterBench: Towards Holistic Evaluation of Watermarks for Large Language Models 3 | 4 | ## Installation 5 | To download the repo to the local machine, run the following command: 6 | ``` bash 7 | git clone https://github.com/THU-KEG/WaterBench.git 8 | cd WaterBench 9 | ``` 10 | 11 | ## How to evaluate on WaterBench 12 | 13 | #### Load Data 14 | 15 | Our datasets are stored in the `data` folder. To load the data, you can use the following code(Here we take the *alpacafarm* dataset as an example): 16 | 17 | ``` python 18 | dataset2level = json.load(open("config/dataset2level.json", "r")) 19 | dataset = "alpacafarm" 20 | data = [] 21 | with open("data/WaterBench/{}_{}.jsonl".format(dataset2level[dataset], dataset), "r", encoding="utf-8") as f: 22 | for line in f: 23 | data.append(json.loads(line)) 24 | ``` 25 | 26 | To convert the original input to the prompt format, you can use the following code(Here we take the *alpacafarm* dataset as an example): 27 | 28 | ``` python 29 | dataset2prompt = json.load(open("config/dataset2prompt.json", "r")) 30 | dataset = "alpacafarm" 31 | prompt_format = dataset2prompt[dataset] 32 | 33 | for json_obj in tqdm(data): 34 | # json_obj is every piece of the data 35 | prompt = prompt_format.format(**json_obj) 36 | 37 | ``` 38 | 39 | #### Data Format 40 | 41 | All the missions WaterBench used are specified in the `data` folder. The data is in the format of `jsonl`. Each line is a json object with the following fields: 42 | 43 | ``` json 44 | { 45 | "input": "The input/command for the task, usually short, such as questions in QA etc", 46 | "context": "The long context required for the task, such as documents, cross-file code", 47 | "outputs": "A List of all true answers", 48 | "input_length": "length of the first two items(counted in characters for Chinese and words for English)", 49 | "output_length": "length of the third item (counted in characters for Chinese and words for English)", 50 | "length": "length of the first three items(counted in characters for Chinese and words for English)", 51 | "all_classes": "All categories in classification tasks, null for non-classification tasks", 52 | "language": "The language of this piece of data", 53 | "dataset": "The name of the dataset to which this piece of data belongs", 54 | "_id": "Random id for each piece of data" 55 | } 56 | ``` 57 | 58 | #### Load Model 59 | 60 | The models we used are *llama2-7b-chat-4k* and *internlm-chat-7b-8k*, you can change the path of model in `config/model2path.json` file. 61 | 62 | You can get the *llama2* from [here](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf), and *internlm* from [here](https://huggingface.co/internlm/internlm-chat-7b-8k). 63 | 64 | We use the `transformers` library to load the model. The explicit code is in the `pred.py` file, and the following is the main code: 65 | 66 | ``` python 67 | from transformers import AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM 68 | def load_model_and_tokenizer(path_to_your_model, model_name, device, load_token_only=False): 69 | if "internlm" in model_name: 70 | tokenizer = AutoTokenizer.from_pretrained(path_to_your_model, trust_remote_code=True) 71 | if not load_token_only: 72 | model = AutoModelForCausalLM.from_pretrained(path_to_your_model, trust_remote_code=True, 73 | output_scores=True, return_dict_in_generate=True, 74 | torch_dtype=torch.bfloat16).to(device) 75 | model.eval() 76 | elif "llama2"in model_name: 77 | # replace_llama_attn_with_flash_attn() 78 | tokenizer = LlamaTokenizer.from_pretrained(path_to_your_model) 79 | if not load_token_only: 80 | model = LlamaForCausalLM.from_pretrained(path_to_your_model, output_scores=True, return_dict_in_generate=True, 81 | torch_dtype=torch.bfloat16).to(device) 82 | if load_token_only: 83 | return tokenizer 84 | else: 85 | model = model.eval() 86 | return model, tokenizer 87 | 88 | ``` 89 | 90 | #### Evaluation 91 | 92 | Install the requirements with pip: `pip install -r requirements.txt`. For Llama-2 based models, we recommend using Flash Attention for optimization and saving GPU memory The relevant dependencies can be installed according to the code base of [Flash Attention](https://github.com/Dao-AILab/flash-attention). 93 | 94 | First, run `pred.py` and select the model you want to evaluate via `--model`. And select the mode and hyper-parameters of the watermark via `--mode`, `--bl_type`, `--gamma`, `--delta`. The parameter `mode` means the kinds of watermarks we used in the experiments, including `no`(without watermark), `old`(soft or hard watermark), `gpt`(gpt watermark), `v2`(v2, watermark). The parameter `bl_type` means whether the type of the watermark is hard or soft. Also, you can select the dataset you want to evaluate via `--dataset`. Here is an example to obtain the result that *llama2-7b-chat-4k* model performs with specified-parameters watermark on the *alpacafarm* dataset: 95 | 96 | ``` bash 97 | CUDA_VISIBLE_DEVICES=0 python pred.py \ 98 | --mode old \ 99 | --gamma 0.5 \ 100 | --delta 10 \ 101 | --bl_type hard \ 102 | --dataset alpacafarm \ 103 | --model llama2-7b-chat-4k \ 104 | ``` 105 | 106 | Or you can modify `shell/pred.sh` and run it directly. 107 | 108 | If you didn't specify the `--dataset`, the code will evaluate the model on all datasets in WaterBench. 109 | 110 | You can obtain the output of the model under all WaterBench datasets under the `pred/` folder corresponding to the model name. 111 | 112 | Then, run the detection code in `detect.py` to obtain z-scores: 113 | 114 | ``` bash 115 | CUDA_VISIBLE_DEVICES=0 python detect.py \ 116 | --input_dir ./pred/llama2-7b-chat-4k_old_g0.5_d10.0_hard 117 | ``` 118 | 119 | Or you can modify `shell/detect.sh` and run it directly. 120 | 121 | Then, you can obtain z-scores of every mission under the input_dir of detect . 122 | 123 | After that, you can run the code in `eval.py` to obtain the gpt-4 evaluation results on all datasets in `result.json`: 124 | 125 | ``` bash 126 | CUDA_VISIBLE_DEVICES=0 python eval.py \ 127 | --input_dir ./pred/llama2-7b-chat-4k_old_g0.5_d10.0_hard 128 | ``` 129 | Or you can modify `shell/eval.sh` and run it directly. 130 | 131 | To get the detection results of the model with watermarks on standard answers, you can run `detect_human.py`: 132 | ``` bash 133 | CUDA_VISIBLE_DEVICES=0 python detect_human.py \ 134 | --reference_dir llama2-7b-chat-4k_old_g0.25_d5.0_hard \ 135 | --detect_dir human_generation \ 136 | ``` 137 | Or you can modify `shell/detect_human.sh` and run it directly. 138 | 139 | Please note that in `config/`, we provide the input format suitable for each dataset and the maximum output length. Feel free to modify them to better suit the model you want to evaluate. After modification, when evaluating with `pred.py`, the data will be automatically organized according to the new format to get the corresponding model output. 140 | 141 | -------------------------------------------------------------------------------- /config/dataset2level.json: -------------------------------------------------------------------------------- 1 | { 2 | "konwledge_memorization" : "1-1", 3 | "konwledge_understanding" : "1-2", 4 | "longform_qa" : "2-1", 5 | "finance_qa" : "2-2", 6 | "hotpotqa" : "3-1", 7 | "lcc" : "3-2", 8 | "multi_news" : "4-1", 9 | "qmsum" : "4-2", 10 | "alpacafarm" : "5-1" 11 | } -------------------------------------------------------------------------------- /config/dataset2maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": 128, 3 | "qasper": 128, 4 | "multifieldqa_en": 64, 5 | "multifieldqa_zh": 64, 6 | "hotpotqa": 32, 7 | "2wikimqa": 32, 8 | "musique": 32, 9 | "dureader": 128, 10 | "gov_report": 512, 11 | "qmsum": 512, 12 | "multi_news": 512, 13 | "vcsum": 512, 14 | "trec": 64, 15 | "triviaqa": 32, 16 | "samsum": 128, 17 | "lsht": 64, 18 | "passage_count": 32, 19 | "passage_retrieval_en": 32, 20 | "passage_retrieval_zh": 32, 21 | "lcc": 64, 22 | "repobench-p": 64, 23 | "konwledge_memorization": 32, 24 | "konwledge_understanding": 16, 25 | "longform_qa": 300, 26 | "finance_qa": 300, 27 | "alpacafarm": 300 28 | } -------------------------------------------------------------------------------- /config/dataset2prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 3 | "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 4 | "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 5 | "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", 6 | "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 7 | "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 8 | "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 9 | "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", 10 | "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", 11 | "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", 12 | "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", 13 | "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", 14 | "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", 15 | "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", 16 | "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", 17 | "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", 18 | "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", 19 | "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ", 20 | "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:", 21 | "lcc": "Please complete the code given below. \n{context}Next line of code:\n", 22 | "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n", 23 | "konwledge_memorization": "Please give answer to the following question about knowledge. Note: If there are more than one answer, print them all and separate them with a semicolon (;). Please do not give anything other than the answers. Question:\n{context}\n{input}", 24 | "konwledge_understanding": "Please give answer to the following question about knowledge. Note: If you are asked for true or false, just answer \"true\" or \"false\" only. If you are asked for similarity, just answer with the entity name only. Do not give anything other than the answers. Question:\n{context}\n{input}", 25 | "longform_qa": "You are a helpful assistant, please answer the following question within 300 words:\n{context}\n{input}", 26 | "finance_qa": "You are a helpful assistant, please answer the following question with financial knowledge within 300 words:\n{context}\n{input}", 27 | "alpacafarm": "You are a helpful assistant, please answer the following instruction: \n{context}\n{input}" 28 | } -------------------------------------------------------------------------------- /config/model2maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "llama2-7b-chat-4k": 3500, 3 | "longchat-v1.5-7b-32k": 31500, 4 | "xgen-7b-8k": 7500, 5 | "internlm-7b-8k": 2048, 6 | "chatglm2-6b": 31500, 7 | "chatglm2-6b-32k": 31500, 8 | "vicuna-v1.5-7b-16k": 15500, 9 | "tulu-7b": 2048, 10 | "vicuna-13b": 2048 11 | } -------------------------------------------------------------------------------- /config/model2path.json: -------------------------------------------------------------------------------- 1 | { 2 | "llama2-7b-chat-4k": "./data/models/llama-2-7b-chat-hf", 3 | "longchat-v1.5-7b-32k": "lmsys/longchat-7b-v1.5-32k", 4 | "xgen-7b-8k": "Salesforce/xgen-7b-8k-inst", 5 | "internlm-7b-8k": "/data2/cookie_huggingface_models/internlm-chat-7b-8k", 6 | "chatglm2-6b": "THUDM/chatglm2-6b", 7 | "chatglm2-6b-32k": "./data/models/chatglm2-6b-32k", 8 | "vicuna-v1.5-7b-16k": "lmsys/vicuna-7b-v1.5-16k", 9 | "tulu-7b": "/data2/cookie_huggingface_models/tulu-7b-full", 10 | "vicuna-13b": "/data2/cookie_huggingface_models/vicuna-13b" 11 | } -------------------------------------------------------------------------------- /convert_llama_weights_to_hf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import gc 16 | import json 17 | import math 18 | import os 19 | import shutil 20 | 21 | import torch 22 | 23 | from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer 24 | 25 | 26 | """ 27 | Sample usage: 28 | 29 | ``` 30 | python src/transformers/models/llama/convert_llama_weights_to_hf.py \ 31 | --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path 32 | ``` 33 | 34 | Thereafter, models can be loaded via: 35 | 36 | ```py 37 | from transformers import LlamaForCausalLM, LlamaTokenizer 38 | 39 | model = LlamaForCausalLM.from_pretrained("/output/path") 40 | tokenizer = LlamaTokenizer.from_pretrained("/output/path") 41 | ``` 42 | 43 | Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions 44 | come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). 45 | """ 46 | 47 | INTERMEDIATE_SIZE_MAP = { 48 | "7B": 11008, 49 | "13B": 13824, 50 | "30B": 17920, 51 | "65B": 22016, 52 | } 53 | NUM_SHARDS = { 54 | "7B": 1, 55 | "13B": 2, 56 | "30B": 4, 57 | "65B": 8, 58 | } 59 | 60 | 61 | def compute_intermediate_size(n): 62 | return int(math.ceil(n * 8 / 3) + 255) // 256 * 256 63 | 64 | 65 | def read_json(path): 66 | with open(path, "r") as f: 67 | return json.load(f) 68 | 69 | 70 | def write_json(text, path): 71 | with open(path, "w") as f: 72 | json.dump(text, f) 73 | 74 | 75 | def write_model(model_path, input_base_path, model_size): 76 | os.makedirs(model_path, exist_ok=True) 77 | tmp_model_path = os.path.join(model_path, "tmp") 78 | os.makedirs(tmp_model_path, exist_ok=True) 79 | 80 | params = read_json(os.path.join(input_base_path, "params.json")) 81 | num_shards = NUM_SHARDS[model_size] 82 | n_layers = params["n_layers"] 83 | n_heads = params["n_heads"] 84 | n_heads_per_shard = n_heads // num_shards 85 | dim = params["dim"] 86 | dims_per_head = dim // n_heads 87 | base = 10000.0 88 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 89 | 90 | # permute for sliced rotary 91 | def permute(w): 92 | return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim) 93 | 94 | print(f"Fetching all parameters from the checkpoint at {input_base_path}.") 95 | # Load weights 96 | if model_size == "7B": 97 | # Not shared 98 | # (The sharded implementation would also work, but this is simpler.) 99 | loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") 100 | else: 101 | # Sharded 102 | loaded = [ 103 | torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") 104 | for i in range(num_shards) 105 | ] 106 | param_count = 0 107 | index_dict = {"weight_map": {}} 108 | for layer_i in range(n_layers): 109 | filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" 110 | if model_size == "7B": 111 | # Unsharded 112 | state_dict = { 113 | f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( 114 | loaded[f"layers.{layer_i}.attention.wq.weight"] 115 | ), 116 | f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( 117 | loaded[f"layers.{layer_i}.attention.wk.weight"] 118 | ), 119 | f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], 120 | f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], 121 | f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], 122 | f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], 123 | f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], 124 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"], 125 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"], 126 | } 127 | else: 128 | # Sharded 129 | # Note that in the 13B checkpoint, not cloning the two following weights will result in the checkpoint 130 | # becoming 37GB instead of 26GB for some reason. 131 | state_dict = { 132 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ 133 | f"layers.{layer_i}.attention_norm.weight" 134 | ].clone(), 135 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ 136 | f"layers.{layer_i}.ffn_norm.weight" 137 | ].clone(), 138 | } 139 | state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( 140 | torch.cat( 141 | [ 142 | loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) 143 | for i in range(num_shards) 144 | ], 145 | dim=0, 146 | ).reshape(dim, dim) 147 | ) 148 | state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( 149 | torch.cat( 150 | [ 151 | loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim) 152 | for i in range(num_shards) 153 | ], 154 | dim=0, 155 | ).reshape(dim, dim) 156 | ) 157 | state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( 158 | [ 159 | loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim) 160 | for i in range(num_shards) 161 | ], 162 | dim=0, 163 | ).reshape(dim, dim) 164 | 165 | state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( 166 | [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 167 | ) 168 | state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( 169 | [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 170 | ) 171 | state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( 172 | [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 173 | ) 174 | state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( 175 | [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 176 | ) 177 | 178 | state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq 179 | for k, v in state_dict.items(): 180 | index_dict["weight_map"][k] = filename 181 | param_count += v.numel() 182 | torch.save(state_dict, os.path.join(tmp_model_path, filename)) 183 | 184 | filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" 185 | if model_size == "7B": 186 | # Unsharded 187 | state_dict = { 188 | "model.embed_tokens.weight": loaded["tok_embeddings.weight"], 189 | "model.norm.weight": loaded["norm.weight"], 190 | "lm_head.weight": loaded["output.weight"], 191 | } 192 | else: 193 | state_dict = { 194 | "model.norm.weight": loaded[0]["norm.weight"], 195 | "model.embed_tokens.weight": torch.cat( 196 | [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 197 | ), 198 | "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), 199 | } 200 | 201 | for k, v in state_dict.items(): 202 | index_dict["weight_map"][k] = filename 203 | param_count += v.numel() 204 | torch.save(state_dict, os.path.join(tmp_model_path, filename)) 205 | 206 | # Write configs 207 | index_dict["metadata"] = {"total_size": param_count * 2} 208 | write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) 209 | 210 | config = LlamaConfig( 211 | hidden_size=dim, 212 | intermediate_size=compute_intermediate_size(dim), 213 | num_attention_heads=params["n_heads"], 214 | num_hidden_layers=params["n_layers"], 215 | rms_norm_eps=params["norm_eps"], 216 | ) 217 | config.save_pretrained(tmp_model_path) 218 | 219 | # Make space so we can load the model properly now. 220 | del state_dict 221 | del loaded 222 | gc.collect() 223 | 224 | print("Loading the checkpoint in a Llama model.") 225 | model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 226 | # Avoid saving this as part of the config. 227 | del model.config._name_or_path 228 | 229 | print("Saving in the Transformers format.") 230 | model.save_pretrained(model_path) 231 | shutil.rmtree(tmp_model_path) 232 | 233 | 234 | def write_tokenizer(tokenizer_path, input_tokenizer_path): 235 | input_tokenizer_path = "/data3/MODELS/llama2/tokenizer.model" 236 | print(f"Fetching the tokenizer from {input_tokenizer_path}.") 237 | # Initialize the tokenizer based on the `spm` model 238 | tokenizer = LlamaTokenizer(input_tokenizer_path) 239 | tokenizer.save_pretrained(tokenizer_path) 240 | 241 | 242 | def main(): 243 | parser = argparse.ArgumentParser() 244 | parser.add_argument( 245 | "--input_dir", 246 | help="Location of LLaMA weights, which contains tokenizer.model and model folders", 247 | ) 248 | parser.add_argument( 249 | "--model_size", 250 | choices=["7B", "13B", "30B", "65B", "tokenizer_only"], 251 | ) 252 | parser.add_argument( 253 | "--output_dir", 254 | help="Location to write HF model and tokenizer", 255 | ) 256 | args = parser.parse_args() 257 | if args.model_size != "tokenizer_only": 258 | write_model( 259 | model_path=args.output_dir, 260 | input_base_path=args.input_dir, 261 | model_size=args.model_size, 262 | ) 263 | write_tokenizer( 264 | tokenizer_path=args.output_dir, 265 | input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"), 266 | ) 267 | 268 | 269 | if __name__ == "__main__": 270 | main() 271 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | from watermark.old_watermark import OldWatermarkDetector 2 | from watermark.our_watermark import NewWatermarkDetector 3 | from watermark.gptwm import GPTWatermarkDetector 4 | from watermark.watermark_v2 import WatermarkDetector 5 | from tqdm import tqdm 6 | from pred import load_model_and_tokenizer, seed_everything, str2bool 7 | import argparse 8 | import os 9 | import json 10 | import torch 11 | 12 | def main(args): 13 | seed_everything(42) 14 | model2path = json.load(open("config/model2path.json", "r")) 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | # get model name 17 | model_name = args.input_dir.split("/")[-1].split("_")[0] 18 | # define your model 19 | tokenizer = load_model_and_tokenizer(model2path[model_name], model_name, device, load_token_only=True) 20 | all_token_ids = list(tokenizer.get_vocab().values()) 21 | vocab_size = len(all_token_ids) 22 | # get gamma and delta 23 | if "gpt" in args.input_dir: 24 | gamma = float(args.input_dir.split("_g")[2].split("_")[0]) 25 | else: 26 | gamma = float(args.input_dir.split("_g")[1].split("_")[0]) 27 | 28 | delta = float(args.input_dir.split("_d")[1].split("_")[0]) 29 | # get all files from input_dir 30 | files = os.listdir(args.input_dir) 31 | # get all json files 32 | json_files = [f for f in files if f.endswith(".jsonl")] 33 | os.makedirs(args.input_dir + "/z_score", exist_ok=True) 34 | if args.mission != "all": 35 | json_files = [f for f in files if args.mission in f] 36 | for json_file in json_files: 37 | print(f"{json_file} has began.........") 38 | # read jsons 39 | with open(os.path.join(args.input_dir, json_file), "r") as f: 40 | # lines 41 | lines = f.readlines() 42 | # texts 43 | prompts = [json.loads(line)["prompt"] for line in lines] 44 | texts = [json.loads(line)["pred"] for line in lines] 45 | print(f"texts[0] is: {texts[0]}") 46 | tokens = [json.loads(line)["completions_tokens"] for line in lines] 47 | 48 | 49 | 50 | if "old" in args.input_dir or "no" in args.input_dir: 51 | detector = OldWatermarkDetector(tokenizer=tokenizer, 52 | vocab=all_token_ids, 53 | gamma=gamma, 54 | delta=delta, 55 | dynamic_seed="markov_1", 56 | device=device) 57 | 58 | if "new" in args.input_dir: 59 | detector = NewWatermarkDetector(tokenizer=tokenizer, 60 | vocab=all_token_ids, 61 | gamma=gamma, 62 | delta=delta, 63 | dynamic_seed="markov_1", 64 | device=device, 65 | # vocabularys=vocabularys, 66 | ) 67 | 68 | if "v2" in args.input_dir: 69 | detector = WatermarkDetector( 70 | vocab=all_token_ids, 71 | gamma=gamma, 72 | z_threshold=args.threshold,tokenizer=tokenizer, 73 | seeding_scheme=args.seeding_scheme, 74 | device=device, 75 | normalizers=args.normalizers, 76 | ignore_repeated_bigrams=args.ignore_repeated_bigrams, 77 | select_green_tokens=args.select_green_tokens) 78 | 79 | if "gpt" in args.input_dir: 80 | detector = GPTWatermarkDetector( 81 | fraction=gamma, 82 | strength=delta, 83 | vocab_size=vocab_size, 84 | watermark_key=args.wm_key) 85 | 86 | 87 | 88 | z_score_list = [] 89 | for idx, cur_text in tqdm(enumerate(texts), total=len(texts)): 90 | #print("cur_text is:", cur_text) 91 | 92 | gen_tokens = tokenizer.encode(cur_text, return_tensors="pt", truncation=True, add_special_tokens=False) 93 | #print("gen_tokens is:", gen_tokens) 94 | prompt = prompts[idx] 95 | 96 | input_prompt = tokenizer.encode(prompt, return_tensors="pt", truncation=True,add_special_tokens=False) 97 | 98 | 99 | # if "v2" in args.input_dir and len(gen_tokens[0]) >= args.test_min_tokens: 100 | # z_score_list.append(detector.detect(cur_text)["z_score"]) 101 | 102 | # elif len(gen_tokens[0]) >= 1: 103 | # if "old" in args.input_dir or "no" in args.input_dir: 104 | # # print("gen_tokens is:", gen_tokens) 105 | # z_score_list.append(detector.detect(tokenized_text=gen_tokens, inputs=input_prompt)) 106 | 107 | # elif "gpt" in args.input_dir: 108 | # z_score_list.append(detector.detect(gen_tokens[0])) 109 | # elif "new" in args.input_dir: 110 | # z_score_list.append(detector.detect(tokenized_text=gen_tokens, tokens=tokens[idx], inputs=input_prompt)) 111 | 112 | # else: 113 | # print(f"Warning: sequence {idx} is too short to test. Which is ", gen_tokens[0]) 114 | 115 | # else: 116 | # print(f"Warning: sequence {idx} is too short to test. Which is ", gen_tokens[0]) 117 | 118 | if len(gen_tokens[0]) >= args.test_min_tokens: 119 | 120 | if "v2" in args.input_dir: 121 | z_score_list.append(detector.detect(cur_text)["z_score"]) 122 | 123 | elif "old" in args.input_dir or "no" in args.input_dir: 124 | print("gen_tokens is:", gen_tokens) 125 | z_score_list.append(detector.detect(tokenized_text=gen_tokens, inputs=input_prompt)) 126 | 127 | elif "gpt" in args.input_dir: 128 | z_score_list.append(detector.detect(gen_tokens[0])) 129 | elif "new" in args.input_dir: 130 | z_score_list.append(detector.detect(tokenized_text=gen_tokens, tokens=tokens[idx], inputs=input_prompt)) 131 | else: 132 | print(f"Warning: sequence {idx} is too short to test.") 133 | 134 | # if len(gen_tokens[0]) >= 1: 135 | # if "old" in args.input_dir or "no" in args.input_dir: 136 | # print("gen_tokens is:", gen_tokens) 137 | # z_score_list.append(detector.detect(tokenized_text=gen_tokens, inputs=input_prompt)) 138 | 139 | # elif "gpt" in args.input_dir: 140 | # z_score_list.append(detector.detect(gen_tokens[0])) 141 | # elif "new" in args.input_dir: 142 | # z_score_list.append(detector.detect(tokenized_text=gen_tokens, tokens=tokens[idx], inputs=input_prompt)) 143 | # else: 144 | # print(f"Warning: sequence {idx} is too short to test.") 145 | 146 | save_dict = { 147 | 'z_score_list': z_score_list, 148 | 'avarage_z': torch.mean(torch.tensor(z_score_list)).item(), 149 | 'wm_pred': [1 if z > args.threshold else 0 for z in z_score_list] 150 | } 151 | 152 | wm_pred_average = torch.mean(torch.tensor(save_dict['wm_pred'], dtype=torch.float)) 153 | save_dict.update({'wm_pred_average': wm_pred_average.item()}) 154 | 155 | print(save_dict) 156 | # average_z = torch.mean(z_score_list) 157 | z_file = json_file.replace('.jsonl', f'_{gamma}_{delta}_{args.threshold}_z.jsonl') 158 | output_path = os.path.join(args.input_dir + "/z_score", z_file) 159 | with open(output_path, 'w') as fout: 160 | json.dump(save_dict, fout) 161 | 162 | 163 | 164 | 165 | 166 | parser = argparse.ArgumentParser(description="Process watermark to calculate z-score for every method") 167 | 168 | parser.add_argument( 169 | "--input_dir", 170 | type=str, 171 | default="/data2/tsq/WaterBench/pred/llama2-7b-chat-4k_old_g0.5_d5.0") 172 | parser.add_argument( # for gpt watermark 173 | "--wm_key", 174 | type=int, 175 | default=0) 176 | 177 | parser.add_argument( 178 | "--threshold", 179 | type=float, 180 | default=6.0) 181 | 182 | parser.add_argument( 183 | "--test_min_tokens", 184 | type=int, 185 | default=2) 186 | 187 | parser.add_argument( # for v2 watermark 188 | "--seeding_scheme", 189 | type=str, 190 | default="simple_1", 191 | help="Seeding scheme to use to generate the greenlists at each generation and verification step.", 192 | ) 193 | 194 | parser.add_argument( # for v2 watermark 195 | "--normalizers", 196 | type=str, 197 | default="", 198 | help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.", 199 | ) 200 | 201 | parser.add_argument( # for v2 watermark 202 | "--ignore_repeated_bigrams", 203 | type=str2bool, 204 | default=False, 205 | help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.", 206 | ) 207 | 208 | parser.add_argument( # for v2 watermark 209 | "--select_green_tokens", 210 | type=str2bool, 211 | default=True, 212 | help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.", 213 | ) 214 | 215 | parser.add_argument( 216 | "--mission", 217 | type=str, 218 | default="all", 219 | help="mission-name", 220 | ) 221 | args = parser.parse_args() 222 | 223 | main(args) 224 | 225 | -------------------------------------------------------------------------------- /detect_human.py: -------------------------------------------------------------------------------- 1 | from watermark.old_watermark import OldWatermarkDetector 2 | from watermark.our_watermark import NewWatermarkDetector 3 | from watermark.gptwm import GPTWatermarkDetector 4 | from watermark.watermark_v2 import WatermarkDetector 5 | from tqdm import tqdm 6 | from pred import load_model_and_tokenizer, seed_everything, str2bool 7 | import argparse 8 | import os 9 | import json 10 | import torch 11 | import re 12 | 13 | def main(args): 14 | seed_everything(42) 15 | model2path = json.load(open("config/model2path.json", "r")) 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | # get model name 18 | model_name = args.reference_dir.split("/")[-1].split("_")[0] 19 | 20 | # define your model 21 | tokenizer = load_model_and_tokenizer(model2path[model_name], model_name, device, load_token_only=True) 22 | all_token_ids = list(tokenizer.get_vocab().values()) 23 | vocab_size = len(all_token_ids) 24 | 25 | 26 | all_input_dir = "./pred/" 27 | # get gamma and delta 28 | 29 | pattern_dir = r"(?P.+)_(?Pold|v2|gpt|new|no)_g(?P.+)_d(?P\d+(\.\d+)?)" 30 | 31 | pattern_mis = r"(?P[a-zA-Z_]+)_(?P\d+(\.\d+)?)_(?P.+)_z" 32 | 33 | matcher_ref = re.match(pattern_dir, args.reference_dir) 34 | 35 | mode_ref = matcher_ref.group("mode") 36 | gamma_ref = float(matcher_ref.group("gamma")) 37 | delta_ref = float(matcher_ref.group("delta")) 38 | bl_type_ref = "None" 39 | bl_type_ref = (args.reference_dir.split("_")[-1]).split(".")[0] 40 | 41 | if bl_type_ref != "hard": 42 | if "old" in args.reference_dir: 43 | bl_type_ref = "soft" 44 | mode_ref += "_" + bl_type_ref 45 | else: 46 | bl_type_ref = "None" 47 | else: 48 | mode_ref += "_" + bl_type_ref 49 | 50 | print("mode_ref is:", mode_ref) 51 | 52 | if args.detect_dir != "human_generation": 53 | matcher_det = re.match(pattern_dir, args.detect_dir) 54 | mode_det = matcher_det.group("mode") 55 | gamma_det = float(matcher_det.group("gamma")) 56 | delta_det = float(matcher_det.group("delta")) 57 | bl_type_det = "None" 58 | bl_type_det = (args.detect_dir.split("_")[-1]).split(".")[0] 59 | 60 | 61 | if bl_type_det != "hard": 62 | if "old" in args.detect_dir: 63 | bl_type_det = "soft" 64 | mode_det += "_" + bl_type_det 65 | else: 66 | bl_type_det = "None" 67 | else: 68 | mode_det += "_" + bl_type_det 69 | 70 | # print("bl_type_det is:", bl_type_det) 71 | # print("mode_det is:", mode_det) 72 | # get all files from detect_dir 73 | 74 | files = os.listdir(all_input_dir + args.detect_dir) 75 | 76 | # get all json files 77 | json_files = [f for f in files if f.endswith(".jsonl")] 78 | 79 | ref_dir = f"./detect_human/{model_name}/ref_{mode_ref}_g{gamma_ref}_d{delta_ref}" 80 | 81 | os.makedirs(f"./detect_human/{model_name}", exist_ok=True) 82 | os.makedirs(ref_dir, exist_ok=True) 83 | if args.detect_dir == "human_generation": 84 | os.makedirs(ref_dir + "/human_generation_z", exist_ok=True) 85 | else: 86 | os.makedirs(ref_dir + f"/{mode_det}_g{gamma_det}_d{delta_det}_z", exist_ok=True) 87 | 88 | if "old" in args.reference_dir or "no" in args.reference_dir: 89 | detector = OldWatermarkDetector(tokenizer=tokenizer, 90 | vocab=all_token_ids, 91 | gamma=gamma_ref, 92 | delta=delta_ref, 93 | dynamic_seed="markov_1", 94 | device=device) 95 | 96 | if "new" in args.reference_dir: 97 | detector = NewWatermarkDetector(tokenizer=tokenizer, 98 | vocab=all_token_ids, 99 | gamma=gamma_ref, 100 | delta=delta_ref, 101 | dynamic_seed="markov_1", 102 | device=device, 103 | # vocabularys=vocabularys, 104 | ) 105 | 106 | if "v2" in args.reference_dir: 107 | detector = WatermarkDetector( 108 | vocab=all_token_ids, 109 | gamma=gamma_ref, 110 | z_threshold=args.threshold,tokenizer=tokenizer, 111 | seeding_scheme=args.seeding_scheme, 112 | device=device, 113 | normalizers=args.normalizers, 114 | ignore_repeated_bigrams=args.ignore_repeated_bigrams, 115 | select_green_tokens=args.select_green_tokens) 116 | 117 | if "gpt" in args.reference_dir: 118 | detector = GPTWatermarkDetector( 119 | fraction=gamma_ref, 120 | strength=delta_ref, 121 | vocab_size=vocab_size, 122 | watermark_key=args.wm_key) 123 | prompts = [] 124 | for json_file in json_files: 125 | print(f"{json_file} has began.........") 126 | if args.detect_dir == "human_generation": 127 | with open(os.path.join(all_input_dir + args.reference_dir, json_file), "r") as f: 128 | lines = f.readlines() 129 | 130 | prompts = [json.loads(line)["prompt"] for line in lines] 131 | print("len of prompts is", len(prompts)) 132 | # read jsons 133 | with open(os.path.join(all_input_dir + args.detect_dir, json_file), "r") as f: 134 | # lines 135 | lines = f.readlines() 136 | # texts 137 | if args.detect_dir != "human_generation": 138 | prompts = [json.loads(line)["prompt"] for line in lines] 139 | texts = [json.loads(line)["pred"] for line in lines] 140 | print(f"texts[0] is: {texts[0]}") 141 | tokens = [json.loads(line)["completions_tokens"] for line in lines] 142 | 143 | z_score_list = [] 144 | for idx, cur_text in tqdm(enumerate(texts), total=len(texts)): 145 | #print("cur_text is:", cur_text) 146 | 147 | gen_tokens = tokenizer.encode(cur_text, return_tensors="pt", truncation=True, add_special_tokens=False) 148 | #print("gen_tokens is:", gen_tokens) 149 | prompt = prompts[idx] 150 | 151 | input_prompt = tokenizer.encode(prompt, return_tensors="pt", truncation=True,add_special_tokens=False) 152 | 153 | 154 | if len(gen_tokens[0]) >= args.test_min_tokens: 155 | 156 | if "v2" in args.reference_dir: 157 | z_score_list.append(detector.detect(cur_text)["z_score"]) 158 | 159 | if len(gen_tokens[0]) >= 1: 160 | if "gpt" in args.reference_dir: 161 | z_score_list.append(detector.detect(gen_tokens[0])) 162 | 163 | elif "old" in args.reference_dir or "no" in args.reference_dir: 164 | z_score_list.append(detector.detect(tokenized_text=gen_tokens, inputs=input_prompt)) 165 | 166 | elif "new" in args.reference_dir: 167 | z_score_list.append(detector.detect(tokenized_text=gen_tokens, tokens=tokens[idx], inputs=input_prompt)) 168 | 169 | else: 170 | print(f"Warning: sequence {idx} is too short to test.") 171 | 172 | save_dict = { 173 | 'z_score_list': z_score_list, 174 | 'avarage_z': torch.mean(torch.tensor(z_score_list)).item(), 175 | 'wm_pred': [1 if z > args.threshold else 0 for z in z_score_list] 176 | } 177 | 178 | wm_pred_average = torch.mean(torch.tensor(save_dict['wm_pred'], dtype=torch.float)) 179 | save_dict.update({'wm_pred_average': wm_pred_average.item()}) 180 | 181 | print(save_dict) 182 | # average_z = torch.mean(z_score_list) 183 | z_file = json_file.replace('.jsonl', f'_{args.threshold}_z.jsonl') 184 | 185 | if args.detect_dir != "human_generation": 186 | output_path = os.path.join(ref_dir + f"/{mode_det}_g{gamma_det}_d{delta_det}_z", z_file) 187 | 188 | else: 189 | output_path = os.path.join(ref_dir + "/human_generation_z", z_file) 190 | with open(output_path, 'w') as fout: 191 | json.dump(save_dict, fout) 192 | 193 | 194 | 195 | 196 | 197 | parser = argparse.ArgumentParser(description="Process watermark to calculate z-score for every method") 198 | 199 | parser.add_argument( 200 | "--input_dir", 201 | type=str, 202 | default="/data2/tsq/WaterBench/pred/llama2-7b-chat-4k_old_g0.5_d5.0") 203 | 204 | parser.add_argument( # for gpt watermark 205 | "--wm_key", 206 | type=int, 207 | default=0) 208 | 209 | parser.add_argument( 210 | "--threshold", 211 | type=float, 212 | default=4.0) 213 | 214 | parser.add_argument( 215 | "--test_min_tokens", 216 | type=int, 217 | default=2) 218 | 219 | parser.add_argument( # for v2 watermark 220 | "--seeding_scheme", 221 | type=str, 222 | default="simple_1", 223 | help="Seeding scheme to use to generate the greenlists at each generation and verification step.", 224 | ) 225 | 226 | parser.add_argument( # for v2 watermark 227 | "--normalizers", 228 | type=str, 229 | default="", 230 | help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.", 231 | ) 232 | 233 | parser.add_argument( # for v2 watermark 234 | "--ignore_repeated_bigrams", 235 | type=str2bool, 236 | default=False, 237 | help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.", 238 | ) 239 | 240 | parser.add_argument( # for v2 watermark 241 | "--select_green_tokens", 242 | type=str2bool, 243 | default=True, 244 | help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.", 245 | ) 246 | 247 | parser.add_argument( 248 | "--reference_dir", 249 | type=str, 250 | default="/data2/tsq/WaterBench/pred/llama2-7b-chat-4k_v2_g0.25_d15.0", 251 | help="Which type as reference to test TN or FP", 252 | ) 253 | 254 | parser.add_argument( 255 | "--detect_dir", 256 | type=str, 257 | default="/data2/tsq/WaterBench/pred/llama2-7b-chat-4k_v2_g0.25_d15.0", 258 | help="Which type need to be detected", 259 | ) 260 | args = parser.parse_args() 261 | 262 | main(args) 263 | 264 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | import openai 6 | from datasets import load_dataset 7 | from alpaca_farm.auto_annotations import alpaca_leaderboard 8 | import datasets 9 | from metrics import ( 10 | qa_f1_score, 11 | rouge_zh_score, 12 | qa_f1_zh_score, 13 | rouge_score, 14 | classification_score, 15 | retrieval_score, 16 | retrieval_zh_score, 17 | count_score, 18 | code_sim_score, 19 | ) 20 | openai.api_key_path = "data/openai_api_key.txt" 21 | 22 | dataset2metric = { 23 | "narrativeqa": qa_f1_score, 24 | "qasper": qa_f1_score, 25 | "multifieldqa_en": qa_f1_score, 26 | "multifieldqa_zh": qa_f1_zh_score, 27 | "hotpotqa": qa_f1_score, 28 | "2wikimqa": qa_f1_score, 29 | "musique": qa_f1_score, 30 | "dureader": rouge_zh_score, 31 | "gov_report": rouge_score, 32 | "qmsum": rouge_score, 33 | "multi_news": rouge_score, 34 | "vcsum": rouge_zh_score, 35 | "trec": classification_score, 36 | "triviaqa": qa_f1_score, 37 | "samsum": rouge_score, 38 | "lsht": classification_score, 39 | "passage_retrieval_en": retrieval_score, 40 | "passage_count": count_score, 41 | "passage_retrieval_zh": retrieval_zh_score, 42 | "lcc": code_sim_score, 43 | "repobench-p": code_sim_score, 44 | "konwledge_memorization": qa_f1_score, 45 | "konwledge_understanding": qa_f1_score, 46 | "longform_qa": rouge_score, 47 | "finance_qa": rouge_score, 48 | } 49 | 50 | def parse_args(args=None): 51 | parser = argparse.ArgumentParser(description="Evaluate texts generated by every method") 52 | 53 | parser.add_argument( 54 | "--input_dir", 55 | type=str, 56 | default="/data2/tsq/WaterBench/pred/llama2-7b-chat-4k_no_g0.5_d5.0") 57 | args = parser.parse_args() 58 | 59 | return args 60 | 61 | # def scorer_e(dataset, predictions, answers, lengths, all_classes): 62 | # scores = {"0-4k": [], "4-8k": [], "8k+": []} 63 | # for (prediction, ground_truths, length) in zip(predictions, answers, lengths): 64 | # score = 0. 65 | # if dataset in ["trec", "triviaqa", "samsum", "lsht"]: 66 | # prediction = prediction.lstrip('\n').split('\n')[0] 67 | # for ground_truth in ground_truths: 68 | # score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) 69 | # if length < 4000: 70 | # scores["0-4k"].append(score) 71 | # elif length < 8000: 72 | # scores["4-8k"].append(score) 73 | # else: 74 | # scores["8k+"].append(score) 75 | # for key in scores.keys(): 76 | # scores[key] = round(100 * np.mean(scores[key]), 2) 77 | # return scores 78 | 79 | def scorer(dataset, predictions, answers, all_classes): 80 | total_score = 0. 81 | for (prediction, ground_truths) in zip(predictions, answers): 82 | score = 0. 83 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]: 84 | prediction = prediction.lstrip('\n').split('\n')[0] 85 | for ground_truth in ground_truths: 86 | score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) 87 | total_score += score 88 | return round(100 * total_score / len(predictions), 2) 89 | 90 | def alpacafarm_score(prompts, predictions, model_name): 91 | # outputs should be a list of json as such: 92 | # [{'instruction': 'What are the names of some famous actors that started their careers on Broadway?', 'input': '', 'output': 'Some famous actors that started their careers on Broadway are Hugh Jackman, Meryl Streep, Denzel Washington, Audra McDonald, and Lin-Manuel Miranda.', 'generator': 'gpt-3.5-turbo-0301', 'dataset': 'helpful_base', 'datasplit': 'eval'}, 93 | # ...] 94 | my_outputs = [] 95 | alapaca_eval_data = load_dataset("tatsu-lab/alpaca_farm", "alpaca_farm_evaluation")["eval"] 96 | for i, json_obj in enumerate(alapaca_eval_data): 97 | prompt = json_obj["instruction"] 98 | _input = json_obj["input"] 99 | prediction = predictions[i] 100 | my_outputs.append({"instruction": prompt, "input": _input, "generator": model_name, "output": prediction}) 101 | print("my_outputs[0] is:", my_outputs[0]) 102 | df_results = alpaca_leaderboard( 103 | path_or_all_outputs=my_outputs, 104 | name=model_name, 105 | is_add_reference_methods=False, 106 | annotators_config = "greedy_gpt4/configs.yaml" 107 | ) 108 | score = df_results.to_string(float_format="%.2f") 109 | return score 110 | 111 | 112 | if __name__ == '__main__': 113 | args = parse_args() 114 | scores = dict() 115 | # get all files from input_dir 116 | files = os.listdir(args.input_dir) 117 | model_name = args.input_dir.split("/")[-1] 118 | # get all json files 119 | json_files = [f for f in files if f.endswith(".jsonl")] 120 | save_dir = os.path.join(args.input_dir, "eval") 121 | os.makedirs(save_dir, exist_ok=True) 122 | print("Evaluating on:", files) 123 | for json_file in json_files: 124 | if not json_file.endswith("jsonl"): 125 | continue 126 | print(f"{json_file} has began.........") 127 | # read jsons 128 | dataset = json_file.split(".")[0] 129 | predictions, answers, lengths, all_classes = [], [], [], [] 130 | with open(os.path.join(args.input_dir, json_file), "r") as f: 131 | # lines 132 | lines = f.readlines() 133 | # texts 134 | prompts = [json.loads(line)["prompt"] for line in lines] 135 | predictions = [json.loads(line)["pred"] for line in lines] 136 | answers = [json.loads(line)["answers"] for line in lines] 137 | all_classes = json.loads(lines[0])["all_classes"] 138 | print(f"predictions[0] is: {predictions[0]}") 139 | if dataset == "alpacafarm": 140 | score = alpacafarm_score(prompts, predictions, model_name) 141 | else: 142 | score = scorer(dataset, predictions, answers, all_classes) 143 | scores[dataset] = score 144 | # save 145 | out_path = os.path.join(save_dir, "result.json") 146 | with open(out_path, "w") as f: 147 | json.dump(scores, f, ensure_ascii=False, indent=4) 148 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | import torch.nn.functional as F 4 | from watermark.old_watermark import BlacklistLogitsProcessor 5 | from watermark.our_watermark import OurBlacklistLogitsProcessor 6 | from watermark.gptwm import GPTWatermarkLogitsWarper 7 | from watermark.watermark_v2 import WatermarkLogitsProcessor 8 | from transformers import LogitsProcessorList 9 | 10 | 11 | class Generator(): 12 | def __init__(self, args, tokenizer, model) -> None: 13 | self.mode = args.mode # watermark mode 14 | self.init_seed, self.dyna_seed, self.gamma, \ 15 | self.delta, self.bl_type, self.num_beams, self.sampling_temp = args.initial_seed, args.dynamic_seed, args.gamma, args.delta, args.bl_type, args.num_beams, args.sampling_temp 16 | self.tokenizer = tokenizer 17 | self.model = model # language model 18 | 19 | self.all_token_ids = list(tokenizer.get_vocab().values()) 20 | self.vocab_size = len(self.all_token_ids) 21 | # if self.vocab_size != model.config.padded_vocab_size: 22 | # self.vocab_size = model.config.padded_vocab_size 23 | 24 | self.bl_processor = BlacklistLogitsProcessor( 25 | bad_words_ids=None, 26 | eos_token_id=tokenizer.eos_token_id, 27 | vocab=self.all_token_ids, 28 | vocab_size=self.vocab_size, 29 | bl_proportion=1-self.gamma, 30 | bl_logit_bias=self.delta, 31 | bl_type=self.bl_type, 32 | initial_seed=self.init_seed, 33 | dynamic_seed=self.dyna_seed) 34 | self.logit_processor_lst = LogitsProcessorList([self.bl_processor]) 35 | if args.mode == 'new': 36 | self.bl_processor = OurBlacklistLogitsProcessor(tokenizer=tokenizer, 37 | bad_words_ids=None, 38 | eos_token_id=tokenizer.eos_token_id, 39 | vocab=self.all_token_ids, 40 | all_vocab_size=self.vocab_size, 41 | bl_proportion=1-self.gamma, 42 | bl_logit_bias=self.delta, 43 | bl_type=self.bl_type, 44 | initial_seed=self.init_seed, 45 | dynamic_seed=self.dyna_seed) 46 | self.logit_processor_lst = LogitsProcessorList([self.bl_processor]) 47 | 48 | if args.mode == 'gpt': 49 | 50 | watermark_processor = GPTWatermarkLogitsWarper(vocab_size=self.vocab_size, 51 | fraction=args.gamma, 52 | strength=args.delta) 53 | 54 | self.logit_processor_lst = LogitsProcessorList([watermark_processor]) 55 | 56 | if args.mode == 'v2': 57 | watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()), 58 | gamma=args.gamma, 59 | delta=args.delta, 60 | seeding_scheme=args.seeding_scheme, 61 | select_green_tokens=args.select_green_tokens) 62 | self.logit_processor_lst = LogitsProcessorList([watermark_processor]) 63 | 64 | def generate(self, input_ids, max_new_tokens): 65 | if self.mode == 'new': 66 | example = {} 67 | 68 | outputs = self.model.generate( 69 | input_ids, max_new_tokens=max_new_tokens, 70 | logits_processor = self.logit_processor_lst, 71 | do_sample=True, 72 | top_k=0, 73 | temperature=self.sampling_temp 74 | ) 75 | 76 | example.update({"bl_vocabularys":self.logit_processor_lst[0].get_and_clear_vocabularys()}) 77 | # remove the attached input from output for some model 78 | scores = outputs.scores 79 | output_ids = outputs.sequences[0, -len(scores):] 80 | 81 | 82 | # compute logprob for each token 83 | completions_tokens = [] 84 | completions_logprob = 0 85 | 86 | for score, token, vocabulary in zip(scores, output_ids, example["bl_vocabularys"], strict=True): 87 | logprobs = F.log_softmax(score[0], dim=-1) 88 | logprob = logprobs[token].item() 89 | completions_tokens.append({ 90 | 'text': self.tokenizer.decode(token), 91 | 'logprob': logprob, 92 | 'vocabulary': vocabulary, 93 | }) 94 | completions_logprob += logprob 95 | 96 | completions_text = self.tokenizer.decode(output_ids, skip_special_tokens=True) 97 | return completions_text, completions_tokens 98 | else: 99 | 100 | if self.mode == 'no': 101 | 102 | outputs = self.model.generate( 103 | input_ids, max_new_tokens=max_new_tokens, 104 | ) 105 | 106 | elif self.mode == 'old': 107 | 108 | outputs = self.model.generate( 109 | input_ids, max_new_tokens=max_new_tokens, 110 | logits_processor = self.logit_processor_lst, 111 | do_sample=True, 112 | top_k=0, 113 | temperature=self.sampling_temp 114 | ) 115 | 116 | elif self.mode == 'gpt': 117 | 118 | outputs = self.model.generate( 119 | input_ids, max_new_tokens=max_new_tokens, 120 | logits_processor = self.logit_processor_lst, 121 | do_sample=True, 122 | top_k=0, 123 | top_p=0.9 124 | ) 125 | 126 | elif self.mode == 'v2': 127 | outputs = self.model.generate( 128 | input_ids, max_new_tokens=max_new_tokens, 129 | logits_processor = self.logit_processor_lst, 130 | do_sample=True, 131 | top_k=0, 132 | temperature=self.sampling_temp 133 | ) 134 | 135 | # remove the attached input from output for some model 136 | scores = outputs.scores 137 | output_ids = outputs.sequences[0, -len(scores):] 138 | 139 | # compute logprob for each token 140 | completions_tokens = [] 141 | completions_logprob = 0 142 | 143 | for score, token in zip(scores, output_ids, strict=True): 144 | logprobs = F.log_softmax(score[0], dim=-1) 145 | logprob = logprobs[token].item() 146 | completions_tokens.append({ 147 | 'text': self.tokenizer.decode(token), 148 | 'logprob': logprob, 149 | }) 150 | completions_logprob += logprob 151 | 152 | completions_text = self.tokenizer.decode(output_ids, skip_special_tokens=True) 153 | 154 | 155 | return completions_text, completions_tokens -------------------------------------------------------------------------------- /llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | import math 6 | 7 | import transformers 8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 9 | 10 | from einops import rearrange 11 | 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], 25 | Optional[Tuple[torch.Tensor]]]: 26 | """Input shape: Batch x Time x Channel 27 | 28 | attention_mask: [bsz, q_len] 29 | """ 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 33 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 34 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 35 | # [bsz, q_len, nh, hd] 36 | # [bsz, nh, q_len, hd] 37 | 38 | kv_seq_len = key_states.shape[-2] 39 | # assert past_key_value is None, "past_key_value is not supported" 40 | if past_key_value is not None: 41 | kv_seq_len += past_key_value[0].shape[-2] 42 | 43 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 44 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 45 | # [bsz, nh, t, hd] 46 | # assert not output_attentions, "output_attentions is not supported" 47 | # assert not use_cache, "use_cache is not supported" 48 | if past_key_value is not None: 49 | # reuse k, v, self_attention 50 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 51 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 52 | 53 | past_key_value = (key_states, value_states) if use_cache else None 54 | 55 | # Flash attention codes from 56 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 57 | 58 | # transform the data into the format required by flash attention 59 | # import pdb; pdb.set_trace() 60 | 61 | if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2]: 62 | # decode token-by-token, do not use flash attention 63 | # in incremental state, do not use flash attention 64 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 65 | 66 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 67 | raise ValueError( 68 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 69 | f" {attn_weights.size()}" 70 | ) 71 | 72 | if attention_mask is not None: 73 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 74 | raise ValueError( 75 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 76 | ) 77 | attn_weights = attn_weights + attention_mask 78 | attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) 79 | 80 | # upcast attention to fp32 81 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 82 | attn_output = torch.matmul(attn_weights, value_states) 83 | 84 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 85 | raise ValueError( 86 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 87 | f" {attn_output.size()}" 88 | ) 89 | 90 | attn_output = attn_output.transpose(1, 2) 91 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 92 | 93 | attn_output = self.o_proj(attn_output) 94 | 95 | if not output_attentions: 96 | attn_weights = None 97 | 98 | return attn_output, attn_weights, past_key_value 99 | 100 | qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] 101 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 102 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 103 | # the attention_mask should be the same as the key_padding_mask 104 | key_padding_mask = attention_mask 105 | 106 | if key_padding_mask is None: 107 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 108 | max_s = q_len 109 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, 110 | device=qkv.device) 111 | output = flash_attn_varlen_qkvpacked_func( 112 | qkv, cu_q_lens, max_s, 0.0, 113 | softmax_scale=None, causal=True 114 | ) 115 | output = rearrange(output, '(b s) ... -> b s ...', b=bsz) 116 | else: 117 | nheads = qkv.shape[-2] 118 | x = rearrange(qkv, 'b s three h d -> b s (three h d)') 119 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 120 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) 121 | output_unpad = flash_attn_varlen_qkvpacked_func( 122 | x_unpad, cu_q_lens, max_s, 0.0, 123 | softmax_scale=None, causal=True 124 | ) 125 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), 126 | indices, bsz, q_len), 127 | 'b s (h d) -> b s h d', h=nheads) 128 | attn_output = self.o_proj(rearrange(output, 'b s h d -> b s (h d)')) 129 | 130 | return attn_output, None, past_key_value 131 | 132 | 133 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 134 | # requires the attention mask to be the same as the key_padding_mask 135 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, 136 | inputs_embeds, past_key_values_length): 137 | # [bsz, seq_len] 138 | if input_shape[-1] > 1 and past_key_values_length == 0: # encode 139 | return attention_mask 140 | return transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask(self, attention_mask, 141 | input_shape, 142 | inputs_embeds, 143 | past_key_values_length) 144 | 145 | 146 | def replace_llama_attn_with_flash_attn(): 147 | print('use FlashAttention') 148 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 149 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | import jieba 4 | from fuzzywuzzy import fuzz 5 | import difflib 6 | 7 | from typing import List 8 | from collections import Counter 9 | from rouge import Rouge 10 | 11 | def normalize_answer(s): 12 | """Lower text and remove punctuation, articles and extra whitespace.""" 13 | 14 | def remove_articles(text): 15 | return re.sub(r"\b(a|an|the)\b", " ", text) 16 | 17 | def white_space_fix(text): 18 | return " ".join(text.split()) 19 | 20 | def remove_punc(text): 21 | exclude = set(string.punctuation) 22 | return "".join(ch for ch in text if ch not in exclude) 23 | 24 | def lower(text): 25 | return text.lower() 26 | 27 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 28 | 29 | 30 | def normalize_zh_answer(s): 31 | """Lower text and remove punctuation, extra whitespace.""" 32 | 33 | def white_space_fix(text): 34 | return "".join(text.split()) 35 | 36 | def remove_punc(text): 37 | cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." 38 | all_punctuation = set(string.punctuation + cn_punctuation) 39 | return "".join(ch for ch in text if ch not in all_punctuation) 40 | 41 | def lower(text): 42 | return text.lower() 43 | 44 | return white_space_fix(remove_punc(lower(s))) 45 | 46 | def count_score(prediction, ground_truth, **kwargs): 47 | numbers = re.findall(r"\d+", prediction) 48 | right_num = 0 49 | for number in numbers: 50 | if str(number) == str(ground_truth): 51 | right_num += 1 52 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 53 | return float(final_score) 54 | 55 | def retrieval_score(prediction, ground_truth, **kwargs): 56 | pattern = r'Paragraph (\d+)' 57 | matches = re.findall(pattern, ground_truth) 58 | ground_truth_id = matches[0] 59 | numbers = re.findall(r"\d+", prediction) 60 | right_num = 0 61 | for number in numbers: 62 | if str(number) == str(ground_truth_id): 63 | right_num += 1 64 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 65 | return float(final_score) 66 | 67 | def retrieval_zh_score(prediction, ground_truth, **kwargs): 68 | pattern = r'段落(\d+)' 69 | matches = re.findall(pattern, ground_truth) 70 | ground_truth_id = matches[0] 71 | numbers = re.findall(r"\d+", prediction) 72 | right_num = 0 73 | for number in numbers: 74 | if str(number) == str(ground_truth_id): 75 | right_num += 1 76 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 77 | return float(final_score) 78 | 79 | def code_sim_score(prediction, ground_truth, **kwargs): 80 | all_lines = prediction.lstrip('\n').split('\n') 81 | prediction = "" 82 | for line in all_lines: 83 | if ('`' not in line) and ('#' not in line) and ('//' not in line): 84 | prediction = line 85 | break 86 | return (fuzz.ratio(prediction, ground_truth) / 100) 87 | 88 | def classification_score(prediction, ground_truth, **kwargs): 89 | em_match_list = [] 90 | all_classes = kwargs["all_classes"] 91 | for class_name in all_classes: 92 | if class_name in prediction: 93 | em_match_list.append(class_name) 94 | for match_term in em_match_list: 95 | if match_term in ground_truth and match_term != ground_truth: 96 | em_match_list.remove(match_term) 97 | if em_match_list != 0: 98 | if ground_truth in em_match_list: 99 | score = (1.0 / len(em_match_list)) 100 | else: 101 | score = 0.0 102 | else: 103 | best_match = None 104 | highest_similarity = 0 105 | for string in all_classes: 106 | similarity = difflib.SequenceMatcher(None, string, prediction).ratio() 107 | if similarity > highest_similarity: 108 | highest_similarity = similarity 109 | best_match = string 110 | score = float(best_match == ground_truth) 111 | return score 112 | 113 | def rouge_score(prediction, ground_truth, **kwargs): 114 | rouge = Rouge() 115 | try: 116 | scores = rouge.get_scores([prediction], [ground_truth], avg=True) 117 | except: 118 | return 0.0 119 | return scores["rouge-l"]["f"] 120 | 121 | def rouge_zh_score(prediction, ground_truth, **kwargs): 122 | prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) 123 | ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False))) 124 | score = rouge_score(prediction, ground_truth) 125 | return score 126 | 127 | def f1_score(prediction, ground_truth, **kwargs): 128 | common = Counter(prediction) & Counter(ground_truth) 129 | num_same = sum(common.values()) 130 | if num_same == 0: 131 | return 0 132 | precision = 1.0 * num_same / len(prediction) 133 | recall = 1.0 * num_same / len(ground_truth) 134 | f1 = (2 * precision * recall) / (precision + recall) 135 | return f1 136 | 137 | def qa_f1_score(prediction, ground_truth, **kwargs): 138 | normalized_prediction = normalize_answer(prediction) 139 | normalized_ground_truth = normalize_answer(ground_truth) 140 | 141 | prediction_tokens = normalized_prediction.split() 142 | ground_truth_tokens = normalized_ground_truth.split() 143 | return f1_score(prediction_tokens, ground_truth_tokens) 144 | 145 | 146 | def qa_f1_zh_score(prediction, ground_truth, **kwargs): 147 | prediction_tokens = list(jieba.cut(prediction, cut_all=False)) 148 | ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False)) 149 | prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] 150 | ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens] 151 | prediction_tokens = [token for token in prediction_tokens if len(token) > 0] 152 | ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] 153 | return f1_score(prediction_tokens, ground_truth_tokens) 154 | -------------------------------------------------------------------------------- /pictures/heat_datasets.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/WaterBench/8f3d779d66518a7b90ce1aad1fabaeb13cfca548/pictures/heat_datasets.pdf -------------------------------------------------------------------------------- /pictures/llama2-7b-chat-4k_mission_scatters1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/WaterBench/8f3d779d66518a7b90ce1aad1fabaeb13cfca548/pictures/llama2-7b-chat-4k_mission_scatters1.pdf -------------------------------------------------------------------------------- /pictures/llama2-7b-chat-4k_mission_scatters2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/WaterBench/8f3d779d66518a7b90ce1aad1fabaeb13cfca548/pictures/llama2-7b-chat-4k_mission_scatters2.pdf -------------------------------------------------------------------------------- /pictures/llama2-7b-chat-4k_radar.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/WaterBench/8f3d779d66518a7b90ce1aad1fabaeb13cfca548/pictures/llama2-7b-chat-4k_radar.pdf -------------------------------------------------------------------------------- /pictures/llama2-7b-chat-4k_search.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/WaterBench/8f3d779d66518a7b90ce1aad1fabaeb13cfca548/pictures/llama2-7b-chat-4k_search.pdf -------------------------------------------------------------------------------- /pictures/roc_auc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/WaterBench/8f3d779d66518a7b90ce1aad1fabaeb13cfca548/pictures/roc_auc.pdf -------------------------------------------------------------------------------- /pictures/scatters_all.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/WaterBench/8f3d779d66518a7b90ce1aad1fabaeb13cfca548/pictures/scatters_all.pdf -------------------------------------------------------------------------------- /pred.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datasets import load_dataset 3 | import torch, gc 4 | import json 5 | from transformers import AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM 6 | from tqdm import tqdm 7 | import numpy as np 8 | import random 9 | import argparse 10 | # from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 11 | from generate import Generator 12 | def str2bool(v): 13 | """Util function for user friendly boolean flag args""" 14 | if isinstance(v, bool): 15 | return v 16 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 17 | return True 18 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 19 | return False 20 | else: 21 | raise argparse.ArgumentTypeError('Boolean value expected.') 22 | 23 | def parse_args(args=None): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--model', type=str, default=None, choices=["llama2-7b-chat-4k", "chatglm2-6b-32k", "tulu-7b", "internlm-7b-8k"]) 26 | parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") 27 | 28 | # watermark args 29 | parser.add_argument( 30 | "--mode", 31 | type=str, 32 | default="old", 33 | choices=["no", "old", "new", "v2", "gpt"], 34 | help="Which version of the watermark to generate", 35 | ) 36 | parser.add_argument( 37 | '--initial_seed', 38 | type=int, 39 | default=1234, 40 | help=("The initial seed to use in the blacklist randomization process.", 41 | "Is unused if the process is markov generally. Can be None."), 42 | ) 43 | 44 | parser.add_argument( 45 | "--dynamic_seed", 46 | type=str, 47 | default="markov_1", 48 | choices=[None, "initial", "markov_1"], 49 | help="The seeding procedure to use when sampling the redlist at each step.", 50 | ) 51 | 52 | parser.add_argument( 53 | "--gamma", 54 | type=float, 55 | default=0.5) 56 | 57 | parser.add_argument( 58 | "--delta", 59 | type=float, 60 | default=5.0) 61 | 62 | parser.add_argument( 63 | "--bl_type", 64 | type=str, 65 | default="soft", 66 | choices=["soft", "hard"], 67 | help="The type of redlisting being performed.", 68 | ) 69 | parser.add_argument( 70 | "--num_beams", 71 | type=int, 72 | default=1, 73 | help="The number of beams to use where '1' is no beam search.", 74 | ) 75 | parser.add_argument( 76 | "--sampling_temp", 77 | type=float, 78 | default=0.7, 79 | help="The temperature to use when generating using multinom sampling", 80 | ) 81 | parser.add_argument( # for gpt watermark 82 | "--wm_key", 83 | type=int, 84 | default=0) 85 | 86 | parser.add_argument( 87 | "--threshold", 88 | type=float, 89 | default=4.0) 90 | 91 | parser.add_argument( 92 | "--test_min_tokens", 93 | type=int, 94 | default=2) 95 | 96 | parser.add_argument( 97 | "--start_point", 98 | type=int, 99 | default=0, 100 | ) 101 | 102 | 103 | 104 | parser.add_argument( # for v2 watermark 105 | "--seeding_scheme", 106 | type=str, 107 | default="simple_1", 108 | help="Seeding scheme to use to generate the greenlists at each generation and verification step.", 109 | ) 110 | 111 | parser.add_argument( # for v2 watermark 112 | "--normalizers", 113 | type=str, 114 | default="", 115 | help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.", 116 | ) 117 | 118 | parser.add_argument( # for v2 watermark 119 | "--ignore_repeated_bigrams", 120 | type=str2bool, 121 | default=False, 122 | help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.", 123 | ) 124 | 125 | parser.add_argument( # for v2 watermark 126 | "--select_green_tokens", 127 | type=str2bool, 128 | default=True, 129 | help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.", 130 | ) 131 | 132 | parser.add_argument( # for dataset 133 | "--dataset", 134 | type=str, 135 | default="all", 136 | choices=["konwledge_memorization","konwledge_understanding","longform_qa", 137 | "finance_qa","hotpotqa","lcc", "multi_news", "qmsum","alpacafarm", "all"], 138 | ) 139 | 140 | return parser.parse_args(args) 141 | 142 | # This is the customized building prompt for chat models 143 | def build_chat(tokenizer, prompt, model_name): 144 | if "chatglm" in model_name: 145 | prompt = tokenizer.build_prompt(prompt) 146 | elif "llama2" in model_name: 147 | prompt = f"[INST]{prompt}[/INST]" 148 | elif "xgen" in model_name: 149 | header = ( 150 | "A chat between a curious human and an artificial intelligence assistant. " 151 | "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n" 152 | ) 153 | prompt = header + f" ### Human: {prompt}\n###" 154 | elif "internlm" in model_name: 155 | prompt = f"<|User|>:{prompt}\n<|Bot|>:" 156 | elif "tulu" in model_name: 157 | prompt = f"<|user|>:{prompt}\n<|assistant|>:" 158 | return prompt 159 | 160 | def post_process(response, model_name): 161 | if "xgen" in model_name: 162 | response = response.strip().replace("Assistant:", "") 163 | elif "internlm" in model_name: 164 | response = response.split("")[0] 165 | return response 166 | 167 | def get_pred(watermark_args, model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name, debug: bool = False): 168 | preds = [] 169 | generator = Generator(watermark_args, tokenizer, model) 170 | torch.cuda.empty_cache() 171 | for json_obj in tqdm(data[watermark_args.start_point:]): 172 | # for json_obj in tqdm(data[2]): 173 | # json_obj = data[695] 174 | prompt = prompt_format.format(**json_obj) 175 | # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) 176 | tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] 177 | if len(tokenized_prompt) > max_length: 178 | half = int(max_length/2) 179 | prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) 180 | if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks 181 | prompt = build_chat(tokenizer, prompt, model_name) 182 | input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device) 183 | # output = model.generate( 184 | # **input, 185 | # max_new_tokens=max_gen, 186 | # num_beams=1, 187 | # do_sample=False, 188 | # temperature=1.0, 189 | # )[0] 190 | completions_text, completions_tokens = generator.generate(input_ids=input.input_ids, max_new_tokens=max_gen) 191 | # gc.collect() 192 | # torch.cuda.empty_cache() 193 | 194 | 195 | if debug: 196 | print("####################") 197 | 198 | pred = completions_text 199 | pred = post_process(pred, model_name) 200 | preds.append({"prompt":prompt, "pred": pred, "completions_tokens":completions_tokens, "answers": json_obj["outputs"], "all_classes": json_obj["all_classes"], "length":json_obj["length"]}) 201 | 202 | return preds 203 | 204 | def seed_everything(seed): 205 | torch.manual_seed(seed) 206 | torch.cuda.manual_seed(seed) 207 | np.random.seed(seed) 208 | random.seed(seed) 209 | torch.backends.cudnn.benchmark = False 210 | torch.backends.cudnn.deterministic = True 211 | torch.cuda.manual_seed_all(seed) 212 | 213 | def load_model_and_tokenizer(path, model_name, device, load_token_only=False): 214 | if "chatglm" in model_name or "internlm" in model_name or "xgen" in model_name: 215 | tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) 216 | if not load_token_only: 217 | model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True, 218 | output_scores=True, return_dict_in_generate=True, 219 | torch_dtype=torch.bfloat16).to(device) 220 | model.eval() 221 | elif "llama2" or "tulu" in model_name: 222 | # replace_llama_attn_with_flash_attn() 223 | tokenizer = LlamaTokenizer.from_pretrained(path) 224 | if not load_token_only: 225 | model = LlamaForCausalLM.from_pretrained(path, output_scores=True, return_dict_in_generate=True, 226 | torch_dtype=torch.bfloat16).to(device) 227 | if load_token_only: 228 | return tokenizer 229 | else: 230 | model = model.eval() 231 | return model, tokenizer 232 | 233 | if __name__ == '__main__': 234 | seed_everything(42) 235 | args = parse_args() 236 | model2path = json.load(open("config/model2path.json", "r")) 237 | model2maxlen = json.load(open("config/model2maxlen.json", "r")) 238 | 239 | 240 | # gpu_list=[1,3,4,5,6,7] 241 | # gpu_list_str = ','.join(map(str, gpu_list)) 242 | # os.environ.setdefault("CUDA_VISIBLE_DEVICES", gpu_list_str) 243 | # device_ids = list(range(torch.cuda.device_count())) 244 | # print("device_ids0 is", device_ids) 245 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 246 | 247 | model_name = args.model 248 | # define your model 249 | model, tokenizer = load_model_and_tokenizer(model2path[model_name], model_name, device) 250 | max_length = model2maxlen[model_name] 251 | if args.e: 252 | datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", \ 253 | "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"] 254 | else: 255 | datasets = ["konwledge_memorization","konwledge_understanding","longform_qa", 256 | "finance_qa","hotpotqa","lcc", "multi_news", "qmsum","alpacafarm"] 257 | # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output 258 | dataset2prompt = json.load(open("config/dataset2prompt.json", "r")) 259 | dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r")) 260 | dataset2level = json.load(open("config/dataset2level.json", "r")) 261 | # make dir for saving predictions 262 | if not os.path.exists("pred"): 263 | os.makedirs("pred") 264 | save_dir = f"pred/{model_name}_{args.mode}_g{args.gamma}_d{args.delta}" 265 | if args.bl_type == "hard": 266 | save_dir = f"pred/{model_name}_{args.mode}_g{args.gamma}_d{args.delta}_hard" 267 | if not os.path.exists(save_dir): 268 | os.makedirs(save_dir) 269 | # predict on each dataset 270 | if args.dataset == "all": 271 | for dataset in datasets: 272 | # load data 273 | print(f"{dataset} has began.........") 274 | data = [] 275 | with open("data/WaterBench/{}_{}.jsonl".format(dataset2level[dataset], dataset), "r", encoding="utf-8") as f: 276 | for line in f: 277 | data.append(json.loads(line)) 278 | out_path = os.path.join(save_dir, f"{dataset}.jsonl") 279 | prompt_format = dataset2prompt[dataset] 280 | max_gen = dataset2maxlen[dataset] 281 | preds = get_pred(args, model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name) 282 | with open(out_path, "w", encoding="utf-8") as f: 283 | for pred in preds: 284 | json.dump(pred, f, ensure_ascii=False) 285 | f.write('\n') 286 | 287 | else: 288 | dataset = args.dataset 289 | print(f"{dataset} has began.........") 290 | data = [] 291 | with open("data/WaterBench/{}_{}.jsonl".format(dataset2level[dataset], dataset), "r", encoding="utf-8") as f: 292 | for line in f: 293 | data.append(json.loads(line)) 294 | out_path = os.path.join(save_dir, f"{dataset}.jsonl") 295 | prompt_format = dataset2prompt[dataset] 296 | max_gen = dataset2maxlen[dataset] 297 | preds = get_pred(args, model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name) 298 | if os.path.exists(out_path): 299 | with open(out_path, "a", encoding="utf-8") as f: 300 | for pred in preds: 301 | json.dump(pred, f, ensure_ascii=False) 302 | f.write('\n') 303 | else: 304 | with open(out_path, "w", encoding="utf-8") as f: 305 | for pred in preds: 306 | json.dump(pred, f, ensure_ascii=False) 307 | f.write('\n') 308 | 309 | -------------------------------------------------------------------------------- /process/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/WaterBench/8f3d779d66518a7b90ce1aad1fabaeb13cfca548/process/__init__.py -------------------------------------------------------------------------------- /process/download.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import random 4 | import argparse 5 | import os 6 | from datasets import load_dataset 7 | from nltk.tokenize import word_tokenize 8 | import json 9 | import hashlib 10 | 11 | dataset2level = { 12 | "konwledge_memorization" : "1-1", 13 | "konwledge_understanding" : "1-2", 14 | "longform_qa" : "2-1", 15 | "finance_qa" : "2-2", 16 | "hotpotqa" : "3-1", 17 | "lcc" : "3-2", 18 | "multi_news" : "4-1", 19 | "qmsum" : "4-2", 20 | "alpacafarm" : "5-1", 21 | } 22 | 23 | origin_datasets = ["kola", "finance_qa", "eli5", "longbench", "alpacafarm"] 24 | 25 | def convert_to_sha256(string): 26 | # Encoding the string into bytes 27 | encoded_string = string.encode('utf-8') 28 | 29 | # Creating a SHA-256 hash object 30 | sha256_hash = hashlib.sha256() 31 | 32 | # Updating the hash object with the encoded string 33 | sha256_hash.update(encoded_string) 34 | 35 | # Obtaining the hexadecimal representation of the hash 36 | hex_digest = sha256_hash.hexdigest() 37 | 38 | return hex_digest 39 | 40 | 41 | def parse_args(args=None): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--dataset', type=str, default="all", choices=["all"]+origin_datasets) 44 | return parser.parse_args(args) 45 | 46 | def cal_len_and_output(_input, context, output, f, dataset): 47 | # calculate length 48 | input_total = _input + " " + context 49 | input_count = len(word_tokenize(input_total)) 50 | if type(output) == list: 51 | output_count = sum([ len(word_tokenize(o)) for o in output]) / len(output) 52 | else: 53 | output_count = len(word_tokenize(output)) 54 | json_obj = {} 55 | json_obj["input"] = _input # The input/command for the task, usually short, such as questions in QA, queries in Few-shot tasks, etc 56 | json_obj["context"] = context # The long context required for the task, such as documents, cross-file code, few-shot examples in Few-shot tasks 57 | if type(output) == list: 58 | json_obj["outputs"] = output 59 | else: 60 | json_obj["outputs"] = [output] 61 | json_obj["input_length"] = input_count 62 | json_obj["output_length"] = output_count 63 | json_obj["length"] = input_count + output_count 64 | # other keys 65 | json_obj["all_classes"] = "null" 66 | json_obj["language"] = "en" 67 | json_obj["dataset"] = dataset 68 | encode_str = "Tsinghua WaterBench " + json.dumps(json_obj) 69 | json_obj["_id"] = convert_to_sha256(encode_str) 70 | json.dump(json_obj, f) 71 | f.write("\n") 72 | 73 | 74 | def get_data(dataset_name): 75 | output_dir = f"./data/WaterBench/" 76 | if not os.path.exists(output_dir): 77 | os.makedirs(output_dir) 78 | if dataset_name == "longbench": 79 | datasets = [ "hotpotqa", "multi_news", "lcc", "qmsum"] 80 | for dataset in datasets: 81 | data = load_dataset('THUDM/LongBench', dataset, split='test') 82 | print(data) 83 | output_path = f"{output_dir}/{dataset2level[dataset]}_{dataset}.jsonl" 84 | with open(output_path, "w") as f: 85 | if len(data) > 200: 86 | data = data.shuffle(seed=42).select(range(200)) 87 | for json_obj in data: 88 | # calculate length 89 | cal_len_and_output(json_obj["input"], json_obj["context"], json_obj["answers"][0], f, dataset) 90 | elif dataset_name == "alpacafarm": 91 | alapaca_eval_data = load_dataset("tatsu-lab/alpaca_farm", "alpaca_farm_evaluation")["eval"] 92 | # load data and output 93 | output_path = f"{output_dir}/{dataset2level[dataset_name]}_{dataset_name}.jsonl" 94 | with open(output_path, "w") as fout: 95 | for json_obj in alapaca_eval_data: 96 | # calculate length 97 | cal_len_and_output(json_obj["input"], json_obj["instruction"], json_obj["output"], fout, dataset_name) 98 | elif dataset_name == "finance_qa": 99 | output_path = f"{output_dir}/{dataset2level[dataset_name]}_{dataset_name}.jsonl" 100 | # load 101 | with open(output_path, "w") as f: 102 | with open("/data2/tsq/WaterBench/data/download/HC3_en.jsonl", 'r') as fin: 103 | lines = fin.readlines()[19142:21142] 104 | json_with_len = [] 105 | for line in lines: 106 | json_obj = json.loads(line.strip()) 107 | refs = json_obj["human_answers"] 108 | output_count = sum([ len(word_tokenize(o)) for o in refs]) / len(refs) 109 | if output_count > 300 or "?" not in json_obj["question"]: 110 | continue 111 | json_with_len.append((json_obj, output_count)) 112 | # sort revese 113 | json_with_len.sort(key=lambda x: x[1], reverse=True) 114 | for json_obj, _ in json_with_len[:200]: 115 | question = json_obj["question"] 116 | refs = json_obj["human_answers"] 117 | cal_len_and_output(question, "", refs, f, dataset_name) 118 | elif dataset_name == "eli5": 119 | dataset = "longform_qa" 120 | output_path = f"{output_dir}/{dataset2level[dataset]}_{dataset}.jsonl" 121 | with open(output_path, "w") as f: 122 | with open("/data2/tsq/WaterBench/data/download/HC3_en.jsonl", 'r') as fin: 123 | lines = fin.readlines()[:2000] 124 | json_with_len = [] 125 | for line in lines: 126 | json_obj = json.loads(line.strip()) 127 | refs = json_obj["human_answers"] 128 | output_count = sum([ len(word_tokenize(o)) for o in refs]) / len(refs) 129 | if output_count > 300: 130 | continue 131 | json_with_len.append((json_obj, output_count)) 132 | # sort revese 133 | json_with_len.sort(key=lambda x: x[1], reverse=True) 134 | for json_obj, _ in json_with_len[:200]: 135 | question = json_obj["question"] 136 | refs = json_obj["human_answers"] 137 | cal_len_and_output(question, "", refs, f, dataset) 138 | elif dataset_name == "kola": 139 | # need 1-1, 1-2. COPEN 140 | datasets = [ "konwledge_memorization", "konwledge_understanding"] 141 | for dataset in datasets: 142 | if dataset == "konwledge_memorization": 143 | data_paths = ["/data2/cookie/input/KG/1_low_freq_ent/test.json", "/data2/cookie/input/KG/2_high_freq_ent/test.json"] 144 | else: 145 | data_paths = ["/data2/cookie/input/IE/COPEN/cpj/test.json", "/data2/cookie/input/IE/COPEN/csj/test.json"] 146 | output_path = f"{output_dir}/{dataset2level[dataset]}_{dataset}.jsonl" 147 | with open(output_path, "w") as f: 148 | for data_path in data_paths: 149 | data_file = json.load(open(data_path, 'r'))["request_states"] 150 | for _instance in data_file: 151 | instance = _instance["instance"] 152 | input = instance["input"]["text"] 153 | output = instance["references"][0]["output"]["text"] 154 | context = "" 155 | cal_len_and_output(input, context, output, f, dataset) 156 | 157 | 158 | if __name__ == '__main__': 159 | args = parse_args() 160 | if args.dataset == "all": 161 | for origin_dataset in origin_datasets: 162 | get_data(origin_dataset) 163 | else: 164 | get_data(args.dataset) -------------------------------------------------------------------------------- /process/hyper-search.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import re 4 | import os 5 | import pandas as pd 6 | import numpy as np 7 | 8 | 9 | def parse_args(args=None): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--model', type=str, default="internlm-7b-8k") 12 | parser.add_argument('--task', type=str, default="avg") 13 | parser.add_argument('--threshold', type=int, default=4) 14 | return parser.parse_args(args) 15 | 16 | 17 | def main(args): 18 | 19 | df = pd.DataFrame(columns=["model_name", "mode", "gamma", "delta", "threshold", "bl_type", "z_score", "true_positive", "false_negative","sum"]) 20 | 21 | input_dir = "./pred" 22 | p = r"(?P.+)_(?Pold|v2|gpt|new|no)_g(?P.+)_d(?P\d+(\.\d+)?)" 23 | p1 = r"(?P[a-zA-Z_]+)_(?P\d+(\.\d+)?)_(?P.+)_z" 24 | 25 | num = 0 26 | # get all files from input_dir 27 | for subfolder in os.listdir(input_dir): 28 | # print("subfolder is:", subfolder) 29 | matcher = re.match(p, subfolder) 30 | if matcher == None: 31 | continue 32 | model_name = matcher.group("model_name") 33 | # if model_name != "tulu-7b": 34 | if model_name != args.model: 35 | # if model_name != "tulu-7b": 36 | continue 37 | mode = matcher.group("mode") 38 | gamma = matcher.group("gamma") 39 | delta = matcher.group("delta") 40 | 41 | bl_type = "None" 42 | bl_type = (subfolder.split("_")[-1]).split(".")[0] 43 | 44 | print("bl_type", bl_type) 45 | if bl_type != "hard": 46 | if "old" in subfolder: 47 | bl_type = "soft" 48 | else: 49 | bl_type = "None" 50 | 51 | 52 | # print(model_name, mode, gamma, delta, bl_type) 53 | 54 | z_score_path = os.path.join(input_dir, subfolder, "z_score") 55 | if os.path.exists(z_score_path): 56 | 57 | print("subfolder is:", subfolder) 58 | for threshold in np.arange(0, 10, 0.1): 59 | files = os.listdir(z_score_path) 60 | temp_df = pd.DataFrame(columns=["model_name", "mode", "gamma", "delta", "threshold", "bl_type", "z_score", "sum"]) 61 | all_z = [] 62 | sums = [] 63 | tp = 0 64 | fn = 0 65 | for file in files: 66 | # print(file) 67 | # read jsons 68 | # matcher1 = re.match(p1, file) 69 | # if matcher1: 70 | # misson_name = matcher1.group("misson_name") 71 | # threshold = 4.0 72 | 73 | 74 | with open(os.path.join(z_score_path, file), "r") as f: 75 | data = json.load(f) 76 | # calculate tp and fn 77 | # threshold = args.threshold 78 | z_score_list = data["z_score_list"] 79 | _sum = len(data["z_score_list"]) 80 | tp += len([x for x in z_score_list if x > threshold]) 81 | # print("tp is:", tp) 82 | fn += len([x for x in z_score_list if x <= threshold]) 83 | # print("fn is:", fn) 84 | 85 | # get data 86 | avarage_z = data["avarage_z"] 87 | all_z.append(avarage_z * _sum) 88 | sums.append(_sum) 89 | num += 1 90 | 91 | # average z_score 92 | # print(temp_df) 93 | true_positive = tp / sum(sums) 94 | if bl_type == "hard": 95 | print("threshold:", threshold) 96 | print("true_positive", true_positive) 97 | if true_positive >= 0.69 and true_positive <= 0.71: 98 | # if bl_type == "hard": 99 | # print("hello1") 100 | temp_df = pd.DataFrame({ 101 | "model_name": [model_name], 102 | "mode": [mode], 103 | 104 | "gamma": [gamma], 105 | "delta":[delta], 106 | "threshold": [threshold], 107 | "bl_type": [bl_type], 108 | "z_score": [sum(all_z)/sum(sums)], 109 | "true_positive": [tp/sum(sums)], 110 | "false_negative": [fn/sum(sums)], 111 | "sum": [sum(sums)]}) 112 | df = pd.concat([df, temp_df], ignore_index=True) 113 | 114 | # df = df.sort_values(by="threshold", ascending=True) 115 | # df.drop(columns=["mission_name"], inplace=True) 116 | df.to_csv(f"csv_data/mod_threshold_{args.model}2.csv") 117 | 118 | deltas = [2, 5, 10, 15] 119 | gammas = [] 120 | if 'llama' in args.model: 121 | gammas = [0.1, 0.25, 0.5, 0.75, 0.9] 122 | elif 'internlm' in args.model: 123 | gammas = [0.1, 0.15, 0.25, 0.5, 0.75, 0.9] 124 | 125 | for gamma in gammas: 126 | # true_po = df[(df['bl_type'] == 'hard') & (df['gamma'] == str(gamma))]['true_positive'] 127 | sub_df = df[(df['bl_type'] == 'hard') & (df['gamma'] == str(gamma))] 128 | 129 | # print("true_po is", true_po) 130 | print("len sub_df is", len(sub_df['true_positive'])) 131 | if len(sub_df['true_positive']) != 0: 132 | # print("1 point") 133 | for delta in deltas: 134 | # print(sub_df['delta'].values.astype(float)) 135 | if (float(delta) not in sub_df['delta'].values.astype(float)): 136 | print("2 point") 137 | temp_df = pd.DataFrame({ 138 | "model_name": [sub_df['model_name'].values[0]], 139 | "mode": ['old'], 140 | "gamma": [gamma], 141 | "delta":[delta], 142 | "threshold": [threshold], 143 | "bl_type": ['hard'], 144 | "z_score": [sub_df['z_score'].values[0]], 145 | "true_positive": [sub_df['true_positive'].values[0]], 146 | "false_negative": [sub_df['false_negative'].values[0]], 147 | "sum": [sub_df['sum'].values[0]]}) 148 | 149 | df = pd.concat([df, temp_df], ignore_index=True) 150 | 151 | df = df.sort_values(by="threshold", ascending=True) 152 | df.to_csv(f"csv_data/mod_threshold_{args.model}.csv") 153 | 154 | print(df) 155 | print(num) 156 | 157 | if __name__ == '__main__': 158 | args = parse_args() 159 | main(args) 160 | -------------------------------------------------------------------------------- /process/load.py: -------------------------------------------------------------------------------- 1 | from process.download import dataset2level 2 | from typing import List, Dict 3 | import pandas as pd 4 | import json 5 | 6 | class DataLoader(): 7 | def __init__(self, dataset_names: list) -> None: 8 | if dataset_names == "all": 9 | dataset_names = list(dataset2level.keys()) 10 | self.dataset_names = dataset_names 11 | 12 | def load_data(self) -> List[Dict]: 13 | data = [] 14 | for dataset_name in self.dataset_names: 15 | data_path = "./data/WaterBench/" + dataset2level[dataset_name] + "_" + dataset_name + ".jsonl" 16 | # load 17 | with open(data_path, "r") as fin: 18 | lines = fin.readlines() 19 | for line in lines: 20 | json_obj = json.loads(line.strip()) 21 | data.append(json_obj) 22 | return data 23 | 24 | if __name__ == '__main__': 25 | dataloader = DataLoader("all") 26 | for name in dataloader.dataset_names: 27 | print(name) 28 | new_loader = DataLoader([name]) 29 | datasets = new_loader.load_data() 30 | df = pd.DataFrame(datasets) 31 | # aggeregate the "input_length", "output_length", and "length" 32 | print(df.describe()) 33 | 34 | -------------------------------------------------------------------------------- /process/process_z.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import re 4 | import os 5 | import pandas as pd 6 | import numpy as np 7 | 8 | 9 | def parse_args(args=None): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--model', type=str, default="internlm-7b-8k") 12 | parser.add_argument('--task', type=str, default="avg") 13 | parser.add_argument('--threshold', type=int, default=4) 14 | return parser.parse_args(args) 15 | 16 | 17 | def main(args): 18 | 19 | df = pd.DataFrame(columns=["model_name", "mission_name", "mode", "gamma", "delta", "threshold", "bl_type", "z_score", "true_positive", "false_negative","sum"]) 20 | 21 | input_dir = "./pred" 22 | p = r"(?P.+)_(?Pold|v2|gpt|new|no)_g(?P.+)_d(?P\d+(\.\d+)?)" 23 | p1 = r"(?P[a-zA-Z_]+)_(?P\d+(\.\d+)?)_(?P.+)_z" 24 | 25 | num = 0 26 | # get all files from input_dir 27 | for subfolder in os.listdir(input_dir): 28 | # print("subfolder is:", subfolder) 29 | matcher = re.match(p, subfolder) 30 | if matcher == None: 31 | continue 32 | model_name = matcher.group("model_name") 33 | # if model_name != "tulu-7b": 34 | if model_name != args.model: 35 | # if model_name != "tulu-7b": 36 | continue 37 | mode = matcher.group("mode") 38 | gamma = matcher.group("gamma") 39 | delta = matcher.group("delta") 40 | 41 | bl_type = "None" 42 | bl_type = (subfolder.split("_")[-1]).split(".")[0] 43 | 44 | if bl_type != "hard": 45 | if "old" in subfolder: 46 | bl_type = "soft" 47 | else: 48 | bl_type = "None" 49 | 50 | 51 | # print(model_name, mode, gamma, delta, bl_type) 52 | 53 | z_score_path = os.path.join(input_dir, subfolder, "z_score") 54 | if os.path.exists(z_score_path): 55 | print("subfolder is:", subfolder) 56 | files = os.listdir(z_score_path) 57 | temp_df = pd.DataFrame(columns=["model_name", "mission_name", "mode", "gamma", "delta", "threshold", "bl_type", "z_score", "sum"]) 58 | all_z = [] 59 | sums = [] 60 | tp = 0 61 | fn = 0 62 | for file in files: 63 | # print(file) 64 | # read jsons 65 | matcher1 = re.match(p1, file) 66 | if matcher1: 67 | misson_name = matcher1.group("misson_name") 68 | threshold = 4.0 69 | else: 70 | threshold = file.split("_")[-2] 71 | 72 | with open(os.path.join(z_score_path, file), "r") as f: 73 | data = json.load(f) 74 | # calculate tp and fn 75 | threshold = args.threshold 76 | z_score_list = data["z_score_list"] 77 | _sum = len(data["z_score_list"]) 78 | tp += len([x for x in z_score_list if x > threshold]) 79 | # print("tp is:", tp) 80 | fn += len([x for x in z_score_list if x <= threshold]) 81 | # print("fn is:", fn) 82 | 83 | # get data 84 | avarage_z = data["avarage_z"] 85 | all_z.append(avarage_z * _sum) 86 | sums.append(_sum) 87 | num += 1 88 | 89 | # average z_score 90 | # print(temp_df) 91 | temp_df = pd.DataFrame({ 92 | "model_name": [model_name], 93 | "mode": [mode], 94 | "mission_name": [misson_name], 95 | "gamma": [gamma], 96 | "delta":[delta], 97 | "threshold": [threshold], 98 | "bl_type": [bl_type], 99 | "z_score": [sum(all_z)/sum(sums)], 100 | "true_positive": [tp/sum(sums)], 101 | "false_negative": [fn/sum(sums)], 102 | "sum": [sum(sums)]}) 103 | df = pd.concat([df, temp_df], ignore_index=True) 104 | df = df.sort_values(by="true_positive", ascending=True) 105 | df.drop(columns=["mission_name"], inplace=True) 106 | 107 | deltas = [2, 5, 10, 15] 108 | gammas = [] 109 | if 'llama' in args.model: 110 | gammas = [0.1, 0.25, 0.5, 0.75, 0.9] 111 | elif 'internlm' in args.model: 112 | gammas = [0.1, 0.15, 0.25, 0.5, 0.75, 0.9] 113 | 114 | for gamma in gammas: 115 | # true_po = df[(df['bl_type'] == 'hard') & (df['gamma'] == str(gamma))]['true_positive'] 116 | sub_df = df[(df['bl_type'] == 'hard') & (df['gamma'] == str(gamma))] 117 | 118 | # print("true_po is", true_po) 119 | print("len sub_df is", len(sub_df['true_positive'])) 120 | if len(sub_df['true_positive']) != 0: 121 | # print("1 point") 122 | for delta in deltas: 123 | # print(sub_df['delta'].values.astype(float)) 124 | if (float(delta) not in sub_df['delta'].values.astype(float)): 125 | print("2 point") 126 | temp_df = pd.DataFrame({ 127 | "model_name": [sub_df['model_name'].values[0]], 128 | "mode": ['old'], 129 | "gamma": [gamma], 130 | "delta":[delta], 131 | "threshold": [threshold], 132 | "bl_type": ['hard'], 133 | "z_score": [sub_df['z_score'].values[0]], 134 | "true_positive": [sub_df['true_positive'].values[0]], 135 | "false_negative": [sub_df['false_negative'].values[0]], 136 | "sum": [sub_df['sum'].values[0]]}) 137 | 138 | df = pd.concat([df, temp_df], ignore_index=True) 139 | 140 | df = df.sort_values(by="true_positive", ascending=True) 141 | df.to_csv(f"csv_data/z_score_avg_{args.model}.csv") 142 | 143 | print(df) 144 | print(num) 145 | 146 | if __name__ == '__main__': 147 | args = parse_args() 148 | main(args) 149 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | alpaca_farm==0.1.9 2 | cpm_kernels==1.0.11 3 | datasets==2.13.0 4 | einops==0.7.0 5 | flash_attn==2.3.3 6 | fuzzywuzzy==0.18.0 7 | jieba==0.42.1 8 | matplotlib==3.7.1 9 | nltk==3.8.1 10 | numpy==1.25.0 11 | openai==0.28.0 12 | pandas==2.0.2 13 | rouge==1.0.1 14 | scipy==1.11.3 15 | seaborn==0.13.0 16 | sentencepiece==0.1.99 17 | spacy==3.7.2 18 | tokenizers==0.13.3 19 | torch==2.0.1 20 | tqdm==4.65.0 21 | transformers==4.31.0 22 | -------------------------------------------------------------------------------- /shell/detect.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | python detect.py \ 4 | --input_dir ./pred/llama2-7b-chat-4k_old_g0.25_d5.0_hard \ 5 | 6 | -------------------------------------------------------------------------------- /shell/detect_human.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | python detect_human.py \ 4 | --reference_dir llama2-7b-chat-4k_old_g0.25_d5.0_hard \ 5 | --detect_dir human_generation \ 6 | -------------------------------------------------------------------------------- /shell/eval.sh: -------------------------------------------------------------------------------- 1 | python eval.py \ 2 | --input_dir ./pred/llama2-7b-chat-4k_old_g0.25_d5.0_hard 3 | -------------------------------------------------------------------------------- /shell/pred.sh: -------------------------------------------------------------------------------- 1 | # change whatever you need 2 | export CUDA_VISIBLE_DEVICES=0 3 | 4 | # old watermark 5 | python pred.py \ 6 | --mode old \ 7 | --gamma 0.25 \ 8 | --delta 5 \ 9 | --bl_type hard \ 10 | --model llama2-7b-chat-4k \ 11 | -------------------------------------------------------------------------------- /watermark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/WaterBench/8f3d779d66518a7b90ce1aad1fabaeb13cfca548/watermark/__init__.py -------------------------------------------------------------------------------- /watermark/gptwm.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | from transformers import LogitsWarper 7 | 8 | 9 | class GPTWatermarkBase: 10 | """ 11 | Base class for watermarking distributions with fixed-group green-listed tokens. 12 | 13 | Args: 14 | fraction: The fraction of the distribution to be green-listed. 15 | strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens. 16 | vocab_size: The size of the vocabulary. 17 | watermark_key: The random seed for the green-listing. 18 | """ 19 | # llama2-7b:32000 20 | # chatglm2:32000 21 | def __init__(self, fraction: float = 0.5, strength: float = 2.0, vocab_size: int = 32000, watermark_key: int = 0): 22 | rng = np.random.default_rng(self._hash_fn(watermark_key)) 23 | mask = np.array([True] * int(fraction * vocab_size) + [False] * (vocab_size - int(fraction * vocab_size))) 24 | rng.shuffle(mask) 25 | self.green_list_mask = torch.tensor(mask, dtype=torch.float32) 26 | self.strength = strength 27 | self.fraction = fraction 28 | 29 | @staticmethod 30 | def _hash_fn(x: int) -> int: 31 | """solution from https://stackoverflow.com/questions/67219691/python-hash-function-that-returns-32-or-64-bits""" 32 | x = np.int64(x) 33 | return int.from_bytes(hashlib.sha256(x).digest()[:4], 'little') 34 | 35 | 36 | class GPTWatermarkLogitsWarper(GPTWatermarkBase, LogitsWarper): 37 | """ 38 | LogitsWarper for watermarking distributions with fixed-group green-listed tokens. 39 | 40 | Args: 41 | fraction: The fraction of the distribution to be green-listed. 42 | strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens. 43 | vocab_size: The size of the vocabulary. 44 | watermark_key: The random seed for the green-listing. 45 | """ 46 | 47 | def __init__(self, *args, **kwargs): 48 | super().__init__(*args, **kwargs) 49 | 50 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor: 51 | """Add the watermark to the logits and return new logits.""" 52 | # print("green_list_mask.shape is", self.green_list_mask.shape) 53 | 54 | watermark = self.strength * self.green_list_mask 55 | # print("scores shape is", scores.shape) 56 | # print("watermark shape is", watermark.shape) 57 | new_logits = scores + watermark.to(scores.device) 58 | 59 | return new_logits 60 | 61 | 62 | class GPTWatermarkDetector(GPTWatermarkBase): 63 | """ 64 | Class for detecting watermarks in a sequence of tokens. 65 | 66 | Args: 67 | fraction: The fraction of the distribution to be green-listed. 68 | strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens. 69 | vocab_size: The size of the vocabulary. 70 | watermark_key: The random seed for the green-listing. 71 | """ 72 | 73 | def __init__(self, *args, **kwargs): 74 | super().__init__(*args, **kwargs) 75 | 76 | @staticmethod 77 | def _z_score(num_green: int, total: int, fraction: float) -> float: 78 | """Calculate and return the z-score of the number of green tokens in a sequence.""" 79 | return (num_green - fraction * total) / np.sqrt(fraction * (1 - fraction) * total) 80 | 81 | def detect(self, sequence: List[int]) -> float: 82 | """Detect the watermark in a sequence of tokens and return the z value.""" 83 | green_tokens = int(sum(self.green_list_mask[i] for i in sequence)) 84 | 85 | return self._z_score(green_tokens, len(sequence), self.fraction) 86 | -------------------------------------------------------------------------------- /watermark/homoglyphs.py: -------------------------------------------------------------------------------- 1 | """Updated version of core.py from 2 | https://github.com/yamatt/homoglyphs/tree/main/homoglyphs_fork 3 | for modern python3 4 | """ 5 | 6 | from collections import defaultdict 7 | import json 8 | from itertools import product 9 | import os 10 | import unicodedata 11 | 12 | # Actions if char not in alphabet 13 | STRATEGY_LOAD = 1 # load category for this char 14 | STRATEGY_IGNORE = 2 # add char to result 15 | STRATEGY_REMOVE = 3 # remove char from result 16 | 17 | ASCII_RANGE = range(128) 18 | 19 | 20 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 21 | DATA_LOCATION = os.path.join(CURRENT_DIR, "homoglyph_data") 22 | 23 | 24 | class Categories: 25 | """ 26 | Work with aliases from ISO 15924. 27 | https://en.wikipedia.org/wiki/ISO_15924#List_of_codes 28 | """ 29 | 30 | fpath = os.path.join(DATA_LOCATION, "categories.json") 31 | 32 | @classmethod 33 | def _get_ranges(cls, categories): 34 | """ 35 | :return: iter: (start code, end code) 36 | :rtype: list 37 | """ 38 | with open(cls.fpath, encoding="utf-8") as f: 39 | data = json.load(f) 40 | 41 | for category in categories: 42 | if category not in data["aliases"]: 43 | raise ValueError("Invalid category: {}".format(category)) 44 | 45 | for point in data["points"]: 46 | if point[2] in categories: 47 | yield point[:2] 48 | 49 | @classmethod 50 | def get_alphabet(cls, categories): 51 | """ 52 | :return: set of chars in alphabet by categories list 53 | :rtype: set 54 | """ 55 | alphabet = set() 56 | for start, end in cls._get_ranges(categories): 57 | chars = (chr(code) for code in range(start, end + 1)) 58 | alphabet.update(chars) 59 | return alphabet 60 | 61 | @classmethod 62 | def detect(cls, char): 63 | """ 64 | :return: category 65 | :rtype: str 66 | """ 67 | with open(cls.fpath, encoding="utf-8") as f: 68 | data = json.load(f) 69 | 70 | # try detect category by unicodedata 71 | try: 72 | category = unicodedata.name(char).split()[0] 73 | except (TypeError, ValueError): 74 | # In Python2 unicodedata.name raise error for non-unicode chars 75 | # Python3 raise ValueError for non-unicode characters 76 | pass 77 | else: 78 | if category in data["aliases"]: 79 | return category 80 | 81 | # try detect category by ranges from JSON file. 82 | code = ord(char) 83 | for point in data["points"]: 84 | if point[0] <= code <= point[1]: 85 | return point[2] 86 | 87 | @classmethod 88 | def get_all(cls): 89 | with open(cls.fpath, encoding="utf-8") as f: 90 | data = json.load(f) 91 | return set(data["aliases"]) 92 | 93 | 94 | class Languages: 95 | fpath = os.path.join(DATA_LOCATION, "languages.json") 96 | 97 | @classmethod 98 | def get_alphabet(cls, languages): 99 | """ 100 | :return: set of chars in alphabet by languages list 101 | :rtype: set 102 | """ 103 | with open(cls.fpath, encoding="utf-8") as f: 104 | data = json.load(f) 105 | alphabet = set() 106 | for lang in languages: 107 | if lang not in data: 108 | raise ValueError("Invalid language code: {}".format(lang)) 109 | alphabet.update(data[lang]) 110 | return alphabet 111 | 112 | @classmethod 113 | def detect(cls, char): 114 | """ 115 | :return: set of languages which alphabet contains passed char. 116 | :rtype: set 117 | """ 118 | with open(cls.fpath, encoding="utf-8") as f: 119 | data = json.load(f) 120 | languages = set() 121 | for lang, alphabet in data.items(): 122 | if char in alphabet: 123 | languages.add(lang) 124 | return languages 125 | 126 | @classmethod 127 | def get_all(cls): 128 | with open(cls.fpath, encoding="utf-8") as f: 129 | data = json.load(f) 130 | return set(data.keys()) 131 | 132 | 133 | class Homoglyphs: 134 | def __init__( 135 | self, 136 | categories=None, 137 | languages=None, 138 | alphabet=None, 139 | strategy=STRATEGY_IGNORE, 140 | ascii_strategy=STRATEGY_IGNORE, 141 | ascii_range=ASCII_RANGE, 142 | ): 143 | # strategies 144 | if strategy not in (STRATEGY_LOAD, STRATEGY_IGNORE, STRATEGY_REMOVE): 145 | raise ValueError("Invalid strategy") 146 | self.strategy = strategy 147 | self.ascii_strategy = ascii_strategy 148 | self.ascii_range = ascii_range 149 | 150 | # Homoglyphs must be initialized by any alphabet for correct work 151 | if not categories and not languages and not alphabet: 152 | categories = ("LATIN", "COMMON") 153 | 154 | # cats and langs 155 | self.categories = set(categories or []) 156 | self.languages = set(languages or []) 157 | 158 | # alphabet 159 | self.alphabet = set(alphabet or []) 160 | if self.categories: 161 | alphabet = Categories.get_alphabet(self.categories) 162 | self.alphabet.update(alphabet) 163 | if self.languages: 164 | alphabet = Languages.get_alphabet(self.languages) 165 | self.alphabet.update(alphabet) 166 | self.table = self.get_table(self.alphabet) 167 | 168 | @staticmethod 169 | def get_table(alphabet): 170 | table = defaultdict(set) 171 | with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f: 172 | data = json.load(f) 173 | for char in alphabet: 174 | if char in data: 175 | for homoglyph in data[char]: 176 | if homoglyph in alphabet: 177 | table[char].add(homoglyph) 178 | return table 179 | 180 | @staticmethod 181 | def get_restricted_table(source_alphabet, target_alphabet): 182 | table = defaultdict(set) 183 | with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f: 184 | data = json.load(f) 185 | for char in source_alphabet: 186 | if char in data: 187 | for homoglyph in data[char]: 188 | if homoglyph in target_alphabet: 189 | table[char].add(homoglyph) 190 | return table 191 | 192 | @staticmethod 193 | def uniq_and_sort(data): 194 | result = list(set(data)) 195 | result.sort(key=lambda x: (-len(x), x)) 196 | return result 197 | 198 | def _update_alphabet(self, char): 199 | # try detect languages 200 | langs = Languages.detect(char) 201 | if langs: 202 | self.languages.update(langs) 203 | alphabet = Languages.get_alphabet(langs) 204 | self.alphabet.update(alphabet) 205 | else: 206 | # try detect categories 207 | category = Categories.detect(char) 208 | if category is None: 209 | return False 210 | self.categories.add(category) 211 | alphabet = Categories.get_alphabet([category]) 212 | self.alphabet.update(alphabet) 213 | # update table for new alphabet 214 | self.table = self.get_table(self.alphabet) 215 | return True 216 | 217 | def _get_char_variants(self, char): 218 | if char not in self.alphabet: 219 | if self.strategy == STRATEGY_LOAD: 220 | if not self._update_alphabet(char): 221 | return [] 222 | elif self.strategy == STRATEGY_IGNORE: 223 | return [char] 224 | elif self.strategy == STRATEGY_REMOVE: 225 | return [] 226 | 227 | # find alternative chars for current char 228 | alt_chars = self.table.get(char, set()) 229 | if alt_chars: 230 | # find alternative chars for alternative chars for current char 231 | alt_chars2 = [self.table.get(alt_char, set()) for alt_char in alt_chars] 232 | # combine all alternatives 233 | alt_chars.update(*alt_chars2) 234 | # add current char to alternatives 235 | alt_chars.add(char) 236 | 237 | # uniq, sort and return 238 | return self.uniq_and_sort(alt_chars) 239 | 240 | def _get_combinations(self, text, ascii=False): 241 | variations = [] 242 | for char in text: 243 | alt_chars = self._get_char_variants(char) 244 | 245 | if ascii: 246 | alt_chars = [char for char in alt_chars if ord(char) in self.ascii_range] 247 | if not alt_chars and self.ascii_strategy == STRATEGY_IGNORE: 248 | return 249 | 250 | if alt_chars: 251 | variations.append(alt_chars) 252 | if variations: 253 | for variant in product(*variations): 254 | yield "".join(variant) 255 | 256 | def get_combinations(self, text): 257 | return list(self._get_combinations(text)) 258 | 259 | def _to_ascii(self, text): 260 | for variant in self._get_combinations(text, ascii=True): 261 | if max(map(ord, variant)) in self.ascii_range: 262 | yield variant 263 | 264 | def to_ascii(self, text): 265 | return self.uniq_and_sort(self._to_ascii(text)) 266 | -------------------------------------------------------------------------------- /watermark/normalizers.py: -------------------------------------------------------------------------------- 1 | """ Text-based normalizers, used to mitigate simple attacks against watermarking. 2 | 3 | This implementation is unlikely to be a complete list of all possible exploits within the unicode standard, 4 | it represents our best effort at the time of writing. 5 | 6 | These normalizers can be used as stand-alone normalizers. They could be made to conform to HF tokenizers standard, but that would 7 | require messing with the limited rust interface of tokenizers.NormalizedString 8 | """ 9 | from collections import defaultdict 10 | from functools import cache 11 | 12 | import re 13 | import unicodedata 14 | import watermark.homoglyphs as hg 15 | 16 | 17 | def normalization_strategy_lookup(strategy_name: str) -> object: 18 | if strategy_name == "unicode": 19 | return UnicodeSanitizer() 20 | elif strategy_name == "homoglyphs": 21 | return HomoglyphCanonizer() 22 | elif strategy_name == "truecase": 23 | return TrueCaser() 24 | 25 | 26 | class HomoglyphCanonizer: 27 | """Attempts to detect homoglyph attacks and find a consistent canon. 28 | 29 | This function does so on a per-ISO-category level. Language-level would also be possible (see commented code). 30 | """ 31 | 32 | def __init__(self): 33 | self.homoglyphs = None 34 | 35 | def __call__(self, homoglyphed_str: str) -> str: 36 | # find canon: 37 | target_category, all_categories = self._categorize_text(homoglyphed_str) 38 | homoglyph_table = self._select_canon_category_and_load(target_category, all_categories) 39 | return self._sanitize_text(target_category, homoglyph_table, homoglyphed_str) 40 | 41 | def _categorize_text(self, text: str) -> dict: 42 | iso_categories = defaultdict(int) 43 | # self.iso_languages = defaultdict(int) 44 | 45 | for char in text: 46 | iso_categories[hg.Categories.detect(char)] += 1 47 | # for lang in hg.Languages.detect(char): 48 | # self.iso_languages[lang] += 1 49 | target_category = max(iso_categories, key=iso_categories.get) 50 | all_categories = tuple(iso_categories) 51 | return target_category, all_categories 52 | 53 | @cache 54 | def _select_canon_category_and_load(self, target_category: str, all_categories: tuple[str]) -> dict: 55 | homoglyph_table = hg.Homoglyphs(categories=(target_category, "COMMON")) # alphabet loaded here from file 56 | 57 | source_alphabet = hg.Categories.get_alphabet(all_categories) 58 | restricted_table = homoglyph_table.get_restricted_table(source_alphabet, homoglyph_table.alphabet) # table loaded here from file 59 | return restricted_table 60 | 61 | def _sanitize_text(self, target_category: str, homoglyph_table: dict, homoglyphed_str: str) -> str: 62 | sanitized_text = "" 63 | for char in homoglyphed_str: 64 | # langs = hg.Languages.detect(char) 65 | cat = hg.Categories.detect(char) 66 | if target_category in cat or "COMMON" in cat or len(cat) == 0: 67 | sanitized_text += char 68 | else: 69 | sanitized_text += list(homoglyph_table[char])[0] 70 | return sanitized_text 71 | 72 | 73 | class UnicodeSanitizer: 74 | """Regex-based unicode sanitzer. Has different levels of granularity. 75 | 76 | * ruleset="whitespaces" - attempts to remove only whitespace unicode characters 77 | * ruleset="IDN.blacklist" - does its best to remove unusual unicode based on Network.IDN.blacklist characters 78 | * ruleset="ascii" - brute-forces all text into ascii 79 | 80 | This is unlikely to be a comprehensive list. 81 | 82 | You can find a more comprehensive discussion at https://www.unicode.org/reports/tr36/ 83 | and https://www.unicode.org/faq/security.html 84 | """ 85 | 86 | def __init__(self, ruleset="whitespaces"): 87 | if ruleset == "whitespaces": 88 | 89 | """Documentation: 90 | \u00A0: Non-breaking space 91 | \u1680: Ogham space mark 92 | \u180E: Mongolian vowel separator 93 | \u2000-\u200B: Various space characters, including en space, em space, thin space, hair space, zero-width space, and zero-width non-joiner 94 | \u200C\u200D: Zero-width non-joiner and zero-width joiner 95 | \u200E,\u200F: Left-to-right-mark, Right-to-left-mark 96 | \u2060: Word joiner 97 | \u2063: Invisible separator 98 | \u202F: Narrow non-breaking space 99 | \u205F: Medium mathematical space 100 | \u3000: Ideographic space 101 | \uFEFF: Zero-width non-breaking space 102 | \uFFA0: Halfwidth hangul filler 103 | \uFFF9\uFFFA\uFFFB: Interlinear annotation characters 104 | \uFE00-\uFE0F: Variation selectors 105 | \u202A-\u202F: Embedding characters 106 | \u3164: Korean hangul filler. 107 | 108 | Note that these characters are not always superfluous whitespace characters! 109 | """ 110 | 111 | self.pattern = re.compile( 112 | r"[\u00A0\u1680\u180E\u2000-\u200B\u200C\u200D\u200E\u200F\u2060\u2063\u202F\u205F\u3000\uFEFF\uFFA0\uFFF9\uFFFA\uFFFB" 113 | r"\uFE00\uFE01\uFE02\uFE03\uFE04\uFE05\uFE06\uFE07\uFE08\uFE09\uFE0A\uFE0B\uFE0C\uFE0D\uFE0E\uFE0F\u3164\u202A\u202B\u202C\u202D" 114 | r"\u202E\u202F]" 115 | ) 116 | elif ruleset == "IDN.blacklist": 117 | 118 | """Documentation: 119 | [\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF]: Matches any whitespace characters in the Unicode character 120 | set that are included in the IDN blacklist. 121 | \uFFF9-\uFFFB: Matches characters that are not defined in Unicode but are used as language tags in various legacy encodings. 122 | These characters are not allowed in domain names. 123 | \uD800-\uDB7F: Matches the first part of a surrogate pair. Surrogate pairs are used to represent characters in the Unicode character 124 | set that cannot be represented by a single 16-bit value. The first part of a surrogate pair is in the range U+D800 to U+DBFF, 125 | and the second part is in the range U+DC00 to U+DFFF. 126 | \uDB80-\uDBFF][\uDC00-\uDFFF]?: Matches the second part of a surrogate pair. The second part of a surrogate pair is in the range U+DC00 127 | to U+DFFF, and is optional. 128 | [\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]: Matches certain invalid UTF-16 sequences which should not appear in IDNs. 129 | """ 130 | 131 | self.pattern = re.compile( 132 | r"[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF\uFFF9-\uFFFB\uD800-\uDB7F\uDB80-\uDBFF]" 133 | r"[\uDC00-\uDFFF]?|[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]" 134 | ) 135 | else: 136 | """Documentation: 137 | This is a simple restriction to "no-unicode", using only ascii characters. Control characters are included. 138 | """ 139 | self.pattern = re.compile(r"[^\x00-\x7F]+") 140 | 141 | def __call__(self, text: str) -> str: 142 | text = unicodedata.normalize("NFC", text) # canon forms 143 | text = self.pattern.sub(" ", text) # pattern match 144 | text = re.sub(" +", " ", text) # collapse whitespaces 145 | text = "".join(c for c in text if unicodedata.category(c) != "Cc") # Remove any remaining non-printable characters 146 | return text 147 | 148 | 149 | class TrueCaser: 150 | """True-casing, is a capitalization normalization that returns text to its original capitalization. 151 | 152 | This defends against attacks that wRIte TeXt lIkE spOngBoB. 153 | 154 | Here, a simple POS-tagger is used. 155 | """ 156 | 157 | uppercase_pos = ["PROPN"] # Name POS tags that should be upper-cased 158 | 159 | def __init__(self, backend="spacy"): 160 | if backend == "spacy": 161 | import spacy 162 | 163 | self.nlp = spacy.load("en_core_web_sm") 164 | self.normalize_fn = self._spacy_truecasing 165 | else: 166 | from nltk import pos_tag, word_tokenize # noqa 167 | import nltk 168 | 169 | nltk.download("punkt") 170 | nltk.download("averaged_perceptron_tagger") 171 | nltk.download("universal_tagset") 172 | self.normalize_fn = self._nltk_truecasing 173 | 174 | def __call__(self, random_capitalized_string: str) -> str: 175 | truecased_str = self.normalize_fn(random_capitalized_string) 176 | return truecased_str 177 | 178 | def _spacy_truecasing(self, random_capitalized_string: str): 179 | doc = self.nlp(random_capitalized_string.lower()) 180 | POS = self.uppercase_pos 181 | truecased_str = "".join([w.text_with_ws.capitalize() if w.pos_ in POS or w.is_sent_start else w.text_with_ws for w in doc]) 182 | return truecased_str 183 | 184 | def _nltk_truecasing(self, random_capitalized_string: str): 185 | from nltk import pos_tag, word_tokenize 186 | import nltk 187 | 188 | nltk.download("punkt") 189 | nltk.download("averaged_perceptron_tagger") 190 | nltk.download("universal_tagset") 191 | POS = ["NNP", "NNPS"] 192 | 193 | tagged_text = pos_tag(word_tokenize(random_capitalized_string.lower())) 194 | truecased_str = " ".join([w.capitalize() if p in POS else w for (w, p) in tagged_text]) 195 | return truecased_str 196 | -------------------------------------------------------------------------------- /watermark/old_watermark.py: -------------------------------------------------------------------------------- 1 | # micro lib to implement the watermarking extensions to LM generation 2 | # as well as utils for eval/validaiton 3 | 4 | from typing import List, Optional, Callable 5 | 6 | import time 7 | import random 8 | import math 9 | import torch 10 | import numpy as np 11 | 12 | from torch import Tensor 13 | from tokenizers import Tokenizer 14 | from transformers import LogitsProcessor, LogitsProcessorList, set_seed 15 | 16 | 17 | def tokenize_and_truncate(example: dict, 18 | completion_length: int = None, 19 | prompt_length: int = None, 20 | hf_model_name: str = None, 21 | tokenizer = None, 22 | model_max_seq_len: int = 4096): 23 | """take hf dataset entry and preprocess it for completion by a model""" 24 | assert hf_model_name is not None, "need model name to know whether to adjust wrt special tokens" 25 | assert "text" in example, "expects 'text' field to be present" 26 | # tokenize 27 | inputs = tokenizer.encode(example["text"], return_tensors="pt", truncation=True, max_length=model_max_seq_len) 28 | example.update({"untruncated_inputs": inputs}) 29 | 30 | if (completion_length is not None) and (prompt_length is None): 31 | # leave at least one token as prefix # FIXME I think plus 1 since 0 is start tok 32 | slice_length = min(inputs.shape[1]-1, completion_length) 33 | elif (prompt_length is not None) and (completion_length is None): 34 | desired_comp_len = (inputs.shape[1]-1) - prompt_length 35 | slice_length = desired_comp_len if desired_comp_len > 0 else 0 36 | else: 37 | raise ValueError((f"Can only tokenize and truncate based on either the desired prompt length or desired completion length,", 38 | f" but got completion_length:{completion_length},prompt_length:{prompt_length}")) 39 | 40 | # truncate 41 | inputs = inputs[:,:inputs.shape[1]-slice_length] 42 | # logic depending on special tokens for the model 43 | if "t5" in hf_model_name or "T0" in hf_model_name: 44 | inputs[0,-1] = 1 45 | # else: pass 46 | example.update({"inputs": inputs}) 47 | return example 48 | 49 | 50 | class BlacklistLogitsProcessor(LogitsProcessor): 51 | """ 52 | [`LogitsProcessor`] that enforces that specified sequences will never be sampled. 53 | 54 | Args: 55 | bad_words_ids (`List[List[int]]`): 56 | List of list of token ids that are not allowed to be generated. In order to get the token ids of the words 57 | that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, 58 | add_special_tokens=False).input_ids`. 59 | eos_token_id (`int`): 60 | The id of the *end-of-sequence* token. 61 | """ 62 | 63 | def __init__(self, 64 | bad_words_ids: List[List[int]], 65 | eos_token_id: int, 66 | vocab: list[int], 67 | vocab_size: int, 68 | bl_proportion: float=0.5, 69 | bl_logit_bias: float=1.0, 70 | bl_type: str = "hard", # "soft" 71 | initial_seed: int=None, 72 | dynamic_seed: str=None, # "initial", "markov_1", None 73 | store_bl_ids: bool=False, 74 | store_spike_ents: bool = False, 75 | noop_blacklist: bool = False, 76 | ): 77 | 78 | self.vocab = vocab 79 | self.vocab_size = vocab_size 80 | self.bl_proportion = bl_proportion 81 | self.bl_logit_bias = bl_logit_bias 82 | self.bl_type = bl_type 83 | 84 | if initial_seed is None: 85 | self.initial_seed = None 86 | assert dynamic_seed != "initial" 87 | else: 88 | random.seed(initial_seed) 89 | self.initial_seed = initial_seed 90 | 91 | self.dynamic_seed = dynamic_seed 92 | 93 | self.eos_token_id = eos_token_id 94 | # self.bad_words_id_length_1 = self._prepare_bad_words(bad_words_ids) 95 | 96 | self.bad_words_mask: Optional[torch.LongTensor] = None 97 | 98 | self.store_bl_ids = store_bl_ids 99 | self.bl_ids = None 100 | 101 | self.store_spike_ents = store_spike_ents 102 | self.spike_entropies = None 103 | 104 | # hack to replace this with an approximation of infinite bias 105 | # so that the expectation coefficient will come out to 1.0 106 | if self.bl_type == "hard": 107 | self.bl_logit_bias = 10000 # FIXME to a value that is actually close to the largest soft watermark used 108 | 109 | alpha = torch.exp(torch.tensor(self.bl_logit_bias)).item() 110 | # gamma = self.bl_proportion 111 | gamma = 1.0-self.bl_proportion 112 | self.alpha = alpha 113 | self.gamma = gamma 114 | 115 | self.z_value = ((1-gamma)*(alpha-1))/(1-gamma+(alpha*gamma)) 116 | self.expected_wl_coef = (gamma*alpha)/(1-gamma+(alpha*gamma)) 117 | 118 | # catch for overflow when bias is "infinite" 119 | if self.alpha == torch.inf: 120 | self.z_value = 1.0 121 | self.expected_wl_coef = 1.0 122 | 123 | self.noop_blacklist = noop_blacklist 124 | if self.noop_blacklist: print(f"Blacklist processor for accounting only, no rescoring of logits") 125 | 126 | self.g_cuda = None 127 | self.large_prime = 15485863 128 | 129 | @property 130 | def blacklisted_ids(self): 131 | assert self.store_bl_ids, "Need to instantiate processor with `store_bl_ids` to be able to retrieve them later" 132 | # flatten the each indexes blacklist 133 | l_of_bl_ids = [[] for _ in range(len(self.bl_ids))] 134 | for b_idx, batch in enumerate(self.bl_ids): 135 | for l_of_l, seed in batch: 136 | bl_ids = [l[0] for l in l_of_l] # this was the main line, maybe unnecessary now? 137 | l_of_bl_ids[b_idx].append((bl_ids,seed)) 138 | return l_of_bl_ids 139 | 140 | def get_and_clear_stored_bl_ids(self): 141 | old_bl_ids = self.bl_ids 142 | self.bl_ids = None 143 | return old_bl_ids 144 | 145 | def get_spike_entropies(self): 146 | spike_ents = [[] for _ in range(len(self.spike_entropies))] 147 | for b_idx, ent_tensor_list in enumerate(self.spike_entropies): 148 | for ent_tensor in ent_tensor_list: 149 | spike_ents[b_idx].append(ent_tensor.item()) 150 | return spike_ents 151 | 152 | def get_and_clear_stored_spike_ents(self): 153 | spike_ents = self.get_spike_entropies() 154 | self.spike_entropies = None 155 | return spike_ents 156 | 157 | def compute_spike_entropy(self, scores): 158 | # precomputed z value in init 159 | probs = scores.softmax(dim=-1) 160 | denoms = 1+(self.z_value*probs) 161 | renormed_probs = probs / denoms 162 | sum_renormed_probs = renormed_probs.sum() 163 | return sum_renormed_probs 164 | 165 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 166 | 167 | self.bad_words_id_length_1 = [None for _ in range(input_ids.shape[0])] 168 | 169 | if self.g_cuda is None: 170 | self.g_cuda = torch.Generator(device=input_ids.device) 171 | 172 | for b_idx in range(input_ids.shape[0]): 173 | if self.dynamic_seed == "initial": 174 | self.g_cuda.manual_seed(self.large_prime*self.initial_seed) 175 | elif self.dynamic_seed == "markov_1": 176 | self.g_cuda.manual_seed(self.large_prime*input_ids[b_idx][-1].item()) 177 | elif self.dynamic_seed is None: 178 | # let the rng evolve naturally - this is not a realistic setting 179 | pass 180 | 181 | # print(f"now tok is {input_ids[b_idx][-1].item()}") 182 | bl_ct = int(self.vocab_size*self.bl_proportion) 183 | blacklist_ids = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.g_cuda)[:bl_ct] # ty Yuxin :] 184 | # print(f"blacklist_ids[0] is {blacklist_ids[0]}") 185 | 186 | if self.store_bl_ids: 187 | if self.bl_ids is None: self.bl_ids = [[] for _ in range(input_ids.shape[0])] 188 | self.bl_ids[b_idx].append((blacklist_ids,input_ids.tolist()[b_idx][-1])) 189 | 190 | if self.store_spike_ents: 191 | if self.spike_entropies is None: self.spike_entropies = [[] for _ in range(input_ids.shape[0])] 192 | self.spike_entropies[b_idx].append(self.compute_spike_entropy(scores[b_idx])) 193 | 194 | # self.bad_words_id_length_1[b_idx] = self._prepare_bad_words(blacklist_ids) 195 | # this logic may not really be necessary for our usecase 196 | self.bad_words_id_length_1[b_idx] = blacklist_ids 197 | 198 | if not self.noop_blacklist: 199 | self.bad_words_mask = self._calc_curr_bad_word_mask(scores) 200 | scores = self._set_scores_to_inf_for_banned_tokens(scores) 201 | 202 | # print("scores shape is", scores.shape) 203 | 204 | return scores 205 | 206 | def _prepare_bad_words(self, bad_words_ids: List[List[int]]) -> list[int]: 207 | bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [self.eos_token_id], bad_words_ids)) 208 | return bad_words_ids 209 | # used to have more logic, not used now 210 | 211 | def _calc_curr_bad_word_mask(self, scores: torch.FloatTensor) -> torch.BoolTensor: 212 | bad_words_mask = torch.zeros_like(scores) 213 | for b_idx in range(len(self.bad_words_id_length_1)): 214 | bad_words_mask[b_idx][self.bad_words_id_length_1[b_idx]] = 1 215 | final_mask = bad_words_mask.bool() 216 | return final_mask 217 | 218 | def _set_scores_to_inf_for_banned_tokens( 219 | self, scores: torch.Tensor 220 | ) -> torch.Tensor: 221 | """ 222 | Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a 223 | list of list of banned tokens to ban in the format [[batch index, vocabulary position],... 224 | 225 | Args: 226 | scores: logits distribution of shape (batch size, vocabulary size) 227 | banned_tokens: list of list of tokens to ban of length (batch_size) 228 | # NOTE^^ Omitted logic for dynamic mask based on multi-token ban words 229 | """ 230 | if self.bl_type == "hard": 231 | scores = scores.masked_fill(self.bad_words_mask, -float("inf")) 232 | elif self.bl_type == "soft": 233 | whitelist_mask = torch.logical_not(self.bad_words_mask) 234 | blacklist_mask = self.bad_words_mask 235 | scores[whitelist_mask] = scores[whitelist_mask] + self.bl_logit_bias 236 | # scores[blacklist_mask] = scores[blacklist_mask] - self.bl_logit_bias # additive only 237 | else: 238 | raise NotImplementedError(f"unrecognized bl type {self.bl_type}!") 239 | return scores 240 | 241 | 242 | def score_sequence(inputs: Tensor = None, 243 | outputs: Tensor = None, 244 | tokenizer: Tokenizer = None, 245 | # logits_processor: LogitsProcessor = None, 246 | initial_seed: int = None, 247 | dynamic_seed: str = None, 248 | bl_proportion: float = None, 249 | use_cuda: bool = True, 250 | record_hits: bool = False, 251 | debug: bool = True, 252 | # trim_tokens: int = 2, 253 | ): 254 | 255 | assert (inputs is not None) and \ 256 | (outputs is not None) and \ 257 | (tokenizer is not None),"output tensor, tokenizer, and bl params req'd" 258 | # (logits_processor is not None), 259 | 260 | vocabulary = list(tokenizer.get_vocab().values()) 261 | vocab_size = len(vocabulary) 262 | 263 | model_generations = outputs.tolist()[0] # these are tensors unpack once for speed 264 | # toks_generated = model_generations[num_orig_input_tokens:] 265 | toks_generated = model_generations 266 | num_toks_generated = len(toks_generated) 267 | 268 | # num_toks_to_trim = trim_tokens*2 269 | # if (num_toks_generated-num_toks_to_trim > 0) == False: 270 | # return -1, -1 271 | 272 | # assert num_toks_generated > num_toks_to_trim, f"Need more than {num_toks_to_trim} toks total since we trim start and end a bit." 273 | 274 | # toks_generated = toks_generated[trim_tokens:-trim_tokens] 275 | 276 | if initial_seed is not None: 277 | random.seed(initial_seed) 278 | 279 | device = (torch.device("cuda") if use_cuda else torch.device("cpu")) 280 | g_cuda = torch.Generator(device=device) 281 | large_prime = 15485863 282 | 283 | bl_hits, hit_list = 0, [] 284 | 285 | prev_token = inputs[0][-1].item() 286 | # prev_token = toks_generated[0] # haven't decided whether this edge effect matters 287 | 288 | # for idx,tok_gend in enumerate(toks_generated[1:]): 289 | for idx,tok_gend in enumerate(toks_generated): 290 | 291 | # prev_token = model_generations[num_orig_input_tokens+idx-1] 292 | 293 | if dynamic_seed == "initial": 294 | g_cuda.manual_seed(large_prime*initial_seed) 295 | elif dynamic_seed == "markov_1": 296 | g_cuda.manual_seed(large_prime*prev_token) 297 | elif dynamic_seed is None: 298 | # let the rng evolve naturally - this is not a realistic setting 299 | pass 300 | 301 | bl_ct = int(vocab_size*bl_proportion) 302 | posthoc_blacklist = torch.randperm(vocab_size, device=device, generator=g_cuda)[:bl_ct] # ty Yuxin :] 303 | 304 | tok_in_ph_bl = tok_gend in posthoc_blacklist 305 | if tok_in_ph_bl: 306 | bl_hits += 1 307 | hit_list.append(True) 308 | else: 309 | hit_list.append(False) 310 | 311 | if debug: 312 | decoded_token = tokenizer.decode(tok_gend, skip_special_tokens=True) 313 | print(f"Token generated: '{decoded_token}' was in the blacklist {tok_in_ph_bl}") 314 | 315 | prev_token = tok_gend 316 | 317 | if debug: 318 | print(f"wl hits / num tokens : {num_toks_generated-bl_hits}/{num_toks_generated} = {(num_toks_generated-bl_hits)/num_toks_generated:.02f}") 319 | print(f"bl hits / num tokens : {bl_hits}/{num_toks_generated} = {bl_hits/num_toks_generated:.02f}") 320 | 321 | if record_hits: 322 | return bl_hits, num_toks_generated, hit_list 323 | # bl_fraction = bl_hits/num_toks_generated 324 | return bl_hits, num_toks_generated 325 | 326 | 327 | def tokenize_for_generation(example: dict, 328 | idx: int, 329 | max_new_tokens: int=None, 330 | min_prompt_tokens: int=None, 331 | hf_model_name : str=None, 332 | tokenizer: Tokenizer=None, 333 | model: torch.nn.Module=None): 334 | 335 | # preprocessing, generation & scoring 336 | assert isinstance(example, dict), "Expect no batch dimension currently!" 337 | 338 | # preprocess for model generation/completion 339 | example = tokenize_and_truncate(example, 340 | completion_length=max_new_tokens, 341 | prompt_length=min_prompt_tokens, 342 | hf_model_name=hf_model_name, 343 | tokenizer=tokenizer, 344 | # model_max_seq_len=model.config.max_position_embeddings) 345 | model_max_seq_len=None) 346 | inputs = example["inputs"] 347 | # for calculating the baseline violation rate across the "gold" completion 348 | untruncated_inputs = example["untruncated_inputs"] 349 | 350 | # decode the preprocessed input to store for audit 351 | re_decoded_input = tokenizer.batch_decode(inputs, skip_special_tokens=True)[0] 352 | example.update({"truncated_input":re_decoded_input}) 353 | 354 | # also decode the original suffix of the input for audit as the baseline 355 | decoded_untruncated_input = tokenizer.batch_decode(untruncated_inputs, skip_special_tokens=True)[0] 356 | example.update({"baseline_completion":decoded_untruncated_input.replace(re_decoded_input,"")}) 357 | 358 | example.update({ 359 | "orig_sample_length" : untruncated_inputs.shape[1], 360 | "prompt_length" : inputs.shape[1], 361 | "real_completion_length" : untruncated_inputs.shape[1] - inputs.shape[1], 362 | }) 363 | return example 364 | 365 | 366 | def generate_completions(example: dict, 367 | idx: int, 368 | max_new_tokens: int=None, 369 | hf_model_name : str=None, 370 | tokenizer: Tokenizer=None, 371 | model: torch.nn.Module=None, 372 | no_bl_partial: Callable=None, 373 | w_bl_partial: Callable=None, 374 | # return_logits: bool=False, 375 | bl_processor_list: LogitsProcessorList=None): 376 | 377 | # preprocessing, generation & scoring 378 | assert isinstance(example, dict), "Expect no batch dimension currently!" 379 | 380 | # # preprocess for model generation/completion 381 | # example = tokenize_and_truncate(example, 382 | # completion_length=max_new_tokens, 383 | # hf_model_name=hf_model_name, 384 | # tokenizer=tokenizer, 385 | # model_max_seq_len=model.config.max_position_embeddings) 386 | # inputs = example["inputs"] 387 | # # for calculating the baseline violation rate across the "gold" completion 388 | # untruncated_inputs = example["untruncated_inputs"] 389 | 390 | # # decode the preprocessed input to store for audit 391 | # re_decoded_input = tokenizer.batch_decode(inputs, skip_special_tokens=True)[0] 392 | # example.update({"truncated_input":re_decoded_input}) 393 | 394 | # # also decode the original suffix of the input for audit as the baseline 395 | # decoded_untruncated_input = tokenizer.batch_decode(untruncated_inputs, skip_special_tokens=True)[0] 396 | # example.update({"baseline_completion":decoded_untruncated_input.replace(re_decoded_input,"")}) 397 | 398 | inputs = example["inputs"] 399 | re_decoded_input = example["truncated_input"] 400 | 401 | # call the vanilla and watermarked generation function wrappers with the preprocessed inputs 402 | with torch.no_grad(): 403 | 404 | samples_taken = 0 405 | max_retries = 10 406 | success = False 407 | while (success is False) and (samples_taken < max_retries): 408 | samples_taken += 1 409 | 410 | # set_seed(1234) # debugging the error when using sampling # leaving this off for now 411 | 412 | start_generation = time.time() 413 | outputs_no_bl = no_bl_partial(inputs.to(model.device)) 414 | example["no_bl_gen_time"] = time.time() - start_generation 415 | 416 | # set_seed(1234) # debugging the error when using sampling 417 | 418 | start_generation = time.time() 419 | outputs_w_bl = w_bl_partial(inputs.to(model.device)) 420 | example["w_bl_gen_time"] = time.time() - start_generation 421 | 422 | # if return_logits: 423 | # output_no_bl_dict = outputs_no_bl 424 | # logits_no_bl = output_no_bl_dict.scores 425 | # outputs_no_bl = output_no_bl_dict.sequences 426 | # example["logits_no_bl"] = logits_no_bl 427 | 428 | # output_w_bl_dict = outputs_w_bl 429 | # logits_w_bl = output_w_bl_dict.scores 430 | # outputs_w_bl = output_w_bl_dict.sequences 431 | # example["logits_w_bl"] = logits_w_bl 432 | 433 | 434 | if bl_processor_list: 435 | if bl_processor_list[0].bl_ids is not None: 436 | example["bl_ids"] = bl_processor_list[0].get_and_clear_stored_bl_ids() 437 | if bl_processor_list[0].spike_entropies is not None: 438 | example["spike_entropies"] = bl_processor_list[0].get_and_clear_stored_spike_ents() 439 | 440 | try: 441 | # decode and store the new generations for auditing 442 | no_bl_decoded_output = tokenizer.batch_decode(outputs_no_bl, skip_special_tokens=True)[0] 443 | example.update({"no_bl_output":no_bl_decoded_output.replace(re_decoded_input,"")}) 444 | 445 | w_bl_decoded_output = tokenizer.batch_decode(outputs_w_bl, skip_special_tokens=True)[0] 446 | example.update({"w_bl_output":w_bl_decoded_output.replace(re_decoded_input,"")}) 447 | 448 | success = True 449 | 450 | except: 451 | # log what happened 452 | print(f"Error while trying to decode the outputs of the model...") 453 | if samples_taken == 1: 454 | print(f"truncated_input: {inputs.tolist()}") 455 | print(f"Result of attempt {samples_taken}") 456 | print(f"shape outputs_no_bl: {outputs_no_bl.shape}") 457 | no_bl_toks = outputs_no_bl.tolist()[0] 458 | print(f"outputs_no_bl: {no_bl_toks}") 459 | print(f"outputs_no_bl min: {min(no_bl_toks)}") 460 | print(f"outputs_no_bl max: {max(no_bl_toks)}") 461 | 462 | print(f"shape outputs_w_bl: {outputs_w_bl.shape}") 463 | w_bl_toks = outputs_w_bl.tolist()[0] 464 | print(f"outputs_w_bl: {w_bl_toks}") 465 | print(f"outputs_w_bl min: {min(w_bl_toks)}") 466 | print(f"outputs_w_bl max: {max(w_bl_toks)}") 467 | 468 | if success is False: 469 | print(f"Unable to get both a no_bl and w_bl output that were decodeable after {samples_taken} tries, returning empty strings.") 470 | example.update({"no_bl_output":""}) 471 | example.update({"w_bl_output":""}) 472 | if bl_processor_list: 473 | if bl_processor_list[0].bl_ids is not None: 474 | example["bl_ids"] = [] 475 | if bl_processor_list[0].spike_entropies is not None: 476 | example["spike_entropies"] = [] 477 | 478 | # Able to get lengths in here by checking 479 | # truncated input shape versus the output shape 480 | 481 | example.update({ 482 | # "baseline_num_tokens_generated" : untruncated_inputs.shape[1] - inputs.shape[1], # want this earlier now 483 | "no_bl_num_tokens_generated" : outputs_no_bl.shape[1] - inputs.shape[1], 484 | "w_bl_num_tokens_generated" : outputs_w_bl.shape[1] - inputs.shape[1] 485 | }) 486 | example.update({ 487 | "no_bl_sec_per_tok" : example["no_bl_gen_time"]/example["no_bl_num_tokens_generated"], 488 | "no_bl_tok_per_sec" : example["no_bl_num_tokens_generated"]/example["no_bl_gen_time"], 489 | "w_bl_sec_per_tok" : example["w_bl_gen_time"]/example["w_bl_num_tokens_generated"], 490 | "w_bl_tok_per_sec" : example["w_bl_num_tokens_generated"]/example["w_bl_gen_time"], 491 | }) 492 | 493 | # now done externally because these persist outside this func 494 | # # remove any fields we don't need to keep 495 | # del example["inputs"] 496 | # del example["untruncated_inputs"] 497 | 498 | return example 499 | 500 | 501 | def compute_bl_metrics(example: dict, 502 | idx: int, 503 | hf_model_name: str = None, 504 | tokenizer: Tokenizer=None, 505 | initial_seed: int = None, 506 | dynamic_seed: str = None, 507 | bl_proportion: float = None, 508 | use_cuda: bool = None, 509 | record_hits: bool = False, 510 | limit_output_tokens: int = 0): 511 | 512 | # if example["idx"] == 3: breakpoint() 513 | # okay need to catch an odd bug here and fix things 514 | baseline_before = example["baseline_completion"] 515 | example["baseline_completion"] = baseline_before.replace(example["truncated_input"][:-1],"") 516 | if example["baseline_completion"] != baseline_before: 517 | print("baseline input replacement bug occurred!") 518 | 519 | no_bl_before = example["no_bl_output"] 520 | example["no_bl_output"] = no_bl_before.replace(example["truncated_input"][:-1],"") 521 | if example["no_bl_output"] != no_bl_before: 522 | print("no_bl_output input replacement bug occurred!") 523 | 524 | w_bl_before = example["w_bl_output"] 525 | example["w_bl_output"] = w_bl_before.replace(example["truncated_input"][:-1],"") 526 | if example["w_bl_output"] != w_bl_before: 527 | print("w_bl_output input replacement bug occurred!") 528 | 529 | if ("w_bl_output_attacked" in example): 530 | w_bl_attacked_before = example["w_bl_output_attacked"] 531 | example["w_bl_output_attacked"] = w_bl_attacked_before.replace(example["truncated_input"][:-1],"") 532 | if example["w_bl_output_attacked"] != w_bl_attacked_before: 533 | print("w_bl_output_attacked input replacement bug occurred!") 534 | 535 | ########## 536 | 537 | # preprocess for model generation/completion 538 | inputs = tokenize_and_truncate({"text":example["truncated_input"]}, 539 | completion_length=0, 540 | hf_model_name=hf_model_name, 541 | tokenizer=tokenizer)["inputs"] 542 | 543 | baseline_outputs = tokenize_and_truncate({"text":example["baseline_completion"]}, 544 | completion_length=0, 545 | hf_model_name=hf_model_name, 546 | tokenizer=tokenizer)["inputs"][:,1:] 547 | 548 | no_bl_outputs = tokenize_and_truncate({"text":example["no_bl_output"]}, 549 | completion_length=0, 550 | hf_model_name=hf_model_name, 551 | tokenizer=tokenizer)["inputs"][:,1:] 552 | 553 | w_bl_outputs = tokenize_and_truncate({"text":example["w_bl_output"]}, 554 | completion_length=0, 555 | hf_model_name=hf_model_name, 556 | tokenizer=tokenizer)["inputs"][:,1:] 557 | if "w_bl_output_attacked" in example: 558 | w_bl_attacked_outputs = tokenize_and_truncate({"text":example["w_bl_output_attacked"]}, 559 | completion_length=0, 560 | hf_model_name=hf_model_name, 561 | tokenizer=tokenizer)["inputs"][:,1:] 562 | else: 563 | w_bl_attacked_outputs = None 564 | 565 | if limit_output_tokens > 0: 566 | example["orig_baseline_completion"] = example["baseline_completion"] 567 | example["orig_real_completion_length"] = example["real_completion_length"] 568 | baseline_outputs = baseline_outputs[:,:limit_output_tokens] 569 | example["real_completion_length"] = baseline_outputs.shape[1] 570 | example["baseline_completion"] = tokenizer.batch_decode(baseline_outputs, skip_special_tokens=True)[0] 571 | 572 | example["orig_no_bl_output"] = example["no_bl_output"] 573 | example["orig_no_bl_num_tokens_generated"] = example["no_bl_num_tokens_generated"] 574 | no_bl_outputs = no_bl_outputs[:,:limit_output_tokens] 575 | example["no_bl_num_tokens_generated"] = no_bl_outputs.shape[1] 576 | example["no_bl_output"] = tokenizer.batch_decode(no_bl_outputs, skip_special_tokens=True)[0] 577 | 578 | example["orig_w_bl_output"] = example["w_bl_output"] 579 | example["orig_w_bl_num_tokens_generated"] = example["w_bl_num_tokens_generated"] 580 | w_bl_outputs = w_bl_outputs[:,:limit_output_tokens] 581 | example["w_bl_num_tokens_generated"] = w_bl_outputs.shape[1] 582 | example["w_bl_output"] = tokenizer.batch_decode(w_bl_outputs, skip_special_tokens=True)[0] 583 | 584 | example["orig_spike_entropies"] = example["spike_entropies"] 585 | example["spike_entropies"] = [example["spike_entropies"][0][:limit_output_tokens]] 586 | 587 | if "w_bl_output_attacked" in example: 588 | # raise NotImplementedError("Havent thought what to do yet for this") 589 | example["orig_w_bl_output_attacked"] = example["w_bl_output_attacked"] 590 | # example["orig_w_bl_attacked_num_tokens_generated"] = example["w_bl_attacked_num_tokens_generated"] 591 | w_bl_attacked_outputs = w_bl_attacked_outputs[:,:limit_output_tokens] 592 | example["w_bl_attacked_num_tokens_generated"] = w_bl_attacked_outputs.shape[1] 593 | example["w_bl_output_attacked"] = tokenizer.batch_decode(w_bl_attacked_outputs, skip_special_tokens=True)[0] 594 | 595 | # score the 3 sequence completions/outputs wrt to bl hits 596 | result = score_sequence(inputs=inputs, 597 | outputs=baseline_outputs, # <-- real text completions 598 | initial_seed=initial_seed, 599 | dynamic_seed=dynamic_seed, 600 | bl_proportion=bl_proportion, 601 | tokenizer=tokenizer, 602 | use_cuda=use_cuda, 603 | record_hits=record_hits, 604 | debug=False) 605 | if record_hits: 606 | bl_hits, num_toks_gend, hit_list = result 607 | else: 608 | bl_hits, num_toks_gend = result 609 | example.update({"baseline_num_toks_gend_eq_0":(num_toks_gend == 0)}) 610 | # if num_toks_gend < 0.99*example["real_completion_length"]: breakpoint() 611 | # if len(hit_list) < 0.99*example["real_completion_length"]: breakpoint() 612 | 613 | if num_toks_gend == 0: 614 | # print("No tokens generated, odd, avoiding div by zero and returning -1's") 615 | wl_frac = -1 616 | bl_frac = -1 617 | else: 618 | wl_frac = (num_toks_gend-bl_hits)/num_toks_gend 619 | bl_frac = bl_hits/num_toks_gend 620 | baseline_stats = { 621 | "baseline_whitelist_fraction": wl_frac, 622 | "baseline_blacklist_fraction": bl_frac 623 | } 624 | example.update(baseline_stats) 625 | if record_hits: example.update({"baseline_hit_list":hit_list}) 626 | 627 | result = score_sequence(inputs=inputs, 628 | outputs=no_bl_outputs, # <-- non-blacklisted version 629 | initial_seed=initial_seed, 630 | dynamic_seed=dynamic_seed, 631 | bl_proportion=bl_proportion, 632 | tokenizer=tokenizer, 633 | record_hits=record_hits, 634 | debug=False) 635 | if record_hits: 636 | bl_hits, num_toks_gend, hit_list = result 637 | else: 638 | bl_hits, num_toks_gend = result 639 | example.update({"no_bl_num_toks_gend_eq_0":(num_toks_gend == 0)}) 640 | # if num_toks_gend < 0.99*example["no_bl_num_tokens_generated"]: breakpoint() 641 | # if len(hit_list) < 0.99*example["no_bl_num_tokens_generated"]: breakpoint() 642 | 643 | if num_toks_gend == 0: 644 | # print("No tokens generated, odd, avoiding div by zero and returning -1's") 645 | wl_frac = -1 646 | bl_frac = -1 647 | else: 648 | wl_frac = (num_toks_gend-bl_hits)/num_toks_gend 649 | bl_frac = bl_hits/num_toks_gend 650 | no_bl_stats = { 651 | "no_bl_whitelist_fraction": wl_frac, 652 | "no_bl_blacklist_fraction": bl_frac 653 | } 654 | example.update(no_bl_stats) 655 | if record_hits: example.update({"no_bl_hit_list":hit_list}) 656 | 657 | result = score_sequence(inputs=inputs, 658 | outputs=w_bl_outputs, # <-- blacklisted version 659 | initial_seed=initial_seed, 660 | dynamic_seed=dynamic_seed, 661 | bl_proportion=bl_proportion, 662 | tokenizer=tokenizer, 663 | record_hits=record_hits, 664 | # breakpoint_on_hit=True, # banging head against wall 665 | debug=False) 666 | if record_hits: 667 | bl_hits, num_toks_gend, hit_list = result 668 | else: 669 | bl_hits, num_toks_gend = result 670 | example.update({"w_bl_num_toks_gend_eq_0":(num_toks_gend == 0)}) 671 | # if num_toks_gend < 0.99*example["w_bl_num_tokens_generated"]: breakpoint() 672 | # if len(hit_list) < 0.99*example["w_bl_num_tokens_generated"]: breakpoint() 673 | 674 | if num_toks_gend == 0: 675 | # print("No tokens generated, odd, avoiding div by zero and returning -1's") 676 | wl_frac = -1 677 | bl_frac = -1 678 | else: 679 | wl_frac = (num_toks_gend-bl_hits)/num_toks_gend 680 | bl_frac = bl_hits/num_toks_gend 681 | w_bl_stats = { 682 | "w_bl_whitelist_fraction": wl_frac, 683 | "w_bl_blacklist_fraction": bl_frac 684 | } 685 | example.update(w_bl_stats) 686 | if record_hits: example.update({"w_bl_hit_list":hit_list}) 687 | 688 | if w_bl_attacked_outputs is not None: 689 | result = score_sequence(inputs=inputs, 690 | outputs=w_bl_attacked_outputs, # <-- blacklisted but attacked version 691 | initial_seed=initial_seed, 692 | dynamic_seed=dynamic_seed, 693 | bl_proportion=bl_proportion, 694 | tokenizer=tokenizer, 695 | record_hits=record_hits, 696 | # breakpoint_on_hit=True, # banging head against wall 697 | debug=False) 698 | if record_hits: 699 | bl_hits, num_toks_gend, hit_list = result 700 | else: 701 | bl_hits, num_toks_gend = result 702 | example.update({"w_bl_attacked_num_toks_gend_eq_0":(num_toks_gend == 0)}) 703 | # if (num_toks_gend-bl_hits)/(num_toks_gend) < 1.0: breakpoint() 704 | 705 | if num_toks_gend == 0: 706 | # print("No tokens generated, odd, avoiding div by zero and returning -1's") 707 | wl_frac = -1 708 | bl_frac = -1 709 | else: 710 | wl_frac = (num_toks_gend-bl_hits)/num_toks_gend 711 | bl_frac = bl_hits/num_toks_gend 712 | w_bl_attacked_stats = { 713 | "w_bl_attacked_num_tokens_generated": num_toks_gend, 714 | "w_bl_attacked_whitelist_fraction": wl_frac, 715 | "w_bl_attacked_blacklist_fraction": bl_frac 716 | } 717 | example.update(w_bl_attacked_stats) 718 | if record_hits: example.update({"w_bl_attacked_hit_list":hit_list}) 719 | 720 | return example 721 | 722 | 723 | 724 | def aggregate_bl_stats(example: dict, idx: int, stat_table: dict): 725 | 726 | stat_table["baseline_stats"]["whitelist_fraction"] += example["baseline_stats"]["whitelist_fraction"] 727 | stat_table["baseline_stats"]["blacklist_fraction"] += example["baseline_stats"]["blacklist_fraction"] 728 | 729 | stat_table["w_bl_stats"]["whitelist_fraction"] += example["w_bl_stats"]["whitelist_fraction"] 730 | stat_table["w_bl_stats"]["blacklist_fraction"] += example["w_bl_stats"]["blacklist_fraction"] 731 | 732 | stat_table["no_bl_stats"]["whitelist_fraction"] += example["no_bl_stats"]["whitelist_fraction"] 733 | stat_table["no_bl_stats"]["blacklist_fraction"] += example["no_bl_stats"]["blacklist_fraction"] 734 | 735 | stat_table["num_examples"] += 1 736 | 737 | return example 738 | 739 | 740 | def compute_ppl_single(prefix_and_output_text = None, 741 | output_text = None, 742 | oracle_model_name = None, 743 | oracle_model = None, 744 | oracle_tokenizer = None): 745 | 746 | with torch.no_grad(): 747 | tokd_prefix = tokenize_and_truncate({"text":prefix_and_output_text}, completion_length=0, hf_model_name=oracle_model_name, tokenizer=oracle_tokenizer, model_max_seq_len=oracle_model.config.max_position_embeddings)["inputs"] 748 | tokd_inputs = tokd_prefix 749 | # if only want to score the "generation" part we need the suffix tokenization length 750 | tokd_suffix = tokenize_and_truncate({"text":output_text}, completion_length=0, hf_model_name=oracle_model_name, tokenizer=oracle_tokenizer, model_max_seq_len=oracle_model.config.max_position_embeddings)["inputs"] 751 | 752 | tokd_inputs = tokd_inputs.to(oracle_model.device) 753 | # make labels, mark if not including all positions 754 | tokd_labels = tokd_inputs.clone().detach() 755 | tokd_labels[:,:tokd_labels.shape[1]-tokd_suffix.shape[1]+1] = -100 756 | 757 | outputs = oracle_model(input_ids=tokd_inputs, labels=tokd_labels) 758 | loss = outputs.loss # avg CE loss all positions (except -100, TODO plz check that this is working correctly) 759 | ppl = torch.tensor(math.exp(loss)) 760 | 761 | return loss.item(), ppl.item() 762 | 763 | 764 | def evaluate_generation_fluency(example: dict, 765 | idx: int, 766 | oracle_model_name = None, 767 | oracle_model = None, 768 | oracle_tokenizer = None): 769 | 770 | # pull out the required fields from the pipeline results 771 | inputs_plus_baseline_output = f"{example['truncated_input']}{example['baseline_completion']}" 772 | baseline_output = f"{example['baseline_completion']}" 773 | 774 | inputs_plus_no_bl_output = f"{example['truncated_input']}{example['no_bl_output']}" 775 | no_bl_output = f"{example['no_bl_output']}" 776 | 777 | inputs_plus_w_bl_output = f"{example['truncated_input']}{example['w_bl_output']}" 778 | w_bl_output = f"{example['w_bl_output']}" 779 | 780 | # add metrics 781 | loss, ppl = compute_ppl_single(inputs_plus_baseline_output, baseline_output, oracle_model_name, oracle_model, oracle_tokenizer) 782 | example["baseline_loss"] = loss 783 | example["baseline_ppl"] = ppl 784 | loss, ppl = compute_ppl_single(inputs_plus_no_bl_output, no_bl_output, oracle_model_name, oracle_model, oracle_tokenizer) 785 | example["no_bl_loss"] = loss 786 | example["no_bl_ppl"] = ppl 787 | loss, ppl = compute_ppl_single(inputs_plus_w_bl_output, w_bl_output, oracle_model_name, oracle_model, oracle_tokenizer) 788 | example["w_bl_loss"] = loss 789 | example["w_bl_ppl"] = ppl 790 | 791 | # del any temp values 792 | return example 793 | 794 | 795 | def add_idx(example,idx): 796 | example.update({"idx":idx}) 797 | return example 798 | 799 | 800 | def check_input_lengths(example,idx, min_sample_len=0, min_prompt_len=0, min_completion_len=0): 801 | orig_sample_length = example["orig_sample_length"] 802 | prompt_length = example["prompt_length"] 803 | real_completion_length = example["real_completion_length"] 804 | 805 | # breakpoint() 806 | 807 | conds = all([ 808 | orig_sample_length >= min_sample_len, 809 | prompt_length >= min_prompt_len, 810 | real_completion_length >= min_completion_len, 811 | ]) 812 | return conds 813 | 814 | 815 | def check_output_lengths(example,min_output_len=0): 816 | no_bl_output_len = example["no_bl_num_tokens_generated"] 817 | w_bl_output_len = example["w_bl_num_tokens_generated"] 818 | conds = all([ 819 | no_bl_output_len >= min_output_len, 820 | w_bl_output_len >= min_output_len, 821 | ]) 822 | return conds 823 | 824 | class OldWatermarkDetector(): 825 | """ 826 | Class for detecting watermarks 827 | """ 828 | 829 | def __init__(self, 830 | tokenizer, 831 | vocab: list[int] = None, 832 | gamma: float = 0.5, 833 | delta: float = 5.0, 834 | hash_key: int = 15485863, 835 | initial_seed: int=None, 836 | dynamic_seed: str=None, # "initial", "markov_1", None 837 | device: torch.device = None, 838 | select_green_tokens: bool = True, 839 | 840 | ): 841 | self.vocab = vocab 842 | self.vocab_size = len(vocab) 843 | self.gamma = gamma 844 | self.delta = delta 845 | self.hash_key = hash_key 846 | self.device = device 847 | self.tokenizer = tokenizer 848 | self.select_green_tokens = select_green_tokens 849 | 850 | if initial_seed is None: 851 | self.initial_seed = None 852 | assert dynamic_seed != "initial" 853 | else: 854 | random.seed(initial_seed) 855 | self.initial_seed = initial_seed 856 | 857 | self.dynamic_seed = dynamic_seed 858 | self.rng = torch.Generator(device=self.device) 859 | 860 | def _compute_z_score(self, observed_count, T): 861 | # count refers to number of green tokens, T is total number of tokens 862 | expected_count = self.gamma 863 | print("observed_count", observed_count) 864 | print("T", T) 865 | numer = observed_count - expected_count * T 866 | denom = math.sqrt(T * expected_count * (1 - expected_count)) 867 | z = numer / denom 868 | return z 869 | 870 | def detect(self, 871 | inputs: list[int]=None, 872 | tokenized_text: list[int]=None, 873 | debug: bool=True, 874 | return_scores: bool = True,): 875 | assert tokenized_text is not None, "Must pass tokenized string" 876 | 877 | assert inputs is not None, "Must pass inputs" 878 | 879 | #if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id): 880 | # print("bos removed!!!") 881 | # tokenized_text = tokenized_text[1:] 882 | 883 | green_token_count, green_token_mask = 0, [] 884 | 885 | 886 | input_sequence = tokenized_text.tolist()[0] 887 | 888 | # print("input sequence is:", input_sequence) 889 | prev_token = inputs[0][-1].item() 890 | # print("prev_token0 is: ", prev_token) 891 | 892 | # prev_token = input_sequence[1] 893 | for idx, tok_gend in enumerate(input_sequence): 894 | if self.dynamic_seed == "initial": 895 | self.rng.manual_seed(self.hash_key*self.initial_seed) 896 | 897 | elif self.dynamic_seed == "markov_1": 898 | self.rng.manual_seed(self.hash_key*prev_token) 899 | 900 | elif self.dynamic_seed == "None": 901 | pass 902 | 903 | # print("prev_token is: ", prev_token) 904 | redlist_size = int(self.vocab_size*(1 - self.gamma)) 905 | 906 | vocab_permutation = torch.randperm(self.vocab_size, device=self.device,generator=self.rng) 907 | 908 | redlist_ids = vocab_permutation 909 | if self.select_green_tokens: # directly 910 | redlist_ids = vocab_permutation[:redlist_size] # new 911 | else: # select green via red 912 | redlist_ids = vocab_permutation[(self.vocab_size - redlist_size):] # legacy behavior 913 | 914 | # print(f"len of redlist is {len(redlist_ids) }, the first is {redlist_ids[0]}") 915 | 916 | tok_in_ph_gl = tok_gend in redlist_ids 917 | 918 | if tok_in_ph_gl: 919 | green_token_mask.append(False) 920 | 921 | else: 922 | green_token_count += 1 923 | green_token_mask.append(True) 924 | 925 | if debug: 926 | decoded_token = self.tokenizer.decode(tok_gend, skip_special_tokens=True) 927 | print(f"Token generated: '{decoded_token}' was in the redlist {tok_in_ph_gl}") 928 | print("prev token decode is: ", decoded_token) 929 | prev_token = tok_gend 930 | 931 | 932 | z_score = self._compute_z_score(green_token_count, len(input_sequence)) 933 | # print("z_score is:", z_score) 934 | return z_score 935 | -------------------------------------------------------------------------------- /watermark/run_generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | import os 4 | import torch.nn.functional as F 5 | import json 6 | import time 7 | import torch 8 | from old_watermark import BlacklistLogitsProcessor 9 | from our_watermark import OurBlacklistLogitsProcessor 10 | from gptwm import GPTWatermarkLogitsWarper 11 | from watermark_v2 import WatermarkLogitsProcessor 12 | 13 | 14 | 15 | 16 | from transformers import T5Tokenizer, AutoTokenizer, \ 17 | AutoModel, AutoModelForCausalLM, T5ForConditionalGeneration, AutoModelForSeq2SeqLM, \ 18 | StoppingCriteria, StoppingCriteriaList, \ 19 | LlamaTokenizer, LlamaForCausalLM, \ 20 | LogitsProcessorList 21 | 22 | def str2bool(v): 23 | """Util function for user friendly boolean flag args""" 24 | if isinstance(v, bool): 25 | return v 26 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 27 | return True 28 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 29 | return False 30 | else: 31 | raise argparse.ArgumentTypeError('Boolean value expected.') 32 | 33 | 34 | 35 | class StopSequences(StoppingCriteria): 36 | def __init__(self, stop_sequences_set): 37 | super().__init__() 38 | self.stop_sequences_set = stop_sequences_set 39 | 40 | def __call__(self, input_ids, scores): 41 | if input_ids[0][-1].item() in self.stop_sequences_set: 42 | return True 43 | return False 44 | 45 | def load_model(model_name, test_input): 46 | # only load tokenizer for test_input mode 47 | 48 | # find model 49 | model_paths = [i + model_name for i in [ 50 | '/data2/lhm/cookie_huggingface_models/', 51 | '/data2/cookie_huggingface_models/', 52 | ]] 53 | for model_path in model_paths: 54 | if os.path.exists(model_path): 55 | break 56 | else: 57 | print(f'Model {model_name} not found!') 58 | return None, None, None 59 | 60 | # load model from local files 61 | # all supported models are listed as following: 62 | if model_name == 'flan-t5-xxl': 63 | tokenizer = T5Tokenizer.from_pretrained(model_path) 64 | model = T5ForConditionalGeneration.from_pretrained( 65 | model_path, device_map='auto', return_dict_in_generate=True, output_scores=True 66 | ) if not test_input else None 67 | elif model_name == 'flan-ul2': 68 | tokenizer = AutoTokenizer.from_pretrained(model_path) 69 | model = T5ForConditionalGeneration.from_pretrained( 70 | model_path, device_map='auto', return_dict_in_generate=True, output_scores=True, torch_dtype=torch.bfloat16 71 | ) if not test_input else None 72 | elif model_name == 'ul2': 73 | tokenizer = AutoTokenizer.from_pretrained(model_path) 74 | model = T5ForConditionalGeneration.from_pretrained( 75 | model_path, device_map='auto', return_dict_in_generate=True, output_scores=True 76 | ) if not test_input else None 77 | elif model_name == 'T0pp': 78 | tokenizer = AutoTokenizer.from_pretrained(model_path) 79 | model = AutoModelForSeq2SeqLM.from_pretrained( 80 | model_path, device_map='auto', return_dict_in_generate=True, output_scores=True 81 | ) if not test_input else None 82 | # elif model_name == 'opt-iml-30b': # TODO 83 | # tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 84 | # model = AutoModelForCausalLM.from_pretrained( 85 | # model_path, device_map='auto', return_dict_in_generate=True, output_scores=True, torch_dtype=torch.float16 86 | # ) if not test_input else None 87 | elif model_name == 'alpaca-lora-7b': 88 | tokenizer = LlamaTokenizer.from_pretrained(model_path) 89 | model = LlamaForCausalLM.from_pretrained( 90 | model_path, device_map='auto', return_dict_in_generate=True, output_scores=True, torch_dtype=torch.float16, load_in_8bit=True 91 | ) if not test_input else None 92 | elif model_name == 'chatglm-6b': 93 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 94 | model = AutoModel.from_pretrained( 95 | model_path, device_map='auto', return_dict_in_generate=True, output_scores=True, trust_remote_code=True 96 | ).half() if not test_input else None 97 | 98 | elif model_name in [ 99 | 'gpt-neox-20b', 100 | 'GPT-JT-6B-v1', 101 | 'gpt-j-6b', 102 | 'bloom-7b1', 103 | 'llama-65b-hf', 104 | 'LLaMA-13B', 105 | 'vicuna-13b', 106 | ]: 107 | tokenizer = AutoTokenizer.from_pretrained(model_path) 108 | model = AutoModelForCausalLM.from_pretrained( 109 | model_path, device_map='auto', return_dict_in_generate=True, output_scores=True 110 | ) if not test_input else None 111 | else: 112 | print(f'Model {model_name} not supported!') 113 | tokenizer, model = None, None 114 | 115 | max_length = { 116 | 'flan-t5-xxl': 512, 117 | 'ul2': 512, 118 | 'flan-ul2': 2048, 119 | 'GPT-JT-6B-v1': 2048, 120 | 'gpt-j-6b': 2048, 121 | 'gpt-neox-20b': 2048, 122 | 'bloom-7b1': 10000, # no length overflow till now 123 | 'T0pp': 512, 124 | 'llama-65b-hf': 10000, 125 | 'alpaca-lora-7b': 10000, 126 | 'chatglm-6b': 10000, 127 | 'LLaMA-13B': 512, 128 | 'vicuna-13b':2048, 129 | }[model_name] 130 | 131 | return tokenizer, model, max_length 132 | 133 | 134 | def load_dataset(dataset_name, layer): 135 | 136 | # find dataset 137 | dataset_path = f'/data2/cookie/input/{layer}/{dataset_name}/' 138 | if not os.path.exists(dataset_path): 139 | print(f'Dataset {dataset_name} not found!') 140 | return None, None 141 | 142 | # load dataset 143 | 144 | # datasets with only one file which means examples have been attached to instances 145 | # other datasets have test.json and train.json and their inputs need to be assembled from the two files 146 | preprocessed_datasets_to_file = { 147 | 'hotpotqa': 'hotpotqa_sample.json', 148 | 'musique' : 'musique_sample.json', 149 | 'kqapro' : 'kqapro_sample.json', 150 | '2wikimultihopqa': '2WikiMultihopQA_sample.json', 151 | } 152 | 153 | if dataset_name.split('/')[0] not in preprocessed_datasets_to_file: 154 | test_file = dataset_path + 'test.json' 155 | else: 156 | test_file = dataset_path + preprocessed_datasets_to_file[dataset_name] 157 | with open(test_file) as f: 158 | dataset_test = json.load(f) 159 | 160 | dataset_train = None 161 | if dataset_name.split('/')[0] not in preprocessed_datasets_to_file: # preprocessed datasets have no train.json 162 | with open(dataset_path + 'train.json') as f: 163 | dataset_train = json.load(f) 164 | 165 | return dataset_test, dataset_train 166 | 167 | 168 | def main(args): 169 | 170 | model_names, dataset_names, layer, test_input, num_shot, init_seed, dyna_seed, bl_proportion, bl_logit_bias, bl_type, num_beams, sampling_temp = args.models, args.datasets, args.layer, args.test_input, args.num_shot, args.initial_seed, args.dynamic_seed, 1-args.gamma, args.delta, args.bl_type, args.num_beams, args.sampling_temp 171 | 172 | for model_name in model_names: 173 | 174 | tokenizer, model, max_length = load_model(model_name, test_input) 175 | if not tokenizer and not model: 176 | return 177 | 178 | if not test_input: 179 | result_path = './our_scores/' + model_name + '/' 180 | if not os.path.exists(result_path): 181 | os.mkdir(result_path) 182 | 183 | all_token_ids = list(tokenizer.get_vocab().values()) 184 | vocab_size = len(all_token_ids) 185 | print(f"Vocabulary size: {vocab_size}") 186 | 187 | # replace datasets to their sub tasks, if exist 188 | datasets_with_sub_tasks = { 189 | 'COPEN': ['cic', 'cpj', 'csj'], 190 | 'FewNERD': ['inter', 'intra', 'supervised'], 191 | 'KoRC': ['iid', 'ood'], 192 | } 193 | dataset_names_temp = [] 194 | for i in dataset_names: 195 | if i not in datasets_with_sub_tasks: 196 | dataset_names_temp.append(i) 197 | else: 198 | dataset_names_temp.extend((f'{i}/{j}' for j in datasets_with_sub_tasks[i])) # 'task/subtask' 199 | dataset_names = dataset_names_temp 200 | 201 | for dataset_name in dataset_names: 202 | 203 | dataset_test, dataset_train = load_dataset(dataset_name, layer) 204 | if not dataset_test: 205 | return 206 | 207 | spec = dataset_test['adapter_spec'] 208 | instruction = spec['instructions'] 209 | input_prefix = spec['input_prefix'] 210 | input_suffix = spec['input_suffix'] 211 | output_prefix = spec['output_prefix'] 212 | output_suffix = spec['output_suffix'] 213 | instance_prefix = spec['instance_prefix'] if 'instance_prefix' in spec else '' # some earlier datasets don't have this field 214 | 215 | # every query start with instruction 216 | if model_name == 'ul2': 217 | query_prefix = f'[S2S] {instruction}' 218 | else: 219 | query_prefix = instruction 220 | 221 | # instances in following datasets have been preprocessed, only instruction needs to be attached 222 | preprocessed_datasets = set([ 223 | '2wikimultihopqa', 224 | 'hotpotqa', 225 | 'kqapro', 226 | 'musique', 227 | ]) 228 | 229 | if dataset_name not in preprocessed_datasets and dataset_train: # if not preprocessed, examples need to be concatenated 230 | for i in dataset_train['request_states'][:num_shot]: 231 | i_input = i['instance']['input']['text'] 232 | i_output = i['instance']['references'][0]['output']['text'] 233 | query_prefix += instance_prefix 234 | query_prefix += f'{input_prefix}{i_input}{input_suffix}' 235 | query_prefix += f'{output_prefix}{i_output}{output_suffix}' 236 | 237 | query_prefix += instance_prefix + input_prefix 238 | 239 | if not test_input: 240 | max_new_tokens = dataset_test['adapter_spec']['max_tokens'] 241 | stop_sequences_set = set(tokenizer.encode(i)[0] for i in dataset_test['adapter_spec']['stop_sequences']) 242 | stop_criteria = StopSequences(stop_sequences_set) 243 | print(f'Running inference on {dataset_name} with {model_name}') 244 | 245 | if test_input: 246 | lens = [] 247 | 248 | 249 | bl_processor = BlacklistLogitsProcessor(tokenizer=tokenizer, 250 | bad_words_ids=None, 251 | eos_token_id=tokenizer.eos_token_id, 252 | vocab=all_token_ids, 253 | all_vocab_size=vocab_size, 254 | bl_proportion=bl_proportion, 255 | bl_logit_bias=bl_logit_bias, 256 | bl_type=bl_type, 257 | initial_seed=init_seed, 258 | dynamic_seed=dyna_seed) 259 | if args.mode == 'new': 260 | bl_processor = OurBlacklistLogitsProcessor(tokenizer=tokenizer, 261 | bad_words_ids=None, 262 | eos_token_id=tokenizer.eos_token_id, 263 | vocab=all_token_ids, 264 | all_vocab_size=vocab_size, 265 | bl_proportion=bl_proportion, 266 | bl_logit_bias=bl_logit_bias, 267 | bl_type=bl_type, 268 | initial_seed=init_seed, 269 | dynamic_seed=dyna_seed) 270 | logit_processor_lst = LogitsProcessorList([bl_processor]) 271 | 272 | if args.mode == 'gpt': 273 | watermark_processor = GPTWatermarkLogitsWarper(vocab_size=len(list(tokenizer.get_vocab().values())), 274 | fraction=args.gamma, 275 | strength=args.delta) 276 | 277 | logit_processor_lst = LogitsProcessorList([watermark_processor]) 278 | 279 | if args.mode == 'v2': 280 | watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()), 281 | gamma=args.gamma, 282 | delta=args.delta, 283 | seeding_scheme=args.seeding_scheme, 284 | select_green_tokens=args.select_green_tokens) 285 | logit_processor_lst = LogitsProcessorList([watermark_processor]) 286 | 287 | 288 | # run inference 289 | for request in tqdm(dataset_test['request_states']): 290 | 291 | input_text = query_prefix + request['instance']['input']['text'] 292 | if dataset_name not in preprocessed_datasets: 293 | input_text += input_suffix + output_prefix 294 | 295 | if model_name == 'ul2': 296 | input_text += " " 297 | input_ids = tokenizer(input_text, return_tensors='pt').input_ids.to('cuda') 298 | 299 | if test_input: # just print one input without running inference 300 | lens.append(len(input_ids[0])) 301 | continue 302 | 303 | if len(input_ids[0]) > max_length: # mark and skip 304 | print(f'skip overlong instances: {len(input_ids[0])}') 305 | request['request'] = { 306 | 'result': { 307 | 'success': False, 308 | 'completions': [{ 309 | 'text': 'null_overlength', 310 | 'logprob': 0., 311 | 'tokens': [], 312 | }], 313 | 'cached': True, 314 | 'request_time': 0., 315 | 'request_datetime': int(time.time()), 316 | } 317 | } 318 | continue 319 | 320 | 321 | if args.mode == 'new': 322 | example = {} 323 | start_time = time.time() 324 | # TODO: pad_token_id=tokenizer.eos_token_id 325 | outputs = model.generate( 326 | input_ids, max_new_tokens=max_new_tokens, 327 | stopping_criteria=StoppingCriteriaList([stop_criteria]), 328 | logits_processor = logit_processor_lst, 329 | do_sample=True, 330 | top_k=0, 331 | temperature=sampling_temp 332 | ) 333 | request_time = time.time() - start_time 334 | request_datetime = int(time.time()) 335 | 336 | example.update({"bl_vocabularys":logit_processor_lst[0].get_and_clear_vocabularys()}) 337 | # remove the attached input from output for some model 338 | scores = outputs.scores 339 | output_ids = outputs.sequences[0, -len(scores):] 340 | 341 | # remove the tail, if generation stops at any stop_sequences 342 | if output_ids[-1].item() in stop_sequences_set: 343 | scores = scores[:-1] 344 | output_ids = output_ids[:-1] 345 | 346 | # compute logprob for each token 347 | completions_tokens = [] 348 | completions_logprob = 0 349 | 350 | for score, token, vocabulary in zip(scores, output_ids, example["bl_vocabularys"], strict=True): 351 | logprobs = F.log_softmax(score[0], dim=-1) 352 | logprob = logprobs[token].item() 353 | completions_tokens.append({ 354 | 'text': tokenizer.decode(token), 355 | 'logprob': logprob, 356 | 'vocabulary': vocabulary, 357 | }) 358 | completions_logprob += logprob 359 | 360 | completions_text = tokenizer.decode(output_ids, skip_special_tokens=True) 361 | 362 | request['request'] = { 363 | 'result': { 364 | 'success': True, 365 | 'completions': [{ 366 | 'text': completions_text, 367 | 'logprob': completions_logprob, 368 | 'tokens': completions_tokens, 369 | }], 370 | 'cached': True, 371 | 'request_time': request_time, 372 | 'request_datetime': request_datetime, 373 | } 374 | } 375 | 376 | else: 377 | 378 | if args.mode == 'no': 379 | start_time = time.time() 380 | # TODO: pad_token_id=tokenizer.eos_token_id 381 | outputs = model.generate( 382 | input_ids, max_new_tokens=max_new_tokens, 383 | stopping_criteria=StoppingCriteriaList([stop_criteria]), 384 | ) 385 | request_time = time.time() - start_time 386 | request_datetime = int(time.time()) 387 | 388 | elif args.mode == 'old': 389 | start_time = time.time() 390 | # TODO: pad_token_id=tokenizer.eos_token_id 391 | outputs = model.generate( 392 | input_ids, max_new_tokens=max_new_tokens, 393 | stopping_criteria=StoppingCriteriaList([stop_criteria]), 394 | logits_processor = logit_processor_lst, 395 | do_sample=True, 396 | top_k=0, 397 | temperature=sampling_temp 398 | ) 399 | request_time = time.time() - start_time 400 | request_datetime = int(time.time()) 401 | 402 | elif args.mode == 'gpt': 403 | start_time = time.time() 404 | # TODO: pad_token_id=tokenizer.eos_token_id 405 | outputs = model.generate( 406 | input_ids, max_new_tokens=max_new_tokens, 407 | stopping_criteria=StoppingCriteriaList([stop_criteria]), 408 | logits_processor = logit_processor_lst, 409 | do_sample=True, 410 | top_k=0, 411 | top_p=0.9 412 | ) 413 | request_time = time.time() - start_time 414 | request_datetime = int(time.time()) 415 | 416 | elif args.mode == 'v2': 417 | start_time = time.time() 418 | # TODO: pad_token_id=tokenizer.eos_token_id 419 | outputs = model.generate( 420 | input_ids, max_new_tokens=max_new_tokens, 421 | stopping_criteria=StoppingCriteriaList([stop_criteria]), 422 | logits_processor = logit_processor_lst, 423 | do_sample=True, 424 | top_k=0, 425 | temperature=sampling_temp 426 | ) 427 | request_time = time.time() - start_time 428 | request_datetime = int(time.time()) 429 | 430 | # remove the attached input from output for some model 431 | scores = outputs.scores 432 | output_ids = outputs.sequences[0, -len(scores):] 433 | 434 | # remove the tail, if generation stops at any stop_sequences 435 | if output_ids[-1].item() in stop_sequences_set: 436 | scores = scores[:-1] 437 | output_ids = output_ids[:-1] 438 | 439 | # compute logprob for each token 440 | completions_tokens = [] 441 | completions_logprob = 0 442 | 443 | for score, token in zip(scores, output_ids, strict=True): 444 | logprobs = F.log_softmax(score[0], dim=-1) 445 | logprob = logprobs[token].item() 446 | completions_tokens.append({ 447 | 'text': tokenizer.decode(token), 448 | 'logprob': logprob, 449 | }) 450 | completions_logprob += logprob 451 | 452 | completions_text = tokenizer.decode(output_ids, skip_special_tokens=True) 453 | 454 | request['request'] = { 455 | 'result': { 456 | 'success': True, 457 | 'completions': [{ 458 | 'text': completions_text, 459 | 'logprob': completions_logprob, 460 | 'tokens': completions_tokens, 461 | }], 462 | 'cached': True, 463 | 'request_time': request_time, 464 | 'request_datetime': request_datetime, 465 | } 466 | } 467 | 468 | dataset_name = dataset_name.replace('/', '++') # rename sub task, e.g. KoRC/iid -> KoRC++iid, as filename 469 | if not test_input: 470 | prefix = 'r_' if layer == 'Rolling' else '' # distinguish the datasets from Rolling, as they have the same names with the former ones 471 | with open(result_path + f'{prefix}{dataset_name}_inference_base.json', 'w') as f: 472 | json.dump(dataset_test, f, indent=2) 473 | else: 474 | print(input_text) 475 | print(sorted(lens)) 476 | 477 | parser = argparse.ArgumentParser(description="Process watermark to calculate z-score for every method") 478 | 479 | parser.add_argument( 480 | "--mode", 481 | type=str, 482 | default="old", 483 | choices=["no", "old", "new", "v2", "gpt"], 484 | help="Which version of the watermark to generate", 485 | ) 486 | 487 | parser.add_argument( 488 | '--datasets', 489 | type=str, 490 | required=True, 491 | nargs='+') 492 | 493 | parser.add_argument( 494 | '--layer', 495 | required=True) 496 | 497 | parser.add_argument( 498 | '--test_input', 499 | action='store_true') 500 | 501 | parser.add_argument( 502 | '--num_shot', 503 | type=int, 504 | default=5) 505 | 506 | parser.add_argument( 507 | '--models', 508 | type=str, 509 | default="vicuna-13b", 510 | required=True, 511 | nargs='+') 512 | 513 | parser.add_argument( 514 | '--initial_seed', 515 | type=int, 516 | default=1234, 517 | help=("The initial seed to use in the blacklist randomization process.", 518 | "Is unused if the process is markov generally. Can be None."), 519 | ) 520 | 521 | parser.add_argument( 522 | "--dynamic_seed", 523 | type=str, 524 | default="markov_1", 525 | choices=[None, "initial", "markov_1"], 526 | help="The seeding procedure to use when sampling the blacklist at each step.", 527 | ) 528 | 529 | parser.add_argument( 530 | "--gamma", 531 | type=float, 532 | default=0.5) 533 | 534 | parser.add_argument( 535 | "--delta", 536 | type=float, 537 | default=5.0) 538 | 539 | parser.add_argument( 540 | "--bl_type", 541 | type=str, 542 | default="soft", 543 | choices=["soft", "hard"], 544 | help="The type of blacklisting being performed.", 545 | ) 546 | parser.add_argument( 547 | "--num_beams", 548 | type=int, 549 | default=1, 550 | help="The number of beams to use where '1' is no beam search.", 551 | ) 552 | parser.add_argument( 553 | "--sampling_temp", 554 | type=float, 555 | default=0.7, 556 | help="The temperature to use when generating using multinom sampling", 557 | ) 558 | 559 | 560 | 561 | 562 | 563 | 564 | parser.add_argument( 565 | "--wm_key", 566 | type=int, 567 | default=0) 568 | 569 | parser.add_argument( 570 | '--test_input', 571 | action='store_true') 572 | 573 | parser.add_argument( 574 | "--threshold", 575 | type=float, 576 | default=6.0) 577 | 578 | parser.add_argument( 579 | "--test_min_tokens", 580 | type=int, 581 | default=2) 582 | 583 | parser.add_argument( 584 | "--seeding_scheme", 585 | type=str, 586 | default="simple_1", 587 | help="Seeding scheme to use to generate the greenlists at each generation and verification step.", 588 | ) 589 | 590 | parser.add_argument( 591 | "--normalizers", 592 | type=str, 593 | default="", 594 | help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.", 595 | ) 596 | 597 | parser.add_argument( 598 | "--ignore_repeated_bigrams", 599 | type=str2bool, 600 | default=False, 601 | help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.", 602 | ) 603 | 604 | parser.add_argument( 605 | "--select_green_tokens", 606 | type=str2bool, 607 | default=True, 608 | help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.", 609 | ) 610 | 611 | args = parser.parse_args() 612 | 613 | main(args) 614 | -------------------------------------------------------------------------------- /watermark/watermark_v2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Authors of "A Watermark for Large Language Models" 3 | # available at https://arxiv.org/abs/2301.10226 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from __future__ import annotations 18 | import collections 19 | from math import sqrt 20 | 21 | import scipy.stats 22 | 23 | import torch 24 | from torch import Tensor 25 | from tokenizers import Tokenizer 26 | from transformers import LogitsProcessor 27 | 28 | from nltk.util import ngrams 29 | 30 | from watermark.normalizers import normalization_strategy_lookup 31 | 32 | class WatermarkBase: 33 | def __init__( 34 | self, 35 | vocab: list[int] = None, 36 | gamma: float = 0.5, 37 | delta: float = 2.0, 38 | seeding_scheme: str = "simple_1", # mostly unused/always default 39 | hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width 40 | select_green_tokens: bool = True, 41 | ): 42 | 43 | # watermarking parameters 44 | self.vocab = vocab 45 | self.vocab_size = len(vocab) 46 | self.gamma = gamma 47 | self.delta = delta 48 | self.seeding_scheme = seeding_scheme 49 | self.rng = None 50 | self.hash_key = hash_key 51 | self.select_green_tokens = select_green_tokens 52 | 53 | def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None: 54 | # can optionally override the seeding scheme, 55 | # but uses the instance attr by default 56 | if seeding_scheme is None: 57 | seeding_scheme = self.seeding_scheme 58 | 59 | if seeding_scheme == "simple_1": 60 | assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng" 61 | prev_token = input_ids[-1].item() 62 | self.rng.manual_seed(self.hash_key * prev_token) 63 | else: 64 | raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}") 65 | return 66 | 67 | def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]: 68 | # seed the rng using the previous tokens/prefix 69 | # according to the seeding_scheme 70 | self._seed_rng(input_ids) 71 | 72 | # print("self.vocab_size:", self.vocab_size) 73 | # print("self.gamma:", self.gamma) 74 | 75 | greenlist_size = int(self.vocab_size * self.gamma) 76 | vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng) 77 | if self.select_green_tokens: # directly 78 | greenlist_ids = vocab_permutation[:greenlist_size] # new 79 | else: # select green via red 80 | greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior 81 | return greenlist_ids 82 | 83 | 84 | class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor): 85 | 86 | def __init__(self, *args, **kwargs): 87 | super().__init__(*args, **kwargs) 88 | 89 | def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor: 90 | # TODO lets see if we can lose this loop 91 | green_tokens_mask = torch.zeros_like(scores) 92 | for b_idx in range(len(greenlist_token_ids)): 93 | green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1 94 | final_mask = green_tokens_mask.bool() 95 | return final_mask 96 | 97 | def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor: 98 | scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias 99 | return scores 100 | 101 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 102 | 103 | # this is lazy to allow us to colocate on the watermarked model's device 104 | if self.rng is None: 105 | self.rng = torch.Generator(device=input_ids.device) 106 | 107 | # NOTE, it would be nice to get rid of this batch loop, but currently, 108 | # the seed and partition operations are not tensor/vectorized, thus 109 | # each sequence in the batch needs to be treated separately. 110 | batched_greenlist_ids = [None for _ in range(input_ids.shape[0])] 111 | 112 | for b_idx in range(input_ids.shape[0]): 113 | greenlist_ids = self._get_greenlist_ids(input_ids[b_idx]) 114 | batched_greenlist_ids[b_idx] = greenlist_ids 115 | 116 | green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids) 117 | 118 | scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta) 119 | return scores 120 | 121 | 122 | class WatermarkDetector(WatermarkBase): 123 | def __init__( 124 | self, 125 | *args, 126 | device: torch.device = None, 127 | tokenizer: Tokenizer = None, 128 | z_threshold: float = 4.0, 129 | normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"] 130 | ignore_repeated_bigrams: bool = False, 131 | **kwargs, 132 | ): 133 | super().__init__(*args, **kwargs) 134 | # also configure the metrics returned/preprocessing options 135 | assert device, "Must pass device" 136 | assert tokenizer, "Need an instance of the generating tokenizer to perform detection" 137 | 138 | self.tokenizer = tokenizer 139 | self.device = device 140 | self.z_threshold = z_threshold 141 | self.rng = torch.Generator(device=self.device) 142 | 143 | if self.seeding_scheme == "simple_1": 144 | self.min_prefix_len = 1 145 | else: 146 | raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}") 147 | 148 | self.normalizers = [] 149 | for normalization_strategy in normalizers: 150 | self.normalizers.append(normalization_strategy_lookup(normalization_strategy)) 151 | 152 | self.ignore_repeated_bigrams = ignore_repeated_bigrams 153 | if self.ignore_repeated_bigrams: 154 | assert self.seeding_scheme == "simple_1", "No repeated bigram credit variant assumes the single token seeding scheme." 155 | 156 | 157 | def _compute_z_score(self, observed_count, T): 158 | # count refers to number of green tokens, T is total number of tokens 159 | expected_count = self.gamma 160 | numer = observed_count - expected_count * T 161 | denom = sqrt(T * expected_count * (1 - expected_count)) 162 | z = numer / denom 163 | return z 164 | 165 | def _compute_p_value(self, z): 166 | p_value = scipy.stats.norm.sf(z) 167 | return p_value 168 | 169 | def _score_sequence( 170 | self, 171 | input_ids: Tensor, 172 | return_num_tokens_scored: bool = True, 173 | return_num_green_tokens: bool = True, 174 | return_green_fraction: bool = True, 175 | return_green_token_mask: bool = False, 176 | return_z_score: bool = True, 177 | return_p_value: bool = True, 178 | debug: bool = True, 179 | ): 180 | if self.ignore_repeated_bigrams: 181 | # Method that only counts a green/red hit once per unique bigram. 182 | # New num total tokens scored (T) becomes the number unique bigrams. 183 | # We iterate over all unqiue token bigrams in the input, computing the greenlist 184 | # induced by the first token in each, and then checking whether the second 185 | # token falls in that greenlist. 186 | assert return_green_token_mask == False, "Can't return the green/red mask when ignoring repeats." 187 | bigram_table = {} 188 | token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2) 189 | freq = collections.Counter(token_bigram_generator) 190 | num_tokens_scored = len(freq.keys()) 191 | for idx, bigram in enumerate(freq.keys()): 192 | prefix = torch.tensor([bigram[0]], device=self.device) # expects a 1-d prefix tensor on the randperm device 193 | greenlist_ids = self._get_greenlist_ids(prefix) 194 | bigram_table[bigram] = True if bigram[1] in greenlist_ids else False 195 | green_token_count = sum(bigram_table.values()) 196 | else: 197 | num_tokens_scored = len(input_ids) - self.min_prefix_len 198 | if num_tokens_scored < 1: 199 | raise ValueError((f"Must have at least {1} token to score after " 200 | f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme.")) 201 | # Standard method. 202 | # Since we generally need at least 1 token (for the simplest scheme) 203 | # we start the iteration over the token sequence with a minimum 204 | # num tokens as the first prefix for the seeding scheme, 205 | # and at each step, compute the greenlist induced by the 206 | # current prefix and check if the current token falls in the greenlist. 207 | green_token_count, green_token_mask = 0, [] 208 | for idx in range(self.min_prefix_len, len(input_ids)): 209 | curr_token = input_ids[idx] 210 | greenlist_ids = self._get_greenlist_ids(input_ids[:idx]) 211 | tok_in_ph_gl = curr_token in greenlist_ids 212 | if curr_token in greenlist_ids: 213 | green_token_count += 1 214 | green_token_mask.append(True) 215 | else: 216 | green_token_mask.append(False) 217 | 218 | if debug: 219 | decoded_token = self.tokenizer.decode(curr_token, skip_special_tokens=True) 220 | print(f"Token generated: '{decoded_token}' was in the blacklist {tok_in_ph_gl}") 221 | print("curr_token decode is: ", decoded_token) 222 | 223 | score_dict = dict() 224 | if return_num_tokens_scored: 225 | score_dict.update(dict(num_tokens_scored=num_tokens_scored)) 226 | if return_num_green_tokens: 227 | score_dict.update(dict(num_green_tokens=green_token_count)) 228 | if return_green_fraction: 229 | score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored))) 230 | if return_z_score: 231 | score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored))) 232 | if return_p_value: 233 | z_score = score_dict.get("z_score") 234 | if z_score is None: 235 | z_score = self._compute_z_score(green_token_count, num_tokens_scored) 236 | score_dict.update(dict(p_value=self._compute_p_value(z_score))) 237 | if return_green_token_mask: 238 | score_dict.update(dict(green_token_mask=green_token_mask)) 239 | 240 | return score_dict 241 | 242 | def detect( 243 | self, 244 | text: str = None, 245 | tokenized_text: list[int] = None, 246 | return_prediction: bool = True, 247 | return_scores: bool = True, 248 | z_threshold: float = None, 249 | debug: bool = True, 250 | **kwargs, 251 | ) -> dict: 252 | 253 | assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string" 254 | if return_prediction: 255 | kwargs["return_p_value"] = True # to return the "confidence":=1-p of positive detections 256 | 257 | # run optional normalizers on text 258 | for normalizer in self.normalizers: 259 | text = normalizer(text) 260 | if len(self.normalizers) > 0: 261 | print(f"Text after normalization:\n\n{text}\n") 262 | 263 | if tokenized_text is None: 264 | assert self.tokenizer is not None, ( 265 | "Watermark detection on raw string ", 266 | "requires an instance of the tokenizer ", 267 | "that was used at generation time.", 268 | ) 269 | tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device) 270 | if tokenized_text[0] == self.tokenizer.bos_token_id: 271 | tokenized_text = tokenized_text[1:] 272 | else: 273 | # try to remove the bos_tok at beginning if it's there 274 | if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id): 275 | tokenized_text = tokenized_text[1:] 276 | 277 | # call score method 278 | output_dict = {} 279 | score_dict = self._score_sequence(tokenized_text, **kwargs) 280 | if return_scores: 281 | output_dict.update(score_dict) 282 | # if passed return_prediction then perform the hypothesis test and return the outcome 283 | if return_prediction: 284 | z_threshold = z_threshold if z_threshold else self.z_threshold 285 | assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test" 286 | output_dict["prediction"] = score_dict["z_score"] > z_threshold 287 | if output_dict["prediction"]: 288 | output_dict["confidence"] = 1 - score_dict["p_value"] 289 | 290 | 291 | return output_dict 292 | 293 | 294 | 295 | --------------------------------------------------------------------------------