├── ExllamaV2_TensorParallel_Files ├── 20240830-16-23-25.json ├── exllamav2.py ├── exllamav2_hf.py └── shared.py ├── README.md └── sd_api_pictures ├── with_ADetailer └── script.py └── without_ADetailer └── script.py /ExllamaV2_TensorParallel_Files/20240830-16-23-25.json: -------------------------------------------------------------------------------- 1 | { 2 | "internal": [ 3 | [ 4 | "<|BEGIN-VISIBLE-CHAT|>", 5 | "How can I help you today?" 6 | ], 7 | [ 8 | "Hello, how are you functioning today?", 9 | "I'm functioning well, thank you for asking! I'm here and ready to assist you. How about you? How's your day going?" 10 | ], 11 | [ 12 | "I'm doing well too. There is this quantization for LLMs called exllamaV2, recently it has implemented tensor parallelism, so folks running models on multi-gpu systems can inference faster....much faster in fact. ExllamaV2 is a loader in a framework called textgeneration webui. This allows different loaders such as ExllamaV2.\n\nSo I have exllamaV2 working with textgen using the exllamav2 loader, but i want to also get it working with the exllamav2_HF loader, this is a huggingface specific implementation of exllamav2. I'll give you the code for both:\n\nCurrently working exllamav2 code:\n #!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Thu Aug 29 18:06:29 2024\n\n@author: myself\n\"\"\"\n\nimport traceback\nfrom pathlib import Path\n\nimport torch\nfrom exllamav2 import (\n ExLlamaV2,\n ExLlamaV2Config,\n ExLlamaV2Cache,\n ExLlamaV2Cache_8bit,\n ExLlamaV2Cache_Q4,\n ExLlamaV2Cache_Q6,\n ExLlamaV2Cache_Q8,\n ExLlamaV2Cache_TP,\n ExLlamaV2Tokenizer,\n model_init,\n)\nfrom exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator\n\nfrom modules import shared\nfrom modules.logging_colors import logger\nfrom modules.text_generation import get_max_prompt_length\n\ntry:\n import flash_attn\nexcept ModuleNotFoundError:\n logger.warning(\n 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '\n 'to be a lot higher than it could be.\\n'\n 'Try installing flash-attention following the instructions here: '\n 'https://github.com/Dao-AILab/flash-attention#installation-and-features'\n )\n pass\nexcept Exception:\n logger.warning('Failed to load flash-attention due to the following error:\\n')\n traceback.print_exc()\n\n\nclass Exllamav2Model:\n def __init__(self):\n pass\n\n @classmethod\n def from_pretrained(self, path_to_model):\n\n path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)\n\n config = ExLlamaV2Config()\n config.model_dir = str(path_to_model)\n config.prepare()\n\n config.max_seq_len = shared.args.max_seq_len\n config.scale_pos_emb = shared.args.compress_pos_emb\n config.scale_alpha_value = shared.args.alpha_value\n config.no_flash_attn = shared.args.no_flash_attn\n config.no_xformers = shared.args.no_xformers\n config.no_sdpa = shared.args.no_sdpa\n config.num_experts_per_token = int(shared.args.num_experts_per_token)\n\n model = ExLlamaV2(config)\n\n # Check if TP is enabled and load model with TP\n if shared.args.enable_tp:\n split = None\n if shared.args.gpu_split:\n split = [float(alloc) for alloc in shared.args.gpu_split.split(\",\")]\n model.load_tp(split) # Ensure TP loading is used\n else:\n if not shared.args.autosplit:\n split = None\n if shared.args.gpu_split:\n split = [float(alloc) for alloc in shared.args.gpu_split.split(\",\")]\n model.load(split)\n\n # Determine the correct cache type\n if shared.args.cache_8bit:\n cache_type = ExLlamaV2Cache_8bit\n elif shared.args.cache_4bit:\n cache_type = ExLlamaV2Cache_Q4\n else:\n cache_type = ExLlamaV2Cache\n\n # Use TP if specified\n if shared.args.enable_tp:\n cache = ExLlamaV2Cache_TP(model, base=cache_type)\n else:\n cache = cache_type(model, lazy=shared.args.autosplit)\n\n if shared.args.autosplit and not shared.args.enable_tp:\n model.load_autosplit(cache)\n\n tokenizer = ExLlamaV2Tokenizer(config)\n generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)\n\n result = self()\n result.model = model\n result.cache = cache\n result.tokenizer = tokenizer\n result.generator = generator\n result.loras = None\n return result, result\n\n def encode(self, string, **kwargs):\n return self.tokenizer.encode(string, add_bos=True, encode_special_tokens=True)\n\n def decode(self, ids, **kwargs):\n if isinstance(ids, list):\n ids = torch.tensor([ids])\n elif isinstance(ids, torch.Tensor) and ids.numel() == 1:\n ids = ids.view(1, -1)\n\n return self.tokenizer.decode(ids, decode_special_tokens=True)[0]\n\n def get_logits(self, token_ids, **kwargs):\n self.cache.current_seq_len = 0\n if token_ids.shape[-1] > 1:\n self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)\n\n return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu()\n\n def generate_with_streaming(self, prompt, state):\n settings = ExLlamaV2Sampler.Settings()\n\n settings.token_repetition_penalty = state['repetition_penalty']\n settings.token_repetition_range = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range']\n\n settings.token_frequency_penalty = state['frequency_penalty']\n settings.token_presence_penalty = state['presence_penalty']\n\n settings.temperature = state['temperature']\n settings.top_k = state['top_k']\n settings.top_p = state['top_p']\n settings.top_a = state['top_a']\n settings.min_p = state['min_p']\n settings.tfs = state['tfs']\n settings.typical = state['typical_p']\n\n settings.temperature_last = state['temperature_last']\n\n settings.mirostat = state['mirostat_mode'] == 2\n settings.mirostat_tau = state['mirostat_tau']\n settings.mirostat_eta = state['mirostat_eta']\n\n if state['ban_eos_token']:\n settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])\n\n if state['custom_token_bans']:\n to_ban = [int(x) for x in state['custom_token_bans'].split(',')]\n if len(to_ban) > 0:\n settings.disallow_tokens(self.tokenizer, to_ban)\n\n ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)\n ids = ids[:, -get_max_prompt_length(state):]\n\n if state['auto_max_new_tokens']:\n max_new_tokens = state['truncation_length'] - ids.shape[-1]\n else:\n max_new_tokens = state['max_new_tokens']\n\n self.generator.begin_stream(ids, settings, loras=self.loras)\n\n decoded_text = ''\n for i in range(max_new_tokens):\n chunk, eos, _ = self.generator.stream()\n if eos or shared.stop_everything:\n break\n\n decoded_text += chunk\n yield decoded_text\n\n def generate(self, prompt, state):\n output = ''\n for output in self.generate_with_streaming(prompt, state):\n pass\n\n return output\n\n\nExllamav2_HF loader I want to use tensor parallelism with:\n\nimport os\nimport traceback\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional, Union\n\nimport torch\nfrom exllamav2 import (\n ExLlamaV2,\n ExLlamaV2Cache,\n ExLlamaV2Cache_8bit,\n ExLlamaV2Cache_Q4,\n ExLlamaV2Config\n)\nfrom torch.nn import CrossEntropyLoss\nfrom transformers import GenerationConfig, PretrainedConfig, PreTrainedModel\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\nfrom modules import shared\nfrom modules.logging_colors import logger\n\ntry:\n import flash_attn\nexcept ModuleNotFoundError:\n logger.warning(\n 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '\n 'to be a lot higher than it could be.\\n'\n 'Try installing flash-attention following the instructions here: '\n 'https://github.com/Dao-AILab/flash-attention#installation-and-features'\n )\n pass\nexcept Exception:\n logger.warning('Failed to load flash-attention due to the following error:\\n')\n traceback.print_exc()\n\n\nclass Exllamav2HF(PreTrainedModel):\n def __init__(self, config: ExLlamaV2Config):\n super().__init__(PretrainedConfig())\n self.ex_config = config\n self.loras = None\n self.generation_config = GenerationConfig()\n\n self.ex_model = ExLlamaV2(config)\n\n if not shared.args.autosplit:\n split = None\n if shared.args.gpu_split:\n split = [float(alloc) for alloc in shared.args.gpu_split.split(\",\")]\n\n self.ex_model.load(split)\n\n if shared.args.cache_8bit:\n self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit)\n elif shared.args.cache_4bit:\n self.ex_cache = ExLlamaV2Cache_Q4(self.ex_model, lazy=shared.args.autosplit)\n else:\n self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit)\n\n if shared.args.autosplit:\n self.ex_model.load_autosplit(self.ex_cache)\n\n self.past_seq = None\n if shared.args.cfg_cache:\n if shared.args.cache_8bit:\n self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)\n elif shared.args.cache_4bit:\n self.ex_cache_negative = ExLlamaV2Cache_Q4(self.ex_model)\n else:\n self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)\n\n self.past_seq_negative = None\n\n def _validate_model_class(self):\n pass\n\n def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):\n pass\n\n def prepare_inputs_for_generation(self, input_ids, **kwargs):\n return {'input_ids': input_ids, **kwargs}\n\n @property\n def device(self) -> torch.device:\n return torch.device(0)\n\n def __call__(self, *args, **kwargs):\n use_cache = kwargs.get('use_cache', True)\n labels = kwargs.get('labels', None)\n past_key_values = kwargs.get('past_key_values', None)\n\n if len(args) > 0:\n if not shared.args.cfg_cache:\n logger.error(\"Please enable the cfg-cache option to use CFG with ExLlamav2_HF.\")\n return\n\n input_ids = args[0]\n is_negative = True\n past_seq = self.past_seq_negative\n ex_cache = self.ex_cache_negative\n else:\n input_ids = kwargs['input_ids']\n is_negative = False\n past_seq = self.past_seq\n ex_cache = self.ex_cache\n\n seq = input_ids[0].tolist()\n if is_negative and past_key_values is not None:\n seq = past_key_values + seq\n\n seq_tensor = torch.tensor(seq)\n reset = True\n\n # Make the forward call\n if labels is None:\n if past_seq is not None:\n min_length = min(past_seq.shape[0], seq_tensor.shape[0])\n indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))\n if len(indices) > 0:\n longest_prefix = indices[0].item()\n else:\n longest_prefix = min_length\n\n if longest_prefix > 0:\n reset = False\n ex_cache.current_seq_len = longest_prefix\n if len(seq_tensor) - longest_prefix > 1:\n self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)\n elif len(seq_tensor) == longest_prefix:\n # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,\n # because we feed input_ids[-1] to forward() below, but that last token is already in the cache!\n ex_cache.current_seq_len -= 1\n\n if reset:\n ex_cache.current_seq_len = 0\n if len(seq_tensor) > 1:\n self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)\n\n logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()\n else:\n ex_cache.current_seq_len = 0\n logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()\n\n if is_negative:\n self.past_seq_negative = seq_tensor\n else:\n self.past_seq = seq_tensor\n\n loss = None\n if labels is not None:\n # Shift so that tokens < n predict n\n shift_logits = logits[..., :-1, :].contiguous()\n shift_labels = labels[..., 1:].contiguous()\n # Flatten the tokens\n loss_fct = CrossEntropyLoss()\n shift_logits = shift_logits.view(-1, logits.shape[-1])\n shift_labels = shift_labels.view(-1)\n # Enable model parallelism\n shift_labels = shift_labels.to(shift_logits.device)\n loss = loss_fct(shift_logits, shift_labels)\n\n return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)\n\n @classmethod\n def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):\n assert len(model_args) == 0 and len(kwargs) == 0, \"extra args is currently not supported\"\n if isinstance(pretrained_model_name_or_path, str):\n pretrained_model_name_or_path = Path(pretrained_model_name_or_path)\n\n pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)\n\n config = ExLlamaV2Config()\n config.model_dir = str(pretrained_model_name_or_path)\n config.prepare()\n\n config.max_seq_len = shared.args.max_seq_len\n config.scale_pos_emb = shared.args.compress_pos_emb\n config.scale_alpha_value = shared.args.alpha_value\n config.no_flash_attn = shared.args.no_flash_attn\n config.no_xformers = shared.args.no_xformers\n config.no_sdpa = shared.args.no_sdpa\n config.num_experts_per_token = int(shared.args.num_experts_per_token)\n\n return Exllamav2HF(config)", 13 | "To integrate tensor parallelism (TP) into the `Exllamav2HF` loader, you'll need to make a few modifications similar to what you have in the `Exllamav2Model` class. Specifically, you'll need to add the logic to handle TP when loading the model and adjust the cache initialization accordingly.\n\nHere's how you can modify the `Exllamav2HF` class to include tensor parallelism:\n\n```python\nimport os\nimport traceback\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional, Union\n\nimport torch\nfrom exllamav2 import (\n ExLlamaV2,\n ExLlamaV2Cache,\n ExLlamaV2Cache_8bit,\n ExLlamaV2Cache_Q4,\n ExLlamaV2Config,\n ExLlamaV2Cache_TP,\n)\nfrom torch.nn import CrossEntropyLoss\nfrom transformers import GenerationConfig, PretrainedConfig, PreTrainedModel\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\nfrom modules import shared\nfrom modules.logging_colors import logger\n\ntry:\n import flash_attn\nexcept ModuleNotFoundError:\n logger.warning(\n 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '\n 'to be a lot higher than it could be.\\n'\n 'Try installing flash-attention following the instructions here: '\n 'https://github.com/Dao-AILab/flash-attention#installation-and-features'\n )\n pass\nexcept Exception:\n logger.warning('Failed to load flash-attention due to the following error:\\n')\n traceback.print_exc()\n\nclass Exllamav2HF(PreTrainedModel):\n def __init__(self, config: ExLlamaV2Config):\n super().__init__(PretrainedConfig())\n self.ex_config = config\n self.loras = None\n self.generation_config = GenerationConfig()\n\n self.ex_model = ExLlamaV2(config)\n\n # Check if TP is enabled and load model with TP\n if shared.args.enable_tp:\n split = None\n if shared.args.gpu_split:\n split = [float(alloc) for alloc in shared.args.gpu_split.split(\",\")]\n self.ex_model.load_tp(split) # Ensure TP loading is used\n else:\n if not shared.args.autosplit:\n split = None\n if shared.args.gpu_split:\n split = [float(alloc) for alloc in shared.args.gpu_split.split(\",\")]\n self.ex_model.load(split)\n\n # Determine the correct cache type\n if shared.args.cache_8bit:\n cache_type = ExLlamaV2Cache_8bit\n elif shared.args.cache_4bit:\n cache_type = ExLlamaV2Cache_Q4\n else:\n cache_type = ExLlamaV2Cache\n\n # Use TP if specified\n if shared.args.enable_tp:\n self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type)\n else:\n self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit)\n\n if shared.args.autosplit and not shared.args.enable_tp:\n self.ex_model.load_autosplit(self.ex_cache)\n\n self.past_seq = None\n if shared.args.cfg_cache:\n if shared.args.cache_8bit:\n self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)\n elif shared.args.cache_4bit:\n self.ex_cache_negative = ExLlamaV2Cache_Q4(self.ex_model)\n else:\n self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)\n\n self.past_seq_negative = None\n\n def _validate_model_class(self):\n pass\n\n def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):\n pass\n\n def prepare_inputs_for_generation(self, input_ids, **kwargs):\n return {'input_ids': input_ids, **kwargs}\n\n @property\n def device(self) -> torch.device:\n return torch.device(0)\n\n def __call__(self, *args, **kwargs):\n use_cache = kwargs.get('use_cache', True)\n labels = kwargs.get('labels', None)\n past_key_values = kwargs.get('past_key_values', None)\n\n if len(args) > 0:\n if not shared.args.cfg_cache:\n logger.error(\"Please enable the cfg-cache option to use CFG with ExLlamav2_HF.\")\n return\n\n input_ids = args[0]\n is_negative = True\n past_seq = self.past_seq_negative\n ex_cache = self.ex_cache_negative\n else:\n input_ids = kwargs['input_ids']\n is_negative = False\n past_seq = self.past_seq\n ex_cache = self.ex_cache\n\n seq = input_ids[0].tolist()\n if is_negative and past_key_values is not None:\n seq = past_key_values + seq\n\n seq_tensor = torch.tensor(seq)\n reset = True\n\n # Make the forward call\n if labels is None:\n if past_seq is not None:\n min_length = min(past_seq.shape[0], seq_tensor.shape[0])\n indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))\n if len(indices) > 0:\n longest_prefix = indices[0].item()\n else:\n longest_prefix = min_length\n\n if longest_prefix > 0:\n reset = False\n ex_cache.current_seq_len = longest_prefix\n if len(seq_tensor) - longest_prefix > 1:\n self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)\n elif len(seq_tensor) == longest_prefix:\n # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,\n # because we feed input_ids[-1] to forward() below, but that last token is already in the cache!\n ex_cache.current_seq_len -= 1\n\n if reset:\n ex_cache.current_seq_len = 0\n if len(seq_tensor) > 1:\n self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)\n\n logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()\n else:\n ex_cache.current_seq_len = 0\n logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()\n\n if is_negative:\n self.past_seq_negative = seq_tensor\n else:\n self.past_seq = seq_tensor\n\n loss = None\n if labels is not None:\n # Shift so that tokens < n predict n\n shift_logits = logits[..., :-1, :].contiguous()\n shift_labels = labels[..., 1:].contiguous()\n # Flatten the tokens\n loss_fct = CrossEntropyLoss()\n shift_logits = shift_logits.view(-1, logits.shape[-1])\n shift_labels = shift_labels.view(-1)\n # Enable model parallelism\n shift_labels = shift_labels.to(shift_logits.device)\n loss = loss_fct(shift_logits, shift_labels)\n\n return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)\n\n @classmethod\n def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):\n assert len(model_args) == 0 and len(kwargs) == 0, \"extra args is currently not supported\"\n if isinstance(pretrained_model_name_or_path, str):\n pretrained_model_name_or_path = Path(pretrained_model_name_or_path)\n\n pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)\n\n config = ExLlamaV2Config()\n config.model_dir = str(pretrained_model_name_or_path)\n config.prepare()\n\n config.max_seq_len = shared.args.max_seq_len\n config.scale_pos_emb = shared.args.compress_pos_emb\n config.scale_alpha_value = shared.args.alpha_value\n config.no_flash_attn = shared.args.no_flash_attn\n config.no_xformers = shared.args.no_xformers\n config.no_sdpa = shared.args.no_sdpa\n config.num_experts_per_token = int(shared.args.num_experts_per_token)\n\n return Exllamav2HF(config)\n```\n\nIn this modified version, I've added the logic to check if tensor parallelism (TP) is enabled and to load the model with TP if it is. I've also adjusted the cache initialization to use `ExLlamaV2Cache_TP` if TP is enabled. This should allow you to use tensor parallelism with the `Exllamav2HF` loader." 14 | ] 15 | ], 16 | "visible": [ 17 | [ 18 | "", 19 | "How can I help you today?" 20 | ], 21 | [ 22 | "Hello, how are you functioning today?", 23 | "I'm functioning well, thank you for asking! I'm here and ready to assist you. How about you? How's your day going?" 24 | ], 25 | [ 26 | "I'm doing well too. There is this quantization for LLMs called exllamaV2, recently it has implemented tensor parallelism, so folks running models on multi-gpu systems can inference faster....much faster in fact. ExllamaV2 is a loader in a framework called textgeneration webui. This allows different loaders such as ExllamaV2.\n\nSo I have exllamaV2 working with textgen using the exllamav2 loader, but i want to also get it working with the exllamav2_HF loader, this is a huggingface specific implementation of exllamav2. I'll give you the code for both:\n\nCurrently working exllamav2 code:\n #!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n"""\nCreated on Thu Aug 29 18:06:29 2024\n\n@author: myself\n"""\n\nimport traceback\nfrom pathlib import Path\n\nimport torch\nfrom exllamav2 import (\n ExLlamaV2,\n ExLlamaV2Config,\n ExLlamaV2Cache,\n ExLlamaV2Cache_8bit,\n ExLlamaV2Cache_Q4,\n ExLlamaV2Cache_Q6,\n ExLlamaV2Cache_Q8,\n ExLlamaV2Cache_TP,\n ExLlamaV2Tokenizer,\n model_init,\n)\nfrom exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator\n\nfrom modules import shared\nfrom modules.logging_colors import logger\nfrom modules.text_generation import get_max_prompt_length\n\ntry:\n import flash_attn\nexcept ModuleNotFoundError:\n logger.warning(\n 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '\n 'to be a lot higher than it could be.\\n'\n 'Try installing flash-attention following the instructions here: '\n 'https://github.com/Dao-AILab/flash-attention#installation-and-features'\n )\n pass\nexcept Exception:\n logger.warning('Failed to load flash-attention due to the following error:\\n')\n traceback.print_exc()\n\n\nclass Exllamav2Model:\n def __init__(self):\n pass\n\n @classmethod\n def from_pretrained(self, path_to_model):\n\n path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)\n\n config = ExLlamaV2Config()\n config.model_dir = str(path_to_model)\n config.prepare()\n\n config.max_seq_len = shared.args.max_seq_len\n config.scale_pos_emb = shared.args.compress_pos_emb\n config.scale_alpha_value = shared.args.alpha_value\n config.no_flash_attn = shared.args.no_flash_attn\n config.no_xformers = shared.args.no_xformers\n config.no_sdpa = shared.args.no_sdpa\n config.num_experts_per_token = int(shared.args.num_experts_per_token)\n\n model = ExLlamaV2(config)\n\n # Check if TP is enabled and load model with TP\n if shared.args.enable_tp:\n split = None\n if shared.args.gpu_split:\n split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]\n model.load_tp(split) # Ensure TP loading is used\n else:\n if not shared.args.autosplit:\n split = None\n if shared.args.gpu_split:\n split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]\n model.load(split)\n\n # Determine the correct cache type\n if shared.args.cache_8bit:\n cache_type = ExLlamaV2Cache_8bit\n elif shared.args.cache_4bit:\n cache_type = ExLlamaV2Cache_Q4\n else:\n cache_type = ExLlamaV2Cache\n\n # Use TP if specified\n if shared.args.enable_tp:\n cache = ExLlamaV2Cache_TP(model, base=cache_type)\n else:\n cache = cache_type(model, lazy=shared.args.autosplit)\n\n if shared.args.autosplit and not shared.args.enable_tp:\n model.load_autosplit(cache)\n\n tokenizer = ExLlamaV2Tokenizer(config)\n generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)\n\n result = self()\n result.model = model\n result.cache = cache\n result.tokenizer = tokenizer\n result.generator = generator\n result.loras = None\n return result, result\n\n def encode(self, string, **kwargs):\n return self.tokenizer.encode(string, add_bos=True, encode_special_tokens=True)\n\n def decode(self, ids, **kwargs):\n if isinstance(ids, list):\n ids = torch.tensor([ids])\n elif isinstance(ids, torch.Tensor) and ids.numel() == 1:\n ids = ids.view(1, -1)\n\n return self.tokenizer.decode(ids, decode_special_tokens=True)[0]\n\n def get_logits(self, token_ids, **kwargs):\n self.cache.current_seq_len = 0\n if token_ids.shape[-1] > 1:\n self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)\n\n return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu()\n\n def generate_with_streaming(self, prompt, state):\n settings = ExLlamaV2Sampler.Settings()\n\n settings.token_repetition_penalty = state['repetition_penalty']\n settings.token_repetition_range = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range']\n\n settings.token_frequency_penalty = state['frequency_penalty']\n settings.token_presence_penalty = state['presence_penalty']\n\n settings.temperature = state['temperature']\n settings.top_k = state['top_k']\n settings.top_p = state['top_p']\n settings.top_a = state['top_a']\n settings.min_p = state['min_p']\n settings.tfs = state['tfs']\n settings.typical = state['typical_p']\n\n settings.temperature_last = state['temperature_last']\n\n settings.mirostat = state['mirostat_mode'] == 2\n settings.mirostat_tau = state['mirostat_tau']\n settings.mirostat_eta = state['mirostat_eta']\n\n if state['ban_eos_token']:\n settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])\n\n if state['custom_token_bans']:\n to_ban = [int(x) for x in state['custom_token_bans'].split(',')]\n if len(to_ban) > 0:\n settings.disallow_tokens(self.tokenizer, to_ban)\n\n ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)\n ids = ids[:, -get_max_prompt_length(state):]\n\n if state['auto_max_new_tokens']:\n max_new_tokens = state['truncation_length'] - ids.shape[-1]\n else:\n max_new_tokens = state['max_new_tokens']\n\n self.generator.begin_stream(ids, settings, loras=self.loras)\n\n decoded_text = ''\n for i in range(max_new_tokens):\n chunk, eos, _ = self.generator.stream()\n if eos or shared.stop_everything:\n break\n\n decoded_text += chunk\n yield decoded_text\n\n def generate(self, prompt, state):\n output = ''\n for output in self.generate_with_streaming(prompt, state):\n pass\n\n return output\n\n\nExllamav2_HF loader I want to use tensor parallelism with:\n\nimport os\nimport traceback\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional, Union\n\nimport torch\nfrom exllamav2 import (\n ExLlamaV2,\n ExLlamaV2Cache,\n ExLlamaV2Cache_8bit,\n ExLlamaV2Cache_Q4,\n ExLlamaV2Config\n)\nfrom torch.nn import CrossEntropyLoss\nfrom transformers import GenerationConfig, PretrainedConfig, PreTrainedModel\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\nfrom modules import shared\nfrom modules.logging_colors import logger\n\ntry:\n import flash_attn\nexcept ModuleNotFoundError:\n logger.warning(\n 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '\n 'to be a lot higher than it could be.\\n'\n 'Try installing flash-attention following the instructions here: '\n 'https://github.com/Dao-AILab/flash-attention#installation-and-features'\n )\n pass\nexcept Exception:\n logger.warning('Failed to load flash-attention due to the following error:\\n')\n traceback.print_exc()\n\n\nclass Exllamav2HF(PreTrainedModel):\n def __init__(self, config: ExLlamaV2Config):\n super().__init__(PretrainedConfig())\n self.ex_config = config\n self.loras = None\n self.generation_config = GenerationConfig()\n\n self.ex_model = ExLlamaV2(config)\n\n if not shared.args.autosplit:\n split = None\n if shared.args.gpu_split:\n split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]\n\n self.ex_model.load(split)\n\n if shared.args.cache_8bit:\n self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit)\n elif shared.args.cache_4bit:\n self.ex_cache = ExLlamaV2Cache_Q4(self.ex_model, lazy=shared.args.autosplit)\n else:\n self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit)\n\n if shared.args.autosplit:\n self.ex_model.load_autosplit(self.ex_cache)\n\n self.past_seq = None\n if shared.args.cfg_cache:\n if shared.args.cache_8bit:\n self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)\n elif shared.args.cache_4bit:\n self.ex_cache_negative = ExLlamaV2Cache_Q4(self.ex_model)\n else:\n self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)\n\n self.past_seq_negative = None\n\n def _validate_model_class(self):\n pass\n\n def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):\n pass\n\n def prepare_inputs_for_generation(self, input_ids, **kwargs):\n return {'input_ids': input_ids, **kwargs}\n\n @property\n def device(self) -> torch.device:\n return torch.device(0)\n\n def __call__(self, *args, **kwargs):\n use_cache = kwargs.get('use_cache', True)\n labels = kwargs.get('labels', None)\n past_key_values = kwargs.get('past_key_values', None)\n\n if len(args) > 0:\n if not shared.args.cfg_cache:\n logger.error("Please enable the cfg-cache option to use CFG with ExLlamav2_HF.")\n return\n\n input_ids = args[0]\n is_negative = True\n past_seq = self.past_seq_negative\n ex_cache = self.ex_cache_negative\n else:\n input_ids = kwargs['input_ids']\n is_negative = False\n past_seq = self.past_seq\n ex_cache = self.ex_cache\n\n seq = input_ids[0].tolist()\n if is_negative and past_key_values is not None:\n seq = past_key_values + seq\n\n seq_tensor = torch.tensor(seq)\n reset = True\n\n # Make the forward call\n if labels is None:\n if past_seq is not None:\n min_length = min(past_seq.shape[0], seq_tensor.shape[0])\n indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))\n if len(indices) > 0:\n longest_prefix = indices[0].item()\n else:\n longest_prefix = min_length\n\n if longest_prefix > 0:\n reset = False\n ex_cache.current_seq_len = longest_prefix\n if len(seq_tensor) - longest_prefix > 1:\n self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)\n elif len(seq_tensor) == longest_prefix:\n # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,\n # because we feed input_ids[-1] to forward() below, but that last token is already in the cache!\n ex_cache.current_seq_len -= 1\n\n if reset:\n ex_cache.current_seq_len = 0\n if len(seq_tensor) > 1:\n self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)\n\n logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()\n else:\n ex_cache.current_seq_len = 0\n logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()\n\n if is_negative:\n self.past_seq_negative = seq_tensor\n else:\n self.past_seq = seq_tensor\n\n loss = None\n if labels is not None:\n # Shift so that tokens < n predict n\n shift_logits = logits[..., :-1, :].contiguous()\n shift_labels = labels[..., 1:].contiguous()\n # Flatten the tokens\n loss_fct = CrossEntropyLoss()\n shift_logits = shift_logits.view(-1, logits.shape[-1])\n shift_labels = shift_labels.view(-1)\n # Enable model parallelism\n shift_labels = shift_labels.to(shift_logits.device)\n loss = loss_fct(shift_logits, shift_labels)\n\n return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)\n\n @classmethod\n def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):\n assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"\n if isinstance(pretrained_model_name_or_path, str):\n pretrained_model_name_or_path = Path(pretrained_model_name_or_path)\n\n pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)\n\n config = ExLlamaV2Config()\n config.model_dir = str(pretrained_model_name_or_path)\n config.prepare()\n\n config.max_seq_len = shared.args.max_seq_len\n config.scale_pos_emb = shared.args.compress_pos_emb\n config.scale_alpha_value = shared.args.alpha_value\n config.no_flash_attn = shared.args.no_flash_attn\n config.no_xformers = shared.args.no_xformers\n config.no_sdpa = shared.args.no_sdpa\n config.num_experts_per_token = int(shared.args.num_experts_per_token)\n\n return Exllamav2HF(config)", 27 | "To integrate tensor parallelism (TP) into the `Exllamav2HF` loader, you'll need to make a few modifications similar to what you have in the `Exllamav2Model` class. Specifically, you'll need to add the logic to handle TP when loading the model and adjust the cache initialization accordingly.\n\nHere's how you can modify the `Exllamav2HF` class to include tensor parallelism:\n\n```python\nimport os\nimport traceback\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional, Union\n\nimport torch\nfrom exllamav2 import (\n ExLlamaV2,\n ExLlamaV2Cache,\n ExLlamaV2Cache_8bit,\n ExLlamaV2Cache_Q4,\n ExLlamaV2Config,\n ExLlamaV2Cache_TP,\n)\nfrom torch.nn import CrossEntropyLoss\nfrom transformers import GenerationConfig, PretrainedConfig, PreTrainedModel\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\nfrom modules import shared\nfrom modules.logging_colors import logger\n\ntry:\n import flash_attn\nexcept ModuleNotFoundError:\n logger.warning(\n 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '\n 'to be a lot higher than it could be.\\n'\n 'Try installing flash-attention following the instructions here: '\n 'https://github.com/Dao-AILab/flash-attention#installation-and-features'\n )\n pass\nexcept Exception:\n logger.warning('Failed to load flash-attention due to the following error:\\n')\n traceback.print_exc()\n\nclass Exllamav2HF(PreTrainedModel):\n def __init__(self, config: ExLlamaV2Config):\n super().__init__(PretrainedConfig())\n self.ex_config = config\n self.loras = None\n self.generation_config = GenerationConfig()\n\n self.ex_model = ExLlamaV2(config)\n\n # Check if TP is enabled and load model with TP\n if shared.args.enable_tp:\n split = None\n if shared.args.gpu_split:\n split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]\n self.ex_model.load_tp(split) # Ensure TP loading is used\n else:\n if not shared.args.autosplit:\n split = None\n if shared.args.gpu_split:\n split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]\n self.ex_model.load(split)\n\n # Determine the correct cache type\n if shared.args.cache_8bit:\n cache_type = ExLlamaV2Cache_8bit\n elif shared.args.cache_4bit:\n cache_type = ExLlamaV2Cache_Q4\n else:\n cache_type = ExLlamaV2Cache\n\n # Use TP if specified\n if shared.args.enable_tp:\n self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type)\n else:\n self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit)\n\n if shared.args.autosplit and not shared.args.enable_tp:\n self.ex_model.load_autosplit(self.ex_cache)\n\n self.past_seq = None\n if shared.args.cfg_cache:\n if shared.args.cache_8bit:\n self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)\n elif shared.args.cache_4bit:\n self.ex_cache_negative = ExLlamaV2Cache_Q4(self.ex_model)\n else:\n self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)\n\n self.past_seq_negative = None\n\n def _validate_model_class(self):\n pass\n\n def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):\n pass\n\n def prepare_inputs_for_generation(self, input_ids, **kwargs):\n return {'input_ids': input_ids, **kwargs}\n\n @property\n def device(self) -> torch.device:\n return torch.device(0)\n\n def __call__(self, *args, **kwargs):\n use_cache = kwargs.get('use_cache', True)\n labels = kwargs.get('labels', None)\n past_key_values = kwargs.get('past_key_values', None)\n\n if len(args) > 0:\n if not shared.args.cfg_cache:\n logger.error("Please enable the cfg-cache option to use CFG with ExLlamav2_HF.")\n return\n\n input_ids = args[0]\n is_negative = True\n past_seq = self.past_seq_negative\n ex_cache = self.ex_cache_negative\n else:\n input_ids = kwargs['input_ids']\n is_negative = False\n past_seq = self.past_seq\n ex_cache = self.ex_cache\n\n seq = input_ids[0].tolist()\n if is_negative and past_key_values is not None:\n seq = past_key_values + seq\n\n seq_tensor = torch.tensor(seq)\n reset = True\n\n # Make the forward call\n if labels is None:\n if past_seq is not None:\n min_length = min(past_seq.shape[0], seq_tensor.shape[0])\n indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))\n if len(indices) > 0:\n longest_prefix = indices[0].item()\n else:\n longest_prefix = min_length\n\n if longest_prefix > 0:\n reset = False\n ex_cache.current_seq_len = longest_prefix\n if len(seq_tensor) - longest_prefix > 1:\n self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)\n elif len(seq_tensor) == longest_prefix:\n # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,\n # because we feed input_ids[-1] to forward() below, but that last token is already in the cache!\n ex_cache.current_seq_len -= 1\n\n if reset:\n ex_cache.current_seq_len = 0\n if len(seq_tensor) > 1:\n self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)\n\n logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()\n else:\n ex_cache.current_seq_len = 0\n logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()\n\n if is_negative:\n self.past_seq_negative = seq_tensor\n else:\n self.past_seq = seq_tensor\n\n loss = None\n if labels is not None:\n # Shift so that tokens < n predict n\n shift_logits = logits[..., :-1, :].contiguous()\n shift_labels = labels[..., 1:].contiguous()\n # Flatten the tokens\n loss_fct = CrossEntropyLoss()\n shift_logits = shift_logits.view(-1, logits.shape[-1])\n shift_labels = shift_labels.view(-1)\n # Enable model parallelism\n shift_labels = shift_labels.to(shift_logits.device)\n loss = loss_fct(shift_logits, shift_labels)\n\n return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)\n\n @classmethod\n def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):\n assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"\n if isinstance(pretrained_model_name_or_path, str):\n pretrained_model_name_or_path = Path(pretrained_model_name_or_path)\n\n pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)\n\n config = ExLlamaV2Config()\n config.model_dir = str(pretrained_model_name_or_path)\n config.prepare()\n\n config.max_seq_len = shared.args.max_seq_len\n config.scale_pos_emb = shared.args.compress_pos_emb\n config.scale_alpha_value = shared.args.alpha_value\n config.no_flash_attn = shared.args.no_flash_attn\n config.no_xformers = shared.args.no_xformers\n config.no_sdpa = shared.args.no_sdpa\n config.num_experts_per_token = int(shared.args.num_experts_per_token)\n\n return Exllamav2HF(config)\n```\n\nIn this modified version, I've added the logic to check if tensor parallelism (TP) is enabled and to load the model with TP if it is. I've also adjusted the cache initialization to use `ExLlamaV2Cache_TP` if TP is enabled. This should allow you to use tensor parallelism with the `Exllamav2HF` loader." 28 | ] 29 | ] 30 | } -------------------------------------------------------------------------------- /ExllamaV2_TensorParallel_Files/exllamav2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Aug 29 18:06:29 2024 5 | 6 | @author: myself 7 | """ 8 | 9 | import traceback 10 | from pathlib import Path 11 | 12 | import torch 13 | from exllamav2 import ( 14 | ExLlamaV2, 15 | ExLlamaV2Config, 16 | ExLlamaV2Cache, 17 | ExLlamaV2Cache_8bit, 18 | ExLlamaV2Cache_Q4, 19 | ExLlamaV2Cache_Q6, 20 | ExLlamaV2Cache_Q8, 21 | ExLlamaV2Cache_TP, 22 | ExLlamaV2Tokenizer, 23 | model_init, 24 | ) 25 | from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator 26 | 27 | from modules import shared 28 | from modules.logging_colors import logger 29 | from modules.text_generation import get_max_prompt_length 30 | 31 | try: 32 | import flash_attn 33 | except ModuleNotFoundError: 34 | logger.warning( 35 | 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage ' 36 | 'to be a lot higher than it could be.\n' 37 | 'Try installing flash-attention following the instructions here: ' 38 | 'https://github.com/Dao-AILab/flash-attention#installation-and-features' 39 | ) 40 | pass 41 | except Exception: 42 | logger.warning('Failed to load flash-attention due to the following error:\n') 43 | traceback.print_exc() 44 | 45 | 46 | class Exllamav2Model: 47 | def __init__(self): 48 | pass 49 | 50 | @classmethod 51 | def from_pretrained(self, path_to_model): 52 | 53 | path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) 54 | 55 | config = ExLlamaV2Config() 56 | config.model_dir = str(path_to_model) 57 | config.prepare() 58 | 59 | config.max_seq_len = shared.args.max_seq_len 60 | config.scale_pos_emb = shared.args.compress_pos_emb 61 | config.scale_alpha_value = shared.args.alpha_value 62 | config.no_flash_attn = shared.args.no_flash_attn 63 | config.no_xformers = shared.args.no_xformers 64 | config.no_sdpa = shared.args.no_sdpa 65 | config.num_experts_per_token = int(shared.args.num_experts_per_token) 66 | 67 | model = ExLlamaV2(config) 68 | 69 | # Check if TP is enabled and load model with TP 70 | if shared.args.enable_tp: 71 | split = None 72 | if shared.args.gpu_split: 73 | split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] 74 | model.load_tp(split) # Ensure TP loading is used 75 | else: 76 | if not shared.args.autosplit: 77 | split = None 78 | if shared.args.gpu_split: 79 | split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] 80 | model.load(split) 81 | 82 | # Determine the correct cache type 83 | if shared.args.cache_8bit: 84 | cache_type = ExLlamaV2Cache_8bit 85 | elif shared.args.cache_4bit: 86 | cache_type = ExLlamaV2Cache_Q4 87 | else: 88 | cache_type = ExLlamaV2Cache 89 | 90 | # Use TP if specified 91 | if shared.args.enable_tp: 92 | cache = ExLlamaV2Cache_TP(model, base=cache_type) 93 | else: 94 | cache = cache_type(model, lazy=shared.args.autosplit) 95 | 96 | if shared.args.autosplit and not shared.args.enable_tp: 97 | model.load_autosplit(cache) 98 | 99 | tokenizer = ExLlamaV2Tokenizer(config) 100 | generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) 101 | 102 | result = self() 103 | result.model = model 104 | result.cache = cache 105 | result.tokenizer = tokenizer 106 | result.generator = generator 107 | result.loras = None 108 | return result, result 109 | 110 | def encode(self, string, **kwargs): 111 | return self.tokenizer.encode(string, add_bos=True, encode_special_tokens=True) 112 | 113 | def decode(self, ids, **kwargs): 114 | if isinstance(ids, list): 115 | ids = torch.tensor([ids]) 116 | elif isinstance(ids, torch.Tensor) and ids.numel() == 1: 117 | ids = ids.view(1, -1) 118 | 119 | return self.tokenizer.decode(ids, decode_special_tokens=True)[0] 120 | 121 | def get_logits(self, token_ids, **kwargs): 122 | self.cache.current_seq_len = 0 123 | if token_ids.shape[-1] > 1: 124 | self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras) 125 | 126 | return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu() 127 | 128 | def generate_with_streaming(self, prompt, state): 129 | settings = ExLlamaV2Sampler.Settings() 130 | 131 | settings.token_repetition_penalty = state['repetition_penalty'] 132 | settings.token_repetition_range = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range'] 133 | 134 | settings.token_frequency_penalty = state['frequency_penalty'] 135 | settings.token_presence_penalty = state['presence_penalty'] 136 | 137 | settings.temperature = state['temperature'] 138 | settings.top_k = state['top_k'] 139 | settings.top_p = state['top_p'] 140 | settings.top_a = state['top_a'] 141 | settings.min_p = state['min_p'] 142 | settings.tfs = state['tfs'] 143 | settings.typical = state['typical_p'] 144 | 145 | settings.temperature_last = state['temperature_last'] 146 | 147 | settings.mirostat = state['mirostat_mode'] == 2 148 | settings.mirostat_tau = state['mirostat_tau'] 149 | settings.mirostat_eta = state['mirostat_eta'] 150 | 151 | if state['ban_eos_token']: 152 | settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) 153 | 154 | if state['custom_token_bans']: 155 | to_ban = [int(x) for x in state['custom_token_bans'].split(',')] 156 | if len(to_ban) > 0: 157 | settings.disallow_tokens(self.tokenizer, to_ban) 158 | 159 | ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True) 160 | ids = ids[:, -get_max_prompt_length(state):] 161 | 162 | if state['auto_max_new_tokens']: 163 | max_new_tokens = state['truncation_length'] - ids.shape[-1] 164 | else: 165 | max_new_tokens = state['max_new_tokens'] 166 | 167 | self.generator.begin_stream(ids, settings, loras=self.loras) 168 | 169 | decoded_text = '' 170 | for i in range(max_new_tokens): 171 | chunk, eos, _ = self.generator.stream() 172 | if eos or shared.stop_everything: 173 | break 174 | 175 | decoded_text += chunk 176 | yield decoded_text 177 | 178 | def generate(self, prompt, state): 179 | output = '' 180 | for output in self.generate_with_streaming(prompt, state): 181 | pass 182 | 183 | return output 184 | -------------------------------------------------------------------------------- /ExllamaV2_TensorParallel_Files/exllamav2_hf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | from pathlib import Path 4 | from typing import Any, Dict, Optional, Union 5 | 6 | import torch 7 | from exllamav2 import ( 8 | ExLlamaV2, 9 | ExLlamaV2Cache, 10 | ExLlamaV2Cache_8bit, 11 | ExLlamaV2Cache_Q4, 12 | ExLlamaV2Config, 13 | ExLlamaV2Cache_TP, 14 | ) 15 | from torch.nn import CrossEntropyLoss 16 | from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel 17 | from transformers.modeling_outputs import CausalLMOutputWithPast 18 | 19 | from modules import shared 20 | from modules.logging_colors import logger 21 | 22 | try: 23 | import flash_attn 24 | except ModuleNotFoundError: 25 | logger.warning( 26 | 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage ' 27 | 'to be a lot higher than it could be.\\n' 28 | 'Try installing flash-attention following the instructions here: ' 29 | 'https://github.com/Dao-AILab/flash-attention#installation-and-features' 30 | ) 31 | pass 32 | except Exception: 33 | logger.warning('Failed to load flash-attention due to the following error:\\n') 34 | traceback.print_exc() 35 | 36 | class Exllamav2HF(PreTrainedModel): 37 | def __init__(self, config: ExLlamaV2Config): 38 | super().__init__(PretrainedConfig()) 39 | self.ex_config = config 40 | self.loras = None 41 | self.generation_config = GenerationConfig() 42 | 43 | self.ex_model = ExLlamaV2(config) 44 | 45 | # Check if TP is enabled and load model with TP 46 | if shared.args.enable_tp: 47 | split = None 48 | if shared.args.gpu_split: 49 | split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] 50 | self.ex_model.load_tp(split) # Ensure TP loading is used 51 | else: 52 | if not shared.args.autosplit: 53 | split = None 54 | if shared.args.gpu_split: 55 | split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] 56 | self.ex_model.load(split) 57 | 58 | # Determine the correct cache type 59 | if shared.args.cache_8bit: 60 | cache_type = ExLlamaV2Cache_8bit 61 | elif shared.args.cache_4bit: 62 | cache_type = ExLlamaV2Cache_Q4 63 | else: 64 | cache_type = ExLlamaV2Cache 65 | 66 | # Use TP if specified 67 | if shared.args.enable_tp: 68 | self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type) 69 | else: 70 | self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit) 71 | 72 | if shared.args.autosplit and not shared.args.enable_tp: 73 | self.ex_model.load_autosplit(self.ex_cache) 74 | 75 | self.past_seq = None 76 | if shared.args.cfg_cache: 77 | if shared.args.cache_8bit: 78 | self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model) 79 | elif shared.args.cache_4bit: 80 | self.ex_cache_negative = ExLlamaV2Cache_Q4(self.ex_model) 81 | else: 82 | self.ex_cache_negative = ExLlamaV2Cache(self.ex_model) 83 | 84 | self.past_seq_negative = None 85 | 86 | def _validate_model_class(self): 87 | pass 88 | 89 | def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): 90 | pass 91 | 92 | def prepare_inputs_for_generation(self, input_ids, **kwargs): 93 | return {'input_ids': input_ids, **kwargs} 94 | 95 | @property 96 | def device(self) -> torch.device: 97 | return torch.device(0) 98 | 99 | def __call__(self, *args, **kwargs): 100 | use_cache = kwargs.get('use_cache', True) 101 | labels = kwargs.get('labels', None) 102 | past_key_values = kwargs.get('past_key_values', None) 103 | 104 | if len(args) > 0: 105 | if not shared.args.cfg_cache: 106 | logger.error("Please enable the cfg-cache option to use CFG with ExLlamav2_HF.") 107 | return 108 | 109 | input_ids = args[0] 110 | is_negative = True 111 | past_seq = self.past_seq_negative 112 | ex_cache = self.ex_cache_negative 113 | else: 114 | input_ids = kwargs['input_ids'] 115 | is_negative = False 116 | past_seq = self.past_seq 117 | ex_cache = self.ex_cache 118 | 119 | seq = input_ids[0].tolist() 120 | if is_negative and past_key_values is not None: 121 | seq = past_key_values + seq 122 | 123 | seq_tensor = torch.tensor(seq) 124 | reset = True 125 | 126 | # Make the forward call 127 | if labels is None: 128 | if past_seq is not None: 129 | min_length = min(past_seq.shape[0], seq_tensor.shape[0]) 130 | indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) 131 | if len(indices) > 0: 132 | longest_prefix = indices[0].item() 133 | else: 134 | longest_prefix = min_length 135 | 136 | if longest_prefix > 0: 137 | reset = False 138 | ex_cache.current_seq_len = longest_prefix 139 | if len(seq_tensor) - longest_prefix > 1: 140 | self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras) 141 | elif len(seq_tensor) == longest_prefix: 142 | # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one, 143 | # because we feed input_ids[-1] to forward() below, but that last token is already in the cache! 144 | ex_cache.current_seq_len -= 1 145 | 146 | if reset: 147 | ex_cache.current_seq_len = 0 148 | if len(seq_tensor) > 1: 149 | self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras) 150 | 151 | logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float() 152 | else: 153 | ex_cache.current_seq_len = 0 154 | logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float() 155 | 156 | if is_negative: 157 | self.past_seq_negative = seq_tensor 158 | else: 159 | self.past_seq = seq_tensor 160 | 161 | loss = None 162 | if labels is not None: 163 | # Shift so that tokens < n predict n 164 | shift_logits = logits[..., :-1, :].contiguous() 165 | shift_labels = labels[..., 1:].contiguous() 166 | # Flatten the tokens 167 | loss_fct = CrossEntropyLoss() 168 | shift_logits = shift_logits.view(-1, logits.shape[-1]) 169 | shift_labels = shift_labels.view(-1) 170 | # Enable model parallelism 171 | shift_labels = shift_labels.to(shift_logits.device) 172 | loss = loss_fct(shift_logits, shift_labels) 173 | 174 | return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss) 175 | 176 | @classmethod 177 | def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): 178 | assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported" 179 | if isinstance(pretrained_model_name_or_path, str): 180 | pretrained_model_name_or_path = Path(pretrained_model_name_or_path) 181 | 182 | pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path) 183 | 184 | config = ExLlamaV2Config() 185 | config.model_dir = str(pretrained_model_name_or_path) 186 | config.prepare() 187 | 188 | config.max_seq_len = shared.args.max_seq_len 189 | config.scale_pos_emb = shared.args.compress_pos_emb 190 | config.scale_alpha_value = shared.args.alpha_value 191 | config.no_flash_attn = shared.args.no_flash_attn 192 | config.no_xformers = shared.args.no_xformers 193 | config.no_sdpa = shared.args.no_sdpa 194 | config.num_experts_per_token = int(shared.args.num_experts_per_token) 195 | 196 | return Exllamav2HF(config) 197 | 198 | -------------------------------------------------------------------------------- /ExllamaV2_TensorParallel_Files/shared.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import sys 5 | from collections import OrderedDict 6 | from pathlib import Path 7 | 8 | import yaml 9 | 10 | from modules.logging_colors import logger 11 | 12 | # Model variables 13 | model = None 14 | tokenizer = None 15 | model_name = 'None' 16 | is_seq2seq = False 17 | model_dirty_from_training = False 18 | lora_names = [] 19 | 20 | # Generation variables 21 | stop_everything = False 22 | generation_lock = None 23 | processing_message = '*Is typing...*' 24 | 25 | # UI variables 26 | gradio = {} 27 | persistent_interface_state = {} 28 | need_restart = False 29 | 30 | # UI defaults 31 | settings = { 32 | 'dark_theme': True, 33 | 'show_controls': True, 34 | 'start_with': '', 35 | 'mode': 'chat-instruct', 36 | 'chat_style': 'cai-chat', 37 | 'prompt-default': 'QA', 38 | 'prompt-notebook': 'QA', 39 | 'preset': 'min_p', 40 | 'max_new_tokens': 512, 41 | 'max_new_tokens_min': 1, 42 | 'max_new_tokens_max': 4096, 43 | 'negative_prompt': '', 44 | 'seed': -1, 45 | 'truncation_length': 2048, 46 | 'max_tokens_second': 0, 47 | 'max_updates_second': 0, 48 | 'prompt_lookup_num_tokens': 0, 49 | 'custom_stopping_strings': '', 50 | 'custom_token_bans': '', 51 | 'auto_max_new_tokens': False, 52 | 'ban_eos_token': False, 53 | 'add_bos_token': True, 54 | 'skip_special_tokens': True, 55 | 'stream': True, 56 | 'character': 'Assistant', 57 | 'name1': 'You', 58 | 'user_bio': '', 59 | 'custom_system_message': '', 60 | 'instruction_template_str': "{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not ns.found -%}\n {{- '' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' + '\\n\\n' -}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {{- '' + message['content'] + '\\n\\n' -}}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{-'### Instruction:\\n' + message['content'] + '\\n\\n'-}}\n {%- else -%}\n {{-'### Response:\\n' + message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{-'### Response:\\n'-}}\n{%- endif -%}", 61 | 'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}", 62 | 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', 63 | 'autoload_model': False, 64 | 'default_extensions': [], 65 | } 66 | 67 | default_settings = copy.deepcopy(settings) 68 | 69 | # Parser copied from https://github.com/vladmandic/automatic 70 | parser = argparse.ArgumentParser(description="Text generation web UI", conflict_handler='resolve', add_help=True, formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=55, indent_increment=2, width=200)) 71 | 72 | # Basic settings 73 | group = parser.add_argument_group('Basic settings') 74 | group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.') 75 | group.add_argument('--character', type=str, help='The name of the character to load in chat mode by default.') 76 | group.add_argument('--model', type=str, help='Name of the model to load by default.') 77 | group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.') 78 | group.add_argument('--model-dir', type=str, default='models/', help='Path to directory with all the models.') 79 | group.add_argument('--lora-dir', type=str, default='loras/', help='Path to directory with all the loras.') 80 | group.add_argument('--model-menu', action='store_true', help='Show a model menu in the terminal when the web UI is first launched.') 81 | group.add_argument('--settings', type=str, help='Load the default interface settings from this yaml file. See settings-template.yaml for an example. If you create a file called settings.yaml, this file will be loaded by default without the need to use the --settings flag.') 82 | group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') 83 | group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') 84 | group.add_argument('--chat-buttons', action='store_true', help='Show buttons on the chat tab instead of a hover menu.') 85 | group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.') 86 | 87 | # Model loader 88 | group = parser.add_argument_group('Model loader') 89 | group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav2_HF, ExLlamav2, AutoGPTQ.') 90 | 91 | # Transformers/Accelerate 92 | group = parser.add_argument_group('Transformers/Accelerate') 93 | group.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.') 94 | group.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') 95 | group.add_argument('--gpu-memory', type=str, nargs='+', help='Maximum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.') 96 | group.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.') 97 | group.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.') 98 | group.add_argument('--disk-cache-dir', type=str, default='cache', help='Directory to save the disk cache to. Defaults to "cache".') 99 | group.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision (using bitsandbytes).') 100 | group.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') 101 | group.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.') 102 | group.add_argument('--trust-remote-code', action='store_true', help='Set trust_remote_code=True while loading the model. Necessary for some models.') 103 | group.add_argument('--force-safetensors', action='store_true', help='Set use_safetensors=True while loading the model. This prevents arbitrary code execution.') 104 | group.add_argument('--no_use_fast', action='store_true', help='Set use_fast=False while loading the tokenizer (it\'s True by default). Use this if you have any problems related to use_fast.') 105 | group.add_argument('--use_flash_attention_2', action='store_true', help='Set use_flash_attention_2=True while loading the model.') 106 | group.add_argument('--use_eager_attention', action='store_true', help='Set attn_implementation= eager while loading the model.') 107 | 108 | # bitsandbytes 4-bit 109 | group = parser.add_argument_group('bitsandbytes 4-bit') 110 | group.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision (using bitsandbytes).') 111 | group.add_argument('--use_double_quant', action='store_true', help='use_double_quant for 4-bit.') 112 | group.add_argument('--compute_dtype', type=str, default='float16', help='compute dtype for 4-bit. Valid options: bfloat16, float16, float32.') 113 | group.add_argument('--quant_type', type=str, default='nf4', help='quant_type for 4-bit. Valid options: nf4, fp4.') 114 | 115 | # llama.cpp 116 | group = parser.add_argument_group('llama.cpp') 117 | group.add_argument('--flash-attn', action='store_true', help='Use flash-attention.') 118 | group.add_argument('--tensorcores', action='store_true', help='NVIDIA only: use llama-cpp-python compiled with tensor cores support. This may increase performance on newer cards.') 119 | group.add_argument('--n_ctx', type=int, default=2048, help='Size of the prompt context.') 120 | group.add_argument('--threads', type=int, default=0, help='Number of threads to use.') 121 | group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.') 122 | group.add_argument('--no_mul_mat_q', action='store_true', help='Disable the mulmat kernels.') 123 | group.add_argument('--n_batch', type=int, default=512, help='Maximum number of prompt tokens to batch together when calling llama_eval.') 124 | group.add_argument('--no-mmap', action='store_true', help='Prevent mmap from being used.') 125 | group.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.') 126 | group.add_argument('--n-gpu-layers', type=int, default=0, help='Number of layers to offload to the GPU.') 127 | group.add_argument('--tensor_split', type=str, default=None, help='Split the model across multiple GPUs. Comma-separated list of proportions. Example: 60,40.') 128 | group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.') 129 | group.add_argument('--logits_all', action='store_true', help='Needs to be set for perplexity evaluation to work. Otherwise, ignore it, as it makes prompt processing slower.') 130 | group.add_argument('--no_offload_kqv', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.') 131 | group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (llama-cpp-python). Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.') 132 | group.add_argument('--row_split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.') 133 | group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.') 134 | group.add_argument('--attention-sink-size', type=int, default=5, help='StreamingLLM: number of sink tokens. Only used if the trimmed prompt does not share a prefix with the old prompt.') 135 | group.add_argument('--tokenizer-dir', type=str, help='Load the tokenizer from this folder. Meant to be used with llamacpp_HF through the command-line.') 136 | 137 | # ExLlamaV2 138 | group = parser.add_argument_group('ExLlamaV2') 139 | group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.') 140 | group.add_argument('--autosplit', action='store_true', help='Autosplit the model tensors across the available GPUs. This causes --gpu-split to be ignored.') 141 | group.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.') 142 | group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.') 143 | group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.') 144 | group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.') 145 | group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to not be used.') 146 | group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.') 147 | group.add_argument('--cache_4bit', action='store_true', help='Use Q4 cache to save VRAM.') 148 | group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.') 149 | group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.') 150 | 151 | # AutoGPTQ 152 | group = parser.add_argument_group('AutoGPTQ') 153 | group.add_argument('--triton', action='store_true', help='Use triton.') 154 | group.add_argument('--no_inject_fused_mlp', action='store_true', help='Triton mode only: disable the use of fused MLP, which will use less VRAM at the cost of slower inference.') 155 | group.add_argument('--no_use_cuda_fp16', action='store_true', help='This can make models faster on some systems.') 156 | group.add_argument('--desc_act', action='store_true', help='For models that do not have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.') 157 | group.add_argument('--disable_exllama', action='store_true', help='Disable ExLlama kernel, which can improve inference speed on some systems.') 158 | group.add_argument('--disable_exllamav2', action='store_true', help='Disable ExLlamav2 kernel.') 159 | group.add_argument('--wbits', type=int, default=0, help='Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.') 160 | group.add_argument('--groupsize', type=int, default=-1, help='Group size.') 161 | 162 | # HQQ 163 | group = parser.add_argument_group('HQQ') 164 | group.add_argument('--hqq-backend', type=str, default='PYTORCH_COMPILE', help='Backend for the HQQ loader. Valid options: PYTORCH, PYTORCH_COMPILE, ATEN.') 165 | 166 | # TensorRT-LLM 167 | group = parser.add_argument_group('TensorRT-LLM') 168 | group.add_argument('--cpp-runner', action='store_true', help='Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn\'t support streaming yet.') 169 | 170 | # DeepSpeed 171 | group = parser.add_argument_group('DeepSpeed') 172 | group.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') 173 | group.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.') 174 | group.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.') 175 | 176 | # RoPE 177 | group = parser.add_argument_group('RoPE') 178 | group.add_argument('--alpha_value', type=float, default=1, help='Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.') 179 | group.add_argument('--rope_freq_base', type=int, default=0, help='If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63).') 180 | group.add_argument('--compress_pos_emb', type=int, default=1, help="Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.") 181 | 182 | # Gradio 183 | group = parser.add_argument_group('Gradio') 184 | group.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') 185 | group.add_argument('--listen-port', type=int, help='The listening port that the server will use.') 186 | group.add_argument('--listen-host', type=str, help='The hostname that the server will use.') 187 | group.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') 188 | group.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') 189 | group.add_argument('--gradio-auth', type=str, help='Set Gradio authentication password in the format "username:password". Multiple credentials can also be supplied with "u1:p1,u2:p2,u3:p3".', default=None) 190 | group.add_argument('--gradio-auth-path', type=str, help='Set the Gradio authentication file path. The file should contain one or more user:password pairs in the same format as above.', default=None) 191 | group.add_argument('--ssl-keyfile', type=str, help='The path to the SSL certificate key file.', default=None) 192 | group.add_argument('--ssl-certfile', type=str, help='The path to the SSL certificate cert file.', default=None) 193 | group.add_argument('--subpath', type=str, help='Customize the subpath for gradio, use with reverse proxy') 194 | 195 | # API 196 | group = parser.add_argument_group('API') 197 | group.add_argument('--api', action='store_true', help='Enable the API extension.') 198 | group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.') 199 | group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None) 200 | group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.') 201 | group.add_argument('--api-key', type=str, default='', help='API authentication key.') 202 | group.add_argument('--admin-key', type=str, default='', help='API authentication key for admin tasks like loading and unloading models. If not set, will be the same as --api-key.') 203 | group.add_argument('--nowebui', action='store_true', help='Do not launch the Gradio UI. Useful for launching the API in standalone mode.') 204 | 205 | # Multimodal 206 | group = parser.add_argument_group('Multimodal') 207 | group.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.') 208 | 209 | # Deprecated parameters 210 | group = parser.add_argument_group('Deprecated') 211 | group.add_argument('--model_type', type=str, help='DEPRECATED') 212 | group.add_argument('--pre_layer', type=int, nargs='+', help='DEPRECATED') 213 | group.add_argument('--checkpoint', type=str, help='DEPRECATED') 214 | group.add_argument('--monkey-patch', action='store_true', help='DEPRECATED') 215 | group.add_argument('--no_inject_fused_attention', action='store_true', help='DEPRECATED') 216 | 217 | args = parser.parse_args() 218 | args_defaults = parser.parse_args([]) 219 | provided_arguments = [] 220 | for arg in sys.argv[1:]: 221 | arg = arg.lstrip('-').replace('-', '_') 222 | if hasattr(args, arg): 223 | provided_arguments.append(arg) 224 | 225 | deprecated_args = [] 226 | 227 | 228 | def do_cmd_flags_warnings(): 229 | 230 | # Deprecation warnings 231 | for k in deprecated_args: 232 | if getattr(args, k): 233 | logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.') 234 | 235 | # Security warnings 236 | if args.trust_remote_code: 237 | logger.warning('trust_remote_code is enabled. This is dangerous.') 238 | if 'COLAB_GPU' not in os.environ and not args.nowebui: 239 | if args.share: 240 | logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.") 241 | if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)): 242 | logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.") 243 | if args.multi_user: 244 | logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.') 245 | 246 | 247 | def fix_loader_name(name): 248 | if not name: 249 | return name 250 | 251 | name = name.lower() 252 | if name in ['llamacpp', 'llama.cpp', 'llama-cpp', 'llama cpp']: 253 | return 'llama.cpp' 254 | if name in ['llamacpp_hf', 'llama.cpp_hf', 'llama-cpp-hf', 'llamacpp-hf', 'llama.cpp-hf']: 255 | return 'llamacpp_HF' 256 | elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']: 257 | return 'Transformers' 258 | elif name in ['autogptq', 'auto-gptq', 'auto_gptq', 'auto gptq']: 259 | return 'AutoGPTQ' 260 | elif name in ['exllama', 'ex-llama', 'ex_llama', 'exlama']: 261 | return 'ExLlama' 262 | elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2', 'exllama2', 'exllama-2']: 263 | return 'ExLlamav2' 264 | elif name in ['exllamav2-hf', 'exllamav2_hf', 'exllama-v2-hf', 'exllama_v2_hf', 'exllama-v2_hf', 'exllama2-hf', 'exllama2_hf', 'exllama-2-hf', 'exllama_2_hf', 'exllama-2_hf']: 265 | return 'ExLlamav2_HF' 266 | elif name in ['hqq']: 267 | return 'HQQ' 268 | elif name in ['tensorrt', 'tensorrtllm', 'tensorrt_llm', 'tensorrt-llm', 'tensort', 'tensortllm']: 269 | return 'TensorRT-LLM' 270 | 271 | 272 | def add_extension(name, last=False): 273 | if args.extensions is None: 274 | args.extensions = [name] 275 | elif last: 276 | args.extensions = [x for x in args.extensions if x != name] 277 | args.extensions.append(name) 278 | elif name not in args.extensions: 279 | args.extensions.append(name) 280 | 281 | 282 | def is_chat(): 283 | return True 284 | 285 | 286 | def load_user_config(): 287 | ''' 288 | Loads custom model-specific settings 289 | ''' 290 | if Path(f'{args.model_dir}/config-user.yaml').exists(): 291 | file_content = open(f'{args.model_dir}/config-user.yaml', 'r').read().strip() 292 | 293 | if file_content: 294 | user_config = yaml.safe_load(file_content) 295 | else: 296 | user_config = {} 297 | else: 298 | user_config = {} 299 | 300 | return user_config 301 | 302 | 303 | args.loader = fix_loader_name(args.loader) 304 | 305 | # Activate the multimodal extension 306 | if args.multimodal_pipeline is not None: 307 | add_extension('multimodal') 308 | 309 | # Activate the API extension 310 | if args.api or args.public_api: 311 | add_extension('openai', last=True) 312 | 313 | # Load model-specific settings 314 | with Path(f'{args.model_dir}/config.yaml') as p: 315 | if p.exists(): 316 | model_config = yaml.safe_load(open(p, 'r').read()) 317 | else: 318 | model_config = {} 319 | 320 | # Load custom model-specific settings 321 | user_config = load_user_config() 322 | 323 | model_config = OrderedDict(model_config) 324 | user_config = OrderedDict(user_config) 325 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TextGenTips 2 | Collection of tips for using textgen in various ways 3 | 4 | # Single GPU TIPS 5 | A collection of resources to help folks get a lot out of a single 24GB GPU (these instructions are for Nvidia cards) 6 | 7 | Even though models are offloaded, they use a little bit of vram. So if your favorite model uses your gpu comptelty it may not work, you may need to quantize the model or move whisper to cpu mode. 8 | 9 | This is in reference to this reddit post: 10 | 11 | https://old.reddit.com/r/LocalLLaMA/comments/1e3aboz/folks_with_one_24gb_gpu_you_can_use_an_llm_sdxl/ 12 | https://www.reddit.com/r/LocalLLaMA/comments/1e3aboz/folks_with_one_24gb_gpu_you_can_use_an_llm_sdxl/ 13 | 14 | It's important to familiarize yourself with each extension on its own before trying to use them all at once, and to watch your GPU VRAM to make sure things are working as expected. 15 | 16 | This is the order of my CMD_Flag, the order is important, if you don't want to use an extension remove it form the list but don't change the order (If you are clicking the UI radio buttons to load your extensions, you need to click each button in the same sequence as the --extensions flag shows 17 | 18 | `--extensions text-generation-webui-model_ducking Lucid_Vision alltalk_tts whisper_stt sd_api_pictures` 19 | 20 | 1. Get text-generation-webui here: https://github.com/oobabooga/text-generation-webui/releases 21 | 22 | 2. After installation, install the text-generation-webui-model_ducking extension here: https://github.com/oobabooga/text-generation-webui-extensions?tab=readme-ov-file#model-ducking 23 | 24 | 3. Load a model in textgen and test out the model_ducking extension, make sure it is working for you. 25 | 26 | 4. Install Lucid_Vision here: https://github.com/oobabooga/text-generation-webui-extensions?tab=readme-ov-file#lucid_vision 27 | 28 | 5. If you do install Lucid_Vision do a test image in the UI without involving the LLM, I don't know what the issue is, but often gradio will timeout the first time you load a model (sometimes LLM models) when this happens the vision model received nothing and the ui will say "error" just try another picture and things should work well from that point on (this is mentioned in the Lucid_Vision repo as well) 29 | 30 | 6. I recommend disabling DeepseekVL when you first use Lucid_Vision https://github.com/RandomInternetPreson/Lucid_Vision?tab=readme-ov-file#model-information The reason being is that to use it you need to import their dependencies, which may cause issues. If things are working for you and you want to use deepseekVL feel free to install it as per the instructions it's unlikely to cause conflicts but not guaranteed. 31 | 32 | 7. Lucid_Vision relies on your LLMs abilities to contextualize how to use the extension by itself, the manual image upload will work regardless of your LLM. 33 | 34 | 8. Install alltalk_tts here: https://github.com/oobabooga/text-generation-webui-extensions?tab=readme-ov-file#alltalk-tts 35 | 36 | 9. erew123 Has a lot of good documentation on how to get Alltalk running in windows and linux, if you want to get the render speeds of the video you need to install deepspeed. It works in Linux just fine, but Windows needs to install the prebuilt deepseed wheels. https://github.com/erew123/alltalk_tts?#-deepspeed-installation-options 37 | 38 | 10. To install deepspeed this is how I do it (these are not instructions, there are other steps in erew123's documentation, but every new install only nees roughly these steps), read erew123's documentation depending on your system (windows is a bit easier): 39 | 40 | ``` 41 | *via the cmd_linux.sh file depending on your os, then close terminal* 42 | pip install libaio 43 | 44 | *Start Textgen* 45 | 46 | *Open Textgen Enviroment via the cmd_linux.sh file depending on your os, I have cuda 12.1 installed* 47 | export CUDA_HOME=/usr/local/cuda-12.1 48 | 49 | cd extensions/alltalk_tts 50 | 51 | chmod +x atsetup.sh 52 | 53 | ./atsetup.sh 54 | 55 | *follow installation instructions and instructions on how to install deepspeed* 56 | ``` 57 | 58 | 11. If you want to use alltalk on linux this is a solution to a perimeninant export CUDA_HOME="/usr/local/cuda" if you are having issues 59 | https://github.com/erew123/alltalk_tts/issues/107 60 | 61 | 12. Whisper_STT comes with textgen you just need to install the dependendies, oobabooga has created an update_wizard file for each operating system, simply click on that and follow the instructions to install the requried Whisper_STT dependencies 62 | 63 | 13. Okay now the last extension, sd_api_pictures. You don't need to install anything for this extension, but if you want the same functionality as the video, you need to replace your script.py file in the sd_api_pictures extensions folder with this version: https://raw.githubusercontent.com/RandomInternetPreson/TextGenTips/main/sd_api_pictures/with_ADetailer/script.py 64 | 65 | This is a simplified version of the original version with some changes that unload and load the model from GPU RAM to and from CPU RAM, auto1111's api will be used when the words "send" and "prompt" are see in the users' message to the LLM. So if you type "send me a prompt of a kitten attacking a great white shark underwater" the LLMs response to you will be fed to the stable diffusion model. 66 | 67 | This code creates images that are 1024x1024, so if you are streaming this on a mobile device it is best to use "desktop mode" else the pictures might be too big for the UI. 68 | 69 | It is these two functions that are the main difference: 70 | 71 | ``` 72 | def load_sd_model(): 73 | print("Loading the Stable Diffusion model into VRAM...") 74 | response = requests.post(url='http://127.0.0.1:7861/sdapi/v1/reload-checkpoint', json='') 75 | response.raise_for_status() 76 | del response 77 | 78 | def unload_sd_model(): 79 | print("Unloading the Stable Diffusion model from VRAM...") 80 | response = requests.post(url='http://127.0.0.1:7861/sdapi/v1/unload-checkpoint', json='') 81 | response.raise_for_status() 82 | del response 83 | ``` 84 | 85 | The original code has something like this too, but I could not get it to work with Text-generation-webui-model_ducking. 86 | 87 | Additionally, the code uses ADetalier, so if you don't use that extension in auto1111 use this version of the code: 88 | https://raw.githubusercontent.com/RandomInternetPreson/TextGenTips/main/sd_api_pictures/without_ADetailer/script.py 89 | 90 | You need to edit your webui-user in the auto1111 install directory so that Auto1111 is running on port 7861, or you can edit the script.py file 91 | 92 | 93 | Now for my slighlty janky implementation, you need to start Auto1111 first, then open the UI and manually load the model onto CPU Ram by checking the "Uload the SD Checkpoiint to RAM" button. You can close the auto1111 UI at this point, and don't need to restart it if you need to restart textgen. I'll work on the code doing this automatically later. 94 | 95 | Go to setting, then Actions at the bottom 96 | 97 | ![image](https://github.com/user-attachments/assets/0f1cfece-6410-4246-89b0-63a2cdbc4663) 98 | 99 | Then scroll up and click the "Uload the SD Checkpoiint to RAM" button 100 | 101 | ![image](https://github.com/user-attachments/assets/87ce6127-d7a7-4f8e-af82-0df8fc4656e9) 102 | 103 | Now start textgen and you should be good to go. 104 | 105 | 14. Test out this sd_api_pictures extension, make sure it is working for you. You need to be clever about how you explain to your model how the extension works, usually I edit the first attmept of the LLM if it's not the right formatting and then it follows that example for subsequent prompt requests. 106 | 107 | You can try telling your llm this: 108 | 109 | ``` 110 | I want you to send me a prompt of a starfish in a tuxedo. And so in doing that, just respond back to me with exactly this phrase, "starfish wearing tuxedo." Do not preface or add any additional text to your reponse back to me. 111 | ``` 112 | 113 | 114 | 115 | # ExllamaV2 tensor parallelism for OOB V1.14 116 | 117 | 1. Install the updated exllamav2 repo using the textgen cmd terminal for you os (linux for example is cmd_linux.sh) 118 | 119 | Within the terminal navigate to the repositories folder "cd repositories" 120 | 121 | Run these commands (if using linux run this command first "export CUDA_HOME=/usr/local/cuda-12.1" for your version of cuda, I'm using 12.1 you might be running something different look at the directory location to figure out which you are using) 122 | ``` 123 | git clone https://github.com/turboderp/exllamav2 124 | cd exllamav2 125 | pip install -r requirements.txt 126 | pip install . 127 | ``` 128 | 2. After you've done this replace your exllamav2.py file and shared.py (located in the modules folder) file with these: https://github.com/RandomInternetPreson/TextGenTips/tree/main/ExllamaV2_TensorParallel_Files 129 | 3. When running textgen add 130 | ``` 131 | --enable_tp 132 | ``` 133 | to your CMD_FLAGS.txt file 134 | -------------------------------------------------------------------------------- /sd_api_pictures/with_ADetailer/script.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import re 4 | import time 5 | from datetime import date 6 | from pathlib import Path 7 | 8 | import gradio as gr 9 | import requests 10 | import torch 11 | from PIL import Image 12 | 13 | from modules import shared 14 | from modules.models import reload_model, unload_model 15 | from modules.ui import create_refresh_button 16 | 17 | torch._C._jit_set_profiling_mode(False) 18 | 19 | # parameters which can be customized in settings.json of webui 20 | params = { 21 | 'address': 'http://127.0.0.1:7861', 22 | 'mode': 1, # modes of operation: 0 (Manual only), 1 (Immersive/Interactive - looks for words to trigger), 2 (Picturebook Adventure - Always on) 23 | 'manage_VRAM': False, 24 | 'save_img': True, 25 | 'SD_model': 'NeverEndingDream', # not used right now 26 | 'prompt_prefix': 'detailed', 27 | 'negative_prompt': 'cartoon, blurry, distorted, cgi', 28 | 'width': 1024, 29 | 'height': 1024, 30 | 'denoising_strength': 0.61, 31 | 'restore_faces': False, 32 | 'enable_hr': False, 33 | 'hr_upscaler': 'ESRGAN_4x', 34 | 'hr_scale': '1.0', 35 | 'seed': -1, 36 | 'sampler_name': 'DPM++ 2M Karras', 37 | 'steps': 32, 38 | 'cfg_scale': 10.5, 39 | 'textgen_prefix': 'detailed, photorealisitc', 40 | 'sd_checkpoint': ' ', 41 | 'checkpoint_list': [" "] 42 | } 43 | 44 | 45 | def give_VRAM_priority(actor): 46 | global shared, params 47 | 48 | if actor == 'SD': 49 | unload_model() 50 | print("Requesting Auto1111 to re-load last checkpoint used...") 51 | response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='') 52 | response.raise_for_status() 53 | 54 | elif actor == 'LLM': 55 | print("Requesting Auto1111 to vacate VRAM...") 56 | response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='') 57 | response.raise_for_status() 58 | reload_model() 59 | 60 | elif actor == 'set': 61 | print("VRAM mangement activated -- requesting Auto1111 to vacate VRAM...") 62 | response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='') 63 | response.raise_for_status() 64 | 65 | elif actor == 'reset': 66 | print("VRAM mangement deactivated -- requesting Auto1111 to reload checkpoint") 67 | response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='') 68 | response.raise_for_status() 69 | 70 | else: 71 | raise RuntimeError(f'Managing VRAM: "{actor}" is not a known state!') 72 | 73 | response.raise_for_status() 74 | del response 75 | 76 | 77 | if params['manage_VRAM']: 78 | give_VRAM_priority('set') 79 | 80 | SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select 81 | 82 | picture_response = False # specifies if the next model response should appear as a picture 83 | 84 | 85 | def remove_surrounded_chars(string): 86 | # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR 87 | # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' 88 | return re.sub('\*[^\*]*?(\*|$)', '', string) 89 | 90 | 91 | def triggers_are_in(string): 92 | string = remove_surrounded_chars(string) 93 | # regex searches for send|main|message|me (at the end of the word) followed by 94 | # a whole word of image|pic|picture|photo|snap|snapshot|selfie|meme(s), 95 | # (?aims) are regex parser flags 96 | #return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string)) 97 | # return bool(re.search('(?aims)(send|mail|message)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme|prompt)s?\\b', string)) 98 | return bool(re.search('(?aims)(send|mail|message)\\b.+?\\b(selfie|meme|prompt)s?\\b', string)) 99 | 100 | 101 | def state_modifier(state): 102 | if picture_response: 103 | state['stream'] = False 104 | 105 | return state 106 | 107 | 108 | def input_modifier(string): 109 | """ 110 | This function is applied to your text inputs before 111 | they are fed into the model. 112 | """ 113 | 114 | global params 115 | 116 | if not params['mode'] == 1: # if not in immersive/interactive mode, do nothing 117 | return string 118 | 119 | if triggers_are_in(string): # if we're in it, check for trigger words 120 | toggle_generation(True) 121 | string = string.lower() 122 | #if "of" in string: 123 | #subject = string.split('of', 1)[1] # subdivide the string once by the first 'of' instance and get what's coming after it 124 | #string = params['textgen_prefix'].replace("[subject]", subject) 125 | #else: 126 | #string = params['textgen_prefix'].replace("[subject]", "a description of what the user requested, limited to 2 sentences using the picture examples provided to you ") 127 | #string = params['textgen_prefix'].replace("[subject]", "your appearance, your surroundings and what you are doing right now") 128 | 129 | return string 130 | 131 | 132 | 133 | def load_sd_model(): 134 | print("Loading the Stable Diffusion model into VRAM...") 135 | response = requests.post(url='http://127.0.0.1:7861/sdapi/v1/reload-checkpoint', json='') 136 | response.raise_for_status() 137 | del response 138 | 139 | def unload_sd_model(): 140 | print("Unloading the Stable Diffusion model from VRAM...") 141 | response = requests.post(url='http://127.0.0.1:7861/sdapi/v1/unload-checkpoint', json='') 142 | response.raise_for_status() 143 | del response 144 | 145 | # Get and save the Stable Diffusion-generated picture 146 | def get_SD_pictures(description, character): 147 | load_sd_model() 148 | global params 149 | 150 | if params['manage_VRAM']: 151 | give_VRAM_priority('SD') 152 | 153 | description = re.sub('', ' ', description) 154 | description = f"({description}:1)" 155 | 156 | payload = { 157 | "prompt": params['textgen_prefix'] + description, 158 | "alwayson_scripts": { 159 | "ADetailer": { 160 | "args": [ 161 | { 162 | "ad_model": "face_yolov8n.pt", 163 | "ad_prompt": params['prompt_prefix'], 164 | "ad_inpaint_width": 1024, 165 | "ad_inpaint_height": 1024, 166 | "ad_negative_prompt": 'drawing, unrealistic eyeballs, cartoon, blurry eyes, ugly, perfect skin' 167 | } 168 | ] 169 | } 170 | }, 171 | "seed": params['seed'], 172 | "sampler_name": params['sampler_name'], 173 | "enable_hr": params['enable_hr'], 174 | "hr_scale": params['hr_scale'], 175 | "hr_upscaler": params['hr_upscaler'], 176 | "denoising_strength": params['denoising_strength'], 177 | "steps": params['steps'], 178 | "cfg_scale": params['cfg_scale'], 179 | "width": params['width'], 180 | "height": params['height'], 181 | "restore_faces": params['restore_faces'], 182 | "override_settings_restore_afterwards": True, 183 | "negative_prompt": params['negative_prompt'] 184 | 185 | } 186 | 187 | 188 | print(f'Prompting the image generator via the API on {params["address"]}...') 189 | response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload) 190 | response.raise_for_status() 191 | r = response.json() 192 | 193 | visible_result = "" 194 | for img_str in r['images']: 195 | if params['save_img']: 196 | img_data = base64.b64decode(img_str) 197 | image = Image.open(io.BytesIO(img_data)) ##new 198 | image = image.resize((800, 800)) ##new 199 | 200 | variadic = f'{date.today().strftime("%Y_%m_%d")}/{character}_{int(time.time())}' 201 | output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png') 202 | #output_file = Path(f'Z:/outputs/{variadic}.png') 203 | #output_file = Path(f'Z:/{variadic}.png') 204 | output_file.parent.mkdir(parents=True, exist_ok=True) 205 | 206 | with open(output_file.as_posix(), 'wb') as f: 207 | #f.write(img_data) 208 | image.save(f, format="PNG") ##new 209 | 210 | visible_result = visible_result + f'{description}\n' 211 | else: 212 | image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0]))) 213 | # lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history 214 | image.thumbnail((300, 300)) 215 | buffered = io.BytesIO() 216 | image.save(buffered, format="JPEG") 217 | buffered.seek(0) 218 | image_bytes = buffered.getvalue() 219 | img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode() 220 | visible_result = visible_result + f'{description}\n' 221 | 222 | if params['manage_VRAM']: 223 | give_VRAM_priority('LLM') 224 | unload_sd_model() 225 | return visible_result 226 | 227 | # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history) 228 | # and replace it with 'text' for the purposes of logging? 229 | def output_modifier(string, state): 230 | """ 231 | This function is applied to the model outputs. 232 | """ 233 | 234 | global picture_response, params 235 | 236 | if not picture_response: 237 | return string 238 | 239 | string = remove_surrounded_chars(string) 240 | string = string.replace('"', '') 241 | string = string.replace('“', '') 242 | string = string.replace('\n', ' ') 243 | string = string.strip() 244 | 245 | if string == '': 246 | string = 'no viable description in reply, try regenerating' 247 | return string 248 | 249 | text = "" 250 | if (params['mode'] < 2): 251 | toggle_generation(False) 252 | text = f'*Sends a picture which portrays: “{string}”*' 253 | else: 254 | text = string 255 | 256 | string = get_SD_pictures(string, state['character_menu']) + "\n" + text 257 | 258 | return string 259 | 260 | 261 | def bot_prefix_modifier(string): 262 | """ 263 | This function is only applied in chat mode. It modifies 264 | the prefix text for the Bot and can be used to bias its 265 | behavior. 266 | """ 267 | 268 | return string 269 | 270 | 271 | def toggle_generation(*args): 272 | global picture_response, shared 273 | 274 | if not args: 275 | picture_response = not picture_response 276 | else: 277 | picture_response = args[0] 278 | 279 | shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*" 280 | 281 | 282 | def filter_address(address): 283 | address = address.strip() 284 | # address = re.sub('http(s)?:\/\/|\/$','',address) # remove starting http:// OR https:// OR trailing slash 285 | address = re.sub('\/$', '', address) # remove trailing /s 286 | if not address.startswith('http'): 287 | address = 'http://' + address 288 | return address 289 | 290 | 291 | def SD_api_address_update(address): 292 | global params 293 | 294 | msg = "✔️ SD API is found on:" 295 | address = filter_address(address) 296 | params.update({"address": address}) 297 | try: 298 | response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models') 299 | response.raise_for_status() 300 | # r = response.json() 301 | except: 302 | msg = "❌ No SD API endpoint on:" 303 | 304 | return gr.Textbox.update(label=msg) 305 | 306 | 307 | def custom_css(): 308 | path_to_css = Path(__file__).parent.resolve() / 'style.css' 309 | return open(path_to_css, 'r').read() 310 | 311 | 312 | def get_checkpoints(): 313 | global params 314 | 315 | try: 316 | models = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models') 317 | options = requests.get(url=f'{params["address"]}/sdapi/v1/options') 318 | options_json = options.json() 319 | params['sd_checkpoint'] = options_json['sd_model_checkpoint'] 320 | params['checkpoint_list'] = [result["title"] for result in models.json()] 321 | except: 322 | params['sd_checkpoint'] = "" 323 | params['checkpoint_list'] = [] 324 | 325 | return gr.update(choices=params['checkpoint_list'], value=params['sd_checkpoint']) 326 | 327 | 328 | def load_checkpoint(checkpoint): 329 | payload = { 330 | "sd_model_checkpoint": checkpoint 331 | } 332 | 333 | try: 334 | requests.post(url=f'{params["address"]}/sdapi/v1/options', json=payload) 335 | except: 336 | pass 337 | 338 | 339 | def get_samplers(): 340 | try: 341 | response = requests.get(url=f'{params["address"]}/sdapi/v1/samplers') 342 | response.raise_for_status() 343 | samplers = [x["name"] for x in response.json()] 344 | except: 345 | samplers = [] 346 | 347 | return samplers 348 | 349 | 350 | def ui(): 351 | 352 | # Gradio elements 353 | # gr.Markdown('### Stable Diffusion API Pictures') # Currently the name of extension is shown as the title 354 | with gr.Accordion("Parameters", open=True, elem_classes="SDAP"): 355 | with gr.Row(): 356 | address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Auto1111\'s WebUI address') 357 | modes_list = ["Manual", "Immersive/Interactive", "Picturebook/Adventure"] 358 | mode = gr.Dropdown(modes_list, value=modes_list[params['mode']], label="Mode of operation", type="index") 359 | with gr.Column(scale=1, min_width=300): 360 | manage_VRAM = gr.Checkbox(value=params['manage_VRAM'], label='Manage VRAM') 361 | save_img = gr.Checkbox(value=params['save_img'], label='Keep original images and use them in chat') 362 | 363 | force_pic = gr.Button("Force the picture response") 364 | suppr_pic = gr.Button("Suppress the picture response") 365 | with gr.Row(): 366 | checkpoint = gr.Dropdown(params['checkpoint_list'], value=params['sd_checkpoint'], label="Checkpoint", type="value") 367 | update_checkpoints = gr.Button("Get list of checkpoints") 368 | 369 | with gr.Accordion("Generation parameters", open=False): 370 | prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix Used for ADetailer Input') 371 | #prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)') 372 | textgen_prefix = gr.Textbox(placeholder=params['textgen_prefix'], value=params['textgen_prefix'], label='textgen prefix what is sent at the beginning of description') 373 | negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt') 374 | with gr.Row(): 375 | with gr.Column(): 376 | width = gr.Slider(64, 2048, value=params['width'], step=64, label='Width') 377 | height = gr.Slider(64, 2048, value=params['height'], step=64, label='Height') 378 | with gr.Column(variant="compact", elem_id="sampler_col"): 379 | with gr.Row(elem_id="sampler_row"): 380 | sampler_name = gr.Dropdown(value=params['sampler_name'], allow_custom_value=True, label='Sampling method', elem_id="sampler_box") 381 | create_refresh_button(sampler_name, lambda: None, lambda: {'choices': get_samplers()}, 'refresh-button') 382 | steps = gr.Slider(1, 150, value=params['steps'], step=1, label="Sampling steps", elem_id="steps_box") 383 | with gr.Row(): 384 | seed = gr.Number(label="Seed", value=params['seed'], elem_id="seed_box") 385 | cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box") 386 | with gr.Column() as hr_options: 387 | restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces') 388 | enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix') 389 | with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options: 390 | hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by') 391 | denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength') 392 | hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler') 393 | 394 | # Event functions to update the parameters in the backend 395 | address.change(lambda x: params.update({"address": filter_address(x)}), address, None) 396 | mode.select(lambda x: params.update({"mode": x}), mode, None) 397 | mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None) 398 | manage_VRAM.change(lambda x: params.update({"manage_VRAM": x}), manage_VRAM, None) 399 | manage_VRAM.change(lambda x: give_VRAM_priority('set' if x else 'reset'), inputs=manage_VRAM, outputs=None) 400 | save_img.change(lambda x: params.update({"save_img": x}), save_img, None) 401 | 402 | address.submit(fn=SD_api_address_update, inputs=address, outputs=address) 403 | prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None) 404 | textgen_prefix.change(lambda x: params.update({"textgen_prefix": x}), textgen_prefix, None) 405 | negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None) 406 | width.change(lambda x: params.update({"width": x}), width, None) 407 | height.change(lambda x: params.update({"height": x}), height, None) 408 | hr_scale.change(lambda x: params.update({"hr_scale": x}), hr_scale, None) 409 | denoising_strength.change(lambda x: params.update({"denoising_strength": x}), denoising_strength, None) 410 | restore_faces.change(lambda x: params.update({"restore_faces": x}), restore_faces, None) 411 | hr_upscaler.change(lambda x: params.update({"hr_upscaler": x}), hr_upscaler, None) 412 | enable_hr.change(lambda x: params.update({"enable_hr": x}), enable_hr, None) 413 | enable_hr.change(lambda x: hr_options.update(visible=params["enable_hr"]), enable_hr, hr_options) 414 | update_checkpoints.click(get_checkpoints, None, checkpoint) 415 | checkpoint.change(lambda x: params.update({"sd_checkpoint": x}), checkpoint, None) 416 | checkpoint.change(load_checkpoint, checkpoint, None) 417 | 418 | sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None) 419 | steps.change(lambda x: params.update({"steps": x}), steps, None) 420 | seed.change(lambda x: params.update({"seed": x}), seed, None) 421 | cfg_scale.change(lambda x: params.update({"cfg_scale": x}), cfg_scale, None) 422 | 423 | force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None) 424 | suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None) 425 | -------------------------------------------------------------------------------- /sd_api_pictures/without_ADetailer/script.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import re 4 | import time 5 | from datetime import date 6 | from pathlib import Path 7 | 8 | import gradio as gr 9 | import requests 10 | import torch 11 | from PIL import Image 12 | 13 | from modules import shared 14 | from modules.models import reload_model, unload_model 15 | from modules.ui import create_refresh_button 16 | 17 | torch._C._jit_set_profiling_mode(False) 18 | 19 | # parameters which can be customized in settings.json of webui 20 | params = { 21 | 'address': 'http://127.0.0.1:7861', 22 | 'mode': 1, # modes of operation: 0 (Manual only), 1 (Immersive/Interactive - looks for words to trigger), 2 (Picturebook Adventure - Always on) 23 | 'manage_VRAM': False, 24 | 'save_img': True, 25 | 'SD_model': 'NeverEndingDream', # not used right now 26 | 'prompt_prefix': 'detailed', 27 | 'negative_prompt': 'cartoon, blurry, distorted, cgi', 28 | 'width': 1024, 29 | 'height': 1024, 30 | 'denoising_strength': 0.61, 31 | 'restore_faces': False, 32 | 'enable_hr': False, 33 | 'hr_upscaler': 'ESRGAN_4x', 34 | 'hr_scale': '1.0', 35 | 'seed': -1, 36 | 'sampler_name': 'DPM++ 2M Karras', 37 | 'steps': 32, 38 | 'cfg_scale': 10.5, 39 | 'textgen_prefix': 'detailed, photorealisitc', 40 | 'sd_checkpoint': ' ', 41 | 'checkpoint_list': [" "] 42 | } 43 | 44 | 45 | def give_VRAM_priority(actor): 46 | global shared, params 47 | 48 | if actor == 'SD': 49 | unload_model() 50 | print("Requesting Auto1111 to re-load last checkpoint used...") 51 | response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='') 52 | response.raise_for_status() 53 | 54 | elif actor == 'LLM': 55 | print("Requesting Auto1111 to vacate VRAM...") 56 | response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='') 57 | response.raise_for_status() 58 | reload_model() 59 | 60 | elif actor == 'set': 61 | print("VRAM mangement activated -- requesting Auto1111 to vacate VRAM...") 62 | response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='') 63 | response.raise_for_status() 64 | 65 | elif actor == 'reset': 66 | print("VRAM mangement deactivated -- requesting Auto1111 to reload checkpoint") 67 | response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='') 68 | response.raise_for_status() 69 | 70 | else: 71 | raise RuntimeError(f'Managing VRAM: "{actor}" is not a known state!') 72 | 73 | response.raise_for_status() 74 | del response 75 | 76 | 77 | if params['manage_VRAM']: 78 | give_VRAM_priority('set') 79 | 80 | SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select 81 | 82 | picture_response = False # specifies if the next model response should appear as a picture 83 | 84 | 85 | def remove_surrounded_chars(string): 86 | # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR 87 | # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' 88 | return re.sub('\*[^\*]*?(\*|$)', '', string) 89 | 90 | 91 | def triggers_are_in(string): 92 | string = remove_surrounded_chars(string) 93 | # regex searches for send|main|message|me (at the end of the word) followed by 94 | # a whole word of image|pic|picture|photo|snap|snapshot|selfie|meme(s), 95 | # (?aims) are regex parser flags 96 | #return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string)) 97 | # return bool(re.search('(?aims)(send|mail|message)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme|prompt)s?\\b', string)) 98 | return bool(re.search('(?aims)(send|mail|message)\\b.+?\\b(selfie|meme|prompt)s?\\b', string)) 99 | 100 | 101 | def state_modifier(state): 102 | if picture_response: 103 | state['stream'] = False 104 | 105 | return state 106 | 107 | 108 | def input_modifier(string): 109 | """ 110 | This function is applied to your text inputs before 111 | they are fed into the model. 112 | """ 113 | 114 | global params 115 | 116 | if not params['mode'] == 1: # if not in immersive/interactive mode, do nothing 117 | return string 118 | 119 | if triggers_are_in(string): # if we're in it, check for trigger words 120 | toggle_generation(True) 121 | string = string.lower() 122 | #if "of" in string: 123 | #subject = string.split('of', 1)[1] # subdivide the string once by the first 'of' instance and get what's coming after it 124 | #string = params['textgen_prefix'].replace("[subject]", subject) 125 | #else: 126 | #string = params['textgen_prefix'].replace("[subject]", "a description of what the user requested, limited to 2 sentences using the picture examples provided to you ") 127 | #string = params['textgen_prefix'].replace("[subject]", "your appearance, your surroundings and what you are doing right now") 128 | 129 | return string 130 | 131 | 132 | 133 | def load_sd_model(): 134 | print("Loading the Stable Diffusion model into VRAM...") 135 | response = requests.post(url='http://127.0.0.1:7861/sdapi/v1/reload-checkpoint', json='') 136 | response.raise_for_status() 137 | del response 138 | 139 | def unload_sd_model(): 140 | print("Unloading the Stable Diffusion model from VRAM...") 141 | response = requests.post(url='http://127.0.0.1:7861/sdapi/v1/unload-checkpoint', json='') 142 | response.raise_for_status() 143 | del response 144 | 145 | # Get and save the Stable Diffusion-generated picture 146 | def get_SD_pictures(description, character): 147 | load_sd_model() 148 | global params 149 | 150 | if params['manage_VRAM']: 151 | give_VRAM_priority('SD') 152 | 153 | description = re.sub('', ' ', description) 154 | description = f"({description}:1)" 155 | 156 | payload = { 157 | "prompt": params['textgen_prefix'] + description, 158 | "seed": params['seed'], 159 | "sampler_name": params['sampler_name'], 160 | "enable_hr": params['enable_hr'], 161 | "hr_scale": params['hr_scale'], 162 | "hr_upscaler": params['hr_upscaler'], 163 | "denoising_strength": params['denoising_strength'], 164 | "steps": params['steps'], 165 | "cfg_scale": params['cfg_scale'], 166 | "width": params['width'], 167 | "height": params['height'], 168 | "restore_faces": params['restore_faces'], 169 | "override_settings_restore_afterwards": True, 170 | "negative_prompt": params['negative_prompt'] 171 | 172 | } 173 | 174 | 175 | print(f'Prompting the image generator via the API on {params["address"]}...') 176 | response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload) 177 | response.raise_for_status() 178 | r = response.json() 179 | 180 | visible_result = "" 181 | for img_str in r['images']: 182 | if params['save_img']: 183 | img_data = base64.b64decode(img_str) 184 | image = Image.open(io.BytesIO(img_data)) ##new 185 | image = image.resize((800, 800)) ##new 186 | 187 | variadic = f'{date.today().strftime("%Y_%m_%d")}/{character}_{int(time.time())}' 188 | output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png') 189 | #output_file = Path(f'Z:/outputs/{variadic}.png') 190 | #output_file = Path(f'Z:/{variadic}.png') 191 | output_file.parent.mkdir(parents=True, exist_ok=True) 192 | 193 | with open(output_file.as_posix(), 'wb') as f: 194 | #f.write(img_data) 195 | image.save(f, format="PNG") ##new 196 | 197 | visible_result = visible_result + f'{description}\n' 198 | else: 199 | image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0]))) 200 | # lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history 201 | image.thumbnail((300, 300)) 202 | buffered = io.BytesIO() 203 | image.save(buffered, format="JPEG") 204 | buffered.seek(0) 205 | image_bytes = buffered.getvalue() 206 | img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode() 207 | visible_result = visible_result + f'{description}\n' 208 | 209 | if params['manage_VRAM']: 210 | give_VRAM_priority('LLM') 211 | unload_sd_model() 212 | return visible_result 213 | 214 | # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history) 215 | # and replace it with 'text' for the purposes of logging? 216 | def output_modifier(string, state): 217 | """ 218 | This function is applied to the model outputs. 219 | """ 220 | 221 | global picture_response, params 222 | 223 | if not picture_response: 224 | return string 225 | 226 | string = remove_surrounded_chars(string) 227 | string = string.replace('"', '') 228 | string = string.replace('“', '') 229 | string = string.replace('\n', ' ') 230 | string = string.strip() 231 | 232 | if string == '': 233 | string = 'no viable description in reply, try regenerating' 234 | return string 235 | 236 | text = "" 237 | if (params['mode'] < 2): 238 | toggle_generation(False) 239 | text = f'*Sends a picture which portrays: “{string}”*' 240 | else: 241 | text = string 242 | 243 | string = get_SD_pictures(string, state['character_menu']) + "\n" + text 244 | 245 | return string 246 | 247 | 248 | def bot_prefix_modifier(string): 249 | """ 250 | This function is only applied in chat mode. It modifies 251 | the prefix text for the Bot and can be used to bias its 252 | behavior. 253 | """ 254 | 255 | return string 256 | 257 | 258 | def toggle_generation(*args): 259 | global picture_response, shared 260 | 261 | if not args: 262 | picture_response = not picture_response 263 | else: 264 | picture_response = args[0] 265 | 266 | shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*" 267 | 268 | 269 | def filter_address(address): 270 | address = address.strip() 271 | # address = re.sub('http(s)?:\/\/|\/$','',address) # remove starting http:// OR https:// OR trailing slash 272 | address = re.sub('\/$', '', address) # remove trailing /s 273 | if not address.startswith('http'): 274 | address = 'http://' + address 275 | return address 276 | 277 | 278 | def SD_api_address_update(address): 279 | global params 280 | 281 | msg = "✔️ SD API is found on:" 282 | address = filter_address(address) 283 | params.update({"address": address}) 284 | try: 285 | response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models') 286 | response.raise_for_status() 287 | # r = response.json() 288 | except: 289 | msg = "❌ No SD API endpoint on:" 290 | 291 | return gr.Textbox.update(label=msg) 292 | 293 | 294 | def custom_css(): 295 | path_to_css = Path(__file__).parent.resolve() / 'style.css' 296 | return open(path_to_css, 'r').read() 297 | 298 | 299 | def get_checkpoints(): 300 | global params 301 | 302 | try: 303 | models = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models') 304 | options = requests.get(url=f'{params["address"]}/sdapi/v1/options') 305 | options_json = options.json() 306 | params['sd_checkpoint'] = options_json['sd_model_checkpoint'] 307 | params['checkpoint_list'] = [result["title"] for result in models.json()] 308 | except: 309 | params['sd_checkpoint'] = "" 310 | params['checkpoint_list'] = [] 311 | 312 | return gr.update(choices=params['checkpoint_list'], value=params['sd_checkpoint']) 313 | 314 | 315 | def load_checkpoint(checkpoint): 316 | payload = { 317 | "sd_model_checkpoint": checkpoint 318 | } 319 | 320 | try: 321 | requests.post(url=f'{params["address"]}/sdapi/v1/options', json=payload) 322 | except: 323 | pass 324 | 325 | 326 | def get_samplers(): 327 | try: 328 | response = requests.get(url=f'{params["address"]}/sdapi/v1/samplers') 329 | response.raise_for_status() 330 | samplers = [x["name"] for x in response.json()] 331 | except: 332 | samplers = [] 333 | 334 | return samplers 335 | 336 | 337 | def ui(): 338 | 339 | # Gradio elements 340 | # gr.Markdown('### Stable Diffusion API Pictures') # Currently the name of extension is shown as the title 341 | with gr.Accordion("Parameters", open=True, elem_classes="SDAP"): 342 | with gr.Row(): 343 | address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Auto1111\'s WebUI address') 344 | modes_list = ["Manual", "Immersive/Interactive", "Picturebook/Adventure"] 345 | mode = gr.Dropdown(modes_list, value=modes_list[params['mode']], label="Mode of operation", type="index") 346 | with gr.Column(scale=1, min_width=300): 347 | manage_VRAM = gr.Checkbox(value=params['manage_VRAM'], label='Manage VRAM') 348 | save_img = gr.Checkbox(value=params['save_img'], label='Keep original images and use them in chat') 349 | 350 | force_pic = gr.Button("Force the picture response") 351 | suppr_pic = gr.Button("Suppress the picture response") 352 | with gr.Row(): 353 | checkpoint = gr.Dropdown(params['checkpoint_list'], value=params['sd_checkpoint'], label="Checkpoint", type="value") 354 | update_checkpoints = gr.Button("Get list of checkpoints") 355 | 356 | with gr.Accordion("Generation parameters", open=False): 357 | prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix Used for ADetailer Input') 358 | #prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)') 359 | textgen_prefix = gr.Textbox(placeholder=params['textgen_prefix'], value=params['textgen_prefix'], label='textgen prefix what is sent at the beginning of description') 360 | negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt') 361 | with gr.Row(): 362 | with gr.Column(): 363 | width = gr.Slider(64, 2048, value=params['width'], step=64, label='Width') 364 | height = gr.Slider(64, 2048, value=params['height'], step=64, label='Height') 365 | with gr.Column(variant="compact", elem_id="sampler_col"): 366 | with gr.Row(elem_id="sampler_row"): 367 | sampler_name = gr.Dropdown(value=params['sampler_name'], allow_custom_value=True, label='Sampling method', elem_id="sampler_box") 368 | create_refresh_button(sampler_name, lambda: None, lambda: {'choices': get_samplers()}, 'refresh-button') 369 | steps = gr.Slider(1, 150, value=params['steps'], step=1, label="Sampling steps", elem_id="steps_box") 370 | with gr.Row(): 371 | seed = gr.Number(label="Seed", value=params['seed'], elem_id="seed_box") 372 | cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box") 373 | with gr.Column() as hr_options: 374 | restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces') 375 | enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix') 376 | with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options: 377 | hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by') 378 | denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength') 379 | hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler') 380 | 381 | # Event functions to update the parameters in the backend 382 | address.change(lambda x: params.update({"address": filter_address(x)}), address, None) 383 | mode.select(lambda x: params.update({"mode": x}), mode, None) 384 | mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None) 385 | manage_VRAM.change(lambda x: params.update({"manage_VRAM": x}), manage_VRAM, None) 386 | manage_VRAM.change(lambda x: give_VRAM_priority('set' if x else 'reset'), inputs=manage_VRAM, outputs=None) 387 | save_img.change(lambda x: params.update({"save_img": x}), save_img, None) 388 | 389 | address.submit(fn=SD_api_address_update, inputs=address, outputs=address) 390 | prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None) 391 | textgen_prefix.change(lambda x: params.update({"textgen_prefix": x}), textgen_prefix, None) 392 | negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None) 393 | width.change(lambda x: params.update({"width": x}), width, None) 394 | height.change(lambda x: params.update({"height": x}), height, None) 395 | hr_scale.change(lambda x: params.update({"hr_scale": x}), hr_scale, None) 396 | denoising_strength.change(lambda x: params.update({"denoising_strength": x}), denoising_strength, None) 397 | restore_faces.change(lambda x: params.update({"restore_faces": x}), restore_faces, None) 398 | hr_upscaler.change(lambda x: params.update({"hr_upscaler": x}), hr_upscaler, None) 399 | enable_hr.change(lambda x: params.update({"enable_hr": x}), enable_hr, None) 400 | enable_hr.change(lambda x: hr_options.update(visible=params["enable_hr"]), enable_hr, hr_options) 401 | update_checkpoints.click(get_checkpoints, None, checkpoint) 402 | checkpoint.change(lambda x: params.update({"sd_checkpoint": x}), checkpoint, None) 403 | checkpoint.change(load_checkpoint, checkpoint, None) 404 | 405 | sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None) 406 | steps.change(lambda x: params.update({"steps": x}), steps, None) 407 | seed.change(lambda x: params.update({"seed": x}), seed, None) 408 | cfg_scale.change(lambda x: params.update({"cfg_scale": x}), cfg_scale, None) 409 | 410 | force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None) 411 | suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None) 412 | --------------------------------------------------------------------------------