├── setup.py ├── llmx ├── generators │ ├── text │ │ ├── __init__.py │ │ ├── providers.py │ │ ├── base_textgen.py │ │ ├── textgen.py │ │ ├── cohere_textgen.py │ │ ├── openai_textgen.py │ │ ├── anthropic_textgen.py │ │ ├── palm_textgen.py │ │ └── hf_textgen.py │ └── __init__.py ├── version.py ├── __init__.py ├── cli.py ├── datamodel.py ├── configs │ └── config.default.yml └── utils.py ├── MANIFEST.in ├── notebooks ├── research │ └── travelbenchmark.ipynb └── tutorial.ipynb ├── LICENSE ├── pyproject.toml ├── tests └── test_generators.py ├── .gitignore └── README.md /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | setup() -------------------------------------------------------------------------------- /llmx/generators/text/__init__.py: -------------------------------------------------------------------------------- 1 | from .textgen import llm 2 | -------------------------------------------------------------------------------- /llmx/version.py: -------------------------------------------------------------------------------- 1 | VERSION = "0.0.21a" 2 | APP_NAME = "llmx" 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-exclude notebooks 2 | recursive-exclude configs 3 | recursive-exclude tests -------------------------------------------------------------------------------- /llmx/generators/__init__.py: -------------------------------------------------------------------------------- 1 | # from .text.textgen import TextGenerator 2 | from .text.textgen import llm 3 | from .text.base_textgen import TextGenerator 4 | -------------------------------------------------------------------------------- /llmx/generators/text/providers.py: -------------------------------------------------------------------------------- 1 | # This file contains the list of providers and models that are available supported by LLMX. 2 | 3 | 4 | from llmx.utils import load_config 5 | 6 | 7 | config = load_config() 8 | providers = providers = config["providers"] if "providers" in config else None 9 | 10 | providers = config["providers"] 11 | -------------------------------------------------------------------------------- /llmx/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from .generators.text.textgen import llm 3 | from .datamodel import TextGenerationConfig, TextGenerationResponse, Message 4 | from .generators.text.base_textgen import TextGenerator 5 | from .generators.text.providers import providers 6 | 7 | if sys.version_info < (3, 9): 8 | raise RuntimeError("llmx requires Python 3.10+") 9 | -------------------------------------------------------------------------------- /llmx/cli.py: -------------------------------------------------------------------------------- 1 | import typer 2 | from .generators.text.providers import providers 3 | 4 | app = typer.Typer() 5 | 6 | 7 | @app.command() 8 | def models(): 9 | print("Available models:") 10 | for provider in providers.items(): 11 | print(f"Provider: {provider[1]['name']}") 12 | for model in provider[1]["models"]: 13 | print(f" - {model['name']}") 14 | 15 | 16 | @app.command() 17 | def list(): 18 | print("list") 19 | 20 | 21 | def run(): 22 | app() 23 | 24 | 25 | if __name__ == "__main__": 26 | app() 27 | -------------------------------------------------------------------------------- /notebooks/research/travelbenchmark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from llmx.generators.text.textgen import TextGenerator\n", 10 | "from llmx.datamodel import TextGenerationConfig \n", 11 | "\n", 12 | "config = TextGenerationConfig( \n", 13 | " n=1,\n", 14 | " temperature=0.0,\n", 15 | " max_tokens=100,\n", 16 | " top_p=1.0,\n", 17 | " top_k=50,\n", 18 | " frequency_penalty=0.0,\n", 19 | " presence_penalty=0.0,\n", 20 | " messages = [\n", 21 | " {\"role\": \"user\", \"content\": \"What is the height of the Eiffel Tower?\"},\n", 22 | " ]\n", 23 | ")" 24 | ] 25 | } 26 | ], 27 | "metadata": { 28 | "language_info": { 29 | "name": "python" 30 | }, 31 | "orig_nbformat": 4 32 | }, 33 | "nbformat": 4, 34 | "nbformat_minor": 2 35 | } 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2023 Victor Dibia. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /llmx/generators/text/base_textgen.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union, List, Dict 3 | from diskcache import Cache 4 | from ...utils import get_user_cache_dir 5 | from ...datamodel import TextGenerationConfig, TextGenerationResponse 6 | from ...version import APP_NAME 7 | from abc import ABC, abstractmethod 8 | 9 | 10 | class TextGenerator(ABC): 11 | 12 | def __init__(self, provider: str = "openai", **kwargs): 13 | self.provider = provider 14 | self.model_name = kwargs.get("model_name", "gpt-3.5-turbo") 15 | 16 | app_name = APP_NAME 17 | cache_dir_default = get_user_cache_dir(app_name) 18 | cache_dir_based_on_model = os.path.join(cache_dir_default, self.provider, self.model_name) 19 | self.cache_dir = kwargs.get("cache_dir", cache_dir_based_on_model) 20 | self.cache = Cache(self.cache_dir) 21 | 22 | @abstractmethod 23 | def generate( 24 | self, messages: Union[List[Dict], 25 | str], 26 | config: TextGenerationConfig = TextGenerationConfig(), 27 | **kwargs) -> TextGenerationResponse: 28 | pass 29 | 30 | @abstractmethod 31 | def count_tokens(self, text) -> int: 32 | pass 33 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llmx" 7 | authors = [ 8 | { name="Victor Dibia", email="victor.dibia@gmail.com" }, 9 | ] 10 | description = "LLMX: A library for LLM Text Generation" 11 | readme = "README.md" 12 | license = { file="LICENSE" } 13 | requires-python = ">=3.9" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | ] 19 | 20 | 21 | dependencies = [ 22 | "pydantic", 23 | "openai", 24 | "tiktoken", 25 | "diskcache", 26 | "cohere", 27 | "google.auth", 28 | "anthropic", 29 | "typer", 30 | "pyyaml", 31 | ] 32 | optional-dependencies = {web = ["fastapi", "uvicorn"], transformers = ["transformers[torch]>=4.26","accelerate", "bitsandbytes"]} 33 | 34 | dynamic = ["version"] 35 | 36 | [tool.setuptools] 37 | include-package-data = true 38 | 39 | 40 | [tool.setuptools.dynamic] 41 | version = {attr = "llmx.version.VERSION"} 42 | readme = {file = ["README.md"]} 43 | 44 | [tool.setuptools.packages.find] 45 | include = ["llmx*"] 46 | exclude = ["*.tests*"] 47 | namespaces = false 48 | 49 | [tool.setuptools.package-data] 50 | "llmx" = ["*.*"] 51 | 52 | [tool.pytest.ini_options] 53 | filterwarnings = [ 54 | "ignore:Deprecated call to `pkg_resources\\.declare_namespace\\('.*'\\):DeprecationWarning", 55 | "ignore::DeprecationWarning:google.rpc", 56 | ] 57 | 58 | 59 | [project.urls] 60 | "Homepage" = "https://github.com/victordibia/llmx" 61 | "Bug Tracker" = "https://github.com/victordibia/llmx/issues" 62 | 63 | [project.scripts] 64 | llmx = "llmx.cli:run" -------------------------------------------------------------------------------- /llmx/datamodel.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | from typing import Any, Optional, Union, List 3 | from pydantic.dataclasses import dataclass 4 | 5 | 6 | @dataclass 7 | class Message: 8 | role: str 9 | content: str 10 | 11 | def __post_init__(self): 12 | self._fields_dict = asdict(self) 13 | 14 | def __getitem__(self, key: Union[str, int]) -> Any: 15 | return self._fields_dict.get(key) 16 | 17 | def to_dict(self): 18 | return self._fields_dict 19 | 20 | def __iter__(self): 21 | for key, value in self._fields_dict.items(): 22 | yield key, value 23 | 24 | 25 | @dataclass 26 | class TextGenerationConfig: 27 | n: int = 1 28 | temperature: float = 0.1 29 | max_tokens: Union[int, None] = None 30 | top_p: float = 1.0 31 | top_k: int = 50 32 | frequency_penalty: float = 0.0 33 | presence_penalty: float = 0.0 34 | provider: Union[str, None] = None 35 | model: Optional[str] = None 36 | stop: Union[List[str], str, None] = None 37 | use_cache: bool = True 38 | 39 | def __post_init__(self): 40 | self._fields_dict = asdict(self) 41 | 42 | def __getitem__(self, key: Union[str, int]) -> Any: 43 | return self._fields_dict.get(key) 44 | 45 | def __iter__(self): 46 | for key, value in self._fields_dict.items(): 47 | yield key, value 48 | 49 | 50 | @dataclass 51 | class TextGenerationResponse: 52 | """Response from a text generation""" 53 | 54 | text: List[Message] 55 | config: Any 56 | logprobs: Optional[Any] = None # logprobs if available 57 | usage: Optional[Any] = None # usage statistics from the provider 58 | response: Optional[Any] = None # full response from the provider 59 | 60 | def __post_init__(self): 61 | self._fields_dict = asdict(self) 62 | 63 | def __getitem__(self, key: Union[str, int]) -> Any: 64 | return self._fields_dict.get(key) 65 | 66 | def __iter__(self): 67 | for key, value in self._fields_dict.items(): 68 | yield key, value 69 | 70 | def to_dict(self): 71 | return self._fields_dict 72 | 73 | def __json__(self): 74 | return self._fields_dict 75 | -------------------------------------------------------------------------------- /tests/test_generators.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | from llmx import llm 4 | from llmx.datamodel import TextGenerationConfig 5 | 6 | 7 | config = TextGenerationConfig( 8 | n=2, 9 | temperature=0.4, 10 | max_tokens=100, 11 | top_p=1.0, 12 | top_k=50, 13 | frequency_penalty=0.0, 14 | presence_penalty=0.0, 15 | use_cache=False 16 | ) 17 | 18 | messages = [ 19 | {"role": "user", 20 | "content": "What is the capital of France? Only respond with the exact answer"}] 21 | 22 | def test_anthropic(): 23 | anthropic_gen = llm(provider="anthropic", api_key=os.environ.get("ANTHROPIC_API_KEY", None)) 24 | config.model = "claude-3-5-sonnet-20240620" # or any other Anthropic model you want to test 25 | anthropic_response = anthropic_gen.generate(messages, config=config) 26 | answer = anthropic_response.text[0].content 27 | print(anthropic_response.text[0].content) 28 | 29 | assert ("paris" in answer.lower()) 30 | assert len(anthropic_response.text) == 1 31 | 32 | def test_openai(): 33 | openai_gen = llm(provider="openai") 34 | openai_response = openai_gen.generate(messages, config=config) 35 | answer = openai_response.text[0].content 36 | print(openai_response.text[0].content) 37 | 38 | assert ("paris" in answer.lower()) 39 | assert len(openai_response.text) == 2 40 | 41 | 42 | def test_google(): 43 | google_gen = llm(provider="palm", api_key=os.environ.get("PALM_API_KEY", None)) 44 | config.model = "chat-bison-001" 45 | google_response = google_gen.generate(messages, config=config) 46 | answer = google_response.text[0].content 47 | print(google_response.text[0].content) 48 | 49 | assert ("paris" in answer.lower()) 50 | # assert len(google_response.text) == 2 palm may chose to return 1 or 2 responses 51 | 52 | 53 | def test_cohere(): 54 | cohere_gen = llm(provider="cohere") 55 | config.model = "command" 56 | cohere_response = cohere_gen.generate(messages, config=config) 57 | answer = cohere_response.text[0].content 58 | print(cohere_response.text[0].content) 59 | 60 | assert ("paris" in answer.lower()) 61 | assert len(cohere_response.text) == 2 62 | 63 | 64 | @pytest.mark.skipif(os.environ.get("LLMX_RUNALL", None) is None 65 | or os.environ.get("LLMX_RUNALL", None) == "False", reason="takes too long") 66 | def test_hf_local(): 67 | hf_local_gen = llm( 68 | provider="hf", 69 | model="TheBloke/Llama-2-7b-chat-fp16", 70 | device_map="auto") 71 | hf_local_response = hf_local_gen.generate(messages, config=config) 72 | answer = hf_local_response.text[0].content 73 | print(hf_local_response.text[0].content) 74 | 75 | assert ("paris" in answer.lower()) 76 | assert len(hf_local_response.text) == 2 77 | -------------------------------------------------------------------------------- /llmx/generators/text/textgen.py: -------------------------------------------------------------------------------- 1 | from ...utils import load_config 2 | from .openai_textgen import OpenAITextGenerator 3 | from .palm_textgen import PalmTextGenerator 4 | from .cohere_textgen import CohereTextGenerator 5 | from .anthropic_textgen import AnthropicTextGenerator 6 | import logging 7 | 8 | logger = logging.getLogger("llmx") 9 | 10 | 11 | def sanitize_provider(provider: str): 12 | if provider.lower() == "openai" or provider.lower() == "default" or provider.lower() == "azureopenai" or provider.lower() == "azureoai": 13 | return "openai" 14 | elif provider.lower() == "palm" or provider.lower() == "google": 15 | return "palm" 16 | elif provider.lower() == "cohere": 17 | return "cohere" 18 | elif provider.lower() == "hf" or provider.lower() == "huggingface": 19 | return "hf" 20 | elif provider.lower() == "anthropic" or provider.lower() == "claude": 21 | return "anthropic" 22 | else: 23 | raise ValueError( 24 | f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'." 25 | ) 26 | 27 | 28 | def llm(provider: str = None, **kwargs): 29 | 30 | # load config. This will load the default config from 31 | # configs/config.default.yml if no path to a config file is specified. in 32 | # the environment variable LLMX_CONFIG_PATH 33 | config = load_config() 34 | if provider is None: 35 | # provider is not explicitly specified, use the default provider from the config file 36 | provider = config["model"]["provider"] if "provider" in config["model"] else None 37 | kwargs = config["model"]["parameters"] if "parameters" in config["model"] else {} 38 | if provider is None: 39 | logger.info("No provider specified. Defaulting to 'openai'.") 40 | provider = "openai" 41 | 42 | # sanitize provider 43 | provider = sanitize_provider(provider) 44 | 45 | # set the list of available models based on the config file 46 | models = config["providers"][provider]["models"] if "providers" in config and provider in config["providers"] else {} 47 | 48 | kwargs["provider"] = kwargs["provider"] if "provider" in kwargs else provider 49 | kwargs["models"] = kwargs["models"] if "models" in kwargs else models 50 | 51 | # print(kwargs) 52 | 53 | if provider.lower() == "openai": 54 | return OpenAITextGenerator(**kwargs) 55 | elif provider.lower() == "palm": 56 | return PalmTextGenerator(**kwargs) 57 | elif provider.lower() == "cohere": 58 | return CohereTextGenerator(**kwargs) 59 | elif provider.lower() == "anthropic": 60 | return AnthropicTextGenerator(**kwargs) 61 | elif provider.lower() == "hf": 62 | try: 63 | import transformers 64 | except ImportError: 65 | raise ImportError( 66 | "Please install the `transformers` package to use the HFTextGenerator class. pip install llmx[transformers]" 67 | ) 68 | 69 | # Check if torch package is installed 70 | try: 71 | import torch 72 | except ImportError: 73 | raise ImportError( 74 | "Please install the `torch` package to use the HFTextGenerator class. pip install llmx[transformers]" 75 | ) 76 | 77 | from .hf_textgen import HFTextGenerator 78 | 79 | return HFTextGenerator(**kwargs) 80 | 81 | else: 82 | raise ValueError( 83 | f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'." 84 | ) -------------------------------------------------------------------------------- /llmx/generators/text/cohere_textgen.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | import os 3 | import cohere 4 | from dataclasses import asdict 5 | 6 | from .base_textgen import TextGenerator 7 | from ...datamodel import TextGenerationConfig, TextGenerationResponse, Message 8 | from ...utils import cache_request, get_models_maxtoken_dict, num_tokens_from_messages 9 | from ..text.providers import providers 10 | 11 | 12 | class CohereTextGenerator(TextGenerator): 13 | def __init__( 14 | self, 15 | api_key: str = None, 16 | provider: str = "cohere", 17 | model: str = None, 18 | models: Dict = None, 19 | ): 20 | super().__init__(provider=provider) 21 | api_key = api_key or os.environ.get("COHERE_API_KEY", None) 22 | if api_key is None: 23 | raise ValueError( 24 | "Cohere API key is not set. Please set the COHERE_API_KEY environment variable." 25 | ) 26 | self.client = cohere.Client(api_key) 27 | self.model_max_token_dict = get_models_maxtoken_dict(models) 28 | self.model_name = model or "command" 29 | 30 | def format_messages(self, messages): 31 | prompt = "" 32 | for message in messages: 33 | if message["role"] == "system": 34 | prompt += message["content"] + "\n" 35 | else: 36 | prompt += message["role"] + ": " + message["content"] + "\n" 37 | 38 | return prompt 39 | 40 | def generate( 41 | self, 42 | messages: Union[list[dict], str], 43 | config: TextGenerationConfig = TextGenerationConfig(), 44 | **kwargs, 45 | ) -> TextGenerationResponse: 46 | use_cache = config.use_cache 47 | messages = self.format_messages(messages) 48 | self.model_name = config.model or self.model_name 49 | 50 | max_tokens = ( 51 | self.model_max_token_dict[self.model_name] 52 | if config.model in self.model_max_token_dict else 1024) 53 | 54 | cohere_config = { 55 | "model": self.model_name, 56 | "prompt": messages, 57 | "max_tokens": config.max_tokens or max_tokens, 58 | "temperature": config.temperature, 59 | "k": config.top_k, 60 | "p": config.top_p, 61 | "num_generations": config.n, 62 | "stop_sequences": config.stop, 63 | "frequency_penalty": config.frequency_penalty, 64 | "presence_penalty": config.presence_penalty, 65 | } 66 | 67 | # print("calling cohere ***************", config) 68 | 69 | cache_key_params = cohere_config | {"messages": messages} 70 | if use_cache: 71 | response = cache_request(cache=self.cache, params=cache_key_params) 72 | if response: 73 | return TextGenerationResponse(**response) 74 | 75 | co_response = self.client.generate(**cohere_config) 76 | 77 | response_text = [ 78 | Message( 79 | role="system", 80 | content=x.text, 81 | ) 82 | for x in co_response.generations 83 | ] 84 | 85 | response = TextGenerationResponse( 86 | text=response_text, 87 | logprobs=[], # You may need to extract log probabilities from the response if needed 88 | config=cohere_config, 89 | usage={}, 90 | response=co_response, 91 | ) 92 | 93 | cache_request( 94 | cache=self.cache, params=cache_key_params, values=asdict(response) 95 | ) 96 | return response 97 | 98 | def count_tokens(self, text) -> int: 99 | return num_tokens_from_messages(text) 100 | -------------------------------------------------------------------------------- /llmx/generators/text/openai_textgen.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Dict 2 | from .base_textgen import TextGenerator 3 | from ...datamodel import Message, TextGenerationConfig, TextGenerationResponse 4 | from ...utils import cache_request, get_models_maxtoken_dict, num_tokens_from_messages 5 | import os 6 | from openai import AzureOpenAI, OpenAI 7 | from dataclasses import asdict 8 | 9 | 10 | class OpenAITextGenerator(TextGenerator): 11 | def __init__( 12 | self, 13 | api_key: str = os.environ.get("OPENAI_API_KEY", None), 14 | provider: str = "openai", 15 | organization: str = None, 16 | api_type: str = None, 17 | api_version: str = None, 18 | azure_endpoint: str = None, 19 | model: str = None, 20 | models: Dict = None, 21 | ): 22 | super().__init__(provider=provider) 23 | self.api_key = api_key or os.environ.get("OPENAI_API_KEY", None) 24 | 25 | if self.api_key is None: 26 | raise ValueError( 27 | "OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable." 28 | ) 29 | 30 | self.client_args = { 31 | "api_key": self.api_key, 32 | "organization": organization, 33 | "api_version": api_version, 34 | "azure_endpoint": azure_endpoint, 35 | } 36 | # remove keys with None values 37 | self.client_args = {k: v for k, 38 | v in self.client_args.items() if v is not None} 39 | 40 | if api_type: 41 | if api_type == "azure": 42 | self.client = AzureOpenAI(**self.client_args) 43 | else: 44 | raise ValueError(f"Unknown api_type: {api_type}") 45 | else: 46 | self.client = OpenAI(**self.client_args) 47 | 48 | self.model_name = model or "gpt-3.5-turbo" 49 | self.model_max_token_dict = get_models_maxtoken_dict(models) 50 | 51 | def generate( 52 | self, 53 | messages: Union[List[dict], str], 54 | config: TextGenerationConfig = TextGenerationConfig(), 55 | **kwargs, 56 | ) -> TextGenerationResponse: 57 | use_cache = config.use_cache 58 | model = config.model or self.model_name 59 | prompt_tokens = num_tokens_from_messages(messages) 60 | max_tokens = max( 61 | self.model_max_token_dict.get( 62 | model, 4096) - prompt_tokens - 10, 200 63 | ) 64 | 65 | oai_config = { 66 | "model": model, 67 | "temperature": config.temperature, 68 | "max_tokens": max_tokens, 69 | "top_p": config.top_p, 70 | "frequency_penalty": config.frequency_penalty, 71 | "presence_penalty": config.presence_penalty, 72 | "n": config.n, 73 | "messages": messages, 74 | } 75 | 76 | self.model_name = model 77 | cache_key_params = (oai_config) | {"messages": messages} 78 | if use_cache: 79 | response = cache_request(cache=self.cache, params=cache_key_params) 80 | if response: 81 | return TextGenerationResponse(**response) 82 | 83 | oai_response = self.client.chat.completions.create(**oai_config) 84 | 85 | response = TextGenerationResponse( 86 | text=[Message(**x.message.model_dump()) 87 | for x in oai_response.choices], 88 | logprobs=[], 89 | config=oai_config, 90 | usage=dict(oai_response.usage), 91 | ) 92 | # if use_cache: 93 | cache_request( 94 | cache=self.cache, params=cache_key_params, values=asdict(response) 95 | ) 96 | return response 97 | 98 | def count_tokens(self, text) -> int: 99 | return num_tokens_from_messages(text) 100 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .release.sh 3 | llmx/generators/cache 4 | llmx.egg-info 5 | notebooks/test.ipynb 6 | notebooks/data 7 | notebooks/.env 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | configs/config.yml 13 | 14 | .DS_Store 15 | n 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | cover/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | .pybuilder/ 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | # For a library or package, you might want to ignore these files since the code is 98 | # intended to run in multiple environments; otherwise, check them in: 99 | # .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # poetry 109 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 110 | # This is especially recommended for binary packages to ensure reproducibility, and is more 111 | # commonly ignored for libraries. 112 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 113 | #poetry.lock 114 | 115 | # pdm 116 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 117 | #pdm.lock 118 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 119 | # in version control. 120 | # https://pdm.fming.dev/#use-with-ide 121 | .pdm.toml 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # PyCharm 167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 169 | # and can be added to the global gitignore or merged into this file. For a more nuclear 170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 171 | #.idea/ 172 | -------------------------------------------------------------------------------- /llmx/configs/config.default.yml: -------------------------------------------------------------------------------- 1 | # Sets the the default model to use for llm() when no provider parameter is set. 2 | model: 3 | provider: openai 4 | parameters: 5 | api_key: null 6 | 7 | # list of supported providers. 8 | providers: 9 | anthropic: 10 | name: Anthropic 11 | description: Anthropic's Claude models. 12 | models: 13 | - name: claude-3-5-sonnet-20240620 14 | max_tokens: 8192 15 | model: 16 | provider: anthropic 17 | parameters: 18 | model: claude-3-5-sonnet-20240620 19 | openai: 20 | name: OpenAI 21 | description: OpenAI's and AzureOpenAI GPT-3 and GPT-4 models. 22 | models: 23 | - name: gpt-4o # general model name, can be anything 24 | max_tokens: 4096 # max supported tokens 25 | model: 26 | provider: openai 27 | parameters: 28 | model: gpt-4o 29 | - name: gpt-4 # general model name, can be anything 30 | max_tokens: 8192 # max supported tokens 31 | model: 32 | provider: openai 33 | parameters: 34 | model: gpt-4 # model actual name, required 35 | - name: gpt-4-32k 36 | max_tokens: 32768 37 | model: 38 | provider: openai 39 | parameters: 40 | model: gpt-4-32k 41 | - name: gpt-3.5-turbo 42 | max_tokens: 4096 43 | model: 44 | provider: openai 45 | parameters: 46 | model: gpt-3.5-turbo 47 | - name: gpt-3.5-turbo-0301 48 | max_tokens: 4096 49 | model: 50 | provider: openai 51 | parameters: 52 | model: gpt-3.5-turbo-0301 53 | - name: gpt-3.5-turbo-16k 54 | max_tokens: 16384 55 | model: 56 | provider: openai 57 | parameters: 58 | model: gpt-3.5-turbo-16k 59 | - name: gpt-3.5-turbo-azure 60 | max_tokens: 4096 61 | model: 62 | provider: azureopenai 63 | parameters: 64 | api_key: 65 | api_type: azure 66 | api_base: 67 | api_version: 68 | organization: # or null 69 | model: gpt-3.5-turbo-0316 70 | palm: 71 | name: Google 72 | description: Google's LLM models. 73 | models: 74 | - name: chat-bison-vertexai 75 | max_tokens: 1024 76 | model: 77 | provider: palm 78 | parameters: 79 | model: codechat-bison@001 80 | project_id: 81 | project_location: 82 | palm_key_file: 83 | - name: chat-bison-makersuite 84 | max_tokens: 1024 85 | model: 86 | provider: palm 87 | parameters: 88 | model: chat-bison-001 89 | api_key: 90 | - name: codechat-bison-makersuite 91 | max_tokens: 1024 92 | model: 93 | provider: palm 94 | parameters: 95 | model: codechat-bison-001 96 | api_key: 97 | - name: codechat-bison-32k 98 | max_tokens: 32768 99 | model: 100 | provider: palm 101 | parameters: 102 | model: codechat-bison-32k 103 | project_id: 104 | project_location: 105 | palm_key_file: 106 | - name: chat-bison-32k 107 | max_tokens: 32768 108 | model: 109 | provider: palm 110 | parameters: 111 | model: codechat-bison-32k 112 | project_id: 113 | project_location: 114 | palm_key_file: 115 | cohere: 116 | name: Cohere 117 | description: Cohere's LLM models. 118 | models: 119 | - name: command 120 | max_tokens: 4096 121 | model: 122 | provider: cohere 123 | parameters: 124 | model: command 125 | - name: command-nightly 126 | max_tokens: 4096 127 | model: 128 | provider: cohere 129 | parameters: 130 | model: command-nightly 131 | huggingface: 132 | name: HuggingFace 133 | description: HuggingFace's LLM models. 134 | models: 135 | - name: TheBloke/Llama-2-7b-chat-fp16 136 | max_tokens: 4096 137 | model: 138 | provider: huggingface 139 | parameters: 140 | model: TheBloke/Llama-2-7b-chat-fp16 141 | device_map: auto 142 | - name: hermes-orca-platypus-13b 143 | max_tokens: 4096 144 | model: 145 | provider: huggingface 146 | parameters: 147 | model: uukuguy/speechless-llama2-hermes-orca-platypus-13b 148 | device_map: auto 149 | trust_remote_code: true 150 | -------------------------------------------------------------------------------- /llmx/generators/text/anthropic_textgen.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Dict 2 | import os 3 | import anthropic 4 | from dataclasses import asdict 5 | 6 | from .base_textgen import TextGenerator 7 | from ...datamodel import TextGenerationConfig, TextGenerationResponse, Message 8 | from ...utils import cache_request, get_models_maxtoken_dict, num_tokens_from_messages 9 | 10 | 11 | class AnthropicTextGenerator(TextGenerator): 12 | def __init__( 13 | self, 14 | api_key: str = None, 15 | provider: str = "anthropic", 16 | model: str = None, 17 | models: Dict = None, 18 | ): 19 | super().__init__(provider=provider) 20 | api_key = api_key or os.environ.get("ANTHROPIC_API_KEY", None) 21 | if api_key is None: 22 | raise ValueError( 23 | "Anthropic API key is not set. Please set the ANTHROPIC_API_KEY environment variable." 24 | ) 25 | self.client = anthropic.Anthropic( 26 | api_key=api_key, 27 | default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"}, 28 | ) 29 | self.model_max_token_dict = get_models_maxtoken_dict(models) 30 | self.model_name = model or "claude-3-5-sonnet-20240620" 31 | 32 | def format_messages(self, messages): 33 | formatted_messages = [] 34 | for message in messages: 35 | formatted_message = {"role": message["role"], "content": message["content"]} 36 | formatted_messages.append(formatted_message) 37 | return formatted_messages 38 | 39 | 40 | def generate( 41 | self, 42 | messages: Union[List[Dict], str], 43 | config: TextGenerationConfig = TextGenerationConfig(), 44 | **kwargs, 45 | ) -> TextGenerationResponse: 46 | use_cache = config.use_cache 47 | model = config.model or self.model_name 48 | prompt_tokens = num_tokens_from_messages(messages) 49 | max_tokens = max( 50 | self.model_max_token_dict.get(model, 8192) - prompt_tokens - 10, 200 51 | ) 52 | 53 | # Process messages 54 | system_message = None 55 | other_messages = [] 56 | for message in messages: 57 | message["content"] = message["content"].strip() 58 | if message["role"] == "system": 59 | if system_message is None: 60 | system_message = message["content"] 61 | else: 62 | # If multiple system messages, concatenate them 63 | system_message += "\n" + message["content"] 64 | else: 65 | other_messages.append(message) 66 | 67 | if not other_messages: 68 | raise ValueError("At least one message is required") 69 | 70 | # Check if inversion is needed 71 | needs_inversion = other_messages[0]["role"] == "assistant" 72 | if needs_inversion: 73 | other_messages = self.invert_messages(other_messages) 74 | 75 | anthropic_config = { 76 | "model": model, 77 | "max_tokens": config.max_tokens or max_tokens, 78 | "temperature": config.temperature, 79 | "top_p": config.top_p, 80 | "messages": other_messages, 81 | } 82 | 83 | if system_message: 84 | anthropic_config["system"] = system_message 85 | 86 | self.model_name = model 87 | cache_key_params = anthropic_config.copy() 88 | cache_key_params["messages"] = messages # Keep original messages for caching 89 | 90 | if use_cache: 91 | response = cache_request(cache=self.cache, params=cache_key_params) 92 | if response: 93 | return TextGenerationResponse(**response) 94 | anthropic_response = self.client.messages.create(**anthropic_config) 95 | 96 | response_content = anthropic_response.content[0].text 97 | 98 | # Always strip "Human: " prefix, regardless of inversion 99 | if response_content.startswith("Human: "): 100 | response_content = response_content[7:] 101 | 102 | response = TextGenerationResponse( 103 | text=[Message(role="assistant", content=response_content)], 104 | logprobs=[], 105 | config=anthropic_config, 106 | usage={ 107 | "prompt_tokens": anthropic_response.usage.input_tokens, 108 | "completion_tokens": anthropic_response.usage.output_tokens, 109 | "total_tokens": anthropic_response.usage.input_tokens 110 | + anthropic_response.usage.output_tokens, 111 | }, 112 | response=anthropic_response, 113 | ) 114 | 115 | cache_request( 116 | cache=self.cache, params=cache_key_params, values=asdict(response) 117 | ) 118 | return response 119 | 120 | def invert_messages(self, messages): 121 | inverted = [] 122 | for i, message in enumerate(messages): 123 | if i % 2 == 0: 124 | inverted.append({"role": "user", "content": message["content"]}) 125 | else: 126 | inverted.append({"role": "assistant", "content": message["content"]}) 127 | return inverted 128 | def count_tokens(self, text) -> int: 129 | return num_tokens_from_messages(text) 130 | -------------------------------------------------------------------------------- /llmx/generators/text/palm_textgen.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | import os 3 | import logging 4 | from typing import Dict, Union 5 | from .base_textgen import TextGenerator 6 | from ...datamodel import TextGenerationConfig, TextGenerationResponse, Message 7 | from ...utils import ( 8 | cache_request, 9 | gcp_request, 10 | get_models_maxtoken_dict, 11 | num_tokens_from_messages, 12 | get_gcp_credentials, 13 | ) 14 | 15 | logger = logging.getLogger("llmx") 16 | 17 | 18 | class PalmTextGenerator(TextGenerator): 19 | def __init__( 20 | self, 21 | api_key: str = os.environ.get("PALM_API_KEY", None), 22 | palm_key_file: str = os.environ.get("PALM_SERVICE_ACCOUNT_KEY_FILE", None), 23 | project_id: str = os.environ.get("PALM_PROJECT_ID", None), 24 | project_location=os.environ.get("PALM_PROJECT_LOCATION", "us-central1"), 25 | provider: str = "palm", 26 | model: str = None, 27 | models: Dict = None, 28 | ): 29 | super().__init__(provider=provider) 30 | 31 | if api_key is None and palm_key_file is None: 32 | raise ValueError( 33 | "PALM_API_KEY or PALM_SERVICE_ACCOUNT_KEY_FILE must be set." 34 | ) 35 | if api_key: 36 | self.api_key = api_key 37 | self.credentials = None 38 | self.project_id = None 39 | self.project_location = None 40 | else: 41 | self.project_id = project_id 42 | self.project_location = project_location 43 | self.api_key = None 44 | self.credentials = get_gcp_credentials(palm_key_file) if palm_key_file else None 45 | 46 | self.model_max_token_dict = get_models_maxtoken_dict(models) 47 | self.model_name = model or "chat-bison" 48 | 49 | def format_messages(self, messages): 50 | palm_messages = [] 51 | system_messages = "" 52 | for message in messages: 53 | if message["role"] == "system": 54 | system_messages += message["content"] + "\n" 55 | else: 56 | if not palm_messages or palm_messages[-1]["author"] != message["role"]: 57 | palm_message = { 58 | "author": message["role"], 59 | "content": message["content"], 60 | } 61 | palm_messages.append(palm_message) 62 | else: 63 | palm_messages[-1]["content"] += "\n" + message["content"] 64 | 65 | if palm_messages and len(palm_messages) % 2 == 0: 66 | merged_content = ( 67 | palm_messages[-2]["content"] + "\n" + palm_messages[-1]["content"] 68 | ) 69 | palm_messages[-2]["content"] = merged_content 70 | palm_messages.pop() 71 | 72 | if len(palm_messages) == 0: 73 | logger.info("No messages to send to PALM") 74 | 75 | return system_messages, palm_messages 76 | 77 | def generate( 78 | self, 79 | messages: Union[list[dict], str], 80 | config: TextGenerationConfig = TextGenerationConfig(), 81 | **kwargs, 82 | ) -> TextGenerationResponse: 83 | use_cache = config.use_cache 84 | model = config.model or self.model_name 85 | 86 | system_messages, messages = self.format_messages(messages) 87 | self.model_name = model 88 | 89 | max_tokens = self.model_max_token_dict[model] if model in self.model_max_token_dict else 1024 90 | palm_config = { 91 | "temperature": config.temperature, 92 | "maxOutputTokens": config.max_tokens or max_tokens, 93 | "candidateCount": config.n, 94 | } 95 | 96 | api_url = "" 97 | if self.api_key: 98 | api_url = f"https://generativelanguage.googleapis.com/v1beta2/models/{model}:generateMessage?key={self.api_key}" 99 | 100 | palm_payload = { 101 | "prompt": {"messages": messages, "context": system_messages}, 102 | "temperature": config.temperature, 103 | "candidateCount": config.n, 104 | "topP": config.top_p, 105 | "topK": config.top_k, 106 | } 107 | 108 | else: 109 | api_url = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.project_location}/publishers/google/models/{model}:predict" 110 | 111 | palm_payload = { 112 | "instances": [ 113 | { 114 | "messages": messages, 115 | "context": system_messages, 116 | "examples": [], 117 | } 118 | ], 119 | "parameters": palm_config, 120 | } 121 | 122 | cache_key_params = {**palm_payload, "model": model, "api_url": api_url} 123 | 124 | if use_cache: 125 | response = cache_request(cache=self.cache, params=cache_key_params) 126 | if response: 127 | return TextGenerationResponse(**response) 128 | 129 | palm_response = gcp_request( 130 | url=api_url, body=palm_payload, method="POST", credentials=self.credentials 131 | ) 132 | 133 | candidates = palm_response["candidates"] if self.api_key else palm_response["predictions"][ 134 | 0]["candidates"] 135 | 136 | response_text = [ 137 | Message( 138 | role="assistant" if x["author"] == "1" else x["author"], 139 | content=x["content"], 140 | ) 141 | for x in candidates 142 | ] 143 | 144 | response = TextGenerationResponse( 145 | text=response_text, 146 | logprobs=[], 147 | config=palm_config, 148 | usage={ 149 | "total_tokens": num_tokens_from_messages( 150 | response_text, model=self.model_name 151 | ) 152 | }, 153 | response=palm_response, 154 | ) 155 | 156 | cache_request( 157 | cache=self.cache, params=(cache_key_params), values=asdict(response) 158 | ) 159 | return response 160 | 161 | def count_tokens(self, text) -> int: 162 | return num_tokens_from_messages(text) 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLMX - An API for Chat Fine-Tuned Language Models 2 | 3 | [![PyPI version](https://badge.fury.io/py/llmx.svg)](https://badge.fury.io/py/llmx) 4 | 5 | A simple python package that provides a unified interface to several LLM providers of chat fine-tuned models [OpenAI, AzureOpenAI, PaLM, Cohere and local HuggingFace Models]. 6 | 7 | > **Note** 8 | > llmx wraps multiple api providers and its interface _may_ change as the providers as well as the general field of LLMs evolve. 9 | 10 | There is nothing particularly special about this library, but some of the requirements I needed when I started building this (that other libraries did not have): 11 | 12 | - **Unified Model Interface**: Single interface to create LLM text generators with support for **multiple LLM providers**. 13 | 14 | ```python 15 | from llmx import llm 16 | 17 | gen = llm(provider="openai") # support azureopenai models too. 18 | gen = llm(provider="palm") # or google 19 | gen = llm(provider="cohere") # or palm 20 | gen = llm(provider="hf", model="HuggingFaceH4/zephyr-7b-beta", device_map="auto") # run huggingface model locally 21 | ``` 22 | 23 | - **Unified Messaging Interface**. Standardizes on the OpenAI ChatML message format and is designed for _chat finetuned_ models. For example, the standard prompt sent a model is formatted as an array of objects, where each object has a role (`system`, `user`, or `assistant`) and content (see below). A single request is list of only one message (e.g., write code to plot a cosine wave signal). A conversation is a list of messages e.g. write code for x, update the axis to y, etc. Same format for all models. 24 | 25 | ```python 26 | messages = [ 27 | {"role": "user", "content": "You are a helpful assistant that can explain concepts clearly to a 6 year old child."}, 28 | {"role": "user", "content": "What is gravity?"} 29 | ] 30 | ``` 31 | 32 | - **Good Utils (e.g., Caching etc)**: E.g. being able to use caching for faster responses. General policy is that cache is used if config (including messages) is the same. If you want to force a new response, set `use_cache=False` in the `generate` call. 33 | 34 | ```python 35 | response = gen.generate(messages=messages, config=TextGeneratorConfig(n=1, use_cache=True)) 36 | ``` 37 | 38 | Output looks like 39 | 40 | ```bash 41 | 42 | TextGenerationResponse( 43 | text=[Message(role='assistant', content="Gravity is like a magical force that pulls things towards each other. It's what keeps us on the ground and stops us from floating away into space. ... ")], 44 | config=TextGenerationConfig(n=1, temperature=0.1, max_tokens=8147, top_p=1.0, top_k=50, frequency_penalty=0.0, presence_penalty=0.0, provider='openai', model='gpt-4', stop=None), 45 | logprobs=[], usage={'prompt_tokens': 34, 'completion_tokens': 69, 'total_tokens': 103}) 46 | 47 | ``` 48 | 49 | Are there other libraries that do things like this really well? Yes! I'd recommend looking at [guidance](https://github.com/microsoft/guidance) which does a lot more. Interested in optimized inference? Try somthing like [vllm](https://github.com/vllm-project/vllm). 50 | 51 | ## Installation 52 | 53 | Install from pypi. Please use **python3.10** or higher. 54 | 55 | ```bash 56 | pip install llmx 57 | ``` 58 | 59 | Install in development mode 60 | 61 | ```bash 62 | git clone 63 | cd llmx 64 | pip install -e . 65 | ``` 66 | 67 | Note that you may want to use the latest version of pip to install this package. 68 | `python3 -m pip install --upgrade pip` 69 | 70 | ## Usage 71 | 72 | Set your api keys first for each service. 73 | 74 | ```bash 75 | # for openai and cohere 76 | export OPENAI_API_KEY= 77 | export COHERE_API_KEY= 78 | 79 | # for PALM via MakerSuite 80 | export PALM_API_KEY= 81 | 82 | # for PaLM (Vertex AI), setup a gcp project, and get a service account key file 83 | export PALM_SERVICE_ACCOUNT_KEY_FILE= 84 | export PALM_PROJECT_ID= 85 | export PALM_PROJECT_LOCATION= 86 | ``` 87 | 88 | You can also set the default provider and list of supported providers via a config file. Use the yaml format in this [sample `config.default.yml` file](llmx/configs/config.default.yml) and set the `LLMX_CONFIG_PATH` to the path of the config file. 89 | 90 | ```python 91 | from llmx import llm 92 | from llmx.datamodel import TextGenerationConfig 93 | 94 | messages = [ 95 | {"role": "system", "content": "You are a helpful assistant that can explain concepts clearly to a 6 year old child."}, 96 | {"role": "user", "content": "What is gravity?"} 97 | ] 98 | 99 | openai_gen = llm(provider="openai") 100 | openai_config = TextGenerationConfig(model="gpt-4", max_tokens=50) 101 | openai_response = openai_gen.generate(messages, config=openai_config, use_cache=True) 102 | print(openai_response.text[0].content) 103 | 104 | ``` 105 | 106 | See the [tutorial](/notebooks/tutorial.ipynb) for more examples. 107 | 108 | ## A Note on Using Local HuggingFace Models 109 | 110 | While llmx can use the huggingface transformers library to run inference with local models, you might get more mileage from using a well-optimized server endpoint like [vllm](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html#openai-compatible-server), or FastChat. The general idea is that these tools let you provide an openai-compatible endpoint but also implement optimizations such as dynamic batching, quantization etc to improve throughput. The general steps are: 111 | 112 | - install vllm, setup endpoint e.g., on port `8000` 113 | - use openai as your provider to access that endpoint. 114 | 115 | ```python 116 | from llmx import llm 117 | hfgen_gen = llm( 118 | provider="openai", 119 | api_base="http://localhost:8000", 120 | api_key="EMPTY, 121 | ) 122 | ... 123 | ``` 124 | 125 | ## Current Work 126 | 127 | - Supported models 128 | - [x] OpenAI 129 | - [x] PaLM ([MakerSuite](https://developers.generativeai.google/api/rest/generativelanguage), [Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models)) 130 | - [x] Cohere 131 | - [x] HuggingFace (local) 132 | 133 | ## Caveats 134 | 135 | - **Prompting**. llmx makes some assumptions around how prompts are constructed e.g., how the chat message interface is assembled into a prompt for each model type. If your application or use case requires more control over the prompt, you may want to use a different library (ideally query the LLM models directly). 136 | - **Inference Optimization**. For hosted models (GPT-4, PalM, Cohere) etc, this library provides an excellent unified interface as the hosted api already takes care of inference optimizations. However, if you are looking for a library that is optimized for inference with **_local models_(e.g., huggingface)** (tensor parrelization, distributed inference etc), I'd recommend looking at [vllm](https://github.com/vllm-project/vllm) or [tgi](https://github.com/huggingface/text-generation-inference). 137 | 138 | ## Citation 139 | 140 | If you use this library in your work, please cite: 141 | 142 | ```bibtex 143 | @software{victordibiallmx, 144 | author = {Victor Dibia}, 145 | license = {MIT}, 146 | month = {10}, 147 | title = {LLMX - An API for Chat Fine-Tuned Language Models}, 148 | url = {https://github.com/victordibia/llmx}, 149 | year = {2023} 150 | } 151 | ``` 152 | -------------------------------------------------------------------------------- /llmx/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | import logging 3 | import json 4 | from typing import Any, Union, Dict 5 | import tiktoken 6 | from diskcache import Cache 7 | import hashlib 8 | import os 9 | import platform 10 | import google.auth 11 | import google.auth.transport.requests 12 | from google.oauth2 import service_account 13 | import requests 14 | import yaml 15 | 16 | logger = logging.getLogger("llmx") 17 | 18 | 19 | def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"): 20 | """Returns the number of tokens used by a list of messages.""" 21 | try: 22 | encoding = tiktoken.encoding_for_model(model) 23 | except KeyError: 24 | encoding = tiktoken.get_encoding("cl100k_base") 25 | if ( 26 | model == "gpt-3.5-turbo-0301" or True 27 | ): # note: future models may deviate from this 28 | num_tokens = 0 29 | for message in messages: 30 | if not isinstance(message, dict): 31 | message = asdict(message) 32 | 33 | num_tokens += ( 34 | 4 # every message follows {role/name}\n{content}\n 35 | ) 36 | 37 | for key, value in message.items(): 38 | num_tokens += len(encoding.encode(value)) 39 | if key == "name": # if there's a name, the role is omitted 40 | num_tokens += -1 # role is always required and always 1 token 41 | num_tokens += 2 # every reply is primed with assistant 42 | return num_tokens 43 | 44 | 45 | def cache_request(cache: Cache, params: dict, values: Union[Dict, None] = None) -> Any: 46 | # Generate a unique key for the request 47 | 48 | key = hashlib.md5(json.dumps(params, sort_keys=True).encode("utf-8")).hexdigest() 49 | # Check if the request is cached 50 | if key in cache and values is None: 51 | # print("retrieving from cache") 52 | return cache[key] 53 | 54 | # Cache the provided values and return them 55 | if values: 56 | # print("saving to cache") 57 | cache[key] = values 58 | return values 59 | 60 | 61 | def get_user_cache_dir(app_name: str) -> str: 62 | system = platform.system() 63 | if system == "Windows": 64 | cache_path = os.path.join(os.getenv("LOCALAPPDATA"), app_name, "Cache") 65 | elif system == "Darwin": 66 | cache_path = os.path.join(os.path.expanduser("~/Library/Caches"), app_name) 67 | else: # Linux and other UNIX-like systems 68 | cache_path = os.path.join( 69 | os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache")), app_name 70 | ) 71 | os.makedirs(cache_path, exist_ok=True) 72 | return cache_path 73 | 74 | 75 | def get_gcp_credentials(service_account_key_file: str = None, scopes: list[str] = [ 76 | 'https://www.googleapis.com/auth/cloud-platform']): 77 | try: 78 | # Attempt to use Application Default Credentials 79 | credentials, project_id = google.auth.default(scopes=scopes) 80 | auth_req = google.auth.transport.requests.Request() 81 | credentials.refresh(auth_req) 82 | return credentials 83 | except google.auth.exceptions.DefaultCredentialsError: 84 | # Fall back to using service account key 85 | if service_account_key_file is None: 86 | raise ValueError( 87 | "Service account key file is not set. Please set the PALM_SERVICE_ACCOUNT_KEY_FILE environment variable." 88 | ) 89 | credentials = service_account.Credentials.from_service_account_file( 90 | service_account_key_file, scopes=scopes) 91 | auth_req = google.auth.transport.requests.Request() 92 | credentials.refresh(auth_req) 93 | return credentials 94 | 95 | 96 | def gcp_request( 97 | url: str, 98 | method: str = "POST", 99 | body: dict = None, 100 | headers: dict = None, 101 | credentials: google.auth.credentials.Credentials = None, 102 | request_timeout: int = 60, 103 | **kwargs, 104 | ): 105 | 106 | headers = headers or {} 107 | 108 | if "key" not in url: 109 | if credentials is None: 110 | credentials = get_gcp_credentials() 111 | auth_req = google.auth.transport.requests.Request() 112 | if credentials.expired: 113 | credentials.refresh(auth_req) 114 | headers["Authorization"] = f"Bearer {credentials.token}" 115 | headers["Content-Type"] = "application/json" 116 | 117 | response = requests.request( 118 | method=method, url=url, json=body, headers=headers, timeout=request_timeout, **kwargs 119 | ) 120 | 121 | if response.status_code not in range(200, 300): 122 | try: 123 | error_message = response.json().get("error", {}).get("message", "") 124 | except json.JSONDecodeError: 125 | error_message = response.content 126 | raise Exception( 127 | f"Request failed with status code {response.status_code}: {error_message}" 128 | ) 129 | 130 | return response.json() 131 | 132 | 133 | def load_config(): 134 | try: 135 | config_path = os.environ.get("LLMX_CONFIG_PATH", None) 136 | if config_path is None or os.path.exists(config_path) is False: 137 | config_path = os.path.join( 138 | os.path.dirname(__file__), 139 | "configs/config.default.yml") 140 | logger.info( 141 | "Info: LLMX_CONFIG_PATH environment variable is not set to a valid config file. Using default config file at '%s'.", 142 | config_path) 143 | if config_path is not None: 144 | try: 145 | with open(config_path, "r", encoding="utf-8") as f: 146 | config = yaml.safe_load(f) 147 | logger.info( 148 | "Loaded config from '%s'.", 149 | config_path) 150 | return config 151 | except FileNotFoundError as file_not_found: 152 | logger.info( 153 | "Error: Config file not found at '%s'. Please check the LLMX_CONFIG_PATH environment variable. %s", 154 | config_path, 155 | str(file_not_found)) 156 | except IOError as io_error: 157 | logger.info( 158 | "Error: Could not read the config file at '%s'. %s", 159 | config_path, str(io_error)) 160 | except yaml.YAMLError as yaml_error: 161 | logger.info( 162 | "Error: Malformed YAML in config file at '%s'. %s", 163 | config_path, str(yaml_error)) 164 | else: 165 | logger.info( 166 | "Info:LLMX_CONFIG_PATH environment variable is not set. Please set it to the path of your config file to setup your default model.") 167 | except Exception as error: 168 | logger.info("Error: An unexpected error occurred: %s", str(error)) 169 | 170 | return None 171 | 172 | 173 | def get_models_maxtoken_dict(models_list): 174 | if not models_list: 175 | return {} 176 | 177 | models_dict = {} 178 | for model in models_list: 179 | if "model" in model and "parameters" in model["model"]: 180 | details = model["model"]["parameters"] 181 | models_dict[details["model"]] = model["max_tokens"] 182 | return models_dict 183 | -------------------------------------------------------------------------------- /llmx/generators/text/hf_textgen.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | from dataclasses import asdict, dataclass 3 | from transformers import (AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig) 4 | import torch 5 | 6 | 7 | from .base_textgen import TextGenerator 8 | from ...datamodel import TextGenerationConfig, TextGenerationResponse 9 | from ...utils import cache_request, get_models_maxtoken_dict 10 | 11 | 12 | @dataclass 13 | class DialogueTemplate: 14 | system: str = None 15 | dialogue_type: str = "default" 16 | messages: list[dict[str, str]] = None 17 | system_token: str = "<|system|>" 18 | user_token: str = "<|user|>" 19 | assistant_token: str = "<|assistant|>" 20 | end_token: str = "<|end|>" 21 | 22 | def get_inference_prompt(self) -> str: 23 | if self.dialogue_type == "default": 24 | prompt = "" 25 | system_prompt = ( 26 | self.system_token + "\n" + self.system + self.end_token + "\n" 27 | if self.system 28 | else "" 29 | ) 30 | if self.messages is None: 31 | raise ValueError("Dialogue template must have at least one message.") 32 | for message in self.messages: 33 | if message["role"] == "system": 34 | system_prompt += ( 35 | self.system_token 36 | + "\n" 37 | + message["content"] 38 | + self.end_token 39 | + "\n" 40 | ) 41 | elif message["role"] == "user": 42 | prompt += ( 43 | self.user_token 44 | + "\n" 45 | + message["content"] 46 | + self.end_token 47 | + "\n" 48 | ) 49 | else: 50 | prompt += ( 51 | self.assistant_token 52 | + "\n" 53 | + message["content"] 54 | + self.end_token 55 | + "\n" 56 | ) 57 | prompt += self.assistant_token 58 | if system_prompt: 59 | prompt = system_prompt + prompt 60 | return prompt 61 | elif self.dialogue_type == "alpaca": 62 | prompt = ( 63 | self.user_token + "\n" + (self.system + "\n" if self.system else "") 64 | ) 65 | for message in self.messages: 66 | prompt += message["content"] + "\n" 67 | prompt = prompt + " " + self.assistant_token + "\n" 68 | # print(instruction) 69 | return prompt 70 | elif self.dialogue_type == "llama2": 71 | prompt = "[INST]" 72 | system_prompt = "" 73 | other_prompt = "" 74 | 75 | for message in self.messages: 76 | if message["role"] == "system": 77 | system_prompt += message["content"] + "\n" 78 | elif message["role"] == "assistant": 79 | other_prompt += message["content"] + " \n" 80 | else: 81 | other_prompt += "[INST] " + message["content"] + "[/INST]\n" 82 | 83 | prompt = ( 84 | prompt 85 | + f" <> {system_prompt} <> \n" 86 | + other_prompt 87 | + "[/INST]" 88 | ) 89 | 90 | 91 | class HFTextGenerator(TextGenerator): 92 | def __init__(self, 93 | provider: str = "huggingface", 94 | models: Dict = None, 95 | device_map=None, **kwargs): 96 | 97 | super().__init__(provider=provider) 98 | 99 | self.dialogue_type = kwargs.get("dialogue_type", "alpaca") 100 | 101 | self.model_name = kwargs.get("model", "uukuguy/speechless-llama2-hermes-orca-platypus-13b") 102 | self.quantization_config = kwargs.get("quantization_config", BitsAndBytesConfig()) 103 | self.trust_remote_code = kwargs.get("trust_remote_code", False) 104 | self.device = kwargs.get("device", self.get_default_device()) 105 | 106 | # load tokenizer and model 107 | self.tokenizer = AutoTokenizer.from_pretrained( 108 | self.model_name, trust_remote_code=self.trust_remote_code) 109 | self.model = AutoModelForCausalLM.from_pretrained( 110 | self.model_name, 111 | device_map=device_map, 112 | quantization_config=self.quantization_config, 113 | trust_remote_code=self.trust_remote_code, 114 | ) 115 | if not device_map: 116 | self.model.to(self.device) 117 | self.model.config.pad_token_id = self.tokenizer.pad_token_id 118 | 119 | self.max_length = kwargs.get("max_length", 1024) 120 | 121 | self.model_max_token_dict = get_models_maxtoken_dict(models) 122 | self.max_context_tokens = kwargs.get( 123 | "max_context_tokens", self.model.config.max_position_embeddings 124 | ) or self.model_max_token_dict[self.model_name] 125 | 126 | if self.dialogue_type == "alpaca": 127 | self.dialogue_template = DialogueTemplate( 128 | dialogue_type="alpaca", 129 | end_token=self.tokenizer.eos_token, 130 | user_token="### Instruction:", 131 | assistant_token="### Response:", 132 | ) 133 | self.model.config.pad_token_id = self.tokenizer.pad_token_id = 0 # unk 134 | self.model.config.bos_token_id = 1 135 | self.model.config.eos_token_id = 2 136 | else: 137 | self.dialogue_template = DialogueTemplate(end_token=self.tokenizer.eos_token) 138 | 139 | def get_default_device(self): 140 | """Pick GPU if available, else CPU""" 141 | if torch.cuda.is_available(): 142 | return torch.device("cuda") 143 | elif torch.backends.mps.is_available(): 144 | return torch.device("mps") 145 | else: 146 | return torch.device("cpu") 147 | 148 | def post_process_response(self, response): 149 | response = ( 150 | response.split(self.dialogue_template.assistant_token)[-1] 151 | .replace(self.dialogue_template.end_token, "") 152 | .strip() 153 | ) 154 | response = {"role": "assistant", "content": response} 155 | return response 156 | 157 | def messages_to_instruction(self, messages): 158 | instruction = "### Instruction: " 159 | for message in messages: 160 | instruction += message["content"] + "\n" 161 | instruction = instruction + "### Response: " 162 | # print(instruction) 163 | return instruction 164 | 165 | def generate( 166 | self, messages: Union[list[dict], 167 | str], 168 | config: TextGenerationConfig = TextGenerationConfig(), 169 | **kwargs) -> TextGenerationResponse: 170 | use_cache = config.use_cache 171 | config.model = self.model_name 172 | cache_key_params = { 173 | **asdict(config), 174 | **kwargs, 175 | "messages": messages, 176 | "dialogue_type": self.dialogue_type} 177 | if use_cache: 178 | response = cache_request(cache=self.cache, params=(cache_key_params)) 179 | if response: 180 | return TextGenerationResponse(**response) 181 | 182 | self.dialogue_template.messages = messages 183 | prompt = self.dialogue_template.get_inference_prompt() 184 | batch = self.tokenizer( 185 | prompt, return_tensors="pt", return_token_type_ids=False 186 | ).to(self.model.device) 187 | input_ids = batch["input_ids"] 188 | 189 | max_new_tokens = kwargs.get( 190 | "max_new_tokens", self.max_context_tokens - input_ids.shape[-1] 191 | ) 192 | # print( 193 | # "***********Prompt tokens: ", 194 | # input_ids.shape[-1], 195 | # " | Max new tokens: ", 196 | # max_new_tokens, 197 | # ) 198 | 199 | top_k = kwargs.get("top_k", config.top_k) 200 | min_new_tokens = kwargs.get("min_new_tokens", 32) 201 | repetition_penalty = kwargs.get("repetition_penalty", 1.0) 202 | 203 | gen_config = GenerationConfig( 204 | max_new_tokens=max_new_tokens, 205 | temperature=max(config.temperature, 0.01), 206 | top_p=config.top_p, 207 | top_k=top_k, 208 | num_return_sequences=config.n, 209 | do_sample=True, 210 | pad_token_id=self.tokenizer.eos_token_id, 211 | eos_token_id=self.tokenizer.eos_token_id, 212 | min_new_tokens=min_new_tokens, 213 | repetition_penalty=repetition_penalty, 214 | ) 215 | with torch.no_grad(): 216 | generated_ids = self.model.generate(**batch, generation_config=gen_config) 217 | 218 | text_response = self.tokenizer.batch_decode( 219 | generated_ids, skip_special_tokens=False 220 | ) 221 | 222 | # print(text_response, "*************") 223 | prompt_tokens = len(batch["input_ids"][0]) 224 | total_tokens = 0 225 | for row in generated_ids: 226 | total_tokens += len(row) 227 | 228 | usage = { 229 | "prompt_tokens": prompt_tokens, 230 | "completion_tokens": total_tokens - prompt_tokens, 231 | "total_tokens": total_tokens, 232 | } 233 | 234 | response = TextGenerationResponse( 235 | text=[self.post_process_response(x) for x in text_response], 236 | logprobs=[], 237 | config=config, 238 | usage=usage, 239 | ) 240 | # if use_cache: 241 | cache_request(cache=self.cache, params=(cache_key_params), values=asdict(response)) 242 | return response 243 | 244 | def count_tokens(self, text: str): 245 | return len(self.tokenizer(text)["input_ids"]) 246 | -------------------------------------------------------------------------------- /notebooks/tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from llmx import llm, TextGenerationConfig\n", 10 | "import os " 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "config = TextGenerationConfig( \n", 20 | " n=1,\n", 21 | " temperature=0.8,\n", 22 | " max_tokens=100,\n", 23 | " top_p=1.0,\n", 24 | " top_k=50,\n", 25 | " frequency_penalty=0.0,\n", 26 | " presence_penalty=0.0,\n", 27 | ")\n", 28 | "messages = [\n", 29 | " {\"role\": \"system\", \"content\": \"You are a helpful assistant that can explain concepts clearly to a 6 year old child.\"},\n", 30 | " {\"role\": \"user\", \"content\": \"What is gravity?\"}\n", 31 | "]" 32 | ] 33 | }, 34 | { 35 | "attachments": {}, 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## llmx Supports Multiple Providers \n", 40 | "\n", 41 | "### OpenAI" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "Gravity is like a big invisible force that pulls things towards each other. It's what keeps us on the ground and makes things fall down instead of floating away. Imagine if you threw a ball up in the air, gravity would pull it back down to the ground. It's like a super strong magnet that pulls everything together.\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "openai_gen = llm(provider=\"openai\", api_key=os.environ[\"OPENAI_API_KEY\"])\n", 59 | "openai_config = TextGenerationConfig(model=\"gpt-3.5-turbo\", use_cache=True)\n", 60 | "openai_response = openai_gen.generate(messages, config=openai_config)\n", 61 | "print(openai_response.text[0].content)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "### Azure OpenAI" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 9, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "Gravity is like a big invisible force that pulls things towards each other. It's what keeps us on the ground and makes things fall down when we drop them. It's like a big hug from the Earth that keeps us close to it.\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "from dotenv import load_dotenv \n", 86 | "\n", 87 | "load_dotenv(override=True)\n", 88 | "\n", 89 | "azure_openai_gen = llm(\n", 90 | " provider=\"openai\",\n", 91 | " api_type=\"azure\",\n", 92 | " azure_endpoint=os.environ[\"AZURE_OPENAI_BASE\"],\n", 93 | " api_key=os.environ[\"AZURE_OPENAI_API_KEY\"],\n", 94 | " api_version=\"2023-07-01-preview\",\n", 95 | ")\n", 96 | "openai_config = TextGenerationConfig(model=\"gpt-35-turbo-0613\", use_cache=True)\n", 97 | "openai_response = azure_openai_gen.generate(messages, config=openai_config)\n", 98 | "print(openai_response.text[0].content)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "### PaLM (Google) \n", 106 | "\n" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "#### PaLM: MakerSuite API \n", 114 | "\n", 115 | "- Visit [https://makersuite.google.com/](https://makersuite.google.com/) to get an api key. \n", 116 | "- Also note that the list of supported models might vary." 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 5, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "Gravity is a force that pulls objects towards each other. The more massive an object is, the stronger its gravitational pull. The Earth is very massive, so it has a strong gravitational pull. This is why we don't float off into space. The Moon is also massive, but it is much smaller than the Earth. This means that its gravitational pull is not as strong. This is why the Moon orbits the Earth, instead of the other way around.\n", 129 | "\n", 130 | "Gravity is a very important force in the universe. It is what keeps the planets in orbit around the Sun, and it is what keeps the Moon in orbit around the Earth. It is also what keeps us on the ground. Without gravity, we would all float off into space.\n", 131 | "\n", 132 | "Gravity is a very mysterious force. We don't really know what causes it. We do know that it is related to mass, but we don't know exactly how. Scientists are still working on trying to understand gravity.\n", 133 | "\n", 134 | "One way to think about gravity is to imagine a trampoline. If you put a bowling ball in the middle of the trampoline, it will make a dent in the trampoline. If you then put a marble on the trampoline, the marble will roll towards the bowling ball. This is because the bowling ball is more massive than the marble, and it has a stronger gravitational pull.\n", 135 | "\n", 136 | "The Earth is like the bowling ball, and we are like the marble. The Earth's gravity pulls us towards the center of the Earth, which is why we don't float off into space.\n", 137 | "\n", 138 | "Gravity is a very important force in the universe. It is what keeps the planets in orbit around the Sun, and it is what keeps the Moon in orbit around the Earth. It is also what keeps us on the ground. Without gravity, we would all float off into space.\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "palm_gen = llm(\n", 144 | " provider=\"palm\",\n", 145 | " api_key=os.environ[\"PALM_API_KEY\"],\n", 146 | ")\n", 147 | "palm_config = TextGenerationConfig(\n", 148 | " model=\"chat-bison-001\", temperature=0, use_cache=True\n", 149 | ")\n", 150 | "palm_response = palm_gen.generate(messages, config=palm_config)\n", 151 | "print(palm_response.text[0].content)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "#### PaLM: Vertex AI\n", 159 | "Uses the same API as Google Cloud AI Platform. You will need to setup a service account and download the key." 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 6, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | " Gravity is a force that pulls objects towards each other. It is what keeps us on the ground and keeps the planets in orbit around the sun. Gravity is always pulling on us, but we don't notice it because we are used to it. But if you jump up in the air, you will feel the force of gravity pulling you back down to the ground.\n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "palm_gen = llm(\n", 177 | " provider=\"palm\",\n", 178 | " palm_key_file=os.environ[\"PALM_SERVICE_ACCOUNT_KEY_FILE\"],\n", 179 | " project_id=os.environ[\"PALM_PROJECT_ID\"],\n", 180 | " project_location=os.environ[\"PALM_PROJECT_LOCATION\"],\n", 181 | " api_key=None\n", 182 | ")\n", 183 | "palm_config = TextGenerationConfig(\n", 184 | " model=\"codechat-bison\", temperature=0, use_cache=True\n", 185 | ")\n", 186 | "palm_response = palm_gen.generate(messages, config=palm_config)\n", 187 | "print(palm_response.text[0].content)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "### Cohere" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 7, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "Gravity is a force that pulls things together. It is what makes things fall to the ground and what holds us on the earth. Gravity is a fundamental force of nature that affects everything around us. It is a property of all matter, and it is what makes things heavy. Gravity is also what causes the moon to orbit the earth and the planets to orbit the sun. It is a very important force that plays a big role in our lives.\n" 207 | ] 208 | } 209 | ], 210 | "source": [ 211 | "cohere_gen = llm(provider=\"cohere\")\n", 212 | "cohere_config = TextGenerationConfig(model=\"command\", max_tokens=4050, use_cache=True)\n", 213 | "cohere_response = cohere_gen.generate(messages, config=cohere_config)\n", 214 | "print(cohere_response.text[0].content)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": {}, 220 | "source": [ 221 | "### Local HuggingFace Model" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 10, 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "name": "stderr", 231 | "output_type": "stream", 232 | "text": [ 233 | "/home/victordibia/miniconda3/envs/llmx/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 234 | " from .autonotebook import tqdm as notebook_tqdm\n", 235 | "/home/victordibia/miniconda3/envs/llmx/lib/python3.11/site-packages/transformers/utils/hub.py:124: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n", 236 | " warnings.warn(\n", 237 | "Loading checkpoint shards: 100%|██████████| 8/8 [00:06<00:00, 1.27it/s]\n" 238 | ] 239 | } 240 | ], 241 | "source": [ 242 | "hf_generator = llm(provider=\"hf\", model=\"HuggingFaceH4/zephyr-7b-beta\", device_map=\"auto\")" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 11, 248 | "metadata": {}, 249 | "outputs": [ 250 | { 251 | "name": "stdout", 252 | "output_type": "stream", 253 | "text": [ 254 | "Gravity is a special kind of force that pulls things down towards the ground. It's what makes apples fall from trees and why we don't float away into space! Gravity is also what keeps the Earth spinning around and around, so we don't fall off! It's a very strong force that we can't see, but we can feel it pulling us down when we jump into the air. Gravity is a very important force that helps keep everything in the universe in its place!\n" 255 | ] 256 | } 257 | ], 258 | "source": [ 259 | "hf_config = TextGenerationConfig(temperature=0, max_tokens=650, use_cache=False)\n", 260 | "hf_response = hf_generator.generate(messages, config=hf_config)\n", 261 | "print(hf_response.text[0].content)" 262 | ] 263 | } 264 | ], 265 | "metadata": { 266 | "kernelspec": { 267 | "display_name": "base", 268 | "language": "python", 269 | "name": "python3" 270 | }, 271 | "language_info": { 272 | "codemirror_mode": { 273 | "name": "ipython", 274 | "version": 3 275 | }, 276 | "file_extension": ".py", 277 | "mimetype": "text/x-python", 278 | "name": "python", 279 | "nbconvert_exporter": "python", 280 | "pygments_lexer": "ipython3", 281 | "version": "3.11.4" 282 | }, 283 | "orig_nbformat": 4 284 | }, 285 | "nbformat": 4, 286 | "nbformat_minor": 2 287 | } 288 | --------------------------------------------------------------------------------