├── README.md ├── examples ├── curl_example.bash ├── javascript_example.js └── python_example.py ├── inference.py ├── inference_batch.py ├── server.py └── server_batch.py /README.md: -------------------------------------------------------------------------------- 1 | # gpt-j-api-huggingface 🤗 2 | 3 | An API to interact with the GPT-J language model in Huggingface! 4 | 5 | ## Launching the server 6 | 7 | To launch the API server simply run: 8 | 9 | ``` 10 | python3 server.py 11 | ``` 12 | 13 | To launch the API server with dynamic batching run: 14 | 15 | ``` 16 | python3 server_batch.py 17 | ``` 18 | 19 | 20 | ## Using the API 21 | 22 | Execute the following commands to run an example in the preferred language: 23 | 24 | ### CURL 25 | 26 | ``` 27 | bash examples/curl_example.bash 28 | ``` 29 | 30 | ### Python 31 | 32 | ``` 33 | python3 examples/python_example.py 34 | ``` 35 | 36 | ### Javascript 37 | 38 | ``` 39 | cd examples 40 | yarn add axios 41 | node javascript_example.js 42 | ``` 43 | -------------------------------------------------------------------------------- /examples/curl_example.bash: -------------------------------------------------------------------------------- 1 | curl -X 'POST' 'http://0.0.0.0:5000/generate' \ 2 | -H 'accept: application/json' \ 3 | -d '{"prompt":"In a shocking finding, scientist discovered", "max_length":100}' 4 | -------------------------------------------------------------------------------- /examples/javascript_example.js: -------------------------------------------------------------------------------- 1 | const axios = require('axios'); 2 | 3 | const json = JSON.stringify({ prompt: "In a shocking finding, scientist discovered", max_length: 100 }); 4 | axios.post('http://0.0.0.0:5000/generate', json, { 5 | headers: { 6 | // Overwrite Axios's automatically set Content-Type 7 | 'Content-Type': 'application/json' 8 | } 9 | }).then(res=>{ 10 | console.log(res.data); // '{"answer":42}' 11 | }) 12 | -------------------------------------------------------------------------------- /examples/python_example.py: -------------------------------------------------------------------------------- 1 | import requests, json 2 | from requests.api import head 3 | 4 | payload = { 5 | "prompt" : "In a shocking finding, scientist discovered", 6 | "max_length" : 100 7 | } 8 | headers = {'Content-type': 'application/json'} 9 | response = requests.post("http://0.0.0.0:5000/generate", data=json.dumps(payload), headers=headers).json() 10 | print(response) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, GPTJForCausalLM 2 | import torch 3 | 4 | model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision='float16', torch_dtype=torch.half, low_cpu_mem_usage=True).cuda() 5 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 6 | 7 | def run_inference(params_json): 8 | input_ids = tokenizer(params_json['prompt'], 9 | return_tensors="pt").input_ids.cuda() 10 | 11 | temperature = params_json["temperature"] if "temperature" in params_json else 1.0 12 | top_k = params_json["top_k"] if "top_k" in params_json else 50 13 | top_p = params_json["top_p"] if "top_p" in params_json else 1.0 14 | max_length = params_json["max_length"] if "max_length" in params_json else 20 15 | 16 | gen_tokens = model.generate(input_ids, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, max_length=max_length) 17 | gen_text = tokenizer.batch_decode(gen_tokens)[0] 18 | 19 | return gen_text 20 | -------------------------------------------------------------------------------- /inference_batch.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, GPTJForCausalLM 2 | import torch 3 | 4 | TEMPERATURE = 1.0 5 | TOP_K = 50 6 | TOP_P = 1.0 7 | 8 | class InferenceModel: 9 | def __init__(self): 10 | self.model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision='float16', 11 | torch_dtype=torch.half, low_cpu_mem_usage=True).cuda() 12 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 13 | self.tokenizer.padding_side = "left" 14 | self.tokenizer.pad_token = self.tokenizer.eos_token # to avoid an error 15 | 16 | 17 | def run_batch_inference(self, params_json_list): 18 | inputs = self.tokenizer([params_json['prompt'] for params_json in params_json_list], 19 | return_tensors="pt", padding=True) 20 | 21 | # Currently, we cannot apply the following parameters for each item in a batch. 22 | # Look at this issue: https://github.com/huggingface/transformers/issues/14530 23 | temperature = TEMPERATURE 24 | top_k = TOP_K 25 | top_p = TOP_P 26 | 27 | max_length = max([params_json["max_length"] for params_json in params_json_list]) 28 | 29 | output_sequences = self.model.generate(input_ids=inputs['input_ids'].cuda(), 30 | attention_mask=inputs['attention_mask'].cuda(), 31 | do_sample=True, 32 | temperature=temperature, 33 | top_p=top_p, 34 | top_k=top_k, 35 | pad_token_id=self.tokenizer.eos_token_id, 36 | max_length=max_length) 37 | outputs = [self.tokenizer.decode(x) for x in output_sequences] 38 | 39 | return outputs 40 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request, HTTPException, BackgroundTasks 2 | from inference import run_inference 3 | import uvicorn 4 | 5 | app = FastAPI() 6 | 7 | @app.post("/generate") 8 | async def root(request: Request): 9 | params_json = await request.json() 10 | 11 | if "prompt" not in params_json: 12 | raise HTTPException(status_code=400, detail="Prompt needs to provided as an input parameter") 13 | 14 | output = run_inference(params_json) 15 | 16 | return {"output": output} 17 | 18 | uvicorn.run(app, host="0.0.0.0", port=5000) -------------------------------------------------------------------------------- /server_batch.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request, HTTPException 2 | from inference_batch import InferenceModel 3 | import asyncio 4 | 5 | MAX_TOKENS = 4778 # HACK: determined the maximum amount of tokens for Nvidia T4 (16 Gb VRAM) 6 | # by manually tuning the number of items in a batch with items of sequence length of 1024 7 | MAX_QUEUE_SIZE = 20 8 | DEFAULT_MAX_SEQ_LEN = 20 9 | 10 | inference_model = InferenceModel() 11 | 12 | app = FastAPI() 13 | 14 | q = asyncio.Queue(MAX_QUEUE_SIZE) 15 | 16 | async def process_batches(): 17 | first_item = None 18 | while True: 19 | if first_item is None: 20 | first_item = await q.get() # wait for the next item (blocking) 21 | items = [first_item] 22 | batch_max_len = first_item[0]['max_length'] 23 | 24 | # dequeue until we get an exception which means that the queue is empty 25 | while True: 26 | q_item = None 27 | try: 28 | q_item = q.get_nowait() 29 | except Exception as e: 30 | print("Exception while getting the future", e) 31 | first_item = None 32 | break 33 | if q_item is not None: 34 | curr_max_len = max(batch_max_len, q_item[0]['max_length']) 35 | if curr_max_len * len(items) < MAX_TOKENS: 36 | items.append(q_item) 37 | batch_max_len = curr_max_len 38 | else: 39 | # Since adding this item to the batch will produce a batch that is too large to fit into a memory, 40 | # we use it as the first item during the next iteration. 41 | first_item = q_item 42 | break 43 | 44 | outputs = inference_model.run_batch_inference([item[0] for item in items]) 45 | 46 | # resolving the futures (setting the results) 47 | for i, output in enumerate(outputs): 48 | items[i][1].set_result(output) 49 | # HACK: the following line ensures that the futures are resolved 50 | # before moving to processing next items 51 | await asyncio.sleep(0) 52 | 53 | loop = asyncio.get_event_loop() 54 | task = loop.create_task(process_batches()) 55 | 56 | @app.post("/generate") 57 | async def root(request: Request): 58 | params_json = await request.json() 59 | 60 | if "prompt" not in params_json: 61 | raise HTTPException(status_code=400, detail="Prompt needs to provided as an input parameter") 62 | 63 | if "max_length" not in params_json: 64 | params_json["max_length"] = DEFAULT_MAX_SEQ_LEN 65 | 66 | fut = loop.create_future() 67 | try: 68 | q.put_nowait((params_json, fut)) 69 | except Exception as e: 70 | raise HTTPException(status_code=500, detail="Maximum amount of requests has been reached. Try again later") 71 | print(q.qsize()) 72 | res = await fut 73 | return res 74 | 75 | from uvicorn import Config, Server 76 | config = Config(app=app, loop=loop, host="0.0.0.0", port=5000) 77 | server = Server(config) 78 | loop.run_until_complete(server.serve()) 79 | --------------------------------------------------------------------------------