├── README.md ├── caching.py ├── guidance_gen.py ├── model_info.py ├── processor.py ├── requirements.txt ├── script.py ├── streamer.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # Guidance API: An Extension for oobabooga/text-generation-webui 2 | 3 | 4 | 5 | Guidance API is a powerful extension for oobabooga/text-generation-webui that integrates the feature-rich and easy-to-use interface of OOGA with the robust capabilities of Guidance. By facilitating network calls for Guidance, this API brings out the full potential of modern language models in a streamlined and efficient manner. 6 | 7 | ## Features 8 | 9 | - **Seamless Integration with oobabooga/text-generation-webui**: Guidance API seamlessly extends the functionalities of OOGA, enriching its feature set while preserving its ease of use. 10 | 11 | - **Network Calls with Guidance**: This extension makes network calls to Guidance, enabling you to harness the power of advanced language models conveniently. 12 | 13 | - **Rich Output Structure**: With the ability to support multiple generations, selections, conditionals, tool use, and more, Guidance API can create a rich output structure. 14 | 15 | - **Smart Generation Caching**: Guidance API optimizes performance and efficiency with smart seed-based generation caching. Tokens are cached on server 16 | 17 | - **Compatibility with Role-Based Chat Models**: Coming Soon 18 | 19 | Note the "select" tag in guidance is currently WIP 20 | 21 | --- 22 | 23 | 24 | 25 | ## Getting Started 26 | Example of flags for config in webui.py 27 | ``` 28 | CMD_FLAGS = " --chat --model-menu --model decapoda-research_llama-7b-hf --extensions guidance_api" 29 | 30 | ``` 31 | 32 | Then in guidance: 33 | 34 | 35 | ``` 36 | import guidance 37 | import json,requests 38 | import re,sys 39 | 40 | guidance.llm = guidance.llms.TGWUI("http://127.0.0.1:9555") 41 | 42 | character_maker = guidance("""The following is a character profile for an RPG game in JSON format. 43 | ```json 44 | { 45 | "id": "{{id}}", 46 | "description": "{{description}}", 47 | "name": "{{gen 'name'}}", 48 | "class": "{{gen 'class'}}", 49 | 50 | }```""") 51 | 52 | # generate a character 53 | res=character_maker( 54 | id="e1f491f7-7ab8-4dac-8c20-c92b5e7d883d", 55 | description="A quick and nimble fighter.", 56 | ) 57 | 58 | print(res) 59 | ``` 60 | 61 | Feel free to submit feedback, this repository is under active development 62 | 63 | -------------------------------------------------------------------------------- /caching.py: -------------------------------------------------------------------------------- 1 | import json 2 | import hashlib 3 | from typing import Any, Dict, Optional 4 | from abc import ABC, abstractmethod 5 | import os 6 | 7 | import diskcache 8 | import platformdirs 9 | 10 | 11 | class Cache(ABC): 12 | @abstractmethod 13 | def __getitem__(self, key: str) -> str: 14 | """get an item from the cache or throw key error""" 15 | pass 16 | 17 | @abstractmethod 18 | def __setitem__(self, key: str, value: str) -> None: 19 | """set an item in the cache""" 20 | pass 21 | 22 | @abstractmethod 23 | def __contains__(self, key: str) -> bool: 24 | """see if we can return a cached value for the passed key""" 25 | pass 26 | 27 | def create_key(self, llm: str, **kwargs: Dict[str, Any]) -> str: 28 | """Define a lookup key for a call to the given llm with the given kwargs. 29 | One of the keyword args could be `cache_key` in which case this function should respect that 30 | and use it. 31 | """ 32 | if "cache_key" in kwargs: 33 | return str(kwargs["cache_key"]) 34 | 35 | hasher = hashlib.md5() 36 | options_str = json.dumps(kwargs, sort_keys=True) 37 | 38 | combined = "{}{}".format(llm, options_str).encode() 39 | 40 | hasher.update(combined) 41 | return hasher.hexdigest() 42 | 43 | def clear(self): 44 | raise NotImplementedError() 45 | 46 | 47 | 48 | 49 | 50 | class DiskCache(Cache): 51 | """DiskCache is a cache that uses diskcache lib.""" 52 | def __init__(self, llm_name: str): 53 | self._diskcache = diskcache.Cache( 54 | os.path.join( 55 | platformdirs.user_cache_dir("guidance"), f"_{llm_name}.diskcache" 56 | ) 57 | ) 58 | 59 | def __getitem__(self, key: str) -> str: 60 | return self._diskcache[key] 61 | 62 | def __setitem__(self, key: str, value: str) -> None: 63 | self._diskcache[key] = value 64 | 65 | def __contains__(self, key: str) -> bool: 66 | return key in self._diskcache 67 | 68 | def clear(self): 69 | self._diskcache.clear() 70 | 71 | 72 | -------------------------------------------------------------------------------- /guidance_gen.py: -------------------------------------------------------------------------------- 1 | from modules import shared 2 | from modules.text_generation import encode, generate_reply,decode 3 | from .util import build_parameters 4 | from typing import Any, Dict, Optional, Callable 5 | import os 6 | import time 7 | import collections 8 | import regex 9 | import pygtrie 10 | import queue 11 | import torch 12 | import threading 13 | import logging 14 | import transformers 15 | from .processor import TokenHealingLogitsProcessor,BiasLogitsProcessor,RegexLogitsProcessor, RegexStoppingCriteria 16 | from .caching import Cache, DiskCache 17 | from .model_info import setup_model_data 18 | def printc(obj, color): 19 | color_code = { 20 | 'black': '30', 'red': '31', 'green': '32', 'yellow': '33', 21 | 'blue': '34', 'magenta': '35', 'cyan': '36', 'white': '37' 22 | } 23 | colored_text = f"\033[{color_code[color]}m{obj}\033[0m" if color in color_code else obj 24 | print(colored_text) 25 | 26 | 27 | 28 | class GuidanceGenerator: 29 | llm_name: str = shared.args.model 30 | 31 | def __init__(self): 32 | super().__init__() 33 | self.llm_model = shared.model 34 | self._call_counts = {} 35 | self.tokenizer = shared.tokenizer 36 | self.data= setup_model_data() 37 | 38 | self.bos_token= self.data['bos_token'] 39 | self.eos_token= self.data['eos_token'] 40 | self.eos_token_id = self.token_to_id(self.data['eos_token']) 41 | self.token_healing=True 42 | 43 | self.model_name = shared.args.model 44 | self.cache = DiskCache(llm_name=self.model_name) 45 | self.cache.clear() 46 | self.cache_version=1 47 | self._past_key_values = None 48 | self._prefix_cache = [] 49 | self._token_prefix_map = self._build_token_prefix_map() 50 | self.data['token_prefix_map_length']=len(self._token_prefix_map) 51 | printc(self.data,"green") 52 | 53 | def id_to_token(self, id): 54 | return decode(int(id)) 55 | 56 | def token_to_id(self, token): 57 | return encode(token) 58 | 59 | def encode(self, string, as_list=True): 60 | tmp= None 61 | if as_list: 62 | tmp= encode(string).tolist()[0] 63 | else: 64 | tmp= encode(string) 65 | return tmp 66 | 67 | 68 | 69 | def decode(self, id): 70 | tmp = decode(id) 71 | return tmp 72 | 73 | 74 | def _build_token_prefix_map(self): 75 | """ Build a map from token to index. 76 | """ 77 | printc(("vocab_size: ",self.tokenizer.vocab_size),"cyan") 78 | token_map = pygtrie.CharTrie() 79 | for i in range(self.tokenizer.vocab_size): 80 | s = self.id_to_token(i) 81 | if s in token_map: 82 | token_map[s].append(i) 83 | else: 84 | token_map[s] = [i] 85 | return token_map 86 | 87 | def new_string_builder(self, starting_ids=None): 88 | return TransformersStringBuilder(self.tokenizer, starting_ids) 89 | 90 | 91 | def prefix_matches(self, prefix): 92 | """ Return the list of tokens that match the given prefix. 93 | """ 94 | return [v for arr in self._token_prefix_map.values(prefix=prefix) for v in arr] 95 | def _gen_key(self, args_dict): 96 | return "_---_".join([str(v) for v in ([args_dict[k] for k in args_dict] + [self.model_name, self.__class__.__name__, self.cache_version])]) 97 | 98 | 99 | def _cache_params(self, args_dict) -> Dict[str, Any]: 100 | """get the parameters for generating the cache key""" 101 | key = self._gen_key(args_dict) 102 | # if we have non-zero temperature we include the call count in the cache key 103 | if args_dict.get("temperature", 0) > 0: 104 | args_dict["call_count"] = self._call_counts.get(key, 0) 105 | self._call_counts[key] = args_dict["call_count"] + 1 106 | args_dict["model_name"] = self.model_name 107 | args_dict["cache_version"] = self.cache_version 108 | args_dict["class_name"] =self.__class__.__name__ 109 | return args_dict 110 | 111 | def _update_prefix_cache(self, streamer): 112 | # note what we now have cached and ready for our next call in this session 113 | if self._past_key_values and len(streamer.generated_sequence) == 1: 114 | self._prefix_cache = streamer.generated_sequence[0][:self._past_key_values[0][0].shape[-2]] # self._past_key_values is already saved, this just aligns with it 115 | 116 | def _stream_then_save(self, streamer, key, thread): 117 | list_out = [] 118 | for out in streamer: 119 | list_out.append(out) 120 | yield out 121 | thread.join() # clean up the thread 122 | self.llm.cache[key] = list_out 123 | self._update_prefix_cache(streamer) 124 | self._last_computed_key = key 125 | 126 | 127 | 128 | 129 | def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n=1, max_tokens=1000, logprobs=None,top_p=1.0, echo=False, logit_bias=None, token_healing=None, pattern=None, stream=False,cache_seed=0, caching=None, **generate_kwargs): 130 | """ Generate a completion of the given prompt. 131 | """ 132 | 133 | args={ 134 | "prompt":prompt, "stop": stop, "stop_regex":stop_regex, "temperature": temperature, "n":n, 135 | "max_tokens":max_tokens, "logprobs":logprobs, "top_p":top_p, "echo":echo, "logit_bias":logit_bias, 136 | "token_healing":token_healing, "pattern":pattern, "stream":stream, "cache_seed":cache_seed, 137 | "caching":caching, "generate_kwargs":generate_kwargs, "model_name": self.model_name, 138 | "cache_version":self.cache_version, "class_name":self.__class__.__name__ 139 | } 140 | cache_params = self._cache_params(args) 141 | llm_cache = self.cache 142 | key = llm_cache.create_key(self.model_name, **cache_params) 143 | if stop is not None: 144 | if isinstance(stop, str): 145 | stop_regex = [regex.escape(stop)] 146 | else: 147 | stop_regex = [regex.escape(s) for s in stop] 148 | if isinstance(stop_regex, str): 149 | stop_regex = [stop_regex] 150 | if stop_regex is None: 151 | stop_regex = [] 152 | stop_regex.append(regex.escape(self.eos_token)) # make sure the end of sequence token is always included 153 | 154 | input_ids= encode(prompt) 155 | healed_token_ids = [] 156 | processors = [] 157 | stoppers = [] 158 | coded_prompt = decode(input_ids[0]) 159 | if token_healing: 160 | healer = TokenHealingLogitsProcessor(self, self.tokenizer.vocab_size, input_ids[0]) 161 | healed_token_ids = healer.healed_token_ids 162 | if len(healed_token_ids) > 0: 163 | input_ids = input_ids[:,:-len(healed_token_ids)] 164 | max_tokens += len(healed_token_ids) 165 | processors.append(healer) 166 | if logit_bias is not None: 167 | processors.append(BiasLogitsProcessor(self, self.tokenizer.vocab_size-1, logit_bias)) 168 | 169 | max_context = shared.settings['max_new_tokens_max'] 170 | 171 | if max_tokens + len(input_ids[0]) > max_context: 172 | max_tokens = max_context - len(input_ids[0]) 173 | 174 | prefix_match_len = 0 175 | if prefix_match_len == len(input_ids[0]): 176 | prefix_match_len -= 1 177 | 178 | #may cause issues 179 | if pattern is not None: 180 | processors.append(RegexLogitsProcessor(pattern, stop_regex, self, self.tokenizer.vocab_size-1, temperature == 0, len(coded_prompt), self.eos_token_id)) 181 | 182 | if stop_regex is not None: 183 | stoppers.append(RegexStoppingCriteria(stop_regex, self, len(coded_prompt))) 184 | 185 | streamer = TransformersStreamer( 186 | llm=self, 187 | input_ids=input_ids, 188 | stop_regex=stop_regex, 189 | healed_token_ids=healed_token_ids, 190 | prefix_length=len(coded_prompt), 191 | string_builder=self.new_string_builder, 192 | max_new_tokens=max_tokens, 193 | logprobs=logprobs 194 | ) 195 | 196 | generate_args = dict( 197 | inputs=input_ids, 198 | temperature=temperature, 199 | max_new_tokens=max_tokens, 200 | top_p=top_p, 201 | pad_token_id=self.llm_model.config.pad_token_id, 202 | logits_processor=transformers.LogitsProcessorList(processors), 203 | stopping_criteria=transformers.StoppingCriteriaList(stoppers), 204 | output_scores=logprobs is not None and logprobs > 0, 205 | return_dict_in_generate=True, 206 | **generate_kwargs 207 | ) 208 | 209 | do_sample = True 210 | if do_sample is True and temperature == 0: 211 | generate_args["do_sample"] = False 212 | elif do_sample is False and temperature > 0: 213 | generate_args["do_sample"] = True 214 | 215 | temperature = 0.005 if args['temperature'] == 0.0 else args['temperature'] 216 | body = { 217 | 'prompt': prompt, 218 | 'max_new_tokens': args['max_tokens'], 219 | 'do_sample': True, 220 | 'temperature': temperature, 221 | 'top_p': args['top_p'] 222 | } 223 | 224 | print(body) 225 | printc("generating sequence","yellow") 226 | prompt = body['prompt'] 227 | generate_params = build_parameters(body) 228 | stopping_strings = generate_params.pop('stopping_strings') 229 | generate_params['stream'] = False 230 | 231 | generated_sequence = generate_reply(prompt, generate_params, stopping_strings=stopping_strings, is_chat=self.data['instruction_following']) 232 | 233 | 234 | answer = '' 235 | for a in generated_sequence: 236 | answer = a 237 | printc(answer,"yellow") 238 | out = self.encode(answer, as_list=False) 239 | streamer.put(out) 240 | self.cache[key] = streamer.__next__() 241 | self._update_prefix_cache(streamer) 242 | 243 | return llm_cache[key] 244 | 245 | # return answer 246 | 247 | 248 | 249 | 250 | def __exit__(self, exc_type, exc_value, traceback): 251 | """ Restore the model to its original state by removing monkey patches. 252 | """ 253 | if getattr(self.llm.model_obj, "_orig_prepare_method", None) is not None: 254 | self.llm.model_obj.prepare_inputs_for_generation = self.llm.model_obj._orig_prepare_method 255 | del self.llm.model_obj._orig_prepare_method 256 | if getattr(self.llm.model_obj, "_orig_update_method", None) is not None: 257 | self.llm.model_obj._update_model_kwargs_for_generation = self.llm.model_obj._orig_update_method 258 | del self.llm.model_obj._orig_update_method 259 | return False 260 | 261 | # __call__ method 262 | class TransformersStringBuilder(): 263 | """This deals with the complexity of building up a string from tokens bit by bit.""" 264 | def __init__(self, tokenizer, llm, starting_ids=None): 265 | 266 | self.tokenizer = tokenizer 267 | self.token_strings = [] 268 | self._joint_string = "" 269 | if starting_ids is not None: 270 | self.extend(starting_ids) 271 | 272 | def extend(self, new_ids): 273 | new_token_strings = self.tokenizer.convert_ids_to_tokens(new_ids) 274 | self.token_strings.extend(new_token_strings) 275 | new_str = self.tokenizer.convert_tokens_to_string(self.token_strings) 276 | diff_str = new_str[len(self._joint_string):] 277 | self._joint_string = new_str 278 | return diff_str 279 | 280 | def pop(self): 281 | """Remove the last token from the string and return text it removed.""" 282 | self.token_strings.pop() 283 | new_str = self.tokenizer.convert_tokens_to_string(self.token_strings) 284 | diff_str = self._joint_string[len(new_str):] 285 | self._joint_string = new_str 286 | return diff_str 287 | 288 | def __str__(self): 289 | return self._joint_string 290 | 291 | def __len__(self): 292 | return len(self._joint_string) 293 | 294 | class TransformersStreamer(): 295 | def __init__(self, llm, input_ids, stop_regex, healed_token_ids, prefix_length, string_builder, max_new_tokens, logprobs, timeout=None): 296 | self.llm = llm 297 | self.input_ids = input_ids 298 | self.stop_regex = stop_regex 299 | self.healed_token_ids = healed_token_ids 300 | self.logprobs = logprobs 301 | self.string_builder=string_builder 302 | self.max_total_tokens = max_new_tokens + len(input_ids[0]) 303 | self.timeout = timeout 304 | self.str_pos = [prefix_length for i in range(len(self.input_ids))] 305 | self.out_queue = queue.Queue() 306 | self.sequence_pos = [len(self.input_ids[0]) for i in range(len(self.input_ids))] 307 | self.generated_sequence = [[] for i in range(len(self.input_ids))] 308 | self.display_logprobs = [[] for i in range(len(self.input_ids))] 309 | self.generated_string = [self.string_builder(input_ids[0]) for i in range(len(self.input_ids))] 310 | # 311 | self.prefix_cache = [] 312 | 313 | def put(self, token_obj): 314 | if isinstance(token_obj, torch.Tensor): 315 | new_tokens = token_obj 316 | else: 317 | new_tokens = token_obj['sequences'] 318 | 319 | if isinstance(new_tokens, torch.Tensor): 320 | new_tokens = new_tokens.cpu() 321 | 322 | # if we are given a single sequence, then make it a batch of size 1 323 | if len(new_tokens.shape) == 1: 324 | new_tokens = new_tokens.unsqueeze(0) 325 | 326 | # extract the scores if we are given them (and format them to be the same shape as the tokens) 327 | if self.logprobs: 328 | assert len(new_tokens) == 1, "logprobs are not supported for batched generation right now in guidance.llms.Transformers" 329 | new_scores = [torch.nn.functional.log_softmax(x, dim=-1).cpu() for x in token_obj['scores']] 330 | len_diff = len(new_tokens[0]) - len(new_scores) 331 | if len_diff > 0: 332 | new_scores = [None for i in range(len_diff)] + new_scores 333 | new_scores = [new_scores] 334 | 335 | out = {"choices": [None for i in range(len(self.input_ids))]} 336 | put_data = False 337 | for i in range(len(self.input_ids)): 338 | self.generated_sequence[i].extend(list(new_tokens[i])) 339 | 340 | # save logprobs if needed 341 | if self.logprobs: 342 | for scores in new_scores[i]: 343 | if scores is None: 344 | self.display_logprobs[i].append(None) 345 | else: 346 | top_inds = scores[0].argsort(descending=True)[:self.logprobs] # TODO: verify the [0] is always correct 347 | self.display_logprobs[i].append({self.llm.id_to_token(j): float(scores[0][j]) for j in top_inds}) 348 | 349 | if self.sequence_pos[i] < len(self.generated_sequence[i]): 350 | display_tokens = list(self.generated_sequence[i][self.sequence_pos[i]:]) 351 | val = self.generated_string[i].extend(display_tokens) 352 | if self.str_pos[i] < len(self.generated_string[i]): 353 | val = str(self.generated_string[i])[self.str_pos[i]:] 354 | finish_reason = None 355 | 356 | # check why we stopped 357 | stop_pos = len(val) + 1 358 | if len(self.generated_sequence[i]) >= self.max_total_tokens: 359 | finish_reason = "length" 360 | elif self.generated_sequence[i][-1] == self.llm.tokenizer.eos_token_id: 361 | finish_reason = "endoftext" 362 | eos_str = self.generated_string[i].pop() # remove the end of text token 363 | stop_pos = len(val) - len(eos_str) 364 | 365 | # trim off the stop regex matches if needed 366 | found_partial = False 367 | stop_text = None 368 | if self.stop_regex is not None:# and (finish_reason is None or len(self.input_ids) > 1): 369 | stop_regex_obj = [regex.compile(s) for s in self.stop_regex] 370 | for s in stop_regex_obj: 371 | m = s.search(val, partial=True) 372 | if m: 373 | span = m.span() 374 | if span[1] > span[0]: 375 | if m.partial: # we might be starting a stop sequence, so we can't emit anything yet 376 | found_partial = True 377 | break 378 | else: 379 | stop_text = val[span[0]:span[1]] 380 | stop_pos = min(span[0], stop_pos) 381 | break 382 | 383 | # record the reason we stopped (if we have stopped) 384 | if stop_pos <= len(val): 385 | finish_reason = "stop" 386 | 387 | # emit the data if we are not potentially in the middle of a stop sequence 388 | if not found_partial or finish_reason is not None: 389 | out["choices"][i] = { 390 | "text": val[:stop_pos], 391 | "finish_reason": finish_reason, 392 | "stop_text": stop_text, 393 | "logprobs": { 394 | # "token_healing_prefix": self.last_token_str, 395 | "top_logprobs": self.display_logprobs[i][self.sequence_pos[i]:] 396 | } 397 | } 398 | self.str_pos[i] = len(self.generated_string[i]) 399 | put_data = True 400 | self.sequence_pos[i] = len(self.generated_sequence[i]) 401 | 402 | if put_data: 403 | self.out_queue.put(out) 404 | 405 | def end(self): 406 | 407 | for i in range(len(self.input_ids)): 408 | assert self.str_pos[i] >= len(self.generated_string[i]), "Not all data was flushed, this means generation stopped for an unknown reason!" 409 | 410 | self.out_queue.put(None) 411 | 412 | def __iter__(self): 413 | return self 414 | 415 | def __next__(self): 416 | value = self.out_queue.get(timeout=self.timeout) 417 | if value is None: 418 | raise StopIteration() 419 | else: 420 | return value 421 | 422 | 423 | def _update_prefix_cache(self, streamer): 424 | # note what we now have cached and ready for our next call in this session 425 | if self._past_key_values and len(streamer.generated_sequence) == 1: 426 | self._prefix_cache = streamer.generated_sequence[0][:self._past_key_values[0][0].shape[-2]] 427 | 428 | 429 | 430 | @staticmethod 431 | def role_start(role): 432 | raise NotImplementedError("In order to use chat role tags you need to use a chat-specific subclass of Transformers for your LLM from guidance.transformers.*!") 433 | 434 | -------------------------------------------------------------------------------- /model_info.py: -------------------------------------------------------------------------------- 1 | from modules import shared 2 | from pathlib import Path 3 | from modules.text_generation import encode, generate_reply,decode 4 | import yaml 5 | 6 | 7 | def id_to_token(id): 8 | return decode(int(id)) 9 | 10 | def token_to_id(token): 11 | return encode(token) 12 | def setup_model_data(): 13 | config_dict = vars(shared.model.config) 14 | data={} 15 | 16 | 17 | if shared.settings['instruction_template'] is not None: 18 | template=shared.settings['instruction_template'] 19 | 20 | if template.lower() == 'none': 21 | data['instruction_following']=False 22 | data['instruction_template'] = None 23 | else: 24 | filepath = Path(f'characters/instruction-following/{template}.yaml') 25 | if filepath.exists(): 26 | with open(filepath, 'r', encoding='utf-8') as f: 27 | data['instruction_template'] = yaml.safe_load(f) 28 | 29 | data['instruction_following']=True 30 | else: 31 | data['instruction_following']= False 32 | data['instruction_template'] = None 33 | 34 | if shared.settings['add_bos_token'] is not None: 35 | data['add_bos_token']= shared.settings['add_bos_token'] 36 | if shared.settings['ban_eos_token'] is not None: 37 | data['ban_eos_token']= shared.settings['ban_eos_token'] 38 | data['vocab_size']= shared.tokenizer.vocab_size 39 | if shared.model_name is not None: 40 | data['model_name']=shared.model_name 41 | else: 42 | data['model_name']="unknown_model_name" 43 | 44 | data["bos_token_id"]=config_dict["bos_token_id"] 45 | data["eos_token_id"]=config_dict["eos_token_id"] 46 | data["bos_token"]=id_to_token(data["bos_token_id"]) 47 | data["eos_token"]=id_to_token(data["eos_token_id"]) 48 | 49 | if shared.tokenizer.eos_token is not None: 50 | data["eos_token"]= shared.tokenizer.eos_token 51 | if shared.tokenizer.bos_token is not None: 52 | data["bos_token"]= shared.tokenizer.bos_token 53 | 54 | if shared.tokenizer.eos_token_id is not None: 55 | data["eos_token_id"]= shared.tokenizer.eos_token_id 56 | if shared.tokenizer.bos_token_id is not None: 57 | data["bos_token_id"]= shared.tokenizer.bos_token_id 58 | 59 | if "bias" in config_dict: 60 | data["bias"]= config_dict["bias"] 61 | if "temperature" in config_dict: 62 | data["temperature"]= config_dict["temperature"] 63 | if "top_p" in config_dict: 64 | data["top_p"]= config_dict["top_p"] 65 | if "top_k" in config_dict: 66 | data["top_k"]= config_dict["top_k"] 67 | 68 | 69 | 70 | 71 | if "falcon" in shared.args.model.lower(): 72 | data['eos_token']= None 73 | data['bos_token']= None 74 | if "instruct" in shared.args.model.lower(): 75 | data['instruction_template']={ 76 | "user": '>>QUESTION<<', 77 | "bot": '>>ANSWER<<' 78 | } 79 | 80 | print(data) 81 | return data -------------------------------------------------------------------------------- /processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import collections 4 | import regex 5 | import pygtrie 6 | import queue 7 | import threading 8 | import logging 9 | # __call__ method 10 | 11 | 12 | class TokenHealingLogitsProcessor(): 13 | """ Token healing. 14 | 15 | When we tokenize the prompt the last token(s) we get are not the last token(s) we would 16 | have gotten if the prompt + generation was concatented and then tokenized. This 17 | is not good because it does not align with the pretraining of the model, so 18 | we "heal" this boundary by backing up as many tokens as needed and then forcing the first tokens 19 | generated to start with the prefix of the tokens we removed from the prompt. This could 20 | result in the same tokens at the end of the prompt, or some suffix of the tokens we removed 21 | could be replaced by a single longer one that crosses the prompt boundary. 22 | """ 23 | 24 | def __init__(self, model, vocab_size, prompt_ids, bias_value=100.): 25 | """ Build a new TokenHealingLogitsProcessor. 26 | 27 | Note that bias_value is in score space (log-odds normally) and should be 28 | enough to ensure those tokens are the only ones used. 29 | """ 30 | 31 | # loop backwards through the prompt tokens looking for places where there are possible 32 | # extensions that cross the prompt boundary 33 | self.model=model 34 | prefix_str = "" 35 | self.extension_tokens = [] 36 | for i in range(len(prompt_ids)-1, max(len(prompt_ids)-10, -1), -1): 37 | token_str = model.id_to_token(prompt_ids[i]) 38 | prefix_str = token_str + prefix_str 39 | try: 40 | extensions = model.prefix_matches(prefix_str) 41 | except KeyError: # this must be a special token outside the vocab, so we assume it does not have any valid extensions 42 | extensions = [] 43 | self.extension_tokens.append(extensions) 44 | if i != len(prompt_ids)-1: 45 | self.extension_tokens[-1].append(prompt_ids[i]) # add the token used in the input prompt to the list of possible extensions 46 | self.extension_tokens = self.extension_tokens[::-1] 47 | 48 | # prune off any extension token positions that don't have multiple multiple possible extensions 49 | found_extensions = False 50 | for i in range(len(self.extension_tokens)): 51 | if len(self.extension_tokens[i]) > 1: 52 | self.extension_tokens = self.extension_tokens[i:] 53 | found_extensions = True 54 | break 55 | if found_extensions: 56 | self.healed_token_ids = prompt_ids[len(prompt_ids)-len(self.extension_tokens):] 57 | else: 58 | self.extension_tokens = [] 59 | self.healed_token_ids = [] 60 | 61 | # if we have multiple possible completions past the last token, then biasing is needed 62 | if len(self.extension_tokens) > 0: 63 | import torch 64 | 65 | # build a set of masks for each possible extension position 66 | self.token_masks = [] 67 | for i in range(len(self.extension_tokens)): 68 | token_mask = torch.zeros(vocab_size) 69 | token_mask.scatter_(0, torch.tensor(self.extension_tokens[i]), bias_value) 70 | self.token_masks.append(token_mask) 71 | 72 | self.num_extensions = 0 73 | 74 | def __call__(self, input_ids, scores): 75 | 76 | # we only bias the first token generated 77 | if self.num_extensions >= len(self.extension_tokens): 78 | return scores 79 | self.num_extensions += 1 80 | 81 | # check if the last token was from the original prompt (if not then we have already "healed" by choosing a token that crosses the prompt boundary) 82 | if self.num_extensions > 1 and input_ids[0][-1] != self.healed_token_ids[self.num_extensions-2]: 83 | return scores 84 | 85 | # handle list inputs 86 | if isinstance(scores, list): 87 | import torch 88 | scores = torch.tensor(scores) 89 | 90 | # make only allowed tokens possible 91 | # Check size mismatch and correct 92 | if scores.shape[1] != self.token_masks[self.num_extensions-1].shape[0]: 93 | scores = scores[:, :-1] 94 | 95 | token_mask = self.token_masks[self.num_extensions-1].to(scores.device) 96 | 97 | res = (scores + token_mask ) 98 | # dg=(res).tolist() 99 | 100 | 101 | return res 102 | # __call__ method 103 | class BiasLogitsProcessor(): 104 | """ Simple token biasing. 105 | """ 106 | 107 | def __init__(self, model, vocab_size, logit_bias): 108 | """ Build a new BiasLogitsProcessor. 109 | """ 110 | import torch 111 | 112 | self.bias_vector = torch.zeros(vocab_size) 113 | for token, bias in logit_bias.items(): 114 | self.bias_vector[token] = bias 115 | self.bias_vector = self.bias_vector.to(model.device) 116 | 117 | def __call__(self, input_ids, scores): 118 | 119 | # handle list inputs 120 | if isinstance(scores, list): 121 | import torch 122 | scores = torch.tensor(scores) 123 | 124 | return scores + self.bias_vector 125 | 126 | 127 | # __call__ method 128 | class RegexLogitsProcessor(): 129 | """ Pattern guiding. 130 | 131 | Guide generation to match a regular expression. 132 | TODO: currently slow, could be made much faster by doing rejection sampling inline with the sampling/greedy process. 133 | """ 134 | 135 | def __init__(self, pattern, stop_regex, llm, vocab_size, is_greedy, prefix_length, eos_token_id, max_consider=500000): 136 | """ Build a new TokenHealingLogitsProcessor. 137 | 138 | Parameters 139 | ---------- 140 | pattern : str 141 | The regex pattern we are seeking to match. 142 | stop_regex : str or list of str 143 | The stop regex(s) allowed to come after this pattern. 144 | llm : function 145 | The llm. 146 | vocab_size : int 147 | The size of the vocabulary. 148 | is_greedy : bool 149 | The token selection mode currently in use. We need to know this so we can 150 | effectively take over that sampling process inside this logit processor. 151 | eos_token_id : int 152 | The end of the stop token of the model. 153 | max_consider : int 154 | How many top values to bias. Note that we could remove this option once this 155 | processor is performance optimized (by integrating it into the sampling/greedy process). 156 | """ 157 | import torch 158 | 159 | if isinstance(stop_regex, str): 160 | stop_regex = [stop_regex] 161 | self.pattern_no_stop = regex.compile(pattern) 162 | self.pattern = regex.compile(pattern + "(" + "|".join(stop_regex) + ")?") 163 | self.llm = llm 164 | self.is_greedy = is_greedy 165 | self.prefix_length = prefix_length 166 | self.max_consider = max_consider 167 | self.bias_vector = torch.zeros(vocab_size) 168 | self.current_strings = None 169 | self.current_length = 0 170 | self.forced_chars = 0 171 | self.eos_token_id = eos_token_id 172 | 173 | def __call__(self, input_ids, scores): 174 | import torch 175 | 176 | # handle 1D inputs 177 | one_dim = False 178 | if not isinstance(input_ids[0], collections.abc.Sequence) and not (hasattr(input_ids[0], "shape") and len(input_ids[0].shape) > 0): 179 | one_dim = True 180 | input_ids = torch.tensor(input_ids).unsqueeze(0) 181 | scores = torch.tensor(scores).unsqueeze(0) 182 | 183 | # extend our current strings 184 | if self.current_strings is None: 185 | self.current_strings = [self.llm.new_string_builder() for i in range(len(input_ids))] 186 | for i in range(len(self.current_strings)): 187 | self.current_strings[i].extend(input_ids[i][self.current_length:]) 188 | 189 | assert len(self.current_strings) == 1, "Regex patterns guides do not support batched inference with Transformers yet!" 190 | 191 | self.current_length = len(input_ids[0]) 192 | 193 | # compute the bias values 194 | self.bias_vector[:] = 0 195 | sort_inds = torch.argsort(scores, 1, True) 196 | to_bias = [] 197 | for i in range(min(sort_inds.shape[1], self.max_consider)): 198 | self.current_strings[0].extend([sort_inds[0,i]]) 199 | proposed_string = str(self.current_strings[0])[self.prefix_length:] 200 | self.current_strings[0].pop() 201 | m = self.pattern.fullmatch(proposed_string, partial=True) # partial means we don't match currently but might as the string grows 202 | if m: 203 | to_bias.append(int(sort_inds[0, i])) 204 | if self.is_greedy: # TODO: make this much faster for non-greedy sampling (by tracking how much prob mass we have looked through perhaps...) 205 | break # we are done if we are doing greedy sampling and we found the top valid hit 206 | 207 | # if we found no more valid tokens then we just end the sequence 208 | if not len(to_bias): 209 | to_bias = [self.eos_token_id] 210 | 211 | # bias allowed tokens 212 | min_to_bias = float(scores[0, to_bias].min()) 213 | bias_value = scores[0, sort_inds[0, 0]] - min_to_bias + 10 # make sure the tokens that fit the pattern have higher scores than the top value 214 | for x in to_bias: 215 | self.bias_vector[x] = bias_value 216 | out = scores + self.bias_vector.to(scores.device) 217 | if one_dim: 218 | return out[0] 219 | else: 220 | return out 221 | # __call__ method 222 | class RegexStoppingCriteria(): 223 | def __init__(self, stop_pattern, llm, prefix_length): 224 | if isinstance(stop_pattern, str): 225 | self.stop_patterns = [regex.compile(stop_pattern)] 226 | else: 227 | self.stop_patterns = [regex.compile(pattern) for pattern in stop_pattern] 228 | self.prefix_length = prefix_length 229 | self.llm = llm 230 | self.current_strings = None 231 | self.current_length = 0 232 | 233 | def __call__(self, input_ids, scores, **kwargs): 234 | 235 | # handle 1D inputs 236 | if not isinstance(input_ids[0], collections.abc.Sequence) and not (hasattr(input_ids[0], "shape") and len(input_ids[0].shape) > 0): 237 | input_ids = [input_ids] 238 | 239 | # extend our current strings 240 | if self.current_strings is None: 241 | self.current_strings = [self.llm.new_string_builder() for _ in range(len(input_ids))] 242 | for i in range(len(self.current_strings)): 243 | self.current_strings[i].extend(input_ids[i][self.current_length:]) 244 | 245 | self.current_length = len(input_ids[0]) 246 | 247 | # check if all of the strings match a stop string (and hence we can stop the batch inference) 248 | all_done = True 249 | for i in range(len(self.current_strings)): 250 | found = False 251 | print(self.current_strings) 252 | for s in self.stop_patterns: 253 | 254 | if s.search(str(self.current_strings[i])[self.prefix_length:]): 255 | found = True 256 | if not found: 257 | all_done = False 258 | break 259 | 260 | return all_done -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flask_cloudflared==0.0.12 2 | sentence-transformers -------------------------------------------------------------------------------- /script.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import os,sys 4 | import time 5 | import torch 6 | import requests 7 | import yaml 8 | from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer 9 | from threading import Thread 10 | from modules.utils import get_available_models 11 | from .guidance_gen import GuidanceGenerator 12 | import numpy as np 13 | 14 | import traceback 15 | import unittest 16 | from modules import shared 17 | 18 | port=9555 19 | def printc(obj, color): 20 | color_code = { 21 | 'black': '30', 'red': '34', 'green': '32', 'yellow': '33', 22 | 'blue': '34', 'magenta': '35', 'cyan': '36', 'white': '37' 23 | } 24 | colored_text = f"\033[{color_code[color]}m{obj}\033[0m" if color in color_code else obj 25 | print(colored_text) 26 | 27 | 28 | 29 | class Handler(BaseHTTPRequestHandler): 30 | 31 | def __init__(self, *args, gen=None, **kwargs): 32 | self.gen = gen 33 | super().__init__(*args, **kwargs) 34 | 35 | def do_GET(self): 36 | if self.path == '/api/v1/model': 37 | self.send_response(200) 38 | self.end_headers() 39 | 40 | 41 | response = json.dumps({"results":self.gen.data}) 42 | self.wfile.write(response.encode('utf-8')) 43 | else: 44 | self.send_error(404) 45 | 46 | def do_POST(self): 47 | content_length = int(self.headers['Content-Length']) 48 | body = json.loads(self.rfile.read(content_length).decode('utf-8')) 49 | 50 | if self.path == '/api/v1/call': 51 | self.send_response(200) 52 | self.send_header('Content-Type', 'application/json') 53 | self.end_headers() 54 | printc("Call request received, accuiring generation lock", "green") 55 | printc(body, "blue") 56 | res="" 57 | try: 58 | res= self.gen.__call__( 59 | prompt=body["prompt"], stop=body["stop"], stop_regex=body["stop_regex"], 60 | temperature=body["temperature"], n=body["n"], max_tokens=body["max_tokens"], 61 | logprobs=body["logprobs"],top_p=body["top_p"], echo=body["echo"], logit_bias=body["logit_bias"], 62 | token_healing=body["token_healing"], pattern=body["pattern"],stream=False,cache_seed=-1, 63 | caching=False, 64 | ) 65 | printc(res, "green") 66 | except Exception as e: 67 | printc("An error occurred: " + str(e), "red") 68 | finally: 69 | print("Call request fulfilled, releasing generation lock") 70 | 71 | 72 | response = json.dumps({ 73 | 'choices': [{ 74 | 'text': res 75 | }] 76 | }) 77 | 78 | self.wfile.write(response.encode('utf-8')) 79 | 80 | elif self.path == '/api/v1/encode': 81 | self.send_response(200) 82 | self.send_header('Content-Type', 'application/json') 83 | self.end_headers() 84 | printc("Encode request received", "green") 85 | printc(body, "blue") 86 | string = body['text'] 87 | res=self.gen.encode(string) 88 | 89 | response = json.dumps({ 90 | 'results': [{ 91 | 'tokens':res 92 | }] 93 | }) 94 | 95 | self.wfile.write(response.encode('utf-8')) 96 | elif self.path == '/api/v1/decode': 97 | self.send_response(200) 98 | self.send_header('Content-Type', 'application/json') 99 | self.end_headers() 100 | printc("decode request received", "green") 101 | printc(body, "blue") 102 | tokens = (body['tokens']) 103 | print("my tokens",tokens, type(tokens)) 104 | decoded_sequences =self.gen.decode(tokens) 105 | response = json.dumps({ 106 | 'results': [{ 107 | 'ids': decoded_sequences 108 | }] 109 | }) 110 | 111 | self.wfile.write(response.encode('utf-8')) 112 | 113 | 114 | 115 | 116 | def _run_server(port: int, gen): 117 | 118 | 119 | address = '0.0.0.0' if shared.args.listen else '127.0.0.1' 120 | 121 | class CustomHandler(Handler): 122 | def __init__(self, *args, **kwargs): 123 | super().__init__(*args, gen=gen, **kwargs) 124 | 125 | server = ThreadingHTTPServer((address, port), CustomHandler) 126 | 127 | def on_start(public_url: str): 128 | print(f'Starting non-streaming server at public url {public_url}/api') 129 | 130 | server.serve_forever() 131 | 132 | def setup(): 133 | printc("starting guidance server","green") 134 | gen =GuidanceGenerator() 135 | 136 | Thread(target=_run_server, args=[port,gen], daemon=True).start() -------------------------------------------------------------------------------- /streamer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import collections 4 | import regex 5 | import pygtrie 6 | import queue 7 | import threading 8 | import logging 9 | import collections.abc 10 | from modules import shared 11 | from modules.text_generation import encode, generate_reply, decode 12 | from typing import Any, Dict, Optional, Callable 13 | 14 | 15 | class TransformersStringBuilder(): 16 | """This deals with the complexity of building up a string from tokens bit by bit.""" 17 | def __init__(self, tokenizer, starting_ids=None): 18 | self.tokenizer = tokenizer 19 | self.token_strings = [] 20 | self._joint_string = "" 21 | if starting_ids is not None: 22 | self.extend(starting_ids) 23 | 24 | def extend(self, new_ids): 25 | new_token_strings = self.tokenizer.convert_ids_to_tokens(new_ids) 26 | self.token_strings.extend(new_token_strings) 27 | new_str = self.tokenizer.convert_tokens_to_string(self.token_strings) 28 | diff_str = new_str[len(self._joint_string):] 29 | self._joint_string = new_str 30 | return diff_str 31 | 32 | def pop(self): 33 | """Remove the last token from the string and return text it removed.""" 34 | self.token_strings.pop() 35 | new_str = self.tokenizer.convert_tokens_to_string(self.token_strings) 36 | diff_str = self._joint_string[len(new_str):] 37 | self._joint_string = new_str 38 | return diff_str 39 | 40 | def __str__(self): 41 | return self._joint_string 42 | 43 | def __len__(self): 44 | return len(self._joint_string) 45 | 46 | class TransformersStreamer(): 47 | def __init__(self, input_ids, stop_regex, healed_token_ids, prefix_length, llm, max_new_tokens, logprobs, timeout=None): 48 | 49 | self.input_ids = input_ids 50 | self.stop_regex = stop_regex 51 | self.healed_token_ids = healed_token_ids 52 | print(logprobs) 53 | self.logprobs = logprobs 54 | self.llm = llm 55 | self.max_total_tokens = max_new_tokens + len(input_ids[0]) 56 | self.timeout = timeout 57 | self.str_pos = [prefix_length for i in range(len(self.input_ids))] 58 | self.out_queue = queue.Queue() 59 | self.sequence_pos = [len(self.input_ids[0]) for i in range(len(self.input_ids))] 60 | self.generated_sequence = [[] for i in range(len(self.input_ids))] 61 | self.display_logprobs = [[] for i in range(len(self.input_ids))] 62 | self.generated_string = [self.llm.new_string_builder(starting_ids=input_ids[0]) for i in range(len(self.input_ids))] 63 | self.prefix_cache = [] 64 | 65 | def put(self, token_obj): 66 | print(self.display_logprobs) 67 | import torch 68 | if isinstance(token_obj, torch.Tensor): 69 | new_tokens = token_obj 70 | else: 71 | new_tokens = token_obj['sequences'] 72 | 73 | if isinstance(new_tokens, torch.Tensor): 74 | new_tokens = new_tokens.cpu() 75 | 76 | # if we are given a single sequence, then make itstop=', a batch of size 1 77 | if len(new_tokens.shape) == 1: 78 | new_tokens = new_tokens.unsqueeze(0) 79 | 80 | # extract the scores if we are given them (and format them to be the same shape as the tokens) 81 | if self.logprobs: 82 | assert len(new_tokens) == 1, "logprobs are not supported for batched generation right now in guidance.llms.Transformers" 83 | new_scores = [torch.nn.functional.log_softmax(x, dim=-1).cpu() for x in token_obj['scores']] 84 | len_diff = len(new_tokens[0]) - len(new_scores) 85 | if len_diff > 0: 86 | new_scores = [None for i in range(len_diff)] + new_scores 87 | new_scores = [new_scores] 88 | 89 | out = {"choices": [None for i in range(len(self.input_ids))]} 90 | put_data = False 91 | for i in range(len(self.input_ids)): 92 | self.generated_sequence[i].extend(list(new_tokens[i])) 93 | 94 | # save logprobs if needed 95 | if self.logprobs: 96 | for scores in new_scores[i]: 97 | if scores is None: 98 | self.display_logprobs[i].append(None) 99 | else: 100 | top_inds = scores[0].argsort(descending=True)[:self.logprobs] # TODO: verify the [0] is always correct 101 | self.display_logprobs[i].append({self.llm.id_to_token(j): float(scores[0][j]) for j in top_inds}) 102 | 103 | if self.sequence_pos[i] < len(self.generated_sequence[i]): 104 | display_tokens = list(self.generated_sequence[i][self.sequence_pos[i]:]) 105 | val = self.generated_string[i].extend(display_tokens) 106 | 107 | if self.str_pos[i] < len(self.generated_string[i]): 108 | val = str(self.generated_string[i])[self.str_pos[i]:] 109 | finish_reason = None 110 | 111 | # check why we stopped 112 | stop_pos = len(val) + 1 113 | if len(self.generated_sequence[i]) >= self.max_total_tokens: 114 | finish_reason = "length" 115 | elif self.generated_sequence[i][-1] == self.llm.tokenizer.eos_token_id: 116 | finish_reason = "endoftext" 117 | eos_str = self.generated_string[i].pop() # remove the end of text token 118 | stop_pos = len(val) - len(eos_str) 119 | 120 | # record the reason we stopped (if we have stopped) 121 | if finish_reason is not None: 122 | out["choices"][i] = { 123 | "text": val[:stop_pos], 124 | "finish_reason": finish_reason, 125 | "stop_text": None, # no stop text since stop is None 126 | "logprobs": { 127 | "top_logprobs": self.display_logprobs[i][self.sequence_pos[i]:] 128 | } 129 | } 130 | self.str_pos[i] = len(self.generated_string[i]) 131 | put_data = True 132 | self.sequence_pos[i] = len(self.generated_sequence[i]) 133 | 134 | if put_data: 135 | self.out_queue.put(out) 136 | 137 | 138 | def end(self): 139 | # make sure we have flushed all of the data 140 | for i in range(len(self.input_ids)): 141 | assert self.str_pos[i] >= len(self.generated_string[i]), "Not all data was flushed, this means generation stopped for an unknown reason!" 142 | 143 | self.out_queue.put(None) 144 | 145 | def __iter__(self): 146 | return self 147 | 148 | def __next__(self): 149 | value = self.out_queue.get(timeout=self.timeout) 150 | if value is None: 151 | raise StopIteration() 152 | else: 153 | return value 154 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import time 2 | import traceback 3 | from threading import Thread 4 | from typing import Callable, Optional 5 | 6 | from modules import shared 7 | from modules.chat import load_character_memoized 8 | 9 | 10 | def build_parameters(body, chat=False): 11 | 12 | generate_params = { 13 | 'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))), 14 | 'do_sample': bool(body.get('do_sample', True)), 15 | 'temperature': float(body.get('temperature', 0.5)), 16 | 'top_p': float(body.get('top_p', 1)), 17 | 'typical_p': float(body.get('typical_p', body.get('typical', 1))), 18 | 'epsilon_cutoff': float(body.get('epsilon_cutoff', 0)), 19 | 'eta_cutoff': float(body.get('eta_cutoff', 0)), 20 | 'tfs': float(body.get('tfs', 1)), 21 | 'top_a': float(body.get('top_a', 0)), 22 | 'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))), 23 | 'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)), 24 | 'top_k': int(body.get('top_k', 0)), 25 | 'min_length': int(body.get('min_length', 0)), 26 | 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)), 27 | 'num_beams': int(body.get('num_beams', 1)), 28 | 'penalty_alpha': float(body.get('penalty_alpha', 0)), 29 | 'length_penalty': float(body.get('length_penalty', 1)), 30 | 'early_stopping': bool(body.get('early_stopping', False)), 31 | 'mirostat_mode': int(body.get('mirostat_mode', 0)), 32 | 'mirostat_tau': float(body.get('mirostat_tau', 5)), 33 | 'mirostat_eta': float(body.get('mirostat_eta', 0.1)), 34 | 'seed': int(body.get('seed', -1)), 35 | 'add_bos_token': bool(body.get('add_bos_token', True)), 36 | 'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))), 37 | 'ban_eos_token': bool(body.get('ban_eos_token', False)), 38 | 'skip_special_tokens': bool(body.get('skip_special_tokens', True)), 39 | 'custom_stopping_strings': '', # leave this blank 40 | 'stopping_strings': body.get('stopping_strings', []), 41 | } 42 | 43 | if chat: 44 | character = body.get('character') 45 | instruction_template = body.get('instruction_template') 46 | name1, name2, _, greeting, context, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), shared.settings['name2'], instruct=False) 47 | name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True) 48 | generate_params.update({ 49 | 'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])), 50 | 'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])), 51 | 'chat_generation_attempts': int(body.get('chat_generation_attempts', shared.settings['chat_generation_attempts'])), 52 | 'mode': str(body.get('mode', 'chat')), 53 | 'name1': name1, 54 | 'name2': name2, 55 | 'context': context, 56 | 'greeting': greeting, 57 | 'name1_instruct': name1_instruct, 58 | 'name2_instruct': name2_instruct, 59 | 'context_instruct': context_instruct, 60 | 'turn_template': turn_template, 61 | 'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])), 62 | }) 63 | 64 | return generate_params 65 | 66 | 67 | def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): 68 | Thread(target=_start_cloudflared, args=[ 69 | port, max_attempts, on_start], daemon=True).start() 70 | 71 | 72 | def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): 73 | try: 74 | from flask_cloudflared import _run_cloudflared 75 | except ImportError: 76 | print('You should install flask_cloudflared manually') 77 | raise Exception( 78 | 'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.') 79 | 80 | for _ in range(max_attempts): 81 | try: 82 | public_url = _run_cloudflared(port, port + 1) 83 | 84 | if on_start: 85 | on_start(public_url) 86 | 87 | return 88 | except Exception: 89 | traceback.print_exc() 90 | time.sleep(3) 91 | 92 | raise Exception('Could not start cloudflared.') 93 | --------------------------------------------------------------------------------