├── .gitignore
├── LICENSE
├── README.md
├── calculate_over_represented_words.ipynb
├── example_api.ipynb
├── example_generate.ipynb
├── example_visualise_backtracking.ipynb
├── run_api.py
├── slop_phrase_prob_adjustments.json
├── slop_phrase_prob_adjustments_full_list.json
├── slop_phrases_2025-04-07.json
├── slop_regexes.txt
├── slop_words_2025-04-07.json
└── src
├── antislop_generate.py
├── slop_index.py
├── util.py
├── validator_json.py
├── validator_regex.py
└── validator_slop.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | src/__pycache__
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AntiSlop Sampler
2 |
3 | ## Overview
4 |
5 | The AntiSlop sampler uses a backtracking mechanism to go back and retry with adjusted token probabilities when it encounters a disallowed word or phrase. No more testaments or tapestries or other gpt-slop.
6 |
7 | Try the sampler here:
8 |
9 | - [Backtracking visualisation notebook](https://colab.research.google.com/drive/1tHS3tHXbZ5PWOsP-Mbk8oK35y6otBZUn?usp=sharing)
10 | - [Generate example notebook](https://colab.research.google.com/drive/1Rd3V4AN31cDytfmY9u80rzHXPD_dS6x9?usp=sharing)
11 |
12 |
13 | ### How to use the sampler
14 |
15 | With koboldcpp (new!)
16 |
17 | Koboldcpp now includes antislop phrase banning in their latest release:
18 | https://github.com/LostRuins/koboldcpp/releases/tag/v1.76
19 |
20 | With a local chat UI (open-webui):
21 |
22 |
23 | [Click to view instructions]
24 |
25 | ### if you want to run in an isolated environment (note: open-webui requires python 3.11):
26 | ```bash
27 | sudo apt install python3.11 python3.11-venv
28 | python3.11 -m venv open-webui
29 | source open-webui/bin/activate
30 | ```
31 |
32 | ### install open-webui
33 | ```bash
34 | pip install open-webui
35 | open-webui serve
36 | ```
37 |
38 | ### start the openai compatible antislop server:
39 | ```bash
40 | git clone https://github.com/sam-paech/antislop-sampler.git && cd antislop-sampler
41 | pip install fastapi uvicorn ipywidgets IPython transformers bitsandbytes accelerate
42 | python3 run_api.py --model unsloth/Llama-3.2-3B-Instruct --slop_adjustments_file slop_phrase_prob_adjustments.json
43 | ```
44 |
45 | ### configure open-webui
46 | - browse to http://localhost:8080
47 | - go to admin panel --> settings --> connections
48 | - set the OpenAI API url to http://0.0.0.0:8000/v1
49 | - set api key to anything (it's not used)
50 | - click save (!!)
51 | - click the refresh icon to verify the connection; should see a success message
52 |
53 | Now it should be all configured! Start a new chat, select the model, and give it a try.
54 |
55 | Note: Only some of the settings in the chat controls will be active. Those are:
56 | - stream chat response
57 | - temperature
58 | - top k
59 | - top p
60 | - min p
61 | - max tokens
62 |
63 | 
64 |
65 |
66 |
67 | Here it is in action (in slow mode so you can see its backtracking & revisions):
68 |
69 | https://github.com/user-attachments/assets/e289fb89-d40c-458b-9d98-118c06b28fba
70 |
71 | ### Disclaimers:
72 |
73 | It's recommended that you roll your own slop list! The slop_phrase_prob_adjustments.json file is mostly auto-generated by computing over-represented words in a large LLM-generated story dataset. As such, it's not well optimised or curated. Curated lists will be added in the future.
74 |
75 | The api is very minimalist, and doesn't support concurrency. It's more geared for local use, testing & dataset generation than for production.
76 |
77 | This is research grade code. Meaning it probably contains bugs & is a work in progress.
78 |
79 |
80 | ### 2024-10-13 Update
81 |
82 | The main changes are:
83 |
84 | - Switched to string matching instead of token matching (it's more robust)
85 | - Added regex matching/banning
86 |
87 | The new regex_bans parameter accepts a list of regex patterns. These will be evaluated during inference, and if one of them matches, we backtrack to the first token of the matched string. From there, we ban the matched continuation and continue inference.
88 |
89 | This allows more freedom for enforcing constraints than phrase matching alone. For instance, we can prevent over-used phrasing like "not x, but y". Note that streaming is not supported when using regex bans, since we can't predict how far back we may need to backtrack.
90 |
91 |
92 |
93 | ### 2024-10-05 Update
94 |
95 | Refactored the code, lots of fixes.
96 |
97 | - Added an OpenAI compatible API server.
98 | - Now using model.generate with stopping conditions, to generate for multiple tokens instead of just 1 at a time. This is much faster.
99 | - Added a basic JSON validator + enforcement to demonstrate how the sampler can enforce long-range constraints.
100 | - Switch to probs from logits for the cached values, so that down/upregulation works as expected (this was a mistake in the previous implementation).
101 | - Refactored the code for better organisation.
102 |
103 | Quick blurb on the JSON validator:
104 |
105 | It uses the same backtracking mechanism to retry invalid JSON output. It checks for unintended unescaped quotes in strings, and encourages the model to choose a valid continuation. This is a very common fail mode for JSON outputs. Other kinds of per-token JSON grammars will just terminate the string if they see an unescaped quote, sadly ending the profound thought the LLM was in the middle of expressing. This is better. You can also use it with high temps.
106 |
107 |
108 |
109 |
110 | ### 2024-10-01 Update
111 |
112 | - Squashed vram leaks, fixed bugs. It should work with any transformers model now.
113 | - Support min_p
114 | - Now using slop_phrase_prob_adjustments.json by default, which has a more intuitive probability adjustment per slop phrase (1 == no change; < 1 means probability is reduced by that factor). It looks like this:
115 | ```
116 | [
117 | ["kaleidoscope", 0.5],
118 | ["symphony", 0.5],
119 | ["testament to", 0.5],
120 | ["elara", 0.5],
121 | ...
122 | ]
123 | ```
124 |
125 |
126 | ### chat_antislop
127 | ```python
128 | # Chat generation with streaming
129 | messages = [
130 | {"role": "user", "content": prompt}
131 | ]
132 | for token in chat_antislop(
133 | model=model,
134 | tokenizer=tokenizer,
135 | messages=messages,
136 | max_new_tokens=400,
137 | temperature=1,
138 | min_p=0.1,
139 | # The adjustment_strength param scales how strongly the probability adjustments are applied.
140 | # A value of 1 means the values in slop_phrase_prob_adjustments (or the defaults) are used unmodified.
141 | # Reasonable values are 0 (disabled) thru 100+ (effectively banning the list).
142 | adjustment_strength=100.0,
143 | # Optional: Provide a list of slop phrases and probability adjustments
144 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,
145 | enforce_json=False,
146 | antislop_enabled=True,
147 | streaming=True
148 | ):
149 | print(tokenizer.decode(token), end='', flush=True)
150 | ```
151 |
152 | ### generate_antislop
153 | ```python
154 | # generate without streaming
155 | prompt_with_template = tokenizer.apply_chat_template(messages, tokenize=False)
156 | generated_text = generate_antislop(
157 | model=model,
158 | tokenizer=tokenizer,
159 | prompt=prompt,
160 | max_length=300,
161 | temperature=1,
162 | min_p=0.1,
163 | adjustment_strength=100.0,
164 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,
165 | enforce_json=False,
166 | antislop_enabled=True,
167 | streaming=False
168 | )
169 | print(tokenizer.decode(generated_text))
170 | ```
171 |
172 | ## What this does:
173 |
174 | You can give it a list of words & phrases to avoid like "a tapestry of", "a testament to", etc., and it will backtrack and try something else if it hits that phrase. It can handle 1000s of slop phrases since the lookups are fast. The phrases and downregulation amounts are user configurable. Previous approaches have done this with per-token logit biasing; but that's quite ineffective since most slop words & phrases are more than one token, and it impairs output quality if we downregulate all those partial-word tokens. So instead, we wait for the whole phrase to appear in the output, then backtrack and downregulate all the tokens that could have produced the slop phrase, and continue from there.
175 |
176 | For the default slop list, we computed a large list of words that are over-represented in LLM output compared to normal human writing. This list is supplemented by a list of over-used phrases that are pet peeves of LLM enthusiasts. During generation, if any of the words & phrases in this list are generated, the sampler reduces the probability of the starting tokens that can lead to that phrase, by the factor specified in the config. This way you can lightly de-slop or strongly de-slop. You can of course also specify your own phrase list & weights.
177 |
178 | ## Why it's interesting:
179 |
180 | Samplers typically work at the token level -- but that doesn't work if want to avoid words/phrases that tokenise to >1 tokens. Elara might tokenise to ["El", "ara"], and we don't want to reduce the probs of everything beginning with "El". So, this approach waits for the whole phrase to appear, then backtracks and reduces the probabilities of all the likely tokens that will lead to that phrase being output. ~~Nobody afaik has tried this before.~~ [edit] It turns out exllamav2 has a banned_strings feature with same/similar implementation so we can't claim novelty. [/edit]
181 |
182 | * Disclaimers: This is only implemented in Transformers (and now koboldcpp!) thus far. It is not well optimised. Expect research grade code & possibly bugs.
183 |
184 |
185 | ## What you need to implement this
186 |
187 | If you'd like to implement this sampler in something other than transformers, here's what you need:
188 |
189 | - A loop to manage the state of the sampler, as it backtracks and needs to refer to past logits that it's cached
190 | - Per-token continuation generation (think: completions, not chat.completions)
191 | - Raw logits
192 | - Ability to bias logits when generating
193 |
194 | Unfortunately that rules out most commercial APIs since few let you specify logit biases. For inferencing engines, they will likely be a mixed bag in terms of ease of integration, as most/all samplers work per token without this weird backtracking stuff we're doing here.
195 |
196 | If you do implement this sampler in your thing, please let me know about it!
197 |
198 | ## Acknowledgements
199 |
200 | Turboderp was the first to implement this mechanism in exllamav2 as the "banned strings" feature. This was unknown to us at the time of creating the AntiSlop sampler, which was birthed independently in a case of convergent evolution. Credit to them for doing it first!
201 |
202 | ## How to Cite
203 |
204 | A paper is in the works, hang tight.
205 |
206 | ```
207 | @misc{paech2024antislop,
208 | title={antislop-sampler},
209 | author={Samuel J. Paech},
210 | year={2024},
211 | howpublished={\url{https://github.com/sam-paech/antislop-sampler}},
212 | note={GitHub repository}
213 | }
214 | ```
215 |
--------------------------------------------------------------------------------
/calculate_over_represented_words.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### This notebook demonstrates how to compute the over-represented words in a corpus of text.\n",
8 | "\n",
9 | "It works by counting word frequencies in the designated dataset (in this case a large set of LLM-generated creative writing). These frequencies are compared that against the wordfreq frequencies, which represent the average prevalence of words in (human) language."
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "#!pip3 install wordfreq datasets numpy"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 11,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "import datasets\n",
28 | "from collections import Counter\n",
29 | "from wordfreq import word_frequency\n",
30 | "import numpy as np\n",
31 | "import re\n",
32 | "from tqdm import tqdm\n",
33 | "\n",
34 | "def download_datasets():\n",
35 | " datasets_info = [\n",
36 | " (\"ajibawa-2023/General-Stories-Collection\", \"train\")\n",
37 | " ]\n",
38 | " return {name: datasets.load_dataset(name, split=split) for name, split in datasets_info}\n",
39 | "\n",
40 | "def parse_text(datasets):\n",
41 | " texts = []\n",
42 | " for example in tqdm(datasets[\"ajibawa-2023/General-Stories-Collection\"]):\n",
43 | " texts.append(example['text'])\n",
44 | " return texts\n",
45 | "\n",
46 | "\n",
47 | "def get_word_counts(texts, min_length=4):\n",
48 | " \"\"\"\n",
49 | " Count word frequencies in a list of texts.\n",
50 | "\n",
51 | " Parameters:\n",
52 | " - texts (iterable of str): The input texts to process.\n",
53 | " - min_length (int): Minimum length of words to include.\n",
54 | "\n",
55 | " Returns:\n",
56 | " - Counter: A Counter object mapping words to their frequencies.\n",
57 | " \"\"\"\n",
58 | " # Precompile the regex pattern for better performance\n",
59 | " # This pattern matches words with internal apostrophes (e.g., \"couldn't\")\n",
60 | " pattern = re.compile(r\"\\b\\w+(?:'\\w+)?\\b\")\n",
61 | " \n",
62 | " word_counts = Counter()\n",
63 | " \n",
64 | " for text in tqdm(texts, desc=\"Counting words\"):\n",
65 | " if not isinstance(text, str):\n",
66 | " continue # Skip non-string entries to make the function more robust\n",
67 | " \n",
68 | " # Convert to lowercase and find all matching words\n",
69 | " words = pattern.findall(text.lower())\n",
70 | " \n",
71 | " # Update counts with words that meet the minimum length\n",
72 | " word_counts.update(word for word in words if len(word) >= min_length)\n",
73 | " \n",
74 | " return word_counts\n",
75 | "\n",
76 | "\n",
77 | "def analyze_word_rarity(word_counts):\n",
78 | " total_words = sum(word_counts.values())\n",
79 | " corpus_frequencies = {word: count / total_words for word, count in word_counts.items()}\n",
80 | " \n",
81 | " wordfreq_frequencies = {word: word_frequency(word, 'en') for word in word_counts.keys()}\n",
82 | " \n",
83 | " # Filter out words with zero frequency\n",
84 | " valid_words = [word for word, freq in wordfreq_frequencies.items() if freq > 0]\n",
85 | " \n",
86 | " corpus_freq_list = [corpus_frequencies[word] for word in valid_words]\n",
87 | " wordfreq_freq_list = [wordfreq_frequencies[word] for word in valid_words]\n",
88 | " \n",
89 | " # Calculate average rarity\n",
90 | " avg_corpus_rarity = np.mean([-np.log10(freq) for freq in corpus_freq_list])\n",
91 | " avg_wordfreq_rarity = np.mean([-np.log10(freq) for freq in wordfreq_freq_list])\n",
92 | " \n",
93 | " # Calculate correlation\n",
94 | " correlation = np.corrcoef(corpus_freq_list, wordfreq_freq_list)[0, 1]\n",
95 | " \n",
96 | " return corpus_frequencies, wordfreq_frequencies, avg_corpus_rarity, avg_wordfreq_rarity, correlation\n",
97 | "\n",
98 | "def find_over_represented_words(corpus_frequencies, wordfreq_frequencies, top_n=50000):\n",
99 | " over_representation = {}\n",
100 | " for word in corpus_frequencies.keys():\n",
101 | " wordfreq_freq = wordfreq_frequencies[word]\n",
102 | " if wordfreq_freq > 0: # Only consider words with non-zero frequency\n",
103 | " over_representation[word] = corpus_frequencies[word] / wordfreq_freq\n",
104 | " \n",
105 | " return sorted(over_representation.items(), key=lambda x: x[1], reverse=True)[:top_n]\n",
106 | "\n",
107 | "def find_zero_frequency_words(word_counts, wordfreq_frequencies, top_n=50000):\n",
108 | " zero_freq_words = {word: count for word, count in word_counts.items() if wordfreq_frequencies[word] == 0}\n",
109 | " return sorted(zero_freq_words.items(), key=lambda x: x[1], reverse=True)[:top_n]\n",
110 | "\n"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "metadata": {},
117 | "outputs": [],
118 | "source": [
119 | "\n",
120 | "print(\"Downloading datasets...\")\n",
121 | "all_datasets = download_datasets()\n"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "\n",
131 | "print(\"Parsing text...\")\n",
132 | "texts = parse_text(all_datasets)\n"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "\n",
142 | "print(f\"Total texts extracted: {len(texts)}\")\n",
143 | "\n",
144 | "print(\"Counting words...\")\n",
145 | "word_counts = get_word_counts(texts)\n"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": null,
151 | "metadata": {},
152 | "outputs": [],
153 | "source": [
154 | "\n",
155 | "def filter_mostly_numeric(word_counts):\n",
156 | " def is_mostly_numbers(word):\n",
157 | " digit_count = sum(c.isdigit() for c in word)\n",
158 | " return digit_count / len(word) > 0.2 # Adjust this ratio if needed\n",
159 | " \n",
160 | " # Create a new Counter with filtered words\n",
161 | " return Counter({word: count for word, count in word_counts.items() if not is_mostly_numbers(word)})\n",
162 | "\n",
163 | "filtered_counts = filter_mostly_numeric(word_counts)\n",
164 | "\n",
165 | "print(\"Analyzing word rarity...\")\n",
166 | "corpus_frequencies, wordfreq_frequencies, avg_corpus_rarity, avg_wordfreq_rarity, correlation = analyze_word_rarity(filtered_counts)\n",
167 | "\n",
168 | "print(f\"Total unique words analyzed: {len(word_counts)}\")\n",
169 | "print(f\"Average corpus rarity: {avg_corpus_rarity:.4f}\")\n",
170 | "print(f\"Average wordfreq rarity: {avg_wordfreq_rarity:.4f}\")\n",
171 | "print(f\"Correlation between corpus and wordfreq frequencies: {correlation:.4f}\")\n",
172 | "\n",
173 | "print(\"\\nMost over-represented words in the corpus:\")\n",
174 | "over_represented = find_over_represented_words(corpus_frequencies, wordfreq_frequencies)\n",
175 | "for word, score in over_represented:\n",
176 | " print(f\"{word}: {score:.2f} times more frequent than expected\")\n",
177 | "\n",
178 | "print(\"\\nMost frequent words with zero wordfreq frequency:\")\n",
179 | "zero_freq_words = find_zero_frequency_words(filtered_counts, wordfreq_frequencies)\n",
180 | "for word, count in zero_freq_words:\n",
181 | " print(f\"{word}: {count} occurrences\")\n"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": 13,
187 | "metadata": {},
188 | "outputs": [],
189 | "source": [
190 | "\n",
191 | "import json\n",
192 | "with open('over_represented_words.json', 'w') as f:\n",
193 | " json.dump(over_represented, f)\n",
194 | "with open('frequent_non_dictionary_words.json', 'w') as f:\n",
195 | " json.dump(zero_freq_words, f)"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "metadata": {},
202 | "outputs": [],
203 | "source": [
204 | "corpus_frequencies['testament']"
205 | ]
206 | }
207 | ],
208 | "metadata": {
209 | "kernelspec": {
210 | "display_name": "Python 3",
211 | "language": "python",
212 | "name": "python3"
213 | },
214 | "language_info": {
215 | "codemirror_mode": {
216 | "name": "ipython",
217 | "version": 3
218 | },
219 | "file_extension": ".py",
220 | "mimetype": "text/x-python",
221 | "name": "python",
222 | "nbconvert_exporter": "python",
223 | "pygments_lexer": "ipython3",
224 | "version": "3.8.9"
225 | }
226 | },
227 | "nbformat": 4,
228 | "nbformat_minor": 2
229 | }
230 |
--------------------------------------------------------------------------------
/example_api.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 5,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# First start the server:\n",
10 | "# python run_api.py --model unsloth/Llama-3.2-1B-Instruct\n",
11 | "# \n",
12 | "# Or with default slop phrase list (will be used with all queries unless slop_phrases is specified in the query):\n",
13 | "# python run_api.py --model unsloth/Llama-3.2-1B-Instruct --slop_adjustments_file slop_phrase_prob_adjustments.json"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "# ! Important:\n",
23 | "# \n",
24 | "# No slop adjustments will be applied if:\n",
25 | "# - no slop_adjustments_file specified at API launch \n",
26 | "# and\n",
27 | "# - no slop_phrases param specified in the query\n",
28 | "#\n",
29 | "# If you specified a slop_adjustments_file file at API launch, it will be used by default with queries\n",
30 | "# that do not specify a slop_phrases param. The query's slop_phrases param overrides the defaults."
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 6,
36 | "metadata": {},
37 | "outputs": [
38 | {
39 | "name": "stdout",
40 | "output_type": "stream",
41 | "text": [
42 | "Once upon a time, in a small village nestled in the Andes mountains of South America, there lived a young boy named Kusi. Kusi was a curious and adventurous soul, with a heart full of wonder and a spirit that longed to explore the world beyond his village. He spent most of his days helping his mother with the family's small farm, tending to the crops and animals, and listening to the tales of the elderly villagers about the mystical creatures that roamed the mountains.\n",
43 | "\n",
44 | "One day, while out collecting firewood, Kusi stumbled upon a strange and beautiful creature. It was a young ladyma (a type of Andean camelid) unlike any he had ever seen before. Her fur was a soft, creamy white, and her eyes shone the brightest stars in the night sky. She had a delicate, almost ethereal quality to her, and Kusi felt an instant connection to this enchanting creature.\n",
45 | "\n",
46 | "The ladyma, whose name was Akira,"
47 | ]
48 | }
49 | ],
50 | "source": [
51 | "import requests\n",
52 | "import json\n",
53 | "\n",
54 | "prompt = \"tell me a story about a magical llama\"\n",
55 | "api_url = 'http://localhost:8000/v1/chat/completions'\n",
56 | "\n",
57 | "messages = [{\"role\": \"user\", \"content\": prompt}]\n",
58 | "data = {\n",
59 | " \"messages\": messages, \n",
60 | " \"max_tokens\": 200,\n",
61 | " \"temperature\": 1,\n",
62 | " \"min_p\": 0.1,\n",
63 | " \"stream\": True,\n",
64 | " \"adjustment_strength\": 100,\n",
65 | " \"antislop_enabled\": True, # Defaults to true\n",
66 | " \"slop_phrases\": [[\"a testament to\", 0.3], [\"llama\", 0.1]] # this overrides the default list\n",
67 | "}\n",
68 | "\n",
69 | "try:\n",
70 | " # Using `stream=True` to handle the response as a stream of data\n",
71 | " response = requests.post(api_url, json=data, stream=True, timeout=30)\n",
72 | " #print(response)\n",
73 | " \n",
74 | " # Read and print the stream of data in chunks\n",
75 | " for chunk in response.iter_lines():\n",
76 | " if chunk:\n",
77 | " decoded_chunk = chunk.decode('utf-8')\n",
78 | " # OpenAI streams responses in `data: {json}` format, so we need to parse that\n",
79 | " if decoded_chunk.startswith('data:'):\n",
80 | " try:\n",
81 | " json_data = json.loads(decoded_chunk[len('data: '):]) # Remove 'data: ' prefix\n",
82 | " if 'choices' in json_data and len(json_data['choices']) > 0:\n",
83 | " print(json_data['choices'][0]['delta'].get('content', ''), end='', flush=True)\n",
84 | " except json.JSONDecodeError as e:\n",
85 | " print(f\"Error decoding JSON: {e}\")\n",
86 | "except Exception as e:\n",
87 | " print(f\"An error occurred: {e}\")\n"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 8,
93 | "metadata": {},
94 | "outputs": [
95 | {
96 | "name": "stdout",
97 | "output_type": "stream",
98 | "text": [
99 | "In the small village of Brindlemark, nestled in the rolling hills of the countryside, there lived a young woman named Eira. She was known throughout the village as the loomweaver, Eira, for her extraordinary talent in weaving intricate patterns with the threads of the finest wool.\n",
100 | "\n",
101 | "Eira's fascination with weaving began when she was just a child, watching her mother and grandmother work the loom in the evenings. The soft rustle of the wool against the fabric, the gentle hum of the shuttle, and the way the colors seemed to come alive as the loom wove its magic – all of these things captivated Eira, and she spent countless hours practicing and experimenting with different techniques.\n",
102 | "\n",
103 | "As she grew older, Eira's skills improved dramatically, and she became one of the most sought-after weavers in the region. Her creations were renowned for their beauty and complexity, with intricate patterns and delicate colors that seemed to dance across the fabric.\n",
104 | "\n",
105 | "But Eira's life\n"
106 | ]
107 | }
108 | ],
109 | "source": [
110 | "# Example using regex bans\n",
111 | "prompt = \"write a story about the loomweaver elara\"\n",
112 | "messages = [{\"role\": \"user\", \"content\": prompt}]\n",
113 | "try:\n",
114 | " data = { \n",
115 | " \"messages\": messages, \n",
116 | " \"max_tokens\": 200,\n",
117 | " \"temperature\": 1,\n",
118 | " \"min_p\": 0.1,\n",
119 | " \"stream\": False,\n",
120 | " \"adjustment_strength\": 100,\n",
121 | " \"slop_phrases\": [[\"a testament to\", 0.3], [\"kusi\", 0.1]],\n",
122 | " \"antislop_enabled\": True,\n",
123 | " \"regex_bans\": ['(?i)not [^.!?]{3,60} but', '(?i)elara'] # Not compatible with streaming\n",
124 | " }\n",
125 | " \n",
126 | " response = requests.post(api_url, json=data, stream=False, timeout=30)\n",
127 | " data = response.json()\n",
128 | " print(data['choices'][0]['message']['content'])\n",
129 | "\n",
130 | "except Exception as e:\n",
131 | " print(f\"An error occurred: {e}\")\n"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": []
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 53,
144 | "metadata": {},
145 | "outputs": [
146 | {
147 | "name": "stdout",
148 | "output_type": "stream",
149 | "text": [
150 | "In the heart of a mystical Andean village, there lived a legend about a magical llama named Luna. Her coat shone like the moon, and her eyes glowed with an ethereal light that seemed to hold the secrets of the universe. Luna was no ordinary llama – she possessed the power to heal the sick, grant wisdom to those who sought it, and even bend the fabric of time to bring peace to those who were troubled.\n",
151 | "\n",
152 | "One day, a young girl named Sophia wandered into the village, searching for her ailing mother. Sophia's mother had fallen gravely ill, and the village healer had exhausted all his remedies. Desperate for a solution, Sophia set out to find the legendary magical llama, Luna.\n",
153 | "\n",
154 | "As she trekked through the high-altitude grasslands, Sophia encountered many creatures – the majestic condors soaring above, the chattering vicuñas below, and the gentle guanacos grazing in the distance. But none of them led her to Luna, until finally"
155 | ]
156 | }
157 | ],
158 | "source": [
159 | "# Default slop list example\n",
160 | "#\n",
161 | "# If the user hasn't specified a list of slop phrases & adjustments in their \n",
162 | "# query, then it will default to use whatever you specified in the\n",
163 | "# --slop_adjustments_file argument when launching the api.\n",
164 | "#\n",
165 | "# If you didn't specify anything, it won't make any adjustments unless the\n",
166 | "# user specifies their adjustments in the query.\n",
167 | "\n",
168 | "prompt = \"tell me a story about a magical llama\"\n",
169 | "api_url = 'http://localhost:8000/v1/chat/completions'\n",
170 | "\n",
171 | "messages = [{\"role\": \"user\", \"content\": prompt}]\n",
172 | "data = { \n",
173 | " \"messages\": messages, \n",
174 | " \"max_tokens\": 200,\n",
175 | " \"temperature\": 1,\n",
176 | " \"min_p\": 0.1,\n",
177 | " \"stream\": True,\n",
178 | " \"antislop_enabled\": True,\n",
179 | " \"adjustment_strength\": 100,\n",
180 | "}\n",
181 | "\n",
182 | "try:\n",
183 | " # Using `stream=True` to handle the response as a stream of data\n",
184 | " response = requests.post(api_url, json=data, stream=True, timeout=30)\n",
185 | " #print(response)\n",
186 | " \n",
187 | " # Read and print the stream of data in chunks\n",
188 | " for chunk in response.iter_lines():\n",
189 | " if chunk:\n",
190 | " decoded_chunk = chunk.decode('utf-8')\n",
191 | " # OpenAI streams responses in `data: {json}` format, so we need to parse that\n",
192 | " if decoded_chunk.startswith('data:'):\n",
193 | " try:\n",
194 | " json_data = json.loads(decoded_chunk[len('data: '):]) # Remove 'data: ' prefix\n",
195 | " if 'choices' in json_data and len(json_data['choices']) > 0:\n",
196 | " print(json_data['choices'][0]['delta'].get('content', ''), end='', flush=True)\n",
197 | " except json.JSONDecodeError as e:\n",
198 | " print(f\"Error decoding JSON: {e}\")\n",
199 | "except Exception as e:\n",
200 | " print(f\"An error occurred: {e}\")"
201 | ]
202 | }
203 | ],
204 | "metadata": {
205 | "kernelspec": {
206 | "display_name": "Python 3",
207 | "language": "python",
208 | "name": "python3"
209 | },
210 | "language_info": {
211 | "codemirror_mode": {
212 | "name": "ipython",
213 | "version": 3
214 | },
215 | "file_extension": ".py",
216 | "mimetype": "text/x-python",
217 | "name": "python",
218 | "nbconvert_exporter": "python",
219 | "pygments_lexer": "ipython3",
220 | "version": "3.10.6"
221 | }
222 | },
223 | "nbformat": 4,
224 | "nbformat_minor": 2
225 | }
226 |
--------------------------------------------------------------------------------
/example_generate.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 6,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# Dependencies:\n",
10 | "\n",
11 | "#!pip install transformers ipywidgets IPython\n",
12 | "\n",
13 | "# If running on Colab:\n",
14 | "#!git clone https://github.com/sam-paech/antislop-sampler.git\n",
15 | "#!mv antislop-sampler/src .\n",
16 | "#!mv antislop-sampler/slop_phrase_prob_adjustments.json ."
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 7,
22 | "metadata": {},
23 | "outputs": [
24 | {
25 | "name": "stdout",
26 | "output_type": "stream",
27 | "text": [
28 | "Model loaded\n"
29 | ]
30 | }
31 | ],
32 | "source": [
33 | "import os\n",
34 | "import torch\n",
35 | "from transformers import (\n",
36 | " AutoModelForCausalLM,\n",
37 | " AutoTokenizer,\n",
38 | ")\n",
39 | "from src.antislop_generate import generate_antislop, chat_antislop\n",
40 | "\n",
41 | "# Enable efficient transfer for Hugging Face models\n",
42 | "os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = \"1\"\n",
43 | "\n",
44 | "# Set the device to 'cuda' if available, else 'cpu'\n",
45 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
46 | "\n",
47 | "# Specify the model name (replace with your preferred model)\n",
48 | "model_name = \"unsloth/Llama-3.2-1B-Instruct\"\n",
49 | "#model_name = \"unsloth/Llama-3.2-3B-Instruct\"\n",
50 | "\n",
51 | "# Load the model and tokenizer\n",
52 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
53 | "model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True)\n",
54 | "model.to(device)\n",
55 | "print('Model loaded')"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 8,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "import json\n",
65 | "if os.path.exists('slop_phrase_prob_adjustments.json'):\n",
66 | " with open('slop_phrase_prob_adjustments.json', 'r') as f:\n",
67 | " slop_phrase_prob_adjustments = dict(json.load(f)[:500])"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 9,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "prompt = \"Write a story about Elara, the weaver of tapestries in future Technopolis. In the bustling city, a group of \""
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 10,
82 | "metadata": {},
83 | "outputs": [
84 | {
85 | "name": "stderr",
86 | "output_type": "stream",
87 | "text": [
88 | "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n",
89 | "From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.\n"
90 | ]
91 | },
92 | {
93 | "name": "stdout",
94 | "output_type": "stream",
95 | "text": [
96 | "In the heart of Future Templar, the city of synthetic and organic fusion, woven magic was a staple of everyday life. It was here that Adara, a young and talented weaver, made her name. Her fingers danced across the loom, creating magnificent works of art that told the stories of the city's inhabitants.\n",
97 | "\n",
98 | "Adara's workshop was a cozy, dimly lit room filled with the soft hum of machinery and the gentle rustle of threads. The walls were adorned with a series of colorful murals depicting the city's history, from the ancient days of the Templars to the present era of technological marvels.\n",
99 | "\n",
100 | "As the sun began to set, casting a warm orange glow over the city, Adara's apprentice, Kael, arrived at her doorstep. He was a lanky, energetic young man with a mop of messy black hair and a mischievous grin. Kael had been learning the art of weaving from Adara for several years, and he was eager to prove himself as a skilled weaver in his own right.\n",
101 | "\n",
102 | "\"Tonight, Kael, we're going to work on the new commission for the Guild of Artisans,\" Adara announced, her eyes sparkling with excitement. \"They want a beautiful, shimmering fabric for their upcoming exhibition.\"\n",
103 | "\n",
104 | "Kael's eyes widened as he took in the task at hand. He had always been impressed by Adara's skill and creativity, and he knew that this commission would be a great opportunity to showcase his own talents.\n",
105 | "\n",
106 | "Together, Adara and Kael set to work, their fingers moving in perfect sync as they wove a breathtaking fabric. The threads danced across the loom, creating a work of art that seemed to shimmer and glow in the light.\n",
107 | "\n",
108 | "As the night wore on, the workshop grew quiet, except for the soft hum of the machinery and the occasional rustle of threads. Adara and Kael worked in silence, lost in their own little worlds of creativity and imagination.\n",
109 | "\n",
110 | "Finally, after"
111 | ]
112 | }
113 | ],
114 | "source": [
115 | "# Chat generation with streaming\n",
116 | "messages = [\n",
117 | " {\"role\": \"user\", \"content\": prompt}\n",
118 | "]\n",
119 | "tokens = []\n",
120 | "text = ''\n",
121 | "for token in chat_antislop(\n",
122 | " model=model,\n",
123 | " tokenizer=tokenizer,\n",
124 | " messages=messages,\n",
125 | " max_new_tokens=400,\n",
126 | " # Antislop sampling may be less reliable at low temperatures.\n",
127 | " #temperature=1, \n",
128 | " #min_p=0.1,\n",
129 | " temperature=0.01,\n",
130 | " # The adjustment_strength param scales how strongly the probability adjustments are applied.\n",
131 | " # A value of 1 means the values in slop_phrase_prob_adjustments (or the defaults) are used unmodified.\n",
132 | " # Reasonable values are 0 (disabled) thru 100+ (effectively banning the list).\n",
133 | " adjustment_strength=100.0,\n",
134 | " # Optional: Provide a list of slop phrases and probability adjustments\n",
135 | " slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,\n",
136 | " enforce_json=False,\n",
137 | " antislop_enabled=True,\n",
138 | " streaming=True,\n",
139 | " stream_smoothing=True, # On by default; this will smooth out the stutters from backtracking.\n",
140 | "):\n",
141 | " tokens.append(token)\n",
142 | " full_text = tokenizer.decode(tokens, skip_special_tokens=True)\n",
143 | " new_text = full_text[len(text):]\n",
144 | " text = full_text\n",
145 | " print(new_text, end='', flush=True)"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": 11,
151 | "metadata": {},
152 | "outputs": [
153 | {
154 | "name": "stdout",
155 | "output_type": "stream",
156 | "text": [
157 | "In the heart of Future City, a sprawling technological marvel, the air was alive with the hum of machinery and the chatter of pedestrians. The city's residents moved with purpose, their footsteps echoing off the sleek, silver skyscrapers that pierced the clouds. Among the throngs was a young woman named Elian, her dark hair tied back in a tight ponytail as she hurried down the main thoroughfare.\n",
158 | "\n",
159 | "Elian was a weaver, one of the city's skilled artisans who brought beauty and magic to the fabric of its inhabitants. Her workshop, a cozy room tucked away in a quiet alley, was a haven of creativity and tranquility. As she worked, her fingers moved deftly, the shuttle of her loom weaving a tale of wonder and enchantment.\n",
160 | "\n",
161 | "The weavers of Future City were renowned for their extraordinary talent, their creations imbued with the essence of the city's energy. Elian's own work was a fusion of traditional techniques and advanced machinery, allowing her to craft fabrics that shone like stars and pulsed with the heartbeat of the city.\n",
162 | "\n",
163 | "As Elian worked, a group of individuals caught her eye. There was Arin, the burly engineer who had designed the city's infrastructure; Lyra, the young hacker who had uncovered hidden patterns in the code; and Jax, the charismatic performer who had mastered the ancient art of acrobatics. Together, they formed an eclectic group known as the \"Synthetix,\" united by their passion for innovation and their love of storytelling.\n",
164 | "\n",
165 | "Their latest project was a grand commission from the city's governing council, a massive mural that would adorn the main square and celebrate the city's \n"
166 | ]
167 | }
168 | ],
169 | "source": [
170 | "# Chat generation without streaming\n",
171 | "messages = [\n",
172 | " {\"role\": \"user\", \"content\": prompt}\n",
173 | "]\n",
174 | "generated_text = chat_antislop(\n",
175 | " model=model,\n",
176 | " tokenizer=tokenizer,\n",
177 | " messages=messages,\n",
178 | " max_length=400,\n",
179 | " temperature=1,\n",
180 | " min_p=0.1,\n",
181 | " adjustment_strength=100.0,\n",
182 | " slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,\n",
183 | " enforce_json=False,\n",
184 | " antislop_enabled=True,\n",
185 | " regex_bans=['(?i)not [^.!?]{3,60} but'], # if these regex expressions are matched, we backtrack, ban that continuation and retry\n",
186 | " streaming=False\n",
187 | ")\n",
188 | "print(tokenizer.decode(generated_text))"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": 12,
194 | "metadata": {},
195 | "outputs": [
196 | {
197 | "name": "stdout",
198 | "output_type": "stream",
199 | "text": [
200 | " urchins stumble upon her at the market square, and a young girl named Aria discovers her talent.\n",
201 | "\n",
202 | "---\n",
203 | "\n",
204 | "The sun beat down on the market square, casting a warm glow over the crowded stalls. Vendors hawked their wares – fresh produce, exotic spices, and handmade crafts – while the smell of roasting meats and baking bread filled the air. In the midst of it all, a lone weaver sat at her stall, her fingers moving deftly as she wove a new design onto a large piece of fabric.\n",
205 | "\n",
206 | "The weaver, a young woman with a wild tangle of curly hair and a warm smile, worked with a quiet focus. Her fingers flew across the loom, the threads shimmering and glowing as she brought her vision to life. Aria, a young girl with a mop of curly brown hair and a curious gaze, wandered through the market, her eyes scanning the stalls for scraps of fabric to use in her own projects.\n",
207 | "\n",
208 | "As she passed by the weaver's stall, Aria caught sight of the beautiful, swirling patterns that covered the fabric. She felt an instant connection to the artistry, the creativity, and the beauty that radiated from the weaver's loom. The young girl's eyes locked onto the weaver, and she felt a shiver run down her spine.\n",
209 | "\n",
210 | "\"\n"
211 | ]
212 | }
213 | ],
214 | "source": [
215 | "# generate without streaming\n",
216 | "prompt_with_template = tokenizer.apply_chat_template(messages, tokenize=False)\n",
217 | "generated_text = generate_antislop(\n",
218 | " model=model,\n",
219 | " tokenizer=tokenizer,\n",
220 | " prompt=prompt,\n",
221 | " max_length=300,\n",
222 | " temperature=1,\n",
223 | " min_p=0.1,\n",
224 | " adjustment_strength=100.0,\n",
225 | " slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,\n",
226 | " enforce_json=False,\n",
227 | " antislop_enabled=True,\n",
228 | " streaming=False\n",
229 | ") \n",
230 | "print(tokenizer.decode(generated_text))"
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": 13,
236 | "metadata": {},
237 | "outputs": [
238 | {
239 | "name": "stdout",
240 | "output_type": "stream",
241 | "text": [
242 | " 5 people gather in a hidden alleyway, each with a mysterious device in their hand. They are about to participate in a secret event, where they will compete in a tournament of wit and strategy, but as they begin, they realize that each of them has been sent by a powerful corporation to spy on each other, and a mysterious figure lurks in the shadows.\n",
243 | "\n",
244 | "The group consists of:\n",
245 | "\n",
246 | "1. **Aurora, a skilled hacker from the cybernetic gang, \"The Codekeepers\"**\n",
247 | "2. **Lysander, a master strategist from the espionage agency \"The Shadowhand\"**\n",
248 | "3. **Mira, a brilliant scientist from the research facility \"The Nexus\"**\n",
249 | "4. **Caspian, a charismatic con artist with ties to the underworld**\n",
250 | "5. **Nova, a rebellious young artist with a hidden past**\n",
251 | "\n",
252 | "As they enter the tournament hall, they are greeted by the host, **Ryker**, a charismatic and suave corporate executive. Ryker explains the rules of the tournament, but as the competitors begin to mingle and exchange information, they start to realize that something is amiss.\n",
253 | "\n",
254 | "Each of the competitors is receiving a mysterious device from a different corporation, and they all seem to have a hidden agenda. The device is a small, sleek box with a glowing screen"
255 | ]
256 | }
257 | ],
258 | "source": [
259 | "# generate with streaming\n",
260 | "prompt_with_template = tokenizer.apply_chat_template(messages, tokenize=False)\n",
261 | "tokens = []\n",
262 | "text = \"\"\n",
263 | "for token in generate_antislop(\n",
264 | " model=model,\n",
265 | " tokenizer=tokenizer,\n",
266 | " prompt=prompt,\n",
267 | " max_length=300,\n",
268 | " temperature=1,\n",
269 | " min_p=0.1,\n",
270 | " slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,\n",
271 | " adjustment_strength=100.0,\n",
272 | " enforce_json=False,\n",
273 | " antislop_enabled=True,\n",
274 | " streaming=True\n",
275 | "):\n",
276 | " tokens.append(token)\n",
277 | " full_text = tokenizer.decode(tokens, skip_special_tokens=True)\n",
278 | " new_text = full_text[len(text):]\n",
279 | " text = full_text\n",
280 | " print(new_text, end='', flush=True)"
281 | ]
282 | }
283 | ],
284 | "metadata": {
285 | "kernelspec": {
286 | "display_name": "Python 3",
287 | "language": "python",
288 | "name": "python3"
289 | },
290 | "language_info": {
291 | "codemirror_mode": {
292 | "name": "ipython",
293 | "version": 3
294 | },
295 | "file_extension": ".py",
296 | "mimetype": "text/x-python",
297 | "name": "python",
298 | "nbconvert_exporter": "python",
299 | "pygments_lexer": "ipython3",
300 | "version": "3.10.6"
301 | }
302 | },
303 | "nbformat": 4,
304 | "nbformat_minor": 2
305 | }
306 |
--------------------------------------------------------------------------------
/example_visualise_backtracking.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Backtracking Visualisation for the AntiSlop Sampler\n",
8 | "\n",
9 | "This notebook demonstrates the AntiSlop sampler by adding delays when slop phrases are detected & replaced, as well as some debug output.\n",
10 | "\n",
11 | "https://github.com/sam-paech/antislop-sampler\n",
12 | "\n",
13 | "### Update 2024-10-05\n",
14 | "\n",
15 | "- Switch to probs from logits for the cached values, so that down/upregulation works as expected.\n",
16 | "- Use model.generate for multiple tokens (not just 1 at a time) with StoppingCondition criteria. This is much faster.\n",
17 | "- Add json validation including long-range checks for unintended unescaped quotes in strings."
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 1,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "# Dependencies:\n",
27 | "\n",
28 | "#!pip install transformers ipywidgets IPython\n",
29 | "\n",
30 | "# If running on Colab:\n",
31 | "#!git clone https://github.com/sam-paech/antislop-sampler.git\n",
32 | "#!mv antislop-sampler/src .\n",
33 | "#!mv antislop-sampler/slop_phrase_prob_adjustments.json ."
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 2,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "import os\n",
43 | "import json\n",
44 | "import torch\n",
45 | "from transformers import (\n",
46 | " AutoModelForCausalLM,\n",
47 | " AutoTokenizer,\n",
48 | ")\n",
49 | "from IPython.display import display, HTML\n",
50 | "from ipywidgets import Output\n",
51 | "from src.antislop_generate import AntiSlopSampler, chat_antislop, generate_antislop\n",
52 | "\n",
53 | "# Enable efficient transfer for Hugging Face models\n",
54 | "os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = \"1\"\n",
55 | "\n",
56 | "# Set the device to 'cuda' if available, else 'cpu'\n",
57 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
58 | "\n",
59 | "# Specify the model name (replace with your preferred model)\n",
60 | "model_name = \"unsloth/Llama-3.2-1B-Instruct\"\n",
61 | "#model_name = \"unsloth/Llama-3.2-3B-Instruct\"\n",
62 | "\n",
63 | "# Load the model and tokenizer\n",
64 | "model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n",
65 | "model.to(device)\n",
66 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
67 | "\n",
68 | "# These are a mix of gpt-slop found in various online lists, plus a much larger set of automatically derived over-represented words in a GPT-generated dataset.\n",
69 | "# See slopcalc.ipynb for the code to generate a more complete list.\n",
70 | "slop_phrase_prob_adjustments = [['kaleidoscope', 0.5], ['symphony', 0.5], ['testament to', 0.5], ['elara', 0.5], ['moth to a flame', 0.5], ['canvas', 0.5], ['eyes glinted', 0.5], ['camaraderie', 0.5], ['humble abode', 0.5], ['cold and calculating', 0.5], ['eyes never leaving', 0.5], ['tapestry', 0.5], ['barely above a whisper', 0.5], ['body and soul', 0.5], ['orchestra', 0.5], ['depths', 0.5], ['a dance of', 0.5], ['chuckles darkly', 0.5], ['maybe, just maybe', 0.5], ['maybe that was enough', 0.5], ['with a mixture of', 0.5], ['air was filled with anticipation', 0.5], ['cacophony', 0.5], ['bore silent witness to', 0.5], ['eyes sparkling with mischief', 0.5], ['was only just beginning', 0.5], ['practiced ease', 0.5], ['ready for the challenges', 0.5], ['only just getting started', 0.5], ['once upon a time', 0.5], ['nestled deep within', 0.5], ['ethereal beauty', 0.5], ['life would never be the same again.', 0.5], [\"it's important to remember\", 0.5], ['for what seemed like an eternity', 0.5], ['feel a sense of pride and accomplishment', 0.5], ['little did he know', 0.5], ['ball is in your court', 0.5], ['game is on', 0.5], ['choice is yours', 0.5], ['feels like an electric shock', 0.5], ['threatens to consume', 0.5], ['meticulous', 0.5], ['meticulously', 0.5], ['navigating', 0.5], ['complexities', 0.5], ['realm', 0.5], ['understanding', 0.5], ['dive into', 0.5], ['shall', 0.5], ['tailored', 0.5], ['towards', 0.5], ['underpins', 0.5], ['everchanging', 0.5], ['ever-evolving', 0.5], ['world of', 0.5], ['not only', 0.5], ['alright', 0.5], ['embark', 0.5], ['journey', 0.5], [\"today's digital age\", 0.5], ['game changer', 0.5], ['designed to enhance', 0.5], ['it is advisable', 0.5], ['daunting', 0.5], ['when it comes to', 0.5], ['in the realm of', 0.5], ['amongst', 0.5], ['unlock the secrets', 0.5], ['unveil the secrets', 0.5], ['and robust', 0.5], ['diving', 0.5], ['elevate', 0.5], ['unleash', 0.5], ['cutting-edge', 0.5], ['rapidly', 0.5], ['expanding', 0.5], ['mastering', 0.5], ['excels', 0.5], ['harness', 0.5], [\"it's important to note\", 0.5], ['delve into', 0.5], ['bustling', 0.5], ['in summary', 0.5], ['remember that', 0.5], ['take a dive into', 0.5], ['landscape', 0.5], ['in the world of', 0.5], ['vibrant', 0.5], ['metropolis', 0.5], ['firstly', 0.5], ['moreover', 0.5], ['crucial', 0.5], ['to consider', 0.5], ['essential', 0.5], ['there are a few considerations', 0.5], ['ensure', 0.5], [\"it's essential to\", 0.5], ['furthermore', 0.5], ['vital', 0.5], ['keen', 0.5], ['fancy', 0.5], ['as a professional', 0.5], ['however', 0.5], ['therefore', 0.5], ['additionally', 0.5], ['specifically', 0.5], ['generally', 0.5], ['consequently', 0.5], ['importantly', 0.5], ['indeed', 0.5], ['thus', 0.5], ['alternatively', 0.5], ['notably', 0.5], ['as well as', 0.5], ['despite', 0.5], ['essentially', 0.5], ['even though', 0.5], ['in contrast', 0.5], ['in order to', 0.5], ['due to', 0.5], ['even if', 0.5], ['given that', 0.5], ['arguably', 0.5], ['you may want to', 0.5], ['on the other hand', 0.5], ['as previously mentioned', 0.5], [\"it's worth noting that\", 0.5], ['to summarize', 0.5], ['ultimately', 0.5], ['to put it simply', 0.5], [\"in today's digital era\", 0.5], ['reverberate', 0.5], ['enhance', 0.5], ['emphasize', 0.5], ['revolutionize', 0.5], ['foster', 0.5], ['remnant', 0.5], ['subsequently', 0.5], ['nestled', 0.5], ['labyrinth', 0.5], ['gossamer', 0.5], ['enigma', 0.5], ['whispering', 0.5], ['sights unseen', 0.5], ['sounds unheard', 0.5], ['indelible', 0.5], ['my friend', 0.5], ['in conclusion', 0.5], ['technopolis', 0.5], ['was soft and gentle', 0.5], ['shivers down', 0.5], ['shivers up', 0.5], ['leaving trails of fire', 0.5], ['ministrations', 0.5], ['audible pop', 0.5], ['rivulets of', 0.5], ['despite herself', 0.5], ['reckless abandon', 0.5], ['torn between', 0.5], ['fiery red hair', 0.5], ['long lashes', 0.5], ['propriety be damned', 0.5], ['world narrows', 0.5], ['chestnut eyes', 0.5], ['cheeks flaming', 0.5], ['cheeks hollowing', 0.5], ['understandingly', 0.5], ['paperbound', 0.5], ['hesitantly', 0.5], ['piqued', 0.5], ['delved', 0.5], ['curveballs', 0.5], ['marveled', 0.5], ['inclusivity', 0.5], ['birdwatcher', 0.5], ['newfound', 0.5031423922762257], ['marveling', 0.5055622891781474], [\"hiroshi's\", 0.506870969939047], ['greentech', 0.5095092042816856], ['thoughtfully', 0.510153898156777], ['intently', 0.5153227374075411], ['birdwatching', 0.5157928537951464], ['amidst', 0.5161190296674488], ['cherishing', 0.5165772000484282], ['attentively', 0.5169695157301188], ['interjected', 0.5208671011920856], ['serendipitous', 0.5219535186850968], [\"marianne's\", 0.5220118279910801], [\"maya's\", 0.5229467776607973], ['excitedly', 0.5235248665614571], ['steepled', 0.5235772300889154], ['engrossed', 0.5236764398055735], ['fostering', 0.5259281627970829], ['brainstormed', 0.5274863713437], ['furrowed', 0.5280860997212533], ['nodded', 0.528640180937889], ['contemplatively', 0.5293698584415747], ['jotted', 0.5300819077932343], [\"mia's\", 0.5311706933553655]];\n",
71 | "slop_phrase_prob_adjustments = dict(slop_phrase_prob_adjustments)\n",
72 | "\n",
73 | "if os.path.exists('slop_phrase_prob_adjustments.json'):\n",
74 | " with open('slop_phrase_prob_adjustments.json', 'r') as f:\n",
75 | " slop_phrase_prob_adjustments = dict(json.load(f)[:500])\n",
76 | " #slop_phrase_prob_adjustments = dict(json.load(f))"
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {},
82 | "source": [
83 | "### Test anti-slop sampling"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 3,
89 | "metadata": {},
90 | "outputs": [
91 | {
92 | "data": {
93 | "text/html": [
94 | "
Prompt
"
95 | ],
96 | "text/plain": [
97 | ""
98 | ]
99 | },
100 | "metadata": {},
101 | "output_type": "display_data"
102 | },
103 | {
104 | "data": {
105 | "application/vnd.jupyter.widget-view+json": {
106 | "model_id": "e213c1b6f6114c90ae65c4a08ec5dcf7",
107 | "version_major": 2,
108 | "version_minor": 0
109 | },
110 | "text/plain": [
111 | "Output()"
112 | ]
113 | },
114 | "metadata": {},
115 | "output_type": "display_data"
116 | },
117 | {
118 | "data": {
119 | "text/html": [
120 | "Inference Output
"
121 | ],
122 | "text/plain": [
123 | ""
124 | ]
125 | },
126 | "metadata": {},
127 | "output_type": "display_data"
128 | },
129 | {
130 | "data": {
131 | "application/vnd.jupyter.widget-view+json": {
132 | "model_id": "6e24357e234841aaa2e44e8b6cdb7368",
133 | "version_major": 2,
134 | "version_minor": 0
135 | },
136 | "text/plain": [
137 | "Output()"
138 | ]
139 | },
140 | "metadata": {},
141 | "output_type": "display_data"
142 | },
143 | {
144 | "data": {
145 | "text/html": [
146 | "Debug Information
"
147 | ],
148 | "text/plain": [
149 | ""
150 | ]
151 | },
152 | "metadata": {},
153 | "output_type": "display_data"
154 | },
155 | {
156 | "data": {
157 | "application/vnd.jupyter.widget-view+json": {
158 | "model_id": "4730fcb8281d42a98ae94afcac457a91",
159 | "version_major": 2,
160 | "version_minor": 0
161 | },
162 | "text/plain": [
163 | "Output()"
164 | ]
165 | },
166 | "metadata": {},
167 | "output_type": "display_data"
168 | },
169 | {
170 | "name": "stderr",
171 | "output_type": "stream",
172 | "text": [
173 | "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n"
174 | ]
175 | }
176 | ],
177 | "source": [
178 | "# Example prompt for the model\n",
179 | "prompt = \"Once upon a time, in a bustling city of Technopolis, there lived a weaver named Elara.\"\n",
180 | "\n",
181 | "# Define the messages for a chat scenario (for example purposes)\n",
182 | "messages = [\n",
183 | " {\"role\": \"user\", \"content\": prompt}\n",
184 | "]\n",
185 | "\n",
186 | "# Display the output widgets\n",
187 | "prompt_output = Output()\n",
188 | "inference_output = Output()\n",
189 | "debug_output = Output()\n",
190 | "\n",
191 | "display(HTML(\"Prompt
\"))\n",
192 | "display(prompt_output)\n",
193 | "display(HTML(\"Inference Output
\"))\n",
194 | "display(inference_output)\n",
195 | "display(HTML(\"Debug Information
\"))\n",
196 | "display(debug_output)\n",
197 | "\n",
198 | "# Initialize display (if using in Jupyter)\n",
199 | "with prompt_output:\n",
200 | " prompt_output.clear_output(wait=True)\n",
201 | " display(HTML(f\"{prompt}
\"))\n",
202 | "\n",
203 | "if True:\n",
204 | " VALIDATE_JASON_STRINGS = True\n",
205 | " # Call the chat_antislop to generate the story with the given messages\n",
206 | " output_tokens = chat_antislop(\n",
207 | " model=model,\n",
208 | " tokenizer=tokenizer,\n",
209 | " messages=messages,\n",
210 | " max_new_tokens=400, \n",
211 | " temperature=1,\n",
212 | " min_p=0.1,\n",
213 | " slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,\n",
214 | " adjustment_strength=100,\n",
215 | " device=device,\n",
216 | " streaming=False,\n",
217 | " slow_debug=True, # Enable slow debugging\n",
218 | " output_every_n_tokens=3,\n",
219 | " debug_delay=1, # Set delay for debug output\n",
220 | " inference_output=inference_output, # Visualization of the text output\n",
221 | " enforce_json=False,\n",
222 | " debug_output=debug_output # Visualization of the debug information\n",
223 | " )\n",
224 | "\n",
225 | " inference = tokenizer.decode(output_tokens, skip_special_tokens=True)"
226 | ]
227 | },
228 | {
229 | "cell_type": "markdown",
230 | "metadata": {},
231 | "source": [
232 | "\n",
233 | "\n",
234 | "\n",
235 | "\n",
236 | "\n",
237 | "\n",
238 | "\n",
239 | "\n",
240 | "\n",
241 | "\n",
242 | "\n",
243 | "\n",
244 | "\n",
245 | "\n",
246 | "\n",
247 | "\n",
248 | "\n",
249 | "\n",
250 | "\n",
251 | "\n",
252 | "\n",
253 | "\n",
254 | "\n",
255 | "\n",
256 | "\n",
257 | "\n"
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {},
263 | "source": [
264 | "## Evaluate json constraint enforcement\n",
265 | "\n",
266 | "Note: This evaluation will fail output that contains multiple separate json elements (like json-lines format). The enforce_json flag does not enforce this; it only enforces json validity within a json block.\n",
267 | "\n",
268 | "As such, it's suggested to use a model that will at least follow the output format for this eval to work as intended (i.e. a stronger model than the example Llama-3.2-1B used above)"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": 4,
274 | "metadata": {},
275 | "outputs": [],
276 | "source": [
277 | "from tqdm import tqdm\n",
278 | "\n",
279 | "prompt = 'Write a short scene of dialogue between two characters, Jill and Steve. Output in json with this format: {\"scene\": \"\"}'\n",
280 | "\n",
281 | "# Define the messages for a chat scenario (for example purposes)\n",
282 | "messages = [\n",
283 | " {\"role\": \"user\", \"content\": prompt}\n",
284 | "]\n",
285 | "\n",
286 | "if False:\n",
287 | " inference_output = Output()\n",
288 | " debug_output = Output()\n",
289 | " display(HTML(\"Inference Output
\"))\n",
290 | " display(inference_output)\n",
291 | " display(HTML(\"Debug Information
\"))\n",
292 | " display(debug_output)\n",
293 | "\n",
294 | " n_iterations = 20\n",
295 | "\n",
296 | " valid_count = 0\n",
297 | " for i in tqdm(range(n_iterations)):\n",
298 | " output_tokens = [] \n",
299 | "\n",
300 | " output_tokens = chat_antislop(\n",
301 | " model=model,\n",
302 | " tokenizer=tokenizer,\n",
303 | " messages=messages,\n",
304 | " max_new_tokens=800, \n",
305 | " temperature=1,\n",
306 | " min_p=0.1,\n",
307 | " slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,\n",
308 | " adjustment_strength=5.0,\n",
309 | " device=device,\n",
310 | " streaming=False,\n",
311 | " slow_debug=False, # Enable slow debugging\n",
312 | " output_every_n_tokens=1, # Update every 5 tokens\n",
313 | " debug_delay=1, # Set delay for debug output\n",
314 | " inference_output=inference_output, # Visualization of the text output\n",
315 | " enforce_json=True,\n",
316 | " debug_output=debug_output # Visualization of the debug information\n",
317 | " )\n",
318 | "\n",
319 | " inference = tokenizer.decode(output_tokens, skip_special_tokens=True)\n",
320 | " lpos = inference.find('{')\n",
321 | " rpos = inference.rfind('}')\n",
322 | " if lpos >= 0 and rpos > 0:\n",
323 | " inference = inference[lpos:rpos+1]\n",
324 | " try:\n",
325 | " parsed = json.loads(inference)\n",
326 | " valid_count += 1\n",
327 | " except Exception as e:\n",
328 | " print([inference])\n",
329 | " print(e)\n",
330 | " pass\n",
331 | "\n",
332 | " print('json successfully parsed:', valid_count, 'of', n_iterations)\n",
333 | " print('')\n"
334 | ]
335 | }
336 | ],
337 | "metadata": {
338 | "kernelspec": {
339 | "display_name": "Python 3",
340 | "language": "python",
341 | "name": "python3"
342 | },
343 | "language_info": {
344 | "codemirror_mode": {
345 | "name": "ipython",
346 | "version": 3
347 | },
348 | "file_extension": ".py",
349 | "mimetype": "text/x-python",
350 | "name": "python",
351 | "nbconvert_exporter": "python",
352 | "pygments_lexer": "ipython3",
353 | "version": "3.10.6"
354 | }
355 | },
356 | "nbformat": 4,
357 | "nbformat_minor": 2
358 | }
359 |
--------------------------------------------------------------------------------
/run_api.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import argparse
4 | from typing import List, Dict, Union, Optional, Any, AsyncGenerator, Tuple
5 |
6 | from fastapi import FastAPI, HTTPException, Request
7 | from fastapi.responses import StreamingResponse
8 | from pydantic import BaseModel, Field
9 | import uvicorn
10 | import logging
11 | import threading
12 |
13 | import torch
14 | from transformers import (
15 | PreTrainedModel,
16 | PreTrainedTokenizer,
17 | AutoTokenizer,
18 | AutoModelForCausalLM,
19 | BitsAndBytesConfig # Ensure this is imported if you use quantization
20 | )
21 |
22 | # Set up logging
23 | logging.basicConfig(level=logging.INFO) # Set to DEBUG for more detailed logs
24 | logger = logging.getLogger(__name__)
25 |
26 | # Import your custom antislop_generate module
27 | from src.antislop_generate import chat_antislop, generate_antislop
28 |
29 | app = FastAPI(title="AntiSlop OpenAI-Compatible API")
30 |
31 | # Global variables to hold the model and tokenizer
32 | model: Optional[PreTrainedModel] = None
33 | tokenizer: Optional[PreTrainedTokenizer] = None
34 | DEFAULT_SLOP_ADJUSTMENTS: Dict[str, float] = {}
35 | device: Optional[torch.device] = None # Modified to allow dynamic setting
36 |
37 | # Variables to store model metadata
38 | model_loaded_time: Optional[int] = None
39 | model_name_loaded: Optional[str] = None
40 |
41 | # Define a global asyncio.Lock to enforce single concurrency
42 | import asyncio
43 | lock = asyncio.Lock()
44 |
45 | # Define Pydantic models for request and response schemas
46 |
47 | class CompletionRequest(BaseModel):
48 | model: Optional[str] = Field(default=None, description="Model to use for completion")
49 | prompt: Union[str, List[str]]
50 | max_tokens: Optional[int] = Field(default=None, ge=1, description="Maximum number of tokens to generate")
51 | temperature: Optional[float] = Field(default=1.0, ge=0.0, description="Sampling temperature")
52 | top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Nucleus sampling probability")
53 | top_k: Optional[int] = Field(default=None, ge=0, description="Top-K sampling")
54 | min_p: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Minimum probability threshold")
55 | stream: Optional[bool] = Field(default=False, description="Whether to stream back partial progress")
56 | slop_phrases: Optional[List[Tuple[str, float]]] = Field(
57 | default=None,
58 | description="List of slop phrases and their adjustment values, e.g., [['a testament to', 0.3], ['tapestry of', 0.1]]"
59 | )
60 | adjustment_strength: Optional[float] = Field(default=20.0, ge=0.0, description="Strength of adjustments")
61 | enforce_json: Optional[bool] = Field(default=False, description="Enforce JSON formatting")
62 | antislop_enabled: Optional[bool] = Field(default=True, description="Enable AntiSlop functionality")
63 | regex_bans: Optional[List[str]] = Field(default=False, description="Ban strings matching these regex expressions")
64 |
65 |
66 | class ChatCompletionMessage(BaseModel):
67 | role: str
68 | content: str
69 |
70 |
71 | class ChatCompletionRequest(BaseModel):
72 | model: Optional[str] = Field(default=None, description="Model to use for completion")
73 | messages: List[ChatCompletionMessage]
74 | max_tokens: Optional[int] = Field(default=None, ge=1, description="Maximum number of tokens to generate")
75 | temperature: Optional[float] = Field(default=1.0, ge=0.0, description="Sampling temperature")
76 | top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Nucleus sampling probability")
77 | top_k: Optional[int] = Field(default=None, ge=0, description="Top-K sampling")
78 | min_p: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Minimum probability threshold")
79 | stream: Optional[bool] = Field(default=False, description="Whether to stream back partial progress")
80 | slop_phrases: Optional[List[Tuple[str, float]]] = Field(
81 | default=None,
82 | description="List of slop phrases and their adjustment values, e.g., [['a testament to', 0.3], ['tapestry of', 0.1]]"
83 | )
84 | adjustment_strength: Optional[float] = Field(default=20.0, ge=0.0, description="Strength of adjustments")
85 | enforce_json: Optional[bool] = Field(default=False, description="Enforce JSON formatting")
86 | antislop_enabled: Optional[bool] = Field(default=True, description="Enable AntiSlop functionality")
87 | regex_bans: Optional[List[str]] = Field(default=False, description="Ban strings matching these regex expressions")
88 |
89 |
90 | class CompletionChoice(BaseModel):
91 | text: str
92 | index: int
93 | logprobs: Optional[Any] = None
94 | finish_reason: Optional[str] = None
95 |
96 |
97 | class ChatCompletionChoice(BaseModel):
98 | message: ChatCompletionMessage
99 | index: int
100 | finish_reason: Optional[str] = None
101 |
102 |
103 | class CompletionResponse(BaseModel):
104 | id: str
105 | object: str
106 | created: int
107 | model: str
108 | choices: List[CompletionChoice]
109 | usage: Dict[str, int]
110 |
111 |
112 | class ChatCompletionResponse(BaseModel):
113 | id: str
114 | object: str
115 | created: int
116 | model: str
117 | choices: List[ChatCompletionChoice]
118 | usage: Dict[str, int]
119 |
120 |
121 | # New Pydantic models for /v1/models endpoint
122 |
123 | class ModelInfo(BaseModel):
124 | id: str
125 | object: str = "model"
126 | created: int
127 | owned_by: str
128 | permission: List[Any] = []
129 | root: str
130 | parent: Optional[str] = None
131 |
132 |
133 | class ModelsResponse(BaseModel):
134 | object: str = "list"
135 | data: List[ModelInfo]
136 |
137 |
138 | # Utility functions
139 |
140 | import uuid
141 | import time
142 | import queue # Import queue for thread-safe communication
143 |
144 | def generate_id() -> str:
145 | return str(uuid.uuid4())
146 |
147 |
148 | def current_timestamp() -> int:
149 | return int(time.time())
150 |
151 |
152 | def load_slop_adjustments(file_path: Optional[str]) -> Dict[str, float]:
153 | if file_path is None:
154 | return {}
155 | if not os.path.exists(file_path):
156 | raise FileNotFoundError(f"Slop phrase adjustments file not found: {file_path}")
157 | with open(file_path, 'r', encoding='utf-8') as f:
158 | try:
159 | adjustments = dict(json.load(f))
160 | if not isinstance(adjustments, dict):
161 | raise ValueError("Slop phrase adjustments file must contain a JSON object (dictionary).")
162 | # Ensure all values are floats
163 | for key, value in adjustments.items():
164 | adjustments[key] = float(value)
165 | return adjustments
166 | except json.JSONDecodeError as e:
167 | raise ValueError(f"Error decoding JSON from slop phrase adjustments file: {e}")
168 |
169 |
170 | # Startup event to load model and tokenizer
171 | @app.on_event("startup")
172 | async def load_model_and_tokenizer():
173 | global model, tokenizer, DEFAULT_SLOP_ADJUSTMENTS, device
174 | global model_loaded_time, model_name_loaded
175 |
176 | # Set device based on GPU_ID environment variable
177 | gpu_id = os.getenv("GPU_ID")
178 | if gpu_id is not None:
179 | try:
180 | gpu_id_int = int(gpu_id)
181 | if torch.cuda.is_available() and gpu_id_int < torch.cuda.device_count():
182 | device = torch.device(f"cuda:{gpu_id_int}")
183 | else:
184 | logger.warning(f"Specified GPU ID {gpu_id_int} is not available. Falling back to default GPU or CPU.")
185 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
186 | except ValueError:
187 | logger.warning(f"Invalid GPU ID '{gpu_id}'. Falling back to default GPU or CPU.")
188 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
189 | else:
190 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191 |
192 | # Load configuration from environment variables
193 | model_name = os.getenv("MODEL_NAME", "")
194 | load_in_4bit = os.getenv("LOAD_IN_4BIT", "false").lower() == "true"
195 | load_in_8bit = os.getenv("LOAD_IN_8BIT", "false").lower() == "true"
196 | slop_adjustments_file = os.getenv("SLOP_ADJUSTMENTS_FILE", None)
197 |
198 | # Validate mutually exclusive flags
199 | if load_in_4bit and load_in_8bit:
200 | logger.error("Cannot set both LOAD_IN_4BIT and LOAD_IN_8BIT. Choose one.")
201 | raise ValueError("Cannot set both LOAD_IN_4BIT and LOAD_IN_8BIT. Choose one.")
202 |
203 | # Load slop phrase adjustments from file if provided
204 | try:
205 | DEFAULT_SLOP_ADJUSTMENTS = load_slop_adjustments(slop_adjustments_file)
206 | logger.info(f"Loaded {len(DEFAULT_SLOP_ADJUSTMENTS)} slop phrase adjustments.")
207 | except Exception as e:
208 | logger.error(f"Failed to load slop adjustments: {e}")
209 | raise ValueError("Slop adjustments file could not be loaded. Make sure you have the right file path and file structure.")
210 |
211 | logger.info(f"Using device: {device}")
212 |
213 | # Load tokenizer
214 | logger.info(f"Loading tokenizer for model '{model_name}'...")
215 | try:
216 | tokenizer = AutoTokenizer.from_pretrained(model_name)
217 | if tokenizer.pad_token is None:
218 | tokenizer.pad_token = tokenizer.eos_token # Ensure pad_token is set
219 | logger.info("Tokenizer loaded.")
220 | except Exception as e:
221 | logger.error(f"Error loading tokenizer: {e}")
222 | raise e
223 |
224 | # Load model with appropriate precision
225 | logger.info(f"Loading model '{model_name}'...")
226 | try:
227 | if load_in_4bit:
228 | # Configure 4-bit loading
229 | quantization_config = BitsAndBytesConfig(
230 | load_in_4bit=True,
231 | bnb_4bit_use_double_quant=True,
232 | bnb_4bit_quant_type='nf4',
233 | bnb_4bit_compute_dtype=torch.float16
234 | )
235 | try:
236 | import bitsandbytes # Ensure bitsandbytes is installed
237 | except ImportError:
238 | logger.error("bitsandbytes is required for 4-bit loading. Install it via 'pip install bitsandbytes'.")
239 | raise ImportError("bitsandbytes is required for 4-bit loading. Install it via 'pip install bitsandbytes'.")
240 |
241 | model = AutoModelForCausalLM.from_pretrained(
242 | model_name,
243 | quantization_config=quantization_config,
244 | device_map="auto"
245 | )
246 | logger.info("Model loaded in 4-bit precision.")
247 | elif load_in_8bit:
248 | # Configure 8-bit loading
249 | quantization_config = BitsAndBytesConfig(
250 | load_in_8bit=True
251 | )
252 | try:
253 | import bitsandbytes # Ensure bitsandbytes is installed
254 | except ImportError:
255 | logger.error("bitsandbytes is required for 8-bit loading. Install it via 'pip install bitsandbytes'.")
256 | raise ImportError("bitsandbytes is required for 8-bit loading. Install it via 'pip install bitsandbytes'.")
257 |
258 | model = AutoModelForCausalLM.from_pretrained(
259 | model_name,
260 | quantization_config=quantization_config,
261 | device_map="auto"
262 | )
263 | logger.info("Model loaded in 8-bit precision.")
264 | else:
265 | # Load model normally
266 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
267 | try:
268 | model.to(device)
269 | except Exception as e:
270 | # if an already quantised model is loaded, the model.to(device) will
271 | # throw a benign error that we can ignore.
272 | print(e)
273 | logger.info("Model loaded in 16-bit precision.")
274 | except Exception as e:
275 | logger.error(f"Error loading model: {e}")
276 | raise e # Let FastAPI handle the startup failure
277 |
278 | logger.info("Model and tokenizer loaded successfully.")
279 |
280 | # Store model metadata
281 | model_loaded_time = current_timestamp()
282 | model_name_loaded = model_name
283 |
284 |
285 | # Utility function for streaming responses
286 |
287 | def generate_id() -> str:
288 | return str(uuid.uuid4())
289 |
290 |
291 | def current_timestamp() -> int:
292 | return int(time.time())
293 |
294 |
295 | async def stream_tokens_sync(generator: Any, is_chat: bool = False) -> AsyncGenerator[str, None]:
296 | """
297 | Converts a synchronous generator to an asynchronous generator for streaming responses.
298 | Formats the output to match OpenAI's streaming response format.
299 | """
300 | q = queue.Queue()
301 |
302 | def generator_thread():
303 | try:
304 | logger.debug("Generator thread started.")
305 | for token in generator:
306 | q.put(token)
307 | logger.debug(f"Token put into queue: {token}")
308 | q.put(None) # Signal completion
309 | logger.debug("Generator thread completed.")
310 | except Exception as e:
311 | logger.error(f"Exception in generator_thread: {e}")
312 | q.put(e) # Signal exception
313 |
314 | # Start the generator in a separate daemon thread
315 | thread = threading.Thread(target=generator_thread, daemon=True)
316 | thread.start()
317 | logger.debug("Generator thread initiated.")
318 |
319 | try:
320 | tokens = []
321 | text = ''
322 | while True:
323 | token = await asyncio.to_thread(q.get)
324 | logger.debug(f"Token retrieved from queue: {token}")
325 |
326 | if token is None:
327 | # Send final finish_reason to indicate the end of the stream
328 | finish_data = {
329 | "choices": [
330 | {
331 | "delta": {},
332 | "index": 0,
333 | "finish_reason": "stop"
334 | }
335 | ]
336 | }
337 | yield f"data: {json.dumps(finish_data)}\n\n"
338 | logger.debug("Finished streaming tokens.")
339 | break
340 |
341 | if isinstance(token, Exception):
342 | # Handle exceptions by sending a finish_reason with 'error'
343 | error_data = {
344 | "choices": [
345 | {
346 | "delta": {},
347 | "index": 0,
348 | "finish_reason": "error"
349 | }
350 | ]
351 | }
352 | yield f"data: {json.dumps(error_data)}\n\n"
353 | logger.error(f"Exception during token streaming: {token}")
354 | break # Exit the loop after handling the error
355 |
356 | # Decode the token to text
357 | # Note: This is inefficient; we're decoding the whole sequence every time a
358 | # new token comes in. This is to handle cases where the tokeniser doesn't
359 | # prepend spaces to words. There's probably a better way to do this.
360 | tokens.append(token)
361 | full_text = tokenizer.decode(tokens, skip_special_tokens=True)
362 | new_text = full_text[len(text):]
363 | text = full_text
364 | logger.debug(f"Decoded token: {new_text}")
365 |
366 | # Prepare the data in OpenAI's streaming format
367 | data = {
368 | "choices": [
369 | {
370 | "delta": {"content": new_text},
371 | "index": 0,
372 | "finish_reason": None
373 | }
374 | ]
375 | }
376 |
377 | # Yield the formatted data as a Server-Sent Event (SSE)
378 | yield f"data: {json.dumps(data)}\n\n"
379 | logger.debug("Yielded token to client.")
380 |
381 | # Yield control back to the event loop
382 | await asyncio.sleep(0)
383 |
384 | except asyncio.CancelledError:
385 | logger.warning("Streaming task was cancelled by the client.")
386 | except Exception as e:
387 | logger.error(f"Unexpected error in stream_tokens_sync: {e}")
388 | finally:
389 | logger.debug("Exiting stream_tokens_sync.")
390 |
391 |
392 | # Endpoint: /v1/completions
393 | @app.post("/v1/completions", response_model=CompletionResponse)
394 | async def completions(request: CompletionRequest, req: Request):
395 | logger.info("Completion request received, waiting for processing...")
396 |
397 | if request.stream and request.regex_bans:
398 | raise HTTPException(status_code=500, detail="Streaming cannot be enabled when using regex bans.")
399 |
400 | try:
401 | if model is None or tokenizer is None:
402 | logger.error("Model and tokenizer are not loaded.")
403 | raise HTTPException(status_code=500, detail="Model and tokenizer are not loaded.")
404 |
405 | # Use the model specified in the request or default
406 | used_model = request.model if request.model else model_name_loaded
407 |
408 | # Handle prompt as string or list
409 | if isinstance(request.prompt, list):
410 | prompt = "\n".join(request.prompt)
411 | else:
412 | prompt = request.prompt
413 |
414 | # Process slop_phrases parameter
415 | if request.slop_phrases is not None:
416 | # Convert list of tuples to dictionary
417 | slop_adjustments = dict(request.slop_phrases)
418 | logger.debug(f"Slop adjustments provided: {slop_adjustments}")
419 | else:
420 | # Use default slop_phrase_prob_adjustments
421 | slop_adjustments = DEFAULT_SLOP_ADJUSTMENTS.copy()
422 | logger.debug(f"Using default slop adjustments with {len(slop_adjustments)} entries.")
423 |
424 | if request.stream:
425 | logger.info("Streaming completion request started.")
426 | # Streaming response
427 | generator_source = generate_antislop(
428 | model=model,
429 | tokenizer=tokenizer,
430 | prompt=prompt,
431 | max_new_tokens=request.max_tokens,
432 | temperature=request.temperature,
433 | top_k=request.top_k,
434 | top_p=request.top_p,
435 | min_p=request.min_p,
436 | slop_phrase_prob_adjustments=slop_adjustments,
437 | adjustment_strength=request.adjustment_strength,
438 | device=device,
439 | streaming=True,
440 | slow_debug=False, # Adjust as needed
441 | output_every_n_tokens=1,
442 | debug_delay=0.0,
443 | inference_output=None,
444 | debug_output=None,
445 | enforce_json=request.enforce_json,
446 | antislop_enabled=request.antislop_enabled,
447 | regex_bans=request.regex_bans
448 | )
449 |
450 | async def streaming_generator():
451 | async with lock:
452 | logger.info("Lock acquired for streaming completion request.")
453 | try:
454 | async for token in stream_tokens_sync(generator_source, is_chat=False):
455 | yield token
456 | except Exception as e:
457 | logger.error(f"Exception in streaming_generator: {e}")
458 | finally:
459 | logger.info("Streaming generator completed and lock released.")
460 |
461 | return StreamingResponse(
462 | streaming_generator(),
463 | media_type="text/event-stream"
464 | )
465 |
466 | else:
467 | logger.info("Non-streaming completion request started.")
468 | async with lock:
469 | logger.info("Lock acquired for non-streaming completion request.")
470 | # Non-streaming response
471 | generated_tokens = generate_antislop(
472 | model=model,
473 | tokenizer=tokenizer,
474 | prompt=prompt,
475 | max_new_tokens=request.max_tokens,
476 | temperature=request.temperature,
477 | top_k=request.top_k,
478 | top_p=request.top_p,
479 | min_p=request.min_p,
480 | slop_phrase_prob_adjustments=slop_adjustments,
481 | adjustment_strength=request.adjustment_strength,
482 | device=device,
483 | streaming=False,
484 | slow_debug=False,
485 | output_every_n_tokens=5,
486 | debug_delay=0.0,
487 | inference_output=None,
488 | debug_output=None,
489 | enforce_json=request.enforce_json,
490 | antislop_enabled=request.antislop_enabled,
491 | regex_bans=request.regex_bans
492 | )
493 |
494 | # Decode the tokens
495 | text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
496 | logger.debug(f"Generated text: {text}")
497 |
498 | # Create the response
499 | response = CompletionResponse(
500 | id=generate_id(),
501 | object="text_completion",
502 | created=current_timestamp(),
503 | model=used_model,
504 | choices=[
505 | CompletionChoice(
506 | text=text,
507 | index=0,
508 | logprobs=None,
509 | finish_reason="length" if request.max_tokens else "stop"
510 | )
511 | ],
512 | usage={
513 | "prompt_tokens": len(tokenizer.encode(prompt)),
514 | "completion_tokens": len(generated_tokens),
515 | "total_tokens": len(tokenizer.encode(prompt)) + len(generated_tokens),
516 | }
517 | )
518 | logger.info("Completion request processing completed.")
519 | return response
520 |
521 | except Exception as e:
522 | logger.error(f"Error during completion processing: {e}")
523 | raise HTTPException(status_code=500, detail=str(e))
524 | finally:
525 | logger.debug("Exiting /v1/completions endpoint.")
526 |
527 |
528 | # Endpoint: /v1/chat/completions
529 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
530 | async def chat_completions(request: ChatCompletionRequest, req: Request):
531 | logger.info("Chat completion request received, waiting for processing...")
532 |
533 | if request.stream and request.regex_bans:
534 | raise HTTPException(status_code=500, detail="Streaming cannot be enabled when using regex bans.")
535 |
536 | try:
537 | if model is None or tokenizer is None:
538 | logger.error("Model and tokenizer are not loaded.")
539 | raise HTTPException(status_code=500, detail="Model and tokenizer are not loaded.")
540 |
541 | # Use the model specified in the request or default
542 | used_model = request.model if request.model else model_name_loaded
543 |
544 | # Build the prompt from chat messages
545 | prompt = "\n".join([f"{msg.role}: {msg.content}" for msg in request.messages])
546 | logger.debug(f"Constructed prompt from messages: {prompt}")
547 |
548 | # Process slop_phrases parameter
549 | if request.slop_phrases is not None:
550 | # Convert list of tuples to dictionary
551 | slop_adjustments = dict(request.slop_phrases)
552 | logger.debug(f"Slop adjustments provided: {slop_adjustments}")
553 | else:
554 | # Use default slop_phrase_prob_adjustments
555 | slop_adjustments = DEFAULT_SLOP_ADJUSTMENTS.copy()
556 | logger.debug(f"Using default slop adjustments with {len(slop_adjustments)} entries.")
557 |
558 | if request.stream:
559 | logger.info("Streaming chat completion request started.")
560 | # Streaming response
561 | generator_source = chat_antislop(
562 | model=model,
563 | tokenizer=tokenizer,
564 | messages=[msg.dict() for msg in request.messages],
565 | max_new_tokens=request.max_tokens,
566 | temperature=request.temperature,
567 | top_k=request.top_k,
568 | top_p=request.top_p,
569 | min_p=request.min_p,
570 | slop_phrase_prob_adjustments=slop_adjustments,
571 | adjustment_strength=request.adjustment_strength,
572 | device=device,
573 | streaming=True,
574 | slow_debug=False, # Adjust as needed
575 | output_every_n_tokens=1,
576 | debug_delay=0.0,
577 | inference_output=None,
578 | debug_output=None,
579 | enforce_json=request.enforce_json,
580 | antislop_enabled=request.antislop_enabled,
581 | regex_bans=request.regex_bans
582 | )
583 |
584 | async def streaming_generator():
585 | async with lock:
586 | logger.info("Lock acquired for streaming chat completion request.")
587 | try:
588 | async for token in stream_tokens_sync(generator_source, is_chat=True):
589 | yield token
590 | except Exception as e:
591 | logger.error(f"Exception in streaming_generator: {e}")
592 | finally:
593 | logger.info("Streaming generator completed and lock released.")
594 |
595 | return StreamingResponse(
596 | streaming_generator(),
597 | media_type="text/event-stream"
598 | )
599 |
600 | else:
601 | logger.info("Non-streaming chat completion request started.")
602 | async with lock:
603 | logger.info("Lock acquired for non-streaming chat completion request.")
604 | # Non-streaming response
605 | generated_tokens = chat_antislop(
606 | model=model,
607 | tokenizer=tokenizer,
608 | messages=[msg.dict() for msg in request.messages],
609 | max_new_tokens=request.max_tokens,
610 | temperature=request.temperature,
611 | top_k=request.top_k,
612 | top_p=request.top_p,
613 | min_p=request.min_p,
614 | slop_phrase_prob_adjustments=slop_adjustments,
615 | adjustment_strength=request.adjustment_strength,
616 | device=device,
617 | streaming=False,
618 | slow_debug=False,
619 | output_every_n_tokens=5,
620 | debug_delay=0.0,
621 | inference_output=None,
622 | debug_output=None,
623 | enforce_json=request.enforce_json,
624 | antislop_enabled=request.antislop_enabled,
625 | regex_bans=request.regex_bans
626 | )
627 |
628 | # Decode the tokens
629 | text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
630 | logger.debug(f"Generated chat text: {text}")
631 |
632 | # Create the response
633 | response = ChatCompletionResponse(
634 | id=generate_id(),
635 | object="chat.completion",
636 | created=current_timestamp(),
637 | model=used_model,
638 | choices=[
639 | ChatCompletionChoice(
640 | message=ChatCompletionMessage(role="assistant", content=text),
641 | index=0,
642 | finish_reason="length" if request.max_tokens else "stop"
643 | )
644 | ],
645 | usage={
646 | "prompt_tokens": len(tokenizer.encode(prompt)),
647 | "completion_tokens": len(generated_tokens),
648 | "total_tokens": len(tokenizer.encode(prompt)) + len(generated_tokens),
649 | }
650 | )
651 | logger.info("Chat completion request processing completed.")
652 | return response
653 |
654 | except Exception as e:
655 | logger.error(f"Error during chat completion processing: {e}")
656 | raise HTTPException(status_code=500, detail=str(e))
657 | finally:
658 | logger.debug("Exiting /v1/chat/completions endpoint.")
659 |
660 |
661 | # New Endpoint: /v1/models
662 | @app.get("/v1/models", response_model=ModelsResponse)
663 | async def get_models():
664 | logger.info("Models request received.")
665 | try:
666 | if model is None or model_name_loaded is None or model_loaded_time is None:
667 | logger.error("Model is not loaded.")
668 | raise HTTPException(status_code=500, detail="Model is not loaded.")
669 |
670 | model_info = ModelInfo(
671 | id=model_name_loaded,
672 | created=model_loaded_time,
673 | owned_by="user", # Adjust as needed
674 | permission=[], # Can be populated with actual permissions if available
675 | root=model_name_loaded,
676 | parent=None
677 | )
678 |
679 | response = ModelsResponse(
680 | data=[model_info]
681 | )
682 |
683 | logger.info("Models response prepared successfully.")
684 | return response
685 |
686 | except Exception as e:
687 | logger.error(f"Error during models processing: {e}")
688 | raise HTTPException(status_code=500, detail=str(e))
689 | finally:
690 | logger.debug("Exiting /v1/models endpoint.")
691 |
692 |
693 | # Main function to parse arguments and start Uvicorn
694 | def main():
695 | parser = argparse.ArgumentParser(description="Launch the AntiSlop OpenAI-Compatible API server.")
696 |
697 | parser.add_argument(
698 | "--model",
699 | type=str,
700 | required=True,
701 | help="Path to the model directory or HuggingFace model ID (e.g., 'gpt2')."
702 | )
703 | parser.add_argument(
704 | "--load_in_4bit",
705 | action="store_true",
706 | help="Load the model in 4-bit precision (requires appropriate support)."
707 | )
708 | parser.add_argument(
709 | "--load_in_8bit",
710 | action="store_true",
711 | help="Load the model in 8-bit precision (requires appropriate support)."
712 | )
713 | parser.add_argument(
714 | "--slop_adjustments_file",
715 | type=str,
716 | default=None,
717 | help="Path to the JSON file containing slop phrase probability adjustments."
718 | )
719 | parser.add_argument(
720 | "--host",
721 | type=str,
722 | default="0.0.0.0",
723 | help="Host address to bind the server to."
724 | )
725 | parser.add_argument(
726 | "--port",
727 | type=int,
728 | default=8000,
729 | help="Port number to bind the server to."
730 | )
731 | parser.add_argument(
732 | "--gpu",
733 | type=int,
734 | default=None,
735 | help="GPU ID to load the model on (e.g., 0, 1). Optional."
736 | )
737 |
738 | args = parser.parse_args()
739 |
740 | # Set environment variables based on parsed arguments
741 | os.environ["MODEL_NAME"] = args.model
742 | os.environ["LOAD_IN_4BIT"] = str(args.load_in_4bit)
743 | os.environ["LOAD_IN_8BIT"] = str(args.load_in_8bit)
744 | if args.slop_adjustments_file:
745 | os.environ["SLOP_ADJUSTMENTS_FILE"] = args.slop_adjustments_file
746 | if args.gpu is not None:
747 | os.environ["GPU_ID"] = str(args.gpu)
748 |
749 | # Run the app using Uvicorn with single worker and single thread
750 | uvicorn.run(
751 | "run_api:app", # Ensure this matches the filename if different
752 | host=args.host,
753 | port=args.port,
754 | reload=False,
755 | log_level="info", # Set to DEBUG for more detailed logs
756 | timeout_keep_alive=600, # 10 minutes
757 | workers=1, # Single worker to enforce global lock
758 | loop="asyncio", # Ensure using asyncio loop
759 | )
760 |
761 |
762 | if __name__ == "__main__":
763 | main()
764 |
--------------------------------------------------------------------------------
/slop_phrase_prob_adjustments.json:
--------------------------------------------------------------------------------
1 | [
2 | ["symphony", 0.03125],
3 | ["testament to", 0.03125],
4 | ["kaleidoscope", 0.03125],
5 | ["delve", 0.03125],
6 | ["delving", 0.05988111477573171],
7 | ["delved", 0.03125],
8 | ["elara", 0.03125],
9 | ["tapestry", 0.03125],
10 | ["tapestries", 0.03125],
11 | ["weave", 0.03125],
12 | ["wove", 0.03125],
13 | ["weaving", 0.03125],
14 | ["elysia", 0.03125],
15 | ["barely above a whisper", 0.03125],
16 | ["barely a whisper", 0.03125],
17 | ["orchestra of", 0.03125],
18 | ["dance of", 0.03125],
19 | ["maybe, just maybe", 0.03125],
20 | ["maybe that was enough", 0.03125],
21 | ["perhaps, just perhaps", 0.03125],
22 | ["was only just beginning", 0.03125],
23 | [", once a ", 0.03125],
24 | ["world of", 0.03125],
25 | ["bustling", 0.03125],
26 | ["labyrinthine", 0.06638702636902488],
27 | ["shivers down", 0.03125],
28 | ["shivers up", 0.03125],
29 | ["shiver down", 0.03125],
30 | ["shiver up", 0.03125],
31 | ["ministrations", 0.03125],
32 | ["numeria", 0.03125],
33 | ["transcended", 0.060515730172740693],
34 | ["lyra", 0.03125],
35 | ["eira", 0.03125],
36 | ["eldoria", 0.03125],
37 | ["atheria", 0.03125],
38 | ["eluned", 0.03125],
39 | ["oakhaven", 0.03125],
40 | ["whisperwood", 0.03125],
41 | ["zephyria", 0.03125],
42 | ["elian", 0.03125],
43 | ["elias", 0.03125],
44 | ["elianore", 0.03125],
45 | ["aria", 0.03125],
46 | ["eitan", 0.03125],
47 | ["kael", 0.03125],
48 | ["jaxon", 0.10836356096002156],
49 | ["ravenswood", 0.03125],
50 | ["moonwhisper", 0.03125],
51 | ["thrummed", 0.03125],
52 | ["flickered", 0.07474260888562831],
53 | [" rasped", 0.03125],
54 | [" rasp", 0.03125],
55 | [" rasping", 0.03125],
56 | [" ,rasped", 0.03125],
57 | [" ,rasp", 0.03125],
58 | [" ,rasping", 0.03125],
59 | ["bioluminescent", 0.03125],
60 | ["glinting", 0.03125],
61 | ["twinkled", 0.07014359686367556],
62 | ["nodded", 0.04128581381864993],
63 | ["nestled", 0.03125],
64 | ["ministration", 0.03125],
65 | ["moth to a flame", 0.03125],
66 | ["canvas", 0.03125],
67 | ["eyes glinted", 0.03125],
68 | ["camaraderie", 0.03125],
69 | ["humble abode", 0.03125],
70 | ["cold and calculating", 0.03125],
71 | ["eyes never leaving", 0.03125],
72 | ["body and soul", 0.03125],
73 | ["orchestra", 0.03125],
74 | ["palpable", 0.03125],
75 | ["depths", 0.03125],
76 | ["a dance of", 0.03125],
77 | ["chuckles darkly", 0.03125],
78 | ["maybe, that was enough", 0.03125],
79 | ["they would face it together", 0.03125],
80 | ["a reminder", 0.03125],
81 | ["that was enough", 0.03125],
82 | ["for now, that was enough", 0.03125],
83 | ["for now, that's enough", 0.03125],
84 | ["with a mixture of", 0.03125],
85 | ["air was filled with anticipation", 0.03125],
86 | ["cacophony", 0.03125],
87 | ["bore silent witness to", 0.03125],
88 | ["eyes sparkling with mischief", 0.03125],
89 | ["practiced ease", 0.03125],
90 | ["ready for the challenges", 0.03125],
91 | ["only just getting started", 0.03125],
92 | ["once upon a time", 0.03125],
93 | ["nestled deep within", 0.03125],
94 | ["ethereal beauty", 0.03125],
95 | ["life would never be the same", 0.03125],
96 | ["it's important to remember", 0.03125],
97 | ["for what seemed like an eternity", 0.03125],
98 | ["little did he know", 0.03125],
99 | ["ball is in your court", 0.03125],
100 | ["game is on", 0.03125],
101 | ["choice is yours", 0.03125],
102 | ["feels like an electric shock", 0.03125],
103 | ["threatens to consume", 0.03125],
104 | ["meticulous", 0.03125],
105 | ["meticulously", 0.03125],
106 | ["navigating", 0.03125],
107 | ["complexities", 0.03125],
108 | ["realm", 0.03125],
109 | ["understanding", 0.03125],
110 | ["dive into", 0.03125],
111 | ["shall", 0.03125],
112 | ["tailored", 0.03125],
113 | ["towards", 0.03125],
114 | ["underpins", 0.03125],
115 | ["everchanging", 0.03125],
116 | ["ever-evolving", 0.03125],
117 | ["not only", 0.03125],
118 | ["alright", 0.03125],
119 | ["embark", 0.03125],
120 | ["journey", 0.03125],
121 | ["today's digital age", 0.03125],
122 | ["game changer", 0.03125],
123 | ["designed to enhance", 0.03125],
124 | ["it is advisable", 0.03125],
125 | ["daunting", 0.03125],
126 | ["when it comes to", 0.03125],
127 | ["in the realm of", 0.03125],
128 | ["unlock the secrets", 0.03125],
129 | ["unveil the secrets", 0.03125],
130 | ["and robust", 0.03125],
131 | ["elevate", 0.03125],
132 | ["unleash", 0.03125],
133 | ["cutting-edge", 0.03125],
134 | ["mastering", 0.03125],
135 | ["harness", 0.03125],
136 | ["it's important to note", 0.03125],
137 | ["in summary", 0.03125],
138 | ["remember that", 0.03125],
139 | ["take a dive into", 0.03125],
140 | ["landscape", 0.03125],
141 | ["in the world of", 0.03125],
142 | ["vibrant", 0.03125],
143 | ["metropolis", 0.03125],
144 | ["moreover", 0.03125],
145 | ["crucial", 0.03125],
146 | ["to consider", 0.03125],
147 | ["there are a few considerations", 0.03125],
148 | ["it's essential to", 0.03125],
149 | ["furthermore", 0.03125],
150 | ["vital", 0.03125],
151 | ["as a professional", 0.03125],
152 | ["thus", 0.03125],
153 | ["you may want to", 0.03125],
154 | ["on the other hand", 0.03125],
155 | ["as previously mentioned", 0.03125],
156 | ["it's worth noting that", 0.03125],
157 | ["to summarize", 0.03125],
158 | ["to put it simply", 0.03125],
159 | ["in today's digital era", 0.03125],
160 | ["reverberate", 0.03125],
161 | ["revolutionize", 0.03125],
162 | ["labyrinth", 0.03125],
163 | ["gossamer", 0.03125],
164 | ["enigma", 0.03125],
165 | ["whispering", 0.03125],
166 | ["sights unseen", 0.03125],
167 | ["sounds unheard", 0.03125],
168 | ["indelible", 0.03125],
169 | ["in conclusion", 0.03125],
170 | ["technopolis", 0.03125],
171 | ["was soft and gentle", 0.03125],
172 | ["leaving trails of fire", 0.03125],
173 | ["audible pop", 0.03125],
174 | ["rivulets of", 0.03125],
175 | ["despite herself", 0.03125],
176 | ["reckless abandon", 0.03125],
177 | ["torn between", 0.03125],
178 | ["fiery red hair", 0.03125],
179 | ["long lashes", 0.03125],
180 | ["world narrows", 0.03125],
181 | ["chestnut eyes", 0.03125],
182 | ["cheeks flaming", 0.03125],
183 | ["cheeks hollowing", 0.03125],
184 | ["understandingly", 0.03125],
185 | ["paperbound", 0.03125],
186 | ["hesitantly", 0.03125],
187 | ["piqued", 0.03125],
188 | ["curveballs", 0.03125],
189 | ["marveled", 0.03125],
190 | ["inclusivity", 0.03125],
191 | ["birdwatcher", 0.03125],
192 | ["newfound", 0.03224441869181626],
193 | ["marveling", 0.0330273218226949],
194 | ["hiroshi", 0.10679436640360451],
195 | ["greentech", 0.03433682773874149],
196 | ["thoughtfully", 0.034554614125603775],
197 | ["intently", 0.036340970869438986],
198 | ["birdwatching", 0.036507038507018016],
199 | ["amidst", 0.036622615767493],
200 | ["cherishing", 0.03678545819532492],
201 | ["attentively", 0.036925354505568025],
202 | ["interjected", 0.03833845769341088],
203 | ["serendipitous", 0.038739958252226044],
204 | ["marianne", 0.038761601988255484],
205 | ["maya", 0.09234714457485342],
206 | ["excitedly", 0.039326615662048627],
207 | ["steepled", 0.03934628705145861],
208 | ["engrossed", 0.03938357871889015],
209 | ["fostering", 0.040237606281663015],
210 | ["brainstormed", 0.040837224374359785],
211 | ["furrowed", 0.04106990333912991],
212 | ["contemplatively", 0.0415715337562441],
213 | ["jotted", 0.04185187388384613],
214 | ["mia", 0.04228346088920659],
215 | ["yesteryears", 0.04297673583870908],
216 | ["conspiratorially", 0.04318452795649216],
217 | ["poring", 0.043368809753774676],
218 | ["stumbled", 0.043572405422243304],
219 | ["strategized", 0.04421861605966468],
220 | ["hesitated", 0.04494448438749303],
221 | ["intrigued", 0.0455166486053908],
222 | ["sarah", 0.10086927179354085],
223 | ["lykos", 0.04637933901297856],
224 | ["adaptability", 0.04670427920019436],
225 | ["yoing", 0.04690684293430847],
226 | ["geocaches", 0.047245460751898956],
227 | ["furrowing", 0.047522914178588074],
228 | ["quandaries", 0.0481208833728088],
229 | ["chimed", 0.0482978999356531],
230 | ["headfirst", 0.04831461279577993],
231 | ["gruffly", 0.049202857565227666],
232 | ["skeptically", 0.04926960158439389],
233 | ["fruitville", 0.05025148031991987],
234 | ["gastronomical", 0.050289103350348655],
235 | ["sighed", 0.05051102378739721],
236 | ["warmly", 0.05079433763521114],
237 | ["approvingly", 0.05115585243689052],
238 | ["questioningly", 0.05138118422767321],
239 | ["timmy", 0.07835758599135755],
240 | ["undeterred", 0.052486122725405135],
241 | ["starlit", 0.052741028161842],
242 | ["unearthing", 0.053074267343765635],
243 | ["grappled", 0.053717801895838344],
244 | ["yumi", 0.09597346307465092],
245 | ["seabrook", 0.054387123574424684],
246 | ["geocachers", 0.05443263224688007],
247 | ["animatedly", 0.054946701090844687],
248 | ["bakersville", 0.055275865634954346],
249 | ["minji", 0.05547746089952316],
250 | ["fateful", 0.055736452807584415],
251 | ["sparkled", 0.05644124927285605],
252 | ["resonated", 0.05661086006277033],
253 | ["harmoniously", 0.05663396393468821],
254 | ["fidgeted", 0.05692129321298146],
255 | ["mwanga", 0.05772491398190184],
256 | ["gleamed", 0.05785836606810767],
257 | ["embracing", 0.05786812845368198],
258 | ["pleasantries", 0.05845295163766754],
259 | ["iostream", 0.059082502393001814],
260 | ["navigated", 0.05913171417008317],
261 | ["interconnectedness", 0.05919654794414185],
262 | ["tanginess", 0.05941208348993963],
263 | ["mouthwatering", 0.05961924098477277],
264 | ["amelia", 0.10495450758329813],
265 | ["mischievously", 0.059910554433957666],
266 | ["tirelessly", 0.05992047742518978],
267 | ["sympathetically", 0.06067860765517939],
268 | ["pondered", 0.06069437404836339],
269 | ["lingered", 0.060889076468395927],
270 | ["println", 0.06094043781330748],
271 | ["empathizing", 0.06134449183509449],
272 | ["niche", 0.06190423094543172],
273 | ["regaled", 0.06212602213421109],
274 | ["greenthumb", 0.06224864856855164],
275 | ["savored", 0.062408346752467134],
276 | ["amira", 0.11021608052671446],
277 | ["firsthand", 0.06392924590568197],
278 | ["empathetically", 0.06417835678867895],
279 | ["unshed", 0.06438631989525674],
280 | ["jenkin", 0.06442424583006488],
281 | ["enigmatically", 0.06523513565206027],
282 | ["marla", 0.06529158024829011],
283 | ["bayville", 0.06533248861052479],
284 | ["adversities", 0.06561596389981958],
285 | ["eagerly", 0.06577650960046391],
286 | ["quizzically", 0.06651670263185526],
287 | ["transcending", 0.06652015162691576],
288 | ["resilience", 0.06735287445058671],
289 | ["lily", 0.11166764730068231],
290 | ["commiserated", 0.06750658002795124],
291 | ["savoring", 0.06776227505035798],
292 | ["amara", 0.10974878019475678],
293 | ["somberly", 0.06899158731759324],
294 | ["cinephile", 0.06956640378878991],
295 | ["solace", 0.06979652935880434],
296 | ["aquascaping", 0.07031299700645137],
297 | ["rippled", 0.07081124722667731],
298 | ["reveled", 0.07100424664644575],
299 | ["greenhaven", 0.07105438878187946],
300 | ["birdwatchers", 0.07113287273958965],
301 | ["adwoa", 0.0715485451251974],
302 | ["appreciatively", 0.07161671276447462],
303 | ["awestruck", 0.07162853701765545],
304 | ["ecotech", 0.0718147029190956],
305 | ["lightheartedness", 0.07252636606936509],
306 | ["disapprovingly", 0.07263207299052071],
307 | ["exclaimed", 0.07276579734327646],
308 | ["samir", 0.07285320350285814],
309 | ["fishkeeping", 0.07309459959713759],
310 | ["sparked", 0.07360160538883663],
311 | ["welled", 0.07400001131145016],
312 | ["jotting", 0.07406314941636878],
313 | ["resourcefulness", 0.07421328466775397],
314 | ["reminisced", 0.07491988173476277],
315 | ["abernathy", 0.07570846614292619],
316 | ["unbeknownst", 0.07630735797388712],
317 | ["pattered", 0.07655890842665049],
318 | ["reassuringly", 0.07682374203555666],
319 | ["miscommunications", 0.07686405323158839],
320 | ["wafted", 0.07712987845796046],
321 | ["absentmindedly", 0.0774482165689122],
322 | ["weightiness", 0.07755972654694927],
323 | ["allyship", 0.0778681972659635],
324 | ["perseverance", 0.07827640317481438],
325 | ["mindfully", 0.07859251285955585],
326 | ["disheartened", 0.07871071570815331],
327 | ["leaned", 0.07879604294731023],
328 | ["birder", 0.07892032327792912],
329 | ["captivated", 0.07892473996549315],
330 | ["ravi", 0.07900206498656101],
331 | ["abuela", 0.07959034010494626],
332 | ["apprehensions", 0.07962652222314125],
333 | ["gestured", 0.079648614599269],
334 | ["sagely", 0.07990023651781021],
335 | ["jamie", 0.07999021650870594],
336 | ["emily", 0.08024569032217586],
337 | ["piquing", 0.08025442043095833],
338 | ["bated", 0.08045544454435882],
339 | ["\u00e9lise", 0.08118149973936264],
340 | ["cinephiles", 0.08124748123357117],
341 | ["alex", 0.08155958528922039],
342 | ["wholeheartedly", 0.08162034496563263],
343 | ["enthusiasts", 0.08196867752993525],
344 | ["enchantingly", 0.0823223018514873],
345 | ["wambui", 0.08263722346791491],
346 | ["blankly", 0.08273500150148104],
347 | ["eadric", 0.08283062478131556],
348 | ["immersing", 0.08318343859097584],
349 | ["adversity", 0.08321158200876488],
350 | ["tldr", 0.08325658701983453],
351 | ["cleanups", 0.08364043130448932],
352 | ["candidness", 0.08479322865932452],
353 | ["todayilearned", 0.08484643847316964],
354 | ["windowpanes", 0.08542326976652978],
355 | ["chuckled", 0.08575177294570574],
356 | ["jake", 0.08599683007615211],
357 | ["cobblestone", 0.08603112799196728],
358 | ["scrolling", 0.08603718332851425],
359 | ["curiosity", 0.08623856472735726],
360 | ["homebrewer", 0.08658753164552577],
361 | ["worriedly", 0.08677545881209053],
362 | ["intriguingly", 0.08692958549184135],
363 | ["brainstorming", 0.08724187110008937],
364 | ["shimmered", 0.08742261524005794],
365 | ["supportively", 0.08751018298262307],
366 | ["aldric", 0.0875538540008498],
367 | ["captivating", 0.08772967862882247],
368 | ["grumbled", 0.08775125893819308],
369 | ["flytraps", 0.08789847780208236],
370 | ["evergreen", 0.08810861754185394],
371 | ["jingled", 0.08816160213719028],
372 | ["csharp", 0.08867618121093104],
373 | ["etched", 0.08872622775971402],
374 | ["intricate", 0.08874912738325326],
375 | ["vibrantly", 0.08875022097469734],
376 | ["insights", 0.08881465445615759],
377 | ["etiquettes", 0.08881746671718654],
378 | ["jamal", 0.0892505566950269],
379 | ["serendipity", 0.08930907685467555],
380 | ["aback", 0.08933029922210305],
381 | ["tightknit", 0.089894914209471],
382 | ["fostered", 0.0902250543058865],
383 | ["unease", 0.09027239734885409],
384 | ["stammered", 0.0904076215995526],
385 | ["passions", 0.09064053503449236],
386 | ["johann", 0.09114771748950079],
387 | ["maplewood", 0.09126008685765656],
388 | ["user1", 0.0916424411590356],
389 | ["appreciating", 0.09166077289251462],
390 | ["bibliophiles", 0.09170778747944702],
391 | ["reverberated", 0.09178604774163615],
392 | ["insightfulness", 0.09183861554018956],
393 | ["amina", 0.09188898218908503],
394 | ["unwavering", 0.091899512603219],
395 | ["makena", 0.09209705774418087],
396 | ["strummed", 0.09229274609527229],
397 | ["dataisbeautiful", 0.09250139284705605],
398 | ["geocache", 0.09268887739812762],
399 | ["conlangs", 0.09280617830631038],
400 | ["geocaching", 0.09321966631692781],
401 | ["advancements", 0.09329024874974093],
402 | ["maria", 0.09333053475643405],
403 | ["shying", 0.09335338497711734],
404 | ["quaint", 0.0934056172785989],
405 | ["unforeseen", 0.09344290002506736],
406 | ["strategizing", 0.09399091709419581],
407 | ["dialogues", 0.09408175325423211],
408 | ["insurmountable", 0.09416355517968139],
409 | ["clinked", 0.09417327657033364],
410 | ["trusty", 0.09444824412826175],
411 | ["persevered", 0.0946142869116114],
412 | ["collaboratively", 0.09474179629801663],
413 | ["fascinated", 0.09484997447900628],
414 | ["thrived", 0.09510888010451658],
415 | ["anika", 0.09540764866213361],
416 | ["chanoyu", 0.09550084290809815],
417 | ["profusely", 0.09557139532288833],
418 | ["eliab", 0.09566581739075723],
419 | ["zara", 0.09584241402517779],
420 | ["solidifying", 0.09608868269878561],
421 | ["naledi", 0.09635414182326163],
422 | ["murmured", 0.09655989921185353],
423 | ["prided", 0.09682772402036069],
424 | ["curveball", 0.0969267440617325],
425 | ["belongingness", 0.09704829827815857],
426 | ["hometown", 0.09706682704637251],
427 | ["glanced", 0.09716443448468784],
428 | ["dismissiveness", 0.09731373843190905],
429 | ["kavarna", 0.09751433047300515],
430 | ["echoed", 0.09752749272947492],
431 | ["arben", 0.09776214191867087],
432 | ["clara", 0.09810844642230725],
433 | ["wonderment", 0.09839827901176737],
434 | ["ayla", 0.09865582838534621],
435 | ["aquarist", 0.09871535431592793],
436 | ["twinkling", 0.09878358270305657],
437 | ["yearned", 0.09914152313228353],
438 | ["sqrt", 0.09938872210307102],
439 | ["paused", 0.0997390693721105],
440 | ["nurturing", 0.09974100543627683],
441 | ["avid", 0.0997656533181099],
442 | ["brimming", 0.09978785592283941],
443 | ["freydis", 0.09986607107311554],
444 | ["gesturing", 0.0999584693413217],
445 | ["seasoned", 0.10002025957030909],
446 | ["zizzi", 0.10014154117375548],
447 | ["claudette", 0.10041075075701582],
448 | ["breadmaking", 0.10053842348672806],
449 | ["hyperparameters", 0.10081921281211534],
450 | ["naturedly", 0.10150071747381084],
451 | ["transformative", 0.1016973017391017],
452 | ["blossomed", 0.10245768551030358],
453 | ["pastimes", 0.10254359260186416],
454 | ["meera", 0.10270915366391394],
455 | ["slipups", 0.10302995120841944],
456 | ["intricacies", 0.10317249299511089],
457 | ["enthusiast", 0.10318098408793414],
458 | ["clinking", 0.1033139770842943],
459 | ["alexei", 0.10332938802798271],
460 | ["underscored", 0.10337155533898727],
461 | ["ramesh", 0.10351934507788225],
462 | ["huddled", 0.10356308221752378],
463 | ["jaspreet", 0.10363536575781072],
464 | ["ofrenda", 0.10390637614166792],
465 | ["gnawed", 0.10391875479983416],
466 | ["quirks", 0.10394287014271747],
467 | ["whirred", 0.10401340688051537],
468 | ["sipped", 0.10411521550710723],
469 | ["fatima", 0.10414541078497928],
470 | ["empathized", 0.1042529964272074],
471 | ["cheerily", 0.10439249128151971],
472 | ["unexpected", 0.10470011792513964],
473 | ["reflecting", 0.10515486160417738],
474 | ["nervously", 0.10518028779172277],
475 | ["melodia", 0.10524116896704774],
476 | ["unlikeliest", 0.10524773887823077],
477 | ["intertwine", 0.10609835729514719],
478 | ["perusing", 0.1062448476794195],
479 | ["towering", 0.10630270599082497],
480 | ["prioritizing", 0.10718619711752761],
481 | ["teemed", 0.1074711285533854],
482 | ["astonishment", 0.10756966481041869],
483 | ["showcasing", 0.10848411397230974],
484 | ["diligently", 0.1085706157908363],
485 | ["setbacks", 0.10864842590384427],
486 | ["exhilarated", 0.10931278623094859],
487 | ["murmurs", 0.10931293954230643],
488 | ["gleaming", 0.10936463324897595],
489 | ["coexisted", 0.10950456348152442],
490 | ["yellowed", 0.10968475409884669],
491 | ["seamlessly", 0.10982629613453346],
492 | ["ominously", 0.11001205898834783],
493 | ["quietude", 0.11031040532864991],
494 | ["adorning", 0.11031906309377383],
495 | ["teeming", 0.11033518418005982],
496 | ["countless", 0.1105164103825428],
497 | ["peculiar", 0.11094333701984035],
498 | ["precariously", 0.11107316266001163],
499 | ["deepened", 0.11113389824614932],
500 | ["embarked", 0.11115616320098191],
501 | ["empathetic", 0.11155900191845305],
502 | ["triumphantly", 0.11184299931805378],
503 | ["remorsefully", 0.11186809526521746],
504 | ["grumble", 0.1119706215959465],
505 | ["nuances", 0.11200998355477185],
506 | ["researching", 0.11250797352943258],
507 | ["pulsated", 0.11270598588075774],
508 | ["aquafaba", 0.11273062908980729],
509 | ["lila", 0.11299047934206825],
510 | ["hunched", 0.11312356011167964],
511 | ["reminiscing", 0.11313931448379727],
512 | ["shockwaves", 0.11320528290207856],
513 | ["considerately", 0.11367512342298255],
514 | ["sparking", 0.11383803128279647],
515 | ["emboldened", 0.11402368058243963],
516 | ["delectable", 0.11409343918029602],
517 | ["vigor", 0.11452660354436013],
518 | ["aisha", 0.11507399789890498]
519 | ]
--------------------------------------------------------------------------------
/slop_regexes.txt:
--------------------------------------------------------------------------------
1 | "(?i)not [^.!?]{3,60} but",
2 | "(?i)each(?:\s*\w+\s*|\s*)a",
3 | "(?i)every(?:\s*\w+\s*|\s*)a"
--------------------------------------------------------------------------------
/src/antislop_generate.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import List, Dict, Tuple, Generator, Set, Union
3 | from threading import Thread
4 | import threading
5 | import queue
6 | import torch
7 | from transformers import (
8 | PreTrainedTokenizer,
9 | PreTrainedModel,
10 | StoppingCriteriaList,
11 | TextIteratorStreamer
12 | )
13 | from IPython.display import display, HTML
14 | from ipywidgets import Output
15 | from src.validator_slop import SlopPhraseHandler, CustomSlopPhraseStoppingCriteria
16 | from src.validator_json import JSONValidator, JSONValidationStoppingCriteria
17 | from src.validator_regex import RegexValidator, RegexValidationStoppingCriteria
18 |
19 | from src.util import precompute_starting_tokens
20 | import asyncio
21 |
22 | class AntiSlopSampler:
23 | def __init__(
24 | self,
25 | model: PreTrainedModel,
26 | tokenizer: PreTrainedTokenizer,
27 | slop_phrase_prob_adjustments: Dict[str, float],
28 | starting_tokens_lookup: Dict[Tuple[int, ...], Set[int]],
29 | adjustment_strength: float = 1.0,
30 | device: torch.device = torch.device('cuda'),
31 | slow_debug: bool = False,
32 | output_every_n_tokens: int = 1,
33 | debug_delay: float = 2.0,
34 | inference_output=None,
35 | debug_output=None,
36 | regex_bans: List[str] = [],
37 | enforce_json: bool = False,
38 | antislop_enabled: bool = True,
39 | ):
40 | self.model = model
41 | self.tokenizer = tokenizer
42 | self.slop_phrase_prob_adjustments = slop_phrase_prob_adjustments
43 | self.starting_tokens_lookup = starting_tokens_lookup
44 | self.adjustment_strength = adjustment_strength
45 | self.device = device
46 | self.slow_debug = slow_debug
47 | self.output_every_n_tokens = output_every_n_tokens
48 | self.debug_delay = debug_delay
49 | self.enforce_json = enforce_json
50 | self.antislop_enabled = antislop_enabled
51 | self.regex_bans = regex_bans or []
52 |
53 | self.sequence_queue = queue.Queue()
54 | self.generation_complete = threading.Event()
55 |
56 | # Output widgets
57 | self.inference_output = inference_output
58 | self.debug_output = debug_output
59 |
60 | self.probs_cache = {}
61 | self.probs_cache_longrange = {} # flags which positions in the logit cache we ignore during cleanup, as we want to keep some positions for long range constraint checks
62 |
63 | # Escaped toks used for lookups in json string repair
64 | self.escaped_tokens_lookup = {
65 | '\n': self.tokenizer.encode('\\n', add_special_tokens=False),
66 | '\t': self.tokenizer.encode('\\t', add_special_tokens=False),
67 | '\r': self.tokenizer.encode('\\r', add_special_tokens=False),
68 | '"': self.tokenizer.encode('\\"', add_special_tokens=False),
69 | ' "': self.tokenizer.encode(' \\"', add_special_tokens=False),
70 | }
71 |
72 | # Initialize Slop Phrase Handler
73 | self.slop_phrase_handler = SlopPhraseHandler(
74 | tokenizer=tokenizer,
75 | probs_cache=self.probs_cache,
76 | probs_cache_longrange=self.probs_cache_longrange,
77 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,
78 | starting_tokens_lookup=starting_tokens_lookup,
79 | adjustment_strength=adjustment_strength,
80 | slow_debug=slow_debug,
81 | inference_output=inference_output,
82 | debug_output=debug_output,
83 | debug_delay=debug_delay
84 | )
85 | self.json_validator = JSONValidator(tokenizer, slow_debug, debug_delay, debug_output, self.probs_cache_longrange)
86 |
87 | # Initialize Regex Validator if regex patterns are provided
88 | if self.regex_bans:
89 | self.regex_validator = RegexValidator(
90 | tokenizer=tokenizer,
91 | regex_bans=self.regex_bans,
92 | slow_debug=slow_debug,
93 | debug_delay=debug_delay,
94 | debug_output=debug_output,
95 | probs_cache_longrange=self.probs_cache_longrange
96 | )
97 | else:
98 | self.regex_validator = None
99 |
100 | self.streamer_retval = None
101 |
102 | def _generate_streaming(self, current_input_ids, new_toks_to_generate, temperature, min_p, top_k, top_p, pad_token_id, stopping_criteria_args):
103 | streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=False)
104 |
105 | generation_kwargs = dict(
106 | input_ids=current_input_ids,
107 | attention_mask=torch.ones_like(current_input_ids),
108 | max_new_tokens=new_toks_to_generate,
109 | do_sample=True,
110 | temperature=temperature,
111 | min_p=min_p,
112 | top_k=top_k,
113 | top_p=top_p,
114 | pad_token_id=pad_token_id,
115 | num_return_sequences=1,
116 | return_dict_in_generate=True,
117 | output_logits=True,
118 | streamer=streamer,
119 | **stopping_criteria_args
120 | )
121 |
122 | # Create an Event to signal thread termination
123 | stop_event = threading.Event()
124 |
125 | # Create a Queue to store the generation output or errors
126 | output_queue = queue.Queue()
127 |
128 | # Define a function to run generation and put the result in the queue
129 | def generate_and_queue():
130 | try:
131 | output = self.model.generate(**generation_kwargs)
132 | if not stop_event.is_set():
133 | output_queue.put((output, None)) # None means no exception occurred
134 | except Exception as e:
135 | print(f"Exception during generation: {e}") # Debug print
136 | if not stop_event.is_set():
137 | output_queue.put((None, e)) # Put the exception in the queue
138 | stop_event.set()
139 |
140 | # Start the generation in a separate thread
141 | thread = Thread(target=generate_and_queue)
142 | thread.start()
143 |
144 | try:
145 | for new_text in streamer:
146 | yield new_text
147 | except Exception as e:
148 | print(f"Exception during streaming: {e}") # Debug print
149 | # Add the exception to the output queue so it is propagated to the caller
150 | if not stop_event.is_set():
151 | output_queue.put((None, e)) # Handle exception during streaming
152 |
153 | # Wait for the generation to complete or for the thread to be terminated
154 | thread.join()
155 |
156 | # Initialize default empty lists for output variables
157 | generated_sequence = []
158 | new_logits = []
159 | error = None # Default error to None
160 |
161 | # Check if there's any output in the queue
162 | if not output_queue.empty():
163 | generation_output, error = output_queue.get()
164 |
165 | # Check if an error occurred during generation or streaming
166 | if error:
167 | print(f"Generation or streaming failed: {error}")
168 | else:
169 | # Extract logits and sequence from the generation output
170 | new_logits = generation_output.logits
171 | generated_sequence = generation_output.sequences[0].tolist()
172 |
173 | # Add final debug information for empty output
174 | if not generated_sequence:
175 | print("Warning: Generated sequence is empty.")
176 | if not new_logits:
177 | print("Warning: Logits are empty.")
178 |
179 | # Return the generated sequence, logits, and any error
180 | self.streamer_retval = (generated_sequence, new_logits, error)
181 |
182 |
183 |
184 | @torch.no_grad()
185 | async def generate_stream(
186 | self,
187 | prompt: str,
188 | max_length: int = None,
189 | max_new_tokens: int = None,
190 | temperature: float = 1.0,
191 | top_k: int = None,
192 | top_p: float = None,
193 | min_p: float = None,
194 | ):
195 | """
196 | Generates text in a streaming fashion with custom downregulation and backtracking.
197 |
198 | Args:
199 | prompt (str): The initial text prompt.
200 | max_length (int, optional): The maximum length of the generated text.
201 | max_new_tokens (int, optional): The maximum number of new tokens to generate.
202 | temperature (float): Sampling temperature.
203 | top_k (int): Top-k filtering.
204 | top_p (float): Top-p (nucleus) filtering.
205 | min_p (float): Minimum probability filtering.
206 |
207 | Yields:
208 | Generator[List[int], None, None]: Yields generated token sequences.
209 | """
210 | try:
211 | # Encode the prompt
212 | input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
213 | generated_sequence = input_ids[0].tolist()
214 |
215 | # If the prompt already came with a bos token, we don't want to add it again
216 | if self.tokenizer.bos_token and \
217 | prompt.startswith(self.tokenizer.bos_token) and \
218 | not prompt.startswith(self.tokenizer.bos_token * 2) and \
219 | generated_sequence[0] == self.tokenizer.bos_token_id and \
220 | generated_sequence[1] == self.tokenizer.bos_token_id:
221 | generated_sequence = generated_sequence[1:]
222 |
223 | self.prompt_length = len(generated_sequence)
224 | self.prompt_length_chars = len(prompt)
225 | current_position = len(generated_sequence) # Tracks the current position in the sequence
226 | output_tokens_counter = 0
227 | pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
228 | next_token_logits = None
229 | filtered_logits = None
230 |
231 | if max_length != None:
232 | this_max_new_tokens = max_length - self.prompt_length
233 | if this_max_new_tokens < 0:
234 | this_max_new_tokens = 0
235 | if max_new_tokens == None or this_max_new_tokens < max_new_tokens:
236 | max_new_tokens = this_max_new_tokens
237 | else:
238 | if max_new_tokens == None:
239 | max_new_tokens = 8096
240 |
241 | stopping_criteria_args = {}
242 | self.stopping_criteria = []
243 |
244 | if self.enforce_json:
245 | json_stopping_criteria = JSONValidationStoppingCriteria(
246 | tokenizer=self.tokenizer,
247 | json_validator=self.json_validator,
248 | prompt_length=self.prompt_length
249 | )
250 | self.stopping_criteria.append(json_stopping_criteria)
251 |
252 | if self.antislop_enabled:
253 | antislop_stopping_criteria = CustomSlopPhraseStoppingCriteria(
254 | tokenizer=self.tokenizer,
255 | max_slop_phrase_length=self.slop_phrase_handler.max_slop_phrase_length,
256 | min_slop_phrase_length=self.slop_phrase_handler.min_slop_phrase_length,
257 | prompt_length=self.prompt_length,
258 | slop_phrase_prob_adjustments=self.slop_phrase_handler.slop_phrase_prob_adjustments
259 | )
260 | self.stopping_criteria.append(antislop_stopping_criteria)
261 |
262 | # Initialize Regex Validation Stopping Criteria
263 | if self.regex_validator:
264 | regex_stopping_criteria = RegexValidationStoppingCriteria(
265 | tokenizer=self.tokenizer,
266 | regex_validator=self.regex_validator,
267 | prompt_length=self.prompt_length
268 | )
269 | self.stopping_criteria.append(regex_stopping_criteria)
270 |
271 |
272 | if self.stopping_criteria:
273 | stopping_criteria_args = {
274 | "stopping_criteria": StoppingCriteriaList(self.stopping_criteria)
275 | }
276 |
277 |
278 | while True:
279 | if max_new_tokens is not None and len(generated_sequence) - self.prompt_length >= max_new_tokens:
280 | break
281 |
282 | new_toks_to_generate = max_new_tokens - (len(generated_sequence) - self.prompt_length)
283 |
284 | current_input_ids = torch.tensor([generated_sequence], device=self.device)
285 |
286 | regenerating = False
287 |
288 | if current_position in self.probs_cache:
289 | # We backtracked and want to use the cached logits
290 | next_token_probs = self.probs_cache[current_position]
291 | regenerating = True
292 | else:
293 | context = ""
294 | #print(new_toks_to_generate)
295 | for new_text in self._generate_streaming(
296 | current_input_ids,
297 | new_toks_to_generate,
298 | temperature,
299 | min_p,
300 | top_k,
301 | top_p,
302 | pad_token_id,
303 | stopping_criteria_args
304 | ):
305 | context += new_text
306 | output_tokens_counter += 1
307 |
308 | # sometimes model.generate adds an extra bos token so we'll manually clip it off.
309 | # otherwise we have conflicts with the originally calculated prompt_length
310 | if self.tokenizer.bos_token and \
311 | prompt.startswith(self.tokenizer.bos_token) and \
312 | not prompt.startswith(self.tokenizer.bos_token * 2) and \
313 | context.startswith(self.tokenizer.bos_token * 2):
314 | context = context[len(self.tokenizer.bos_token):]
315 |
316 | if output_tokens_counter >= self.output_every_n_tokens:
317 | output_tokens_counter = 0
318 |
319 | if self.inference_output:
320 | with self.inference_output:
321 | self.inference_output.clear_output(wait=True)
322 |
323 | display(HTML(f"{context[self.prompt_length_chars:]}
"))
324 |
325 | self.sequence_queue.put(self.tokenizer.encode(context, add_special_tokens=False))
326 |
327 |
328 | torch.cuda.empty_cache()
329 |
330 | # sync with the returned vals in case the streaming came thru out of order
331 | # (not sure if this is necessary)
332 | if self.streamer_retval:
333 | generated_sequence, new_logits, error = self.streamer_retval
334 | if error:
335 | self.generation_complete.set()
336 | return
337 |
338 | # sometimes model.generate adds an extra bos token so we'll manually clip it off.
339 | # otherwise we have conflicts with the originally calculated prompt_length
340 | if self.tokenizer.bos_token and \
341 | prompt.startswith(self.tokenizer.bos_token) and \
342 | not prompt.startswith(self.tokenizer.bos_token * 2) and \
343 | generated_sequence[0] == self.tokenizer.bos_token_id and \
344 | generated_sequence[1] == self.tokenizer.bos_token_id:
345 | generated_sequence = generated_sequence[1:]
346 |
347 | self.streamer_retval = None
348 | else:
349 | print('!! error missing retval from streamer')
350 | self.generation_complete.set()
351 | return
352 |
353 | for i, logit in enumerate(new_logits):
354 | self.probs_cache[current_position + i] = torch.softmax(logit.clone(), dim=-1)
355 |
356 | next_token = generated_sequence[-1]
357 | current_position = len(generated_sequence)
358 |
359 |
360 | if regenerating:
361 | # Apply min_p, top-k and top-p filtering
362 | filtered_probs = self._filter_probs(next_token_probs, top_k, top_p, min_p)
363 | # Sample the next token
364 | next_token_index = torch.multinomial(filtered_probs, num_samples=1)
365 |
366 | next_token = next_token_index.item()
367 | # Append the new token to the sequence
368 | generated_sequence.append(next_token)
369 | #print('alt token:',self.tokenizer.decode(next_token), self.probs_cache[current_position][:, next_token])
370 | output_tokens_counter += 1
371 |
372 | if output_tokens_counter >= self.output_every_n_tokens:
373 | output_tokens_counter = 0
374 | current_text = self.tokenizer.decode(generated_sequence[self.prompt_length:])
375 | if self.inference_output:
376 | with self.inference_output:
377 | self.inference_output.clear_output(wait=True)
378 | display(HTML(f"{current_text}
"))
379 | self.sequence_queue.put(generated_sequence)
380 | #print('downregulated token after reselection', self.slop_phrase_handler.probs_cache[current_position][:, self.json_validator.last_downregulated_token])
381 | current_position = len(generated_sequence)
382 |
383 | if regenerating and self.slow_debug:
384 | alt_token = self.tokenizer.decode(next_token, skip_special_tokens=True)
385 | debug_info = f"Alternate token: {[alt_token]}"
386 |
387 | self._display_debug(debug_info)
388 | if self.slow_debug:
389 | time.sleep(self.debug_delay)
390 |
391 |
392 | # Clean up the probs cache
393 | if not self.enforce_json and not self.regex_bans:
394 | # json validation needs to keep the long range dependencies
395 | # although we can probably delete the ones that aren't flagged in self.probs_cache_longrange.
396 | if self.antislop_enabled:
397 | to_del = [key for key in self.probs_cache if key < current_position - self.slop_phrase_handler.max_slop_phrase_length - 5 and not self.probs_cache_longrange.get(key, False)]
398 | for key in to_del:
399 | if key not in self.probs_cache_longrange:
400 | del self.probs_cache[key]
401 |
402 |
403 | # Check for end-of-sequence token
404 | if next_token == self.tokenizer.eos_token_id:
405 | break
406 |
407 | # JSON validation
408 | if self.enforce_json:
409 | result = self.json_validator.validate_json_string(generated_sequence, self.prompt_length, self.probs_cache)
410 | if result != False:
411 | generated_sequence = result
412 | current_position = len(generated_sequence)
413 | continue # Skip the rest of this iteration and start over
414 |
415 | # After adding the token, check for disallowed sequences
416 | if self.antislop_enabled:
417 | antislop_result = self.slop_phrase_handler.deslop(generated_sequence, self.prompt_length)
418 | if antislop_result != False:
419 | generated_sequence = antislop_result
420 | current_position = len(generated_sequence)
421 | continue
422 |
423 | # Initialize Regex Validation Stopping Criteria
424 | if self.regex_validator:
425 | regex_result = self.regex_validator.validate_regex_matches(generated_sequence, self.prompt_length, self.probs_cache)
426 | if regex_result != False:
427 | generated_sequence = regex_result
428 | current_position = len(generated_sequence)
429 | continue
430 |
431 |
432 |
433 | # Final display of the generated text
434 | final_text = self.tokenizer.decode(generated_sequence[self.prompt_length:], skip_special_tokens=False)
435 | if self.inference_output:
436 | with self.inference_output:
437 | self.inference_output.clear_output(wait=True)
438 | display(HTML(f"{final_text}
"))
439 | self.sequence_queue.put(generated_sequence)
440 |
441 |
442 | # Clear variables to free up memory
443 | del next_token_logits, filtered_logits
444 |
445 | # signal end of generation
446 | self.generation_complete.set()
447 | except Exception as e:
448 | print(e)
449 | # signal end of generation
450 | self.generation_complete.set()
451 |
452 |
453 | def _filter_probs(self, probs: torch.FloatTensor, top_k: int, top_p: float, min_p: float) -> torch.FloatTensor:
454 | # Make a copy of the probabilities to ensure we do not modify the original tensor
455 | probs = probs.clone()
456 |
457 | # Apply min_p filtering
458 | if min_p is not None:
459 | top_prob, _ = torch.max(probs, dim=-1)
460 | scaled_min_p = min_p * top_prob
461 | probs = torch.where(probs < scaled_min_p, 0, probs)
462 |
463 | if top_k is not None and top_k > 0:
464 | top_k = min(top_k, probs.size(-1))
465 | top_k_probs, _ = torch.topk(probs, top_k)
466 | min_top_k = top_k_probs[:, -1].unsqueeze(-1)
467 | probs = torch.where(probs < min_top_k, 0, probs)
468 |
469 | if top_p is not None and top_p < 1.0:
470 | sorted_probs, sorted_indices = torch.sort(probs, descending=True)
471 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
472 |
473 | sorted_indices_to_remove = cumulative_probs > top_p
474 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
475 | sorted_indices_to_remove[:, 0] = False
476 | indices_to_remove = sorted_indices_to_remove.scatter(
477 | dim=1, index=sorted_indices, src=sorted_indices_to_remove
478 | )
479 | probs = probs.masked_fill(indices_to_remove, 0)
480 |
481 | return probs
482 |
483 | def _display_debug(self, message: str):
484 | """
485 | Displays debug information in the debug_output widget.
486 | """
487 | if self.debug_output:
488 | with self.debug_output:
489 | self.debug_output.clear_output(wait=True)
490 | display(HTML(f"{message}
"))
491 | else:
492 | print(message)
493 |
494 | def _clear_gpu_memory_async(self):
495 | def clear_gpu_memory():
496 | torch.cuda.empty_cache()
497 |
498 | # Create and start the daemon thread
499 | cleaner_thread = threading.Thread(target=clear_gpu_memory, daemon=True)
500 | cleaner_thread.start()
501 |
502 | # Return immediately without waiting for the thread
503 | return
504 |
505 | def cleanup(self):
506 | # Clear the queue
507 | while not self.sequence_queue.empty():
508 | try:
509 | self.sequence_queue.get_nowait()
510 | except queue.Empty:
511 | break
512 |
513 | # Clear the event
514 | self.generation_complete.clear()
515 |
516 | # Clear caches
517 | self.probs_cache.clear()
518 | self.probs_cache_longrange.clear()
519 |
520 | # Clear other attributes
521 | self.model = None
522 | self.tokenizer = None
523 | self.slop_phrase_handler = None
524 | self.json_validator = None
525 | self.regex_validator = None
526 |
527 | # Clear output widgets
528 | if self.inference_output:
529 | self.inference_output.clear_output()
530 | if self.debug_output:
531 | self.debug_output.clear_output()
532 |
533 | # Clear CUDA cache
534 | if torch.cuda.is_available():
535 | torch.cuda.empty_cache()
536 |
537 | def chat_antislop(
538 | model: PreTrainedModel,
539 | tokenizer: PreTrainedTokenizer,
540 | messages: List[Dict[str, str]],
541 | max_length: int = None,
542 | max_new_tokens: int = None,
543 | temperature: float = 1.0,
544 | top_k: int = None,
545 | top_p: float = None,
546 | min_p: float = None,
547 | slop_phrase_prob_adjustments: Dict[str, float] = None,
548 | adjustment_strength: float = 1.0, # 1.0 is no change from the provided adjustment factors.
549 | device: torch.device = torch.device('cuda'),
550 | streaming: bool = False,
551 | stream_smoothing: bool = True,
552 | slow_debug: bool = False, # Add slow_debug argument for debugging
553 | output_every_n_tokens: int = 1, # Control how frequently the output is updated
554 | debug_delay: float = 2.0, # Delay for slow debugging mode
555 | inference_output: Output = None, # For visualization during generation
556 | debug_output: Output = None, # For visualization of debug information
557 | enforce_json: bool = False,
558 | antislop_enabled: bool = True,
559 | regex_bans: List[str] = None,
560 | ):
561 | """
562 | Generates a chat response while avoiding overrepresented phrases (slop) with debugging features.
563 | This method creates a generator or a non-streamed output, depending on the streaming flag.
564 |
565 | Args:
566 | model (PreTrainedModel): The language model.
567 | tokenizer (PreTrainedTokenizer): The tokenizer.
568 | messages (List[Dict[str, str]]): The list of messages in the conversation.
569 | max_length (int, optional): The maximum length of the generated text (including the prompt).
570 | max_new_tokens (int, optional): The maximum number of new tokens to generate.
571 | temperature (float): Sampling temperature.
572 | top_k (int): Top-k filtering.
573 | top_p (float): Top-p (nucleus) filtering.
574 | min_p (float): Minimum probability filtering.
575 | slop_phrase_prob_adjustments (Dict[str, float], optional): Dictionary of target words with their respective probability adjustment factor.
576 | adjustment_strength (float, optional): Strength of the downregulation adjustment.
577 | device (torch.device, optional): The device to run the model on.
578 | streaming (bool, optional): Whether to yield tokens as they are generated.
579 | slow_debug (bool, optional): Enables slow debug mode when set to True.
580 | output_every_n_tokens (int, optional): Frequency of updating the inference output display.
581 | debug_delay (float, optional): Time in seconds to pause during slow debug steps.
582 | inference_output (Output, optional): For visualization during generation.
583 | debug_output (Output, optional): For visualization of debug information.
584 |
585 | Returns:
586 | Union[Generator[str, None, None], List[int]]:
587 | If streaming is True, yields generated text chunks.
588 | If streaming is False, returns a list of generated token IDs.
589 | """
590 |
591 | # Build the prompt using the provided messages
592 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
593 |
594 | return generate_antislop(
595 | model=model,
596 | tokenizer=tokenizer,
597 | prompt=prompt,
598 | max_length=max_length,
599 | max_new_tokens=max_new_tokens,
600 | temperature=temperature,
601 | top_k=top_k,
602 | top_p=top_p,
603 | min_p=min_p,
604 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,
605 | adjustment_strength=adjustment_strength,
606 | device=device,
607 | slow_debug=slow_debug,
608 | output_every_n_tokens=output_every_n_tokens,
609 | debug_delay=debug_delay,
610 | inference_output=inference_output,
611 | debug_output=debug_output,
612 | antislop_enabled=antislop_enabled,
613 | enforce_json=enforce_json,
614 | regex_bans=regex_bans,
615 | streaming=streaming,
616 | stream_smoothing=stream_smoothing,
617 | )
618 |
619 |
620 | def generate_antislop(
621 | model: PreTrainedModel,
622 | tokenizer: PreTrainedTokenizer,
623 | prompt: str,
624 | max_length: int = None,
625 | max_new_tokens: int = None,
626 | temperature: float = 1.0,
627 | top_k: int = None,
628 | top_p: float = None,
629 | min_p: float = None,
630 | slop_phrase_prob_adjustments: Dict[str, float] = None,
631 | adjustment_strength: float = 1.0,
632 | device: torch.device = torch.device('cuda'),
633 | streaming: bool = False,
634 | stream_smoothing: bool = True,
635 | slow_debug: bool = False, # Added slow_debug
636 | output_every_n_tokens: int = 1,
637 | debug_delay: float = 2.0,
638 | inference_output: Output = None,
639 | debug_output: Output = None,
640 | enforce_json: bool = False,
641 | antislop_enabled: bool = True,
642 | regex_bans: List[str] = None,
643 | ) -> Union[Generator[str, None, None], List[int]]:
644 | """
645 | Wrapper function for generate_antislop that handles both streaming and non-streaming modes.
646 | """
647 | # Type checking and validation of input arguments
648 | if not isinstance(prompt, str):
649 | raise TypeError("prompt must be a string")
650 | if max_length is not None and not isinstance(max_length, int):
651 | raise TypeError("max_length must be an integer or None")
652 | if max_new_tokens is not None and not isinstance(max_new_tokens, int):
653 | raise TypeError("max_new_tokens must be an integer or None")
654 | if not isinstance(temperature, (int, float)):
655 | raise TypeError("temperature must be a float")
656 | if top_k is not None and not isinstance(top_k, int):
657 | raise TypeError("top_k must be an integer or None")
658 | if top_p is not None and not isinstance(top_p, float):
659 | raise TypeError("top_p must be a float or None")
660 | if min_p is not None and not isinstance(min_p, float):
661 | raise TypeError("min_p must be a float or None")
662 | if slop_phrase_prob_adjustments is not None and not isinstance(slop_phrase_prob_adjustments, dict):
663 | raise TypeError("slop_phrase_prob_adjustments must be a dictionary or None")
664 | if not isinstance(adjustment_strength, (int, float)):
665 | raise TypeError("adjustment_strength must be a float")
666 | if not isinstance(device, torch.device):
667 | raise TypeError("device must be an instance of torch.device")
668 | if not isinstance(streaming, bool):
669 | raise TypeError("streaming must be a boolean")
670 |
671 | # Value validation
672 | if max_length is not None and max_length <= 0:
673 | raise ValueError("max_length must be positive")
674 | if max_new_tokens is not None and max_new_tokens <= 0:
675 | raise ValueError("max_new_tokens must be positive")
676 | if temperature <= 0:
677 | raise ValueError("temperature must be > 0")
678 | if top_k is not None and top_k < 0:
679 | raise ValueError("top_k must be positive")
680 | if top_p is not None and (top_p < 0 or top_p > 1):
681 | raise ValueError("top_p must be in the range (0, 1]")
682 | if min_p is not None and (min_p < 0 or min_p > 1):
683 | print(min_p)
684 | raise ValueError("min_p must be in the range (0, 1]")
685 | if adjustment_strength < 0:
686 | raise ValueError("adjustment_strength must be non-negative")
687 |
688 | if not debug_output or not inference_output:
689 | debug_delay = 0
690 | slow_debug = False
691 |
692 | if slop_phrase_prob_adjustments:
693 | for phrase, adjustment in slop_phrase_prob_adjustments.items():
694 | if not isinstance(phrase, str):
695 | raise TypeError("All keys in slop_phrase_prob_adjustments must be strings")
696 | if not isinstance(adjustment, (int, float)):
697 | raise TypeError("All values in slop_phrase_prob_adjustments must be floats")
698 |
699 | if streaming:
700 | return _generate_antislop(
701 | model=model,
702 | tokenizer=tokenizer,
703 | prompt=prompt,
704 | max_length=max_length,
705 | max_new_tokens=max_new_tokens,
706 | temperature=temperature,
707 | top_k=top_k,
708 | top_p=top_p,
709 | min_p=min_p,
710 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,
711 | adjustment_strength=adjustment_strength,
712 | device=device,
713 | slow_debug=slow_debug, # Pass slow_debug to support detailed debug output
714 | output_every_n_tokens=output_every_n_tokens,
715 | debug_delay=debug_delay,
716 | inference_output=inference_output,
717 | debug_output=debug_output,
718 | enforce_json=enforce_json,
719 | regex_bans=regex_bans,
720 | antislop_enabled=antislop_enabled,
721 | streaming=streaming,
722 | stream_smoothing=stream_smoothing,
723 | )
724 | else:
725 | generated_tokens = []
726 | for token in _generate_antislop(
727 | model=model,
728 | tokenizer=tokenizer,
729 | prompt=prompt,
730 | max_length=max_length,
731 | max_new_tokens=max_new_tokens,
732 | temperature=temperature,
733 | top_k=top_k,
734 | top_p=top_p,
735 | min_p=min_p,
736 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,
737 | adjustment_strength=adjustment_strength,
738 | device=device,
739 | slow_debug=slow_debug, # Pass slow_debug to support detailed debug output
740 | output_every_n_tokens=output_every_n_tokens,
741 | debug_delay=debug_delay,
742 | inference_output=inference_output,
743 | debug_output=debug_output,
744 | enforce_json=enforce_json,
745 | regex_bans=regex_bans,
746 | antislop_enabled=antislop_enabled,
747 | streaming=streaming,
748 | stream_smoothing=stream_smoothing,
749 | ):
750 | generated_tokens.append(token)
751 | return generated_tokens
752 |
753 | def simple_thread_function():
754 | print("Simple thread function executed")
755 | time.sleep(1)
756 | print("Simple thread function completed")
757 |
758 |
759 | def _generate_antislop(
760 | model: PreTrainedModel,
761 | tokenizer: PreTrainedTokenizer,
762 | prompt: str,
763 | max_length: int = None,
764 | max_new_tokens: int = None,
765 | temperature: float = 1.0,
766 | top_k: int = None,
767 | top_p: float = None,
768 | min_p: float = None,
769 | slop_phrase_prob_adjustments: Dict[str, float] = None,
770 | adjustment_strength: float = 1.0,
771 | device: torch.device = torch.device('cuda'),
772 | slow_debug: bool = False, # Added slow_debug
773 | output_every_n_tokens: int = 1,
774 | debug_delay: float = 2.0,
775 | inference_output: 'Output' = None, # Assuming Output is defined elsewhere
776 | debug_output: 'Output' = None,
777 | streaming: bool = False,
778 | stream_smoothing: bool = True,
779 | enforce_json: bool = False,
780 | antislop_enabled: bool = True,
781 | regex_bans: List[str] = None,
782 | ) -> Generator[int, None, None]:
783 | """
784 | Generates text while avoiding overrepresented phrases (slop).
785 | This function is now always a generator with temporal buffering.
786 | """
787 |
788 | if streaming and regex_bans:
789 | raise ValueError("Streaming is not supported when using regex patterns.")
790 |
791 | # Precompute starting tokens for the slop phrases
792 | starting_tokens_lookup = precompute_starting_tokens(tokenizer, slop_phrase_prob_adjustments or {})
793 |
794 | # Initialize the sampler
795 | sampler = AntiSlopSampler(
796 | model=model,
797 | tokenizer=tokenizer,
798 | slop_phrase_prob_adjustments=slop_phrase_prob_adjustments or {},
799 | starting_tokens_lookup=starting_tokens_lookup,
800 | adjustment_strength=adjustment_strength,
801 | device=device,
802 | slow_debug=slow_debug, # Enable slow debugging
803 | output_every_n_tokens=output_every_n_tokens,
804 | debug_delay=debug_delay,
805 | inference_output=inference_output,
806 | debug_output=debug_output,
807 | enforce_json=enforce_json,
808 | antislop_enabled=antislop_enabled,
809 | regex_bans=regex_bans,
810 | )
811 |
812 | # Generate token stream
813 | generate_kwargs = {
814 | 'max_length': max_length,
815 | 'max_new_tokens': max_new_tokens,
816 | 'temperature': temperature,
817 | 'top_k': top_k,
818 | 'top_p': top_p,
819 | 'min_p': min_p,
820 | }
821 |
822 | loop = asyncio.new_event_loop()
823 |
824 | def run_event_loop(loop):
825 | asyncio.set_event_loop(loop)
826 | loop.run_until_complete(sampler.generate_stream(prompt, **generate_kwargs))
827 |
828 | thread = threading.Thread(target=run_event_loop, args=(loop,), daemon=True)
829 | thread.start()
830 |
831 | try:
832 | prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
833 | if len(prompt_tokens) == 0:
834 | print('! prompt is empty')
835 | return
836 |
837 | backtracking_buffer_size = sampler.slop_phrase_handler.max_slop_phrase_length + 5
838 | last_released_position = len(prompt_tokens) - 1
839 | generated_sequence = []
840 |
841 | if streaming and stream_smoothing:
842 | # Buffer to allow bactracking and also to smooth output rate
843 | temporal_buffer_size = 30
844 | last_generation_time = time.time()
845 | token_times = [last_generation_time] * len(prompt_tokens)
846 | while True:
847 | try:
848 | generated_sequence = sampler.sequence_queue.get(timeout=0.01)
849 | except queue.Empty:
850 | if sampler.generation_complete.is_set():
851 | break
852 | continue
853 |
854 | # update token times
855 | if len(generated_sequence) <= len(token_times):
856 | # we backtracked
857 | token_times = token_times[:len(generated_sequence)-1]
858 |
859 | while len(generated_sequence) - last_released_position > backtracking_buffer_size:
860 | # get the latest sequence from the queue
861 | while True:
862 | try:
863 | generated_sequence = sampler.sequence_queue.get(timeout=0.001)
864 | if len(generated_sequence) <= len(token_times):
865 | # we backtracked
866 | token_times = token_times[:len(generated_sequence)-1]
867 | except queue.Empty:
868 | break
869 | if sampler.generation_complete.is_set():
870 | break
871 |
872 | # calculate simple moving avg of last n token times
873 | adjusted_last_released_pos = last_released_position - len(prompt_tokens)
874 | sma_tokens = token_times[len(prompt_tokens):][adjusted_last_released_pos-temporal_buffer_size:adjusted_last_released_pos]
875 | if len(sma_tokens) > 0:
876 | sma_token_time = (time.time() - sma_tokens[0]) / len(sma_tokens)
877 | else:
878 | sma_token_time = 0
879 | #print(sma_token_time)
880 |
881 | if len(generated_sequence) - last_released_position > backtracking_buffer_size + temporal_buffer_size:
882 | sleep_time = 0.02
883 | else:
884 | sleep_time = sma_token_time
885 |
886 | last_released_position += 1
887 | token_to_release = generated_sequence[last_released_position]
888 |
889 | # Sleep to smooth the output
890 | if sleep_time > 0:
891 | time.sleep(sleep_time)
892 |
893 | token_times.append(time.time())
894 |
895 | # Yield the token
896 | yield token_to_release
897 | else:
898 | # smoothing disabled
899 | while True:
900 | try:
901 | generated_sequence = sampler.sequence_queue.get(timeout=0.01)
902 | except queue.Empty:
903 | if sampler.generation_complete.is_set():
904 | break
905 | continue
906 |
907 | while len(generated_sequence) - last_released_position > backtracking_buffer_size:
908 | last_released_position += 1
909 | token_to_release = generated_sequence[last_released_position]
910 |
911 | # Yield the token
912 | yield token_to_release
913 |
914 | # Release any remaining tokens after generation is complete
915 | if last_released_position < len(generated_sequence) - 1:
916 | #print(len(generated_sequence) - last_released_position, 'to release')
917 | for tok in generated_sequence[last_released_position + 1:]:
918 | # Release remaining tokens at full rate with constant delay
919 | yield tok
920 | time.sleep(0.02) # Constant delay as per user's instruction
921 |
922 | finally:
923 | # Stop the event loop
924 | loop.call_soon_threadsafe(loop.stop)
925 | # Wait for the thread to finish
926 | thread.join()
927 | # Close the loop
928 | loop.close()
929 | # Clean up the sampler
930 | sampler.cleanup()
931 | del sampler
932 |
--------------------------------------------------------------------------------
/src/slop_index.py:
--------------------------------------------------------------------------------
1 | # Calculates a slop score for a provided text
2 |
3 | import json
4 | import re
5 | import matplotlib.pyplot as plt
6 | import matplotlib.ticker as ticker
7 | import numpy as np
8 | from joblib import Parallel, delayed
9 |
10 | def load_and_preprocess_slop_words():
11 | with open('slop_phrase_prob_adjustments.json', 'r') as f:
12 | slop_phrases = json.load(f)
13 |
14 | phrase_weighting = [1.0 - prob_adjustment for word, prob_adjustment in slop_phrases]
15 | max_score = max(phrase_weighting)
16 | scaled_weightings = [score / max_score for score in phrase_weighting]
17 | n_slop_words = 600
18 | return {word.lower(): score for (word, _), score in zip(slop_phrases[:n_slop_words], scaled_weightings[:n_slop_words])}
19 |
20 | def extract_text_blocks(file_path, compiled_pattern):
21 | with open(file_path, 'r', encoding='utf-8') as file:
22 | content = file.read()
23 |
24 | matches = compiled_pattern.findall(content)
25 | return '\n'.join(matches)
26 |
27 | def calculate_slop_score_chunk(args):
28 | text, slop_words_chunk = args
29 | return sum(
30 | score * len(re.findall(r'\b' + re.escape(word) + r'\b', text))
31 | for word, score in slop_words_chunk.items()
32 | )
33 |
34 | def calculate_and_plot_slop_indices(slop_indices):
35 | if not slop_indices:
36 | print("No slop indices to plot.")
37 | return []
38 |
39 | # Sort the indices in descending order
40 | sorted_indices = sorted(slop_indices.items(), key=lambda x: x[1], reverse=True)
41 | models, indices = zip(*sorted_indices) if sorted_indices else ([], [])
42 |
43 | # Set the style for better aesthetics
44 | plt.style.use('seaborn-darkgrid') # You can choose other styles like 'ggplot', 'fivethirtyeight', etc.
45 |
46 | plt.figure(figsize=(12, 18))
47 |
48 | # Create a horizontal bar chart
49 | bars = plt.barh(models, indices, color=plt.cm.viridis(range(len(indices))))
50 |
51 | plt.title('Slop Index by Model', fontsize=16, weight='bold', pad=15)
52 | plt.xlabel('Slop Index', fontsize=14, labelpad=10)
53 | plt.ylabel('Model', fontsize=14, labelpad=10)
54 |
55 | # Invert y-axis to have the highest slop index on top
56 | plt.gca().invert_yaxis()
57 |
58 | # Add value labels to each bar
59 | for bar in bars:
60 | width = bar.get_width()
61 | plt.text(width + max(indices)*0.01, bar.get_y() + bar.get_height()/2,
62 | f'{width:.2f}', va='center', fontsize=12)
63 |
64 | # Customize x-axis ticks
65 | plt.gca().xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
66 |
67 | plt.tight_layout()
68 |
69 | # Save the figure with higher resolution
70 | plt.savefig('slop_index_chart.png', dpi=300)
71 | plt.show()
72 | plt.close()
73 |
74 | return sorted_indices
75 |
76 | def split_into_chunks(slop_words, num_chunks):
77 | slop_words_items = list(slop_words.items())
78 | chunk_size = len(slop_words_items) // num_chunks
79 | if chunk_size == 0:
80 | chunk_size = 1
81 | return [dict(slop_words_items[i:i + chunk_size]) for i in range(0, len(slop_words_items), chunk_size)]
82 |
83 |
84 | # Call this to function to calculate a slop score.
85 | # This is the way it's calculated for the eqbench creative writing leaderboard.
86 | def calculate_slop_index(extracted_text):
87 | slop_words = load_and_preprocess_slop_words()
88 |
89 | num_chunks = 12 #mp.cpu_count()
90 | slop_words_chunks = split_into_chunks(slop_words, num_chunks)
91 |
92 | if not extracted_text:
93 | slop_index = 0.0
94 | else:
95 | # Parallelize the calculation using joblib
96 | slop_scores = Parallel(n_jobs=num_chunks)(delayed(calculate_slop_score_chunk)((extracted_text, chunk)) for chunk in slop_words_chunks)
97 |
98 | slop_score = sum(slop_scores)
99 | total_words = len(extracted_text.split())
100 | slop_index = (slop_score / total_words) * 1000 if total_words > 0 else 0
101 | return slop_index
102 |
103 |
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Tuple, Generator, Set, Union
2 | from transformers import PreTrainedTokenizer
3 |
4 | def precompute_starting_tokens(
5 | tokenizer: PreTrainedTokenizer, slop_phrase_prob_adjustments: Dict[str, float]
6 | ) -> Dict[Tuple[int, ...], Set[int]]:
7 | starting_tokens_lookup = {}
8 |
9 | for slop_phrase in slop_phrase_prob_adjustments.keys():
10 | starting_tokens = set()
11 | variants = [
12 | slop_phrase.lower(),
13 | slop_phrase.capitalize(),
14 | slop_phrase.upper(),
15 | f" {slop_phrase.lower()}",
16 | f" {slop_phrase.capitalize()}",
17 | f" {slop_phrase.upper()}",
18 | ]
19 |
20 | for variant in variants:
21 | token_ids = tokenizer.encode(variant, add_special_tokens=False)
22 | if token_ids:
23 | starting_tokens.add(token_ids[0])
24 | first_token_decoded = tokenizer.decode(token_ids[0], skip_special_tokens=True)
25 |
26 | for i in range(len(first_token_decoded) - 1):
27 | prefix = first_token_decoded[:-(i + 1)]
28 | if prefix == ' ':
29 | continue
30 | encoded_prefix = tokenizer.encode(prefix, add_special_tokens=False)
31 | if encoded_prefix:
32 | starting_tokens.add(encoded_prefix[0])
33 |
34 | starting_tokens_lookup[slop_phrase] = starting_tokens
35 |
36 | return starting_tokens_lookup
--------------------------------------------------------------------------------
/src/validator_json.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | from transformers import StoppingCriteria
4 | from IPython.display import clear_output, display, HTML
5 |
6 | class JSONValidator:
7 | def __init__(self, tokenizer, slow_debug, debug_delay, debug_output, probs_cache_longrange):
8 | self.tokenizer = tokenizer
9 | self.slow_debug = slow_debug
10 | self.debug_delay = debug_delay
11 | self.debug_output = debug_output
12 | self.probs_cache_longrange = probs_cache_longrange
13 |
14 | # Escaped tokens lookup for common characters that need escaping in JSON strings
15 | self.escaped_tokens_lookup = {
16 | '\n': self.tokenizer.encode('\\n', add_special_tokens=False),
17 | '\t': self.tokenizer.encode('\\t', add_special_tokens=False),
18 | '\r': self.tokenizer.encode('\\r', add_special_tokens=False),
19 | '"': self.tokenizer.encode('\\"', add_special_tokens=False),
20 | ' "': self.tokenizer.encode(' \\"', add_special_tokens=False),
21 | }
22 |
23 | self.last_downregulated_token = 0
24 |
25 | def validate_json_string(self, generated_sequence, prompt_length, probs_cache):
26 | result = self._validate_json_string(generated_sequence, prompt_length, probs_cache)
27 | if result is not False:
28 | generated_sequence, problematic_token, invalid_char, reason = result
29 | if self.slow_debug:
30 | problematic_token_decoded = self.tokenizer.decode(problematic_token, skip_special_tokens=True)
31 | debug_info = f"JSON structure violation detected:\n{reason}\nProblematic char: {[invalid_char]} in token {[problematic_token_decoded]}"
32 | self._display_debug(debug_info)
33 | time.sleep(self.debug_delay)
34 |
35 | problematic_pos = len(generated_sequence)
36 |
37 | # Clear subsequent logit cache
38 | to_del = [key for key in probs_cache if key > problematic_pos]
39 | for key in to_del:
40 | del probs_cache[key]
41 |
42 | # Flag positions to keep in the logit cache
43 | self.probs_cache_longrange[problematic_pos] = True
44 |
45 | return generated_sequence
46 |
47 | return False
48 |
49 | # String validation checks for unescaped chars inside a (partial) json generation.
50 | # It works better to do this with with long-range constraints (at least for double
51 | # quotes) because we aren't forcing a termination of the string if that's not what
52 | # the model intended.
53 | #
54 | # To do this, if we are in a json string and see a ", we wait for more tokens to
55 | # see if the continuation looks like it was intended to be a string terminator,
56 | # or meant to be a quotation mark within the string. If it looks like it was an
57 | # accidentally unescaped quote, we downregulate this token and upregulate the
58 | # escaped token at that position (then backtrack & resample)
59 | #
60 | # It's not foolproof, but it does fix most of the json parsing fails that occur.
61 | def _validate_json_string(self, generated_sequence, prompt_length, probs_cache=None, validate_only=False):
62 | # Get only the generated text
63 | generated_text = self.tokenizer.decode(generated_sequence[prompt_length:])
64 |
65 | # Create character to token position mapping
66 | char_to_token_pos = self._create_char_to_token_mapping(generated_sequence[prompt_length:], prompt_length)
67 |
68 | in_json = False
69 | in_string = False
70 | escape_next = False
71 | validating_string_end = False
72 | string_end_char = -1
73 |
74 | # Tracks the nesting of JSON objects and arrays
75 | brace_stack = []
76 | expected_tokens = []
77 |
78 | for i, char in enumerate(generated_text):
79 | if not in_json:
80 | if char in '{[':
81 | in_json = True
82 | # Start tracking the JSON structure
83 | brace_stack.append(char)
84 | expected_tokens = self._get_expected_tokens(brace_stack)
85 | continue
86 |
87 | if escape_next:
88 | escape_next = False
89 | continue
90 |
91 | if validating_string_end:
92 | if char in ['\n', '\t', '\r', ' ']:
93 | continue
94 | elif char in ',:]}[{':
95 | # Valid string termination
96 | in_string = False
97 | validating_string_end = False
98 | expected_tokens = self._get_expected_tokens(brace_stack)
99 | else:
100 | # Invalid string termination, backtrack
101 | if validate_only:
102 | return True
103 |
104 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, char_to_token_pos[string_end_char], upregulate_safe_continuation=True, probs_cache=probs_cache)
105 | return generated_sequence, problematic_token, char, "Unexpected string termination"
106 |
107 | if char == '\\':
108 | escape_next = True
109 | elif char == '"':
110 | if in_string:
111 | # End of string, validate its termination
112 | string_end_char = i
113 | validating_string_end = True
114 | else:
115 | # Start of string
116 | in_string = True
117 | expected_tokens = ['"']
118 | elif in_string:
119 | if char in '\n\r\t':
120 | # These characters should be escaped in JSON strings
121 | if validate_only:
122 | return True
123 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, char_to_token_pos[i], upregulate_safe_continuation=True, probs_cache=probs_cache)
124 | return generated_sequence, problematic_token, char, "Unterminated character in string"
125 |
126 | elif char in '{[':
127 | brace_stack.append(char)
128 | expected_tokens = self._get_expected_tokens(brace_stack)
129 | elif char in '}]':
130 | if not brace_stack or (char == '}' and brace_stack[-1] != '{') or (char == ']' and brace_stack[-1] != '['):
131 | if validate_only:
132 | return True
133 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, char_to_token_pos[i], probs_cache=probs_cache)
134 | return generated_sequence, problematic_token, char, "Unexpected '}' or ']'"
135 | brace_stack.pop()
136 | expected_tokens = self._get_expected_tokens(brace_stack)
137 | elif char == ',':
138 | if not brace_stack or brace_stack[-1] not in '{[':
139 | if validate_only:
140 | return True
141 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, char_to_token_pos[i], probs_cache=probs_cache)
142 | return generated_sequence, problematic_token, char, "Unexpected ','"
143 | expected_tokens = self._get_expected_tokens(brace_stack)
144 | elif char == ':':
145 | if not brace_stack or brace_stack[-1] != '{':
146 | if validate_only:
147 | return True
148 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, char_to_token_pos[i], probs_cache=probs_cache)
149 | return generated_sequence, problematic_token, char, "Unexpected ':'"
150 | expected_tokens = self._get_expected_tokens(brace_stack)
151 | elif char not in ' \n\t\r':
152 | if char not in expected_tokens:
153 | if validate_only:
154 | return True
155 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, char_to_token_pos[i], probs_cache=probs_cache)
156 | return generated_sequence, problematic_token, char, "Unexpected char (expecting whitespace)"
157 |
158 | # Check if we have closed all JSON structures
159 | if not brace_stack:
160 | in_json = False
161 |
162 | return False
163 |
164 | def _get_expected_tokens(self, brace_stack):
165 | # Determine expected tokens based on current JSON structure
166 | if not brace_stack:
167 | return ['{', '[']
168 | elif brace_stack[-1] == '{':
169 | return ['"', '}', ' ', '\n', '\t', '\r']
170 | elif brace_stack[-1] == '[':
171 | return ['{', '[', ']', '"', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 't', 'f', 'n', ' ', '\n', '\t', '\r']
172 | else:
173 | return [',', '}', ']', ' ', '\n', '\t', '\r']
174 |
175 | def _create_char_to_token_mapping(self, tokens, prompt_length):
176 | char_to_token_pos = {}
177 | char_pos = 0
178 | for token_pos, token in enumerate(tokens):
179 | token_text = self.tokenizer.decode(token)
180 | for _ in range(len(token_text)):
181 | char_to_token_pos[char_pos] = token_pos + prompt_length
182 | char_pos += 1
183 | return char_to_token_pos
184 |
185 | def _backtrack_and_adjust(self, generated_sequence, problematic_pos, upregulate_safe_continuation=False, probs_cache=None):
186 | # Identify the problematic token and backtrack
187 | problematic_token = generated_sequence[problematic_pos]
188 |
189 | # Downregulate the problematic token
190 | self.last_downregulated_token = problematic_token
191 | probs_cache[problematic_pos][:, problematic_token] *= 0.0001
192 |
193 | # Upregulate the properly escaped version
194 | if upregulate_safe_continuation:
195 | self._upregulate_safe_continuation(problematic_pos, problematic_token, probs_cache)
196 |
197 | return generated_sequence[:problematic_pos], problematic_token
198 |
199 | # We need to account for cases where the problematic token is more than 1 char
200 | # let's take this approach:
201 | # - first, check if the token is in the lookup
202 | # - if so, we upregulate the escaped lookup value
203 | # - if not, we check if the 1st char of the decoded token is on of our problematic tokens
204 | # - if so, upregulate just that (escaped) token
205 | # - else, we extract the substring up to (and not including) the first problematic char
206 | # - then upregulate the first token of this encoded string, as the intended continuation
207 | def _upregulate_safe_continuation(self, problematic_pos, problematic_token, probs_cache):
208 | # this function is specifically to handle unescaped chars inside strings
209 |
210 | problematic_token_decoded = self.tokenizer.decode(problematic_token, skip_special_tokens=True)
211 |
212 | debug=False
213 |
214 | # Check if the token is in the lookup
215 | if problematic_token_decoded in self.escaped_tokens_lookup:
216 | if debug:
217 | print('upregulating1', [self.tokenizer.decode(self.escaped_tokens_lookup[problematic_token_decoded][0])])
218 | probs_cache[problematic_pos][:, self.escaped_tokens_lookup[problematic_token_decoded][0]] *= 2
219 | # normalise probs
220 | probs_cache[problematic_pos].div_(torch.sum(probs_cache[problematic_pos]))
221 |
222 | elif problematic_token_decoded[0] in self.escaped_tokens_lookup:
223 | encoded_escaped_tok = self.escaped_tokens_lookup[problematic_token_decoded[0]][0]
224 | if debug:
225 | print('upregulating2', [self.tokenizer.decode(encoded_escaped_tok)])
226 |
227 | probs_cache[problematic_pos][:, encoded_escaped_tok] *= 2
228 | # normalise probs
229 | probs_cache[problematic_pos].div_(torch.sum(probs_cache[problematic_pos]))
230 | else:
231 | # Find the first problematic character
232 | first_problematic_index = next((i for i, char in enumerate(problematic_token_decoded) if char in self.escaped_tokens_lookup), None)
233 |
234 | if first_problematic_index is not None:
235 | # Extract the substring up to the first problematic character
236 | safe_substring = problematic_token_decoded[:first_problematic_index]
237 |
238 | # Encode the safe substring
239 | encoded_safe_substring = self.tokenizer.encode(safe_substring)
240 |
241 | # Upregulate the first token of the encoded safe substring
242 | if encoded_safe_substring:
243 | if debug:
244 | print('upregulating3', [safe_substring])
245 | probs_cache[problematic_pos][:, encoded_safe_substring[0]] *= 2
246 | # normalise probs
247 | probs_cache[problematic_pos].div_(torch.sum(probs_cache[problematic_pos]))
248 |
249 |
250 | def _display_debug(self, message: str):
251 | """
252 | Displays debug information in the debug_output widget.
253 | """
254 | if self.debug_output:
255 | with self.debug_output:
256 | self.debug_output.clear_output(wait=True)
257 | display(HTML(f"{message}
"))
258 |
259 |
260 | class JSONValidationStoppingCriteria(StoppingCriteria):
261 | def __init__(self, tokenizer, json_validator, prompt_length):
262 | self.tokenizer = tokenizer
263 | self.json_validator = json_validator
264 | self.prompt_length = prompt_length
265 |
266 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
267 | self.previous_tokens = input_ids[0].tolist()
268 |
269 | # Check if the generated sequence is valid JSON
270 | result = self.json_validator._validate_json_string(self.previous_tokens, self.prompt_length, validate_only=True)
271 |
272 | return result
273 |
--------------------------------------------------------------------------------
/src/validator_regex.py:
--------------------------------------------------------------------------------
1 | # File: ai/antislop-sampler/src/validator_regex.py
2 |
3 | import re
4 | import time
5 | import torch
6 | from transformers import StoppingCriteria
7 | from IPython.display import clear_output, display, HTML
8 | from typing import List, Dict, Tuple, Optional
9 |
10 | # This implements banning of sequences using regex matching.
11 | # If the inference text matches one of the specified regex expressions,
12 | # we backtrack to the position of the first match, ban that token, then
13 | # continue inference.
14 | class RegexValidator:
15 | def __init__(self, tokenizer, regex_bans: List[str], slow_debug, debug_delay, debug_output, probs_cache_longrange):
16 | self.tokenizer = tokenizer
17 | self.regex_bans = [re.compile(pattern) for pattern in regex_bans]
18 | self.slow_debug = slow_debug
19 | self.debug_delay = debug_delay
20 | self.debug_output = debug_output
21 | self.probs_cache_longrange = probs_cache_longrange
22 |
23 | self.last_downregulated_token = 0
24 |
25 | def validate_regex_matches(self, generated_sequence, prompt_length, probs_cache):
26 | result = self._validate_regex_matches(generated_sequence, prompt_length, probs_cache)
27 | if result is not False:
28 | generated_sequence, problematic_token, invalid_match, reason = result
29 | if self.slow_debug:
30 | problematic_token_decoded = self.tokenizer.decode(problematic_token, skip_special_tokens=True)
31 | debug_info = f"Regex violation detected:\n{reason}\nProblematic match: '{invalid_match}' in token '{problematic_token_decoded}'"
32 | self._display_debug(debug_info)
33 | time.sleep(self.debug_delay)
34 |
35 | problematic_pos = len(generated_sequence)
36 |
37 | # Clear subsequent logit cache
38 | to_del = [key for key in probs_cache if key > problematic_pos]
39 | for key in to_del:
40 | del probs_cache[key]
41 |
42 | # Flag positions to keep in the logit cache
43 | self.probs_cache_longrange[problematic_pos] = True
44 |
45 | return generated_sequence
46 |
47 | return False
48 |
49 | def _validate_regex_matches(self, generated_sequence, prompt_length, probs_cache=None, validate_only=False):
50 | # Decode only the newly generated tokens
51 | generated_text = self.tokenizer.decode(generated_sequence[prompt_length:], skip_special_tokens=True)
52 |
53 | for pattern in self.regex_bans:
54 | match = pattern.search(generated_text)
55 | if match:
56 | # Find the position in tokens where the match starts
57 | match_start_char_pos = match.start()
58 |
59 | # We prepend a char because some tokenisers won't add the initial space
60 | # for the first token which would otherwise be there if there was a token
61 | # in front of it. If we don't have this we'll sometimes match to the wrong
62 | # token position.
63 | prepended_char = self.tokenizer.encode('|', add_special_tokens=False)
64 | problematic_token_pos = None
65 |
66 | for start_pos in range(len(generated_sequence)-1, prompt_length-1, -1):
67 | test_str = self.tokenizer.decode(prepended_char + generated_sequence[start_pos:], skip_special_tokens=True)
68 |
69 | if len(generated_text) - (len(test_str) - 1) <= match_start_char_pos: # -1 is to account for the prepended char
70 | problematic_token_pos = start_pos
71 | break
72 | if problematic_token_pos == None:
73 | print('!! failed to get problematic token pos')
74 | return False
75 |
76 | problematic_token = generated_sequence[problematic_token_pos]
77 | reason = f"Regex pattern '{pattern.pattern}' matched."
78 |
79 | if validate_only:
80 | return True
81 |
82 | # Adjust the generated sequence and cache
83 | generated_sequence, problematic_token = self._backtrack_and_adjust(generated_sequence, problematic_token_pos, upregulate_safe_continuation=False, probs_cache=probs_cache)
84 | return generated_sequence, problematic_token, match.group(), reason
85 |
86 | return False
87 |
88 | def _create_char_to_token_mapping(self, tokens, prompt_length):
89 | char_to_token_pos = {}
90 | char_pos = 0
91 | for token_pos, token in enumerate(tokens):
92 | token_text = self.tokenizer.decode([token], skip_special_tokens=True)
93 | for _ in range(len(token_text)):
94 | char_to_token_pos[char_pos] = token_pos + prompt_length
95 | char_pos += 1
96 | return char_to_token_pos
97 |
98 | def _backtrack_and_adjust(self, generated_sequence, problematic_pos, upregulate_safe_continuation=False, probs_cache=None):
99 | # Identify the problematic token and backtrack
100 | problematic_token = generated_sequence[problematic_pos]
101 |
102 | # Downregulate the problematic token
103 | self.last_downregulated_token = problematic_token
104 | probs_cache[problematic_pos][:, problematic_token] *= 0.0001
105 |
106 | # Flag positions to keep in the logit cache
107 | self.probs_cache_longrange[problematic_pos] = True
108 |
109 | # Clear the probs_cache ahead of start_pos since we've backtracked
110 | to_del = [key for key in probs_cache if key > problematic_pos]
111 | for key in to_del:
112 | del probs_cache[key]
113 |
114 | return generated_sequence[:problematic_pos], problematic_token
115 |
116 |
117 |
118 | def _display_debug(self, message: str):
119 | """
120 | Displays debug information in the debug_output widget.
121 | """
122 | if self.debug_output:
123 | with self.debug_output:
124 | self.debug_output.clear_output(wait=True)
125 | display(HTML(f"{message}
"))
126 |
127 | class RegexValidationStoppingCriteria(StoppingCriteria):
128 | def __init__(self, tokenizer, regex_validator, prompt_length):
129 | self.tokenizer = tokenizer
130 | self.regex_validator = regex_validator
131 | self.prompt_length = prompt_length
132 |
133 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
134 | previous_tokens = input_ids[0].tolist()
135 |
136 | # Check if any regex pattern matches the generated sequence
137 | result = self.regex_validator._validate_regex_matches(previous_tokens, self.prompt_length, validate_only=True)
138 |
139 | return result
140 |
--------------------------------------------------------------------------------
/src/validator_slop.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import List, Dict, Tuple, Generator, Set, Union
3 |
4 | import torch
5 | from transformers import (
6 | PreTrainedTokenizer,
7 | StoppingCriteria,
8 | )
9 | from IPython.display import display, HTML
10 | from ipywidgets import Output
11 |
12 |
13 | # This function detects if any of the slop phrases are in the end segment of inference.
14 | # It is optimised using dict lookups to be more or less constant execution time regardless
15 | # of slop list length.
16 | def detect_disallowed_sequence(tokenizer: PreTrainedTokenizer,
17 | inference: str,
18 | generated_sequence: List[int],
19 | prompt_length: int,
20 | slop_phrase_prob_adjustments: Dict[str, float],
21 | max_slop_phrase_length: int,
22 | min_slop_phrase_length: int,
23 | check_n_chars_back: int = 16 # this moves the detection window back n chars, so we can detect phrases that were completed further back
24 | ) -> Tuple[Tuple[int, ...], int]:
25 |
26 | inference = inference.lower()
27 |
28 | for char_offset in range(0, check_n_chars_back):
29 | for candidate_str_length in range(max_slop_phrase_length, min_slop_phrase_length - 1, -1):
30 | if candidate_str_length + char_offset > len(inference):
31 | continue
32 | candidate_str = inference[-(candidate_str_length + char_offset):len(inference)-char_offset]
33 | #print(candidate_str)
34 | if candidate_str in slop_phrase_prob_adjustments:
35 | # determine the token containing the beginning of the detected phrase
36 | #print('looking for', candidate_str,'in decoded text')
37 | for start_pos in range(len(generated_sequence)-1, prompt_length-1, -1):
38 | candidate_seq = generated_sequence[start_pos:]
39 | candidate_seq_decoded = tokenizer.decode(candidate_seq, skip_special_tokens=True).lower()
40 | #print(candidate_seq_decoded)
41 | if candidate_str in candidate_seq_decoded:
42 | #print('detected!', candidate_str, time.time() - start)
43 | return candidate_str, start_pos
44 | # if we reached here, something went wrong
45 | print('!! candidate_str not found after decoding')
46 |
47 | return None, -1
48 |
49 | class SlopPhraseHandler:
50 | def __init__(
51 | self,
52 | tokenizer: PreTrainedTokenizer,
53 | probs_cache: Dict[int, List[int]],
54 | probs_cache_longrange: Dict[int, bool],
55 | slop_phrase_prob_adjustments: Dict[str, float],
56 | starting_tokens_lookup: Dict[Tuple[int, ...], Set[int]],
57 | adjustment_strength: float,
58 | slow_debug: bool,
59 | inference_output: Output | None,
60 | debug_output: Output | None,
61 | debug_delay: float,
62 | ):
63 | self.tokenizer = tokenizer
64 | self.probs_cache = probs_cache
65 | self.probs_cache_longrange = probs_cache_longrange
66 | self.slop_phrase_prob_adjustments = slop_phrase_prob_adjustments
67 | self.starting_tokens_lookup = starting_tokens_lookup
68 | self.adjustment_strength = adjustment_strength
69 | self.slow_debug = slow_debug
70 | self.inference_output = inference_output
71 | self.debug_output = debug_output
72 | self.debug_delay = debug_delay
73 |
74 | self.max_slop_phrase_length = max(len(seq) for seq in self.slop_phrase_prob_adjustments.keys()) if self.slop_phrase_prob_adjustments else 0
75 | self.min_slop_phrase_length = min(len(seq) for seq in self.slop_phrase_prob_adjustments.keys()) if self.slop_phrase_prob_adjustments else 0
76 | #self.stopping_criteria = SlopPhraseStoppingCriteria(tokenizer, self.slop_phrase_sequences, self.max_slop_phrase_length)
77 |
78 | tmp = {}
79 | for key in self.slop_phrase_prob_adjustments:
80 | tmp[key.lower()] = self.slop_phrase_prob_adjustments[key]
81 | self.slop_phrase_prob_adjustments = tmp
82 |
83 |
84 |
85 | def _handle_disallowed_sequence(
86 | self,
87 | matched_phrase: str,
88 | start_pos: int,
89 | generated_sequence: List[int],
90 | probs_cache: Dict[int, torch.FloatTensor],
91 | adjustment_strength: float,
92 | slow_debug: bool,
93 | tokenizer: PreTrainedTokenizer,
94 | inference_output: Output,
95 | debug_output: Output,
96 | debug_delay: float,
97 | ) -> List[int]:
98 | # Downregulate the relevant tokens at the start_pos
99 | adjustment = self.slop_phrase_prob_adjustments[matched_phrase.lower()]
100 |
101 | # Display debug information
102 | debug_info = f"Replacing '{matched_phrase}'"
103 | self._display_debug(debug_info)
104 |
105 | if slow_debug:
106 | time.sleep(debug_delay)
107 | if debug_output:
108 | with debug_output:
109 | debug_output.clear_output(wait=True)
110 |
111 | #print('downregulating', [tokenizer.decode([generated_sequence[start_pos]])])
112 |
113 | # Identify starting tokens to downregulate
114 | slop_phrase_starting_token = generated_sequence[start_pos]
115 | starting_tokens = self.starting_tokens_lookup.get(matched_phrase.lower(), set())
116 | starting_tokens.add(slop_phrase_starting_token)
117 |
118 | for token_id in starting_tokens:
119 | self.probs_cache[start_pos][:, token_id] *= adjustment ** adjustment_strength
120 |
121 | # Check if the starting token would still be selected after downregulation
122 | if torch.argmax(self.probs_cache[start_pos]).item() == slop_phrase_starting_token:
123 | if slow_debug:
124 | debug_info = f"Slop phrase '{matched_phrase}' prob was downregulated {round(1/(adjustment**adjustment_strength), 2)}x but still selected."
125 | self._display_debug(debug_info)
126 | return generated_sequence
127 |
128 | # Backtrack: remove tokens from the generated_sequence that are part of the disallowed sequence
129 | for _ in range(len(generated_sequence) - start_pos):
130 | generated_sequence.pop()
131 |
132 | # Clear the probs_cache ahead of start_pos since we've backtracked
133 | to_del = [key for key in self.probs_cache if key > start_pos]
134 | for key in to_del:
135 | del self.probs_cache[key]
136 |
137 | return generated_sequence
138 |
139 | def deslop(self, generated_sequence, prompt_length):
140 | self.prompt_length = prompt_length
141 | # After adding the token(s), check for disallowed sequences
142 |
143 | inference = self.tokenizer.decode(generated_sequence[prompt_length:], skip_special_tokens=True)
144 |
145 | matched_phrase, start_pos = detect_disallowed_sequence(self.tokenizer,
146 | inference,
147 | generated_sequence,
148 | prompt_length,
149 | self.slop_phrase_prob_adjustments,
150 | self.max_slop_phrase_length,
151 | self.min_slop_phrase_length)
152 |
153 | if matched_phrase:
154 | if self.slow_debug:
155 | current_text = self.tokenizer.decode(generated_sequence[prompt_length:start_pos])
156 | #print([current_text])
157 | matched_phrase_to_display = self.tokenizer.decode(generated_sequence[start_pos:], skip_special_tokens=True)
158 | #print([matched_phrase_to_display])
159 | # Add HTML formatting to display the matched_phrase in red
160 | highlighted_text = f"{current_text}{matched_phrase_to_display}"
161 |
162 | with self.inference_output:
163 | self.inference_output.clear_output(wait=True)
164 | display(HTML(f"{highlighted_text}
"))
165 |
166 | # Display debug information
167 | debug_info = f"Replacing '{matched_phrase}'"
168 | self._display_debug(debug_info)
169 |
170 | if self.slow_debug:
171 | #time.sleep(self.debug_delay)
172 | if self.debug_output:
173 | with self.debug_output:
174 | self.debug_output.clear_output(wait=True)
175 |
176 | # Handle the disallowed sequence using SlopPhraseHandler
177 | generated_sequence = self._handle_disallowed_sequence(
178 | matched_phrase=matched_phrase,
179 | start_pos=start_pos,
180 | generated_sequence=generated_sequence,
181 | probs_cache=self.probs_cache,
182 | adjustment_strength=self.adjustment_strength,
183 | slow_debug=self.slow_debug,
184 | tokenizer=self.tokenizer,
185 | inference_output=self.inference_output,
186 | debug_output=self.debug_output,
187 | debug_delay=self.debug_delay,
188 | )
189 |
190 | return generated_sequence
191 | return False
192 |
193 | def _display_debug(self, message: str):
194 | """
195 | Displays debug information in the debug_output widget.
196 | """
197 | if self.debug_output:
198 | with self.debug_output:
199 | self.debug_output.clear_output(wait=True)
200 | display(HTML(f"{message}
"))
201 |
202 |
203 | class CustomSlopPhraseStoppingCriteria(StoppingCriteria):
204 | def __init__(self, tokenizer, max_slop_phrase_length, min_slop_phrase_length, prompt_length, slop_phrase_prob_adjustments):
205 | self.tokenizer = tokenizer
206 | self.max_slop_phrase_length = max_slop_phrase_length
207 | self.min_slop_phrase_length = min_slop_phrase_length
208 | self.slop_phrase_prob_adjustments = slop_phrase_prob_adjustments
209 | self.prompt_length = prompt_length
210 |
211 | # !! For some reason this isn't reliably triggered every token
212 | # which means we might have output a slop phrase a token back or so.
213 | # Not sure why!
214 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
215 | # Combine previous tokens with newly generated tokens
216 | self.previous_tokens = input_ids[0].tolist()
217 |
218 | inference = self.tokenizer.decode(self.previous_tokens[self.prompt_length:], skip_special_tokens=True)
219 |
220 | matched_phrase, start_pos = detect_disallowed_sequence(self.tokenizer,
221 | inference,
222 | self.previous_tokens,
223 | self.prompt_length,
224 | self.slop_phrase_prob_adjustments,
225 | self.max_slop_phrase_length,
226 | self.min_slop_phrase_length)
227 | if matched_phrase:
228 | #print('matched', matched_phrase)
229 | return True
230 | return False
231 |
--------------------------------------------------------------------------------