├── tests ├── tool.prompt ├── test.prompt ├── copy.prompt ├── constants.prompt ├── structured.prompt ├── reasoning.prompt ├── joke.prompt ├── test_llm_structured_output.py ├── test_logprobs.py ├── test_function_calling.py └── test_llm_generate.py ├── tasks ├── __init__.py └── main.py ├── assets └── logo.png ├── pyproject.toml ├── chainlite ├── threadsafe_dict.py ├── __init__.py ├── chain_log_handler.py ├── redis_cache.py ├── llm_output.py ├── utils.py ├── llm_config.py ├── load_prompt.py ├── llm_generate.py └── chat_lite_llm.py ├── .github └── workflows │ └── python-publish.yml ├── setup.py ├── .gitignore ├── llm_config.yaml ├── README.md └── LICENSE /tests/tool.prompt: -------------------------------------------------------------------------------- 1 | # input 2 | {{ message }} -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from tasks.main import * 2 | -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-oval/chainlite/HEAD/assets/logo.png -------------------------------------------------------------------------------- /tests/test.prompt: -------------------------------------------------------------------------------- 1 | # instruction 2 | Create a list of country names. The output must be in JSON format. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /tests/copy.prompt: -------------------------------------------------------------------------------- 1 | # instruction 2 | You are a copy machine. Copy what the user inputs and print it out. 3 | 4 | # input 5 | {{ input }} -------------------------------------------------------------------------------- /tests/constants.prompt: -------------------------------------------------------------------------------- 1 | # instruction 2 | Today's date is {{ today }}. 3 | The current year is {{ current_year }}. 4 | 5 | # input 6 | {{ question }} -------------------------------------------------------------------------------- /tests/structured.prompt: -------------------------------------------------------------------------------- 1 | # instruction 2 | Extract structured data from the given text. 3 | {# This is a comment, and will be ignored anywhere in a .prompt file. Other than block definitions and comments, '#' is allowed and is treated as a normal character. #} 4 | 5 | # input 6 | {{ text }} -------------------------------------------------------------------------------- /tests/reasoning.prompt: -------------------------------------------------------------------------------- 1 | # instruction 2 | 3 | # input 4 | Write a Python program that counts the number of substrings in a given string that are palindromes. 5 | A palindrome is a string that reads the same forwards and backwards. The program should take a string as input and output the number of palindromic substrings in the string. 6 | -------------------------------------------------------------------------------- /tests/joke.prompt: -------------------------------------------------------------------------------- 1 | # Instruction 2 | Tell a joke about the given topic. The format of the joke should be a question and response, separated by a line break. 3 | {# This is a comment, and will be ignored anywhere in a .prompt file. Other than block definitions and comments, '#' is allowed and is treated as a normal character. #} 4 | 5 | # input 6 | Physics 7 | 8 | 9 | # output 10 | Why don't scientists trust atoms? 11 | Because they make up everything! 12 | 13 | # input 14 | {{ topic }} -------------------------------------------------------------------------------- /chainlite/threadsafe_dict.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | 4 | class ThreadSafeDict: 5 | def __init__(self): 6 | self._dict = {} 7 | self._lock = threading.Lock() 8 | 9 | def __setitem__(self, key, value): 10 | with self._lock: 11 | self._dict[key] = value 12 | 13 | def __getitem__(self, key): 14 | with self._lock: 15 | return self._dict[key] 16 | 17 | def __delitem__(self, key): 18 | with self._lock: 19 | del self._dict[key] 20 | 21 | def get(self, key, default=None): 22 | with self._lock: 23 | return self._dict.get(key, default) 24 | 25 | def __contains__(self, key): 26 | with self._lock: 27 | return key in self._dict 28 | 29 | def items(self): 30 | with self._lock: 31 | return list(self._dict.items()) 32 | 33 | def keys(self): 34 | with self._lock: 35 | return list(self._dict.keys()) 36 | 37 | def values(self): 38 | with self._lock: 39 | return list(self._dict.values()) 40 | -------------------------------------------------------------------------------- /chainlite/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_core.runnables import Runnable, chain 2 | 3 | from chainlite.llm_config import ( 4 | get_all_configured_engines, 5 | get_total_cost, 6 | initialize_llm_config, 7 | ) 8 | from chainlite.llm_generate import llm_generation_chain, write_prompt_logs_to_file 9 | from chainlite.llm_output import ( 10 | ToolOutput, 11 | extract_tag_from_llm_output, 12 | lines_to_list, 13 | string_to_indices, 14 | string_to_json, 15 | ) 16 | from chainlite.load_prompt import register_prompt_constants 17 | from chainlite.utils import get_logger, pprint_chain, run_async_in_parallel 18 | 19 | __all__ = [ 20 | "llm_generation_chain", 21 | "get_logger", 22 | "initialize_llm_config", 23 | "pprint_chain", 24 | "write_prompt_logs_to_file", 25 | "get_total_cost", 26 | "chain", 27 | "ToolOutput", 28 | "get_all_configured_engines", 29 | "register_prompt_constants", 30 | "Runnable", # Exported for type hinting 31 | "run_async_in_parallel", 32 | # For processing LLM outputs: 33 | "extract_tag_from_llm_output", 34 | "lines_to_list", 35 | "string_to_indices", 36 | "string_to_json", 37 | ] 38 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.10' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="chainlite", 5 | version="0.4.4", 6 | author="Sina Semnani", 7 | author_email="sinaj@cs.stanford.edu", 8 | description="A Python package that uses LangChain and LiteLLM to call large language model APIs easily", 9 | long_description=open("README.md").read(), 10 | long_description_content_type="text/markdown", 11 | url="https://github.com/stanford-oval/chainlite", 12 | packages=find_packages(), 13 | install_requires=[ 14 | "tqdm", 15 | "langchain>=0.3.18", 16 | "langchain-community>=0.3", 17 | "langgraph>=0.2", 18 | "litellm==1.65.4.post1", # the unified interface to LLM APIs 19 | "numpydoc", # needed for function calling with LiteLLM 20 | "grandalf", # to visualize LangGraph graphs 21 | "pydantic>=2.5", 22 | "redis[hiredis]", 23 | ], 24 | extras_require={ 25 | "dev": [ 26 | "invoke", # for running tasks and scripts 27 | "pytest", # for testing 28 | "pytest-asyncio", # for testing async code 29 | "setuptools", # for building wheels 30 | "wheel", # for building wheels 31 | "twine", # for uploading to PyPI 32 | "isort", # for code formatting 33 | "black", # for code formatting 34 | "tuna", # for measuring import time 35 | ], 36 | }, 37 | classifiers=[ 38 | "Programming Language :: Python :: 3", 39 | "License :: OSI Approved :: Apache Software License", 40 | "Operating System :: OS Independent", 41 | ], 42 | python_requires=">=3.10", 43 | license="Apache License 2.0", 44 | ) 45 | -------------------------------------------------------------------------------- /tests/test_llm_structured_output.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import random 3 | import string 4 | import time 5 | from datetime import datetime 6 | from typing import List 7 | from zoneinfo import ZoneInfo 8 | 9 | import pytest 10 | from langchain_core.runnables import RunnableLambda 11 | from pydantic import BaseModel 12 | 13 | from chainlite import ( 14 | get_all_configured_engines, 15 | get_logger, 16 | get_total_cost, 17 | llm_generation_chain, 18 | register_prompt_constants, 19 | write_prompt_logs_to_file, 20 | ) 21 | from chainlite.llm_config import GlobalVars 22 | from chainlite.utils import run_async_in_parallel 23 | 24 | logger = get_logger(__name__) 25 | 26 | 27 | @pytest.mark.asyncio(scope="session") 28 | @pytest.mark.parametrize("engine", ["gpt-4o-openai", "gpt-4o-azure"]) 29 | async def test_structured_output(engine: str): 30 | class Debate(BaseModel): 31 | """ 32 | A Debate event 33 | """ 34 | 35 | mention: str 36 | people: List[str] 37 | 38 | response = await llm_generation_chain( 39 | template_file="structured.prompt", 40 | engine=engine, 41 | max_tokens=1000, 42 | force_skip_cache=True, 43 | pydantic_class=Debate, 44 | ).ainvoke( 45 | { 46 | "text": "4 major candidates for California U.S. Senate seat clash in first debate" 47 | } 48 | ) 49 | 50 | assert isinstance(response, Debate) 51 | assert response.mention 52 | assert response.people 53 | 54 | 55 | @pytest.mark.asyncio(scope="session") 56 | @pytest.mark.parametrize("engine", ["gpt-4o-openai", "gpt-4o-azure"]) 57 | async def test_structured_output_engine(engine: str): 58 | class Debate(BaseModel): 59 | """ 60 | A Debate event 61 | """ 62 | 63 | mention: str 64 | people: List[str] 65 | 66 | response = await llm_generation_chain( 67 | template_file="structured.prompt", 68 | engine=engine, 69 | engine_for_structured_output=engine, 70 | max_tokens=1000, 71 | force_skip_cache=True, 72 | pydantic_class=Debate, 73 | ).ainvoke( 74 | { 75 | "text": "4 major candidates for California U.S. Senate seat clash in first debate" 76 | } 77 | ) 78 | 79 | print(response) 80 | assert isinstance(response, Debate) 81 | assert response.mention 82 | assert response.people 83 | -------------------------------------------------------------------------------- /tests/test_logprobs.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | import time 4 | import pytest 5 | 6 | from chainlite import get_logger, llm_generation_chain, get_total_cost 7 | 8 | logger = get_logger(__name__) 9 | 10 | 11 | test_engine = "gpt-4o-openai" 12 | 13 | 14 | @pytest.mark.asyncio(scope="session") 15 | async def test_llm_generate_with_logprobs(): 16 | response, logprobs = await llm_generation_chain( 17 | template_file="test.prompt", # prompt path relative to one of the paths specified in `prompt_dirs` 18 | engine=test_engine, 19 | max_tokens=5, 20 | force_skip_cache=True, 21 | return_top_logprobs=10, 22 | ).ainvoke({}) 23 | 24 | assert response is not None, "The response should not be None" 25 | assert isinstance(response, str), "The response should be a string" 26 | assert len(response) > 0, "The response should not be empty" 27 | 28 | assert len(logprobs) == 5 29 | for i in range(len(logprobs)): 30 | assert "top_logprobs" in logprobs[i] 31 | assert len(logprobs[i]["top_logprobs"]) == 10 32 | 33 | 34 | @pytest.mark.asyncio(scope="session") 35 | async def test_logprob_cache(): 36 | c = llm_generation_chain( 37 | template_file="tests/copy.prompt", 38 | engine=test_engine, 39 | max_tokens=1, 40 | temperature=0.0, 41 | return_top_logprobs=20, 42 | ) 43 | # use random input so that the first call is not cached 44 | start_time = time.time() 45 | random_input = "".join(random.choices(string.ascii_letters + string.digits, k=20)) 46 | response1 = await c.ainvoke({"input": random_input}) 47 | first_time = time.time() - start_time 48 | first_cost = get_total_cost() 49 | 50 | print("First call took {:.2f} seconds".format(first_time)) 51 | print("Total cost after first call: ${:.10f}".format(first_cost)) 52 | 53 | start_time = time.time() 54 | response2 = await c.ainvoke({"input": random_input}) 55 | second_time = time.time() - start_time 56 | print("Second call took {:.2f} seconds".format(second_time)) 57 | second_cost = get_total_cost() 58 | print("Total cost after second call: ${:.10f}".format(second_cost)) 59 | 60 | assert response1 == response2 61 | assert ( 62 | second_time < first_time * 0.5 63 | ), "The second (cached) LLM call should be much faster than the first call" 64 | assert first_cost > 0, "The cost should be greater than 0" 65 | assert ( 66 | second_cost == first_cost 67 | ), "The cost should not change after a cached LLM call" 68 | -------------------------------------------------------------------------------- /tasks/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import redis 4 | from invoke import task 5 | 6 | from chainlite.utils import get_logger 7 | 8 | logger = get_logger(__name__) 9 | 10 | DEFAULT_REDIS_PORT = 6379 11 | 12 | 13 | @task 14 | def load_api_keys(c): 15 | try: 16 | with open("API_KEYS") as f: 17 | for line in f: 18 | line = line.strip() 19 | if line and not line.startswith("#"): 20 | key, value = tuple(line.split("=", 1)) 21 | key, value = key.strip(), value.strip() 22 | os.environ[key] = value 23 | logger.debug("Loaded API key named %s", key) 24 | 25 | except Exception as e: 26 | logger.error( 27 | "Error while loading API keys from API_KEY. Make sure this file exists, and has the correct format. %s", 28 | str(e), 29 | ) 30 | 31 | 32 | @task() 33 | def start_redis(c, redis_port: int = DEFAULT_REDIS_PORT): 34 | """ 35 | Start a Redis server if it is not already running. 36 | 37 | This task attempts to connect to a Redis server on the specified port. 38 | If the connection fails (indicating that the Redis server is not running), 39 | it starts a new Redis server on that port. 40 | 41 | Parameters: 42 | - c: Context, automatically passed by invoke. 43 | - redis_port (int): The port number on which to start the Redis server. Defaults to DEFAULT_REDIS_PORT. 44 | """ 45 | try: 46 | r = redis.Redis(host="localhost", port=redis_port) 47 | r.ping() 48 | except redis.exceptions.ConnectionError: 49 | logger.info("Redis server not found, starting it now...") 50 | c.run( 51 | f"docker run --rm -d --name redis-stack -p {redis_port}:6379 -p 8001:8001 redis/redis-stack:latest" 52 | ) 53 | return 54 | 55 | logger.debug("Redis server is already running.") 56 | 57 | 58 | @task(pre=[load_api_keys, start_redis], aliases=["test"]) 59 | def tests(c, log_level="info", parallel=False, test_file: str = None): 60 | """Run tests using pytest""" 61 | 62 | if test_file: 63 | test_files = [f"./tests/{test_file}"] 64 | else: 65 | test_files = [ 66 | "./tests/test_llm_generate.py", 67 | "./tests/test_llm_structured_output.py", 68 | "./tests/test_function_calling.py", 69 | "./tests/test_logprobs.py", 70 | ] 71 | 72 | pytest_command = ( 73 | f"pytest " 74 | f"--log-cli-level={log_level} " 75 | "-rP " 76 | "--color=yes " 77 | # "--disable-warnings " 78 | "-x " # Stop after first failure 79 | ) 80 | 81 | if parallel: 82 | pytest_command += f"-n auto " 83 | 84 | pytest_command += " ".join(test_files) 85 | 86 | c.run(pytest_command, pty=True) 87 | 88 | 89 | @task 90 | def format_code(c): 91 | """Format code using black and isort""" 92 | c.run("isort --profile black .", pty=True) 93 | c.run("black .", pty=True) 94 | -------------------------------------------------------------------------------- /tests/test_function_calling.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from chainlite import llm_generation_chain, write_prompt_logs_to_file 4 | from chainlite.llm_generate import ToolOutput 5 | 6 | 7 | def get_current_weather(location: str): 8 | """ 9 | Get the current weather in a given location. 10 | 11 | Parameters 12 | ---------- 13 | location : str 14 | The location for which to get the current weather. 15 | 16 | Returns 17 | ------- 18 | str 19 | A string describing the current weather in the specified location. 20 | """ 21 | 22 | if "boston" in location.lower(): 23 | return "The weather is 12F" 24 | 25 | 26 | def add(a: int, b: int) -> int: 27 | """Adds a and b.""" 28 | return a + b 29 | 30 | 31 | @pytest.mark.asyncio(scope="session") 32 | @pytest.mark.parametrize("engine", ["gpt-4o-openai", "gpt-4o-azure"]) 33 | async def test_function_calling(engine): 34 | test_tool_chain = llm_generation_chain( 35 | "tool.prompt", 36 | engine=engine, 37 | max_tokens=100, 38 | tools=[get_current_weather, add], 39 | force_skip_cache=True, 40 | ) 41 | # No function calling done, just output text 42 | text_output, tool_outputs = await test_tool_chain.ainvoke( 43 | {"message": "What tools do you have available?"} 44 | ) 45 | assert "weather" in text_output.lower() 46 | assert "add" in text_output.lower() 47 | assert tool_outputs == [] 48 | 49 | # Function calling needed 50 | text_output, tool_outputs = await test_tool_chain.ainvoke( 51 | {"message": "What is the weather like in Boston ?"} 52 | ) 53 | 54 | assert text_output == "" 55 | assert str(tool_outputs) == "[get_current_weather(location='Boston')]" 56 | 57 | text_output, tool_outputs = await test_tool_chain.ainvoke( 58 | {"message": "What 1021 + 9573?"} 59 | ) 60 | assert text_output == "" 61 | assert str(tool_outputs) == "[add(a=1021, b=9573)]" 62 | 63 | write_prompt_logs_to_file("tests/llm_input_outputs.jsonl") 64 | 65 | 66 | @pytest.mark.asyncio(scope="session") 67 | @pytest.mark.parametrize("engine", ["gpt-4o-openai", "gpt-4o-azure"]) 68 | async def test_forced_function_calling(engine): 69 | test_tool_chain = llm_generation_chain( 70 | "tool.prompt", 71 | engine=engine, 72 | max_tokens=100, 73 | tools=[get_current_weather, add], 74 | force_skip_cache=True, 75 | force_tool_calling=True, 76 | ) 77 | 78 | # Forcing function call when it is already needed 79 | tool_outputs = await test_tool_chain.ainvoke( 80 | {"message": "What is the weather like in New York City?"} 81 | ) 82 | 83 | assert isinstance(tool_outputs, list) 84 | assert str(tool_outputs) == "[get_current_weather(location='New York City')]" 85 | print(tool_outputs) 86 | 87 | # Forcing function call when it is not needed 88 | tool_outputs = await test_tool_chain.ainvoke({"message": "What is your name?"}) 89 | print(tool_outputs) 90 | assert isinstance(tool_outputs, list) 91 | assert len(tool_outputs) > 0 92 | assert isinstance(tool_outputs[0], ToolOutput) 93 | 94 | write_prompt_logs_to_file("tests/llm_input_outputs.jsonl") 95 | -------------------------------------------------------------------------------- /chainlite/chain_log_handler.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | from uuid import UUID 3 | 4 | from langchain_core.callbacks import AsyncCallbackHandler 5 | from langchain_core.messages import BaseMessage 6 | from langchain_core.outputs import LLMResult 7 | 8 | from chainlite.llm_config import GlobalVars 9 | from chainlite.llm_output import ToolOutput 10 | 11 | 12 | class ChainLogHandler(AsyncCallbackHandler): 13 | async def on_chat_model_start( 14 | self, 15 | serialized: Dict[str, Any], 16 | messages: List[List[BaseMessage]], 17 | *, 18 | run_id: UUID, 19 | parent_run_id: Optional[UUID] = None, 20 | tags: Optional[List[str]] = None, 21 | metadata: Optional[Dict[str, Any]] = None, 22 | **kwargs: Any, 23 | ) -> Any: 24 | run_id = str(parent_run_id) 25 | 26 | # find the first system message 27 | system_message = [m for m in messages[0] if m.type == "system"] 28 | if len(system_message) == 0: 29 | # no system message 30 | distillation_instruction = "" 31 | else: 32 | distillation_instruction = system_message[0].content 33 | 34 | llm_input = messages[0][-1].content 35 | if messages[0][-1].type == "system": 36 | # it means the prompt did not have an `# input` block, and only has an instruction block 37 | llm_input = "" 38 | if run_id not in GlobalVars.prompt_logs: 39 | GlobalVars.prompt_logs[run_id] = {} 40 | GlobalVars.prompt_logs[run_id]["instruction"] = distillation_instruction 41 | GlobalVars.prompt_logs[run_id]["input"] = llm_input 42 | GlobalVars.prompt_logs[run_id]["template_name"] = metadata["template_name"] 43 | 44 | async def on_chain_end( 45 | self, 46 | response: LLMResult, 47 | *, 48 | run_id: UUID, 49 | parent_run_id: Optional[UUID] = None, 50 | tags: Optional[List[str]] = None, 51 | **kwargs: Any, 52 | ) -> None: 53 | """Run when LLM ends running.""" 54 | run_id = str(run_id) 55 | if run_id in GlobalVars.prompt_logs: 56 | # this is the final response in the entire chain 57 | 58 | if ( 59 | isinstance(response, tuple) 60 | and len(response) == 2 61 | and isinstance(response[1], ToolOutput) 62 | ): 63 | response = list(response) 64 | response[1] = str(response[1]) 65 | elif ( 66 | isinstance(response, tuple) 67 | and len(response) == 2 68 | and isinstance(response[1], list) 69 | ): 70 | # the second element of the tuple is a list of ChatCompletionTokenLogprob (or its converted dict) 71 | response = response[0] 72 | 73 | elif isinstance(response, ToolOutput): 74 | response = str(response) 75 | if isinstance(response, tuple) and len(response) == 2: 76 | response = list(response) 77 | # if exactly one is not None/empty, then we want to log that one 78 | if response[0] and not response[1]: 79 | response = response[0] 80 | elif not response[0] and response[1]: 81 | response = response[1] 82 | GlobalVars.prompt_logs[run_id][ 83 | "output" 84 | ] = ( 85 | response.__repr__() 86 | ) # convert to str because output might be a Pydantic object (if `pydantic_class` is provided in `llm_generation_chain()`) 87 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .vscode 163 | 164 | *.jsonl 165 | API_KEYS -------------------------------------------------------------------------------- /chainlite/redis_cache.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import atexit 3 | from contextlib import redirect_stdout 4 | import io 5 | from typing import Any, Optional 6 | 7 | from langchain_community.cache import AsyncRedisCache 8 | from langchain_core.caches import RETURN_VAL_TYPE 9 | from langchain_core.load.dump import dumps 10 | import redis.asyncio as redis 11 | 12 | 13 | SECONDS_IN_A_WEEK = 60 * 60 * 24 * 7 14 | 15 | 16 | class CustomAsyncRedisCache(AsyncRedisCache): 17 | """This class fixes langchain>=0.2.*'s cache issue with LiteLLM 18 | The core of the problem is that LiteLLM's `Usage`, `ChatCompletionMessageToolCall` and `ChatCompletionTokenLogprob` 19 | classes should inherit from LangChain's Serializable class, but don't. 20 | This class is the minimal fix to make it work. 21 | """ 22 | 23 | @staticmethod 24 | def _configure_pipeline_for_update( 25 | key: str, 26 | pipe: Any, 27 | return_val: RETURN_VAL_TYPE, 28 | ttl: Optional[int] = SECONDS_IN_A_WEEK, 29 | ) -> None: 30 | for r in return_val: 31 | if ( 32 | hasattr(r.message, "additional_kwargs") 33 | and "tool_calls" in r.message.additional_kwargs 34 | ): 35 | r.message.additional_kwargs["tool_calls"] = [ 36 | tool_call.dict() 37 | for tool_call in r.message.additional_kwargs["tool_calls"] 38 | ] 39 | 40 | if ( 41 | hasattr(r.message, "response_metadata") 42 | and "token_usage" in r.message.response_metadata 43 | ): 44 | r.message.response_metadata["token_usage"] = ( 45 | r.message.response_metadata["token_usage"].dict() 46 | ) 47 | if ( 48 | hasattr(r.message, "response_metadata") 49 | and "logprobs" in r.message.response_metadata 50 | ): 51 | r.message.response_metadata["logprobs"] = [ 52 | logprob.dict() 53 | for logprob in r.message.response_metadata["logprobs"] 54 | ] 55 | pipe.hset( 56 | key, 57 | mapping={ 58 | str(idx): dumps(generation) for idx, generation in enumerate(return_val) 59 | }, 60 | ) 61 | if ttl is not None: 62 | pipe.expire(key, ttl) 63 | 64 | 65 | _global_redis_client = None 66 | 67 | 68 | def init_redis_client() -> None: 69 | """Initialize the global Redis client.""" 70 | # TODO move cache setting to the config file 71 | # We do not use LiteLLM's cache since it has a bug. We use LangChain's instead 72 | 73 | global _global_redis_client 74 | _global_redis_client = redis.Redis.from_url("redis://localhost:6379") 75 | redis_cache = CustomAsyncRedisCache( 76 | _global_redis_client, 77 | ) 78 | from langchain.globals import set_llm_cache 79 | 80 | set_llm_cache(redis_cache) 81 | 82 | # Register the cleanup function so that it runs when Python exits 83 | atexit.register(_sync_close_redis_client) 84 | 85 | 86 | async def _close_redis_client() -> None: 87 | """Asynchronously close the global Redis client.""" 88 | global _global_redis_client 89 | if _global_redis_client: 90 | await _global_redis_client.close() 91 | 92 | 93 | def _sync_close_redis_client() -> None: 94 | try: 95 | # Create a new event loop to avoid pitfalls with an already closed global loop 96 | loop = asyncio.new_event_loop() 97 | asyncio.set_event_loop(loop) 98 | # Redirect stdout (or stderr) while closing the client to prevent the message from appearing. 99 | with redirect_stdout(io.StringIO()): 100 | loop.run_until_complete(_close_redis_client()) 101 | loop.close() 102 | except RuntimeError as e: 103 | if "Event loop is closed" in str(e): 104 | pass 105 | else: 106 | print(f"Error during Redis client cleanup: {e}") 107 | -------------------------------------------------------------------------------- /llm_config.yaml: -------------------------------------------------------------------------------- 1 | # This configuration file defines the setup for how ChainLite calls various LLM APIs, and how it logs LLM inputs/outputs. 2 | # To configure it: 3 | # 1. Set directories containing the prompt files under the `prompt_dirs` section. 4 | # 2. Adjust logging settings and optionally specify which prompts you would like to skip in the `prompt_logging` section. 5 | # 3. Configure LLM endpoints under the `llm_endpoints` section, specifying the API base URL, version (if needed), API key (if needed), 6 | # and the mapping of model names to their respective deployment identifiers. The name on the left-hand side of each mapping is "engine", the shorthand 7 | # you can use in your code when calling llm_generation_chain(engine=...). 8 | # The name on the right side-hand is the "model", the specific name that LiteLLM expects: https://docs.litellm.ai/docs/providers 9 | # Note that "engine" names should be unique within this file, but "model" names do not have to be unique. 10 | # 4. Follow the examples provided for Azure, OpenAI, Groq, Together, Mistral, and local models as needed, and remove unused llm endpoints. 11 | 12 | prompt_dirs: # List of directories containing prompt files, relative to the location of this file 13 | - "./" # Current directory 14 | - "./tests/" 15 | 16 | litellm_set_verbose: false # Verbose logging setting for LiteLLM 17 | prompt_logging: 18 | log_file: ./prompt_logs.jsonl # Path to the log file for prompt logs, relative to the location of this file 19 | prompts_to_skip: 20 | - tests/test.prompt # List of prompts to exclude from logging, relative to the location of this file 21 | 22 | # Configuration for different LLM endpoints 23 | llm_endpoints: 24 | # Example configuration for OpenAI models via Azure API 25 | - api_base: https://oval-hai.openai.azure.com/ # Base URL for Azure OpenAI API 26 | api_version: 2025-01-01-preview # API version for Azure OpenAI 27 | api_key: AZURE_OPENAI_API_KEY_TEST # API key for Azure OpenAI 28 | engine_map: # Mapping of model names to Azure deployment identifiers prepended by "azure/" 29 | gpt-4o-azure: azure/gpt-4o 30 | o1-azure: azure/o1 31 | o3-mini-azure: azure/o3-mini 32 | 33 | # Example of OpenAI models via openai.com 34 | - api_base: https://api.openai.com/v1 35 | api_key: OPENAI_API_KEY_TEST 36 | engine_map: # OpenAI models don't need the "openai/" prefix 37 | gpt-35-turbo: gpt-3.5-turbo-0125 38 | gpt-35-turbo-instruct: gpt-3.5-turbo-instruct 39 | gpt-4: gpt-4-turbo-2024-04-09 40 | gpt-4o-mini: gpt-4o-mini 41 | gpt-4o-openai: gpt-4o-2024-08-06 # you can specify which version of the model you want 42 | gpt-4o: gpt-4o # you can leave it to OpenAI to select the latest model version for you 43 | gpt-4o-another-one: gpt-4o # "model" names, which are on the right side-hand of a mapping, do not need to be unique 44 | o1: o1 45 | o3-mini: o3-mini 46 | 47 | # Example of OpenAI fine-tuned model 48 | - api_base: https://api.openai.com/v1 49 | api_key: OPENAI_API_KEY 50 | prompt_format: distilled 51 | engine_map: 52 | gpt-35-turbo-finetuned: ft:gpt-3.5-turbo-1106: 53 | 54 | # Example of Groq API (groq.com) 55 | - api_base: https://api.groq.com/openai/v1 56 | api_key: GROQ_API_KEY 57 | engine_map: # Has limited model availability, but a very fast inference on custom hardware 58 | llama-3-70b-instruct: groq/llama3-70b-8192 59 | 60 | # Example of Together API (together.ai) 61 | - api_key: TOGETHER_API_KEY 62 | engine_map: # TODO non-instruct models don't work well because of LiteLLM's formatting issues, does not work with free accounts because of the 1 QPS limit 63 | llama-2-70b: together_ai/togethercomputer/llama-2-70b 64 | llama-3-70b-instruct: together_ai/meta-llama/Llama-3-70b-chat-hf 65 | mixtral-8-7b: together_ai/mistralai/Mixtral-8x7B-v0.1 66 | mixtral-8-7b-instruct: together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1 67 | mistral-7b: together_ai/mistralai/Mistral-7B-v0.1 68 | 69 | # Example of Mistral API (mistral.ai) 70 | - api_base: https://api.mistral.ai/v1 # https://docs.mistral.ai/platform/endpoints/ 71 | api_key: MISTRAL_API_KEY 72 | engine_map: 73 | mistral-large: mistral/mistral-large-latest 74 | mistral-medium: mistral/mistral-medium-latest 75 | mistral-small: mistral/mistral-small-latest 76 | mistral-7b-instruct: mistral/open-mistral-7b 77 | mixtral-8-7b-instruct: mistral/open-mixtral-8x7b 78 | 79 | # Example of local distilled models served via HuggingFace's text-generation-inference (https://github.com/huggingface/text-generation-inference/) 80 | # The name after huggingface/* does not matter and is unused 81 | - api_base: http://127.0.0.1:5002 82 | prompt_format: distilled 83 | engine_map: 84 | local_distilled: huggingface/local 85 | 86 | # Example of local models served via HuggingFace's text-generation-inference (https://github.com/huggingface/text-generation-inference/) 87 | # The name after huggingface/* does not matter and is unused 88 | - api_base: http://127.0.0.1:5004 89 | engine_map: 90 | local_fewshot: huggingface/local 91 | 92 | -------------------------------------------------------------------------------- /chainlite/llm_output.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Callable 4 | 5 | from langchain_core.runnables import chain 6 | from pydantic import BaseModel 7 | 8 | from chainlite.utils import get_logger 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | @chain 14 | def extract_tag_from_llm_output( 15 | llm_output: str, tags: str | list[str] 16 | ) -> str | list[str]: 17 | """ 18 | Extracts content enclosed within tags from a given LLM output string. 19 | 20 | Args: 21 | llm_output (str): The output string from which to extract content. 22 | tags (str | list[str]): A single tag or a list of tags to search for in the output string. 23 | 24 | Returns: 25 | str | list[str]: The extracted content for the specified tag(s). If a single tag is provided, 26 | a single string is returned. If a list of tags is provided, a list of strings 27 | is returned, each corresponding to the content of the respective tag. 28 | """ 29 | is_list = isinstance(tags, list) 30 | if not is_list: 31 | assert isinstance(tags, str) 32 | tags = [tags] 33 | all_extracted_tags = [] 34 | for tag in tags: 35 | extracted_tag = "" 36 | tag_start = llm_output.find(f"<{tag}>") + len(f"<{tag}>") 37 | tag_end = llm_output.find(f"", tag_start) 38 | if tag_start >= 0 and tag_end >= 0: 39 | extracted_tag = llm_output[tag_start:tag_end].strip() 40 | if tag_start >= 0 and tag_end < 0: 41 | extracted_tag = llm_output[tag_start:].strip() 42 | if extracted_tag.startswith("-"): 43 | extracted_tag = extracted_tag[1:].strip() 44 | all_extracted_tags.append(extracted_tag.strip()) 45 | 46 | if not is_list: 47 | return all_extracted_tags[0] 48 | return all_extracted_tags 49 | 50 | 51 | @chain 52 | def lines_to_list(llm_output: str) -> list[str]: 53 | """ 54 | Convert a string of lines into a list of strings, processing each line. 55 | 56 | This function processes each line of the input string `llm_output` by: 57 | - Splitting the input string by newline characters. 58 | - Ignoring empty lines. 59 | - Removing leading hyphens and trimming whitespace. 60 | - Removing starting item numbers (e.g., "1.", "2.", etc.). 61 | 62 | Args: 63 | llm_output (str): The input string containing lines to be processed. 64 | 65 | Returns: 66 | list[str]: A list of processed strings. 67 | """ 68 | ret = [] 69 | for r in llm_output.split("\n"): 70 | if not r.strip(): 71 | continue 72 | if r.startswith("-"): 73 | r = r[1:].strip() 74 | # remove starting item number 75 | r = re.split(r"^\d+\.", r)[-1] 76 | ret.append(r.strip()) 77 | 78 | return ret 79 | 80 | 81 | @chain 82 | def string_to_indices(llm_output: str, llm_output_start_index: int) -> list[int]: 83 | """ 84 | Convert a comma-separated string of n-indexed integers into a 0-indexed list of integers. 85 | 86 | This function takes a string containing integers separated by commas, 87 | removes any surrounding square brackets, and converts each integer 88 | into 0-indexed by subtracting the specified start index. 89 | 90 | Args: 91 | llm_output (str): The string containing comma-separated integers. 92 | llm_output_start_index (int): Whether the llm_output_start_index is 1-indexed or 0-indexed. 93 | 94 | Returns: 95 | list[int]: A list of indices derived from the input string. 96 | """ 97 | # Remove square brackets, if any 98 | cleaned_output = llm_output.strip("[]") 99 | 100 | # Split the string by commas and convert each element to an integer 101 | result = [] 102 | for item in cleaned_output.split(","): 103 | item = item.strip() 104 | if item.isdigit(): 105 | result.append(int(item) - llm_output_start_index) 106 | return result 107 | 108 | 109 | @chain 110 | def string_to_json(llm_output: str): 111 | """ 112 | Converts a string output from a language model (LLM) to a JSON object. Useful after a `llm_generation_chain(..., output_json=True)` 113 | Args: 114 | llm_output (str): The string output from the LLM that needs to be converted to JSON. 115 | 116 | Returns: 117 | dict or None: The JSON object if the conversion is successful, otherwise None. 118 | 119 | Raises: 120 | json.JSONDecodeError: If there is an error in decoding the JSON string. 121 | """ 122 | try: 123 | return json.loads(llm_output) 124 | except json.JSONDecodeError as e: 125 | # Handle JSON decoding error 126 | logger.exception(f"Error decoding JSON: {e}") 127 | return None 128 | 129 | 130 | @chain 131 | def string_to_pydantic_object(llm_output: str, pydantic_class: BaseModel): 132 | try: 133 | return pydantic_class.model_validate(json.loads(llm_output)) 134 | except Exception as e: 135 | logger.exception( 136 | f"Error decoding JSON: {e}. This might be resolved by increasing `max_tokens`" 137 | ) 138 | logger.error(f"LLM output: {llm_output}") 139 | return None 140 | 141 | 142 | class ToolOutput(BaseModel): 143 | function: Callable 144 | kwargs: dict 145 | 146 | def __repr__(self): 147 | return ( 148 | f"{self.function.__name__}(" 149 | + ", ".join([f"{k}={repr(v)}" for k, v in self.kwargs.items()]) 150 | + ")" 151 | ) 152 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | ChainLite Logo 3 |

ChainLite

4 |
5 |
6 |

ChainLite combines LangChain and LiteLLM to provide an easy-to-use and customizable interface for large language model applications.

7 |

* Logo is generated using DALL·E 3.

8 |
9 | 10 | 11 | ## Installation 12 | 13 | ChainLite has been tested with Python 3.10. To install, do the following: 14 | 15 | 16 | 1. Install ChainLite via pip: 17 | ```bash 18 | pip install chainlite 19 | ``` 20 | or 21 | ```bash 22 | pip install https://github.com/stanford-oval/chainlite.git 23 | ``` 24 | 25 | 26 | 1. Copy `llm_config.yaml` to your project and follow the instructions there to update it with your own configuration. 27 | 28 | ## Usage 29 | 30 | Before you can use Chainlite, you can call the following function to load the configuration file. If you don't, ChainLite will use `llm_config.yaml` in the current directory (the directory you are running your script from) by default. 31 | 32 | ```python 33 | from chainlite import load_config_file 34 | load_config_file("./llm_config.yaml") # The path should be relative to the directory you run the script from, usually the root directory of your project 35 | ``` 36 | 37 | Make sure the corresponding API keys are set in environment variables with the name you specified in the configuration file, e.g. `OPENAI_API_KEY` etc. 38 | 39 | Then you can use the following functions in your code: 40 | 41 | ```python 42 | llm_generation_chain( 43 | template_file: str, 44 | engine: str, 45 | max_tokens: int, 46 | temperature: float = 0.0, 47 | stop_tokens: Optional[List[str]] = None, 48 | top_p: float = 0.9, 49 | output_json: bool = False, 50 | pydantic_class: Any = None, 51 | engine_for_structured_output: Optional[str] = None, 52 | template_blocks: Optional[list[tuple[str, str]]] = None, 53 | keep_indentation: bool = False, 54 | progress_bar_desc: Optional[str] = None, 55 | additional_postprocessing_runnable: Optional[Runnable] = None, 56 | tools: Optional[list[Callable]] = None, 57 | force_tool_calling: bool = False, 58 | return_top_logprobs: int = 0, 59 | bind_prompt_values: Optional[dict] = None, 60 | force_skip_cache: bool = False, 61 | reasoning_effort: Optional[str] = None, 62 | ) # returns a LangChain chain the accepts inputs and returns a string as output 63 | pprint_chain() # can be used to print inputs or outputs of a LangChain chain. 64 | write_prompt_logs_to_file(log_file: Optional[str]) # writes all instructions, inputs and outputs of all your LLM API calls to a jsonl file. Good for debugging or collecting data using LLMs 65 | get_total_cost() # returns the total cost of all LLM API calls you have made. Resets each time you run your code. 66 | ``` 67 | 68 | ## Full Example 69 | 70 | `joke.prompt` with a 1-shot example: 71 | 72 | ```markdown 73 | # instruction 74 | Tell a joke about the input topic. The format of the joke should be a question and response, separated by a line break. 75 | {# This is a comment, and will be ignored anywhere in a .prompt file. Other than block definitions and comments, '#' is allowed and is treated as a normal character. #} 76 | 77 | # input 78 | Physics 79 | 80 | # output 81 | Why don't scientists trust atoms? 82 | Because they make up everything! 83 | 84 | # input 85 | {{ topic }} 86 | ``` 87 | 88 | `main.py`: 89 | ```python 90 | from chainlite import load_config_from_file, llm_generation_chain, write_prompt_logs_to_file 91 | load_config_file("./chainlite_config.yaml") 92 | 93 | async def tell_joke(topic: str): 94 | response = await llm_generation_chain( 95 | template_file="joke.prompt", 96 | engine="gpt-35-turbo", 97 | max_tokens=100, 98 | ).ainvoke({"topic": topic}) 99 | print(response) 100 | 101 | import asyncio 102 | asyncio.run(tell_joke("Life as a PhD student")) # prints "Why did the PhD student bring a ladder to the library?\nTo take their research to the next level!" 103 | write_prompt_logs_to_file("llm_input_outputs.jsonl") 104 | ``` 105 | 106 | Then you will have `llm_input_outputs.jsonl`: 107 | ```json 108 | {"template_name": "joke.prompt", "instruction": "Tell a joke.", "input": "Life as a PhD student", "output": "Why did the PhD student bring a ladder to the library?\nTo take their research to the next level!"} 109 | ``` 110 | 111 | For more examples, see `tests/test_llm_generate.py` 112 | 113 | ## Configuration 114 | 115 | The `chainlite_config.yaml` file allows you to customize the behavior of ChainLite. Modify the file to set your preferences for the LangChain and LiteLLM integrations. 116 | 117 | ## Syntax Highlighting 118 | If you are using VSCode, you can install [this extension](https://marketplace.visualstudio.com/items?itemName=samuelcolvin.jinjahtml) and switch `.prompt` files to use the "Jinja Markdown" syntax highlighting. 119 | 120 | ## Contributing 121 | 122 | We welcome contributions! Please follow these steps to contribute: 123 | 124 | 1. Fork the repository. 125 | 2. Create a new branch for your feature or bugfix. 126 | 3. Commit your changes. 127 | 4. Push the branch to your forked repository. 128 | 5. Create a pull request with a detailed description of your changes. 129 | 130 | ## License 131 | 132 | ChainLite is licensed under the Apache-2.0 License. See the [LICENSE](LICENSE) file for more information. 133 | 134 | ## Contact 135 | 136 | For any questions or inquiries, please open an issue on the [GitHub Issues](https://github.com/stanford-oval/chainlite/issues) page. 137 | 138 | --- 139 | 140 | Thank you for using ChainLite! 141 | -------------------------------------------------------------------------------- /chainlite/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from typing import Any, Optional 4 | 5 | from pydantic import validate_call 6 | import rich 7 | from langchain_core.runnables import chain 8 | from tqdm.asyncio import tqdm as async_tqdm 9 | 10 | logging.getLogger("LiteLLM").setLevel(logging.WARNING) 11 | logging.getLogger("LiteLLM Router").setLevel(logging.WARNING) 12 | logging.getLogger("LiteLLM Proxy").setLevel(logging.WARNING) 13 | logging.getLogger("httpx").setLevel(logging.WARNING) 14 | 15 | 16 | def get_logger(name: Optional[str] = None): 17 | logger = logging.getLogger(name) 18 | # logger.setLevel(logging.INFO) 19 | if not logger.hasHandlers(): 20 | # without this if statement, will have duplicate logs 21 | handler = logging.StreamHandler() 22 | formatter = logging.Formatter("%(levelname)-5s : %(message)s") 23 | handler.setFormatter(formatter) 24 | logger.addHandler(handler) 25 | return logger 26 | 27 | 28 | logger = get_logger(__name__) 29 | 30 | 31 | async def run_async_in_parallel( 32 | async_function, 33 | *iterables, 34 | max_concurrency: int, 35 | timeout: float = 60 * 60 * 1, 36 | desc: str = "", 37 | progbar_update_interval: float = 1.0, 38 | ): 39 | """ 40 | Executes an asynchronous function concurrently over multiple iterables with a fixed number of worker tasks. 41 | 42 | This function schedules calls to the provided asynchronous function using a queue and worker tasks. 43 | Each worker repeatedly retrieves a tuple of arguments from the queue, executes the asynchronous function, 44 | and if the call does not complete within the specified timeout, it is retried until successful. The progress 45 | of execution is periodically updated via a progress bar that refreshes at fixed intervals. 46 | 47 | Parameters: 48 | async_function (Callable): The asynchronous function to be executed. It must accept as many arguments as there are iterables. 49 | *iterables (Iterable): One or more iterables supplying arguments for async_function. All iterables must have the same length. 50 | max_concurrency (int): The maximum number of concurrent worker tasks to run. 51 | timeout (float, optional): The maximum number of seconds to wait for async_function to complete for each call. 52 | Defaults to 3600 seconds (1 hour). 53 | desc (str, optional): Description text for the progress bar. If empty, the progress bar is disabled. 54 | Defaults to an empty string. 55 | progbar_update_interval (float, optional): The interval (in seconds) at which the progress bar is updated. 56 | Defaults to 1.0 second. 57 | 58 | Returns: 59 | list: A list of results obtained from executing async_function with the provided arguments. 60 | If a task times out or raises an exception, the corresponding result is set to None. 61 | 62 | Raises: 63 | ValueError: If the provided iterables do not have the same length. 64 | """ 65 | if not iterables: 66 | return [] 67 | 68 | length = len(iterables[0]) 69 | for it in iterables: 70 | if len(it) != length: 71 | raise ValueError("All iterables must have the same length.") 72 | 73 | # Enqueue all jobs as (index, args) pairs. 74 | queue: asyncio.Queue[tuple[int, tuple]] = asyncio.Queue() 75 | for index, args in enumerate(zip(*iterables)): 76 | await queue.put((index, args)) 77 | 78 | results: list = [None] * length 79 | finished_count = 0 # shared progress counter 80 | pbar = async_tqdm(total=length, smoothing=0, desc=desc, disable=(not desc)) 81 | 82 | async def worker(): 83 | nonlocal finished_count 84 | while True: 85 | try: 86 | index, args = await queue.get() 87 | except asyncio.CancelledError: 88 | break 89 | try: 90 | # Retry until async_function finishes within timeout. 91 | while True: 92 | try: 93 | # Wait for the async_function with a timeout. 94 | results[index] = await asyncio.wait_for( 95 | async_function(*args), timeout 96 | ) 97 | break # success: exit retry loop 98 | except asyncio.TimeoutError: 99 | # Log or print a message here if desired. 100 | # The task timed out; retry it. 101 | continue 102 | except Exception: 103 | logger.exception(f"Exception in async worker: {index}") 104 | results[index] = None 105 | finally: 106 | finished_count += 1 107 | queue.task_done() 108 | 109 | # Refresh task to update the tqdm progress bar exactly once every progbar_update_interval seconds. 110 | stop_refresh = asyncio.Event() 111 | 112 | async def refresh_progress(): 113 | while not stop_refresh.is_set(): 114 | await asyncio.sleep(progbar_update_interval) 115 | pbar.n = finished_count 116 | pbar.refresh() 117 | 118 | # Spawn the worker tasks and the refresh task. 119 | workers = [asyncio.create_task(worker()) for _ in range(max_concurrency)] 120 | refresh_task = asyncio.create_task(refresh_progress()) 121 | 122 | # Wait until all jobs are processed. 123 | await queue.join() 124 | stop_refresh.set() 125 | await refresh_task # wait for refresh task to finish 126 | 127 | # Cancel workers and do a final progress update. 128 | for w in workers: 129 | w.cancel() 130 | pbar.n = finished_count 131 | pbar.refresh() 132 | pbar.close() 133 | 134 | return results 135 | 136 | 137 | @chain 138 | def pprint_chain(_dict: Any) -> Any: 139 | """ 140 | Print intermediate results for debugging 141 | """ 142 | rich.print(_dict) 143 | return _dict 144 | 145 | 146 | def validate_function(): 147 | """A shortcut decorator""" 148 | return validate_call( 149 | validate_return=True, config=dict(arbitrary_types_allowed=True) 150 | ) 151 | -------------------------------------------------------------------------------- /chainlite/llm_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | 4 | import warnings 5 | from pydantic import PydanticDeprecatedSince20 6 | 7 | warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20) 8 | 9 | import yaml 10 | 11 | from chainlite.load_prompt import initialize_jinja_environment 12 | from chainlite.threadsafe_dict import ThreadSafeDict 13 | 14 | 15 | class GlobalVars: 16 | prompt_logs = ThreadSafeDict() 17 | all_llm_endpoints = None 18 | prompt_dirs = None 19 | prompt_log_file = None 20 | prompts_to_skip_for_debugging = None 21 | local_engine_set = None 22 | 23 | @classmethod 24 | def get_all_configured_engines(cls): 25 | all_engines = set() 26 | for endpoint in cls.all_llm_endpoints: 27 | all_engines.update(endpoint["engine_map"].keys()) 28 | return all_engines 29 | 30 | 31 | def get_all_configured_engines(): 32 | initialize_llm_config() 33 | return GlobalVars.get_all_configured_engines() 34 | 35 | 36 | chainlite_initialized = False 37 | 38 | 39 | def initialize_llm_config(config_file: str = "./llm_config.yaml") -> None: 40 | global chainlite_initialized 41 | if chainlite_initialized: 42 | return 43 | chainlite_initialized = True 44 | 45 | import litellm 46 | from chainlite.redis_cache import init_redis_client 47 | init_redis_client() 48 | 49 | 50 | litellm.drop_params = True # Drops unsupported parameters for non-OpenAI APIs like TGI and Together.ai 51 | litellm.success_callback = [ 52 | track_cost_callback 53 | ] # Assign the cost callback function 54 | 55 | if GlobalVars.all_llm_endpoints is not None: 56 | # Configuration file is already loaded 57 | return 58 | with open(config_file, "r") as config_file: 59 | config = yaml.unsafe_load(config_file) 60 | 61 | # TODO raise errors if these values are not set, use pydantic v2 62 | GlobalVars.prompt_dirs = config.get("prompt_dirs", ["./"]) 63 | GlobalVars.prompt_log_file = config.get("prompt_logging", {}).get( 64 | "log_file", "./prompt_logs.jsonl" 65 | ) 66 | GlobalVars.prompts_to_skip_for_debugging = set( 67 | config.get("prompt_logging", {}).get("prompts_to_skip", []) 68 | ) 69 | 70 | set_verbose = config.get("litellm_set_verbose", False) 71 | if set_verbose: 72 | os.environ["LITELLM_LOG"] = "DEBUG" 73 | 74 | GlobalVars.all_llm_endpoints = config.get("llm_endpoints", []) 75 | for a in GlobalVars.all_llm_endpoints: 76 | if "api_key" in a: 77 | a["api_key"] = os.getenv(a["api_key"]) 78 | 79 | GlobalVars.all_llm_endpoints = [ 80 | a 81 | for a in GlobalVars.all_llm_endpoints 82 | if "api_key" not in a or (a["api_key"] is not None and len(a["api_key"]) > 0) 83 | ] # remove resources for which we don't have a key 84 | 85 | # tell LiteLLM how we want to map the messages to a prompt string for these non-chat models 86 | for endpoint in GlobalVars.all_llm_endpoints: 87 | if "prompt_format" in endpoint: 88 | if endpoint["prompt_format"] == "distilled": 89 | # {instruction}\n\n{input}\n 90 | for engine, model in endpoint["engine_map"].items(): 91 | litellm.register_prompt_template( 92 | model=model, 93 | roles={ 94 | "system": { 95 | "pre_message": "", 96 | "post_message": "\n\n", 97 | }, 98 | "user": { 99 | "pre_message": "", 100 | "post_message": "\n", 101 | }, 102 | "assistant": { 103 | "pre_message": "", 104 | "post_message": "\n", 105 | }, # this will be ignored since "distilled" formats only support one output turn 106 | }, 107 | initial_prompt_value="", 108 | final_prompt_value="", 109 | ) 110 | else: 111 | raise ValueError( 112 | f"Unsupported prompt format: {endpoint['prompt_format']}" 113 | ) 114 | GlobalVars.local_engine_set = set() 115 | 116 | for endpoint in GlobalVars.all_llm_endpoints: 117 | for engine, model in endpoint["engine_map"].items(): 118 | if model.startswith("huggingface/"): 119 | GlobalVars.local_engine_set.add(engine) 120 | 121 | initialize_jinja_environment(GlobalVars.prompt_dirs) 122 | 123 | 124 | # this code is NOT safe to use with multiprocessing, only multithreading 125 | thread_lock = threading.Lock() 126 | 127 | total_cost = 0.0 # in USD 128 | 129 | 130 | def add_to_total_cost(amount: float): 131 | global total_cost 132 | with thread_lock: 133 | total_cost += amount 134 | 135 | 136 | def get_total_cost() -> float: 137 | """ 138 | This function is used to get the total LLM cost accumulated so far 139 | 140 | Returns: 141 | float: The total cost accumulated so far in USD. 142 | """ 143 | global total_cost 144 | return total_cost 145 | 146 | 147 | async def track_cost_callback( 148 | kwargs, # kwargs to completion 149 | completion_response, # response from completion 150 | start_time, 151 | end_time, # start/end time 152 | ): 153 | from litellm import completion_cost 154 | 155 | try: 156 | if kwargs["cache_hit"]: 157 | # no cost because of caching 158 | # TODO this doesn't work with streaming 159 | return 160 | 161 | response_cost = 0 162 | # check if we have collected an entire stream response 163 | if "complete_streaming_response" in kwargs: 164 | # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost 165 | completion_response = kwargs["complete_streaming_response"] 166 | input_text = kwargs["messages"] 167 | output_text = completion_response["choices"][0]["message"]["content"] 168 | response_cost = completion_cost( 169 | model=kwargs["model"], messages=input_text, completion=output_text 170 | ) 171 | elif kwargs["stream"] != True: 172 | # for non streaming responses 173 | response_cost = completion_cost(completion_response=completion_response) 174 | if response_cost > 0: 175 | add_to_total_cost(response_cost) 176 | except: 177 | pass 178 | # This can happen for example because of local models 179 | -------------------------------------------------------------------------------- /chainlite/load_prompt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functionality to work with .prompt files in Jinja2 format. 3 | """ 4 | 5 | import re 6 | from datetime import datetime 7 | from functools import lru_cache 8 | from typing import List, Tuple 9 | from zoneinfo import ZoneInfo # Python 3.9 and later 10 | 11 | from jinja2 import Environment, FileSystemLoader 12 | from langchain_core.prompts import ( 13 | AIMessagePromptTemplate, 14 | ChatPromptTemplate, 15 | HumanMessagePromptTemplate, 16 | SystemMessagePromptTemplate, 17 | ) 18 | 19 | jinja2_comment_pattern = re.compile(r"{#.*?#}", re.DOTALL) 20 | 21 | # Initial setup for prompt_block_identifiers remains the same 22 | prompt_block_identifiers = { 23 | "input": [ 24 | "# input\n", 25 | "# Input\n", 26 | "# INPUT\n", 27 | "# user\n", 28 | "# User\n", 29 | "# USER\n", 30 | "# human\n", 31 | "# Human\n", 32 | "# HUMAN\n", 33 | ], 34 | "output": [ 35 | "# output\n", 36 | "# Output\n", 37 | "# OUTPUT\n", 38 | "# assistant\n", 39 | "# Assistant\n", 40 | "# ASSISTANT\n", 41 | "# ai\n", 42 | "# Ai\n", 43 | "# AI\n", 44 | ], 45 | "instruction": [ 46 | "# instruction\n", 47 | "# Instruction\n", 48 | "# INSTRUCTION\n", 49 | "# System\n", 50 | "# SYSTEM\n", 51 | "# system\n", 52 | ], 53 | } 54 | 55 | jinja_environment = None # Global variable to hold the Jinja2 environment 56 | 57 | 58 | def initialize_jinja_environment(loader_paths): 59 | global jinja_environment 60 | 61 | loader = FileSystemLoader(loader_paths) 62 | jinja_environment = Environment( 63 | loader=loader, 64 | trim_blocks=True, 65 | lstrip_blocks=True, 66 | ) 67 | 68 | 69 | @lru_cache() 70 | def load_template_file(template_file: str, keep_indentation: bool) -> str: 71 | """ 72 | This function is here just so that we can cache the templates and not have to read from disk every time. 73 | Also removes comment blocks and white space at the beginning and end of each line. These are usually added to make prompt templates more readable. 74 | We remove comment blocks first, so that we can get rid of extra lines before or after them. 75 | """ 76 | raw_template = jinja_environment.loader.get_source( 77 | jinja_environment, template_file 78 | )[0] 79 | raw_template = re.sub(jinja2_comment_pattern, "", raw_template) 80 | if not keep_indentation: 81 | raw_template = "\n".join([line.strip() for line in raw_template.split("\n")]) 82 | else: 83 | raw_template = "\n".join([line.rstrip() for line in raw_template.split("\n")]) 84 | raw_template = re.sub( 85 | r"%}\s*", "%}", raw_template 86 | ) # remove the white space after {% for ... %} tags etc. 87 | 88 | return raw_template 89 | 90 | 91 | added_template_constants = {} 92 | 93 | 94 | def register_prompt_constants(constant_name_to_value_map: dict) -> None: 95 | """ 96 | Make constant values available to all prompt templates. 97 | By default, current_year, today and location are set, and you can overwrite them or add new constants using this method. 98 | 99 | Args: 100 | constant_name_to_value_map (dict): A dictionary where keys are constant names and values are the corresponding constant values. 101 | 102 | Returns: 103 | None 104 | """ 105 | for k, v in constant_name_to_value_map.items(): 106 | added_template_constants[k] = v 107 | 108 | 109 | def add_constants_to_template( 110 | chat_prompt_template: ChatPromptTemplate, 111 | ) -> ChatPromptTemplate: 112 | # always make these useful constants available in a template 113 | # make a new function call each time since the date might change during a long-term server deployment 114 | pacific_zone = ZoneInfo("America/Los_Angeles") 115 | today = datetime.now(pacific_zone).date() 116 | 117 | template_constants = { 118 | "current_year": today.year, 119 | "today": today.strftime("%B %d, %Y"), # e.g. May 30, 2024 120 | "location": "the U.S.", 121 | } 122 | for k, v in added_template_constants.items(): 123 | template_constants[k] = v 124 | 125 | chat_prompt_template = chat_prompt_template.partial(**template_constants) 126 | 127 | return chat_prompt_template 128 | 129 | 130 | def find_all_substrings(string, substring) -> List[str]: 131 | return [match.start() for match in re.finditer(re.escape(substring), string)] 132 | 133 | 134 | def _split_prompt_to_blocks(prompt: str) -> List[Tuple[str, str]]: 135 | block_indices = [] 136 | for identifier in prompt_block_identifiers: 137 | for alternative in prompt_block_identifiers[identifier]: 138 | for i in find_all_substrings(prompt, alternative): 139 | block_indices.append((i, identifier, alternative)) 140 | 141 | block_indices = sorted( 142 | block_indices 143 | ) # sort according to the index they were found at 144 | 145 | # check the prompt format is correct 146 | assert ( 147 | len([b for b in block_indices if b[1] == "instruction"]) <= 1 148 | ), "Prompts should contain at most one instruction block" 149 | 150 | num_inputs = len([b for b in block_indices if b[1] == "input"]) 151 | num_outputs = len([b for b in block_indices if b[1] == "output"]) 152 | fewshot_start = 1 153 | assert (num_inputs == num_outputs + 1) or ( 154 | num_inputs == 0 and num_outputs == 0 155 | ), "The order of few-shot blocks in the prompt should be ((input -> output) * N) -> input" 156 | for i, b in enumerate(block_indices[fewshot_start:]): 157 | if i % 2 == 0: 158 | assert ( 159 | b[1] == "input" 160 | ), "The order of few-shot blocks in the prompt should be ((input -> output) * N) -> input" 161 | else: 162 | assert ( 163 | b[1] == "output" 164 | ), "The order of few-shot blocks in the prompt should be ((input -> output) * N) -> input" 165 | 166 | block_indices_with_end = block_indices + [(len(prompt), "end", "end")] 167 | blocks = [] 168 | for i in range(len(block_indices)): 169 | block_content = prompt[ 170 | block_indices_with_end[i][0] 171 | + len(block_indices_with_end[i][2]) : block_indices_with_end[i + 1][0] 172 | ].strip() 173 | 174 | blocks.append((block_indices_with_end[i][1], block_content)) 175 | 176 | return blocks 177 | 178 | 179 | def _prompt_blocks_to_chat_messages( 180 | blocks: List[Tuple[str, str]], is_distilled: bool 181 | ) -> Tuple[ChatPromptTemplate, str | None]: 182 | message_prompt_templates = [] 183 | 184 | # Add an instruction block if it is not present 185 | if len([b for b in blocks if b[0] == "instruction"]) == 0: 186 | blocks = [("instruction", "")] + blocks 187 | 188 | for block_type, block in blocks: 189 | if block_type == "instruction": 190 | block_type = SystemMessagePromptTemplate 191 | elif block_type == "input": 192 | block_type = HumanMessagePromptTemplate 193 | elif block_type == "output": 194 | block_type = AIMessagePromptTemplate 195 | else: 196 | assert False 197 | message_prompt_templates.append( 198 | block_type.from_template(block, template_format="jinja2") 199 | ) 200 | if is_distilled: 201 | # only keep the system message and the last input 202 | message_prompt_templates = [ 203 | message_prompt_templates[0], 204 | message_prompt_templates[-1], 205 | ] 206 | chat_prompt_template = ChatPromptTemplate.from_messages(message_prompt_templates) 207 | chat_prompt_template = add_constants_to_template(chat_prompt_template) 208 | 209 | return chat_prompt_template 210 | 211 | 212 | def load_fewshot_prompt_template( 213 | template_file: str, 214 | template_blocks: list[tuple[str]], 215 | is_distilled: bool, 216 | keep_indentation: bool, 217 | ) -> Tuple[ChatPromptTemplate, str | None]: 218 | if not template_blocks: 219 | fp = load_template_file(template_file, keep_indentation) 220 | template_blocks = _split_prompt_to_blocks(fp) 221 | chat_prompt_template = _prompt_blocks_to_chat_messages( 222 | template_blocks, is_distilled 223 | ) 224 | 225 | return chat_prompt_template 226 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tests/test_llm_generate.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import random 3 | import string 4 | import time 5 | from datetime import datetime 6 | from zoneinfo import ZoneInfo 7 | 8 | import pytest 9 | from langchain_core.runnables import RunnableLambda 10 | 11 | from chainlite import ( 12 | get_all_configured_engines, 13 | get_logger, 14 | get_total_cost, 15 | llm_generation_chain, 16 | register_prompt_constants, 17 | write_prompt_logs_to_file, 18 | ) 19 | from chainlite.llm_config import GlobalVars 20 | from chainlite.utils import run_async_in_parallel 21 | import time 22 | import subprocess 23 | import sys 24 | import pytest 25 | 26 | logger = get_logger(__name__) 27 | 28 | 29 | test_engine = "gpt-4o-openai" 30 | 31 | 32 | def test_chainlite_import_time(): 33 | # Measure the import time of the chainlite module in a subprocess, 34 | # using Python's -X importtime flag to get detailed import timing. 35 | script = "import chainlite" 36 | start = time.perf_counter() 37 | result = subprocess.run( 38 | [sys.executable, "-X", "importtime", "-c", script], 39 | capture_output=True, 40 | text=True, 41 | ) 42 | elapsed = time.perf_counter() - start 43 | 44 | # Check that the import was successful. 45 | assert result.returncode == 0, f"Chainlite import failed: {result.stderr}" 46 | 47 | # Set a threshold for the import time (e.g., 0.1 second). 48 | threshold = 1 49 | if elapsed >= threshold: 50 | print("Import took too long. Breakdown of import times:") 51 | print(result.stderr) 52 | # Write the detailed trace to disk. 53 | with open("chainlite_import_trace.log", "w") as f: 54 | f.write(result.stderr) 55 | 56 | assert ( 57 | elapsed < threshold 58 | ), f"Importing chainlite took too long: {elapsed:.2f} seconds" 59 | logger.info(f"Importing chainlite took {elapsed:.2f} seconds") 60 | 61 | 62 | @pytest.fixture(scope="session", autouse=True) 63 | def run_after_all_tests(): 64 | """ 65 | This fixture will run after all tests in the session. 66 | """ 67 | yield # This ensures that the fixture runs after all tests are done. 68 | 69 | write_prompt_logs_to_file("tests/llm_input_outputs.jsonl") 70 | with open("tests/llm_input_outputs.jsonl", "r") as f: 71 | prompt_logs = f.read() 72 | 73 | print(prompt_logs) 74 | assert ( 75 | "test.prompt" not in prompt_logs 76 | ), "test.prompt is in the prompts_to_skip and therefore should not be logged" 77 | logger.info(f"Total LLM cost: ${get_total_cost():.1f}") 78 | 79 | 80 | @pytest.mark.asyncio(scope="session") 81 | async def test_instruction_with_variable(): 82 | # Test a prompt where the instruction block contains a variable. 83 | template_blocks = [ 84 | ("instruction", "You are a chatbot named {{ name }}."), 85 | ("input", "{{ message }}"), 86 | ] 87 | # Invoke the chain with variables to be interpolated. 88 | variables = {"name": "ChainLite", "message": "what is your name?"} 89 | response = await llm_generation_chain( 90 | template_file="", 91 | template_blocks=template_blocks, 92 | engine=test_engine, 93 | max_tokens=50, 94 | temperature=0, 95 | ).ainvoke(variables) 96 | 97 | assert "chainlite" in response.lower() 98 | 99 | 100 | @pytest.mark.asyncio(scope="session") 101 | async def test_llm_generate(): 102 | logger.info("All registered engines: %s", str(get_all_configured_engines())) 103 | 104 | # Check that the config file has been loaded properly 105 | assert GlobalVars.all_llm_endpoints 106 | assert GlobalVars.prompt_dirs 107 | assert GlobalVars.prompt_log_file 108 | # assert GlobalVars.prompts_to_skip_for_debugging 109 | assert GlobalVars.local_engine_set 110 | 111 | response = await llm_generation_chain( 112 | template_file="test.prompt", # prompt path relative to one of the paths specified in `prompt_dirs` 113 | engine=test_engine, 114 | max_tokens=100, 115 | force_skip_cache=True, 116 | ).ainvoke({}) 117 | # logger.info(response) 118 | 119 | assert response is not None, "The response should not be None" 120 | assert isinstance(response, str), "The response should be a string" 121 | assert len(response) > 0, "The response should not be empty" 122 | 123 | 124 | @pytest.mark.asyncio(scope="session") 125 | async def test_string_prompts(): 126 | response = await llm_generation_chain( 127 | template_file="", 128 | template_blocks=[ 129 | ("instruction", "X = 1, Y = 6."), 130 | ("input", "what is X?"), 131 | ("output", "The value of X is one"), 132 | ("input", "what is {{ variable }}?"), 133 | ], 134 | engine=test_engine, 135 | max_tokens=10, 136 | temperature=0, 137 | ).ainvoke({"variable": "Y"}) 138 | assert "The value of Y is six" in response 139 | 140 | # Without instruction block 141 | response = await llm_generation_chain( 142 | template_file="", 143 | template_blocks=[ 144 | ("input", "what is X?"), 145 | ("output", "The value of X is one"), 146 | ("input", "what is {{ variable }}?"), 147 | ], 148 | engine=test_engine, 149 | max_tokens=10, 150 | temperature=0, 151 | ).ainvoke({"variable": "Y"}) 152 | 153 | 154 | @pytest.mark.asyncio(scope="session") 155 | @pytest.mark.parametrize("engine", ["gpt-4o-openai", "gpt-4o-azure"]) 156 | async def test_llm_examples(engine): 157 | response = await llm_generation_chain( 158 | template_file="tests/joke.prompt", 159 | engine=engine, 160 | max_tokens=100, 161 | temperature=0.1, 162 | progress_bar_desc="test1", 163 | additional_postprocessing_runnable=RunnableLambda(lambda x: x[:5]), 164 | ).ainvoke({"topic": "Life as a PhD student"}) 165 | 166 | assert isinstance(response, str) 167 | assert len(response) == 5 168 | 169 | 170 | @pytest.mark.asyncio(scope="session") 171 | async def test_constants(): 172 | pacific_zone = ZoneInfo("America/Los_Angeles") 173 | today = datetime.now(pacific_zone).date().strftime("%B %d, %Y") # e.g. May 30, 2024 174 | response = await llm_generation_chain( 175 | template_file="tests/constants.prompt", 176 | engine=test_engine, 177 | max_tokens=10, 178 | temperature=0, 179 | ).ainvoke({"question": "What is today's date?"}) 180 | assert today in response 181 | 182 | # overwrite "today" 183 | register_prompt_constants({"today": "Thursday"}) 184 | response = await llm_generation_chain( 185 | template_file="tests/constants.prompt", 186 | engine=test_engine, 187 | max_tokens=10, 188 | temperature=0, 189 | ).ainvoke({"question": "What day of the week is today?"}) 190 | assert "thursday" in response.lower() 191 | 192 | 193 | @pytest.mark.asyncio(scope="session") 194 | async def test_batching(): 195 | chain_inputs = [ 196 | {"topic": "Ice cream"}, 197 | {"topic": "Cats"}, 198 | {"topic": "Dogs"}, 199 | {"topic": "Rabbits"}, 200 | ] 201 | response = await llm_generation_chain( 202 | template_file="tests/joke.prompt", 203 | engine=test_engine, 204 | max_tokens=10, 205 | temperature=0.1, 206 | progress_bar_desc="test2", 207 | ).abatch(chain_inputs) 208 | assert len(response) == len(chain_inputs) 209 | 210 | 211 | @pytest.mark.asyncio(scope="session") 212 | async def test_cached_batching(): 213 | c = llm_generation_chain( 214 | template_file="tests/joke.prompt", 215 | engine=test_engine, 216 | max_tokens=100, 217 | temperature=0.0, 218 | progress_bar_desc="test2", 219 | ) 220 | await c.ainvoke({"topic": "Ice cream"}) 221 | first_cost = get_total_cost() 222 | start_time = time.time() 223 | response = await c.abatch([{"topic": "Ice cream"}] * 200) 224 | elapsed_time = time.time() - start_time 225 | assert ( 226 | elapsed_time < 1 227 | ), f"The batched LLM calls should be cached and therefore very fast, but took {elapsed_time:.2f} seconds" 228 | assert ( 229 | get_total_cost() == first_cost 230 | ), "The cost should not change after a cached batched LLM call" 231 | 232 | 233 | @pytest.mark.asyncio(scope="session") 234 | @pytest.mark.parametrize("engine", ["o1", "o3-mini", "o3-mini-azure"]) 235 | async def test_reasoning_models(engine): 236 | response = await llm_generation_chain( 237 | template_file="tests/joke.prompt", 238 | engine=engine, 239 | max_tokens=2000, 240 | temperature=0.01, 241 | ).ainvoke({"topic": "A strawberry."}) 242 | assert response 243 | 244 | 245 | @pytest.mark.asyncio(scope="session") 246 | async def test_cache(): 247 | c = llm_generation_chain( 248 | template_file="tests/copy.prompt", 249 | engine=test_engine, 250 | max_tokens=100, 251 | temperature=0.0, 252 | ) 253 | # use random input so that the first call is not cached 254 | start_time = time.time() 255 | random_input = "".join(random.choices(string.ascii_letters + string.digits, k=20)) 256 | response1 = await c.ainvoke({"input": random_input}) 257 | first_time = time.time() - start_time 258 | first_cost = get_total_cost() 259 | 260 | print("First call took {:.2f} seconds".format(first_time)) 261 | print("Total cost after first call: ${:.10f}".format(first_cost)) 262 | 263 | start_time = time.time() 264 | response2 = await c.ainvoke({"input": random_input}) 265 | second_time = time.time() - start_time 266 | print("Second call took {:.2f} seconds".format(second_time)) 267 | second_cost = get_total_cost() 268 | print("Total cost after second call: ${:.10f}".format(second_cost)) 269 | 270 | assert response1 == response2 271 | assert ( 272 | second_time < first_time * 0.5 273 | ), "The second (cached) LLM call should be much faster than the first call" 274 | assert first_cost > 0, "The cost should be greater than 0" 275 | assert ( 276 | second_cost == first_cost 277 | ), "The cost should not change after a cached LLM call" 278 | 279 | 280 | @pytest.mark.asyncio(scope="session") 281 | @pytest.mark.parametrize("engine", ["o3-mini", "o3-mini-azure"]) 282 | async def test_reasoning_effort_cache(engine: str): 283 | c1 = llm_generation_chain( 284 | template_file="tests/copy.prompt", 285 | engine=engine, 286 | max_tokens=1000, 287 | temperature=0.0, 288 | reasoning_effort="low", 289 | ) 290 | 291 | c2 = llm_generation_chain( 292 | template_file="tests/copy.prompt", 293 | engine=engine, 294 | max_tokens=1000, 295 | temperature=0.0, 296 | reasoning_effort="medium", 297 | ) 298 | # use random input so that the first call is not cached 299 | start_time = time.time() 300 | random_input = "".join(random.choices(string.ascii_letters + string.digits, k=20)) 301 | response1 = await c1.ainvoke({"input": random_input}) 302 | first_time = time.time() - start_time 303 | first_cost = get_total_cost() 304 | 305 | print("First call took {:.2f} seconds".format(first_time)) 306 | print("Total cost after first call: ${:.10f}".format(first_cost)) 307 | 308 | start_time = time.time() 309 | response2 = await c2.ainvoke({"input": random_input}) 310 | second_time = time.time() - start_time 311 | print("Second call took {:.2f} seconds".format(second_time)) 312 | second_cost = get_total_cost() 313 | print("Total cost after second call: ${:.10f}".format(second_cost)) 314 | 315 | assert response1 == response2 316 | assert ( 317 | second_time > first_time * 0.5 318 | ), "The different reasoning efforts should not be cached" 319 | assert first_cost > 0, "The cost should be greater than 0" 320 | assert ( 321 | second_cost > first_cost 322 | ), "The cost should increase after a different reasoning effort LLM call" 323 | 324 | # another call to c1 should be cached 325 | start_time = time.time() 326 | response3 = await c1.ainvoke({"input": random_input}) 327 | third_time = time.time() - start_time 328 | third_cost = get_total_cost() 329 | print("Third call took {:.2f} seconds".format(third_time)) 330 | print("Total cost after third call: ${:.10f}".format(third_cost)) 331 | assert response1 == response3 332 | assert ( 333 | third_time < first_time * 0.5 334 | ), "The third (cached) LLM call should be much faster than the first call" 335 | assert ( 336 | third_cost == second_cost 337 | ), "The cost should not change after a cached LLM call" 338 | 339 | 340 | @pytest.mark.asyncio(scope="session") 341 | async def test_run_async_in_parallel(): 342 | 343 | async def async_function(i, j): 344 | await asyncio.sleep(1) 345 | return i 346 | 347 | test_inputs1 = range(10) 348 | test_inputs2 = range(10, 20) 349 | max_concurrency = 5 350 | desc = "test" 351 | ret = await run_async_in_parallel( 352 | async_function, 353 | test_inputs1, 354 | test_inputs2, 355 | max_concurrency=max_concurrency, 356 | desc=desc, 357 | ) 358 | assert ret == list(test_inputs1) 359 | 360 | 361 | @pytest.mark.asyncio(scope="session") 362 | async def test_o1_reasoning_effort(): 363 | for reasoning_effort in ["low", "medium", "high"]: 364 | start_time = time.time() 365 | response = await llm_generation_chain( 366 | template_file="tests/reasoning.prompt", 367 | engine="o1", 368 | max_tokens=2000, 369 | force_skip_cache=True, 370 | reasoning_effort=reasoning_effort, 371 | ).ainvoke({}) 372 | print(response) 373 | print( 374 | f"Reasoning effort: {reasoning_effort}, Time taken: {time.time() - start_time}" 375 | ) 376 | assert response 377 | -------------------------------------------------------------------------------- /chainlite/llm_generate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from datetime import datetime 5 | from typing import Any, Callable, List, Optional 6 | from uuid import UUID 7 | 8 | import warnings 9 | from pydantic import PydanticDeprecatedSince20 10 | 11 | warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20) 12 | 13 | from langchain_core.callbacks import AsyncCallbackHandler 14 | from langchain_core.output_parsers import StrOutputParser 15 | from langchain_core.outputs import LLMResult 16 | from langchain_core.runnables import ( 17 | RunnableLambda, 18 | RunnablePassthrough, 19 | Runnable, 20 | chain, 21 | ) 22 | from tqdm.auto import tqdm 23 | 24 | from chainlite.chain_log_handler import ChainLogHandler 25 | from chainlite.chat_lite_llm import ChatLiteLLM 26 | from chainlite.llm_config import GlobalVars, initialize_llm_config 27 | from chainlite.llm_output import ToolOutput, string_to_pydantic_object 28 | from chainlite.load_prompt import load_fewshot_prompt_template 29 | from chainlite.utils import get_logger, validate_function 30 | 31 | logger = get_logger(__name__) 32 | 33 | 34 | def is_same_prompt(template_name_1: str, template_name_2: str) -> bool: 35 | return os.path.basename(template_name_1) == os.path.basename(template_name_2) 36 | 37 | 38 | def write_prompt_logs_to_file( 39 | log_file: Optional[str] = None, 40 | append: bool = False, 41 | include_timestamp: bool = False, 42 | ): 43 | if not log_file: 44 | log_file = GlobalVars.prompt_log_file 45 | key_order = [ 46 | "template_name", 47 | "instruction", 48 | "input", 49 | "output", 50 | ] # specifies the sort order of keys in the output, for a better viewing experience 51 | if include_timestamp: 52 | key_order = ["datetime"] + key_order 53 | 54 | mode = "w" 55 | if append: 56 | mode = "a" 57 | with open(log_file, mode) as f: 58 | for item in GlobalVars.prompt_logs.values(): 59 | should_skip = False 60 | for t in GlobalVars.prompts_to_skip_for_debugging: 61 | if is_same_prompt(t, item["template_name"]): 62 | should_skip = True 63 | break 64 | if should_skip: 65 | continue 66 | if "output" not in item: 67 | # happens if the code crashes in the middle of a an LLM call 68 | continue 69 | if include_timestamp: 70 | datetime_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 71 | item["datetime"] = datetime_str 72 | 73 | f.write( 74 | json.dumps( 75 | {key: item[key] for key in key_order}, 76 | ensure_ascii=False, 77 | ) 78 | ) 79 | f.write("\n") 80 | 81 | 82 | class ProgbarHandler(AsyncCallbackHandler): 83 | def __init__(self, desc: str): 84 | super().__init__() 85 | self.count = 0 86 | self.desc = desc 87 | 88 | # Override on_llm_end method. This is called after every response from LLM 89 | async def on_llm_end( 90 | self, 91 | response: LLMResult, 92 | *, 93 | run_id: UUID, 94 | parent_run_id: UUID | None = None, 95 | **kwargs: Any, 96 | ) -> Any: 97 | if self.count == 0: 98 | self.progress_bar = tqdm( 99 | total=None, 100 | desc=self.desc, 101 | unit=" LLM Calls", 102 | bar_format="{desc}: {n_fmt}{unit} ({rate_fmt})", 103 | mininterval=0, 104 | position=0, 105 | leave=True, 106 | ) # define a progress bar 107 | self.count += 1 108 | self.progress_bar.update(1) 109 | 110 | 111 | def is_list(obj): 112 | return isinstance(obj, list) 113 | 114 | 115 | def is_dict(obj): 116 | return isinstance(obj, dict) 117 | 118 | 119 | @chain 120 | async def return_response_and_tool( 121 | llm_output, tools: list[Callable], force_tool_calling: bool 122 | ): 123 | response = await StrOutputParser().ainvoke(input=llm_output) 124 | tool_output_in_json_format = llm_output.tool_calls 125 | 126 | tool_outputs = [] 127 | for t in tool_output_in_json_format: 128 | tool_name = t["name"] 129 | matching_tool = next( 130 | (tool for tool in tools if tool.__name__ == tool_name), None 131 | ) 132 | if matching_tool: 133 | tool_outputs.append(ToolOutput(function=matching_tool, kwargs=t["args"])) 134 | if force_tool_calling: 135 | return tool_outputs 136 | return response, tool_outputs 137 | 138 | 139 | @chain 140 | async def return_response_and_logprobs(llm_output): 141 | response = await StrOutputParser().ainvoke(input=llm_output) 142 | return response, llm_output.response_metadata.get("logprobs") 143 | 144 | 145 | @validate_function() 146 | def pick_llm_resource(engine: str) -> dict: 147 | initialize_llm_config() 148 | if not GlobalVars.all_llm_endpoints: 149 | raise ValueError( 150 | "No LLM API found. Make sure configuration and API_KEY files are set correctly, and that initialize_llm_config() is called before using any other function." 151 | ) 152 | 153 | # Decide which LLM resource to send this request to. 154 | potential_llm_resources = [ 155 | resource 156 | for resource in GlobalVars.all_llm_endpoints 157 | if engine in resource["engine_map"] 158 | ] 159 | if len(potential_llm_resources) == 0: 160 | raise IndexError( 161 | f"Could not find any matching engines for {engine}. Please check that llm_config.yaml is configured correctly and that the API key is set in the terminal before running this script." 162 | ) 163 | 164 | llm_resource = random.choice(potential_llm_resources) 165 | return llm_resource 166 | 167 | 168 | def convert_to_structured_output_prompt(x: dict): 169 | main_llm_input = x["main_llm_input"] 170 | main_llm_output = x["main_llm_output"] 171 | return f"""An LLM was given these inputs: 172 | {main_llm_input} 173 | 174 | 175 | And produced this output: 176 | {main_llm_output} 177 | 178 | Convert this output to the expected JSON structured format.""" 179 | 180 | 181 | @validate_function() 182 | def llm_generation_chain( 183 | template_file: str, 184 | engine: str, 185 | max_tokens: int, 186 | temperature: float = 0.0, 187 | stop_tokens: Optional[List[str]] = None, 188 | top_p: float = 0.9, 189 | output_json: bool = False, 190 | pydantic_class: Any = None, 191 | engine_for_structured_output: Optional[str] = None, 192 | template_blocks: Optional[list[tuple[str, str]]] = None, 193 | keep_indentation: bool = False, 194 | progress_bar_desc: Optional[str] = None, 195 | additional_postprocessing_runnable: Optional[Runnable] = None, 196 | tools: Optional[list[Callable]] = None, 197 | force_tool_calling: bool = False, 198 | return_top_logprobs: int = 0, 199 | bind_prompt_values: Optional[dict] = None, 200 | force_skip_cache: bool = False, 201 | reasoning_effort: Optional[str] = None, 202 | ) -> Runnable: 203 | """ 204 | Constructs a LangChain generation chain for LLM response utilizing LLM APIs prescribed in the ChainLite config file. 205 | 206 | Parameters: 207 | template_file (str): The path to the generation template file. Must be a .prompt file. 208 | engine (str): The language model engine to employ. An engine represents the left-hand value of an `engine_map` in the ChainLite config file. 209 | max_tokens (int): The upper limit of tokens the LLM can generate. 210 | temperature (float, optional): Dictates the randomness in the generation. Must be >= 0.0. Defaults to 0.0 (deterministic). 211 | stop_tokens (List[str], optional): The list of tokens causing the LLM to stop generating. Defaults to None. 212 | top_p (float, optional): The max cumulative probability for nucleus sampling, must be within 0.0 - 1.0. Defaults to 0.9. 213 | output_json (bool, optional): If True, asks the LLM API to output a JSON. This depends on the underlying model to support. 214 | For example, GPT-4, GPT-4o and newer GPT-3.5-Turbo models support it, but require the word "json" to be present in the input. Defaults to False. 215 | pydantic_class (BaseModel, optional): If provided, will parse the output to match this Pydantic class. Only models like gpt-4o-mini and gpt-4o-2024-08-06 216 | and newer are supported 217 | engine_for_structured_output (str, optional): If provided, will use this engine to convert the base `engine`'s output to the expected json or Pydantic class. 218 | Helpful for when `engine` does not support structured output. Defaults to None. 219 | template_blocks: If provided, will use this instead of `template_file`. The format is [(role, string)] where role is one of "instruction", "input", "output" 220 | keep_indentation (bool, optional): If True, will keep indentations at the beginning of each line in the template_file. Defaults to False. 221 | progress_bar_name (str, optional): If provided, will display a `tqdm` progress bar using this name 222 | additional_postprocessing_runnable (Runnable, optional): If provided, will be applied to the output of LLM generation, and the final output will be logged 223 | tools (List[Callable], optional): If provided, will be made available to the underlying LLM, to optionally output it for function calling. Defaults to None. 224 | force_tool_calling (bool, optional): If True, will force the LLM to output the tools for function calling. Defaults to False. 225 | return_top_logprobs (int, optional): If > 0, will return the top logprobs for each token, so the output will be Tuple[str, dict]. Defaults to 0. 226 | bind_prompt_values (dict, optional): A dictionary containing {Variable: str : Value}. Binds values to the prompt. Additional variables can be provided when the chain is called. Defaults to {}. 227 | force_skip_cache (bool, optional): If True, will force the LLM to skip the cache, and the new value won't be saved in cache either. Defaults to False. 228 | reasoning_effort (str, optional): The reasoning effort to use for reasoning models like o1. Must be one of "low", "medium", "high". Defaults to medium. Cache is not sensitive to the value of this parameter, meaning that the cache is shared across reasoning effort values. 229 | 230 | Returns: 231 | Runnable: The language model generation chain 232 | 233 | Raises: 234 | IndexError: Raised when no engine matches the provided string in the LLM APIs configured, or the API key is not found. 235 | """ 236 | 237 | assert reasoning_effort in [ 238 | None, 239 | "low", 240 | "medium", 241 | "high", 242 | ], f"Invalid reasoning_effort: {reasoning_effort}. Valid values are 'low', 'medium', 'high'." 243 | 244 | if ( 245 | sum( 246 | [ 247 | bool(pydantic_class), 248 | bool(output_json), 249 | bool(tools), 250 | return_top_logprobs > 0, 251 | ] 252 | ) 253 | > 1 254 | ): 255 | raise ValueError( 256 | "At most one of `pydantic_class`, `output_json`, `return_top_logprobs` and `tools` can be used." 257 | ) 258 | if return_top_logprobs > 0 or tools: 259 | if engine_for_structured_output: 260 | raise ValueError( 261 | "engine_for_structured_output cannot be used with return_top_logprobs or tools." 262 | ) 263 | if engine_for_structured_output and not pydantic_class and not output_json: 264 | raise ValueError( 265 | "engine_for_structured_output requires either pydantic_class or output_json to be set." 266 | ) 267 | 268 | llm_resource = pick_llm_resource(engine) 269 | model = llm_resource["engine_map"][engine] 270 | 271 | # ChatLiteLLM expects these parameters in a separate dictionary for some reason 272 | model_kwargs = {} 273 | 274 | # TODO move these to ChatLiteLLM 275 | if engine in GlobalVars.local_engine_set: 276 | if temperature > 0: 277 | model_kwargs["do_sample"] = True 278 | else: 279 | model_kwargs["do_sample"] = False 280 | if top_p == 1: 281 | top_p = None 282 | 283 | if model.startswith("mistral/"): 284 | # Mistral API expects top_p to be 1 when greedy decoding 285 | if temperature == 0: 286 | top_p = 1 287 | 288 | should_cache = (temperature == 0) and not force_skip_cache 289 | 290 | is_distilled = ( 291 | "prompt_format" in llm_resource and llm_resource["prompt_format"] == "distilled" 292 | ) 293 | 294 | prompt = load_fewshot_prompt_template( 295 | template_file, 296 | template_blocks, 297 | is_distilled=is_distilled, 298 | keep_indentation=keep_indentation, 299 | ) 300 | 301 | if engine_for_structured_output: 302 | # Apply the structured output-related settings to the structured output engine 303 | structured_model_kwargs = model_kwargs.copy() 304 | if output_json: 305 | structured_model_kwargs["response_format"] = {"type": "json_object"} 306 | elif pydantic_class: 307 | structured_model_kwargs["response_format"] = pydantic_class 308 | structure_output_resource = pick_llm_resource(engine_for_structured_output) 309 | structured_output_llm = ChatLiteLLM( 310 | model=structure_output_resource["engine_map"][engine_for_structured_output], 311 | api_base=( 312 | structure_output_resource["api_base"] 313 | if "api_base" in structure_output_resource 314 | else None 315 | ), 316 | api_key=( 317 | structure_output_resource["api_key"] 318 | if "api_key" in structure_output_resource 319 | else None 320 | ), 321 | api_version=( 322 | structure_output_resource["api_version"] 323 | if "api_version" in structure_output_resource 324 | else None 325 | ), 326 | cache=should_cache, 327 | max_tokens=max_tokens, 328 | temperature=temperature, 329 | top_p=top_p, 330 | stop=stop_tokens, 331 | model_kwargs=structured_model_kwargs, 332 | ) 333 | else: 334 | # Apply the structured output-related settings to the main engine 335 | if output_json: 336 | model_kwargs["response_format"] = {"type": "json_object"} 337 | elif pydantic_class: 338 | model_kwargs["response_format"] = pydantic_class 339 | 340 | if return_top_logprobs > 0: 341 | model_kwargs["logprobs"] = True 342 | model_kwargs["top_logprobs"] = return_top_logprobs 343 | 344 | if reasoning_effort: 345 | # only include it when explicitly set, because most models do not support it 346 | model_kwargs["reasoning_effort"] = reasoning_effort 347 | 348 | if tools: 349 | from litellm.utils import function_to_dict 350 | 351 | function_json = [ 352 | {"type": "function", "function": function_to_dict(t)} for t in tools 353 | ] 354 | model_kwargs["tools"] = function_json 355 | model_kwargs["tool_choice"] = "required" if force_tool_calling else "auto" 356 | 357 | callbacks = [] 358 | if progress_bar_desc: 359 | cb = ProgbarHandler(progress_bar_desc) 360 | callbacks.append(cb) 361 | 362 | main_llm = ChatLiteLLM( 363 | model=model, 364 | api_base=llm_resource["api_base"] if "api_base" in llm_resource else None, 365 | api_key=llm_resource["api_key"] if "api_key" in llm_resource else None, 366 | api_version=( 367 | llm_resource["api_version"] if "api_version" in llm_resource else None 368 | ), 369 | cache=should_cache, 370 | max_tokens=max_tokens, 371 | temperature=temperature, 372 | top_p=top_p, 373 | stop=stop_tokens, 374 | callbacks=callbacks, 375 | model_kwargs=model_kwargs, 376 | ) 377 | 378 | if engine_for_structured_output: 379 | main_llm = ( 380 | { 381 | "main_llm_output": main_llm | StrOutputParser(), 382 | "main_llm_input": RunnablePassthrough(), 383 | } 384 | | RunnableLambda(convert_to_structured_output_prompt) 385 | | structured_output_llm 386 | ) 387 | 388 | if bind_prompt_values: 389 | prompt = prompt.partial(**bind_prompt_values) 390 | 391 | llm_generation_chain = prompt | main_llm 392 | if tools: 393 | llm_generation_chain = llm_generation_chain | return_response_and_tool.bind( 394 | tools=tools, force_tool_calling=force_tool_calling 395 | ) 396 | else: 397 | if return_top_logprobs > 0: 398 | llm_generation_chain = llm_generation_chain | return_response_and_logprobs 399 | else: 400 | llm_generation_chain = llm_generation_chain | StrOutputParser() 401 | 402 | if pydantic_class: 403 | llm_generation_chain = llm_generation_chain | string_to_pydantic_object.bind( 404 | pydantic_class=pydantic_class 405 | ) 406 | 407 | if additional_postprocessing_runnable: 408 | llm_generation_chain = llm_generation_chain | additional_postprocessing_runnable 409 | return llm_generation_chain.with_config( 410 | callbacks=[ChainLogHandler()], 411 | metadata={ 412 | "template_name": os.path.basename(template_file), 413 | }, 414 | ) # for logging to file 415 | -------------------------------------------------------------------------------- /chainlite/chat_lite_llm.py: -------------------------------------------------------------------------------- 1 | """Wrapper around LiteLLM. Modified from https://python.langchain.com/api_reference/_modules/langchain_community/chat_models/litellm.html#ChatLiteLLM""" 2 | 3 | from __future__ import annotations 4 | 5 | import json 6 | import logging 7 | from typing import ( 8 | Any, 9 | AsyncIterator, 10 | Callable, 11 | Dict, 12 | Iterator, 13 | List, 14 | Mapping, 15 | Optional, 16 | Sequence, 17 | Tuple, 18 | Type, 19 | Union, 20 | ) 21 | 22 | import warnings 23 | from pydantic import PydanticDeprecatedSince20 24 | 25 | warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20) 26 | 27 | from langchain_core.callbacks import ( 28 | AsyncCallbackManagerForLLMRun, 29 | CallbackManagerForLLMRun, 30 | ) 31 | from langchain_core.language_models import LanguageModelInput 32 | from langchain_core.language_models.chat_models import ( 33 | BaseChatModel, 34 | agenerate_from_stream, 35 | generate_from_stream, 36 | ) 37 | from langchain_core.messages import ( 38 | AIMessage, 39 | AIMessageChunk, 40 | BaseMessage, 41 | BaseMessageChunk, 42 | ChatMessage, 43 | ChatMessageChunk, 44 | FunctionMessage, 45 | FunctionMessageChunk, 46 | HumanMessage, 47 | HumanMessageChunk, 48 | SystemMessage, 49 | SystemMessageChunk, 50 | ToolCall, 51 | ToolCallChunk, 52 | ToolMessage, 53 | ) 54 | from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult 55 | from langchain_core.runnables import Runnable 56 | from langchain_core.tools import BaseTool 57 | from langchain_core.utils import pre_init 58 | from langchain_core.utils.function_calling import convert_to_openai_tool 59 | from pydantic import BaseModel, Field 60 | 61 | from tenacity import ( 62 | retry, 63 | retry_base, 64 | retry_if_exception_type, 65 | stop_after_attempt, 66 | wait_exponential, 67 | ) 68 | 69 | from chainlite.llm_config import GlobalVars 70 | from chainlite.llm_output import ToolOutput 71 | 72 | logger = logging.getLogger(__name__) 73 | 74 | 75 | class ChatLiteLLMException(Exception): 76 | """Error with the `LiteLLM I/O` library""" 77 | 78 | 79 | def _create_retry_decorator(llm) -> Callable[[Any], Any]: 80 | """Returns a tenacity retry decorator, configured to handle LLM exceptions""" 81 | import litellm 82 | 83 | errors = [ 84 | litellm.Timeout, 85 | litellm.APIError, 86 | litellm.APIConnectionError, 87 | litellm.RateLimitError, 88 | ] 89 | retry_instance: retry_base = retry_if_exception_type(errors[0]) 90 | for error in errors[1:]: 91 | retry_instance = retry_instance | retry_if_exception_type(error) 92 | return retry( 93 | # reraise=True, 94 | stop=stop_after_attempt(llm.max_retries), 95 | wait=wait_exponential(multiplier=1, min=4, max=20), 96 | retry=retry_instance, 97 | # before_sleep=_before_sleep, 98 | ) 99 | 100 | 101 | def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: 102 | role = _dict["role"] 103 | if role == "user": 104 | return HumanMessage(content=_dict["content"]) 105 | elif role == "assistant": 106 | # Fix for azure 107 | # Also OpenAI returns None for tool invocations 108 | content = _dict.get("content", "") or "" 109 | 110 | additional_kwargs = {} 111 | if _dict.get("function_call"): 112 | additional_kwargs["function_call"] = dict(_dict["function_call"]) 113 | 114 | if _dict.get("tool_calls"): 115 | additional_kwargs["tool_calls"] = _dict["tool_calls"] 116 | 117 | return AIMessage(content=content, additional_kwargs=additional_kwargs) 118 | elif role == "system": 119 | return SystemMessage(content=_dict["content"]) 120 | elif role == "function": 121 | return FunctionMessage(content=_dict["content"], name=_dict["name"]) 122 | else: 123 | return ChatMessage(content=_dict["content"], role=role) 124 | 125 | 126 | async def acompletion_with_retry( 127 | llm: ChatLiteLLM, 128 | run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, 129 | **kwargs: Any, 130 | ) -> Any: 131 | """Use tenacity to retry the async completion call.""" 132 | retry_decorator = _create_retry_decorator(llm) 133 | 134 | @retry_decorator 135 | async def _completion_with_retry(**kwargs: Any) -> Any: 136 | import litellm 137 | 138 | return await litellm.acreate(**kwargs) 139 | 140 | return await _completion_with_retry(**kwargs) 141 | 142 | 143 | def _convert_delta_to_message_chunk( 144 | _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] 145 | ) -> BaseMessageChunk: 146 | role = _dict.get("role") 147 | content = _dict.get("content") or "" 148 | if _dict.get("function_call"): 149 | additional_kwargs = {"function_call": dict(_dict["function_call"])} 150 | else: 151 | additional_kwargs = {} 152 | 153 | tool_call_chunks = [] 154 | if raw_tool_calls := _dict.get("tool_calls"): 155 | additional_kwargs["tool_calls"] = raw_tool_calls 156 | try: 157 | tool_call_chunks = [ 158 | ToolCallChunk( 159 | name=rtc["function"].get("name"), 160 | args=rtc["function"].get("arguments"), 161 | id=rtc.get("id"), 162 | index=rtc["index"], 163 | ) 164 | for rtc in raw_tool_calls 165 | ] 166 | except KeyError: 167 | pass 168 | 169 | if role == "user" or default_class == HumanMessageChunk: 170 | return HumanMessageChunk(content=content) 171 | elif role == "assistant" or default_class == AIMessageChunk: 172 | return AIMessageChunk( 173 | content=content, 174 | additional_kwargs=additional_kwargs, 175 | tool_call_chunks=tool_call_chunks, 176 | ) 177 | elif role == "system" or default_class == SystemMessageChunk: 178 | return SystemMessageChunk(content=content) 179 | elif role == "function" or default_class == FunctionMessageChunk: 180 | return FunctionMessageChunk(content=content, name=_dict["name"]) 181 | elif role or default_class == ChatMessageChunk: 182 | return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] 183 | else: 184 | return default_class(content=content) # type: ignore[call-arg] 185 | 186 | 187 | def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict: 188 | return { 189 | "type": "function", 190 | "id": tool_call["id"], 191 | "function": { 192 | "name": tool_call["name"], 193 | "arguments": json.dumps(tool_call["args"]), 194 | }, 195 | } 196 | 197 | 198 | def _convert_message_to_dict(message: BaseMessage) -> dict: 199 | message_dict: Dict[str, Any] = {"content": message.content} 200 | if isinstance(message, ChatMessage): 201 | message_dict["role"] = message.role 202 | elif isinstance(message, HumanMessage): 203 | message_dict["role"] = "user" 204 | elif isinstance(message, AIMessage): 205 | message_dict["role"] = "assistant" 206 | if "function_call" in message.additional_kwargs: 207 | message_dict["function_call"] = message.additional_kwargs["function_call"] 208 | if message.tool_calls: 209 | message_dict["tool_calls"] = [ 210 | _lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls 211 | ] 212 | elif "tool_calls" in message.additional_kwargs: 213 | message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] 214 | elif isinstance(message, SystemMessage): 215 | message_dict["role"] = "system" 216 | elif isinstance(message, FunctionMessage): 217 | message_dict["role"] = "function" 218 | message_dict["name"] = message.name 219 | elif isinstance(message, ToolMessage): 220 | message_dict["role"] = "tool" 221 | message_dict["tool_call_id"] = message.tool_call_id 222 | else: 223 | raise ValueError(f"Got unknown type {message}") 224 | if "name" in message.additional_kwargs: 225 | message_dict["name"] = message.additional_kwargs["name"] 226 | return message_dict 227 | 228 | 229 | class ChatLiteLLM(BaseChatModel): 230 | """Chat model that uses the LiteLLM API.""" 231 | 232 | model: str = "" 233 | api_key: Optional[str] = None 234 | api_base: Optional[str] = None 235 | api_version: Optional[str] = None 236 | streaming: bool = False 237 | temperature: Optional[float] = 0 238 | model_kwargs: Dict[str, Any] = Field(default_factory=dict) 239 | top_p: Optional[float] = None 240 | top_k: Optional[int] = None 241 | max_tokens: Optional[int] = None 242 | template_file: Optional[str] = None 243 | instruction: Optional[str] = None 244 | 245 | max_retries: int = 6 246 | 247 | @property 248 | def _default_params(self) -> Dict[str, Any]: 249 | """Get the default parameters for calling LLM API.""" 250 | return { 251 | "model": self.model, 252 | "max_tokens": self.max_tokens, 253 | "stream": self.streaming, 254 | "temperature": self.temperature, 255 | "model": self.model, 256 | "api_base": self.api_base, 257 | "api_key": self.api_key, 258 | "api_version": self.api_version, 259 | **self.model_kwargs, 260 | } 261 | 262 | def completion_with_retry( 263 | self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any 264 | ) -> Any: 265 | """Use tenacity to retry the completion call.""" 266 | retry_decorator = _create_retry_decorator(self, run_manager=run_manager) 267 | 268 | @retry_decorator 269 | def _completion_with_retry(**kwargs: Any) -> Any: 270 | import litellm 271 | 272 | return litellm.completion(**kwargs) 273 | 274 | return _completion_with_retry(**kwargs) 275 | 276 | @pre_init 277 | def validate_environment(cls, values: Dict) -> Dict: 278 | """Validate api key, python package exists, temperature, top_p, and top_k.""" 279 | values["api_key"] = values.get("api_key", "") 280 | 281 | if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: 282 | raise ValueError("temperature must be in the range [0.0, 1.0]") 283 | 284 | if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: 285 | raise ValueError("top_p must be in the range [0.0, 1.0]") 286 | 287 | if values["top_k"] is not None and values["top_k"] <= 0: 288 | raise ValueError("top_k must be positive") 289 | 290 | return values 291 | 292 | def _generate( 293 | self, 294 | messages: List[BaseMessage], 295 | run_manager: Optional[CallbackManagerForLLMRun] = None, 296 | stream: Optional[bool] = None, 297 | **kwargs: Any, 298 | ) -> ChatResult: 299 | should_stream = stream if stream is not None else self.streaming 300 | if should_stream: 301 | stream_iter = self._stream(messages, run_manager=run_manager, **kwargs) 302 | return generate_from_stream(stream_iter) 303 | 304 | message_dicts, params = self._create_message_dicts(messages) 305 | params = {**params, **kwargs} 306 | response = self.completion_with_retry( 307 | messages=message_dicts, run_manager=run_manager, **params 308 | ) 309 | return self._create_chat_result(response) 310 | 311 | def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: 312 | generations = [] 313 | for res in response["choices"]: 314 | message = _convert_dict_to_message(res["message"]) 315 | gen = ChatGeneration( 316 | message=message, 317 | generation_info=dict(finish_reason=res.get("finish_reason")), 318 | ) 319 | generations.append(gen) 320 | token_usage = response.get("usage", {}) 321 | llm_output = {"token_usage": token_usage, "model": self.model} 322 | if ( 323 | "logprobs" in response["choices"][0] 324 | and "content" in response["choices"][0]["logprobs"] 325 | ): 326 | llm_output["logprobs"] = ( 327 | response["choices"][0].get("logprobs").get("content") 328 | ) 329 | return ChatResult(generations=generations, llm_output=llm_output) 330 | 331 | def _create_message_dicts( 332 | self, messages: List[BaseMessage] 333 | ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: 334 | params = self._default_params 335 | message_dicts = [_convert_message_to_dict(m) for m in messages] 336 | return message_dicts, params 337 | 338 | def _stream( 339 | self, 340 | messages: List[BaseMessage], 341 | run_manager: Optional[CallbackManagerForLLMRun] = None, 342 | **kwargs: Any, 343 | ) -> Iterator[ChatGenerationChunk]: 344 | message_dicts, params = self._create_message_dicts(messages) 345 | params = {**params, **kwargs, "stream": True} 346 | 347 | default_chunk_class = AIMessageChunk 348 | for chunk in self.completion_with_retry( 349 | messages=message_dicts, run_manager=run_manager, **params 350 | ): 351 | if not isinstance(chunk, dict): 352 | chunk = chunk.model_dump() 353 | if len(chunk["choices"]) == 0: 354 | continue 355 | delta = chunk["choices"][0]["delta"] 356 | chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) 357 | default_chunk_class = chunk.__class__ 358 | cg_chunk = ChatGenerationChunk(message=chunk) 359 | if run_manager: 360 | run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) 361 | yield cg_chunk 362 | 363 | async def _astream( 364 | self, 365 | messages: List[BaseMessage], 366 | run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, 367 | **kwargs: Any, 368 | ) -> AsyncIterator[ChatGenerationChunk]: 369 | message_dicts, params = self._create_message_dicts(messages) 370 | params = {**params, **kwargs, "stream": True} 371 | 372 | default_chunk_class = AIMessageChunk 373 | async for chunk in await acompletion_with_retry( 374 | self, messages=message_dicts, run_manager=run_manager, **params 375 | ): 376 | if not isinstance(chunk, dict): 377 | chunk = chunk.model_dump() 378 | if len(chunk["choices"]) == 0: 379 | continue 380 | delta = chunk["choices"][0]["delta"] 381 | chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) 382 | default_chunk_class = chunk.__class__ 383 | cg_chunk = ChatGenerationChunk(message=chunk) 384 | if run_manager: 385 | await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) 386 | yield cg_chunk 387 | 388 | async def _agenerate( 389 | self, 390 | messages: List[BaseMessage], 391 | run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, 392 | stream: Optional[bool] = None, 393 | **kwargs: Any, 394 | ) -> ChatResult: 395 | should_stream = stream if stream is not None else self.streaming 396 | if should_stream: 397 | stream_iter = self._astream( 398 | messages=messages, run_manager=run_manager, **kwargs 399 | ) 400 | return await agenerate_from_stream(stream_iter) 401 | 402 | message_dicts, params = self._create_message_dicts(messages) 403 | params = {**params, **kwargs} 404 | response = await acompletion_with_retry( 405 | self, messages=message_dicts, run_manager=run_manager, **params 406 | ) 407 | return self._create_chat_result(response) 408 | 409 | def bind_tools( 410 | self, 411 | tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], 412 | **kwargs: Any, 413 | ) -> Runnable[LanguageModelInput, BaseMessage]: 414 | """Bind tool-like objects to this chat model. 415 | 416 | LiteLLM expects tools argument in OpenAI format. 417 | 418 | Args: 419 | tools: A list of tool definitions to bind to this chat model. 420 | Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic 421 | models, callables, and BaseTools will be automatically converted to 422 | their schema dictionary representation. 423 | tool_choice: Which tool to require the model to call. 424 | Must be the name of the single provided function or 425 | "auto" to automatically determine which function to call 426 | (if any), or a dict of the form: 427 | {"type": "function", "function": {"name": <>}}. 428 | **kwargs: Any additional parameters to pass to the 429 | :class:`~langchain.runnable.Runnable` constructor. 430 | """ 431 | 432 | formatted_tools = [convert_to_openai_tool(tool) for tool in tools] 433 | return super().bind(tools=formatted_tools, **kwargs) 434 | 435 | @property 436 | def _identifying_params(self) -> Dict[str, Any]: 437 | """Get the identifying parameters. 438 | These are used as the cache key for LLM calls.""" 439 | return { 440 | "model": self.model, 441 | "temperature": self.temperature, 442 | "top_p": self.top_p, 443 | "top_k": self.top_k, 444 | "reasoning_effort": self.model_kwargs.get("reasoning_effort"), 445 | "max_tokens": self.max_tokens, 446 | } 447 | 448 | @property 449 | def _llm_type(self) -> str: 450 | return "litellm-chat" 451 | 452 | def _log_inputs_and_outputs( 453 | self, messages: List[BaseMessage], response: Any 454 | ) -> None: 455 | """Log the inputs and outputs of the model.""" 456 | llm_input = messages[0][-1].content 457 | if messages[0][-1].type == "system": 458 | # it means the prompt did not have an `# input` block, and only has an instruction block 459 | llm_input = "" 460 | 461 | prompt_log = { 462 | "template_name": self.template_name, 463 | "instruction": self.instruction, 464 | "input": llm_input, 465 | } 466 | 467 | if ( 468 | isinstance(response, tuple) 469 | and len(response) == 2 470 | and isinstance(response[1], ToolOutput) 471 | ): 472 | response = list(response) 473 | response[1] = str(response[1]) 474 | elif isinstance(response, ToolOutput): 475 | response = str(response) 476 | if isinstance(response, tuple) and len(response) == 2: 477 | response = list(response) 478 | # if exactly one is not None/empty, then we want to log that one 479 | if response[0] and not response[1]: 480 | response = response[0] 481 | elif not response[0] and response[1]: 482 | response = response[1] 483 | 484 | prompt_log["output"] = str(response) 485 | 486 | GlobalVars.prompt_logs.append(prompt_log) 487 | --------------------------------------------------------------------------------