├── .gitignore ├── LICENSE ├── README.md ├── calculate_over_represented_words.ipynb ├── example_api.ipynb ├── example_generate.ipynb ├── example_visualise_backtracking.ipynb ├── run_api.py ├── slop_phrase_prob_adjustments.json ├── slop_phrase_prob_adjustments_full_list.json ├── slop_phrases_2025-04-07.json ├── slop_regexes.txt ├── slop_words_2025-04-07.json └── src ├── antislop_generate.py ├── slop_index.py ├── util.py ├── validator_json.py ├── validator_regex.py └── validator_slop.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | src/__pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AntiSlop Sampler 2 | 3 | ## Overview 4 | 5 | The AntiSlop sampler uses a backtracking mechanism to go back and retry with adjusted token probabilities when it encounters a disallowed word or phrase. No more testaments or tapestries or other gpt-slop. 6 | 7 | Try the sampler here: 8 | 9 | - [Backtracking visualisation notebook](https://colab.research.google.com/drive/1tHS3tHXbZ5PWOsP-Mbk8oK35y6otBZUn?usp=sharing) 10 | - [Generate example notebook](https://colab.research.google.com/drive/1Rd3V4AN31cDytfmY9u80rzHXPD_dS6x9?usp=sharing) 11 | 12 | 13 | ### How to use the sampler 14 | 15 | With koboldcpp (new!) 16 | 17 | Koboldcpp now includes antislop phrase banning in their latest release: 18 | https://github.com/LostRuins/koboldcpp/releases/tag/v1.76 19 | 20 | With a local chat UI (open-webui): 21 | 22 |
23 | [Click to view instructions] 24 | 25 | ### if you want to run in an isolated environment (note: open-webui requires python 3.11): 26 | ```bash 27 | sudo apt install python3.11 python3.11-venv 28 | python3.11 -m venv open-webui 29 | source open-webui/bin/activate 30 | ``` 31 | 32 | ### install open-webui 33 | ```bash 34 | pip install open-webui 35 | open-webui serve 36 | ``` 37 | 38 | ### start the openai compatible antislop server: 39 | ```bash 40 | git clone https://github.com/sam-paech/antislop-sampler.git && cd antislop-sampler 41 | pip install fastapi uvicorn ipywidgets IPython transformers bitsandbytes accelerate 42 | python3 run_api.py --model unsloth/Llama-3.2-3B-Instruct --slop_adjustments_file slop_phrase_prob_adjustments.json 43 | ``` 44 | 45 | ### configure open-webui 46 | - browse to http://localhost:8080 47 | - go to admin panel --> settings --> connections 48 | - set the OpenAI API url to http://0.0.0.0:8000/v1 49 | - set api key to anything (it's not used) 50 | - click save (!!) 51 | - click the refresh icon to verify the connection; should see a success message 52 | 53 | Now it should be all configured! Start a new chat, select the model, and give it a try. 54 | 55 | Note: Only some of the settings in the chat controls will be active. Those are: 56 | - stream chat response 57 | - temperature 58 | - top k 59 | - top p 60 | - min p 61 | - max tokens 62 | 63 | ![image](https://github.com/user-attachments/assets/8bcc2906-b1e1-4be0-b01b-66cf1b18a9ad) 64 | 65 |
66 | 67 | Here it is in action (in slow mode so you can see its backtracking & revisions): 68 | 69 | https://github.com/user-attachments/assets/e289fb89-d40c-458b-9d98-118c06b28fba 70 | 71 | ### Disclaimers: 72 | 73 | It's recommended that you roll your own slop list! The slop_phrase_prob_adjustments.json file is mostly auto-generated by computing over-represented words in a large LLM-generated story dataset. As such, it's not well optimised or curated. Curated lists will be added in the future. 74 | 75 | The api is very minimalist, and doesn't support concurrency. It's more geared for local use, testing & dataset generation than for production. 76 | 77 | This is research grade code. Meaning it probably contains bugs & is a work in progress. 78 | 79 | 80 | ### 2024-10-13 Update 81 | 82 | The main changes are: 83 | 84 | - Switched to string matching instead of token matching (it's more robust) 85 | - Added regex matching/banning 86 | 87 | The new regex_bans parameter accepts a list of regex patterns. These will be evaluated during inference, and if one of them matches, we backtrack to the first token of the matched string. From there, we ban the matched continuation and continue inference. 88 | 89 | This allows more freedom for enforcing constraints than phrase matching alone. For instance, we can prevent over-used phrasing like "not x, but y". Note that streaming is not supported when using regex bans, since we can't predict how far back we may need to backtrack. 90 | 91 | 92 |
93 | ### 2024-10-05 Update 94 | 95 | Refactored the code, lots of fixes. 96 | 97 | - Added an OpenAI compatible API server. 98 | - Now using model.generate with stopping conditions, to generate for multiple tokens instead of just 1 at a time. This is much faster. 99 | - Added a basic JSON validator + enforcement to demonstrate how the sampler can enforce long-range constraints. 100 | - Switch to probs from logits for the cached values, so that down/upregulation works as expected (this was a mistake in the previous implementation). 101 | - Refactored the code for better organisation. 102 | 103 | Quick blurb on the JSON validator: 104 | 105 | It uses the same backtracking mechanism to retry invalid JSON output. It checks for unintended unescaped quotes in strings, and encourages the model to choose a valid continuation. This is a very common fail mode for JSON outputs. Other kinds of per-token JSON grammars will just terminate the string if they see an unescaped quote, sadly ending the profound thought the LLM was in the middle of expressing. This is better. You can also use it with high temps. 106 | 107 |
108 | 109 |
110 | ### 2024-10-01 Update 111 | 112 | - Squashed vram leaks, fixed bugs. It should work with any transformers model now. 113 | - Support min_p 114 | - Now using slop_phrase_prob_adjustments.json by default, which has a more intuitive probability adjustment per slop phrase (1 == no change; < 1 means probability is reduced by that factor). It looks like this: 115 | ``` 116 | [ 117 | ["kaleidoscope", 0.5], 118 | ["symphony", 0.5], 119 | ["testament to", 0.5], 120 | ["elara", 0.5], 121 | ... 122 | ] 123 | ``` 124 |
125 | 126 | ### chat_antislop 127 | ```python 128 | # Chat generation with streaming 129 | messages = [ 130 | {"role": "user", "content": prompt} 131 | ] 132 | for token in chat_antislop( 133 | model=model, 134 | tokenizer=tokenizer, 135 | messages=messages, 136 | max_new_tokens=400, 137 | temperature=1, 138 | min_p=0.1, 139 | # The adjustment_strength param scales how strongly the probability adjustments are applied. 140 | # A value of 1 means the values in slop_phrase_prob_adjustments (or the defaults) are used unmodified. 141 | # Reasonable values are 0 (disabled) thru 100+ (effectively banning the list). 142 | adjustment_strength=100.0, 143 | # Optional: Provide a list of slop phrases and probability adjustments 144 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments, 145 | enforce_json=False, 146 | antislop_enabled=True, 147 | streaming=True 148 | ): 149 | print(tokenizer.decode(token), end='', flush=True) 150 | ``` 151 | 152 | ### generate_antislop 153 | ```python 154 | # generate without streaming 155 | prompt_with_template = tokenizer.apply_chat_template(messages, tokenize=False) 156 | generated_text = generate_antislop( 157 | model=model, 158 | tokenizer=tokenizer, 159 | prompt=prompt, 160 | max_length=300, 161 | temperature=1, 162 | min_p=0.1, 163 | adjustment_strength=100.0, 164 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments, 165 | enforce_json=False, 166 | antislop_enabled=True, 167 | streaming=False 168 | ) 169 | print(tokenizer.decode(generated_text)) 170 | ``` 171 | 172 | ## What this does: 173 | 174 | You can give it a list of words & phrases to avoid like "a tapestry of", "a testament to", etc., and it will backtrack and try something else if it hits that phrase. It can handle 1000s of slop phrases since the lookups are fast. The phrases and downregulation amounts are user configurable. Previous approaches have done this with per-token logit biasing; but that's quite ineffective since most slop words & phrases are more than one token, and it impairs output quality if we downregulate all those partial-word tokens. So instead, we wait for the whole phrase to appear in the output, then backtrack and downregulate all the tokens that could have produced the slop phrase, and continue from there. 175 | 176 | For the default slop list, we computed a large list of words that are over-represented in LLM output compared to normal human writing. This list is supplemented by a list of over-used phrases that are pet peeves of LLM enthusiasts. During generation, if any of the words & phrases in this list are generated, the sampler reduces the probability of the starting tokens that can lead to that phrase, by the factor specified in the config. This way you can lightly de-slop or strongly de-slop. You can of course also specify your own phrase list & weights. 177 | 178 | ## Why it's interesting: 179 | 180 | Samplers typically work at the token level -- but that doesn't work if want to avoid words/phrases that tokenise to >1 tokens. Elara might tokenise to ["El", "ara"], and we don't want to reduce the probs of everything beginning with "El". So, this approach waits for the whole phrase to appear, then backtracks and reduces the probabilities of all the likely tokens that will lead to that phrase being output. ~~Nobody afaik has tried this before.~~ [edit] It turns out exllamav2 has a banned_strings feature with same/similar implementation so we can't claim novelty. [/edit] 181 | 182 | * Disclaimers: This is only implemented in Transformers (and now koboldcpp!) thus far. It is not well optimised. Expect research grade code & possibly bugs. 183 | 184 | 185 | ## What you need to implement this 186 | 187 | If you'd like to implement this sampler in something other than transformers, here's what you need: 188 | 189 | - A loop to manage the state of the sampler, as it backtracks and needs to refer to past logits that it's cached 190 | - Per-token continuation generation (think: completions, not chat.completions) 191 | - Raw logits 192 | - Ability to bias logits when generating 193 | 194 | Unfortunately that rules out most commercial APIs since few let you specify logit biases. For inferencing engines, they will likely be a mixed bag in terms of ease of integration, as most/all samplers work per token without this weird backtracking stuff we're doing here. 195 | 196 | If you do implement this sampler in your thing, please let me know about it! 197 | 198 | ## Acknowledgements 199 | 200 | Turboderp was the first to implement this mechanism in exllamav2 as the "banned strings" feature. This was unknown to us at the time of creating the AntiSlop sampler, which was birthed independently in a case of convergent evolution. Credit to them for doing it first! 201 | 202 | ## How to Cite 203 | 204 | A paper is in the works, hang tight. 205 | 206 | ``` 207 | @misc{paech2024antislop, 208 | title={antislop-sampler}, 209 | author={Samuel J. Paech}, 210 | year={2024}, 211 | howpublished={\url{https://github.com/sam-paech/antislop-sampler}}, 212 | note={GitHub repository} 213 | } 214 | ``` 215 | -------------------------------------------------------------------------------- /calculate_over_represented_words.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### This notebook demonstrates how to compute the over-represented words in a corpus of text.\n", 8 | "\n", 9 | "It works by counting word frequencies in the designated dataset (in this case a large set of LLM-generated creative writing). These frequencies are compared that against the wordfreq frequencies, which represent the average prevalence of words in (human) language." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "#!pip3 install wordfreq datasets numpy" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 11, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import datasets\n", 28 | "from collections import Counter\n", 29 | "from wordfreq import word_frequency\n", 30 | "import numpy as np\n", 31 | "import re\n", 32 | "from tqdm import tqdm\n", 33 | "\n", 34 | "def download_datasets():\n", 35 | " datasets_info = [\n", 36 | " (\"ajibawa-2023/General-Stories-Collection\", \"train\")\n", 37 | " ]\n", 38 | " return {name: datasets.load_dataset(name, split=split) for name, split in datasets_info}\n", 39 | "\n", 40 | "def parse_text(datasets):\n", 41 | " texts = []\n", 42 | " for example in tqdm(datasets[\"ajibawa-2023/General-Stories-Collection\"]):\n", 43 | " texts.append(example['text'])\n", 44 | " return texts\n", 45 | "\n", 46 | "\n", 47 | "def get_word_counts(texts, min_length=4):\n", 48 | " \"\"\"\n", 49 | " Count word frequencies in a list of texts.\n", 50 | "\n", 51 | " Parameters:\n", 52 | " - texts (iterable of str): The input texts to process.\n", 53 | " - min_length (int): Minimum length of words to include.\n", 54 | "\n", 55 | " Returns:\n", 56 | " - Counter: A Counter object mapping words to their frequencies.\n", 57 | " \"\"\"\n", 58 | " # Precompile the regex pattern for better performance\n", 59 | " # This pattern matches words with internal apostrophes (e.g., \"couldn't\")\n", 60 | " pattern = re.compile(r\"\\b\\w+(?:'\\w+)?\\b\")\n", 61 | " \n", 62 | " word_counts = Counter()\n", 63 | " \n", 64 | " for text in tqdm(texts, desc=\"Counting words\"):\n", 65 | " if not isinstance(text, str):\n", 66 | " continue # Skip non-string entries to make the function more robust\n", 67 | " \n", 68 | " # Convert to lowercase and find all matching words\n", 69 | " words = pattern.findall(text.lower())\n", 70 | " \n", 71 | " # Update counts with words that meet the minimum length\n", 72 | " word_counts.update(word for word in words if len(word) >= min_length)\n", 73 | " \n", 74 | " return word_counts\n", 75 | "\n", 76 | "\n", 77 | "def analyze_word_rarity(word_counts):\n", 78 | " total_words = sum(word_counts.values())\n", 79 | " corpus_frequencies = {word: count / total_words for word, count in word_counts.items()}\n", 80 | " \n", 81 | " wordfreq_frequencies = {word: word_frequency(word, 'en') for word in word_counts.keys()}\n", 82 | " \n", 83 | " # Filter out words with zero frequency\n", 84 | " valid_words = [word for word, freq in wordfreq_frequencies.items() if freq > 0]\n", 85 | " \n", 86 | " corpus_freq_list = [corpus_frequencies[word] for word in valid_words]\n", 87 | " wordfreq_freq_list = [wordfreq_frequencies[word] for word in valid_words]\n", 88 | " \n", 89 | " # Calculate average rarity\n", 90 | " avg_corpus_rarity = np.mean([-np.log10(freq) for freq in corpus_freq_list])\n", 91 | " avg_wordfreq_rarity = np.mean([-np.log10(freq) for freq in wordfreq_freq_list])\n", 92 | " \n", 93 | " # Calculate correlation\n", 94 | " correlation = np.corrcoef(corpus_freq_list, wordfreq_freq_list)[0, 1]\n", 95 | " \n", 96 | " return corpus_frequencies, wordfreq_frequencies, avg_corpus_rarity, avg_wordfreq_rarity, correlation\n", 97 | "\n", 98 | "def find_over_represented_words(corpus_frequencies, wordfreq_frequencies, top_n=50000):\n", 99 | " over_representation = {}\n", 100 | " for word in corpus_frequencies.keys():\n", 101 | " wordfreq_freq = wordfreq_frequencies[word]\n", 102 | " if wordfreq_freq > 0: # Only consider words with non-zero frequency\n", 103 | " over_representation[word] = corpus_frequencies[word] / wordfreq_freq\n", 104 | " \n", 105 | " return sorted(over_representation.items(), key=lambda x: x[1], reverse=True)[:top_n]\n", 106 | "\n", 107 | "def find_zero_frequency_words(word_counts, wordfreq_frequencies, top_n=50000):\n", 108 | " zero_freq_words = {word: count for word, count in word_counts.items() if wordfreq_frequencies[word] == 0}\n", 109 | " return sorted(zero_freq_words.items(), key=lambda x: x[1], reverse=True)[:top_n]\n", 110 | "\n" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "\n", 120 | "print(\"Downloading datasets...\")\n", 121 | "all_datasets = download_datasets()\n" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "\n", 131 | "print(\"Parsing text...\")\n", 132 | "texts = parse_text(all_datasets)\n" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "\n", 142 | "print(f\"Total texts extracted: {len(texts)}\")\n", 143 | "\n", 144 | "print(\"Counting words...\")\n", 145 | "word_counts = get_word_counts(texts)\n" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "\n", 155 | "def filter_mostly_numeric(word_counts):\n", 156 | " def is_mostly_numbers(word):\n", 157 | " digit_count = sum(c.isdigit() for c in word)\n", 158 | " return digit_count / len(word) > 0.2 # Adjust this ratio if needed\n", 159 | " \n", 160 | " # Create a new Counter with filtered words\n", 161 | " return Counter({word: count for word, count in word_counts.items() if not is_mostly_numbers(word)})\n", 162 | "\n", 163 | "filtered_counts = filter_mostly_numeric(word_counts)\n", 164 | "\n", 165 | "print(\"Analyzing word rarity...\")\n", 166 | "corpus_frequencies, wordfreq_frequencies, avg_corpus_rarity, avg_wordfreq_rarity, correlation = analyze_word_rarity(filtered_counts)\n", 167 | "\n", 168 | "print(f\"Total unique words analyzed: {len(word_counts)}\")\n", 169 | "print(f\"Average corpus rarity: {avg_corpus_rarity:.4f}\")\n", 170 | "print(f\"Average wordfreq rarity: {avg_wordfreq_rarity:.4f}\")\n", 171 | "print(f\"Correlation between corpus and wordfreq frequencies: {correlation:.4f}\")\n", 172 | "\n", 173 | "print(\"\\nMost over-represented words in the corpus:\")\n", 174 | "over_represented = find_over_represented_words(corpus_frequencies, wordfreq_frequencies)\n", 175 | "for word, score in over_represented:\n", 176 | " print(f\"{word}: {score:.2f} times more frequent than expected\")\n", 177 | "\n", 178 | "print(\"\\nMost frequent words with zero wordfreq frequency:\")\n", 179 | "zero_freq_words = find_zero_frequency_words(filtered_counts, wordfreq_frequencies)\n", 180 | "for word, count in zero_freq_words:\n", 181 | " print(f\"{word}: {count} occurrences\")\n" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 13, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "\n", 191 | "import json\n", 192 | "with open('over_represented_words.json', 'w') as f:\n", 193 | " json.dump(over_represented, f)\n", 194 | "with open('frequent_non_dictionary_words.json', 'w') as f:\n", 195 | " json.dump(zero_freq_words, f)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "corpus_frequencies['testament']" 205 | ] 206 | } 207 | ], 208 | "metadata": { 209 | "kernelspec": { 210 | "display_name": "Python 3", 211 | "language": "python", 212 | "name": "python3" 213 | }, 214 | "language_info": { 215 | "codemirror_mode": { 216 | "name": "ipython", 217 | "version": 3 218 | }, 219 | "file_extension": ".py", 220 | "mimetype": "text/x-python", 221 | "name": "python", 222 | "nbconvert_exporter": "python", 223 | "pygments_lexer": "ipython3", 224 | "version": "3.8.9" 225 | } 226 | }, 227 | "nbformat": 4, 228 | "nbformat_minor": 2 229 | } 230 | -------------------------------------------------------------------------------- /example_api.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# First start the server:\n", 10 | "# python run_api.py --model unsloth/Llama-3.2-1B-Instruct\n", 11 | "# \n", 12 | "# Or with default slop phrase list (will be used with all queries unless slop_phrases is specified in the query):\n", 13 | "# python run_api.py --model unsloth/Llama-3.2-1B-Instruct --slop_adjustments_file slop_phrase_prob_adjustments.json" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "# ! Important:\n", 23 | "# \n", 24 | "# No slop adjustments will be applied if:\n", 25 | "# - no slop_adjustments_file specified at API launch \n", 26 | "# and\n", 27 | "# - no slop_phrases param specified in the query\n", 28 | "#\n", 29 | "# If you specified a slop_adjustments_file file at API launch, it will be used by default with queries\n", 30 | "# that do not specify a slop_phrases param. The query's slop_phrases param overrides the defaults." 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 6, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "Once upon a time, in a small village nestled in the Andes mountains of South America, there lived a young boy named Kusi. Kusi was a curious and adventurous soul, with a heart full of wonder and a spirit that longed to explore the world beyond his village. He spent most of his days helping his mother with the family's small farm, tending to the crops and animals, and listening to the tales of the elderly villagers about the mystical creatures that roamed the mountains.\n", 43 | "\n", 44 | "One day, while out collecting firewood, Kusi stumbled upon a strange and beautiful creature. It was a young ladyma (a type of Andean camelid) unlike any he had ever seen before. Her fur was a soft, creamy white, and her eyes shone the brightest stars in the night sky. She had a delicate, almost ethereal quality to her, and Kusi felt an instant connection to this enchanting creature.\n", 45 | "\n", 46 | "The ladyma, whose name was Akira," 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "import requests\n", 52 | "import json\n", 53 | "\n", 54 | "prompt = \"tell me a story about a magical llama\"\n", 55 | "api_url = 'http://localhost:8000/v1/chat/completions'\n", 56 | "\n", 57 | "messages = [{\"role\": \"user\", \"content\": prompt}]\n", 58 | "data = {\n", 59 | " \"messages\": messages, \n", 60 | " \"max_tokens\": 200,\n", 61 | " \"temperature\": 1,\n", 62 | " \"min_p\": 0.1,\n", 63 | " \"stream\": True,\n", 64 | " \"adjustment_strength\": 100,\n", 65 | " \"antislop_enabled\": True, # Defaults to true\n", 66 | " \"slop_phrases\": [[\"a testament to\", 0.3], [\"llama\", 0.1]] # this overrides the default list\n", 67 | "}\n", 68 | "\n", 69 | "try:\n", 70 | " # Using `stream=True` to handle the response as a stream of data\n", 71 | " response = requests.post(api_url, json=data, stream=True, timeout=30)\n", 72 | " #print(response)\n", 73 | " \n", 74 | " # Read and print the stream of data in chunks\n", 75 | " for chunk in response.iter_lines():\n", 76 | " if chunk:\n", 77 | " decoded_chunk = chunk.decode('utf-8')\n", 78 | " # OpenAI streams responses in `data: {json}` format, so we need to parse that\n", 79 | " if decoded_chunk.startswith('data:'):\n", 80 | " try:\n", 81 | " json_data = json.loads(decoded_chunk[len('data: '):]) # Remove 'data: ' prefix\n", 82 | " if 'choices' in json_data and len(json_data['choices']) > 0:\n", 83 | " print(json_data['choices'][0]['delta'].get('content', ''), end='', flush=True)\n", 84 | " except json.JSONDecodeError as e:\n", 85 | " print(f\"Error decoding JSON: {e}\")\n", 86 | "except Exception as e:\n", 87 | " print(f\"An error occurred: {e}\")\n" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 8, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "In the small village of Brindlemark, nestled in the rolling hills of the countryside, there lived a young woman named Eira. She was known throughout the village as the loomweaver, Eira, for her extraordinary talent in weaving intricate patterns with the threads of the finest wool.\n", 100 | "\n", 101 | "Eira's fascination with weaving began when she was just a child, watching her mother and grandmother work the loom in the evenings. The soft rustle of the wool against the fabric, the gentle hum of the shuttle, and the way the colors seemed to come alive as the loom wove its magic – all of these things captivated Eira, and she spent countless hours practicing and experimenting with different techniques.\n", 102 | "\n", 103 | "As she grew older, Eira's skills improved dramatically, and she became one of the most sought-after weavers in the region. Her creations were renowned for their beauty and complexity, with intricate patterns and delicate colors that seemed to dance across the fabric.\n", 104 | "\n", 105 | "But Eira's life\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "# Example using regex bans\n", 111 | "prompt = \"write a story about the loomweaver elara\"\n", 112 | "messages = [{\"role\": \"user\", \"content\": prompt}]\n", 113 | "try:\n", 114 | " data = { \n", 115 | " \"messages\": messages, \n", 116 | " \"max_tokens\": 200,\n", 117 | " \"temperature\": 1,\n", 118 | " \"min_p\": 0.1,\n", 119 | " \"stream\": False,\n", 120 | " \"adjustment_strength\": 100,\n", 121 | " \"slop_phrases\": [[\"a testament to\", 0.3], [\"kusi\", 0.1]],\n", 122 | " \"antislop_enabled\": True,\n", 123 | " \"regex_bans\": ['(?i)not [^.!?]{3,60} but', '(?i)elara'] # Not compatible with streaming\n", 124 | " }\n", 125 | " \n", 126 | " response = requests.post(api_url, json=data, stream=False, timeout=30)\n", 127 | " data = response.json()\n", 128 | " print(data['choices'][0]['message']['content'])\n", 129 | "\n", 130 | "except Exception as e:\n", 131 | " print(f\"An error occurred: {e}\")\n" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 53, 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "In the heart of a mystical Andean village, there lived a legend about a magical llama named Luna. Her coat shone like the moon, and her eyes glowed with an ethereal light that seemed to hold the secrets of the universe. Luna was no ordinary llama – she possessed the power to heal the sick, grant wisdom to those who sought it, and even bend the fabric of time to bring peace to those who were troubled.\n", 151 | "\n", 152 | "One day, a young girl named Sophia wandered into the village, searching for her ailing mother. Sophia's mother had fallen gravely ill, and the village healer had exhausted all his remedies. Desperate for a solution, Sophia set out to find the legendary magical llama, Luna.\n", 153 | "\n", 154 | "As she trekked through the high-altitude grasslands, Sophia encountered many creatures – the majestic condors soaring above, the chattering vicuñas below, and the gentle guanacos grazing in the distance. But none of them led her to Luna, until finally" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "# Default slop list example\n", 160 | "#\n", 161 | "# If the user hasn't specified a list of slop phrases & adjustments in their \n", 162 | "# query, then it will default to use whatever you specified in the\n", 163 | "# --slop_adjustments_file argument when launching the api.\n", 164 | "#\n", 165 | "# If you didn't specify anything, it won't make any adjustments unless the\n", 166 | "# user specifies their adjustments in the query.\n", 167 | "\n", 168 | "prompt = \"tell me a story about a magical llama\"\n", 169 | "api_url = 'http://localhost:8000/v1/chat/completions'\n", 170 | "\n", 171 | "messages = [{\"role\": \"user\", \"content\": prompt}]\n", 172 | "data = { \n", 173 | " \"messages\": messages, \n", 174 | " \"max_tokens\": 200,\n", 175 | " \"temperature\": 1,\n", 176 | " \"min_p\": 0.1,\n", 177 | " \"stream\": True,\n", 178 | " \"antislop_enabled\": True,\n", 179 | " \"adjustment_strength\": 100,\n", 180 | "}\n", 181 | "\n", 182 | "try:\n", 183 | " # Using `stream=True` to handle the response as a stream of data\n", 184 | " response = requests.post(api_url, json=data, stream=True, timeout=30)\n", 185 | " #print(response)\n", 186 | " \n", 187 | " # Read and print the stream of data in chunks\n", 188 | " for chunk in response.iter_lines():\n", 189 | " if chunk:\n", 190 | " decoded_chunk = chunk.decode('utf-8')\n", 191 | " # OpenAI streams responses in `data: {json}` format, so we need to parse that\n", 192 | " if decoded_chunk.startswith('data:'):\n", 193 | " try:\n", 194 | " json_data = json.loads(decoded_chunk[len('data: '):]) # Remove 'data: ' prefix\n", 195 | " if 'choices' in json_data and len(json_data['choices']) > 0:\n", 196 | " print(json_data['choices'][0]['delta'].get('content', ''), end='', flush=True)\n", 197 | " except json.JSONDecodeError as e:\n", 198 | " print(f\"Error decoding JSON: {e}\")\n", 199 | "except Exception as e:\n", 200 | " print(f\"An error occurred: {e}\")" 201 | ] 202 | } 203 | ], 204 | "metadata": { 205 | "kernelspec": { 206 | "display_name": "Python 3", 207 | "language": "python", 208 | "name": "python3" 209 | }, 210 | "language_info": { 211 | "codemirror_mode": { 212 | "name": "ipython", 213 | "version": 3 214 | }, 215 | "file_extension": ".py", 216 | "mimetype": "text/x-python", 217 | "name": "python", 218 | "nbconvert_exporter": "python", 219 | "pygments_lexer": "ipython3", 220 | "version": "3.10.6" 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 2 225 | } 226 | -------------------------------------------------------------------------------- /example_generate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Dependencies:\n", 10 | "\n", 11 | "#!pip install transformers ipywidgets IPython\n", 12 | "\n", 13 | "# If running on Colab:\n", 14 | "#!git clone https://github.com/sam-paech/antislop-sampler.git\n", 15 | "#!mv antislop-sampler/src .\n", 16 | "#!mv antislop-sampler/slop_phrase_prob_adjustments.json ." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 7, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Model loaded\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "import os\n", 34 | "import torch\n", 35 | "from transformers import (\n", 36 | " AutoModelForCausalLM,\n", 37 | " AutoTokenizer,\n", 38 | ")\n", 39 | "from src.antislop_generate import generate_antislop, chat_antislop\n", 40 | "\n", 41 | "# Enable efficient transfer for Hugging Face models\n", 42 | "os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = \"1\"\n", 43 | "\n", 44 | "# Set the device to 'cuda' if available, else 'cpu'\n", 45 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 46 | "\n", 47 | "# Specify the model name (replace with your preferred model)\n", 48 | "model_name = \"unsloth/Llama-3.2-1B-Instruct\"\n", 49 | "#model_name = \"unsloth/Llama-3.2-3B-Instruct\"\n", 50 | "\n", 51 | "# Load the model and tokenizer\n", 52 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 53 | "model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", 54 | "model.to(device)\n", 55 | "print('Model loaded')" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 8, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "import json\n", 65 | "if os.path.exists('slop_phrase_prob_adjustments.json'):\n", 66 | " with open('slop_phrase_prob_adjustments.json', 'r') as f:\n", 67 | " slop_phrase_prob_adjustments = dict(json.load(f)[:500])" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 9, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "prompt = \"Write a story about Elara, the weaver of tapestries in future Technopolis. In the bustling city, a group of \"" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 10, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stderr", 86 | "output_type": "stream", 87 | "text": [ 88 | "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n", 89 | "From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.\n" 90 | ] 91 | }, 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "In the heart of Future Templar, the city of synthetic and organic fusion, woven magic was a staple of everyday life. It was here that Adara, a young and talented weaver, made her name. Her fingers danced across the loom, creating magnificent works of art that told the stories of the city's inhabitants.\n", 97 | "\n", 98 | "Adara's workshop was a cozy, dimly lit room filled with the soft hum of machinery and the gentle rustle of threads. The walls were adorned with a series of colorful murals depicting the city's history, from the ancient days of the Templars to the present era of technological marvels.\n", 99 | "\n", 100 | "As the sun began to set, casting a warm orange glow over the city, Adara's apprentice, Kael, arrived at her doorstep. He was a lanky, energetic young man with a mop of messy black hair and a mischievous grin. Kael had been learning the art of weaving from Adara for several years, and he was eager to prove himself as a skilled weaver in his own right.\n", 101 | "\n", 102 | "\"Tonight, Kael, we're going to work on the new commission for the Guild of Artisans,\" Adara announced, her eyes sparkling with excitement. \"They want a beautiful, shimmering fabric for their upcoming exhibition.\"\n", 103 | "\n", 104 | "Kael's eyes widened as he took in the task at hand. He had always been impressed by Adara's skill and creativity, and he knew that this commission would be a great opportunity to showcase his own talents.\n", 105 | "\n", 106 | "Together, Adara and Kael set to work, their fingers moving in perfect sync as they wove a breathtaking fabric. The threads danced across the loom, creating a work of art that seemed to shimmer and glow in the light.\n", 107 | "\n", 108 | "As the night wore on, the workshop grew quiet, except for the soft hum of the machinery and the occasional rustle of threads. Adara and Kael worked in silence, lost in their own little worlds of creativity and imagination.\n", 109 | "\n", 110 | "Finally, after" 111 | ] 112 | } 113 | ], 114 | "source": [ 115 | "# Chat generation with streaming\n", 116 | "messages = [\n", 117 | " {\"role\": \"user\", \"content\": prompt}\n", 118 | "]\n", 119 | "tokens = []\n", 120 | "text = ''\n", 121 | "for token in chat_antislop(\n", 122 | " model=model,\n", 123 | " tokenizer=tokenizer,\n", 124 | " messages=messages,\n", 125 | " max_new_tokens=400,\n", 126 | " # Antislop sampling may be less reliable at low temperatures.\n", 127 | " #temperature=1, \n", 128 | " #min_p=0.1,\n", 129 | " temperature=0.01,\n", 130 | " # The adjustment_strength param scales how strongly the probability adjustments are applied.\n", 131 | " # A value of 1 means the values in slop_phrase_prob_adjustments (or the defaults) are used unmodified.\n", 132 | " # Reasonable values are 0 (disabled) thru 100+ (effectively banning the list).\n", 133 | " adjustment_strength=100.0,\n", 134 | " # Optional: Provide a list of slop phrases and probability adjustments\n", 135 | " slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,\n", 136 | " enforce_json=False,\n", 137 | " antislop_enabled=True,\n", 138 | " streaming=True,\n", 139 | " stream_smoothing=True, # On by default; this will smooth out the stutters from backtracking.\n", 140 | "):\n", 141 | " tokens.append(token)\n", 142 | " full_text = tokenizer.decode(tokens, skip_special_tokens=True)\n", 143 | " new_text = full_text[len(text):]\n", 144 | " text = full_text\n", 145 | " print(new_text, end='', flush=True)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 11, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "In the heart of Future City, a sprawling technological marvel, the air was alive with the hum of machinery and the chatter of pedestrians. The city's residents moved with purpose, their footsteps echoing off the sleek, silver skyscrapers that pierced the clouds. Among the throngs was a young woman named Elian, her dark hair tied back in a tight ponytail as she hurried down the main thoroughfare.\n", 158 | "\n", 159 | "Elian was a weaver, one of the city's skilled artisans who brought beauty and magic to the fabric of its inhabitants. Her workshop, a cozy room tucked away in a quiet alley, was a haven of creativity and tranquility. As she worked, her fingers moved deftly, the shuttle of her loom weaving a tale of wonder and enchantment.\n", 160 | "\n", 161 | "The weavers of Future City were renowned for their extraordinary talent, their creations imbued with the essence of the city's energy. Elian's own work was a fusion of traditional techniques and advanced machinery, allowing her to craft fabrics that shone like stars and pulsed with the heartbeat of the city.\n", 162 | "\n", 163 | "As Elian worked, a group of individuals caught her eye. There was Arin, the burly engineer who had designed the city's infrastructure; Lyra, the young hacker who had uncovered hidden patterns in the code; and Jax, the charismatic performer who had mastered the ancient art of acrobatics. Together, they formed an eclectic group known as the \"Synthetix,\" united by their passion for innovation and their love of storytelling.\n", 164 | "\n", 165 | "Their latest project was a grand commission from the city's governing council, a massive mural that would adorn the main square and celebrate the city's \n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "# Chat generation without streaming\n", 171 | "messages = [\n", 172 | " {\"role\": \"user\", \"content\": prompt}\n", 173 | "]\n", 174 | "generated_text = chat_antislop(\n", 175 | " model=model,\n", 176 | " tokenizer=tokenizer,\n", 177 | " messages=messages,\n", 178 | " max_length=400,\n", 179 | " temperature=1,\n", 180 | " min_p=0.1,\n", 181 | " adjustment_strength=100.0,\n", 182 | " slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,\n", 183 | " enforce_json=False,\n", 184 | " antislop_enabled=True,\n", 185 | " regex_bans=['(?i)not [^.!?]{3,60} but'], # if these regex expressions are matched, we backtrack, ban that continuation and retry\n", 186 | " streaming=False\n", 187 | ")\n", 188 | "print(tokenizer.decode(generated_text))" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 12, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "name": "stdout", 198 | "output_type": "stream", 199 | "text": [ 200 | " urchins stumble upon her at the market square, and a young girl named Aria discovers her talent.\n", 201 | "\n", 202 | "---\n", 203 | "\n", 204 | "The sun beat down on the market square, casting a warm glow over the crowded stalls. Vendors hawked their wares – fresh produce, exotic spices, and handmade crafts – while the smell of roasting meats and baking bread filled the air. In the midst of it all, a lone weaver sat at her stall, her fingers moving deftly as she wove a new design onto a large piece of fabric.\n", 205 | "\n", 206 | "The weaver, a young woman with a wild tangle of curly hair and a warm smile, worked with a quiet focus. Her fingers flew across the loom, the threads shimmering and glowing as she brought her vision to life. Aria, a young girl with a mop of curly brown hair and a curious gaze, wandered through the market, her eyes scanning the stalls for scraps of fabric to use in her own projects.\n", 207 | "\n", 208 | "As she passed by the weaver's stall, Aria caught sight of the beautiful, swirling patterns that covered the fabric. She felt an instant connection to the artistry, the creativity, and the beauty that radiated from the weaver's loom. The young girl's eyes locked onto the weaver, and she felt a shiver run down her spine.\n", 209 | "\n", 210 | "\"\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "# generate without streaming\n", 216 | "prompt_with_template = tokenizer.apply_chat_template(messages, tokenize=False)\n", 217 | "generated_text = generate_antislop(\n", 218 | " model=model,\n", 219 | " tokenizer=tokenizer,\n", 220 | " prompt=prompt,\n", 221 | " max_length=300,\n", 222 | " temperature=1,\n", 223 | " min_p=0.1,\n", 224 | " adjustment_strength=100.0,\n", 225 | " slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,\n", 226 | " enforce_json=False,\n", 227 | " antislop_enabled=True,\n", 228 | " streaming=False\n", 229 | ") \n", 230 | "print(tokenizer.decode(generated_text))" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 13, 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "name": "stdout", 240 | "output_type": "stream", 241 | "text": [ 242 | " 5 people gather in a hidden alleyway, each with a mysterious device in their hand. They are about to participate in a secret event, where they will compete in a tournament of wit and strategy, but as they begin, they realize that each of them has been sent by a powerful corporation to spy on each other, and a mysterious figure lurks in the shadows.\n", 243 | "\n", 244 | "The group consists of:\n", 245 | "\n", 246 | "1. **Aurora, a skilled hacker from the cybernetic gang, \"The Codekeepers\"**\n", 247 | "2. **Lysander, a master strategist from the espionage agency \"The Shadowhand\"**\n", 248 | "3. **Mira, a brilliant scientist from the research facility \"The Nexus\"**\n", 249 | "4. **Caspian, a charismatic con artist with ties to the underworld**\n", 250 | "5. **Nova, a rebellious young artist with a hidden past**\n", 251 | "\n", 252 | "As they enter the tournament hall, they are greeted by the host, **Ryker**, a charismatic and suave corporate executive. Ryker explains the rules of the tournament, but as the competitors begin to mingle and exchange information, they start to realize that something is amiss.\n", 253 | "\n", 254 | "Each of the competitors is receiving a mysterious device from a different corporation, and they all seem to have a hidden agenda. The device is a small, sleek box with a glowing screen" 255 | ] 256 | } 257 | ], 258 | "source": [ 259 | "# generate with streaming\n", 260 | "prompt_with_template = tokenizer.apply_chat_template(messages, tokenize=False)\n", 261 | "tokens = []\n", 262 | "text = \"\"\n", 263 | "for token in generate_antislop(\n", 264 | " model=model,\n", 265 | " tokenizer=tokenizer,\n", 266 | " prompt=prompt,\n", 267 | " max_length=300,\n", 268 | " temperature=1,\n", 269 | " min_p=0.1,\n", 270 | " slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,\n", 271 | " adjustment_strength=100.0,\n", 272 | " enforce_json=False,\n", 273 | " antislop_enabled=True,\n", 274 | " streaming=True\n", 275 | "):\n", 276 | " tokens.append(token)\n", 277 | " full_text = tokenizer.decode(tokens, skip_special_tokens=True)\n", 278 | " new_text = full_text[len(text):]\n", 279 | " text = full_text\n", 280 | " print(new_text, end='', flush=True)" 281 | ] 282 | } 283 | ], 284 | "metadata": { 285 | "kernelspec": { 286 | "display_name": "Python 3", 287 | "language": "python", 288 | "name": "python3" 289 | }, 290 | "language_info": { 291 | "codemirror_mode": { 292 | "name": "ipython", 293 | "version": 3 294 | }, 295 | "file_extension": ".py", 296 | "mimetype": "text/x-python", 297 | "name": "python", 298 | "nbconvert_exporter": "python", 299 | "pygments_lexer": "ipython3", 300 | "version": "3.10.6" 301 | } 302 | }, 303 | "nbformat": 4, 304 | "nbformat_minor": 2 305 | } 306 | -------------------------------------------------------------------------------- /example_visualise_backtracking.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Backtracking Visualisation for the AntiSlop Sampler\n", 8 | "\n", 9 | "This notebook demonstrates the AntiSlop sampler by adding delays when slop phrases are detected & replaced, as well as some debug output.\n", 10 | "\n", 11 | "https://github.com/sam-paech/antislop-sampler\n", 12 | "\n", 13 | "### Update 2024-10-05\n", 14 | "\n", 15 | "- Switch to probs from logits for the cached values, so that down/upregulation works as expected.\n", 16 | "- Use model.generate for multiple tokens (not just 1 at a time) with StoppingCondition criteria. This is much faster.\n", 17 | "- Add json validation including long-range checks for unintended unescaped quotes in strings." 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "# Dependencies:\n", 27 | "\n", 28 | "#!pip install transformers ipywidgets IPython\n", 29 | "\n", 30 | "# If running on Colab:\n", 31 | "#!git clone https://github.com/sam-paech/antislop-sampler.git\n", 32 | "#!mv antislop-sampler/src .\n", 33 | "#!mv antislop-sampler/slop_phrase_prob_adjustments.json ." 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import os\n", 43 | "import json\n", 44 | "import torch\n", 45 | "from transformers import (\n", 46 | " AutoModelForCausalLM,\n", 47 | " AutoTokenizer,\n", 48 | ")\n", 49 | "from IPython.display import display, HTML\n", 50 | "from ipywidgets import Output\n", 51 | "from src.antislop_generate import AntiSlopSampler, chat_antislop, generate_antislop\n", 52 | "\n", 53 | "# Enable efficient transfer for Hugging Face models\n", 54 | "os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = \"1\"\n", 55 | "\n", 56 | "# Set the device to 'cuda' if available, else 'cpu'\n", 57 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 58 | "\n", 59 | "# Specify the model name (replace with your preferred model)\n", 60 | "model_name = \"unsloth/Llama-3.2-1B-Instruct\"\n", 61 | "#model_name = \"unsloth/Llama-3.2-3B-Instruct\"\n", 62 | "\n", 63 | "# Load the model and tokenizer\n", 64 | "model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n", 65 | "model.to(device)\n", 66 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 67 | "\n", 68 | "# These are a mix of gpt-slop found in various online lists, plus a much larger set of automatically derived over-represented words in a GPT-generated dataset.\n", 69 | "# See slopcalc.ipynb for the code to generate a more complete list.\n", 70 | "slop_phrase_prob_adjustments = [['kaleidoscope', 0.5], ['symphony', 0.5], ['testament to', 0.5], ['elara', 0.5], ['moth to a flame', 0.5], ['canvas', 0.5], ['eyes glinted', 0.5], ['camaraderie', 0.5], ['humble abode', 0.5], ['cold and calculating', 0.5], ['eyes never leaving', 0.5], ['tapestry', 0.5], ['barely above a whisper', 0.5], ['body and soul', 0.5], ['orchestra', 0.5], ['depths', 0.5], ['a dance of', 0.5], ['chuckles darkly', 0.5], ['maybe, just maybe', 0.5], ['maybe that was enough', 0.5], ['with a mixture of', 0.5], ['air was filled with anticipation', 0.5], ['cacophony', 0.5], ['bore silent witness to', 0.5], ['eyes sparkling with mischief', 0.5], ['was only just beginning', 0.5], ['practiced ease', 0.5], ['ready for the challenges', 0.5], ['only just getting started', 0.5], ['once upon a time', 0.5], ['nestled deep within', 0.5], ['ethereal beauty', 0.5], ['life would never be the same again.', 0.5], [\"it's important to remember\", 0.5], ['for what seemed like an eternity', 0.5], ['feel a sense of pride and accomplishment', 0.5], ['little did he know', 0.5], ['ball is in your court', 0.5], ['game is on', 0.5], ['choice is yours', 0.5], ['feels like an electric shock', 0.5], ['threatens to consume', 0.5], ['meticulous', 0.5], ['meticulously', 0.5], ['navigating', 0.5], ['complexities', 0.5], ['realm', 0.5], ['understanding', 0.5], ['dive into', 0.5], ['shall', 0.5], ['tailored', 0.5], ['towards', 0.5], ['underpins', 0.5], ['everchanging', 0.5], ['ever-evolving', 0.5], ['world of', 0.5], ['not only', 0.5], ['alright', 0.5], ['embark', 0.5], ['journey', 0.5], [\"today's digital age\", 0.5], ['game changer', 0.5], ['designed to enhance', 0.5], ['it is advisable', 0.5], ['daunting', 0.5], ['when it comes to', 0.5], ['in the realm of', 0.5], ['amongst', 0.5], ['unlock the secrets', 0.5], ['unveil the secrets', 0.5], ['and robust', 0.5], ['diving', 0.5], ['elevate', 0.5], ['unleash', 0.5], ['cutting-edge', 0.5], ['rapidly', 0.5], ['expanding', 0.5], ['mastering', 0.5], ['excels', 0.5], ['harness', 0.5], [\"it's important to note\", 0.5], ['delve into', 0.5], ['bustling', 0.5], ['in summary', 0.5], ['remember that', 0.5], ['take a dive into', 0.5], ['landscape', 0.5], ['in the world of', 0.5], ['vibrant', 0.5], ['metropolis', 0.5], ['firstly', 0.5], ['moreover', 0.5], ['crucial', 0.5], ['to consider', 0.5], ['essential', 0.5], ['there are a few considerations', 0.5], ['ensure', 0.5], [\"it's essential to\", 0.5], ['furthermore', 0.5], ['vital', 0.5], ['keen', 0.5], ['fancy', 0.5], ['as a professional', 0.5], ['however', 0.5], ['therefore', 0.5], ['additionally', 0.5], ['specifically', 0.5], ['generally', 0.5], ['consequently', 0.5], ['importantly', 0.5], ['indeed', 0.5], ['thus', 0.5], ['alternatively', 0.5], ['notably', 0.5], ['as well as', 0.5], ['despite', 0.5], ['essentially', 0.5], ['even though', 0.5], ['in contrast', 0.5], ['in order to', 0.5], ['due to', 0.5], ['even if', 0.5], ['given that', 0.5], ['arguably', 0.5], ['you may want to', 0.5], ['on the other hand', 0.5], ['as previously mentioned', 0.5], [\"it's worth noting that\", 0.5], ['to summarize', 0.5], ['ultimately', 0.5], ['to put it simply', 0.5], [\"in today's digital era\", 0.5], ['reverberate', 0.5], ['enhance', 0.5], ['emphasize', 0.5], ['revolutionize', 0.5], ['foster', 0.5], ['remnant', 0.5], ['subsequently', 0.5], ['nestled', 0.5], ['labyrinth', 0.5], ['gossamer', 0.5], ['enigma', 0.5], ['whispering', 0.5], ['sights unseen', 0.5], ['sounds unheard', 0.5], ['indelible', 0.5], ['my friend', 0.5], ['in conclusion', 0.5], ['technopolis', 0.5], ['was soft and gentle', 0.5], ['shivers down', 0.5], ['shivers up', 0.5], ['leaving trails of fire', 0.5], ['ministrations', 0.5], ['audible pop', 0.5], ['rivulets of', 0.5], ['despite herself', 0.5], ['reckless abandon', 0.5], ['torn between', 0.5], ['fiery red hair', 0.5], ['long lashes', 0.5], ['propriety be damned', 0.5], ['world narrows', 0.5], ['chestnut eyes', 0.5], ['cheeks flaming', 0.5], ['cheeks hollowing', 0.5], ['understandingly', 0.5], ['paperbound', 0.5], ['hesitantly', 0.5], ['piqued', 0.5], ['delved', 0.5], ['curveballs', 0.5], ['marveled', 0.5], ['inclusivity', 0.5], ['birdwatcher', 0.5], ['newfound', 0.5031423922762257], ['marveling', 0.5055622891781474], [\"hiroshi's\", 0.506870969939047], ['greentech', 0.5095092042816856], ['thoughtfully', 0.510153898156777], ['intently', 0.5153227374075411], ['birdwatching', 0.5157928537951464], ['amidst', 0.5161190296674488], ['cherishing', 0.5165772000484282], ['attentively', 0.5169695157301188], ['interjected', 0.5208671011920856], ['serendipitous', 0.5219535186850968], [\"marianne's\", 0.5220118279910801], [\"maya's\", 0.5229467776607973], ['excitedly', 0.5235248665614571], ['steepled', 0.5235772300889154], ['engrossed', 0.5236764398055735], ['fostering', 0.5259281627970829], ['brainstormed', 0.5274863713437], ['furrowed', 0.5280860997212533], ['nodded', 0.528640180937889], ['contemplatively', 0.5293698584415747], ['jotted', 0.5300819077932343], [\"mia's\", 0.5311706933553655]];\n", 71 | "slop_phrase_prob_adjustments = dict(slop_phrase_prob_adjustments)\n", 72 | "\n", 73 | "if os.path.exists('slop_phrase_prob_adjustments.json'):\n", 74 | " with open('slop_phrase_prob_adjustments.json', 'r') as f:\n", 75 | " slop_phrase_prob_adjustments = dict(json.load(f)[:500])\n", 76 | " #slop_phrase_prob_adjustments = dict(json.load(f))" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "### Test anti-slop sampling" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 3, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "text/html": [ 94 | "

Prompt

" 95 | ], 96 | "text/plain": [ 97 | "" 98 | ] 99 | }, 100 | "metadata": {}, 101 | "output_type": "display_data" 102 | }, 103 | { 104 | "data": { 105 | "application/vnd.jupyter.widget-view+json": { 106 | "model_id": "e213c1b6f6114c90ae65c4a08ec5dcf7", 107 | "version_major": 2, 108 | "version_minor": 0 109 | }, 110 | "text/plain": [ 111 | "Output()" 112 | ] 113 | }, 114 | "metadata": {}, 115 | "output_type": "display_data" 116 | }, 117 | { 118 | "data": { 119 | "text/html": [ 120 | "

Inference Output

" 121 | ], 122 | "text/plain": [ 123 | "" 124 | ] 125 | }, 126 | "metadata": {}, 127 | "output_type": "display_data" 128 | }, 129 | { 130 | "data": { 131 | "application/vnd.jupyter.widget-view+json": { 132 | "model_id": "6e24357e234841aaa2e44e8b6cdb7368", 133 | "version_major": 2, 134 | "version_minor": 0 135 | }, 136 | "text/plain": [ 137 | "Output()" 138 | ] 139 | }, 140 | "metadata": {}, 141 | "output_type": "display_data" 142 | }, 143 | { 144 | "data": { 145 | "text/html": [ 146 | "

Debug Information

" 147 | ], 148 | "text/plain": [ 149 | "" 150 | ] 151 | }, 152 | "metadata": {}, 153 | "output_type": "display_data" 154 | }, 155 | { 156 | "data": { 157 | "application/vnd.jupyter.widget-view+json": { 158 | "model_id": "4730fcb8281d42a98ae94afcac457a91", 159 | "version_major": 2, 160 | "version_minor": 0 161 | }, 162 | "text/plain": [ 163 | "Output()" 164 | ] 165 | }, 166 | "metadata": {}, 167 | "output_type": "display_data" 168 | }, 169 | { 170 | "name": "stderr", 171 | "output_type": "stream", 172 | "text": [ 173 | "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n" 174 | ] 175 | } 176 | ], 177 | "source": [ 178 | "# Example prompt for the model\n", 179 | "prompt = \"Once upon a time, in a bustling city of Technopolis, there lived a weaver named Elara.\"\n", 180 | "\n", 181 | "# Define the messages for a chat scenario (for example purposes)\n", 182 | "messages = [\n", 183 | " {\"role\": \"user\", \"content\": prompt}\n", 184 | "]\n", 185 | "\n", 186 | "# Display the output widgets\n", 187 | "prompt_output = Output()\n", 188 | "inference_output = Output()\n", 189 | "debug_output = Output()\n", 190 | "\n", 191 | "display(HTML(\"

Prompt

\"))\n", 192 | "display(prompt_output)\n", 193 | "display(HTML(\"

Inference Output

\"))\n", 194 | "display(inference_output)\n", 195 | "display(HTML(\"

Debug Information

\"))\n", 196 | "display(debug_output)\n", 197 | "\n", 198 | "# Initialize display (if using in Jupyter)\n", 199 | "with prompt_output:\n", 200 | " prompt_output.clear_output(wait=True)\n", 201 | " display(HTML(f\"
{prompt}
\"))\n", 202 | "\n", 203 | "if True:\n", 204 | " VALIDATE_JASON_STRINGS = True\n", 205 | " # Call the chat_antislop to generate the story with the given messages\n", 206 | " output_tokens = chat_antislop(\n", 207 | " model=model,\n", 208 | " tokenizer=tokenizer,\n", 209 | " messages=messages,\n", 210 | " max_new_tokens=400, \n", 211 | " temperature=1,\n", 212 | " min_p=0.1,\n", 213 | " slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,\n", 214 | " adjustment_strength=100,\n", 215 | " device=device,\n", 216 | " streaming=False,\n", 217 | " slow_debug=True, # Enable slow debugging\n", 218 | " output_every_n_tokens=3,\n", 219 | " debug_delay=1, # Set delay for debug output\n", 220 | " inference_output=inference_output, # Visualization of the text output\n", 221 | " enforce_json=False,\n", 222 | " debug_output=debug_output # Visualization of the debug information\n", 223 | " )\n", 224 | "\n", 225 | " inference = tokenizer.decode(output_tokens, skip_special_tokens=True)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "\n", 233 | "\n", 234 | "\n", 235 | "\n", 236 | "\n", 237 | "\n", 238 | "\n", 239 | "\n", 240 | "\n", 241 | "\n", 242 | "\n", 243 | "\n", 244 | "\n", 245 | "\n", 246 | "\n", 247 | "\n", 248 | "\n", 249 | "\n", 250 | "\n", 251 | "\n", 252 | "\n", 253 | "\n", 254 | "\n", 255 | "\n", 256 | "\n", 257 | "\n" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "## Evaluate json constraint enforcement\n", 265 | "\n", 266 | "Note: This evaluation will fail output that contains multiple separate json elements (like json-lines format). The enforce_json flag does not enforce this; it only enforces json validity within a json block.\n", 267 | "\n", 268 | "As such, it's suggested to use a model that will at least follow the output format for this eval to work as intended (i.e. a stronger model than the example Llama-3.2-1B used above)" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 4, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "from tqdm import tqdm\n", 278 | "\n", 279 | "prompt = 'Write a short scene of dialogue between two characters, Jill and Steve. Output in json with this format: {\"scene\": \"\"}'\n", 280 | "\n", 281 | "# Define the messages for a chat scenario (for example purposes)\n", 282 | "messages = [\n", 283 | " {\"role\": \"user\", \"content\": prompt}\n", 284 | "]\n", 285 | "\n", 286 | "if False:\n", 287 | " inference_output = Output()\n", 288 | " debug_output = Output()\n", 289 | " display(HTML(\"

Inference Output

\"))\n", 290 | " display(inference_output)\n", 291 | " display(HTML(\"

Debug Information

\"))\n", 292 | " display(debug_output)\n", 293 | "\n", 294 | " n_iterations = 20\n", 295 | "\n", 296 | " valid_count = 0\n", 297 | " for i in tqdm(range(n_iterations)):\n", 298 | " output_tokens = [] \n", 299 | "\n", 300 | " output_tokens = chat_antislop(\n", 301 | " model=model,\n", 302 | " tokenizer=tokenizer,\n", 303 | " messages=messages,\n", 304 | " max_new_tokens=800, \n", 305 | " temperature=1,\n", 306 | " min_p=0.1,\n", 307 | " slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,\n", 308 | " adjustment_strength=5.0,\n", 309 | " device=device,\n", 310 | " streaming=False,\n", 311 | " slow_debug=False, # Enable slow debugging\n", 312 | " output_every_n_tokens=1, # Update every 5 tokens\n", 313 | " debug_delay=1, # Set delay for debug output\n", 314 | " inference_output=inference_output, # Visualization of the text output\n", 315 | " enforce_json=True,\n", 316 | " debug_output=debug_output # Visualization of the debug information\n", 317 | " )\n", 318 | "\n", 319 | " inference = tokenizer.decode(output_tokens, skip_special_tokens=True)\n", 320 | " lpos = inference.find('{')\n", 321 | " rpos = inference.rfind('}')\n", 322 | " if lpos >= 0 and rpos > 0:\n", 323 | " inference = inference[lpos:rpos+1]\n", 324 | " try:\n", 325 | " parsed = json.loads(inference)\n", 326 | " valid_count += 1\n", 327 | " except Exception as e:\n", 328 | " print([inference])\n", 329 | " print(e)\n", 330 | " pass\n", 331 | "\n", 332 | " print('json successfully parsed:', valid_count, 'of', n_iterations)\n", 333 | " print('')\n" 334 | ] 335 | } 336 | ], 337 | "metadata": { 338 | "kernelspec": { 339 | "display_name": "Python 3", 340 | "language": "python", 341 | "name": "python3" 342 | }, 343 | "language_info": { 344 | "codemirror_mode": { 345 | "name": "ipython", 346 | "version": 3 347 | }, 348 | "file_extension": ".py", 349 | "mimetype": "text/x-python", 350 | "name": "python", 351 | "nbconvert_exporter": "python", 352 | "pygments_lexer": "ipython3", 353 | "version": "3.10.6" 354 | } 355 | }, 356 | "nbformat": 4, 357 | "nbformat_minor": 2 358 | } 359 | -------------------------------------------------------------------------------- /run_api.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | from typing import List, Dict, Union, Optional, Any, AsyncGenerator, Tuple 5 | 6 | from fastapi import FastAPI, HTTPException, Request 7 | from fastapi.responses import StreamingResponse 8 | from pydantic import BaseModel, Field 9 | import uvicorn 10 | import logging 11 | import threading 12 | 13 | import torch 14 | from transformers import ( 15 | PreTrainedModel, 16 | PreTrainedTokenizer, 17 | AutoTokenizer, 18 | AutoModelForCausalLM, 19 | BitsAndBytesConfig # Ensure this is imported if you use quantization 20 | ) 21 | 22 | # Set up logging 23 | logging.basicConfig(level=logging.INFO) # Set to DEBUG for more detailed logs 24 | logger = logging.getLogger(__name__) 25 | 26 | # Import your custom antislop_generate module 27 | from src.antislop_generate import chat_antislop, generate_antislop 28 | 29 | app = FastAPI(title="AntiSlop OpenAI-Compatible API") 30 | 31 | # Global variables to hold the model and tokenizer 32 | model: Optional[PreTrainedModel] = None 33 | tokenizer: Optional[PreTrainedTokenizer] = None 34 | DEFAULT_SLOP_ADJUSTMENTS: Dict[str, float] = {} 35 | device: Optional[torch.device] = None # Modified to allow dynamic setting 36 | 37 | # Variables to store model metadata 38 | model_loaded_time: Optional[int] = None 39 | model_name_loaded: Optional[str] = None 40 | 41 | # Define a global asyncio.Lock to enforce single concurrency 42 | import asyncio 43 | lock = asyncio.Lock() 44 | 45 | # Define Pydantic models for request and response schemas 46 | 47 | class CompletionRequest(BaseModel): 48 | model: Optional[str] = Field(default=None, description="Model to use for completion") 49 | prompt: Union[str, List[str]] 50 | max_tokens: Optional[int] = Field(default=None, ge=1, description="Maximum number of tokens to generate") 51 | temperature: Optional[float] = Field(default=1.0, ge=0.0, description="Sampling temperature") 52 | top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Nucleus sampling probability") 53 | top_k: Optional[int] = Field(default=None, ge=0, description="Top-K sampling") 54 | min_p: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Minimum probability threshold") 55 | stream: Optional[bool] = Field(default=False, description="Whether to stream back partial progress") 56 | slop_phrases: Optional[List[Tuple[str, float]]] = Field( 57 | default=None, 58 | description="List of slop phrases and their adjustment values, e.g., [['a testament to', 0.3], ['tapestry of', 0.1]]" 59 | ) 60 | adjustment_strength: Optional[float] = Field(default=20.0, ge=0.0, description="Strength of adjustments") 61 | enforce_json: Optional[bool] = Field(default=False, description="Enforce JSON formatting") 62 | antislop_enabled: Optional[bool] = Field(default=True, description="Enable AntiSlop functionality") 63 | regex_bans: Optional[List[str]] = Field(default=False, description="Ban strings matching these regex expressions") 64 | 65 | 66 | class ChatCompletionMessage(BaseModel): 67 | role: str 68 | content: str 69 | 70 | 71 | class ChatCompletionRequest(BaseModel): 72 | model: Optional[str] = Field(default=None, description="Model to use for completion") 73 | messages: List[ChatCompletionMessage] 74 | max_tokens: Optional[int] = Field(default=None, ge=1, description="Maximum number of tokens to generate") 75 | temperature: Optional[float] = Field(default=1.0, ge=0.0, description="Sampling temperature") 76 | top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Nucleus sampling probability") 77 | top_k: Optional[int] = Field(default=None, ge=0, description="Top-K sampling") 78 | min_p: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Minimum probability threshold") 79 | stream: Optional[bool] = Field(default=False, description="Whether to stream back partial progress") 80 | slop_phrases: Optional[List[Tuple[str, float]]] = Field( 81 | default=None, 82 | description="List of slop phrases and their adjustment values, e.g., [['a testament to', 0.3], ['tapestry of', 0.1]]" 83 | ) 84 | adjustment_strength: Optional[float] = Field(default=20.0, ge=0.0, description="Strength of adjustments") 85 | enforce_json: Optional[bool] = Field(default=False, description="Enforce JSON formatting") 86 | antislop_enabled: Optional[bool] = Field(default=True, description="Enable AntiSlop functionality") 87 | regex_bans: Optional[List[str]] = Field(default=False, description="Ban strings matching these regex expressions") 88 | 89 | 90 | class CompletionChoice(BaseModel): 91 | text: str 92 | index: int 93 | logprobs: Optional[Any] = None 94 | finish_reason: Optional[str] = None 95 | 96 | 97 | class ChatCompletionChoice(BaseModel): 98 | message: ChatCompletionMessage 99 | index: int 100 | finish_reason: Optional[str] = None 101 | 102 | 103 | class CompletionResponse(BaseModel): 104 | id: str 105 | object: str 106 | created: int 107 | model: str 108 | choices: List[CompletionChoice] 109 | usage: Dict[str, int] 110 | 111 | 112 | class ChatCompletionResponse(BaseModel): 113 | id: str 114 | object: str 115 | created: int 116 | model: str 117 | choices: List[ChatCompletionChoice] 118 | usage: Dict[str, int] 119 | 120 | 121 | # New Pydantic models for /v1/models endpoint 122 | 123 | class ModelInfo(BaseModel): 124 | id: str 125 | object: str = "model" 126 | created: int 127 | owned_by: str 128 | permission: List[Any] = [] 129 | root: str 130 | parent: Optional[str] = None 131 | 132 | 133 | class ModelsResponse(BaseModel): 134 | object: str = "list" 135 | data: List[ModelInfo] 136 | 137 | 138 | # Utility functions 139 | 140 | import uuid 141 | import time 142 | import queue # Import queue for thread-safe communication 143 | 144 | def generate_id() -> str: 145 | return str(uuid.uuid4()) 146 | 147 | 148 | def current_timestamp() -> int: 149 | return int(time.time()) 150 | 151 | 152 | def load_slop_adjustments(file_path: Optional[str]) -> Dict[str, float]: 153 | if file_path is None: 154 | return {} 155 | if not os.path.exists(file_path): 156 | raise FileNotFoundError(f"Slop phrase adjustments file not found: {file_path}") 157 | with open(file_path, 'r', encoding='utf-8') as f: 158 | try: 159 | adjustments = dict(json.load(f)) 160 | if not isinstance(adjustments, dict): 161 | raise ValueError("Slop phrase adjustments file must contain a JSON object (dictionary).") 162 | # Ensure all values are floats 163 | for key, value in adjustments.items(): 164 | adjustments[key] = float(value) 165 | return adjustments 166 | except json.JSONDecodeError as e: 167 | raise ValueError(f"Error decoding JSON from slop phrase adjustments file: {e}") 168 | 169 | 170 | # Startup event to load model and tokenizer 171 | @app.on_event("startup") 172 | async def load_model_and_tokenizer(): 173 | global model, tokenizer, DEFAULT_SLOP_ADJUSTMENTS, device 174 | global model_loaded_time, model_name_loaded 175 | 176 | # Set device based on GPU_ID environment variable 177 | gpu_id = os.getenv("GPU_ID") 178 | if gpu_id is not None: 179 | try: 180 | gpu_id_int = int(gpu_id) 181 | if torch.cuda.is_available() and gpu_id_int < torch.cuda.device_count(): 182 | device = torch.device(f"cuda:{gpu_id_int}") 183 | else: 184 | logger.warning(f"Specified GPU ID {gpu_id_int} is not available. Falling back to default GPU or CPU.") 185 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 186 | except ValueError: 187 | logger.warning(f"Invalid GPU ID '{gpu_id}'. Falling back to default GPU or CPU.") 188 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 189 | else: 190 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 191 | 192 | # Load configuration from environment variables 193 | model_name = os.getenv("MODEL_NAME", "") 194 | load_in_4bit = os.getenv("LOAD_IN_4BIT", "false").lower() == "true" 195 | load_in_8bit = os.getenv("LOAD_IN_8BIT", "false").lower() == "true" 196 | slop_adjustments_file = os.getenv("SLOP_ADJUSTMENTS_FILE", None) 197 | 198 | # Validate mutually exclusive flags 199 | if load_in_4bit and load_in_8bit: 200 | logger.error("Cannot set both LOAD_IN_4BIT and LOAD_IN_8BIT. Choose one.") 201 | raise ValueError("Cannot set both LOAD_IN_4BIT and LOAD_IN_8BIT. Choose one.") 202 | 203 | # Load slop phrase adjustments from file if provided 204 | try: 205 | DEFAULT_SLOP_ADJUSTMENTS = load_slop_adjustments(slop_adjustments_file) 206 | logger.info(f"Loaded {len(DEFAULT_SLOP_ADJUSTMENTS)} slop phrase adjustments.") 207 | except Exception as e: 208 | logger.error(f"Failed to load slop adjustments: {e}") 209 | raise ValueError("Slop adjustments file could not be loaded. Make sure you have the right file path and file structure.") 210 | 211 | logger.info(f"Using device: {device}") 212 | 213 | # Load tokenizer 214 | logger.info(f"Loading tokenizer for model '{model_name}'...") 215 | try: 216 | tokenizer = AutoTokenizer.from_pretrained(model_name) 217 | if tokenizer.pad_token is None: 218 | tokenizer.pad_token = tokenizer.eos_token # Ensure pad_token is set 219 | logger.info("Tokenizer loaded.") 220 | except Exception as e: 221 | logger.error(f"Error loading tokenizer: {e}") 222 | raise e 223 | 224 | # Load model with appropriate precision 225 | logger.info(f"Loading model '{model_name}'...") 226 | try: 227 | if load_in_4bit: 228 | # Configure 4-bit loading 229 | quantization_config = BitsAndBytesConfig( 230 | load_in_4bit=True, 231 | bnb_4bit_use_double_quant=True, 232 | bnb_4bit_quant_type='nf4', 233 | bnb_4bit_compute_dtype=torch.float16 234 | ) 235 | try: 236 | import bitsandbytes # Ensure bitsandbytes is installed 237 | except ImportError: 238 | logger.error("bitsandbytes is required for 4-bit loading. Install it via 'pip install bitsandbytes'.") 239 | raise ImportError("bitsandbytes is required for 4-bit loading. Install it via 'pip install bitsandbytes'.") 240 | 241 | model = AutoModelForCausalLM.from_pretrained( 242 | model_name, 243 | quantization_config=quantization_config, 244 | device_map="auto" 245 | ) 246 | logger.info("Model loaded in 4-bit precision.") 247 | elif load_in_8bit: 248 | # Configure 8-bit loading 249 | quantization_config = BitsAndBytesConfig( 250 | load_in_8bit=True 251 | ) 252 | try: 253 | import bitsandbytes # Ensure bitsandbytes is installed 254 | except ImportError: 255 | logger.error("bitsandbytes is required for 8-bit loading. Install it via 'pip install bitsandbytes'.") 256 | raise ImportError("bitsandbytes is required for 8-bit loading. Install it via 'pip install bitsandbytes'.") 257 | 258 | model = AutoModelForCausalLM.from_pretrained( 259 | model_name, 260 | quantization_config=quantization_config, 261 | device_map="auto" 262 | ) 263 | logger.info("Model loaded in 8-bit precision.") 264 | else: 265 | # Load model normally 266 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) 267 | try: 268 | model.to(device) 269 | except Exception as e: 270 | # if an already quantised model is loaded, the model.to(device) will 271 | # throw a benign error that we can ignore. 272 | print(e) 273 | logger.info("Model loaded in 16-bit precision.") 274 | except Exception as e: 275 | logger.error(f"Error loading model: {e}") 276 | raise e # Let FastAPI handle the startup failure 277 | 278 | logger.info("Model and tokenizer loaded successfully.") 279 | 280 | # Store model metadata 281 | model_loaded_time = current_timestamp() 282 | model_name_loaded = model_name 283 | 284 | 285 | # Utility function for streaming responses 286 | 287 | def generate_id() -> str: 288 | return str(uuid.uuid4()) 289 | 290 | 291 | def current_timestamp() -> int: 292 | return int(time.time()) 293 | 294 | 295 | async def stream_tokens_sync(generator: Any, is_chat: bool = False) -> AsyncGenerator[str, None]: 296 | """ 297 | Converts a synchronous generator to an asynchronous generator for streaming responses. 298 | Formats the output to match OpenAI's streaming response format. 299 | """ 300 | q = queue.Queue() 301 | 302 | def generator_thread(): 303 | try: 304 | logger.debug("Generator thread started.") 305 | for token in generator: 306 | q.put(token) 307 | logger.debug(f"Token put into queue: {token}") 308 | q.put(None) # Signal completion 309 | logger.debug("Generator thread completed.") 310 | except Exception as e: 311 | logger.error(f"Exception in generator_thread: {e}") 312 | q.put(e) # Signal exception 313 | 314 | # Start the generator in a separate daemon thread 315 | thread = threading.Thread(target=generator_thread, daemon=True) 316 | thread.start() 317 | logger.debug("Generator thread initiated.") 318 | 319 | try: 320 | tokens = [] 321 | text = '' 322 | while True: 323 | token = await asyncio.to_thread(q.get) 324 | logger.debug(f"Token retrieved from queue: {token}") 325 | 326 | if token is None: 327 | # Send final finish_reason to indicate the end of the stream 328 | finish_data = { 329 | "choices": [ 330 | { 331 | "delta": {}, 332 | "index": 0, 333 | "finish_reason": "stop" 334 | } 335 | ] 336 | } 337 | yield f"data: {json.dumps(finish_data)}\n\n" 338 | logger.debug("Finished streaming tokens.") 339 | break 340 | 341 | if isinstance(token, Exception): 342 | # Handle exceptions by sending a finish_reason with 'error' 343 | error_data = { 344 | "choices": [ 345 | { 346 | "delta": {}, 347 | "index": 0, 348 | "finish_reason": "error" 349 | } 350 | ] 351 | } 352 | yield f"data: {json.dumps(error_data)}\n\n" 353 | logger.error(f"Exception during token streaming: {token}") 354 | break # Exit the loop after handling the error 355 | 356 | # Decode the token to text 357 | # Note: This is inefficient; we're decoding the whole sequence every time a 358 | # new token comes in. This is to handle cases where the tokeniser doesn't 359 | # prepend spaces to words. There's probably a better way to do this. 360 | tokens.append(token) 361 | full_text = tokenizer.decode(tokens, skip_special_tokens=True) 362 | new_text = full_text[len(text):] 363 | text = full_text 364 | logger.debug(f"Decoded token: {new_text}") 365 | 366 | # Prepare the data in OpenAI's streaming format 367 | data = { 368 | "choices": [ 369 | { 370 | "delta": {"content": new_text}, 371 | "index": 0, 372 | "finish_reason": None 373 | } 374 | ] 375 | } 376 | 377 | # Yield the formatted data as a Server-Sent Event (SSE) 378 | yield f"data: {json.dumps(data)}\n\n" 379 | logger.debug("Yielded token to client.") 380 | 381 | # Yield control back to the event loop 382 | await asyncio.sleep(0) 383 | 384 | except asyncio.CancelledError: 385 | logger.warning("Streaming task was cancelled by the client.") 386 | except Exception as e: 387 | logger.error(f"Unexpected error in stream_tokens_sync: {e}") 388 | finally: 389 | logger.debug("Exiting stream_tokens_sync.") 390 | 391 | 392 | # Endpoint: /v1/completions 393 | @app.post("/v1/completions", response_model=CompletionResponse) 394 | async def completions(request: CompletionRequest, req: Request): 395 | logger.info("Completion request received, waiting for processing...") 396 | 397 | if request.stream and request.regex_bans: 398 | raise HTTPException(status_code=500, detail="Streaming cannot be enabled when using regex bans.") 399 | 400 | try: 401 | if model is None or tokenizer is None: 402 | logger.error("Model and tokenizer are not loaded.") 403 | raise HTTPException(status_code=500, detail="Model and tokenizer are not loaded.") 404 | 405 | # Use the model specified in the request or default 406 | used_model = request.model if request.model else model_name_loaded 407 | 408 | # Handle prompt as string or list 409 | if isinstance(request.prompt, list): 410 | prompt = "\n".join(request.prompt) 411 | else: 412 | prompt = request.prompt 413 | 414 | # Process slop_phrases parameter 415 | if request.slop_phrases is not None: 416 | # Convert list of tuples to dictionary 417 | slop_adjustments = dict(request.slop_phrases) 418 | logger.debug(f"Slop adjustments provided: {slop_adjustments}") 419 | else: 420 | # Use default slop_phrase_prob_adjustments 421 | slop_adjustments = DEFAULT_SLOP_ADJUSTMENTS.copy() 422 | logger.debug(f"Using default slop adjustments with {len(slop_adjustments)} entries.") 423 | 424 | if request.stream: 425 | logger.info("Streaming completion request started.") 426 | # Streaming response 427 | generator_source = generate_antislop( 428 | model=model, 429 | tokenizer=tokenizer, 430 | prompt=prompt, 431 | max_new_tokens=request.max_tokens, 432 | temperature=request.temperature, 433 | top_k=request.top_k, 434 | top_p=request.top_p, 435 | min_p=request.min_p, 436 | slop_phrase_prob_adjustments=slop_adjustments, 437 | adjustment_strength=request.adjustment_strength, 438 | device=device, 439 | streaming=True, 440 | slow_debug=False, # Adjust as needed 441 | output_every_n_tokens=1, 442 | debug_delay=0.0, 443 | inference_output=None, 444 | debug_output=None, 445 | enforce_json=request.enforce_json, 446 | antislop_enabled=request.antislop_enabled, 447 | regex_bans=request.regex_bans 448 | ) 449 | 450 | async def streaming_generator(): 451 | async with lock: 452 | logger.info("Lock acquired for streaming completion request.") 453 | try: 454 | async for token in stream_tokens_sync(generator_source, is_chat=False): 455 | yield token 456 | except Exception as e: 457 | logger.error(f"Exception in streaming_generator: {e}") 458 | finally: 459 | logger.info("Streaming generator completed and lock released.") 460 | 461 | return StreamingResponse( 462 | streaming_generator(), 463 | media_type="text/event-stream" 464 | ) 465 | 466 | else: 467 | logger.info("Non-streaming completion request started.") 468 | async with lock: 469 | logger.info("Lock acquired for non-streaming completion request.") 470 | # Non-streaming response 471 | generated_tokens = generate_antislop( 472 | model=model, 473 | tokenizer=tokenizer, 474 | prompt=prompt, 475 | max_new_tokens=request.max_tokens, 476 | temperature=request.temperature, 477 | top_k=request.top_k, 478 | top_p=request.top_p, 479 | min_p=request.min_p, 480 | slop_phrase_prob_adjustments=slop_adjustments, 481 | adjustment_strength=request.adjustment_strength, 482 | device=device, 483 | streaming=False, 484 | slow_debug=False, 485 | output_every_n_tokens=5, 486 | debug_delay=0.0, 487 | inference_output=None, 488 | debug_output=None, 489 | enforce_json=request.enforce_json, 490 | antislop_enabled=request.antislop_enabled, 491 | regex_bans=request.regex_bans 492 | ) 493 | 494 | # Decode the tokens 495 | text = tokenizer.decode(generated_tokens, skip_special_tokens=True) 496 | logger.debug(f"Generated text: {text}") 497 | 498 | # Create the response 499 | response = CompletionResponse( 500 | id=generate_id(), 501 | object="text_completion", 502 | created=current_timestamp(), 503 | model=used_model, 504 | choices=[ 505 | CompletionChoice( 506 | text=text, 507 | index=0, 508 | logprobs=None, 509 | finish_reason="length" if request.max_tokens else "stop" 510 | ) 511 | ], 512 | usage={ 513 | "prompt_tokens": len(tokenizer.encode(prompt)), 514 | "completion_tokens": len(generated_tokens), 515 | "total_tokens": len(tokenizer.encode(prompt)) + len(generated_tokens), 516 | } 517 | ) 518 | logger.info("Completion request processing completed.") 519 | return response 520 | 521 | except Exception as e: 522 | logger.error(f"Error during completion processing: {e}") 523 | raise HTTPException(status_code=500, detail=str(e)) 524 | finally: 525 | logger.debug("Exiting /v1/completions endpoint.") 526 | 527 | 528 | # Endpoint: /v1/chat/completions 529 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 530 | async def chat_completions(request: ChatCompletionRequest, req: Request): 531 | logger.info("Chat completion request received, waiting for processing...") 532 | 533 | if request.stream and request.regex_bans: 534 | raise HTTPException(status_code=500, detail="Streaming cannot be enabled when using regex bans.") 535 | 536 | try: 537 | if model is None or tokenizer is None: 538 | logger.error("Model and tokenizer are not loaded.") 539 | raise HTTPException(status_code=500, detail="Model and tokenizer are not loaded.") 540 | 541 | # Use the model specified in the request or default 542 | used_model = request.model if request.model else model_name_loaded 543 | 544 | # Build the prompt from chat messages 545 | prompt = "\n".join([f"{msg.role}: {msg.content}" for msg in request.messages]) 546 | logger.debug(f"Constructed prompt from messages: {prompt}") 547 | 548 | # Process slop_phrases parameter 549 | if request.slop_phrases is not None: 550 | # Convert list of tuples to dictionary 551 | slop_adjustments = dict(request.slop_phrases) 552 | logger.debug(f"Slop adjustments provided: {slop_adjustments}") 553 | else: 554 | # Use default slop_phrase_prob_adjustments 555 | slop_adjustments = DEFAULT_SLOP_ADJUSTMENTS.copy() 556 | logger.debug(f"Using default slop adjustments with {len(slop_adjustments)} entries.") 557 | 558 | if request.stream: 559 | logger.info("Streaming chat completion request started.") 560 | # Streaming response 561 | generator_source = chat_antislop( 562 | model=model, 563 | tokenizer=tokenizer, 564 | messages=[msg.dict() for msg in request.messages], 565 | max_new_tokens=request.max_tokens, 566 | temperature=request.temperature, 567 | top_k=request.top_k, 568 | top_p=request.top_p, 569 | min_p=request.min_p, 570 | slop_phrase_prob_adjustments=slop_adjustments, 571 | adjustment_strength=request.adjustment_strength, 572 | device=device, 573 | streaming=True, 574 | slow_debug=False, # Adjust as needed 575 | output_every_n_tokens=1, 576 | debug_delay=0.0, 577 | inference_output=None, 578 | debug_output=None, 579 | enforce_json=request.enforce_json, 580 | antislop_enabled=request.antislop_enabled, 581 | regex_bans=request.regex_bans 582 | ) 583 | 584 | async def streaming_generator(): 585 | async with lock: 586 | logger.info("Lock acquired for streaming chat completion request.") 587 | try: 588 | async for token in stream_tokens_sync(generator_source, is_chat=True): 589 | yield token 590 | except Exception as e: 591 | logger.error(f"Exception in streaming_generator: {e}") 592 | finally: 593 | logger.info("Streaming generator completed and lock released.") 594 | 595 | return StreamingResponse( 596 | streaming_generator(), 597 | media_type="text/event-stream" 598 | ) 599 | 600 | else: 601 | logger.info("Non-streaming chat completion request started.") 602 | async with lock: 603 | logger.info("Lock acquired for non-streaming chat completion request.") 604 | # Non-streaming response 605 | generated_tokens = chat_antislop( 606 | model=model, 607 | tokenizer=tokenizer, 608 | messages=[msg.dict() for msg in request.messages], 609 | max_new_tokens=request.max_tokens, 610 | temperature=request.temperature, 611 | top_k=request.top_k, 612 | top_p=request.top_p, 613 | min_p=request.min_p, 614 | slop_phrase_prob_adjustments=slop_adjustments, 615 | adjustment_strength=request.adjustment_strength, 616 | device=device, 617 | streaming=False, 618 | slow_debug=False, 619 | output_every_n_tokens=5, 620 | debug_delay=0.0, 621 | inference_output=None, 622 | debug_output=None, 623 | enforce_json=request.enforce_json, 624 | antislop_enabled=request.antislop_enabled, 625 | regex_bans=request.regex_bans 626 | ) 627 | 628 | # Decode the tokens 629 | text = tokenizer.decode(generated_tokens, skip_special_tokens=True) 630 | logger.debug(f"Generated chat text: {text}") 631 | 632 | # Create the response 633 | response = ChatCompletionResponse( 634 | id=generate_id(), 635 | object="chat.completion", 636 | created=current_timestamp(), 637 | model=used_model, 638 | choices=[ 639 | ChatCompletionChoice( 640 | message=ChatCompletionMessage(role="assistant", content=text), 641 | index=0, 642 | finish_reason="length" if request.max_tokens else "stop" 643 | ) 644 | ], 645 | usage={ 646 | "prompt_tokens": len(tokenizer.encode(prompt)), 647 | "completion_tokens": len(generated_tokens), 648 | "total_tokens": len(tokenizer.encode(prompt)) + len(generated_tokens), 649 | } 650 | ) 651 | logger.info("Chat completion request processing completed.") 652 | return response 653 | 654 | except Exception as e: 655 | logger.error(f"Error during chat completion processing: {e}") 656 | raise HTTPException(status_code=500, detail=str(e)) 657 | finally: 658 | logger.debug("Exiting /v1/chat/completions endpoint.") 659 | 660 | 661 | # New Endpoint: /v1/models 662 | @app.get("/v1/models", response_model=ModelsResponse) 663 | async def get_models(): 664 | logger.info("Models request received.") 665 | try: 666 | if model is None or model_name_loaded is None or model_loaded_time is None: 667 | logger.error("Model is not loaded.") 668 | raise HTTPException(status_code=500, detail="Model is not loaded.") 669 | 670 | model_info = ModelInfo( 671 | id=model_name_loaded, 672 | created=model_loaded_time, 673 | owned_by="user", # Adjust as needed 674 | permission=[], # Can be populated with actual permissions if available 675 | root=model_name_loaded, 676 | parent=None 677 | ) 678 | 679 | response = ModelsResponse( 680 | data=[model_info] 681 | ) 682 | 683 | logger.info("Models response prepared successfully.") 684 | return response 685 | 686 | except Exception as e: 687 | logger.error(f"Error during models processing: {e}") 688 | raise HTTPException(status_code=500, detail=str(e)) 689 | finally: 690 | logger.debug("Exiting /v1/models endpoint.") 691 | 692 | 693 | # Main function to parse arguments and start Uvicorn 694 | def main(): 695 | parser = argparse.ArgumentParser(description="Launch the AntiSlop OpenAI-Compatible API server.") 696 | 697 | parser.add_argument( 698 | "--model", 699 | type=str, 700 | required=True, 701 | help="Path to the model directory or HuggingFace model ID (e.g., 'gpt2')." 702 | ) 703 | parser.add_argument( 704 | "--load_in_4bit", 705 | action="store_true", 706 | help="Load the model in 4-bit precision (requires appropriate support)." 707 | ) 708 | parser.add_argument( 709 | "--load_in_8bit", 710 | action="store_true", 711 | help="Load the model in 8-bit precision (requires appropriate support)." 712 | ) 713 | parser.add_argument( 714 | "--slop_adjustments_file", 715 | type=str, 716 | default=None, 717 | help="Path to the JSON file containing slop phrase probability adjustments." 718 | ) 719 | parser.add_argument( 720 | "--host", 721 | type=str, 722 | default="0.0.0.0", 723 | help="Host address to bind the server to." 724 | ) 725 | parser.add_argument( 726 | "--port", 727 | type=int, 728 | default=8000, 729 | help="Port number to bind the server to." 730 | ) 731 | parser.add_argument( 732 | "--gpu", 733 | type=int, 734 | default=None, 735 | help="GPU ID to load the model on (e.g., 0, 1). Optional." 736 | ) 737 | 738 | args = parser.parse_args() 739 | 740 | # Set environment variables based on parsed arguments 741 | os.environ["MODEL_NAME"] = args.model 742 | os.environ["LOAD_IN_4BIT"] = str(args.load_in_4bit) 743 | os.environ["LOAD_IN_8BIT"] = str(args.load_in_8bit) 744 | if args.slop_adjustments_file: 745 | os.environ["SLOP_ADJUSTMENTS_FILE"] = args.slop_adjustments_file 746 | if args.gpu is not None: 747 | os.environ["GPU_ID"] = str(args.gpu) 748 | 749 | # Run the app using Uvicorn with single worker and single thread 750 | uvicorn.run( 751 | "run_api:app", # Ensure this matches the filename if different 752 | host=args.host, 753 | port=args.port, 754 | reload=False, 755 | log_level="info", # Set to DEBUG for more detailed logs 756 | timeout_keep_alive=600, # 10 minutes 757 | workers=1, # Single worker to enforce global lock 758 | loop="asyncio", # Ensure using asyncio loop 759 | ) 760 | 761 | 762 | if __name__ == "__main__": 763 | main() 764 | -------------------------------------------------------------------------------- /slop_phrase_prob_adjustments.json: -------------------------------------------------------------------------------- 1 | [ 2 | ["symphony", 0.03125], 3 | ["testament to", 0.03125], 4 | ["kaleidoscope", 0.03125], 5 | ["delve", 0.03125], 6 | ["delving", 0.05988111477573171], 7 | ["delved", 0.03125], 8 | ["elara", 0.03125], 9 | ["tapestry", 0.03125], 10 | ["tapestries", 0.03125], 11 | ["weave", 0.03125], 12 | ["wove", 0.03125], 13 | ["weaving", 0.03125], 14 | ["elysia", 0.03125], 15 | ["barely above a whisper", 0.03125], 16 | ["barely a whisper", 0.03125], 17 | ["orchestra of", 0.03125], 18 | ["dance of", 0.03125], 19 | ["maybe, just maybe", 0.03125], 20 | ["maybe that was enough", 0.03125], 21 | ["perhaps, just perhaps", 0.03125], 22 | ["was only just beginning", 0.03125], 23 | [", once a ", 0.03125], 24 | ["world of", 0.03125], 25 | ["bustling", 0.03125], 26 | ["labyrinthine", 0.06638702636902488], 27 | ["shivers down", 0.03125], 28 | ["shivers up", 0.03125], 29 | ["shiver down", 0.03125], 30 | ["shiver up", 0.03125], 31 | ["ministrations", 0.03125], 32 | ["numeria", 0.03125], 33 | ["transcended", 0.060515730172740693], 34 | ["lyra", 0.03125], 35 | ["eira", 0.03125], 36 | ["eldoria", 0.03125], 37 | ["atheria", 0.03125], 38 | ["eluned", 0.03125], 39 | ["oakhaven", 0.03125], 40 | ["whisperwood", 0.03125], 41 | ["zephyria", 0.03125], 42 | ["elian", 0.03125], 43 | ["elias", 0.03125], 44 | ["elianore", 0.03125], 45 | ["aria", 0.03125], 46 | ["eitan", 0.03125], 47 | ["kael", 0.03125], 48 | ["jaxon", 0.10836356096002156], 49 | ["ravenswood", 0.03125], 50 | ["moonwhisper", 0.03125], 51 | ["thrummed", 0.03125], 52 | ["flickered", 0.07474260888562831], 53 | [" rasped", 0.03125], 54 | [" rasp", 0.03125], 55 | [" rasping", 0.03125], 56 | [" ,rasped", 0.03125], 57 | [" ,rasp", 0.03125], 58 | [" ,rasping", 0.03125], 59 | ["bioluminescent", 0.03125], 60 | ["glinting", 0.03125], 61 | ["twinkled", 0.07014359686367556], 62 | ["nodded", 0.04128581381864993], 63 | ["nestled", 0.03125], 64 | ["ministration", 0.03125], 65 | ["moth to a flame", 0.03125], 66 | ["canvas", 0.03125], 67 | ["eyes glinted", 0.03125], 68 | ["camaraderie", 0.03125], 69 | ["humble abode", 0.03125], 70 | ["cold and calculating", 0.03125], 71 | ["eyes never leaving", 0.03125], 72 | ["body and soul", 0.03125], 73 | ["orchestra", 0.03125], 74 | ["palpable", 0.03125], 75 | ["depths", 0.03125], 76 | ["a dance of", 0.03125], 77 | ["chuckles darkly", 0.03125], 78 | ["maybe, that was enough", 0.03125], 79 | ["they would face it together", 0.03125], 80 | ["a reminder", 0.03125], 81 | ["that was enough", 0.03125], 82 | ["for now, that was enough", 0.03125], 83 | ["for now, that's enough", 0.03125], 84 | ["with a mixture of", 0.03125], 85 | ["air was filled with anticipation", 0.03125], 86 | ["cacophony", 0.03125], 87 | ["bore silent witness to", 0.03125], 88 | ["eyes sparkling with mischief", 0.03125], 89 | ["practiced ease", 0.03125], 90 | ["ready for the challenges", 0.03125], 91 | ["only just getting started", 0.03125], 92 | ["once upon a time", 0.03125], 93 | ["nestled deep within", 0.03125], 94 | ["ethereal beauty", 0.03125], 95 | ["life would never be the same", 0.03125], 96 | ["it's important to remember", 0.03125], 97 | ["for what seemed like an eternity", 0.03125], 98 | ["little did he know", 0.03125], 99 | ["ball is in your court", 0.03125], 100 | ["game is on", 0.03125], 101 | ["choice is yours", 0.03125], 102 | ["feels like an electric shock", 0.03125], 103 | ["threatens to consume", 0.03125], 104 | ["meticulous", 0.03125], 105 | ["meticulously", 0.03125], 106 | ["navigating", 0.03125], 107 | ["complexities", 0.03125], 108 | ["realm", 0.03125], 109 | ["understanding", 0.03125], 110 | ["dive into", 0.03125], 111 | ["shall", 0.03125], 112 | ["tailored", 0.03125], 113 | ["towards", 0.03125], 114 | ["underpins", 0.03125], 115 | ["everchanging", 0.03125], 116 | ["ever-evolving", 0.03125], 117 | ["not only", 0.03125], 118 | ["alright", 0.03125], 119 | ["embark", 0.03125], 120 | ["journey", 0.03125], 121 | ["today's digital age", 0.03125], 122 | ["game changer", 0.03125], 123 | ["designed to enhance", 0.03125], 124 | ["it is advisable", 0.03125], 125 | ["daunting", 0.03125], 126 | ["when it comes to", 0.03125], 127 | ["in the realm of", 0.03125], 128 | ["unlock the secrets", 0.03125], 129 | ["unveil the secrets", 0.03125], 130 | ["and robust", 0.03125], 131 | ["elevate", 0.03125], 132 | ["unleash", 0.03125], 133 | ["cutting-edge", 0.03125], 134 | ["mastering", 0.03125], 135 | ["harness", 0.03125], 136 | ["it's important to note", 0.03125], 137 | ["in summary", 0.03125], 138 | ["remember that", 0.03125], 139 | ["take a dive into", 0.03125], 140 | ["landscape", 0.03125], 141 | ["in the world of", 0.03125], 142 | ["vibrant", 0.03125], 143 | ["metropolis", 0.03125], 144 | ["moreover", 0.03125], 145 | ["crucial", 0.03125], 146 | ["to consider", 0.03125], 147 | ["there are a few considerations", 0.03125], 148 | ["it's essential to", 0.03125], 149 | ["furthermore", 0.03125], 150 | ["vital", 0.03125], 151 | ["as a professional", 0.03125], 152 | ["thus", 0.03125], 153 | ["you may want to", 0.03125], 154 | ["on the other hand", 0.03125], 155 | ["as previously mentioned", 0.03125], 156 | ["it's worth noting that", 0.03125], 157 | ["to summarize", 0.03125], 158 | ["to put it simply", 0.03125], 159 | ["in today's digital era", 0.03125], 160 | ["reverberate", 0.03125], 161 | ["revolutionize", 0.03125], 162 | ["labyrinth", 0.03125], 163 | ["gossamer", 0.03125], 164 | ["enigma", 0.03125], 165 | ["whispering", 0.03125], 166 | ["sights unseen", 0.03125], 167 | ["sounds unheard", 0.03125], 168 | ["indelible", 0.03125], 169 | ["in conclusion", 0.03125], 170 | ["technopolis", 0.03125], 171 | ["was soft and gentle", 0.03125], 172 | ["leaving trails of fire", 0.03125], 173 | ["audible pop", 0.03125], 174 | ["rivulets of", 0.03125], 175 | ["despite herself", 0.03125], 176 | ["reckless abandon", 0.03125], 177 | ["torn between", 0.03125], 178 | ["fiery red hair", 0.03125], 179 | ["long lashes", 0.03125], 180 | ["world narrows", 0.03125], 181 | ["chestnut eyes", 0.03125], 182 | ["cheeks flaming", 0.03125], 183 | ["cheeks hollowing", 0.03125], 184 | ["understandingly", 0.03125], 185 | ["paperbound", 0.03125], 186 | ["hesitantly", 0.03125], 187 | ["piqued", 0.03125], 188 | ["curveballs", 0.03125], 189 | ["marveled", 0.03125], 190 | ["inclusivity", 0.03125], 191 | ["birdwatcher", 0.03125], 192 | ["newfound", 0.03224441869181626], 193 | ["marveling", 0.0330273218226949], 194 | ["hiroshi", 0.10679436640360451], 195 | ["greentech", 0.03433682773874149], 196 | ["thoughtfully", 0.034554614125603775], 197 | ["intently", 0.036340970869438986], 198 | ["birdwatching", 0.036507038507018016], 199 | ["amidst", 0.036622615767493], 200 | ["cherishing", 0.03678545819532492], 201 | ["attentively", 0.036925354505568025], 202 | ["interjected", 0.03833845769341088], 203 | ["serendipitous", 0.038739958252226044], 204 | ["marianne", 0.038761601988255484], 205 | ["maya", 0.09234714457485342], 206 | ["excitedly", 0.039326615662048627], 207 | ["steepled", 0.03934628705145861], 208 | ["engrossed", 0.03938357871889015], 209 | ["fostering", 0.040237606281663015], 210 | ["brainstormed", 0.040837224374359785], 211 | ["furrowed", 0.04106990333912991], 212 | ["contemplatively", 0.0415715337562441], 213 | ["jotted", 0.04185187388384613], 214 | ["mia", 0.04228346088920659], 215 | ["yesteryears", 0.04297673583870908], 216 | ["conspiratorially", 0.04318452795649216], 217 | ["poring", 0.043368809753774676], 218 | ["stumbled", 0.043572405422243304], 219 | ["strategized", 0.04421861605966468], 220 | ["hesitated", 0.04494448438749303], 221 | ["intrigued", 0.0455166486053908], 222 | ["sarah", 0.10086927179354085], 223 | ["lykos", 0.04637933901297856], 224 | ["adaptability", 0.04670427920019436], 225 | ["yoing", 0.04690684293430847], 226 | ["geocaches", 0.047245460751898956], 227 | ["furrowing", 0.047522914178588074], 228 | ["quandaries", 0.0481208833728088], 229 | ["chimed", 0.0482978999356531], 230 | ["headfirst", 0.04831461279577993], 231 | ["gruffly", 0.049202857565227666], 232 | ["skeptically", 0.04926960158439389], 233 | ["fruitville", 0.05025148031991987], 234 | ["gastronomical", 0.050289103350348655], 235 | ["sighed", 0.05051102378739721], 236 | ["warmly", 0.05079433763521114], 237 | ["approvingly", 0.05115585243689052], 238 | ["questioningly", 0.05138118422767321], 239 | ["timmy", 0.07835758599135755], 240 | ["undeterred", 0.052486122725405135], 241 | ["starlit", 0.052741028161842], 242 | ["unearthing", 0.053074267343765635], 243 | ["grappled", 0.053717801895838344], 244 | ["yumi", 0.09597346307465092], 245 | ["seabrook", 0.054387123574424684], 246 | ["geocachers", 0.05443263224688007], 247 | ["animatedly", 0.054946701090844687], 248 | ["bakersville", 0.055275865634954346], 249 | ["minji", 0.05547746089952316], 250 | ["fateful", 0.055736452807584415], 251 | ["sparkled", 0.05644124927285605], 252 | ["resonated", 0.05661086006277033], 253 | ["harmoniously", 0.05663396393468821], 254 | ["fidgeted", 0.05692129321298146], 255 | ["mwanga", 0.05772491398190184], 256 | ["gleamed", 0.05785836606810767], 257 | ["embracing", 0.05786812845368198], 258 | ["pleasantries", 0.05845295163766754], 259 | ["iostream", 0.059082502393001814], 260 | ["navigated", 0.05913171417008317], 261 | ["interconnectedness", 0.05919654794414185], 262 | ["tanginess", 0.05941208348993963], 263 | ["mouthwatering", 0.05961924098477277], 264 | ["amelia", 0.10495450758329813], 265 | ["mischievously", 0.059910554433957666], 266 | ["tirelessly", 0.05992047742518978], 267 | ["sympathetically", 0.06067860765517939], 268 | ["pondered", 0.06069437404836339], 269 | ["lingered", 0.060889076468395927], 270 | ["println", 0.06094043781330748], 271 | ["empathizing", 0.06134449183509449], 272 | ["niche", 0.06190423094543172], 273 | ["regaled", 0.06212602213421109], 274 | ["greenthumb", 0.06224864856855164], 275 | ["savored", 0.062408346752467134], 276 | ["amira", 0.11021608052671446], 277 | ["firsthand", 0.06392924590568197], 278 | ["empathetically", 0.06417835678867895], 279 | ["unshed", 0.06438631989525674], 280 | ["jenkin", 0.06442424583006488], 281 | ["enigmatically", 0.06523513565206027], 282 | ["marla", 0.06529158024829011], 283 | ["bayville", 0.06533248861052479], 284 | ["adversities", 0.06561596389981958], 285 | ["eagerly", 0.06577650960046391], 286 | ["quizzically", 0.06651670263185526], 287 | ["transcending", 0.06652015162691576], 288 | ["resilience", 0.06735287445058671], 289 | ["lily", 0.11166764730068231], 290 | ["commiserated", 0.06750658002795124], 291 | ["savoring", 0.06776227505035798], 292 | ["amara", 0.10974878019475678], 293 | ["somberly", 0.06899158731759324], 294 | ["cinephile", 0.06956640378878991], 295 | ["solace", 0.06979652935880434], 296 | ["aquascaping", 0.07031299700645137], 297 | ["rippled", 0.07081124722667731], 298 | ["reveled", 0.07100424664644575], 299 | ["greenhaven", 0.07105438878187946], 300 | ["birdwatchers", 0.07113287273958965], 301 | ["adwoa", 0.0715485451251974], 302 | ["appreciatively", 0.07161671276447462], 303 | ["awestruck", 0.07162853701765545], 304 | ["ecotech", 0.0718147029190956], 305 | ["lightheartedness", 0.07252636606936509], 306 | ["disapprovingly", 0.07263207299052071], 307 | ["exclaimed", 0.07276579734327646], 308 | ["samir", 0.07285320350285814], 309 | ["fishkeeping", 0.07309459959713759], 310 | ["sparked", 0.07360160538883663], 311 | ["welled", 0.07400001131145016], 312 | ["jotting", 0.07406314941636878], 313 | ["resourcefulness", 0.07421328466775397], 314 | ["reminisced", 0.07491988173476277], 315 | ["abernathy", 0.07570846614292619], 316 | ["unbeknownst", 0.07630735797388712], 317 | ["pattered", 0.07655890842665049], 318 | ["reassuringly", 0.07682374203555666], 319 | ["miscommunications", 0.07686405323158839], 320 | ["wafted", 0.07712987845796046], 321 | ["absentmindedly", 0.0774482165689122], 322 | ["weightiness", 0.07755972654694927], 323 | ["allyship", 0.0778681972659635], 324 | ["perseverance", 0.07827640317481438], 325 | ["mindfully", 0.07859251285955585], 326 | ["disheartened", 0.07871071570815331], 327 | ["leaned", 0.07879604294731023], 328 | ["birder", 0.07892032327792912], 329 | ["captivated", 0.07892473996549315], 330 | ["ravi", 0.07900206498656101], 331 | ["abuela", 0.07959034010494626], 332 | ["apprehensions", 0.07962652222314125], 333 | ["gestured", 0.079648614599269], 334 | ["sagely", 0.07990023651781021], 335 | ["jamie", 0.07999021650870594], 336 | ["emily", 0.08024569032217586], 337 | ["piquing", 0.08025442043095833], 338 | ["bated", 0.08045544454435882], 339 | ["\u00e9lise", 0.08118149973936264], 340 | ["cinephiles", 0.08124748123357117], 341 | ["alex", 0.08155958528922039], 342 | ["wholeheartedly", 0.08162034496563263], 343 | ["enthusiasts", 0.08196867752993525], 344 | ["enchantingly", 0.0823223018514873], 345 | ["wambui", 0.08263722346791491], 346 | ["blankly", 0.08273500150148104], 347 | ["eadric", 0.08283062478131556], 348 | ["immersing", 0.08318343859097584], 349 | ["adversity", 0.08321158200876488], 350 | ["tldr", 0.08325658701983453], 351 | ["cleanups", 0.08364043130448932], 352 | ["candidness", 0.08479322865932452], 353 | ["todayilearned", 0.08484643847316964], 354 | ["windowpanes", 0.08542326976652978], 355 | ["chuckled", 0.08575177294570574], 356 | ["jake", 0.08599683007615211], 357 | ["cobblestone", 0.08603112799196728], 358 | ["scrolling", 0.08603718332851425], 359 | ["curiosity", 0.08623856472735726], 360 | ["homebrewer", 0.08658753164552577], 361 | ["worriedly", 0.08677545881209053], 362 | ["intriguingly", 0.08692958549184135], 363 | ["brainstorming", 0.08724187110008937], 364 | ["shimmered", 0.08742261524005794], 365 | ["supportively", 0.08751018298262307], 366 | ["aldric", 0.0875538540008498], 367 | ["captivating", 0.08772967862882247], 368 | ["grumbled", 0.08775125893819308], 369 | ["flytraps", 0.08789847780208236], 370 | ["evergreen", 0.08810861754185394], 371 | ["jingled", 0.08816160213719028], 372 | ["csharp", 0.08867618121093104], 373 | ["etched", 0.08872622775971402], 374 | ["intricate", 0.08874912738325326], 375 | ["vibrantly", 0.08875022097469734], 376 | ["insights", 0.08881465445615759], 377 | ["etiquettes", 0.08881746671718654], 378 | ["jamal", 0.0892505566950269], 379 | ["serendipity", 0.08930907685467555], 380 | ["aback", 0.08933029922210305], 381 | ["tightknit", 0.089894914209471], 382 | ["fostered", 0.0902250543058865], 383 | ["unease", 0.09027239734885409], 384 | ["stammered", 0.0904076215995526], 385 | ["passions", 0.09064053503449236], 386 | ["johann", 0.09114771748950079], 387 | ["maplewood", 0.09126008685765656], 388 | ["user1", 0.0916424411590356], 389 | ["appreciating", 0.09166077289251462], 390 | ["bibliophiles", 0.09170778747944702], 391 | ["reverberated", 0.09178604774163615], 392 | ["insightfulness", 0.09183861554018956], 393 | ["amina", 0.09188898218908503], 394 | ["unwavering", 0.091899512603219], 395 | ["makena", 0.09209705774418087], 396 | ["strummed", 0.09229274609527229], 397 | ["dataisbeautiful", 0.09250139284705605], 398 | ["geocache", 0.09268887739812762], 399 | ["conlangs", 0.09280617830631038], 400 | ["geocaching", 0.09321966631692781], 401 | ["advancements", 0.09329024874974093], 402 | ["maria", 0.09333053475643405], 403 | ["shying", 0.09335338497711734], 404 | ["quaint", 0.0934056172785989], 405 | ["unforeseen", 0.09344290002506736], 406 | ["strategizing", 0.09399091709419581], 407 | ["dialogues", 0.09408175325423211], 408 | ["insurmountable", 0.09416355517968139], 409 | ["clinked", 0.09417327657033364], 410 | ["trusty", 0.09444824412826175], 411 | ["persevered", 0.0946142869116114], 412 | ["collaboratively", 0.09474179629801663], 413 | ["fascinated", 0.09484997447900628], 414 | ["thrived", 0.09510888010451658], 415 | ["anika", 0.09540764866213361], 416 | ["chanoyu", 0.09550084290809815], 417 | ["profusely", 0.09557139532288833], 418 | ["eliab", 0.09566581739075723], 419 | ["zara", 0.09584241402517779], 420 | ["solidifying", 0.09608868269878561], 421 | ["naledi", 0.09635414182326163], 422 | ["murmured", 0.09655989921185353], 423 | ["prided", 0.09682772402036069], 424 | ["curveball", 0.0969267440617325], 425 | ["belongingness", 0.09704829827815857], 426 | ["hometown", 0.09706682704637251], 427 | ["glanced", 0.09716443448468784], 428 | ["dismissiveness", 0.09731373843190905], 429 | ["kavarna", 0.09751433047300515], 430 | ["echoed", 0.09752749272947492], 431 | ["arben", 0.09776214191867087], 432 | ["clara", 0.09810844642230725], 433 | ["wonderment", 0.09839827901176737], 434 | ["ayla", 0.09865582838534621], 435 | ["aquarist", 0.09871535431592793], 436 | ["twinkling", 0.09878358270305657], 437 | ["yearned", 0.09914152313228353], 438 | ["sqrt", 0.09938872210307102], 439 | ["paused", 0.0997390693721105], 440 | ["nurturing", 0.09974100543627683], 441 | ["avid", 0.0997656533181099], 442 | ["brimming", 0.09978785592283941], 443 | ["freydis", 0.09986607107311554], 444 | ["gesturing", 0.0999584693413217], 445 | ["seasoned", 0.10002025957030909], 446 | ["zizzi", 0.10014154117375548], 447 | ["claudette", 0.10041075075701582], 448 | ["breadmaking", 0.10053842348672806], 449 | ["hyperparameters", 0.10081921281211534], 450 | ["naturedly", 0.10150071747381084], 451 | ["transformative", 0.1016973017391017], 452 | ["blossomed", 0.10245768551030358], 453 | ["pastimes", 0.10254359260186416], 454 | ["meera", 0.10270915366391394], 455 | ["slipups", 0.10302995120841944], 456 | ["intricacies", 0.10317249299511089], 457 | ["enthusiast", 0.10318098408793414], 458 | ["clinking", 0.1033139770842943], 459 | ["alexei", 0.10332938802798271], 460 | ["underscored", 0.10337155533898727], 461 | ["ramesh", 0.10351934507788225], 462 | ["huddled", 0.10356308221752378], 463 | ["jaspreet", 0.10363536575781072], 464 | ["ofrenda", 0.10390637614166792], 465 | ["gnawed", 0.10391875479983416], 466 | ["quirks", 0.10394287014271747], 467 | ["whirred", 0.10401340688051537], 468 | ["sipped", 0.10411521550710723], 469 | ["fatima", 0.10414541078497928], 470 | ["empathized", 0.1042529964272074], 471 | ["cheerily", 0.10439249128151971], 472 | ["unexpected", 0.10470011792513964], 473 | ["reflecting", 0.10515486160417738], 474 | ["nervously", 0.10518028779172277], 475 | ["melodia", 0.10524116896704774], 476 | ["unlikeliest", 0.10524773887823077], 477 | ["intertwine", 0.10609835729514719], 478 | ["perusing", 0.1062448476794195], 479 | ["towering", 0.10630270599082497], 480 | ["prioritizing", 0.10718619711752761], 481 | ["teemed", 0.1074711285533854], 482 | ["astonishment", 0.10756966481041869], 483 | ["showcasing", 0.10848411397230974], 484 | ["diligently", 0.1085706157908363], 485 | ["setbacks", 0.10864842590384427], 486 | ["exhilarated", 0.10931278623094859], 487 | ["murmurs", 0.10931293954230643], 488 | ["gleaming", 0.10936463324897595], 489 | ["coexisted", 0.10950456348152442], 490 | ["yellowed", 0.10968475409884669], 491 | ["seamlessly", 0.10982629613453346], 492 | ["ominously", 0.11001205898834783], 493 | ["quietude", 0.11031040532864991], 494 | ["adorning", 0.11031906309377383], 495 | ["teeming", 0.11033518418005982], 496 | ["countless", 0.1105164103825428], 497 | ["peculiar", 0.11094333701984035], 498 | ["precariously", 0.11107316266001163], 499 | ["deepened", 0.11113389824614932], 500 | ["embarked", 0.11115616320098191], 501 | ["empathetic", 0.11155900191845305], 502 | ["triumphantly", 0.11184299931805378], 503 | ["remorsefully", 0.11186809526521746], 504 | ["grumble", 0.1119706215959465], 505 | ["nuances", 0.11200998355477185], 506 | ["researching", 0.11250797352943258], 507 | ["pulsated", 0.11270598588075774], 508 | ["aquafaba", 0.11273062908980729], 509 | ["lila", 0.11299047934206825], 510 | ["hunched", 0.11312356011167964], 511 | ["reminiscing", 0.11313931448379727], 512 | ["shockwaves", 0.11320528290207856], 513 | ["considerately", 0.11367512342298255], 514 | ["sparking", 0.11383803128279647], 515 | ["emboldened", 0.11402368058243963], 516 | ["delectable", 0.11409343918029602], 517 | ["vigor", 0.11452660354436013], 518 | ["aisha", 0.11507399789890498] 519 | ] -------------------------------------------------------------------------------- /slop_regexes.txt: -------------------------------------------------------------------------------- 1 | "(?i)not [^.!?]{3,60} but", 2 | "(?i)each(?:\s*\w+\s*|\s*)a", 3 | "(?i)every(?:\s*\w+\s*|\s*)a" -------------------------------------------------------------------------------- /src/antislop_generate.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import List, Dict, Tuple, Generator, Set, Union 3 | from threading import Thread 4 | import threading 5 | import queue 6 | import torch 7 | from transformers import ( 8 | PreTrainedTokenizer, 9 | PreTrainedModel, 10 | StoppingCriteriaList, 11 | TextIteratorStreamer 12 | ) 13 | from IPython.display import display, HTML 14 | from ipywidgets import Output 15 | from src.validator_slop import SlopPhraseHandler, CustomSlopPhraseStoppingCriteria 16 | from src.validator_json import JSONValidator, JSONValidationStoppingCriteria 17 | from src.validator_regex import RegexValidator, RegexValidationStoppingCriteria 18 | 19 | from src.util import precompute_starting_tokens 20 | import asyncio 21 | 22 | class AntiSlopSampler: 23 | def __init__( 24 | self, 25 | model: PreTrainedModel, 26 | tokenizer: PreTrainedTokenizer, 27 | slop_phrase_prob_adjustments: Dict[str, float], 28 | starting_tokens_lookup: Dict[Tuple[int, ...], Set[int]], 29 | adjustment_strength: float = 1.0, 30 | device: torch.device = torch.device('cuda'), 31 | slow_debug: bool = False, 32 | output_every_n_tokens: int = 1, 33 | debug_delay: float = 2.0, 34 | inference_output=None, 35 | debug_output=None, 36 | regex_bans: List[str] = [], 37 | enforce_json: bool = False, 38 | antislop_enabled: bool = True, 39 | ): 40 | self.model = model 41 | self.tokenizer = tokenizer 42 | self.slop_phrase_prob_adjustments = slop_phrase_prob_adjustments 43 | self.starting_tokens_lookup = starting_tokens_lookup 44 | self.adjustment_strength = adjustment_strength 45 | self.device = device 46 | self.slow_debug = slow_debug 47 | self.output_every_n_tokens = output_every_n_tokens 48 | self.debug_delay = debug_delay 49 | self.enforce_json = enforce_json 50 | self.antislop_enabled = antislop_enabled 51 | self.regex_bans = regex_bans or [] 52 | 53 | self.sequence_queue = queue.Queue() 54 | self.generation_complete = threading.Event() 55 | 56 | # Output widgets 57 | self.inference_output = inference_output 58 | self.debug_output = debug_output 59 | 60 | self.probs_cache = {} 61 | self.probs_cache_longrange = {} # flags which positions in the logit cache we ignore during cleanup, as we want to keep some positions for long range constraint checks 62 | 63 | # Escaped toks used for lookups in json string repair 64 | self.escaped_tokens_lookup = { 65 | '\n': self.tokenizer.encode('\\n', add_special_tokens=False), 66 | '\t': self.tokenizer.encode('\\t', add_special_tokens=False), 67 | '\r': self.tokenizer.encode('\\r', add_special_tokens=False), 68 | '"': self.tokenizer.encode('\\"', add_special_tokens=False), 69 | ' "': self.tokenizer.encode(' \\"', add_special_tokens=False), 70 | } 71 | 72 | # Initialize Slop Phrase Handler 73 | self.slop_phrase_handler = SlopPhraseHandler( 74 | tokenizer=tokenizer, 75 | probs_cache=self.probs_cache, 76 | probs_cache_longrange=self.probs_cache_longrange, 77 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments, 78 | starting_tokens_lookup=starting_tokens_lookup, 79 | adjustment_strength=adjustment_strength, 80 | slow_debug=slow_debug, 81 | inference_output=inference_output, 82 | debug_output=debug_output, 83 | debug_delay=debug_delay 84 | ) 85 | self.json_validator = JSONValidator(tokenizer, slow_debug, debug_delay, debug_output, self.probs_cache_longrange) 86 | 87 | # Initialize Regex Validator if regex patterns are provided 88 | if self.regex_bans: 89 | self.regex_validator = RegexValidator( 90 | tokenizer=tokenizer, 91 | regex_bans=self.regex_bans, 92 | slow_debug=slow_debug, 93 | debug_delay=debug_delay, 94 | debug_output=debug_output, 95 | probs_cache_longrange=self.probs_cache_longrange 96 | ) 97 | else: 98 | self.regex_validator = None 99 | 100 | self.streamer_retval = None 101 | 102 | def _generate_streaming(self, current_input_ids, new_toks_to_generate, temperature, min_p, top_k, top_p, pad_token_id, stopping_criteria_args): 103 | streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=False) 104 | 105 | generation_kwargs = dict( 106 | input_ids=current_input_ids, 107 | attention_mask=torch.ones_like(current_input_ids), 108 | max_new_tokens=new_toks_to_generate, 109 | do_sample=True, 110 | temperature=temperature, 111 | min_p=min_p, 112 | top_k=top_k, 113 | top_p=top_p, 114 | pad_token_id=pad_token_id, 115 | num_return_sequences=1, 116 | return_dict_in_generate=True, 117 | output_logits=True, 118 | streamer=streamer, 119 | **stopping_criteria_args 120 | ) 121 | 122 | # Create an Event to signal thread termination 123 | stop_event = threading.Event() 124 | 125 | # Create a Queue to store the generation output or errors 126 | output_queue = queue.Queue() 127 | 128 | # Define a function to run generation and put the result in the queue 129 | def generate_and_queue(): 130 | try: 131 | output = self.model.generate(**generation_kwargs) 132 | if not stop_event.is_set(): 133 | output_queue.put((output, None)) # None means no exception occurred 134 | except Exception as e: 135 | print(f"Exception during generation: {e}") # Debug print 136 | if not stop_event.is_set(): 137 | output_queue.put((None, e)) # Put the exception in the queue 138 | stop_event.set() 139 | 140 | # Start the generation in a separate thread 141 | thread = Thread(target=generate_and_queue) 142 | thread.start() 143 | 144 | try: 145 | for new_text in streamer: 146 | yield new_text 147 | except Exception as e: 148 | print(f"Exception during streaming: {e}") # Debug print 149 | # Add the exception to the output queue so it is propagated to the caller 150 | if not stop_event.is_set(): 151 | output_queue.put((None, e)) # Handle exception during streaming 152 | 153 | # Wait for the generation to complete or for the thread to be terminated 154 | thread.join() 155 | 156 | # Initialize default empty lists for output variables 157 | generated_sequence = [] 158 | new_logits = [] 159 | error = None # Default error to None 160 | 161 | # Check if there's any output in the queue 162 | if not output_queue.empty(): 163 | generation_output, error = output_queue.get() 164 | 165 | # Check if an error occurred during generation or streaming 166 | if error: 167 | print(f"Generation or streaming failed: {error}") 168 | else: 169 | # Extract logits and sequence from the generation output 170 | new_logits = generation_output.logits 171 | generated_sequence = generation_output.sequences[0].tolist() 172 | 173 | # Add final debug information for empty output 174 | if not generated_sequence: 175 | print("Warning: Generated sequence is empty.") 176 | if not new_logits: 177 | print("Warning: Logits are empty.") 178 | 179 | # Return the generated sequence, logits, and any error 180 | self.streamer_retval = (generated_sequence, new_logits, error) 181 | 182 | 183 | 184 | @torch.no_grad() 185 | async def generate_stream( 186 | self, 187 | prompt: str, 188 | max_length: int = None, 189 | max_new_tokens: int = None, 190 | temperature: float = 1.0, 191 | top_k: int = None, 192 | top_p: float = None, 193 | min_p: float = None, 194 | ): 195 | """ 196 | Generates text in a streaming fashion with custom downregulation and backtracking. 197 | 198 | Args: 199 | prompt (str): The initial text prompt. 200 | max_length (int, optional): The maximum length of the generated text. 201 | max_new_tokens (int, optional): The maximum number of new tokens to generate. 202 | temperature (float): Sampling temperature. 203 | top_k (int): Top-k filtering. 204 | top_p (float): Top-p (nucleus) filtering. 205 | min_p (float): Minimum probability filtering. 206 | 207 | Yields: 208 | Generator[List[int], None, None]: Yields generated token sequences. 209 | """ 210 | try: 211 | # Encode the prompt 212 | input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) 213 | generated_sequence = input_ids[0].tolist() 214 | 215 | # If the prompt already came with a bos token, we don't want to add it again 216 | if self.tokenizer.bos_token and \ 217 | prompt.startswith(self.tokenizer.bos_token) and \ 218 | not prompt.startswith(self.tokenizer.bos_token * 2) and \ 219 | generated_sequence[0] == self.tokenizer.bos_token_id and \ 220 | generated_sequence[1] == self.tokenizer.bos_token_id: 221 | generated_sequence = generated_sequence[1:] 222 | 223 | self.prompt_length = len(generated_sequence) 224 | self.prompt_length_chars = len(prompt) 225 | current_position = len(generated_sequence) # Tracks the current position in the sequence 226 | output_tokens_counter = 0 227 | pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id 228 | next_token_logits = None 229 | filtered_logits = None 230 | 231 | if max_length != None: 232 | this_max_new_tokens = max_length - self.prompt_length 233 | if this_max_new_tokens < 0: 234 | this_max_new_tokens = 0 235 | if max_new_tokens == None or this_max_new_tokens < max_new_tokens: 236 | max_new_tokens = this_max_new_tokens 237 | else: 238 | if max_new_tokens == None: 239 | max_new_tokens = 8096 240 | 241 | stopping_criteria_args = {} 242 | self.stopping_criteria = [] 243 | 244 | if self.enforce_json: 245 | json_stopping_criteria = JSONValidationStoppingCriteria( 246 | tokenizer=self.tokenizer, 247 | json_validator=self.json_validator, 248 | prompt_length=self.prompt_length 249 | ) 250 | self.stopping_criteria.append(json_stopping_criteria) 251 | 252 | if self.antislop_enabled: 253 | antislop_stopping_criteria = CustomSlopPhraseStoppingCriteria( 254 | tokenizer=self.tokenizer, 255 | max_slop_phrase_length=self.slop_phrase_handler.max_slop_phrase_length, 256 | min_slop_phrase_length=self.slop_phrase_handler.min_slop_phrase_length, 257 | prompt_length=self.prompt_length, 258 | slop_phrase_prob_adjustments=self.slop_phrase_handler.slop_phrase_prob_adjustments 259 | ) 260 | self.stopping_criteria.append(antislop_stopping_criteria) 261 | 262 | # Initialize Regex Validation Stopping Criteria 263 | if self.regex_validator: 264 | regex_stopping_criteria = RegexValidationStoppingCriteria( 265 | tokenizer=self.tokenizer, 266 | regex_validator=self.regex_validator, 267 | prompt_length=self.prompt_length 268 | ) 269 | self.stopping_criteria.append(regex_stopping_criteria) 270 | 271 | 272 | if self.stopping_criteria: 273 | stopping_criteria_args = { 274 | "stopping_criteria": StoppingCriteriaList(self.stopping_criteria) 275 | } 276 | 277 | 278 | while True: 279 | if max_new_tokens is not None and len(generated_sequence) - self.prompt_length >= max_new_tokens: 280 | break 281 | 282 | new_toks_to_generate = max_new_tokens - (len(generated_sequence) - self.prompt_length) 283 | 284 | current_input_ids = torch.tensor([generated_sequence], device=self.device) 285 | 286 | regenerating = False 287 | 288 | if current_position in self.probs_cache: 289 | # We backtracked and want to use the cached logits 290 | next_token_probs = self.probs_cache[current_position] 291 | regenerating = True 292 | else: 293 | context = "" 294 | #print(new_toks_to_generate) 295 | for new_text in self._generate_streaming( 296 | current_input_ids, 297 | new_toks_to_generate, 298 | temperature, 299 | min_p, 300 | top_k, 301 | top_p, 302 | pad_token_id, 303 | stopping_criteria_args 304 | ): 305 | context += new_text 306 | output_tokens_counter += 1 307 | 308 | # sometimes model.generate adds an extra bos token so we'll manually clip it off. 309 | # otherwise we have conflicts with the originally calculated prompt_length 310 | if self.tokenizer.bos_token and \ 311 | prompt.startswith(self.tokenizer.bos_token) and \ 312 | not prompt.startswith(self.tokenizer.bos_token * 2) and \ 313 | context.startswith(self.tokenizer.bos_token * 2): 314 | context = context[len(self.tokenizer.bos_token):] 315 | 316 | if output_tokens_counter >= self.output_every_n_tokens: 317 | output_tokens_counter = 0 318 | 319 | if self.inference_output: 320 | with self.inference_output: 321 | self.inference_output.clear_output(wait=True) 322 | 323 | display(HTML(f"
{context[self.prompt_length_chars:]}
")) 324 | 325 | self.sequence_queue.put(self.tokenizer.encode(context, add_special_tokens=False)) 326 | 327 | 328 | torch.cuda.empty_cache() 329 | 330 | # sync with the returned vals in case the streaming came thru out of order 331 | # (not sure if this is necessary) 332 | if self.streamer_retval: 333 | generated_sequence, new_logits, error = self.streamer_retval 334 | if error: 335 | self.generation_complete.set() 336 | return 337 | 338 | # sometimes model.generate adds an extra bos token so we'll manually clip it off. 339 | # otherwise we have conflicts with the originally calculated prompt_length 340 | if self.tokenizer.bos_token and \ 341 | prompt.startswith(self.tokenizer.bos_token) and \ 342 | not prompt.startswith(self.tokenizer.bos_token * 2) and \ 343 | generated_sequence[0] == self.tokenizer.bos_token_id and \ 344 | generated_sequence[1] == self.tokenizer.bos_token_id: 345 | generated_sequence = generated_sequence[1:] 346 | 347 | self.streamer_retval = None 348 | else: 349 | print('!! error missing retval from streamer') 350 | self.generation_complete.set() 351 | return 352 | 353 | for i, logit in enumerate(new_logits): 354 | self.probs_cache[current_position + i] = torch.softmax(logit.clone(), dim=-1) 355 | 356 | next_token = generated_sequence[-1] 357 | current_position = len(generated_sequence) 358 | 359 | 360 | if regenerating: 361 | # Apply min_p, top-k and top-p filtering 362 | filtered_probs = self._filter_probs(next_token_probs, top_k, top_p, min_p) 363 | # Sample the next token 364 | next_token_index = torch.multinomial(filtered_probs, num_samples=1) 365 | 366 | next_token = next_token_index.item() 367 | # Append the new token to the sequence 368 | generated_sequence.append(next_token) 369 | #print('alt token:',self.tokenizer.decode(next_token), self.probs_cache[current_position][:, next_token]) 370 | output_tokens_counter += 1 371 | 372 | if output_tokens_counter >= self.output_every_n_tokens: 373 | output_tokens_counter = 0 374 | current_text = self.tokenizer.decode(generated_sequence[self.prompt_length:]) 375 | if self.inference_output: 376 | with self.inference_output: 377 | self.inference_output.clear_output(wait=True) 378 | display(HTML(f"
{current_text}
")) 379 | self.sequence_queue.put(generated_sequence) 380 | #print('downregulated token after reselection', self.slop_phrase_handler.probs_cache[current_position][:, self.json_validator.last_downregulated_token]) 381 | current_position = len(generated_sequence) 382 | 383 | if regenerating and self.slow_debug: 384 | alt_token = self.tokenizer.decode(next_token, skip_special_tokens=True) 385 | debug_info = f"Alternate token: {[alt_token]}" 386 | 387 | self._display_debug(debug_info) 388 | if self.slow_debug: 389 | time.sleep(self.debug_delay) 390 | 391 | 392 | # Clean up the probs cache 393 | if not self.enforce_json and not self.regex_bans: 394 | # json validation needs to keep the long range dependencies 395 | # although we can probably delete the ones that aren't flagged in self.probs_cache_longrange. 396 | if self.antislop_enabled: 397 | to_del = [key for key in self.probs_cache if key < current_position - self.slop_phrase_handler.max_slop_phrase_length - 5 and not self.probs_cache_longrange.get(key, False)] 398 | for key in to_del: 399 | if key not in self.probs_cache_longrange: 400 | del self.probs_cache[key] 401 | 402 | 403 | # Check for end-of-sequence token 404 | if next_token == self.tokenizer.eos_token_id: 405 | break 406 | 407 | # JSON validation 408 | if self.enforce_json: 409 | result = self.json_validator.validate_json_string(generated_sequence, self.prompt_length, self.probs_cache) 410 | if result != False: 411 | generated_sequence = result 412 | current_position = len(generated_sequence) 413 | continue # Skip the rest of this iteration and start over 414 | 415 | # After adding the token, check for disallowed sequences 416 | if self.antislop_enabled: 417 | antislop_result = self.slop_phrase_handler.deslop(generated_sequence, self.prompt_length) 418 | if antislop_result != False: 419 | generated_sequence = antislop_result 420 | current_position = len(generated_sequence) 421 | continue 422 | 423 | # Initialize Regex Validation Stopping Criteria 424 | if self.regex_validator: 425 | regex_result = self.regex_validator.validate_regex_matches(generated_sequence, self.prompt_length, self.probs_cache) 426 | if regex_result != False: 427 | generated_sequence = regex_result 428 | current_position = len(generated_sequence) 429 | continue 430 | 431 | 432 | 433 | # Final display of the generated text 434 | final_text = self.tokenizer.decode(generated_sequence[self.prompt_length:], skip_special_tokens=False) 435 | if self.inference_output: 436 | with self.inference_output: 437 | self.inference_output.clear_output(wait=True) 438 | display(HTML(f"
{final_text}
")) 439 | self.sequence_queue.put(generated_sequence) 440 | 441 | 442 | # Clear variables to free up memory 443 | del next_token_logits, filtered_logits 444 | 445 | # signal end of generation 446 | self.generation_complete.set() 447 | except Exception as e: 448 | print(e) 449 | # signal end of generation 450 | self.generation_complete.set() 451 | 452 | 453 | def _filter_probs(self, probs: torch.FloatTensor, top_k: int, top_p: float, min_p: float) -> torch.FloatTensor: 454 | # Make a copy of the probabilities to ensure we do not modify the original tensor 455 | probs = probs.clone() 456 | 457 | # Apply min_p filtering 458 | if min_p is not None: 459 | top_prob, _ = torch.max(probs, dim=-1) 460 | scaled_min_p = min_p * top_prob 461 | probs = torch.where(probs < scaled_min_p, 0, probs) 462 | 463 | if top_k is not None and top_k > 0: 464 | top_k = min(top_k, probs.size(-1)) 465 | top_k_probs, _ = torch.topk(probs, top_k) 466 | min_top_k = top_k_probs[:, -1].unsqueeze(-1) 467 | probs = torch.where(probs < min_top_k, 0, probs) 468 | 469 | if top_p is not None and top_p < 1.0: 470 | sorted_probs, sorted_indices = torch.sort(probs, descending=True) 471 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) 472 | 473 | sorted_indices_to_remove = cumulative_probs > top_p 474 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() 475 | sorted_indices_to_remove[:, 0] = False 476 | indices_to_remove = sorted_indices_to_remove.scatter( 477 | dim=1, index=sorted_indices, src=sorted_indices_to_remove 478 | ) 479 | probs = probs.masked_fill(indices_to_remove, 0) 480 | 481 | return probs 482 | 483 | def _display_debug(self, message: str): 484 | """ 485 | Displays debug information in the debug_output widget. 486 | """ 487 | if self.debug_output: 488 | with self.debug_output: 489 | self.debug_output.clear_output(wait=True) 490 | display(HTML(f"
{message}
")) 491 | else: 492 | print(message) 493 | 494 | def _clear_gpu_memory_async(self): 495 | def clear_gpu_memory(): 496 | torch.cuda.empty_cache() 497 | 498 | # Create and start the daemon thread 499 | cleaner_thread = threading.Thread(target=clear_gpu_memory, daemon=True) 500 | cleaner_thread.start() 501 | 502 | # Return immediately without waiting for the thread 503 | return 504 | 505 | def cleanup(self): 506 | # Clear the queue 507 | while not self.sequence_queue.empty(): 508 | try: 509 | self.sequence_queue.get_nowait() 510 | except queue.Empty: 511 | break 512 | 513 | # Clear the event 514 | self.generation_complete.clear() 515 | 516 | # Clear caches 517 | self.probs_cache.clear() 518 | self.probs_cache_longrange.clear() 519 | 520 | # Clear other attributes 521 | self.model = None 522 | self.tokenizer = None 523 | self.slop_phrase_handler = None 524 | self.json_validator = None 525 | self.regex_validator = None 526 | 527 | # Clear output widgets 528 | if self.inference_output: 529 | self.inference_output.clear_output() 530 | if self.debug_output: 531 | self.debug_output.clear_output() 532 | 533 | # Clear CUDA cache 534 | if torch.cuda.is_available(): 535 | torch.cuda.empty_cache() 536 | 537 | def chat_antislop( 538 | model: PreTrainedModel, 539 | tokenizer: PreTrainedTokenizer, 540 | messages: List[Dict[str, str]], 541 | max_length: int = None, 542 | max_new_tokens: int = None, 543 | temperature: float = 1.0, 544 | top_k: int = None, 545 | top_p: float = None, 546 | min_p: float = None, 547 | slop_phrase_prob_adjustments: Dict[str, float] = None, 548 | adjustment_strength: float = 1.0, # 1.0 is no change from the provided adjustment factors. 549 | device: torch.device = torch.device('cuda'), 550 | streaming: bool = False, 551 | stream_smoothing: bool = True, 552 | slow_debug: bool = False, # Add slow_debug argument for debugging 553 | output_every_n_tokens: int = 1, # Control how frequently the output is updated 554 | debug_delay: float = 2.0, # Delay for slow debugging mode 555 | inference_output: Output = None, # For visualization during generation 556 | debug_output: Output = None, # For visualization of debug information 557 | enforce_json: bool = False, 558 | antislop_enabled: bool = True, 559 | regex_bans: List[str] = None, 560 | ): 561 | """ 562 | Generates a chat response while avoiding overrepresented phrases (slop) with debugging features. 563 | This method creates a generator or a non-streamed output, depending on the streaming flag. 564 | 565 | Args: 566 | model (PreTrainedModel): The language model. 567 | tokenizer (PreTrainedTokenizer): The tokenizer. 568 | messages (List[Dict[str, str]]): The list of messages in the conversation. 569 | max_length (int, optional): The maximum length of the generated text (including the prompt). 570 | max_new_tokens (int, optional): The maximum number of new tokens to generate. 571 | temperature (float): Sampling temperature. 572 | top_k (int): Top-k filtering. 573 | top_p (float): Top-p (nucleus) filtering. 574 | min_p (float): Minimum probability filtering. 575 | slop_phrase_prob_adjustments (Dict[str, float], optional): Dictionary of target words with their respective probability adjustment factor. 576 | adjustment_strength (float, optional): Strength of the downregulation adjustment. 577 | device (torch.device, optional): The device to run the model on. 578 | streaming (bool, optional): Whether to yield tokens as they are generated. 579 | slow_debug (bool, optional): Enables slow debug mode when set to True. 580 | output_every_n_tokens (int, optional): Frequency of updating the inference output display. 581 | debug_delay (float, optional): Time in seconds to pause during slow debug steps. 582 | inference_output (Output, optional): For visualization during generation. 583 | debug_output (Output, optional): For visualization of debug information. 584 | 585 | Returns: 586 | Union[Generator[str, None, None], List[int]]: 587 | If streaming is True, yields generated text chunks. 588 | If streaming is False, returns a list of generated token IDs. 589 | """ 590 | 591 | # Build the prompt using the provided messages 592 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 593 | 594 | return generate_antislop( 595 | model=model, 596 | tokenizer=tokenizer, 597 | prompt=prompt, 598 | max_length=max_length, 599 | max_new_tokens=max_new_tokens, 600 | temperature=temperature, 601 | top_k=top_k, 602 | top_p=top_p, 603 | min_p=min_p, 604 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments, 605 | adjustment_strength=adjustment_strength, 606 | device=device, 607 | slow_debug=slow_debug, 608 | output_every_n_tokens=output_every_n_tokens, 609 | debug_delay=debug_delay, 610 | inference_output=inference_output, 611 | debug_output=debug_output, 612 | antislop_enabled=antislop_enabled, 613 | enforce_json=enforce_json, 614 | regex_bans=regex_bans, 615 | streaming=streaming, 616 | stream_smoothing=stream_smoothing, 617 | ) 618 | 619 | 620 | def generate_antislop( 621 | model: PreTrainedModel, 622 | tokenizer: PreTrainedTokenizer, 623 | prompt: str, 624 | max_length: int = None, 625 | max_new_tokens: int = None, 626 | temperature: float = 1.0, 627 | top_k: int = None, 628 | top_p: float = None, 629 | min_p: float = None, 630 | slop_phrase_prob_adjustments: Dict[str, float] = None, 631 | adjustment_strength: float = 1.0, 632 | device: torch.device = torch.device('cuda'), 633 | streaming: bool = False, 634 | stream_smoothing: bool = True, 635 | slow_debug: bool = False, # Added slow_debug 636 | output_every_n_tokens: int = 1, 637 | debug_delay: float = 2.0, 638 | inference_output: Output = None, 639 | debug_output: Output = None, 640 | enforce_json: bool = False, 641 | antislop_enabled: bool = True, 642 | regex_bans: List[str] = None, 643 | ) -> Union[Generator[str, None, None], List[int]]: 644 | """ 645 | Wrapper function for generate_antislop that handles both streaming and non-streaming modes. 646 | """ 647 | # Type checking and validation of input arguments 648 | if not isinstance(prompt, str): 649 | raise TypeError("prompt must be a string") 650 | if max_length is not None and not isinstance(max_length, int): 651 | raise TypeError("max_length must be an integer or None") 652 | if max_new_tokens is not None and not isinstance(max_new_tokens, int): 653 | raise TypeError("max_new_tokens must be an integer or None") 654 | if not isinstance(temperature, (int, float)): 655 | raise TypeError("temperature must be a float") 656 | if top_k is not None and not isinstance(top_k, int): 657 | raise TypeError("top_k must be an integer or None") 658 | if top_p is not None and not isinstance(top_p, float): 659 | raise TypeError("top_p must be a float or None") 660 | if min_p is not None and not isinstance(min_p, float): 661 | raise TypeError("min_p must be a float or None") 662 | if slop_phrase_prob_adjustments is not None and not isinstance(slop_phrase_prob_adjustments, dict): 663 | raise TypeError("slop_phrase_prob_adjustments must be a dictionary or None") 664 | if not isinstance(adjustment_strength, (int, float)): 665 | raise TypeError("adjustment_strength must be a float") 666 | if not isinstance(device, torch.device): 667 | raise TypeError("device must be an instance of torch.device") 668 | if not isinstance(streaming, bool): 669 | raise TypeError("streaming must be a boolean") 670 | 671 | # Value validation 672 | if max_length is not None and max_length <= 0: 673 | raise ValueError("max_length must be positive") 674 | if max_new_tokens is not None and max_new_tokens <= 0: 675 | raise ValueError("max_new_tokens must be positive") 676 | if temperature <= 0: 677 | raise ValueError("temperature must be > 0") 678 | if top_k is not None and top_k < 0: 679 | raise ValueError("top_k must be positive") 680 | if top_p is not None and (top_p < 0 or top_p > 1): 681 | raise ValueError("top_p must be in the range (0, 1]") 682 | if min_p is not None and (min_p < 0 or min_p > 1): 683 | print(min_p) 684 | raise ValueError("min_p must be in the range (0, 1]") 685 | if adjustment_strength < 0: 686 | raise ValueError("adjustment_strength must be non-negative") 687 | 688 | if not debug_output or not inference_output: 689 | debug_delay = 0 690 | slow_debug = False 691 | 692 | if slop_phrase_prob_adjustments: 693 | for phrase, adjustment in slop_phrase_prob_adjustments.items(): 694 | if not isinstance(phrase, str): 695 | raise TypeError("All keys in slop_phrase_prob_adjustments must be strings") 696 | if not isinstance(adjustment, (int, float)): 697 | raise TypeError("All values in slop_phrase_prob_adjustments must be floats") 698 | 699 | if streaming: 700 | return _generate_antislop( 701 | model=model, 702 | tokenizer=tokenizer, 703 | prompt=prompt, 704 | max_length=max_length, 705 | max_new_tokens=max_new_tokens, 706 | temperature=temperature, 707 | top_k=top_k, 708 | top_p=top_p, 709 | min_p=min_p, 710 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments, 711 | adjustment_strength=adjustment_strength, 712 | device=device, 713 | slow_debug=slow_debug, # Pass slow_debug to support detailed debug output 714 | output_every_n_tokens=output_every_n_tokens, 715 | debug_delay=debug_delay, 716 | inference_output=inference_output, 717 | debug_output=debug_output, 718 | enforce_json=enforce_json, 719 | regex_bans=regex_bans, 720 | antislop_enabled=antislop_enabled, 721 | streaming=streaming, 722 | stream_smoothing=stream_smoothing, 723 | ) 724 | else: 725 | generated_tokens = [] 726 | for token in _generate_antislop( 727 | model=model, 728 | tokenizer=tokenizer, 729 | prompt=prompt, 730 | max_length=max_length, 731 | max_new_tokens=max_new_tokens, 732 | temperature=temperature, 733 | top_k=top_k, 734 | top_p=top_p, 735 | min_p=min_p, 736 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments, 737 | adjustment_strength=adjustment_strength, 738 | device=device, 739 | slow_debug=slow_debug, # Pass slow_debug to support detailed debug output 740 | output_every_n_tokens=output_every_n_tokens, 741 | debug_delay=debug_delay, 742 | inference_output=inference_output, 743 | debug_output=debug_output, 744 | enforce_json=enforce_json, 745 | regex_bans=regex_bans, 746 | antislop_enabled=antislop_enabled, 747 | streaming=streaming, 748 | stream_smoothing=stream_smoothing, 749 | ): 750 | generated_tokens.append(token) 751 | return generated_tokens 752 | 753 | def simple_thread_function(): 754 | print("Simple thread function executed") 755 | time.sleep(1) 756 | print("Simple thread function completed") 757 | 758 | 759 | def _generate_antislop( 760 | model: PreTrainedModel, 761 | tokenizer: PreTrainedTokenizer, 762 | prompt: str, 763 | max_length: int = None, 764 | max_new_tokens: int = None, 765 | temperature: float = 1.0, 766 | top_k: int = None, 767 | top_p: float = None, 768 | min_p: float = None, 769 | slop_phrase_prob_adjustments: Dict[str, float] = None, 770 | adjustment_strength: float = 1.0, 771 | device: torch.device = torch.device('cuda'), 772 | slow_debug: bool = False, # Added slow_debug 773 | output_every_n_tokens: int = 1, 774 | debug_delay: float = 2.0, 775 | inference_output: 'Output' = None, # Assuming Output is defined elsewhere 776 | debug_output: 'Output' = None, 777 | streaming: bool = False, 778 | stream_smoothing: bool = True, 779 | enforce_json: bool = False, 780 | antislop_enabled: bool = True, 781 | regex_bans: List[str] = None, 782 | ) -> Generator[int, None, None]: 783 | """ 784 | Generates text while avoiding overrepresented phrases (slop). 785 | This function is now always a generator with temporal buffering. 786 | """ 787 | 788 | if streaming and regex_bans: 789 | raise ValueError("Streaming is not supported when using regex patterns.") 790 | 791 | # Precompute starting tokens for the slop phrases 792 | starting_tokens_lookup = precompute_starting_tokens(tokenizer, slop_phrase_prob_adjustments or {}) 793 | 794 | # Initialize the sampler 795 | sampler = AntiSlopSampler( 796 | model=model, 797 | tokenizer=tokenizer, 798 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments or {}, 799 | starting_tokens_lookup=starting_tokens_lookup, 800 | adjustment_strength=adjustment_strength, 801 | device=device, 802 | slow_debug=slow_debug, # Enable slow debugging 803 | output_every_n_tokens=output_every_n_tokens, 804 | debug_delay=debug_delay, 805 | inference_output=inference_output, 806 | debug_output=debug_output, 807 | enforce_json=enforce_json, 808 | antislop_enabled=antislop_enabled, 809 | regex_bans=regex_bans, 810 | ) 811 | 812 | # Generate token stream 813 | generate_kwargs = { 814 | 'max_length': max_length, 815 | 'max_new_tokens': max_new_tokens, 816 | 'temperature': temperature, 817 | 'top_k': top_k, 818 | 'top_p': top_p, 819 | 'min_p': min_p, 820 | } 821 | 822 | loop = asyncio.new_event_loop() 823 | 824 | def run_event_loop(loop): 825 | asyncio.set_event_loop(loop) 826 | loop.run_until_complete(sampler.generate_stream(prompt, **generate_kwargs)) 827 | 828 | thread = threading.Thread(target=run_event_loop, args=(loop,), daemon=True) 829 | thread.start() 830 | 831 | try: 832 | prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False) 833 | if len(prompt_tokens) == 0: 834 | print('! prompt is empty') 835 | return 836 | 837 | backtracking_buffer_size = sampler.slop_phrase_handler.max_slop_phrase_length + 5 838 | last_released_position = len(prompt_tokens) - 1 839 | generated_sequence = [] 840 | 841 | if streaming and stream_smoothing: 842 | # Buffer to allow bactracking and also to smooth output rate 843 | temporal_buffer_size = 30 844 | last_generation_time = time.time() 845 | token_times = [last_generation_time] * len(prompt_tokens) 846 | while True: 847 | try: 848 | generated_sequence = sampler.sequence_queue.get(timeout=0.01) 849 | except queue.Empty: 850 | if sampler.generation_complete.is_set(): 851 | break 852 | continue 853 | 854 | # update token times 855 | if len(generated_sequence) <= len(token_times): 856 | # we backtracked 857 | token_times = token_times[:len(generated_sequence)-1] 858 | 859 | while len(generated_sequence) - last_released_position > backtracking_buffer_size: 860 | # get the latest sequence from the queue 861 | while True: 862 | try: 863 | generated_sequence = sampler.sequence_queue.get(timeout=0.001) 864 | if len(generated_sequence) <= len(token_times): 865 | # we backtracked 866 | token_times = token_times[:len(generated_sequence)-1] 867 | except queue.Empty: 868 | break 869 | if sampler.generation_complete.is_set(): 870 | break 871 | 872 | # calculate simple moving avg of last n token times 873 | adjusted_last_released_pos = last_released_position - len(prompt_tokens) 874 | sma_tokens = token_times[len(prompt_tokens):][adjusted_last_released_pos-temporal_buffer_size:adjusted_last_released_pos] 875 | if len(sma_tokens) > 0: 876 | sma_token_time = (time.time() - sma_tokens[0]) / len(sma_tokens) 877 | else: 878 | sma_token_time = 0 879 | #print(sma_token_time) 880 | 881 | if len(generated_sequence) - last_released_position > backtracking_buffer_size + temporal_buffer_size: 882 | sleep_time = 0.02 883 | else: 884 | sleep_time = sma_token_time 885 | 886 | last_released_position += 1 887 | token_to_release = generated_sequence[last_released_position] 888 | 889 | # Sleep to smooth the output 890 | if sleep_time > 0: 891 | time.sleep(sleep_time) 892 | 893 | token_times.append(time.time()) 894 | 895 | # Yield the token 896 | yield token_to_release 897 | else: 898 | # smoothing disabled 899 | while True: 900 | try: 901 | generated_sequence = sampler.sequence_queue.get(timeout=0.01) 902 | except queue.Empty: 903 | if sampler.generation_complete.is_set(): 904 | break 905 | continue 906 | 907 | while len(generated_sequence) - last_released_position > backtracking_buffer_size: 908 | last_released_position += 1 909 | token_to_release = generated_sequence[last_released_position] 910 | 911 | # Yield the token 912 | yield token_to_release 913 | 914 | # Release any remaining tokens after generation is complete 915 | if last_released_position < len(generated_sequence) - 1: 916 | #print(len(generated_sequence) - last_released_position, 'to release') 917 | for tok in generated_sequence[last_released_position + 1:]: 918 | # Release remaining tokens at full rate with constant delay 919 | yield tok 920 | time.sleep(0.02) # Constant delay as per user's instruction 921 | 922 | finally: 923 | # Stop the event loop 924 | loop.call_soon_threadsafe(loop.stop) 925 | # Wait for the thread to finish 926 | thread.join() 927 | # Close the loop 928 | loop.close() 929 | # Clean up the sampler 930 | sampler.cleanup() 931 | del sampler 932 | -------------------------------------------------------------------------------- /src/slop_index.py: -------------------------------------------------------------------------------- 1 | # Calculates a slop score for a provided text 2 | 3 | import json 4 | import re 5 | import matplotlib.pyplot as plt 6 | import matplotlib.ticker as ticker 7 | import numpy as np 8 | from joblib import Parallel, delayed 9 | 10 | def load_and_preprocess_slop_words(): 11 | with open('slop_phrase_prob_adjustments.json', 'r') as f: 12 | slop_phrases = json.load(f) 13 | 14 | phrase_weighting = [1.0 - prob_adjustment for word, prob_adjustment in slop_phrases] 15 | max_score = max(phrase_weighting) 16 | scaled_weightings = [score / max_score for score in phrase_weighting] 17 | n_slop_words = 600 18 | return {word.lower(): score for (word, _), score in zip(slop_phrases[:n_slop_words], scaled_weightings[:n_slop_words])} 19 | 20 | def extract_text_blocks(file_path, compiled_pattern): 21 | with open(file_path, 'r', encoding='utf-8') as file: 22 | content = file.read() 23 | 24 | matches = compiled_pattern.findall(content) 25 | return '\n'.join(matches) 26 | 27 | def calculate_slop_score_chunk(args): 28 | text, slop_words_chunk = args 29 | return sum( 30 | score * len(re.findall(r'\b' + re.escape(word) + r'\b', text)) 31 | for word, score in slop_words_chunk.items() 32 | ) 33 | 34 | def calculate_and_plot_slop_indices(slop_indices): 35 | if not slop_indices: 36 | print("No slop indices to plot.") 37 | return [] 38 | 39 | # Sort the indices in descending order 40 | sorted_indices = sorted(slop_indices.items(), key=lambda x: x[1], reverse=True) 41 | models, indices = zip(*sorted_indices) if sorted_indices else ([], []) 42 | 43 | # Set the style for better aesthetics 44 | plt.style.use('seaborn-darkgrid') # You can choose other styles like 'ggplot', 'fivethirtyeight', etc. 45 | 46 | plt.figure(figsize=(12, 18)) 47 | 48 | # Create a horizontal bar chart 49 | bars = plt.barh(models, indices, color=plt.cm.viridis(range(len(indices)))) 50 | 51 | plt.title('Slop Index by Model', fontsize=16, weight='bold', pad=15) 52 | plt.xlabel('Slop Index', fontsize=14, labelpad=10) 53 | plt.ylabel('Model', fontsize=14, labelpad=10) 54 | 55 | # Invert y-axis to have the highest slop index on top 56 | plt.gca().invert_yaxis() 57 | 58 | # Add value labels to each bar 59 | for bar in bars: 60 | width = bar.get_width() 61 | plt.text(width + max(indices)*0.01, bar.get_y() + bar.get_height()/2, 62 | f'{width:.2f}', va='center', fontsize=12) 63 | 64 | # Customize x-axis ticks 65 | plt.gca().xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) 66 | 67 | plt.tight_layout() 68 | 69 | # Save the figure with higher resolution 70 | plt.savefig('slop_index_chart.png', dpi=300) 71 | plt.show() 72 | plt.close() 73 | 74 | return sorted_indices 75 | 76 | def split_into_chunks(slop_words, num_chunks): 77 | slop_words_items = list(slop_words.items()) 78 | chunk_size = len(slop_words_items) // num_chunks 79 | if chunk_size == 0: 80 | chunk_size = 1 81 | return [dict(slop_words_items[i:i + chunk_size]) for i in range(0, len(slop_words_items), chunk_size)] 82 | 83 | 84 | # Call this to function to calculate a slop score. 85 | # This is the way it's calculated for the eqbench creative writing leaderboard. 86 | def calculate_slop_index(extracted_text): 87 | slop_words = load_and_preprocess_slop_words() 88 | 89 | num_chunks = 12 #mp.cpu_count() 90 | slop_words_chunks = split_into_chunks(slop_words, num_chunks) 91 | 92 | if not extracted_text: 93 | slop_index = 0.0 94 | else: 95 | # Parallelize the calculation using joblib 96 | slop_scores = Parallel(n_jobs=num_chunks)(delayed(calculate_slop_score_chunk)((extracted_text, chunk)) for chunk in slop_words_chunks) 97 | 98 | slop_score = sum(slop_scores) 99 | total_words = len(extracted_text.split()) 100 | slop_index = (slop_score / total_words) * 1000 if total_words > 0 else 0 101 | return slop_index 102 | 103 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple, Generator, Set, Union 2 | from transformers import PreTrainedTokenizer 3 | 4 | def precompute_starting_tokens( 5 | tokenizer: PreTrainedTokenizer, slop_phrase_prob_adjustments: Dict[str, float] 6 | ) -> Dict[Tuple[int, ...], Set[int]]: 7 | starting_tokens_lookup = {} 8 | 9 | for slop_phrase in slop_phrase_prob_adjustments.keys(): 10 | starting_tokens = set() 11 | variants = [ 12 | slop_phrase.lower(), 13 | slop_phrase.capitalize(), 14 | slop_phrase.upper(), 15 | f" {slop_phrase.lower()}", 16 | f" {slop_phrase.capitalize()}", 17 | f" {slop_phrase.upper()}", 18 | ] 19 | 20 | for variant in variants: 21 | token_ids = tokenizer.encode(variant, add_special_tokens=False) 22 | if token_ids: 23 | starting_tokens.add(token_ids[0]) 24 | first_token_decoded = tokenizer.decode(token_ids[0], skip_special_tokens=True) 25 | 26 | for i in range(len(first_token_decoded) - 1): 27 | prefix = first_token_decoded[:-(i + 1)] 28 | if prefix == ' ': 29 | continue 30 | encoded_prefix = tokenizer.encode(prefix, add_special_tokens=False) 31 | if encoded_prefix: 32 | starting_tokens.add(encoded_prefix[0]) 33 | 34 | starting_tokens_lookup[slop_phrase] = starting_tokens 35 | 36 | return starting_tokens_lookup -------------------------------------------------------------------------------- /src/validator_json.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from transformers import StoppingCriteria 4 | from IPython.display import clear_output, display, HTML 5 | 6 | class JSONValidator: 7 | def __init__(self, tokenizer, slow_debug, debug_delay, debug_output, probs_cache_longrange): 8 | self.tokenizer = tokenizer 9 | self.slow_debug = slow_debug 10 | self.debug_delay = debug_delay 11 | self.debug_output = debug_output 12 | self.probs_cache_longrange = probs_cache_longrange 13 | 14 | # Escaped tokens lookup for common characters that need escaping in JSON strings 15 | self.escaped_tokens_lookup = { 16 | '\n': self.tokenizer.encode('\\n', add_special_tokens=False), 17 | '\t': self.tokenizer.encode('\\t', add_special_tokens=False), 18 | '\r': self.tokenizer.encode('\\r', add_special_tokens=False), 19 | '"': self.tokenizer.encode('\\"', add_special_tokens=False), 20 | ' "': self.tokenizer.encode(' \\"', add_special_tokens=False), 21 | } 22 | 23 | self.last_downregulated_token = 0 24 | 25 | def validate_json_string(self, generated_sequence, prompt_length, probs_cache): 26 | result = self._validate_json_string(generated_sequence, prompt_length, probs_cache) 27 | if result is not False: 28 | generated_sequence, problematic_token, invalid_char, reason = result 29 | if self.slow_debug: 30 | problematic_token_decoded = self.tokenizer.decode(problematic_token, skip_special_tokens=True) 31 | debug_info = f"JSON structure violation detected:\n{reason}\nProblematic char: {[invalid_char]} in token {[problematic_token_decoded]}" 32 | self._display_debug(debug_info) 33 | time.sleep(self.debug_delay) 34 | 35 | problematic_pos = len(generated_sequence) 36 | 37 | # Clear subsequent logit cache 38 | to_del = [key for key in probs_cache if key > problematic_pos] 39 | for key in to_del: 40 | del probs_cache[key] 41 | 42 | # Flag positions to keep in the logit cache 43 | self.probs_cache_longrange[problematic_pos] = True 44 | 45 | return generated_sequence 46 | 47 | return False 48 | 49 | # String validation checks for unescaped chars inside a (partial) json generation. 50 | # It works better to do this with with long-range constraints (at least for double 51 | # quotes) because we aren't forcing a termination of the string if that's not what 52 | # the model intended. 53 | # 54 | # To do this, if we are in a json string and see a ", we wait for more tokens to 55 | # see if the continuation looks like it was intended to be a string terminator, 56 | # or meant to be a quotation mark within the string. If it looks like it was an 57 | # accidentally unescaped quote, we downregulate this token and upregulate the 58 | # escaped token at that position (then backtrack & resample) 59 | # 60 | # It's not foolproof, but it does fix most of the json parsing fails that occur. 61 | def _validate_json_string(self, generated_sequence, prompt_length, probs_cache=None, validate_only=False): 62 | # Get only the generated text 63 | generated_text = self.tokenizer.decode(generated_sequence[prompt_length:]) 64 | 65 | # Create character to token position mapping 66 | char_to_token_pos = self._create_char_to_token_mapping(generated_sequence[prompt_length:], prompt_length) 67 | 68 | in_json = False 69 | in_string = False 70 | escape_next = False 71 | validating_string_end = False 72 | string_end_char = -1 73 | 74 | # Tracks the nesting of JSON objects and arrays 75 | brace_stack = [] 76 | expected_tokens = [] 77 | 78 | for i, char in enumerate(generated_text): 79 | if not in_json: 80 | if char in '{[': 81 | in_json = True 82 | # Start tracking the JSON structure 83 | brace_stack.append(char) 84 | expected_tokens = self._get_expected_tokens(brace_stack) 85 | continue 86 | 87 | if escape_next: 88 | escape_next = False 89 | continue 90 | 91 | if validating_string_end: 92 | if char in ['\n', '\t', '\r', ' ']: 93 | continue 94 | elif char in ',:]}[{': 95 | # Valid string termination 96 | in_string = False 97 | validating_string_end = False 98 | expected_tokens = self._get_expected_tokens(brace_stack) 99 | else: 100 | # Invalid string termination, backtrack 101 | if validate_only: 102 | return True 103 | 104 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, char_to_token_pos[string_end_char], upregulate_safe_continuation=True, probs_cache=probs_cache) 105 | return generated_sequence, problematic_token, char, "Unexpected string termination" 106 | 107 | if char == '\\': 108 | escape_next = True 109 | elif char == '"': 110 | if in_string: 111 | # End of string, validate its termination 112 | string_end_char = i 113 | validating_string_end = True 114 | else: 115 | # Start of string 116 | in_string = True 117 | expected_tokens = ['"'] 118 | elif in_string: 119 | if char in '\n\r\t': 120 | # These characters should be escaped in JSON strings 121 | if validate_only: 122 | return True 123 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, char_to_token_pos[i], upregulate_safe_continuation=True, probs_cache=probs_cache) 124 | return generated_sequence, problematic_token, char, "Unterminated character in string" 125 | 126 | elif char in '{[': 127 | brace_stack.append(char) 128 | expected_tokens = self._get_expected_tokens(brace_stack) 129 | elif char in '}]': 130 | if not brace_stack or (char == '}' and brace_stack[-1] != '{') or (char == ']' and brace_stack[-1] != '['): 131 | if validate_only: 132 | return True 133 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, char_to_token_pos[i], probs_cache=probs_cache) 134 | return generated_sequence, problematic_token, char, "Unexpected '}' or ']'" 135 | brace_stack.pop() 136 | expected_tokens = self._get_expected_tokens(brace_stack) 137 | elif char == ',': 138 | if not brace_stack or brace_stack[-1] not in '{[': 139 | if validate_only: 140 | return True 141 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, char_to_token_pos[i], probs_cache=probs_cache) 142 | return generated_sequence, problematic_token, char, "Unexpected ','" 143 | expected_tokens = self._get_expected_tokens(brace_stack) 144 | elif char == ':': 145 | if not brace_stack or brace_stack[-1] != '{': 146 | if validate_only: 147 | return True 148 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, char_to_token_pos[i], probs_cache=probs_cache) 149 | return generated_sequence, problematic_token, char, "Unexpected ':'" 150 | expected_tokens = self._get_expected_tokens(brace_stack) 151 | elif char not in ' \n\t\r': 152 | if char not in expected_tokens: 153 | if validate_only: 154 | return True 155 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, char_to_token_pos[i], probs_cache=probs_cache) 156 | return generated_sequence, problematic_token, char, "Unexpected char (expecting whitespace)" 157 | 158 | # Check if we have closed all JSON structures 159 | if not brace_stack: 160 | in_json = False 161 | 162 | return False 163 | 164 | def _get_expected_tokens(self, brace_stack): 165 | # Determine expected tokens based on current JSON structure 166 | if not brace_stack: 167 | return ['{', '['] 168 | elif brace_stack[-1] == '{': 169 | return ['"', '}', ' ', '\n', '\t', '\r'] 170 | elif brace_stack[-1] == '[': 171 | return ['{', '[', ']', '"', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 't', 'f', 'n', ' ', '\n', '\t', '\r'] 172 | else: 173 | return [',', '}', ']', ' ', '\n', '\t', '\r'] 174 | 175 | def _create_char_to_token_mapping(self, tokens, prompt_length): 176 | char_to_token_pos = {} 177 | char_pos = 0 178 | for token_pos, token in enumerate(tokens): 179 | token_text = self.tokenizer.decode(token) 180 | for _ in range(len(token_text)): 181 | char_to_token_pos[char_pos] = token_pos + prompt_length 182 | char_pos += 1 183 | return char_to_token_pos 184 | 185 | def _backtrack_and_adjust(self, generated_sequence, problematic_pos, upregulate_safe_continuation=False, probs_cache=None): 186 | # Identify the problematic token and backtrack 187 | problematic_token = generated_sequence[problematic_pos] 188 | 189 | # Downregulate the problematic token 190 | self.last_downregulated_token = problematic_token 191 | probs_cache[problematic_pos][:, problematic_token] *= 0.0001 192 | 193 | # Upregulate the properly escaped version 194 | if upregulate_safe_continuation: 195 | self._upregulate_safe_continuation(problematic_pos, problematic_token, probs_cache) 196 | 197 | return generated_sequence[:problematic_pos], problematic_token 198 | 199 | # We need to account for cases where the problematic token is more than 1 char 200 | # let's take this approach: 201 | # - first, check if the token is in the lookup 202 | # - if so, we upregulate the escaped lookup value 203 | # - if not, we check if the 1st char of the decoded token is on of our problematic tokens 204 | # - if so, upregulate just that (escaped) token 205 | # - else, we extract the substring up to (and not including) the first problematic char 206 | # - then upregulate the first token of this encoded string, as the intended continuation 207 | def _upregulate_safe_continuation(self, problematic_pos, problematic_token, probs_cache): 208 | # this function is specifically to handle unescaped chars inside strings 209 | 210 | problematic_token_decoded = self.tokenizer.decode(problematic_token, skip_special_tokens=True) 211 | 212 | debug=False 213 | 214 | # Check if the token is in the lookup 215 | if problematic_token_decoded in self.escaped_tokens_lookup: 216 | if debug: 217 | print('upregulating1', [self.tokenizer.decode(self.escaped_tokens_lookup[problematic_token_decoded][0])]) 218 | probs_cache[problematic_pos][:, self.escaped_tokens_lookup[problematic_token_decoded][0]] *= 2 219 | # normalise probs 220 | probs_cache[problematic_pos].div_(torch.sum(probs_cache[problematic_pos])) 221 | 222 | elif problematic_token_decoded[0] in self.escaped_tokens_lookup: 223 | encoded_escaped_tok = self.escaped_tokens_lookup[problematic_token_decoded[0]][0] 224 | if debug: 225 | print('upregulating2', [self.tokenizer.decode(encoded_escaped_tok)]) 226 | 227 | probs_cache[problematic_pos][:, encoded_escaped_tok] *= 2 228 | # normalise probs 229 | probs_cache[problematic_pos].div_(torch.sum(probs_cache[problematic_pos])) 230 | else: 231 | # Find the first problematic character 232 | first_problematic_index = next((i for i, char in enumerate(problematic_token_decoded) if char in self.escaped_tokens_lookup), None) 233 | 234 | if first_problematic_index is not None: 235 | # Extract the substring up to the first problematic character 236 | safe_substring = problematic_token_decoded[:first_problematic_index] 237 | 238 | # Encode the safe substring 239 | encoded_safe_substring = self.tokenizer.encode(safe_substring) 240 | 241 | # Upregulate the first token of the encoded safe substring 242 | if encoded_safe_substring: 243 | if debug: 244 | print('upregulating3', [safe_substring]) 245 | probs_cache[problematic_pos][:, encoded_safe_substring[0]] *= 2 246 | # normalise probs 247 | probs_cache[problematic_pos].div_(torch.sum(probs_cache[problematic_pos])) 248 | 249 | 250 | def _display_debug(self, message: str): 251 | """ 252 | Displays debug information in the debug_output widget. 253 | """ 254 | if self.debug_output: 255 | with self.debug_output: 256 | self.debug_output.clear_output(wait=True) 257 | display(HTML(f"
{message}
")) 258 | 259 | 260 | class JSONValidationStoppingCriteria(StoppingCriteria): 261 | def __init__(self, tokenizer, json_validator, prompt_length): 262 | self.tokenizer = tokenizer 263 | self.json_validator = json_validator 264 | self.prompt_length = prompt_length 265 | 266 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 267 | self.previous_tokens = input_ids[0].tolist() 268 | 269 | # Check if the generated sequence is valid JSON 270 | result = self.json_validator._validate_json_string(self.previous_tokens, self.prompt_length, validate_only=True) 271 | 272 | return result 273 | -------------------------------------------------------------------------------- /src/validator_regex.py: -------------------------------------------------------------------------------- 1 | # File: ai/antislop-sampler/src/validator_regex.py 2 | 3 | import re 4 | import time 5 | import torch 6 | from transformers import StoppingCriteria 7 | from IPython.display import clear_output, display, HTML 8 | from typing import List, Dict, Tuple, Optional 9 | 10 | # This implements banning of sequences using regex matching. 11 | # If the inference text matches one of the specified regex expressions, 12 | # we backtrack to the position of the first match, ban that token, then 13 | # continue inference. 14 | class RegexValidator: 15 | def __init__(self, tokenizer, regex_bans: List[str], slow_debug, debug_delay, debug_output, probs_cache_longrange): 16 | self.tokenizer = tokenizer 17 | self.regex_bans = [re.compile(pattern) for pattern in regex_bans] 18 | self.slow_debug = slow_debug 19 | self.debug_delay = debug_delay 20 | self.debug_output = debug_output 21 | self.probs_cache_longrange = probs_cache_longrange 22 | 23 | self.last_downregulated_token = 0 24 | 25 | def validate_regex_matches(self, generated_sequence, prompt_length, probs_cache): 26 | result = self._validate_regex_matches(generated_sequence, prompt_length, probs_cache) 27 | if result is not False: 28 | generated_sequence, problematic_token, invalid_match, reason = result 29 | if self.slow_debug: 30 | problematic_token_decoded = self.tokenizer.decode(problematic_token, skip_special_tokens=True) 31 | debug_info = f"Regex violation detected:\n{reason}\nProblematic match: '{invalid_match}' in token '{problematic_token_decoded}'" 32 | self._display_debug(debug_info) 33 | time.sleep(self.debug_delay) 34 | 35 | problematic_pos = len(generated_sequence) 36 | 37 | # Clear subsequent logit cache 38 | to_del = [key for key in probs_cache if key > problematic_pos] 39 | for key in to_del: 40 | del probs_cache[key] 41 | 42 | # Flag positions to keep in the logit cache 43 | self.probs_cache_longrange[problematic_pos] = True 44 | 45 | return generated_sequence 46 | 47 | return False 48 | 49 | def _validate_regex_matches(self, generated_sequence, prompt_length, probs_cache=None, validate_only=False): 50 | # Decode only the newly generated tokens 51 | generated_text = self.tokenizer.decode(generated_sequence[prompt_length:], skip_special_tokens=True) 52 | 53 | for pattern in self.regex_bans: 54 | match = pattern.search(generated_text) 55 | if match: 56 | # Find the position in tokens where the match starts 57 | match_start_char_pos = match.start() 58 | 59 | # We prepend a char because some tokenisers won't add the initial space 60 | # for the first token which would otherwise be there if there was a token 61 | # in front of it. If we don't have this we'll sometimes match to the wrong 62 | # token position. 63 | prepended_char = self.tokenizer.encode('|', add_special_tokens=False) 64 | problematic_token_pos = None 65 | 66 | for start_pos in range(len(generated_sequence)-1, prompt_length-1, -1): 67 | test_str = self.tokenizer.decode(prepended_char + generated_sequence[start_pos:], skip_special_tokens=True) 68 | 69 | if len(generated_text) - (len(test_str) - 1) <= match_start_char_pos: # -1 is to account for the prepended char 70 | problematic_token_pos = start_pos 71 | break 72 | if problematic_token_pos == None: 73 | print('!! failed to get problematic token pos') 74 | return False 75 | 76 | problematic_token = generated_sequence[problematic_token_pos] 77 | reason = f"Regex pattern '{pattern.pattern}' matched." 78 | 79 | if validate_only: 80 | return True 81 | 82 | # Adjust the generated sequence and cache 83 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, problematic_token_pos, upregulate_safe_continuation=False, probs_cache=probs_cache) 84 | return generated_sequence, problematic_token, match.group(), reason 85 | 86 | return False 87 | 88 | def _create_char_to_token_mapping(self, tokens, prompt_length): 89 | char_to_token_pos = {} 90 | char_pos = 0 91 | for token_pos, token in enumerate(tokens): 92 | token_text = self.tokenizer.decode([token], skip_special_tokens=True) 93 | for _ in range(len(token_text)): 94 | char_to_token_pos[char_pos] = token_pos + prompt_length 95 | char_pos += 1 96 | return char_to_token_pos 97 | 98 | def _backtrack_and_adjust(self, generated_sequence, problematic_pos, upregulate_safe_continuation=False, probs_cache=None): 99 | # Identify the problematic token and backtrack 100 | problematic_token = generated_sequence[problematic_pos] 101 | 102 | # Downregulate the problematic token 103 | self.last_downregulated_token = problematic_token 104 | probs_cache[problematic_pos][:, problematic_token] *= 0.0001 105 | 106 | # Flag positions to keep in the logit cache 107 | self.probs_cache_longrange[problematic_pos] = True 108 | 109 | # Clear the probs_cache ahead of start_pos since we've backtracked 110 | to_del = [key for key in probs_cache if key > problematic_pos] 111 | for key in to_del: 112 | del probs_cache[key] 113 | 114 | return generated_sequence[:problematic_pos], problematic_token 115 | 116 | 117 | 118 | def _display_debug(self, message: str): 119 | """ 120 | Displays debug information in the debug_output widget. 121 | """ 122 | if self.debug_output: 123 | with self.debug_output: 124 | self.debug_output.clear_output(wait=True) 125 | display(HTML(f"
{message}
")) 126 | 127 | class RegexValidationStoppingCriteria(StoppingCriteria): 128 | def __init__(self, tokenizer, regex_validator, prompt_length): 129 | self.tokenizer = tokenizer 130 | self.regex_validator = regex_validator 131 | self.prompt_length = prompt_length 132 | 133 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 134 | previous_tokens = input_ids[0].tolist() 135 | 136 | # Check if any regex pattern matches the generated sequence 137 | result = self.regex_validator._validate_regex_matches(previous_tokens, self.prompt_length, validate_only=True) 138 | 139 | return result 140 | -------------------------------------------------------------------------------- /src/validator_slop.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import List, Dict, Tuple, Generator, Set, Union 3 | 4 | import torch 5 | from transformers import ( 6 | PreTrainedTokenizer, 7 | StoppingCriteria, 8 | ) 9 | from IPython.display import display, HTML 10 | from ipywidgets import Output 11 | 12 | 13 | # This function detects if any of the slop phrases are in the end segment of inference. 14 | # It is optimised using dict lookups to be more or less constant execution time regardless 15 | # of slop list length. 16 | def detect_disallowed_sequence(tokenizer: PreTrainedTokenizer, 17 | inference: str, 18 | generated_sequence: List[int], 19 | prompt_length: int, 20 | slop_phrase_prob_adjustments: Dict[str, float], 21 | max_slop_phrase_length: int, 22 | min_slop_phrase_length: int, 23 | check_n_chars_back: int = 16 # this moves the detection window back n chars, so we can detect phrases that were completed further back 24 | ) -> Tuple[Tuple[int, ...], int]: 25 | 26 | inference = inference.lower() 27 | 28 | for char_offset in range(0, check_n_chars_back): 29 | for candidate_str_length in range(max_slop_phrase_length, min_slop_phrase_length - 1, -1): 30 | if candidate_str_length + char_offset > len(inference): 31 | continue 32 | candidate_str = inference[-(candidate_str_length + char_offset):len(inference)-char_offset] 33 | #print(candidate_str) 34 | if candidate_str in slop_phrase_prob_adjustments: 35 | # determine the token containing the beginning of the detected phrase 36 | #print('looking for', candidate_str,'in decoded text') 37 | for start_pos in range(len(generated_sequence)-1, prompt_length-1, -1): 38 | candidate_seq = generated_sequence[start_pos:] 39 | candidate_seq_decoded = tokenizer.decode(candidate_seq, skip_special_tokens=True).lower() 40 | #print(candidate_seq_decoded) 41 | if candidate_str in candidate_seq_decoded: 42 | #print('detected!', candidate_str, time.time() - start) 43 | return candidate_str, start_pos 44 | # if we reached here, something went wrong 45 | print('!! candidate_str not found after decoding') 46 | 47 | return None, -1 48 | 49 | class SlopPhraseHandler: 50 | def __init__( 51 | self, 52 | tokenizer: PreTrainedTokenizer, 53 | probs_cache: Dict[int, List[int]], 54 | probs_cache_longrange: Dict[int, bool], 55 | slop_phrase_prob_adjustments: Dict[str, float], 56 | starting_tokens_lookup: Dict[Tuple[int, ...], Set[int]], 57 | adjustment_strength: float, 58 | slow_debug: bool, 59 | inference_output: Output | None, 60 | debug_output: Output | None, 61 | debug_delay: float, 62 | ): 63 | self.tokenizer = tokenizer 64 | self.probs_cache = probs_cache 65 | self.probs_cache_longrange = probs_cache_longrange 66 | self.slop_phrase_prob_adjustments = slop_phrase_prob_adjustments 67 | self.starting_tokens_lookup = starting_tokens_lookup 68 | self.adjustment_strength = adjustment_strength 69 | self.slow_debug = slow_debug 70 | self.inference_output = inference_output 71 | self.debug_output = debug_output 72 | self.debug_delay = debug_delay 73 | 74 | self.max_slop_phrase_length = max(len(seq) for seq in self.slop_phrase_prob_adjustments.keys()) if self.slop_phrase_prob_adjustments else 0 75 | self.min_slop_phrase_length = min(len(seq) for seq in self.slop_phrase_prob_adjustments.keys()) if self.slop_phrase_prob_adjustments else 0 76 | #self.stopping_criteria = SlopPhraseStoppingCriteria(tokenizer, self.slop_phrase_sequences, self.max_slop_phrase_length) 77 | 78 | tmp = {} 79 | for key in self.slop_phrase_prob_adjustments: 80 | tmp[key.lower()] = self.slop_phrase_prob_adjustments[key] 81 | self.slop_phrase_prob_adjustments = tmp 82 | 83 | 84 | 85 | def _handle_disallowed_sequence( 86 | self, 87 | matched_phrase: str, 88 | start_pos: int, 89 | generated_sequence: List[int], 90 | probs_cache: Dict[int, torch.FloatTensor], 91 | adjustment_strength: float, 92 | slow_debug: bool, 93 | tokenizer: PreTrainedTokenizer, 94 | inference_output: Output, 95 | debug_output: Output, 96 | debug_delay: float, 97 | ) -> List[int]: 98 | # Downregulate the relevant tokens at the start_pos 99 | adjustment = self.slop_phrase_prob_adjustments[matched_phrase.lower()] 100 | 101 | # Display debug information 102 | debug_info = f"Replacing '{matched_phrase}'" 103 | self._display_debug(debug_info) 104 | 105 | if slow_debug: 106 | time.sleep(debug_delay) 107 | if debug_output: 108 | with debug_output: 109 | debug_output.clear_output(wait=True) 110 | 111 | #print('downregulating', [tokenizer.decode([generated_sequence[start_pos]])]) 112 | 113 | # Identify starting tokens to downregulate 114 | slop_phrase_starting_token = generated_sequence[start_pos] 115 | starting_tokens = self.starting_tokens_lookup.get(matched_phrase.lower(), set()) 116 | starting_tokens.add(slop_phrase_starting_token) 117 | 118 | for token_id in starting_tokens: 119 | self.probs_cache[start_pos][:, token_id] *= adjustment ** adjustment_strength 120 | 121 | # Check if the starting token would still be selected after downregulation 122 | if torch.argmax(self.probs_cache[start_pos]).item() == slop_phrase_starting_token: 123 | if slow_debug: 124 | debug_info = f"Slop phrase '{matched_phrase}' prob was downregulated {round(1/(adjustment**adjustment_strength), 2)}x but still selected." 125 | self._display_debug(debug_info) 126 | return generated_sequence 127 | 128 | # Backtrack: remove tokens from the generated_sequence that are part of the disallowed sequence 129 | for _ in range(len(generated_sequence) - start_pos): 130 | generated_sequence.pop() 131 | 132 | # Clear the probs_cache ahead of start_pos since we've backtracked 133 | to_del = [key for key in self.probs_cache if key > start_pos] 134 | for key in to_del: 135 | del self.probs_cache[key] 136 | 137 | return generated_sequence 138 | 139 | def deslop(self, generated_sequence, prompt_length): 140 | self.prompt_length = prompt_length 141 | # After adding the token(s), check for disallowed sequences 142 | 143 | inference = self.tokenizer.decode(generated_sequence[prompt_length:], skip_special_tokens=True) 144 | 145 | matched_phrase, start_pos = detect_disallowed_sequence(self.tokenizer, 146 | inference, 147 | generated_sequence, 148 | prompt_length, 149 | self.slop_phrase_prob_adjustments, 150 | self.max_slop_phrase_length, 151 | self.min_slop_phrase_length) 152 | 153 | if matched_phrase: 154 | if self.slow_debug: 155 | current_text = self.tokenizer.decode(generated_sequence[prompt_length:start_pos]) 156 | #print([current_text]) 157 | matched_phrase_to_display = self.tokenizer.decode(generated_sequence[start_pos:], skip_special_tokens=True) 158 | #print([matched_phrase_to_display]) 159 | # Add HTML formatting to display the matched_phrase in red 160 | highlighted_text = f"{current_text}{matched_phrase_to_display}" 161 | 162 | with self.inference_output: 163 | self.inference_output.clear_output(wait=True) 164 | display(HTML(f"
{highlighted_text}
")) 165 | 166 | # Display debug information 167 | debug_info = f"Replacing '{matched_phrase}'" 168 | self._display_debug(debug_info) 169 | 170 | if self.slow_debug: 171 | #time.sleep(self.debug_delay) 172 | if self.debug_output: 173 | with self.debug_output: 174 | self.debug_output.clear_output(wait=True) 175 | 176 | # Handle the disallowed sequence using SlopPhraseHandler 177 | generated_sequence = self._handle_disallowed_sequence( 178 | matched_phrase=matched_phrase, 179 | start_pos=start_pos, 180 | generated_sequence=generated_sequence, 181 | probs_cache=self.probs_cache, 182 | adjustment_strength=self.adjustment_strength, 183 | slow_debug=self.slow_debug, 184 | tokenizer=self.tokenizer, 185 | inference_output=self.inference_output, 186 | debug_output=self.debug_output, 187 | debug_delay=self.debug_delay, 188 | ) 189 | 190 | return generated_sequence 191 | return False 192 | 193 | def _display_debug(self, message: str): 194 | """ 195 | Displays debug information in the debug_output widget. 196 | """ 197 | if self.debug_output: 198 | with self.debug_output: 199 | self.debug_output.clear_output(wait=True) 200 | display(HTML(f"
{message}
")) 201 | 202 | 203 | class CustomSlopPhraseStoppingCriteria(StoppingCriteria): 204 | def __init__(self, tokenizer, max_slop_phrase_length, min_slop_phrase_length, prompt_length, slop_phrase_prob_adjustments): 205 | self.tokenizer = tokenizer 206 | self.max_slop_phrase_length = max_slop_phrase_length 207 | self.min_slop_phrase_length = min_slop_phrase_length 208 | self.slop_phrase_prob_adjustments = slop_phrase_prob_adjustments 209 | self.prompt_length = prompt_length 210 | 211 | # !! For some reason this isn't reliably triggered every token 212 | # which means we might have output a slop phrase a token back or so. 213 | # Not sure why! 214 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 215 | # Combine previous tokens with newly generated tokens 216 | self.previous_tokens = input_ids[0].tolist() 217 | 218 | inference = self.tokenizer.decode(self.previous_tokens[self.prompt_length:], skip_special_tokens=True) 219 | 220 | matched_phrase, start_pos = detect_disallowed_sequence(self.tokenizer, 221 | inference, 222 | self.previous_tokens, 223 | self.prompt_length, 224 | self.slop_phrase_prob_adjustments, 225 | self.max_slop_phrase_length, 226 | self.min_slop_phrase_length) 227 | if matched_phrase: 228 | #print('matched', matched_phrase) 229 | return True 230 | return False 231 | --------------------------------------------------------------------------------