├── .gitignore ├── README.md ├── blocking_api.py ├── cli.py ├── download.py ├── requirements.txt ├── streaming_api.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | cache 2 | characters 3 | training/datasets 4 | extensions/silero_tts/outputs 5 | extensions/elevenlabs_tts/outputs 6 | extensions/sd_api_pictures/outputs 7 | extensions/multimodal/pipelines 8 | logs 9 | loras 10 | models 11 | repositories 12 | softprompts 13 | torch-dumps 14 | *pycache* 15 | */*pycache* 16 | */*/pycache* 17 | venv/ 18 | .venv/ 19 | .vscode 20 | *.bak 21 | *.ipynb 22 | *.log 23 | 24 | settings.json 25 | notification.mp3 26 | img_bot* 27 | img_me* 28 | prompts/[0-9]* 29 | models/config-user.yaml 30 | 31 | .DS_Store 32 | Thumbs.db 33 | 34 | cert.pem 35 | key.pem 36 | nohup.out 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Create a new conda environment 3 | ``` 4 | conda create -n autogptq python=3.10.9 5 | conda activate autogptq 6 | ``` 7 | ## Install Pytorch 8 | ``` 9 | conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia 10 | 11 | ``` 12 | ## Install AutoGPTQ 13 | ### Quick Installation 14 | ``` 15 | pip install auto-gptq 16 | ``` 17 | ### Install from source 18 | ``` 19 | mkdir repositories 20 | cd repositories 21 | 22 | git clone https://github.com/PanQiWei/AutoGPTQ.git && cd AutoGPTQ 23 | 24 | pip install . 25 | ``` 26 | 27 | ## Install dependencies 28 | ``` 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | ## Create a self-signed certificate 33 | ``` 34 | openssl req -x509 -out cert.pem -keyout key.pem \ 35 | -newkey rsa:2048 -nodes -sha256 \ 36 | -subj '/CN=localhost' -extensions EXT -config <( \ 37 | printf "[dn]\nCN=localhost\n[req]\ndistinguished_name = dn\n[EXT]\nsubjectAltName=DNS:localhost\nkeyUsage=digitalSignature\nextendedKeyUsage=serverAuth") 38 | ``` 39 | 40 | ## Download models 41 | ``` 42 | python download.py TheBloke/WizardCoder-15B-1.0-GPTQ 43 | ``` 44 | 45 | ## Usage 46 | 47 | 1. Blocking api, update the model name and model weight path in blocking_api.py and run. 48 | ``` 49 | python blocking_api.py 50 | 51 | ``` 52 | The server will start on localhost port 5000. 53 | 54 | To generate text, send a POST request to the /api/v1/generate endpoint. The request body should be a JSON object with the following keys: 55 | prompt: The input prompt (required). 56 | min_length: The minimum length of the sequence to be generated (optional, default is 0). 57 | max_length: The maximum length of the sequence to be generated (optional, default is 50). 58 | top_p: The nucleus sampling probability (optional, default is 0.95). 59 | temperature: The temperature for sampling (optional, default is 0.6). For example, you can use curl to send a request 60 | ``` 61 | curl -k -s -X POST https://localhost:5000/api/v1/generate \ 62 | -H "Content-Type: application/json" \ 63 | -d '{"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request\n### Instruction: write a for loop in typescript\n### Response:", "max_length": 1000, "temperature": 0.7}' 64 | ``` -------------------------------------------------------------------------------- /blocking_api.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, StoppingCriteriaList 3 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 4 | from utils import _SentinelTokenStoppingCriteria 5 | 6 | import json 7 | from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer 8 | from threading import Thread 9 | import ssl 10 | 11 | import logging 12 | 13 | DEV = "cuda:0" 14 | # Setup logging 15 | logging.basicConfig(level=logging.INFO) 16 | 17 | model_name_or_path = "./models/TheBloke/WizardCoder-15B-1.0-GPTQ" 18 | 19 | use_triton = False 20 | 21 | quantize_config = BaseQuantizeConfig( 22 | bits=4, # quantize model to 4-bit 23 | group_size=128, # it is recommended to set the value to 128 24 | desc_act=False, # desc_act and groupsize only works on triton 25 | ) 26 | 27 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) 28 | 29 | model = AutoGPTQForCausalLM.from_quantized(model_name_or_path, 30 | use_safetensors=True, 31 | device=DEV, 32 | use_triton=use_triton, 33 | quantize_config=quantize_config) 34 | 35 | model.eval() 36 | 37 | if not tokenizer.pad_token: 38 | tokenizer.pad_token = tokenizer.eos_token 39 | 40 | model.config.pad_token_id = tokenizer.pad_token_id 41 | 42 | class GenerateHandler: 43 | @staticmethod 44 | def handle_request(handler, body): 45 | handler.send_response(200) 46 | handler.send_header('Content-Type', 'application/json') 47 | handler.end_headers() 48 | 49 | text = body['prompt'] 50 | min_length = body.get('min_length', 0) 51 | max_new_tokens= body.get('max_new_tokens', 200) 52 | top_p = body.get('top_p', 0.95) 53 | top_k = body.get('top_k', 40) 54 | typical_p = body.get('typical_p', 1) 55 | do_sample = body.get('do_sample', True) 56 | temperature = body.get('temperature', 0.1) 57 | no_repeat_ngram_size = body.get('no_repeat_ngram_size', 0) 58 | num_beams = body.get('num_beams', 1) 59 | stopping_strings = body.get('stopping_strings', ['Human:', ]) 60 | 61 | input_ids = tokenizer.encode(text, return_tensors="pt").to(DEV) 62 | 63 | # handle stopping strings 64 | stopping_criteria_list = StoppingCriteriaList() 65 | if len(stopping_strings) > 0: 66 | sentinel_token_ids = [tokenizer.encode( 67 | string, add_special_tokens=False, return_tensors='pt').to(DEV) for string in stopping_strings] 68 | starting_idx = len(input_ids[0]) 69 | stopping_criteria_list.append(_SentinelTokenStoppingCriteria( 70 | sentinel_token_ids, starting_idx)) 71 | 72 | with torch.no_grad(): 73 | generated_ids = model.generate( 74 | input_ids =input_ids, 75 | min_length=min_length, 76 | max_new_tokens=max_new_tokens, 77 | top_p=top_p, 78 | top_k=top_k, 79 | typical_p=typical_p, 80 | do_sample=do_sample, 81 | temperature=temperature, 82 | no_repeat_ngram_size=no_repeat_ngram_size, 83 | num_beams=num_beams, 84 | stopping_criteria=stopping_criteria_list, 85 | ) 86 | 87 | generated_text = tokenizer.decode( 88 | [el.item() for el in generated_ids[0][starting_idx:]], skip_special_tokens=True) 89 | 90 | response = json.dumps( 91 | {'results': [{'text': generated_text.strip()}]}) 92 | handler.wfile.write(response.encode('utf-8')) 93 | 94 | 95 | class Handler(BaseHTTPRequestHandler): 96 | def do_POST(self): 97 | content_length = int(self.headers['Content-Length']) 98 | body = json.loads(self.rfile.read(content_length).decode('utf-8')) 99 | if self.path == '/api/v1/generate': 100 | GenerateHandler.handle_request(self, body) 101 | else: 102 | self.send_error(404) 103 | 104 | 105 | def _run_server(port: int, share: bool = False): 106 | address = '0.0.0.0' 107 | server = ThreadingHTTPServer((address, port), Handler) 108 | 109 | sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 110 | sslctx.load_cert_chain(certfile='cert.pem', keyfile="key.pem") 111 | server.socket = sslctx.wrap_socket(server.socket, server_side=True) 112 | 113 | # Log server start 114 | logging.info('Server is running on http://{}:{}'.format(address, port)) 115 | server.serve_forever() 116 | 117 | 118 | def start_server(port: int, share: bool = False): 119 | Thread(target=_run_server, args=[port, share]).start() 120 | 121 | 122 | if __name__ == '__main__': 123 | start_server(5000) -------------------------------------------------------------------------------- /cli.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import sys 4 | 5 | try: 6 | import websockets 7 | except ImportError: 8 | print("Websockets package not found. Make sure it's installed.") 9 | 10 | HOST = '127.0.0.1:5005' 11 | URI = f'ws://{HOST}/api/v1/stream' 12 | 13 | async def connect_to_server(uri): 14 | return await websockets.connect(uri, ping_interval=None) 15 | 16 | async def send_to_server(websocket, message): 17 | await websocket.send(message) 18 | 19 | while True: 20 | incoming_data = await websocket.recv() 21 | incoming_data = json.loads(incoming_data) 22 | 23 | match incoming_data['event']: 24 | case 'text_stream': 25 | yield incoming_data['text'] 26 | case 'stream_end': 27 | yield None 28 | 29 | async def print_response_stream(): 30 | conversation = [] 31 | websocket = await connect_to_server(URI) 32 | # prePrompt="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." 33 | while True: 34 | user_input = input("\nEnter your question: ") 35 | previous_response='' 36 | # append the user input to the conversation 37 | conversation.append({'role': 'Instruction', 'message': user_input}) 38 | prompt = "\n".join([f"### {turn['role']}: {turn['message']}" for turn in conversation]) 39 | prompt += "\n### Response:" 40 | 41 | async for response in send_to_server(websocket, prompt): 42 | if response is not None: 43 | print(response, end='') 44 | sys.stdout.flush() 45 | previous_response+=response 46 | else: 47 | assistant_response = previous_response.split("Response:")[-1].strip() 48 | conversation.append({'role': 'Response', 'message': assistant_response}) 49 | previous_response='' 50 | break 51 | 52 | await websocket.close() 53 | 54 | if __name__ == '__main__': 55 | asyncio.run(print_response_stream()) 56 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import requests 3 | import os 4 | from tqdm import tqdm 5 | 6 | def download_file(url, path): 7 | response = requests.get(url, stream=True) 8 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 9 | block_size = 1024 #1 Kbyte 10 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 11 | 12 | with open(path, 'wb') as file: 13 | for data in response.iter_content(block_size): 14 | progress_bar.update(len(data)) 15 | file.write(data) 16 | 17 | progress_bar.close() 18 | 19 | def download_model(model_name, destination_folder="models"): 20 | # Define the base URL and headers for the Hugging Face API 21 | base_url = f"https://huggingface.co/{model_name}/resolve/main" 22 | headers = {"User-Agent": "Hugging Face Python"} 23 | 24 | # Send a GET request to the Hugging Face API to get a list of all files 25 | response = requests.get(f"https://huggingface.co/api/models/{model_name}", headers=headers) 26 | response.raise_for_status() 27 | 28 | # Extract the list of files from the response JSON 29 | files_to_download = [file["rfilename"] for file in response.json()["siblings"]] 30 | 31 | # Ensure the directory exists 32 | os.makedirs(f"{destination_folder}/{model_name}", exist_ok=True) 33 | 34 | # Download each file 35 | for file in files_to_download: 36 | print(f"Downloading {file}...") 37 | download_file(f"{base_url}/{file}", f"{destination_folder}/{model_name}/{file}") 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("model_name", type=str, help="Name of the model to download.") 42 | args = parser.parse_args() 43 | 44 | download_model(args.model_name) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | tqdm 3 | git+https://github.com/huggingface/transformers 4 | chardet 5 | cchardet 6 | 7 | websockets -------------------------------------------------------------------------------- /streaming_api.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, StoppingCriteriaList 3 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 4 | from utils import _SentinelTokenStoppingCriteria,Iteratorize,Stream 5 | from threading import Thread 6 | import gc 7 | import traceback 8 | import asyncio 9 | import json 10 | from websockets.server import serve 11 | 12 | model_name_or_path = "./models/TheBloke/WizardCoder-15B-1.0-GPTQ" 13 | DEV = "cuda:0" 14 | 15 | use_triton = False 16 | 17 | quantize_config = BaseQuantizeConfig( 18 | bits=4, # quantize model to 4-bit 19 | group_size=128, # it is recommended to set the value to 128 20 | desc_act=False, # desc_act and groupsize only works on triton 21 | ) 22 | 23 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) 24 | 25 | model = AutoGPTQForCausalLM.from_quantized(model_name_or_path, 26 | use_safetensors=True, 27 | device=DEV, 28 | use_triton=use_triton, 29 | quantize_config=quantize_config) 30 | 31 | model.eval() 32 | 33 | 34 | def generate_with_callback(callback=None, **kwargs): 35 | kwargs['stopping_criteria'].append(Stream(callback_func=callback)) 36 | gc.collect() 37 | if torch.cuda.is_available(): 38 | torch.cuda.empty_cache() 39 | with torch.no_grad(): 40 | model.generate(**kwargs) 41 | 42 | 43 | def generate_with_streaming(**kwargs): 44 | return Iteratorize(generate_with_callback, kwargs, callback=None) 45 | 46 | 47 | PATH = '/api/v1/stream' 48 | 49 | 50 | async def _handle_connection(websocket, path): 51 | if path != PATH: 52 | print(f'Streaming api: unknown path: {path}') 53 | return 54 | 55 | async for message in websocket: 56 | # Use plain text for now, can change to JSON string. 57 | input_text = message 58 | 59 | input_ids = tokenizer.encode( 60 | input_text, return_tensors="pt").to(DEV) 61 | 62 | # handle stopping strings 63 | stopping_strings = ['Human:'] 64 | stopping_criteria_list = StoppingCriteriaList() 65 | sentinel_token_ids = [tokenizer.encode( 66 | string, add_special_tokens=False, return_tensors='pt').to(DEV) for string in stopping_strings] 67 | starting_idx = len(input_ids[0]) 68 | stopping_criteria_list.append(_SentinelTokenStoppingCriteria( 69 | sentinel_token_ids, starting_idx)) 70 | 71 | # hardcode generation parameters 72 | generate_params = { 73 | 'input_ids': input_ids, 74 | 'max_length': 1000, 75 | 'temperature': 1.0, 76 | 'do_sample': True, 77 | "top_p": 0.9, 78 | 'stopping_criteria': stopping_criteria_list, 79 | } 80 | 81 | # As we stream, only send the new bytes. 82 | skip_index = 0 83 | message_num = 0 84 | 85 | # Generate tokens one by one 86 | with Iteratorize(generate_with_callback, generate_params, callback=None) as generator: 87 | for output in generator: 88 | # Decode the entire generated text so far 89 | generated_text = tokenizer.decode( 90 | output.cpu(), skip_special_tokens=True) 91 | # Only send the new part of the text 92 | to_send = generated_text[skip_index:] 93 | # remove bos token 94 | if not skip_index: 95 | to_send = to_send.replace(tokenizer.bos_token, "") 96 | to_send = to_send.strip() 97 | 98 | await websocket.send(json.dumps({ 99 | 'event': 'text_stream', 100 | 'message_num': message_num, 101 | 'text': to_send 102 | })) 103 | 104 | await asyncio.sleep(0) 105 | 106 | skip_index += len(to_send) 107 | message_num += 1 108 | 109 | await websocket.send(json.dumps({ 110 | 'event': 'stream_end', 111 | 'message_num': message_num 112 | })) 113 | 114 | 115 | async def _run(host: str, port: int): 116 | async with serve(_handle_connection, host, port, ping_interval=None): 117 | await asyncio.Future() # run forever 118 | 119 | 120 | def _run_server(port: int): 121 | address = '0.0.0.0' # Listen on all addresses 122 | 123 | print(f'Starting streaming server at ws://{address}:{port}') 124 | 125 | asyncio.run(_run(host=address, port=port)) 126 | 127 | 128 | if __name__ == '__main__': 129 | _run_server(5005) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | import gc 4 | from queue import Queue 5 | from threading import Thread 6 | 7 | 8 | class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): 9 | 10 | def __init__(self, sentinel_token_ids: list, starting_idx: int): 11 | transformers.StoppingCriteria.__init__(self) 12 | self.sentinel_token_ids = sentinel_token_ids 13 | self.starting_idx = starting_idx 14 | self.shortest = min([x.shape[-1] for x in sentinel_token_ids]) 15 | 16 | def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool: 17 | for sample in input_ids: 18 | trimmed_sample = sample[self.starting_idx:] 19 | trimmed_len = trimmed_sample.shape[-1] 20 | if trimmed_len < self.shortest: 21 | continue 22 | 23 | for sentinel in self.sentinel_token_ids: 24 | sentinel_len = sentinel.shape[-1] 25 | if trimmed_len < sentinel_len: 26 | continue 27 | 28 | window = trimmed_sample[-sentinel_len:] 29 | if torch.all(torch.eq(sentinel, window)): 30 | return True 31 | 32 | return False 33 | 34 | 35 | 36 | 37 | class Iteratorize: 38 | def __init__(self, func, kwargs=None, callback=None): 39 | self.mfunc = func 40 | self.c_callback = callback 41 | self.q = Queue() 42 | self.sentinel = object() 43 | self.kwargs = kwargs or {} 44 | self.stop_now = False 45 | 46 | def _callback(val): 47 | if self.stop_now: 48 | raise ValueError 49 | self.q.put(val) 50 | 51 | def gentask(): 52 | try: 53 | ret = self.mfunc(callback=_callback, **self.kwargs) 54 | except ValueError: 55 | pass 56 | except: 57 | traceback.print_exc() 58 | pass 59 | 60 | self.q.put(self.sentinel) 61 | if self.c_callback: 62 | self.c_callback(ret) 63 | 64 | self.thread = Thread(target=gentask) 65 | self.thread.start() 66 | 67 | def __iter__(self): 68 | return self 69 | 70 | def __next__(self): 71 | obj = self.q.get(True, None) 72 | if obj is self.sentinel: 73 | raise StopIteration 74 | else: 75 | return obj 76 | 77 | def __del__(self): 78 | gc.collect() 79 | if torch.cuda.is_available(): 80 | torch.cuda.empty_cache() 81 | 82 | def __enter__(self): 83 | return self 84 | 85 | def __exit__(self, exc_type, exc_val, exc_tb): 86 | self.stop_now = True 87 | gc.collect() 88 | if torch.cuda.is_available(): 89 | torch.cuda.empty_cache() 90 | 91 | 92 | 93 | class Stream(transformers.StoppingCriteria): 94 | def __init__(self, callback_func=None): 95 | self.callback_func = callback_func 96 | 97 | def __call__(self, input_ids, scores) -> bool: 98 | if self.callback_func is not None: 99 | self.callback_func(input_ids[0]) 100 | return False 101 | --------------------------------------------------------------------------------