├── README.md ├── config.yaml ├── examples.yaml └── model ├── __init__.py ├── foo.py ├── model.py └── patch.py /README.md: -------------------------------------------------------------------------------- 1 | # Starcoder Truss 2 | 3 | This is a [Truss](https://truss.baseten.co/) for Starcoder. Starcoder is an open-source language model trained specifically 4 | for code auto-completions. It was trained on text from over 80 programming languages. Check out more info about this model 5 | [here](https://huggingface.co/bigcode/starcoder). 6 | 7 | Before deploying this model, you'll need to: 8 | 9 | 1. Accept the terms of service of the Starcoder model [here](https://huggingface.co/bigcode/starcoder). 10 | 2. Retrieve your Huggingface token from the [settings](https://huggingface.co/settings/tokens). 11 | 3. Set your Huggingface token as a Baseten secret [here](https://app.baseten.co/settings/secrets) with the key `hf_api_key`. Note that you will *not* be able to successfully deploy Starcoder without doing this. 12 | 13 | ## Deploying Starcoder 14 | 15 | To deploy the Starcoder Truss, you'll need to follow these steps: 16 | 17 | 1. __Prerequisites__: Make sure you have a Baseten account and API key. You can sign up for a Baseten account [here](https://app.baseten.co/signup). 18 | 19 | 2. __Install Truss and the Baseten Python client__: If you haven't already, install the Baseten Python client and Truss in your development environment using: 20 | ``` 21 | pip install --upgrade baseten truss 22 | ``` 23 | 24 | 3. __Load the Starcoder Truss__: Assuming you've cloned this repo, spin up an IPython shell and load the Truss into memory: 25 | 26 | Note this assumes that you started the ipython shell from root of the repo. 27 | 28 | ``` 29 | import truss 30 | 31 | starcoder_truss = truss.load(".") 32 | ``` 33 | 34 | 4. __Log in to Baseten__: Log in to your Baseten account using your API key (key found [here](https://app.baseten.co/settings/account/api_keys)): 35 | ``` 36 | import baseten 37 | 38 | baseten.login("PASTE_API_KEY_HERE") 39 | ``` 40 | 41 | 5. __Deploy the Starcoder Truss__: Deploy the Starcoder Truss to Baseten with the following command: 42 | ``` 43 | baseten.deploy(starcoder_truss) 44 | ``` 45 | 46 | Once your Truss is deployed, you can start using the Starcoder model through the Baseten platform! Navigate to the Baseten UI to watch the model build and deploy and invoke it via the REST API. 47 | 48 | ## Starcoder API documentation 49 | 50 | ### Input 51 | 52 | This deployment of Starcoder takes a dictionary as input, which requires the following key: 53 | 54 | * `prompt` - the prompt for code auto-completion 55 | 56 | It also supports all parameters detailed in the transformers [GenerationConfig](https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig). 57 | 58 | ### Output 59 | 60 | The result will be a dictionary containing: 61 | 62 | * `status` - either `success` or `failed` 63 | * `data` - dictionary containing keys `completion`, which is the model result, and `prompt`, which is the prompt from the input. 64 | * `message` - will contain details in the case of errors 65 | 66 | ```json 67 | {"status": "success", 68 | "data": {"completion": "code for fibonacci sequence: '))\n\ndef fibonacci(n):\n if n == 0:\n return 0\n elif n == 1:\n return 1\n else:\n return fibonacci(n-1) + fibonacci(n-2)\n\nprint(fibonacci(n))\n", 69 | "prompt": "code for fib"}, 70 | "message": null} 71 | ``` 72 | 73 | ## Example usage 74 | 75 | ``` 76 | curl -X POST https://app.baseten.co/models/EqwKvqa/predict \ 77 | -H 'Authorization: Api-Key {YOUR_API_KEY}' \ 78 | -d '{"prompt": "def compute_fib(n):"}' 79 | ``` 80 | 81 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | bundled_packages_dir: packages 2 | data_dir: data 3 | description: null 4 | environment_variables: {} 5 | examples_filename: examples.yaml 6 | input_type: Any 7 | model_class_filename: model.py 8 | model_class_name: Model 9 | model_framework: custom 10 | model_metadata: {} 11 | model_module_dir: model 12 | model_name: Starcoder 13 | model_type: custom 14 | python_version: py39 15 | requirements: 16 | - torch 17 | - transformers 18 | - accelerate 19 | - joblib 20 | resources: 21 | cpu: "4" 22 | memory: 85Gi 23 | use_gpu: true 24 | accelerator: "A100:1" 25 | secrets: 26 | hf_api_key: null 27 | system_packages: [] 28 | spec_version: 2.0 29 | -------------------------------------------------------------------------------- /examples.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basetenlabs/starcoder-truss/253e3f64e58198440e77628ee438b8a5a8022100/examples.yaml -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basetenlabs/starcoder-truss/253e3f64e58198440e77628ee438b8a5a8022100/model/__init__.py -------------------------------------------------------------------------------- /model/foo.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def main(): 4 | 5 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline 3 | from typing import Dict, List 4 | from .patch import patch 5 | 6 | CHECKPOINT = "bigcode/starcoder" 7 | DEFAULT_MAX_LENGTH = 128 8 | DEFAULT_TOP_P = 0.95 9 | 10 | class Model: 11 | def __init__(self, data_dir: str, config: Dict, secrets: Dict, **kwargs) -> None: 12 | self._data_dir = data_dir 13 | self._config = config 14 | self._secrets = secrets 15 | self._hf_api_key = secrets["hf_api_key"] 16 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 17 | self._tokenizer = None 18 | self._model = None 19 | 20 | def load(self): 21 | patch() 22 | self._tokenizer = AutoTokenizer.from_pretrained( 23 | CHECKPOINT, 24 | use_auth_token=self._hf_api_key 25 | ) 26 | self._model = AutoModelForCausalLM.from_pretrained( 27 | CHECKPOINT, 28 | use_auth_token=self._hf_api_key, 29 | device_map="auto", 30 | offload_folder="offload" 31 | ) 32 | 33 | 34 | def predict(self, request: Dict) -> Dict: 35 | with torch.no_grad(): 36 | try: 37 | prompt = request.pop("prompt") 38 | max_length = request.pop("max_length", DEFAULT_MAX_LENGTH) 39 | top_p = request.pop("top_p", DEFAULT_TOP_P) 40 | encoded_prompt = self._tokenizer(prompt, return_tensors="pt").input_ids 41 | 42 | encoded_output = self._model.generate( 43 | encoded_prompt, 44 | max_length=max_length, 45 | top_p=top_p, 46 | **request 47 | )[0] 48 | decoded_output = self._tokenizer.decode( 49 | encoded_output, skip_special_tokens=True 50 | ) 51 | instance_response = { 52 | "completion": decoded_output, 53 | "prompt": prompt, 54 | } 55 | 56 | return {"status": "success", "data": instance_response, "message": None} 57 | except Exception as exc: 58 | return {"status": "error", "data": None, "message": str(exc)} -------------------------------------------------------------------------------- /model/patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Patch for Huggingface/Transformer: 3 | 4 | The Transformer's library currently performs serial downloads of weights, which can 5 | be time-consuming when there are multiple weights to be fetched. This patch aims to 6 | parallelize the downloads using joblib, allowing the downloads to be spread across 7 | multiple cores and improving overall performance. 8 | """ 9 | import os 10 | 11 | from joblib import Parallel, delayed 12 | from transformers.utils import hub 13 | 14 | 15 | def get_checkpoint_shard_files( 16 | pretrained_model_name_or_path, 17 | index_filename, 18 | cache_dir=None, 19 | force_download=False, 20 | proxies=None, 21 | resume_download=False, 22 | local_files_only=False, 23 | use_auth_token=None, 24 | user_agent=None, 25 | revision=None, 26 | subfolder="", 27 | _commit_hash=None, 28 | ): 29 | import json 30 | 31 | from huggingface_hub.utils import EntryNotFoundError 32 | from requests.exceptions import HTTPError 33 | from transformers.utils.hub import HUGGINGFACE_CO_RESOLVE_ENDPOINT, cached_file 34 | 35 | if not os.path.isfile(index_filename): 36 | raise ValueError( 37 | f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}." 38 | ) 39 | 40 | with open(index_filename, "r") as f: 41 | index = json.loads(f.read()) 42 | 43 | shard_filenames = sorted(set(index["weight_map"].values())) 44 | sharded_metadata = index["metadata"] 45 | sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) 46 | sharded_metadata["weight_map"] = index["weight_map"].copy() 47 | 48 | # First, let's deal with local folder. 49 | if os.path.isdir(pretrained_model_name_or_path): 50 | shard_filenames = [ 51 | os.path.join(pretrained_model_name_or_path, subfolder, f) 52 | for f in shard_filenames 53 | ] 54 | return shard_filenames, sharded_metadata 55 | 56 | # At this stage pretrained_model_name_or_path is a model identifier on the Hub 57 | cached_filenames = [] 58 | 59 | # Function to load file from url 60 | def load_file_from_url(shard_filename): 61 | try: 62 | # Load from URL 63 | return cached_file( 64 | pretrained_model_name_or_path, 65 | shard_filename, 66 | cache_dir=cache_dir, 67 | force_download=force_download, 68 | proxies=proxies, 69 | resume_download=resume_download, 70 | local_files_only=local_files_only, 71 | use_auth_token=use_auth_token, 72 | user_agent=user_agent, 73 | revision=revision, 74 | subfolder=subfolder, 75 | _commit_hash=_commit_hash, 76 | ) 77 | except EntryNotFoundError: 78 | raise EnvironmentError( 79 | f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is " 80 | "required according to the checkpoint index." 81 | ) 82 | except HTTPError: 83 | raise EnvironmentError( 84 | f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try" 85 | " again after checking your internet connection." 86 | ) 87 | 88 | # Use joblib to download in parallel 89 | cached_filenames = Parallel(n_jobs=4)( 90 | delayed(load_file_from_url)(shard_filename) 91 | for shard_filename in shard_filenames 92 | ) 93 | 94 | # Filter out None results in case of errors 95 | cached_filenames = [f for f in cached_filenames if f is not None] 96 | 97 | return cached_filenames, sharded_metadata 98 | 99 | 100 | def patch(): 101 | """ 102 | Patching Huggingface downloads to make them parallelized. 103 | """ 104 | hub.get_checkpoint_shard_files = get_checkpoint_shard_files 105 | --------------------------------------------------------------------------------