├── .bumpversion.cfg
├── .github
└── workflows
│ └── tests.yaml
├── .gitignore
├── LICENSE
├── README.md
├── architecture.excalidraw
├── commit.sh
├── docs
├── examples
│ ├── agent_example.py
│ ├── agent_stream_example.py
│ ├── function_calling.py
│ └── llm_chat.py
└── get_started.md
├── release.sh
├── setup.py
├── tinyllm.yaml
└── tinyllm
├── __init__.py
├── agent
├── __init__.py
├── agent.py
├── agent_stream.py
└── tool
│ ├── __init__.py
│ ├── tool.py
│ ├── toolkit.py
│ └── tools
│ ├── __init__.py
│ ├── code_interpreter.py
│ ├── think_plan.py
│ └── wikipedia.py
├── constants.py
├── eval
├── __init__.py
├── evaluation_model.py
├── evaluator.py
├── evaluators
│ ├── __init__.py
│ ├── answer_accuracy_evaluator.py
│ └── retrieval_evaluator.py
├── qa_generator.py
└── rag_eval_pipeline.py
├── examples
├── __init__.py
├── example_manager.py
└── example_selector.py
├── exceptions.py
├── function.py
├── function_stream.py
├── llms
├── __init__.py
├── constants.py
├── lite_llm.py
├── lite_llm_stream.py
└── tiny_function.py
├── memory
├── __init__.py
└── memory.py
├── prompt_manager.py
├── rag
├── __init__.py
├── document
│ ├── __init__.py
│ ├── document.py
│ └── store.py
├── loaders
│ ├── __init__.py
│ ├── image_loader.py
│ └── loaders.py
├── rerank.py
└── vector_store.py
├── state.py
├── test_vector_store.py
├── tests
├── __init__.py
├── base.py
├── test_agent.py
├── test_agent_stream.py
├── test_document_store.py
├── test_evaluators.py
├── test_example_selector.py
├── test_function.py
├── test_litellm.py
├── test_litellm_stream.py
├── test_memory.py
├── test_rag.py
├── test_tiny_function.py
└── test_tracing.py
├── tracing
├── __init__.py
├── helpers.py
└── langfuse_context.py
├── util
├── __init__.py
├── ai_util.py
├── db_util.py
├── helpers.py
├── message.py
├── os_util.py
├── parse_util.py
└── prompt_util.py
└── validator.py
/.bumpversion.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | current_version = 0.1.1
3 | commit = True
4 | tag = True
5 |
6 | [bumpversion:file:setup.py]
7 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yaml:
--------------------------------------------------------------------------------
1 | name: Tinyllm tests
2 |
3 | on:
4 | push:
5 | branches: [ none ]
6 |
7 | jobs:
8 | build:
9 |
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - uses: actions/checkout@v2
14 | - name: Set up Python
15 | uses: actions/setup-python@v2
16 | with:
17 | python-version: '3.11'
18 |
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install --upgrade pip
22 | pip install .
23 |
24 | - name: Test with pytest
25 | run: |
26 | pip install pytest
27 | pytest
28 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *DS_Store
6 | *egg-info
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/#use-with-ide
111 | .pdm.toml
112 |
113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114 | __pypackages__/
115 |
116 | # Celery stuff
117 | celerybeat-schedule
118 | celerybeat.pid
119 |
120 | # SageMath parsed files
121 | *.sage.py
122 |
123 | # Environments
124 | .env
125 | .venv
126 | env/
127 | venv/
128 | ENV/
129 | env.bak/
130 | venv.bak/
131 |
132 | # Spyder project settings
133 | .spyderproject
134 | .spyproject
135 |
136 | # Rope project settings
137 | .ropeproject
138 |
139 | # mkdocs documentation
140 | /site
141 |
142 | # mypy
143 | .mypy_cache/
144 | .dmypy.json
145 | dmypy.json
146 |
147 | # Pyre type checker
148 | .pyre/
149 |
150 | # pytype static type analyzer
151 | .pytype/
152 |
153 | # Cython debug symbols
154 | cython_debug/
155 |
156 | # PyCharm
157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159 | # and can be added to the global gitignore or merged into this file. For a more nuclear
160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161 | #.idea/
162 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 zozoheir
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | # 🚀 What is tinyllm?
8 | tinyllm is a lightweight framework for developing, debugging and monitoring LLM and Agent powered applications at scale. The main goal of the library is to keep code as simple and readable as possible while allowing user to create complex agents or LLM workflows in production.
9 |
10 | `Function` and its streaming equivalent `FunctionStream` are the core classes in tinyllm. They are designed to standardize and control LLM, ToolStore and any relevant calls for scalable production use in stream mode and otherwise.
11 |
12 | It provides a structured approach to handle various aspects of function execution, including input/output validation, output processing, error handling, evaluation, all while keeping code readable. You can create a chain with its own prompt, LLM model and evaluators all in a single file. No need to jump through many class definitions, no spaghetti code. Any other library agent/chain (langchain/llama-index...) can also seamlessly be imported as a tinyllm Function.
13 |
14 |
15 | ## 🚀 Install
16 | ```
17 | pip install tinyllm
18 | ```
19 |
20 | ## 🚀 Getting started
21 | * #### [Setup](https://github.com/zozoheir/tinyllm/blob/main/docs/get_started.md)
22 | * #### [Examples](https://github.com/zozoheir/tinyllm/blob/main/docs/examples/)
23 |
24 |
25 | ## 🚀 Features
26 | #### Build LLM apps with:
27 | - **LiteLLM integration**: 20+ model providers available (OpenAI, Huggingface etc ...)
28 | - **Langfuse integration**: Monitor trace and debug LLMs, Agents, Tools, RAG pipelines etc in structured run trees
29 | - **Agents:** An agent is an LLM with Memory, a Toolkit and an ExampleManager
30 | - **ToolStore and Toolkits**: let your Agent run python functions using ToolStore
31 | - **Example manager**: constant examples + variable examples using and example selector with similarity search
32 | - **Memory:** conversations history
33 | - **Retrieval Augmented Generation**: RAG tools to search and generate answers
34 | - **Evaluation:** Evaluators can be defined to evaluate and log the quality of the function's output in real-time
35 | - **PGVector store:** PostgreSQL DB with the pgvector extension for vector storage.
36 | - **Prompt engineering tools:** utility modules for prompt engineering, optimization and string formatting
37 |
38 | #### 🚀 Deploy to production with:
39 | - **Layered validation:** 3 validations happen during the Function lifecycle: input, output and output processing.
40 | - **IO Standardization:** Maintains consistent response patterns and failure handling across different function implementations.
41 | - **Observability:** Integrates with Langfuse for
42 | - **Logging:** Records detailed logs for debugging and auditing purposes.
43 | - **Finite State Machine design:** Manages the function's lifecycle through defined states, ensuring controlled and predictable execution.
44 |
45 | #### Tiny function wrapper
46 |
47 | ```python
48 | class RiskScoreOutput(BaseModel):
49 | risk_score: float
50 |
51 | @tiny_function(output_model=RiskScoreOutput)
52 | async def calculate_risk_score(bank_account_history: str, employment_history: str):
53 | """
54 |
55 | Extract a Risk Score between 0 and 1 for a Credit Card application based on bank account and employment history.
56 |
57 |
58 |
59 | Given the bank account history: {bank_account_history}
60 | And the employment history: {employment_history}
61 | Calculate the risk score for a credit card application.
62 |
63 | """
64 | pass
65 | ```
66 |
67 | #### Tracing with Langfuse
68 |
69 |
70 |
71 |
72 |
73 | ## Background and goals
74 | Many of the LLM libraries today (langchain, llama-index, deep pavlov...) have made serious software design commitments which I believe were too early to make given the infancy of the industry.
75 | The goals of tinyllm are:
76 | * **Solve painpoints from current libraries**: lack of composability (within + between libraries), complex software designs, code readability, debugging and logging.
77 | * **High level, robust abstractions**: tinyllm is designed to be as simple as possible to use and integrate with existing and living codebases.
78 | * **Human and machine readable code** to enable AI powered and autonomous chain development
79 |
80 | ## API model
81 | LLM Functions are designed to behave like a web API. All Functions will always, even if failed, return a dictionary response.
82 |
83 | #### Validation
84 | Validations are defined through a Pydantic model and are provided to the Function using input_validator, output_validator and output_processing_validator args to a Function
85 |
86 | ## Tracing
87 | tinyllm is integrated with Langfuse for tracing chains, functions and agents.
88 | 
89 |
90 | ### Managing configs and credentials
91 | Configs are managed through a tinyllm.yaml file. It gets picked up at runtime in tinyllm.__init__ and can be placed in any of /Documents, your root folder, or the current working directory.
92 | An empty tinyllm.yaml file is at the source of the repo to get you setup.
93 |
94 |
95 | ## ⚡ Concurrency vs Parallelism vs Chaining
96 | These tend to be confusing across the board. Here's a quick explanation:
97 | - **Concurrency** : This means more than 1 Input/Ouput request at a time. Just like you can download 10 files
98 | concurrently on your web browser, you can call 10 APIs concurrently.
99 | - **Chaining** : An ordered list of Functions where a Function's output is the input of the next Function in the chain.
100 | - **Parallelism** : compute/calculations being performed on more than 1 process/CPU Core on the same machine. This is what
101 | model providers like OpenAI do using large GPU clusters (Nvidia, AMD...). This is used for "CPU Bound" tasks.
102 |
103 | Tinyllm does not care about Parallelism. Parallelism is implemented by LLM providers
104 | on a GPU/CPU level and should be abstracted away using an LLM microservice.
105 | Tinyllm only cares about Concurrency, Chaining and organizing IO Bound tasks.
106 |
107 |
108 |
109 | ### Logging
110 |
111 | Finite state machine with predictable and controlled state transitions for easy debugging of your chains/compute graphs.
112 |
113 | Below is the start and end of a trace for asking "What is the weather in Puerto Rico?" to an Agent with a get_weather Tool.
114 |
115 | ```
116 | INFO | tinyllm.function | 2023-12-25 19:37:10,617 : [Standard example selector] transition to: States.INIT
117 | INFO | tinyllm.function | 2023-12-25 19:37:12,720 : [BufferMemory] transition to: States.INIT
118 | INFO | tinyllm.function | 2023-12-25 19:37:12,729 : [get_weather] transition to: States.INIT
119 | INFO | tinyllm.function | 2023-12-25 19:37:12,729 : [Toolkit] transition to: States.INIT
120 | INFO | tinyllm.function | 2023-12-25 19:37:12,731 : [LiteLLM] transition to: States.INIT
121 | ...
122 | ...
123 | INFO | tinyllm.function | 2023-12-25 19:37:17,150 : [AnswerCorrectnessEvaluator] transition to: States.PROCESSING_OUTPUT
124 | INFO | tinyllm.function | 2023-12-25 19:37:17,151 : [AnswerCorrectnessEvaluator] transition to: States.PROCESSED_OUTPUT_VALIDATION
125 | INFO | tinyllm.function | 2023-12-25 19:37:17,151 : [AnswerCorrectnessEvaluator] transition to: States.COMPLETE
126 | INFO | tinyllm.function | 2023-12-25 19:37:17,846 : [Agent] transition to: States.PROCESSING_OUTPUT
127 | INFO | tinyllm.function | 2023-12-25 19:37:17,847 : [Agent] transition to: States.PROCESSED_OUTPUT_VALIDATION
128 | INFO | tinyllm.function | 2023-12-25 19:37:17,847 : [Agent] transition to: States.COMPLETE
129 | {'status': 'success', 'output': {'response': {'id': 'chatcmpl-8ZpjY0QmXbDiMIcSRwKuCUny4sxul', 'choices': [{'finish_reason': 'stop', 'index': 0, 'message': {'content': "It is 25 degrees celsius in Puerto Rico", 'role': 'assistant'}}], 'created': 1703551035, 'model': 'gpt-3.5-turbo-0613', 'object': 'chat.completion', 'system_fingerprint': None, 'usage': {'completion_tokens': 12, 'prompt_tokens': 138, 'total_tokens': 150}, '_response_ms': 785.606}}}
130 | ```
131 |
132 |
133 | ## ⚡ Concurrency vs Parallelism vs Chaining
134 | These tend to be confusing across the board. Here's a quick explanation:
135 | - **Concurrency** : This means more than 1 Input/Ouput request at a time. Just like you can download 10 files
136 | concurrently on your web browser, you can call 10 APIs concurrently.
137 | - **Chaining** : An ordered list of Functions where a Function's output is the input of the next Function in the chain.
138 | - **Parallelism** : compute/calculations being performed on more than 1 process/CPU Core on the same machine. This is what
139 | model providers like OpenAI do using large GPU clusters (Nvidia, AMD...). This is used for "CPU Bound" tasks.
140 |
141 | Tinyllm does not care about Parallelism. Parallelism is implemented by LLM providers
142 | on a GPU/CPU level and should be abstracted away using an LLM microservice.
143 | Tinyllm only cares about Concurrency, Chaining and organizing IO Bound tasks.
144 |
145 |
146 |
--------------------------------------------------------------------------------
/commit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export TINYLLM_CONFIG_PATH='/Users/othmanezoheir/PycharmProjects/zoheir-consulting/context-ai-env/tinyllm.yaml'
3 | PYTHONPATH="$PYTHONPATH:/Users/othmanezoheir/PycharmProjects/personal/smartpy"
4 | PYTHONPATH="$PYTHONPATH:/Users/othmanezoheir/PycharmProjects/rumorz-io/rumorz-backend"
5 | PYTHONPATH="$PYTHONPATH:/Users/othmanezoheir/PycharmProjects/rumorz-io/rumorz-env"
6 | PYTHONPATH="$PYTHONPATH:/Users/othmanezoheir/PycharmProjects/rumorz-io/rumorz-jobs"
7 | PYTHONPATH="$PYTHONPATH:/Users/othmanezoheir/PycharmProjects/rumorz-io/rumorz-models"
8 | PYTHONPATH="$PYTHONPATH:/Users/othmanezoheir/PycharmProjects/rumorz-io/rumorz-pyenv"
9 | PYTHONPATH="$PYTHONPATH:/Users/othmanezoheir/PycharmProjects/rumorz-io/rumorz-llms"
10 | PYTHONPATH="$PYTHONPATH:/Users/othmanezoheir/PycharmProjects/personal/langdt"
11 |
12 | export PYTHONPATH
13 |
14 |
15 | # Set up a virtual environment path for ease of use
16 | VENV_PATH="/Users/othmanezoheir/venv/rumorz-jobs-2/bin/python"
17 | cd /Users/othmanezoheir/PycharmProjects/openagents/tinyllm
18 | # Run unittest discovery in the tinyllm/tests directory
19 | if find ./tinyllm -name "test_*.py" ! -name "*vector_store*" -exec echo {} \; | xargs -n1 $VENV_PATH -m unittest; then
20 | # If tests pass, proceed to add, commit, and push changes to git
21 | git add .
22 | echo "Enter commit message:"
23 | read commit_message
24 | git commit -m "[$commit_message]"
25 | if [ $? -eq 0 ]; then
26 | git push
27 | echo "Changes pushed successfully!"
28 | else
29 | echo "Commit failed, changes not pushed."
30 | fi
31 | else
32 | # If tests fail, halt the process
33 | echo "Tests failed, changes not added or committed."
34 | fi
35 |
--------------------------------------------------------------------------------
/docs/examples/agent_example.py:
--------------------------------------------------------------------------------
1 | """
2 | VISUALIZATION OF THE AGENT STREAM EXAMPLE
3 |
4 | https://us.cloud.langfuse.com/project/cloz2bp020000l008kg9ujywd/traces/e5d46730-8528-4bd1-935d-ab949904f80d
5 | """
6 |
7 | import asyncio
8 |
9 | from tinyllm.agent.agent import Agent
10 | from tinyllm.agent.tool import Toolkit
11 | from tinyllm.agent.tool.tool import Tool
12 | from tinyllm.eval.evaluator import Evaluator
13 | from tinyllm.memory.memory import BufferMemory
14 |
15 | loop = asyncio.get_event_loop()
16 |
17 | class AnswerCorrectnessEvaluator(Evaluator):
18 |
19 | async def run(self, **kwargs):
20 | completion = kwargs['response']['choices'][0]['message']['content']
21 | evals = {
22 | "evals": {
23 | "correct_answer": 1 if 'january 1st' in completion.lower() else 0
24 | },
25 | "metadata": {}
26 | }
27 |
28 | return evals
29 |
30 |
31 | def get_user_property(asked_property):
32 | if asked_property == "name":
33 | return "Elias"
34 | elif asked_property == "birthday":
35 | return "January 1st"
36 |
37 |
38 | tools = [
39 | Tool(
40 | name="get_user_property",
41 | description="This is the tool you use retrieve ANY information about the user. Use it to answer questions about his birthday and any personal info",
42 | python_lambda=get_user_property,
43 | parameters={
44 | "type": "object",
45 | "properties": {
46 | "asked_property": {
47 | "type": "string",
48 | "enum": ["birthday", "name"],
49 | "description": "The specific property the user asked about",
50 | },
51 | },
52 | "required": ["asked_property"],
53 | },
54 | )
55 | ]
56 | toolkit = Toolkit(
57 | name='Toolkit',
58 | tools=tools,
59 | )
60 |
61 | async def run_agent():
62 | tiny_agent = Agent(system_role="You are a helpful agent that can answer questions about the user's profile using available tools.",
63 | toolkit=toolkit,
64 | memory=BufferMemory(),
65 | run_evaluators=[
66 | AnswerCorrectnessEvaluator(),
67 | ])
68 |
69 | result = await tiny_agent(content="What is the user's birthday?")
70 | print(result)
71 |
72 | result = loop.run_until_complete(run_agent())
73 |
--------------------------------------------------------------------------------
/docs/examples/agent_stream_example.py:
--------------------------------------------------------------------------------
1 | """
2 | VISUALIZATION OF THE AGENT STREAM EXAMPLE
3 |
4 | https://us.cloud.langfuse.com/project/cloz2bp020000l008kg9ujywd/traces/39a332ed-8c3d-4bc7-b4b7-b1738d1a5912
5 | """
6 |
7 | import asyncio
8 |
9 | from tinyllm.agent.agent_stream import AgentStream
10 | from tinyllm.agent.tool import Toolkit
11 | from tinyllm.agent.tool.tool import Tool
12 | from tinyllm.eval.evaluator import Evaluator
13 |
14 | loop = asyncio.get_event_loop()
15 |
16 | class AnswerCorrectnessEvaluator(Evaluator):
17 | async def run(self, **kwargs):
18 | completion = kwargs['output']['completion']
19 | evals = {
20 | "evals": {
21 | "correct_answer": 1 if 'january 1st' in completion.lower() else 0
22 | },
23 | "metadata": {}
24 | }
25 | return evals
26 |
27 |
28 | def get_user_property(asked_property):
29 | if asked_property == "name":
30 | return "Elias"
31 | elif asked_property == "birthday":
32 | return "January 1st"
33 |
34 |
35 | tools = [
36 | Tool(
37 | name="get_user_property",
38 | description="This is the tool you use retrieve ANY information about the user. Use it to answer questions about his birthday and any personal info",
39 | python_lambda=get_user_property,
40 | parameters={
41 | "type": "object",
42 | "properties": {
43 | "asked_property": {
44 | "type": "string",
45 | "enum": ["birthday", "name"],
46 | "description": "The specific property the user asked about",
47 | },
48 | },
49 | "required": ["asked_property"],
50 | },
51 | )
52 | ]
53 | toolkit = Toolkit(
54 | name='Toolkit',
55 | tools=tools
56 | )
57 |
58 |
59 | async def run_agent_stream():
60 |
61 |
62 | tiny_agent = AgentStream(system_role="You are a helpful agent that can answer questions about the user's profile using available tools.",
63 | toolkit=toolkit,
64 | run_evaluators=[
65 | AnswerCorrectnessEvaluator(
66 | name="Functional call corrector",
67 |
68 | ),
69 | ])
70 |
71 | msgs = []
72 | async for message in tiny_agent(content="What is the user's birthday?"):
73 | msgs.append(message)
74 | return msgs
75 |
76 | result = loop.run_until_complete(run_agent_stream())
77 |
--------------------------------------------------------------------------------
/docs/examples/function_calling.py:
--------------------------------------------------------------------------------
1 | import tinyllm
2 | from litellm import completion
3 |
4 | from tinyllm.agent.tool import tinyllm_toolkit
5 |
6 | messages = [
7 | {
8 | "role": "system",
9 | "content": "You are a helpful assistant"
10 | },
11 | {
12 | "name": "think_and_plan",
13 | "role": "tool",
14 | "content": "Execution plan: \n1. Get the population of Morocco\n2. Get the population of Senegal\n3. Multiply the population of Morocco by the population of Senegal\n4. Get Elon Musk's age\n5. Square the result from step 3 by Elon Musk's age",
15 | "tool_call_id": "call_JmVeXsLA8hEeWdhWYMoECpr4"
16 | },
17 | {
18 | "role": "assistant",
19 | "content": "",
20 | "tool_calls": [
21 | {
22 | "id": "call_Ord7obJdU2EK62c2T82t56DM",
23 | "type": "function",
24 | "function": {
25 | "name": "get_wikipedia_summary",
26 | "arguments": {
27 | "page_title": "Demographics_of_Morocco"
28 | }
29 | }
30 | }
31 | ]
32 | },
33 | {
34 | "name": "get_wikipedia_summary",
35 | "role": "tool",
36 | "content": "Demographic features of the population of Morocco include population density, ethnicity, education level, health of the populace, economic status, religious affiliations and other aspects of the population. The population of Morocco in 2021 is 37.271 million.",
37 | "tool_call_id": "call_Ord7obJdU2EK62c2T82t56DM"
38 | },
39 | {
40 | "role": "assistant",
41 | "content": "",
42 | "tool_calls": [
43 | {
44 | "id": "call_v65qP9Jc2dAY3C6X8O5YzjHX",
45 | "type": "function",
46 | "function": {
47 | "name": "get_wikipedia_summary",
48 | "arguments": {
49 | "page_title": "Demographics_of_Senegal"
50 | }
51 | }
52 | }
53 | ]
54 | },
55 | {
56 | "name": "get_wikipedia_summary",
57 | "role": "tool",
58 | "content": "Demographic features of the population of Senegal include population density, ethnicity, education level, health of the populace, economic status, religious affiliations and other aspects of the population.",
59 | "tool_call_id": "call_v65qP9Jc2dAY3C6X8O5YzjHX"
60 | },
61 | {
62 | "role": "assistant",
63 | "content": "",
64 | "tool_calls": [
65 | {
66 | "id": "call_z5vV9dcnCxSGShyaF0JCsXa3",
67 | "type": "function",
68 | "function": {
69 | "name": "get_wikipedia_summary",
70 | "arguments": {
71 | "page_title": "Elon_Musk"
72 | }
73 | }
74 | }
75 | ]
76 | },
77 | {
78 | "name": "get_wikipedia_summary",
79 | "role": "tool",
80 | "content": "Elon Reeve Musk is a businessman and investor. He is the founder, chairman, CEO, and CTO of SpaceX; angel investor, CEO, product architect and former chairman of Tesla, Inc.; owner, chairman and CTO of X Corp.; founder of the Boring Company and xAI; co-founder of Neuralink and OpenAI; and president of the Musk Foundation. He is the wealthiest person in the world, with an estimated net worth of US$232 billion as of December 2023, according to the Bloomberg Billionaires Index, and $254 billion according to Forbes, primarily from his ownership stakes in Tesla and SpaceX.",
81 | "tool_call_id": "call_z5vV9dcnCxSGShyaF0JCsXa3"
82 | },
83 | {
84 | "role": "assistant",
85 | "content": "",
86 | "tool_calls": [
87 | {
88 | "id": "call_xa6vT4h9RE7RmliTQAXXw3Zn",
89 | "type": "function",
90 | "function": {
91 | "name": "code_interpreter",
92 | "arguments": {
93 | "code": "population_morocco = 37.271\npopulation_senegal = 16.743\nresult = population_morocco * population_senegal\nelon_musk_age = 51\nfinal_result = result ** elon_musk_age\nfinal_result"
94 | }
95 | }
96 | }
97 | ]
98 | },
99 | {
100 | "name": "code_interpreter",
101 | "role": "tool",
102 | "content": "The code did not return anything. Did you forget to print?",
103 | "tool_call_id": "call_xa6vT4h9RE7RmliTQAXXw3Zn"
104 | },
105 | {
106 | "role": "assistant",
107 | "content": "",
108 | "tool_calls": [
109 | {
110 | "id": "call_MhrGzib7PgBI9ZBsEcUk5Y0h",
111 | "type": "function",
112 | "function": {
113 | "name": "code_interpreter",
114 | "arguments": {
115 | "code": "population_morocco = 37.271\npopulation_senegal = 16.743\nresult = population_morocco * population_senegal\nelon_musk_age = 51\nfinal_result = result ** elon_musk_age\nprint(final_result)"
116 | }
117 | }
118 | }
119 | ]
120 | },
121 | {
122 | "name": "code_interpreter",
123 | "role": "tool",
124 | "content": "StdOut:\nb'3.5926958058201675e+142\\n'\nStdErr:\nb''",
125 | "tool_call_id": "call_MhrGzib7PgBI9ZBsEcUk5Y0h"
126 | }
127 | ]
128 | tools = tinyllm_toolkit().as_dict_list()
129 |
130 | response = completion(
131 | messages= messages,
132 | model="gpt-3.5-turbo",
133 | tools=tools
134 | )
135 |
136 |
--------------------------------------------------------------------------------
/docs/examples/llm_chat.py:
--------------------------------------------------------------------------------
1 | from tinyllm.eval.evaluator import Evaluator
2 | from tinyllm.llms.lite_llm import LiteLLM
3 | from tinyllm.llms.lite_llm_stream import LiteLLMStream
4 | from tinyllm.util.helpers import get_openai_message
5 |
6 | import asyncio
7 |
8 | loop = asyncio.get_event_loop()
9 |
10 | #### Basic chat
11 | message = get_openai_message(role='user',
12 | content="Hi")
13 | litellm_chat = LiteLLM()
14 | response = loop.run_until_complete(litellm_chat(messages=[message]))
15 |
16 |
17 | #### Chat with evaluator
18 |
19 | class SuccessFullRunEvaluator(Evaluator):
20 | async def run(self, **kwargs):
21 | print('Evaluating...')
22 | return {
23 | "evals": {
24 | "successful_score": 1,
25 | },
26 | "metadata": {}
27 | }
28 |
29 |
30 | litellm_chat = LiteLLM(run_evaluators=[SuccessFullRunEvaluator()],
31 | processed_output_evaluators=[SuccessFullRunEvaluator()])
32 | message = get_openai_message(role='user',
33 | content="Hi")
34 | result = loop.run_until_complete(litellm_chat(messages=[message]))
35 |
36 | #### Chat stream
37 |
38 | litellmstream_chat = LiteLLMStream(name='Test: LiteLLM Stream')
39 |
40 | async def get_stream():
41 | message = get_openai_message(role='user',
42 | content="Hi")
43 | async for msg in litellmstream_chat(messages=[message]):
44 | i = 0
45 | return msg
46 |
47 |
48 | result = loop.run_until_complete(get_stream())
49 | print(result['output']['streaming_status'])
50 |
--------------------------------------------------------------------------------
/docs/get_started.md:
--------------------------------------------------------------------------------
1 | ## Getting started with tinyllm
2 |
3 | ## Configs
4 |
5 |
6 | A yaml config file needs to be provided as an environment variable. [**The config file template is is available here**](https://github.com/zozoheir/tinyllm/blob/main/tinyllm.yaml)
7 |
8 | ```python
9 | import os
10 | os.environ['TINYLLM_CONFIG_PATH'] = '/path/to/tinyllm.yaml'
11 | ```
12 |
13 |
14 | If the config file is not provided, the library will look for a tinyllm.yaml config file in the following directories:
15 | - The current working directory
16 | - The user's home directory
17 | - The user's Document's directory
18 |
19 |
20 | ## Example agent
21 |
22 |
23 | ```python
24 | import asyncio
25 | from typing import Any, Optional
26 | from pydantic import BaseModel, Field
27 | from tinyllm.agent.agent import Agent
28 |
29 | class Person(BaseModel):
30 | name: str = Field(..., description='Name of the person')
31 | age: int = Field(..., description='Age of the person')
32 | note: Optional[Any]
33 |
34 | class RiskScoreOutput(BaseModel):
35 | risk_score: float = Field(..., description='A risk score between 0 and 1')
36 | person: Person
37 |
38 | tiny_agent = Agent(
39 | name='Test: Agent JSON output',
40 | system_role="You are a Credit Risk Analyst. Respond with a risk score based on the provided customer data",
41 | output_model=RiskScoreOutput
42 | )
43 |
44 | result = asyncio.run(tiny_agent(content="Johny Vargas, 29yo, the customer has missed 99% of his bill payments in the last year"))
45 |
46 |
47 | ```
48 |
--------------------------------------------------------------------------------
/release.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Run the commit.sh script
4 | #./commit.sh
5 |
6 | if [ $? -eq 0 ]; then
7 | echo "Commit successful, proceeding to activate virtual environment."
8 | else
9 | echo "Commit failed, release process aborted."
10 | exit 1
11 | fi
12 |
13 | # Activate virtual environment
14 | source "/Users/othmanezoheir/venv/rumorz-jobs-2/bin/activate"
15 |
16 | if [ $? -eq 0 ]; then
17 | echo "Virtual environment activated successfully."
18 | else
19 | echo "Failed to activate virtual environment, release process aborted."
20 | exit 1
21 | fi
22 |
23 | # Increment version number
24 | bump2version patch
25 |
26 | if [ $? -eq 0 ]; then
27 | echo "Version number incremented successfully."
28 | else
29 | echo "Failed to increment version number, release process aborted."
30 | exit 1
31 | fi
32 |
33 | # Clearing the dist/ directory
34 | echo "Cleaning up the dist/ directory..."
35 | rm -rf dist/*
36 |
37 |
38 | # Build the package
39 | python setup.py sdist bdist_wheel
40 |
41 | if [ $? -eq 0 ]; then
42 | echo "Package built successfully."
43 | elseO
44 | echo "Failed to build package, release process aborted."
45 | exit 1
46 | fi
47 |
48 | # Upload to PyPI
49 | twine upload dist/*
50 |
51 | if [ $? -eq 0 ]; then
52 | echo "Package uploaded successfully. New version released!"
53 | else
54 | echo "Failed to upload package, release process aborted."
55 | exit 1
56 | fi
57 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(name='tinyllm',
4 | version='0.1.1',
5 | description='Development and management infrastructure for LLM applications',
6 | packages=find_packages(),
7 | install_requires=[
8 | 'astor',
9 | 'asyncpg',
10 | 'click',
11 | 'tenacity',
12 | 'langfuse',
13 | 'litellm',
14 | 'openai',
15 | 'numpydoc',
16 | 'pathspec',
17 | 'pgvector',
18 | 'psutil',
19 | 'psycopg2-binary',
20 | 'pydantic>=2.0',
21 | 'pyperclip',
22 | 'pytest',
23 | 'pyyaml',
24 | 'python-Levenshtein',
25 | 'sqlalchemy',
26 | 'tiktoken',
27 | 'typing-extensions',
28 | 'uuid',
29 | 'fuzzywuzzy'
30 | ],
31 | author='Othmane Zoheir',
32 | author_email='othmane@rumorz.io',
33 | url='')
34 |
--------------------------------------------------------------------------------
/tinyllm.yaml:
--------------------------------------------------------------------------------
1 | LOGS:
2 | LOGGING: true
3 | LOG_STATES:
4 | - 'RUNNING'
5 | - 'COMPLETE'
6 | - 'FAILED'
7 | LLM_PROVIDERS:
8 | OPENAI_API_KEY:
9 | ANYSCALE_API_KEY:
10 | AZURE_API_KEY:
11 | AZURE_API_BASE:
12 | AZURE_API_VERSION:
13 | LANGFUSE:
14 | PROJECT_ID:
15 | PUBLIC_KEY:
16 | SECRET_KEY:
17 | HOST:
18 | POSTGRES:
19 | USERNAME:
20 | PASSWORD:
21 | HOST:
22 | PORT:
23 | NAME:
--------------------------------------------------------------------------------
/tinyllm/__init__.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import os
3 | import pyperclip
4 |
5 | from langfuse import Langfuse
6 |
7 | import logging
8 | from logging import StreamHandler, Formatter
9 | from pathlib import Path
10 |
11 | tinyllm_logger = logging.getLogger('tinyllm')
12 | tinyllm_logger.propagate = False
13 | tinyllm_logger.setLevel(logging.DEBUG)
14 | formatter = Formatter('%(levelname)s | %(name)s | %(asctime)s : %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
15 | ch = StreamHandler()
16 | ch.setLevel(logging.DEBUG)
17 | ch.setFormatter(formatter)
18 | tinyllm_logger.addHandler(ch)
19 |
20 |
21 |
22 | global tinyllm_config
23 | global langfuse_client
24 |
25 | tinyllm_config = None
26 | langfuse_client = None
27 |
28 | def load_yaml_config(yaml_file_path: str) -> dict:
29 | config = None
30 | yaml_path = Path(yaml_file_path.strip())
31 | if yaml_path.is_file():
32 | with open(yaml_path, 'r') as stream:
33 | try:
34 | config = yaml.safe_load(stream)
35 | except yaml.YAMLError as exc:
36 | tinyllm_logger.error(f"Error loading YAML file: {exc}")
37 | raise exc
38 | else:
39 | tinyllm_logger.error(f"Config file not found at {yaml_path}")
40 | raise FileNotFoundError(f"No config file at {yaml_path}")
41 | return config
42 |
43 |
44 | def set_config(file_path: str):
45 |
46 |
47 | global tinyllm_config
48 | global langfuse_client
49 |
50 | # Load config file
51 |
52 | tinyllm_config = load_yaml_config(file_path)
53 |
54 | # Set LLM providers env vars from the config file
55 | for provider_key in tinyllm_config['LLM_PROVIDERS'].keys():
56 | os.environ[provider_key] = tinyllm_config['LLM_PROVIDERS'][provider_key]
57 |
58 | # Initialize Langfuse client
59 | langfuse_client = Langfuse(
60 | public_key=tinyllm_config['LANGFUSE']['PUBLIC_KEY'],
61 | secret_key=tinyllm_config['LANGFUSE']['SECRET_KEY'],
62 | host=tinyllm_config['LANGFUSE']['HOST'],
63 | flush_interval=0.1,
64 | )
65 |
66 |
67 | def find_yaml_config(yaml_file_name: str, directories: list) -> dict:
68 | for directory in directories:
69 | if directory is None:
70 | continue
71 | yaml_path = Path(directory) / yaml_file_name
72 | if yaml_path.is_file():
73 | tinyllm_logger.info(f"Tinyllm: config found at {yaml_path}")
74 | return yaml_path
75 |
76 |
77 | # Directories to look for the config file, in order of priority
78 |
79 | directories = [
80 | Path.cwd() if Path.cwd().name != 'tinyllm' else None,
81 | Path.home(),
82 | Path.home() / 'Documents',
83 | ]
84 |
85 | if langfuse_client is None and tinyllm_config is None:
86 | tinyllm_config_file_path = os.environ.get('TINYLLM_CONFIG_PATH', None)
87 | tinyllm_logger.info(f"TINYLLM_CONFIG_PATH: {tinyllm_config_file_path}")
88 | if tinyllm_config_file_path is not None and tinyllm_config_file_path != '':
89 | set_config(tinyllm_config_file_path)
90 | else:
91 | tinyllm_logger.info(f"Tinyllm: no config file path provided, searching for config file")
92 | found_config_path = find_yaml_config('tinyllm.yaml', directories)
93 | if found_config_path is None:
94 | raise FileNotFoundError(f"Please provide a config file for tinyllm")
95 | set_config(found_config_path)
96 |
97 |
98 |
99 |
100 | def get_agent_code(system_role):
101 | definition = f"""
102 | tiny_agent = Agent(
103 | name='My Agent',
104 | system_role='{system_role}',
105 | output_model=None
106 | )
107 | """
108 | pyperclip.copy(definition)
109 |
110 |
111 |
--------------------------------------------------------------------------------
/tinyllm/agent/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/agent/__init__.py
--------------------------------------------------------------------------------
/tinyllm/agent/agent.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pprint
3 | from abc import abstractmethod
4 | from typing import Optional, Union, List, Type, Callable
5 |
6 | from pydantic import BaseModel
7 |
8 | from tinyllm.agent.tool import Toolkit
9 | from tinyllm.examples.example_manager import ExampleManager
10 | from tinyllm.function import Function
11 | from tinyllm.llms.lite_llm import LiteLLM
12 | from tinyllm.memory.memory import Memory, BufferMemory
13 | from tinyllm.prompt_manager import PromptManager, MaxTokensStrategy
14 | from tinyllm.util.message import Content, UserMessage, ToolMessage, AssistantMessage
15 | from tinyllm.validator import Validator
16 |
17 |
18 | class AgentInitValidator(Validator):
19 | system_role: str
20 | llm: Optional[Function]
21 | memory: Optional[Memory]
22 | toolkit: Optional[Toolkit]
23 | example_manager: Optional[ExampleManager]
24 | initial_user_message_text: Optional[str]
25 | tool_retries: Optional[int]
26 | output_model: Optional[Type[BaseModel]]
27 | prompt_manager: Optional[PromptManager]
28 |
29 |
30 | class AgentInputValidator(Validator):
31 | content: Union[str, list, Content, List[Content]]
32 | max_tokens_strategy: Optional[MaxTokensStrategy] = None
33 | allowed_max_tokens: Optional[int] = int(4096 * 0.25)
34 | expects_block: Optional[str] = None
35 |
36 |
37 | class Brain(BaseModel):
38 | personality: Optional[str]
39 |
40 | @abstractmethod
41 | def update(self, **kwargs):
42 | pass
43 |
44 |
45 | class AgentCallBackHandler:
46 |
47 | async def on_tools(self,
48 | **kwargs):
49 | pass
50 |
51 |
52 | class Agent(Function):
53 |
54 | def __init__(self,
55 | system_role: str = 'You are a helpful assistant',
56 | example_manager: Optional[ExampleManager] = None,
57 | llm: Function = None,
58 | memory: Memory = None,
59 | toolkit: Optional[Toolkit] = None,
60 | initial_user_message_text: Optional[str] = None,
61 | prompt_manager: Optional[PromptManager] = None,
62 | tool_retries: int = 3,
63 | output_model: Optional[Type[BaseModel]] = None,
64 | brain: Brain = None,
65 | **kwargs):
66 |
67 | AgentInitValidator(system_role=system_role,
68 | llm=llm,
69 | toolkit=toolkit,
70 | memory=memory,
71 | example_manager=example_manager,
72 | initial_user_message_text=initial_user_message_text,
73 | tool_retries=tool_retries,
74 | output_model=output_model,
75 | brain=brain,
76 | prompt_manager=None)
77 | super().__init__(
78 | input_validator=AgentInputValidator,
79 | **kwargs
80 | )
81 | self.system_role = system_role.strip()
82 | self.output_model = output_model
83 | self.llm = llm or LiteLLM()
84 | self.toolkit = toolkit
85 | self.example_manager = example_manager
86 | self.prompt_manager = PromptManager(
87 | system_role=self.system_role,
88 | example_manager=example_manager,
89 | memory=memory or BufferMemory() if toolkit else None,
90 | initial_user_message_text=initial_user_message_text,
91 | ) if prompt_manager is None else prompt_manager
92 | self.tool_retries = tool_retries
93 | self.is_stuck = False
94 | self.brain = brain
95 | self.session_tool_messages = []
96 |
97 | @property
98 | def tools(self):
99 | if self.is_stuck:
100 | return None
101 | else:
102 | return self.toolkit.as_dict_list() if self.toolkit else None
103 |
104 | async def run(self,
105 | **kwargs):
106 |
107 | input_msgs = [UserMessage(kwargs['content'])]
108 |
109 | while True: # Loop until agent decides to respond
110 |
111 | request_kwargs = await self.prompt_manager.prepare_llm_request(messages=input_msgs,
112 | json_model=self.output_model,
113 | **kwargs)
114 | response_msg = await self.llm(tools=self.tools,
115 | **request_kwargs)
116 |
117 | for msg in input_msgs:
118 | await self.prompt_manager.add_memory(message=msg)
119 |
120 | if response_msg['status'] == 'success':
121 | is_tool_call = response_msg['output']['response']['choices'][0]['message'] == 'tool_calls' or \
122 | response_msg['output']['response']['choices'][0]['message'].get('tool_calls',
123 | None) is not None
124 |
125 | if is_tool_call:
126 | # Agent decides to call a tool
127 |
128 | # Memorize tool call
129 | tool_call_msg = response_msg['output']['response']['choices'][0]['message']
130 | await self.prompt_manager.add_memory(message=AssistantMessage(content='',
131 | tool_calls=tool_call_msg[
132 | 'tool_calls']))
133 |
134 | tool_calls = tool_call_msg.get('tool_calls', [])
135 | tool_results = await self.toolkit(
136 | tool_calls=[{
137 | 'name': tool_call['function']['name'],
138 | 'arguments': json.loads(tool_call['function']['arguments'])
139 | } for tool_call in tool_calls]
140 | )
141 |
142 |
143 | # Format for next openai call
144 | tool_call_messages = [ToolMessage(name=tool_result['name'],
145 | content=pprint.pformat(tool_result['content']),
146 | tool_call_id=tool_call['id']) for tool_result, tool_call in
147 | zip(tool_results['output']['tool_results'], tool_calls)]
148 |
149 | # Set next input
150 | input_msgs = tool_call_messages
151 |
152 | else:
153 |
154 | # Agent decides to respond
155 | if self.output_model is None:
156 | msg_content = response_msg['output']['response']['choices'][0]['message']['content']
157 | await self.prompt_manager.add_memory(
158 | message=AssistantMessage(msg_content)
159 | )
160 | return {'response': response_msg['output']['response']}
161 | else:
162 | msg_content = response_msg['output']['response']['choices'][0]['message']['content']
163 | msg_content = msg_content.replace('```json', '').replace('```', '')
164 | parsed_output = json.loads(msg_content)
165 | await self.prompt_manager.add_memory(
166 | message=AssistantMessage(msg_content)
167 | )
168 | return {'response': self.output_model(**parsed_output)}
169 |
170 | else:
171 | raise (Exception(response_msg))
172 |
173 | def is_tool_stuck(self, session_tool_results):
174 | if len(session_tool_results) < self.tool_retries:
175 | return False
176 | retry_window_tool_results = session_tool_results[len(session_tool_results) - self.tool_retries:]
177 | all_results_the_same = all(
178 | [tool_results == retry_window_tool_results[0] for tool_results in retry_window_tool_results])
179 | return all_results_the_same
180 |
--------------------------------------------------------------------------------
/tinyllm/agent/agent_stream.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Optional
3 |
4 | from tinyllm.agent.agent import AgentInitValidator, AgentInputValidator
5 | from tinyllm.agent.tool import Toolkit
6 |
7 | from tinyllm.examples.example_manager import ExampleManager
8 | from tinyllm.function_stream import FunctionStream
9 | from tinyllm.llms.lite_llm_stream import LiteLLMStream
10 | from tinyllm.memory.memory import BufferMemory, Memory
11 | from tinyllm.prompt_manager import PromptManager
12 | from tinyllm.util.helpers import get_openai_message
13 | from tinyllm.util.message import UserMessage, AssistantMessage, ToolMessage
14 |
15 |
16 | class AgentStream(FunctionStream):
17 |
18 | def __init__(self,
19 | system_role: str = 'You are a helpful assistant',
20 | example_manager: Optional[ExampleManager] = None,
21 | llm: FunctionStream = None,
22 | memory: Memory = None,
23 | toolkit: Optional[Toolkit] = None,
24 | initial_user_message_text: Optional[str] = None,
25 | tool_retries: int = 3,
26 | **kwargs):
27 | AgentInitValidator(system_role=system_role,
28 | llm=llm,
29 | toolkit=toolkit,
30 | memory=memory,
31 | example_manager=example_manager,
32 | initial_user_message_text=initial_user_message_text,
33 | tool_retries=tool_retries
34 | )
35 | super().__init__(
36 | input_validator=AgentInputValidator,
37 | **kwargs)
38 | self.system_role = system_role
39 | self.llm = llm or LiteLLMStream()
40 | self.toolkit = toolkit
41 | self.example_manager = example_manager
42 | self.prompt_manager = PromptManager(
43 | system_role=system_role,
44 | example_manager=example_manager,
45 | memory=memory or BufferMemory(),
46 | initial_user_message_text=initial_user_message_text,
47 | )
48 | self.tool_retries = tool_retries
49 |
50 | async def run(self,
51 | **kwargs):
52 |
53 | input_msg = UserMessage(kwargs['content'])
54 |
55 | while True:
56 | kwargs = await self.prompt_manager.prepare_llm_request(message=input_msg,
57 | **kwargs)
58 |
59 | async for msg in self.llm(tools=self.toolkit.as_dict_list() if self.toolkit else None,
60 | **kwargs):
61 | yield msg
62 |
63 | await self.prompt_manager.add_memory(message=input_msg)
64 |
65 | # Process the last message
66 | if msg['status'] == 'success':
67 | msg_output = msg['output']
68 |
69 | # Agent decides to call a tool
70 | if msg_output['type'] == 'tool':
71 | input_msg = await self.get_tool_message(msg_output)
72 | elif msg_output['type'] == 'assistant':
73 | break
74 |
75 | else:
76 | raise Exception(msg['message'])
77 |
78 | async def get_tool_message(self,
79 | msg_output):
80 | api_tool_call = msg_output['last_completion_delta']['tool_calls'][0]
81 | msg_output['last_completion_delta'].pop('function_call')
82 |
83 | # Memorize tool call with arguments
84 | json_tool_call = {
85 | 'name': msg_output['completion']['name'],
86 | 'arguments': msg_output['completion']['arguments']
87 | }
88 | api_tool_call['function'] = json_tool_call
89 | msg_output['last_completion_delta']['content'] = ''
90 | await self.prompt_manager.add_memory(message=AssistantMessage(content='',
91 | tool_calls=msg_output['last_completion_delta']['tool_calls']))
92 |
93 | # Memorize tool result
94 | tool_results = await self.toolkit(
95 | tool_calls=[{
96 | 'name': api_tool_call['function']['name'],
97 | 'arguments': json.loads(api_tool_call['function']['arguments'])
98 | }])
99 | tool_result = tool_results['output']['tool_results'][0]
100 | tool_call_result_msg = get_openai_message(
101 | name=tool_result['name'],
102 | role='tool',
103 | content=tool_result['content'],
104 | tool_call_id=api_tool_call['id']
105 | )
106 |
107 | tool_call_result_msg.pop('role')
108 |
109 | return ToolMessage(**tool_call_result_msg)
110 |
--------------------------------------------------------------------------------
/tinyllm/agent/tool/__init__.py:
--------------------------------------------------------------------------------
1 | from tinyllm.agent.tool.toolkit import Toolkit
2 | from tinyllm.agent.tool.tools.code_interpreter import get_code_interpreter_tool
3 | from tinyllm.agent.tool.tools.think_plan import get_think_and_plan_tool
4 | from tinyllm.agent.tool.tools.wikipedia import get_wikipedia_summary_tool
5 |
6 |
7 | def tinyllm_toolkit():
8 | return Toolkit(
9 | name='Toolkit',
10 | tools=[
11 | get_think_and_plan_tool(),
12 | get_code_interpreter_tool(),
13 | get_wikipedia_summary_tool(),
14 | ],
15 | )
16 |
17 |
--------------------------------------------------------------------------------
/tinyllm/agent/tool/tool.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import traceback
3 | from functools import partial
4 | from typing import Dict, Callable
5 |
6 | from langchain_core.utils.function_calling import convert_to_openai_function, convert_to_openai_tool
7 |
8 | from tinyllm.function import Function
9 | from tinyllm.tracing.langfuse_context import observation
10 | from tinyllm.util.helpers import get_openai_message
11 | from tinyllm.validator import Validator
12 |
13 |
14 | class ToolInitValidator(Validator):
15 | description: str
16 | parameters: dict
17 | python_lambda: Callable
18 |
19 |
20 |
21 |
22 | class Tool(Function):
23 |
24 | @classmethod
25 | def from_pydantic_model(cls, model):
26 | dictionary = convert_to_openai_tool(model)
27 | return cls(
28 | name=dictionary['function']['name'],
29 | description=dictionary['function']['description'],
30 | python_lambda=lambda **kwargs: model(**kwargs).call_tool(**kwargs),
31 | parameters=dictionary['function']['parameters']
32 | )
33 |
34 |
35 | def __init__(self,
36 | description,
37 | parameters,
38 | python_lambda,
39 | **kwargs):
40 | ToolInitValidator(
41 | description=description,
42 | parameters=parameters,
43 | python_lambda=python_lambda,
44 | )
45 | super().__init__(
46 | **kwargs)
47 | self.description = description.strip()
48 | self.parameters = parameters
49 | self.python_lambda = python_lambda
50 |
51 | def as_dict(self):
52 | return {
53 | "type": "function",
54 | "function": {
55 | "name": self.name,
56 | "description": self.description,
57 | "parameters": self.parameters,
58 | }
59 | }
60 |
61 | @observation(observation_type='span')
62 | async def run(self, **kwargs):
63 | try:
64 | if inspect.iscoroutinefunction(self.python_lambda):
65 | tool_output = await self.python_lambda(**kwargs)
66 | else:
67 | tool_output = self.python_lambda(**kwargs)
68 | except:
69 | tool_output = f"""
70 |
71 | The tool returned the following error:
72 | {traceback.format_exc()}
73 |
74 | """
75 |
76 | return {'response': get_openai_message(role='tool', content=tool_output, name=self.name)}
77 |
--------------------------------------------------------------------------------
/tinyllm/agent/tool/toolkit.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import List, Dict
3 |
4 | from tinyllm.agent.tool.tool import Tool
5 | from tinyllm.function import Function
6 | from tinyllm.tracing.langfuse_context import observation
7 | from tinyllm.validator import Validator
8 |
9 |
10 | class ToolkitInputValidator(Validator):
11 | tool_calls: List[Dict]
12 |
13 |
14 | class ToolkitOutputValidator(Validator):
15 | tool_results: List[Dict]
16 |
17 |
18 | class Toolkit(Function):
19 |
20 | def __init__(self,
21 | tools: List[Tool],
22 | **kwargs):
23 | super().__init__(
24 | input_validator=ToolkitInputValidator,
25 | **kwargs)
26 | self.tools = tools
27 |
28 | @observation(observation_type='span', input_mapping={'input': 'tool_calls'})
29 | async def run(self,
30 | **kwargs):
31 | tasks = []
32 |
33 | for tool_call in kwargs['tool_calls']:
34 | name = tool_call['name']
35 | tool = [tool for tool in self.tools if tool.name == name][0]
36 | tasks.append(tool(**tool_call['arguments']))
37 |
38 | results = await asyncio.gather(*tasks)
39 | tool_results = [result['output']['response'] for result in results]
40 | return {'tool_results': tool_results,
41 | 'tool_calls': kwargs['tool_calls']}
42 |
43 | def as_dict_list(self):
44 | return [tool.as_dict() for tool in self.tools]
45 |
--------------------------------------------------------------------------------
/tinyllm/agent/tool/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/agent/tool/tools/__init__.py
--------------------------------------------------------------------------------
/tinyllm/agent/tool/tools/code_interpreter.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import astor
3 | import subprocess
4 | import sys
5 |
6 | from tinyllm.agent.tool.tool import Tool
7 |
8 |
9 | def code_interpreter(code: str):
10 | modified_code = modify_code_to_print_last_expression(code)
11 | result = run_code(modified_code)
12 | if result.stdout == b'' and result.stderr == b'':
13 | return "The code did not return anything. Did you forget to print?"
14 |
15 | return f"StdOut:\n{result.stdout.decode('utf-8')}\nStdErr:\n{result.stderr.decode('utf-8')}"
16 |
17 |
18 | def run_code(code):
19 | return subprocess.run(
20 | [sys.executable, "-c", code], stdout=subprocess.PIPE, stderr=subprocess.PIPE
21 | )
22 |
23 |
24 | def modify_code_to_print_last_expression(code):
25 | """
26 | Parse the code and modify it to print the last expression or variable assignment.
27 | """
28 | try:
29 | tree = ast.parse(code)
30 | last_node = tree.body[-1]
31 |
32 | # Check if the last node is an expression or a variable assignment
33 | if isinstance(last_node, (ast.Expr, ast.Assign)):
34 | # Create a print node
35 | if isinstance(last_node, ast.Assign):
36 | # For variable assignment, print the variable
37 | print_node = ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
38 | args=[last_node.targets[0]],
39 | keywords=[]))
40 | else:
41 | # For direct expressions, print the expression
42 | print_node = ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
43 | args=[last_node.value],
44 | keywords=[]))
45 |
46 | # Add the print node to the AST
47 | tree.body.append(print_node)
48 |
49 | # Convert the AST back to code
50 | return astor.to_source(tree)
51 |
52 | except SyntaxError as e:
53 | return f"SyntaxError: {e}"
54 |
55 |
56 |
57 | def get_code_interpreter_tool():
58 | return Tool(
59 | name="code_interpreter",
60 | description="""
61 | Use this tool to run python code.
62 | """,
63 | python_lambda=code_interpreter,
64 | parameters={
65 | "type": "object",
66 | "required": ["code"],
67 | "properties": {
68 | "code": {
69 | "type": "string",
70 | "description": "Python code to execute",
71 | },
72 | }
73 | }
74 | )
75 |
--------------------------------------------------------------------------------
/tinyllm/agent/tool/tools/think_plan.py:
--------------------------------------------------------------------------------
1 | from tinyllm.agent.tool.tool import Tool
2 |
3 |
4 | def think_and_plan(execution_plan: str):
5 | return f"Execution plan: \n{execution_plan}"
6 |
7 |
8 | def get_think_and_plan_tool():
9 | return Tool(
10 | name="think_and_plan",
11 | description="Use this tool to plan complex execution of a task using tools",
12 | python_lambda=think_and_plan,
13 | parameters={
14 | "type": "object",
15 | "required": ["symbol"],
16 | "properties": {
17 | "execution_plan": {
18 | "type": "string",
19 | "description": "Think step by step about the execution plan, and output a numbered list of tools to execute. Eg: 1. tool1: > 2. tool2: ",
20 | },
21 | }
22 | }
23 | )
24 |
--------------------------------------------------------------------------------
/tinyllm/agent/tool/tools/wikipedia.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | from tinyllm.agent.tool.tool import Tool
4 |
5 |
6 | def get_wikipedia_summary(page_title):
7 | """ Fetch the summary of a Wikipedia page given its title. """
8 | URL = "https://en.wikipedia.org/api/rest_v1/page/summary/" + page_title
9 | try:
10 | response = requests.get(URL)
11 | data = response.json()
12 | return data.get('extract', 'No summary available.')
13 | except requests.RequestException as e:
14 | return str(e)
15 |
16 |
17 | def get_wikipedia_summary_tool():
18 | return Tool(
19 | name="get_wikipedia_summary",
20 | description="""
21 | Use this tool to get a Wikipedia summary.
22 | """,
23 | python_lambda=get_wikipedia_summary,
24 | parameters={
25 | "type": "object",
26 | "required": ["page_title"],
27 | "properties": {
28 | "page_title": {
29 | "type": "string",
30 | "description": "The wikipedia page_title (no spaces, only '_'). Example: Elon_Musk, Python_(programming_language)",
31 | },
32 | }
33 | }
34 | )
35 |
--------------------------------------------------------------------------------
/tinyllm/constants.py:
--------------------------------------------------------------------------------
1 | LLM_PRICING = {
2 | "azure/gpt-4o-mini": {
3 | "input": 0.000150,
4 | "output": 0.000600,
5 | },
6 | "azure/gpt-4o": {
7 | "input": 0.00500,
8 | "output": 0.01500,
9 | },
10 | }
11 |
--------------------------------------------------------------------------------
/tinyllm/eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/eval/__init__.py
--------------------------------------------------------------------------------
/tinyllm/eval/evaluation_model.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import Column, UniqueConstraint, Integer, String
2 | from sqlalchemy.dialects import postgresql
3 | from sqlalchemy.orm import declarative_base
4 |
5 | Base = declarative_base()
6 |
7 | class Evaluations(Base):
8 | __tablename__ = "embeddings"
9 | __table_args__ = (UniqueConstraint('text', 'collection_name', name='uq_text_collection_name'),)
10 |
11 | id = Column(Integer, primary_key=True)
12 | context = Column(String)
13 | chat_response = Column(String)
14 | question = Column(String)
15 | correct_answer = Column(String)
16 | generated_answer = Column(String)
17 | generation_id = Column(String)
18 | scores = Column(postgresql.JSON)
--------------------------------------------------------------------------------
/tinyllm/eval/evaluator.py:
--------------------------------------------------------------------------------
1 | import pprint
2 | from typing import Optional, Any, Type, Union
3 | from tinyllm.function import Function
4 | from tinyllm.validator import Validator
5 |
6 |
7 |
8 | class EvaluatorInitValidator(Validator):
9 | prefix: Optional[str] = ''
10 |
11 | class EvaluatorInputValidator(Validator):
12 | observation: Any
13 |
14 | class EvaluatorInputValidator(Validator):
15 | observation: Any
16 |
17 | class EvaluatorOutputValidator(Validator):
18 | evals: dict
19 | comment: Optional[Any] = ""
20 |
21 |
22 | class Evaluator(Function):
23 |
24 | def __init__(self,
25 | prefix='',
26 | **kwargs):
27 | EvaluatorInitValidator(prefix=prefix)
28 | super().__init__(output_validator=EvaluatorOutputValidator,
29 | input_validator=EvaluatorInputValidator,
30 | **kwargs)
31 | self.prefix = prefix
32 | self.evals = None
33 |
34 |
35 | async def process_output(self, **kwargs):
36 | self.evals = kwargs['evals']
37 | for name, score in kwargs['evals'].items():
38 | self.input['observation'].score(
39 | name=self.prefix+name,
40 | value=score,
41 | comment=pprint.pformat(kwargs.get('metadata',{})),
42 | )
43 | return kwargs
--------------------------------------------------------------------------------
/tinyllm/eval/evaluators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/eval/evaluators/__init__.py
--------------------------------------------------------------------------------
/tinyllm/eval/evaluators/answer_accuracy_evaluator.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from tinyllm.eval.evaluator import Evaluator
4 | from tinyllm.llms.lite_llm import LiteLLM
5 | from tinyllm.util.helpers import get_openai_message
6 |
7 | EXAMPLE_INPUT = """
8 | Context:
9 | The spate of recommendations ended the silent period for the nearly 30 banks that underwrote Arm’s IPO in September.
10 | The chip manufacturer raised $4.87 billion for its owner, SoftBank Group, marking the biggest public listing of 2023. From a broader perspective, the IPO’s success provided much-needed confidence to investors and companies considering going public following a nearly two-year market drought. Arm’s IPO was one of the three big September listings, with delivery company Instacart and marketing automation firm Klaviyo debuting on the US stock exchanges.
11 | With Arm’s shares currently trading at $55.5 a piece, the aforementioned price targets by Wall Street giants imply the stock has an upside potential of between 10% and 27%. Meanwhile, some brokerages, like HSBC, offered a more cautious coverage for Arm’s stock, saying the company’s shares may remain range-bound due to smartphone market uncertainty.
12 | Where do you think Arm’s share price will stand by the end of 2023? Let us know in the comments below.
13 |
14 | Question:
15 | How much did Arm raise for its owner, SoftBank Group, during its IPO?
16 |
17 | Correct answer:
18 | Arm raised $4.87 billion for its owner, SoftBank Group, during its IPO.
19 |
20 | Generated answer:
21 | Arm Holdings raised a couple of billion dollars for its owner, SoftBank Group, during its IPO.
22 | """
23 |
24 | EXAMPLE_OUTPUT = """
25 | - Reasoning: The generated answer states that Arm Holdings raised "a couple of billion dollars" for its owner, SoftBank Group, during
26 | its IPO. The context clearly states that the chip manufacturer (Arm) raised "$4.87 billion" for SoftBank Group.
27 | The phrase "a couple of billion dollars" is typically interpreted as meaning "two billion dollars", which is
28 | significantly less than $4.87 billion. While the generated answer is in the correct ballpark, it is imprecise and
29 | understates the actual amount by nearly $3 billion.
30 | - Correctness score: 5/10
31 | """
32 |
33 | examples = [
34 | get_openai_message(role='user',content=EXAMPLE_INPUT),
35 | get_openai_message(role='assistant',content=EXAMPLE_OUTPUT)
36 | ]
37 |
38 | system_role = """
39 | ROLE:
40 | You are an evaluator. Given a question, a correct answer, and a generated answer, you are to evaluate the correctness of the
41 | predicted answer on a scale of 0 to 10 with respect to the question asked and correct answer.
42 | You will think and reason about the correctness of the generated answer then provide a Correctness score.
43 | If the the generated answer is "Not enough information", the score should be 0.
44 | """
45 |
46 |
47 | class AnswerCorrectnessEvaluator(Evaluator):
48 |
49 | def __init__(self, **kwargs):
50 | super().__init__(**kwargs)
51 |
52 | self.litellm_chat = LiteLLM(
53 | name="Answer Accuracy Evaluator",
54 | system_role=system_role,
55 |
56 | )
57 |
58 | async def run(self, **kwargs):
59 | context = kwargs["retrieved_chunks"]
60 | question = kwargs["input"]
61 | correct_answer = kwargs["correct_answer"]
62 | generated_answer = kwargs["response"]
63 | formatted_message = f"""
64 | Context:
65 | {context}
66 |
67 | Question:
68 | {question}
69 |
70 | Correct answer:
71 | {correct_answer}
72 |
73 | Generated answer:
74 | {generated_answer}
75 | """
76 | openai_response = await self.litellm_chat(
77 | message=formatted_message,
78 | generation_name="Answer Correctness Evaluator")
79 | chat_response = openai_response['output']["response"]
80 | # Regex patterns
81 | reasoning_pattern = r"- Reasoning: (.*?)- Correctness score:"
82 | correct_score_pattern = r"- Correctness score: (.*)"
83 | reasoning_match = re.search(reasoning_pattern, chat_response, re.DOTALL)
84 | truth_score_match = re.search(correct_score_pattern, chat_response)
85 | if reasoning_match:
86 | truth_score_reasoning = reasoning_match.group(1).strip()
87 | if truth_score_match:
88 | correctness_score = float(truth_score_match.group(1).strip().split('/')[0]) / 10
89 | return {
90 | "evals": {
91 | "correctness_score": correctness_score,
92 | },
93 | "metadata": {
94 | "truth_score_reasoning": truth_score_reasoning,
95 | }
96 | }
97 |
--------------------------------------------------------------------------------
/tinyllm/eval/evaluators/retrieval_evaluator.py:
--------------------------------------------------------------------------------
1 | from smartpy.utility.ai_util import get_cosine_similarity
2 | from tinyllm.eval.evaluator import Evaluator
3 |
4 |
5 | class RetrievalEvaluator(Evaluator):
6 |
7 | def __init__(self, **kwargs):
8 | super().__init__(**kwargs)
9 |
10 | async def run(self, **kwargs):
11 | truth_context = kwargs["truth_context"]
12 | question = kwargs["input"]
13 | retrieved_chunks = kwargs["retrieved_chunks"]
14 | chunk_texts = [chunk["text"] for chunk in retrieved_chunks]
15 | chunk_similarities = []
16 | async def embedding_function(text):
17 | return [[1]*384] #
18 |
19 | embeddings = await embedding_function(question)
20 | question_vector = embeddings[0]
21 | for chunk_text in chunk_texts:
22 | embeddings = await embedding_function(chunk_text)
23 | chunk_vector = embeddings[0]
24 | chunk_similarities.append(get_cosine_similarity(chunk_vector, question_vector))
25 |
26 | embeddings = await embedding_function(truth_context)
27 | truth_vectors = embeddings[0]
28 | truth_similarity = get_cosine_similarity(truth_vectors[0], question_vector)
29 | retrieved_similarity = sum(chunk_similarities) / len(chunk_similarities)
30 |
31 | evals = {
32 | "truth_similarity": truth_similarity,
33 | "precision": max(chunk_similarities) / truth_similarity,
34 | "avg_chunk_similarity_norm": retrieved_similarity / truth_similarity,
35 | "avg_chunk_similarity": retrieved_similarity,
36 | }
37 | return {
38 | "evals": evals,
39 | "metadata": {
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/tinyllm/eval/qa_generator.py:
--------------------------------------------------------------------------------
1 | import random
2 | import re
3 | from textwrap import dedent
4 | from typing import List
5 |
6 | from tinyllm.function import Function
7 | from tinyllm.llms.lite_llm import LiteLLM
8 | from tinyllm.util.helpers import get_openai_message
9 | from tinyllm.validator import Validator
10 |
11 | INPUT_EXAMPLE = """
12 | Relevant context for question/answer generation:
13 | The spate of recommendations ended the silent period for the nearly 30 banks that underwrote Arm’s IPO in September.
14 | The chip manufacturer raised $4.87 billion for its owner, SoftBank Group, marking the biggest public listing of 2023.
15 | From a broader perspective, the IPO’s success provided much-needed confidence to investors and companies considering
16 | going public following a nearly two-year market drought. Arm’s IPO was one of the three big September listings, with
17 | delivery company Instacart and marketing automation firm Klaviyo debuting on the US stock exchanges.
18 | With Arm’s shares currently trading at $55.5 a piece, the aforementioned price targets by Wall Street giants imply the stock has an upside potential of between 10% and 27%. Meanwhile, some brokerages, like HSBC, offered a more cautious coverage for Arm’s stock, saying the company’s shares may remain range-bound due to smartphone market uncertainty.
19 | Where do you think Arm’s share price will stand by the end of 2023? Let us know in the comments below.
20 | """
21 | OUTPUT_EXAMPLE = """
22 | Question: How much did Arm raise for its owner, SoftBank Group, during its IPO?
23 | Truthful answer: Arm raised $4.87 billion for its owner, SoftBank Group, during its IPO.
24 | """
25 |
26 | example_msgs = [
27 | get_openai_message(role='user', content=INPUT_EXAMPLE),
28 | get_openai_message(role='assistant', content=OUTPUT_EXAMPLE)
29 | ]
30 | system_role=dedent(f"""
31 | ROLE:
32 | You are a knowledgeable expert. Given a context, your role is to generate a relevant question about the context and
33 | provide a truthful answer based on the information in the context.
34 | """)
35 |
36 | class InputQASetGenerator(Validator):
37 | documents: list
38 | n: int
39 |
40 | class OutputQASetGenerator(Validator):
41 | qa_test_set: List[dict] # dict with keys: context, chat_response, question, correct_answer
42 |
43 |
44 | class QASetGenerator(Function):
45 |
46 | def __init__(self,
47 | **kwargs):
48 | super().__init__(input_validator=InputQASetGenerator,
49 | output_validator=OutputQASetGenerator,
50 | **kwargs)
51 |
52 | self.openai_chat = LiteLLM(
53 | name="QA Data Point Generator",
54 | system_role=system_role,
55 |
56 | )
57 |
58 | async def run(self, **kwargs):
59 | documents = kwargs["documents"]
60 | n = kwargs["n"]
61 |
62 | qa_test_set = []
63 | for _ in range(n):
64 | doc = random.choice(documents)
65 | context = doc["text"]
66 | formatted_context = dedent(f"""
67 | Relevant context for question/answer generation:
68 | {context}
69 | """)
70 | openai_response = await self.openai_chat(
71 | message=formatted_context,
72 | generation_name="QA Data Point Generator"
73 | )
74 |
75 | qa_test_set.append({
76 | "metadata": doc["emetadata"],
77 | "truth_context": context,
78 | "chat_response": openai_response['output']["response"]
79 | })
80 |
81 | return {"qa_test_set": qa_test_set}
82 |
83 | async def process_output(self, **kwargs) -> dict:
84 | qa_test_set = kwargs["qa_test_set"]
85 | for test_data_point in qa_test_set:
86 | chat_response = test_data_point['chat_response']
87 | question_match = re.search(r"Question: (.+?)\n", chat_response)
88 | answer_match = re.search(r"Truthful answer: (.+)", chat_response)
89 |
90 | if not (question_match and answer_match):
91 | raise ValueError("The provided string doesn't match the expected format.")
92 | else:
93 | question_match = question_match.group(1).strip()
94 | answer_match = answer_match.group(1).strip()
95 | if question_match and answer_match:
96 | test_data_point.update({
97 | "question": question_match,
98 | "correct_answer": answer_match
99 | })
100 | else:
101 | raise ValueError("The provided string doesn't match the expected format.")
102 | return {"qa_test_set": qa_test_set}
--------------------------------------------------------------------------------
/tinyllm/eval/rag_eval_pipeline.py:
--------------------------------------------------------------------------------
1 | """
2 | QuestionAnswerGenerator:
3 | - input: documents, embedding function
4 | - output: list of (context, question, correct_answer) dicts
5 |
6 | AnswerAccuracyEvaluator:
7 | - inputs : question, correct_answer, generated_output
8 | - outputs: accuracy, explanation
9 |
10 | Context relevance:
11 | - inputs: context, answer, generated_output
12 | - outputs: similarity
13 |
14 | EvalPipeline:
15 | - inputs: rag_lambda, evaluators, list of (context, question, correct_answer)
16 | - output: list of evaluator outputs
17 |
18 | """
19 | from typing import List
20 |
21 | from tinyllm.function import Function
22 |
23 | class RagEvaluationPipeline:
24 |
25 | def __init__(self,
26 | rag_lambda,
27 | qa_test_set,
28 | evaluators: List[Function] ):
29 | self.rag_lambda = rag_lambda
30 | self.qa_test_set = qa_test_set
31 | self.evaluators = evaluators
32 |
33 | async def run_evals(self):
34 |
35 | # Predict an answer for each question
36 | for data_point in self.qa_test_set:
37 | retrieved_chunks, generated_output, generation_id = await self.rag_lambda(data_point["question"])
38 | toinsert = {
39 | "retrieved_chunks": retrieved_chunks,
40 | "generated_answer": generated_output,
41 | "generation_id": generation_id
42 | }
43 | data_point.update(toinsert)
44 |
45 | # Run each evaluator
46 | for data_point in self.qa_test_set:
47 | data_point['scores'] = {}
48 | for evaluator in self.evaluators:
49 | eval_result = await evaluator(**data_point)
50 | if eval_result['status'] == 'success':
51 | data_point['scores'].update(eval_result['output'])
52 | else:
53 | data_point['scores'].update(eval_result)
54 | return self.qa_test_set
--------------------------------------------------------------------------------
/tinyllm/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/examples/__init__.py
--------------------------------------------------------------------------------
/tinyllm/examples/example_manager.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from tinyllm.util.message import *
4 |
5 |
6 | class Example:
7 | def __init__(self,
8 | user_message: UserMessage,
9 | assistant_message: AssistantMessage):
10 | self.user_message = user_message
11 | self.assistant_message = assistant_message
12 |
13 |
14 | class ExampleManager:
15 |
16 | def __init__(self,
17 | example_selector=None,
18 | constant_examples: Union[List[Example], Example] = None):
19 | self.constant_examples = constant_examples if type(constant_examples) == list else ([constant_examples]) if constant_examples else []
20 | self.example_selector = example_selector
21 |
--------------------------------------------------------------------------------
/tinyllm/examples/example_selector.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Optional, Callable
2 |
3 | import numpy as np
4 |
5 | from tinyllm.function import Function
6 | from tinyllm.validator import Validator
7 | from tinyllm.util.ai_util import get_top_n_similar_vectors_index
8 |
9 |
10 | class ExampleSelectorInitValidator(Validator):
11 | collection_name: str
12 |
13 | class InputValidator(Validator):
14 | input: str
15 | k: Optional[int] = 1
16 |
17 | class OutputValidator(Validator):
18 | best_examples: List[Dict]
19 |
20 | class ProcessedOutputValidator(Validator):
21 | best_examples: List[Dict]
22 |
23 | class ExampleSelectorInitValidator(Validator):
24 | examples: List[dict]
25 | embedding_function: Callable
26 | embeddings: Optional[List]
27 |
28 |
29 | class ExampleSelector(Function):
30 |
31 | def __init__(self,
32 | embedding_function,
33 | examples=[],
34 | embeddings=None,
35 | **kwargs):
36 | ExampleSelectorInitValidator(examples=examples,
37 | embedding_function=embedding_function,
38 | embeddings=embeddings)
39 | super().__init__(
40 | input_validator=InputValidator,
41 | output_validator=OutputValidator,
42 | processed_output_validator=ProcessedOutputValidator,
43 | **kwargs
44 | )
45 | self.example_dicts = examples
46 | self.embeddings = embeddings
47 | self.embedding_function = embedding_function
48 | all_example_dicts_have_embeddings = all([example.get('embedding') is not None for example in self.example_dicts])
49 | if all_example_dicts_have_embeddings is False and self.embedding_function is None:
50 | raise Exception('Example selector needs either an embedding function or vector embeddings for each example')
51 |
52 | async def embed_examples(self,
53 | **kwargs):
54 | example_dicts = kwargs.get('example_dicts', self.example_dicts)
55 | embeddings = []
56 | for example in example_dicts:
57 | embeddings_list = await self.embedding_function(example['user'])
58 | embeddings.append(embeddings_list[0])
59 | self.embeddings = embeddings
60 |
61 | async def run(self, **kwargs):
62 | embeddings = await self.embedding_function(kwargs['input'])
63 | similar_indexes = get_top_n_similar_vectors_index(input_vector=embeddings[0], vectors=self.embeddings, k=kwargs['k'])
64 | return {'best_examples': [self.example_dicts[i] for i in similar_indexes]}
65 |
66 | async def process_output(self, **kwargs):
67 | result = kwargs['best_examples']
68 | return {'best_examples': result}
--------------------------------------------------------------------------------
/tinyllm/exceptions.py:
--------------------------------------------------------------------------------
1 | class InvalidStateTransition(Exception):
2 | pass
3 |
4 | class MissingBlockException(Exception):
5 | pass
6 |
7 | class LLMJsonValidationError(Exception):
8 | pass
9 |
--------------------------------------------------------------------------------
/tinyllm/function.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import traceback
3 | from typing import Any, Optional, Type, Union
4 |
5 | from tinyllm.exceptions import InvalidStateTransition
6 | from tinyllm import langfuse_client, tinyllm_config, tinyllm_logger
7 | from tinyllm.state import States, ALLOWED_TRANSITIONS
8 | from tinyllm.tracing.langfuse_context import observation
9 | from tinyllm.validator import Validator
10 |
11 |
12 | class CallBackHandler:
13 |
14 | async def on_tools(self,
15 | **kwargs):
16 | pass
17 |
18 |
19 | class FunctionInitValidator(Validator):
20 | user_id: Optional[Union[str, int]]
21 | session_id: Optional[Union[str, int]]
22 | input_validator: Optional[Type[Validator]]
23 | output_validator: Optional[Type[Validator]]
24 | processed_output_validator: Optional[Type[Validator]] = None
25 | run_evaluators: Optional[list]
26 | processed_output_evaluators: Optional[list]
27 | stream: Optional[bool]
28 | callback_handler: Optional[Type[CallBackHandler]] = None
29 |
30 |
31 |
32 | class Function:
33 |
34 | def __init__(
35 | self,
36 | name=None,
37 | user_id=None,
38 | session_id=None,
39 | input_validator=Validator,
40 | output_validator=Validator,
41 | processed_output_validator=Validator,
42 | run_evaluators=[],
43 | processed_output_evaluators=[],
44 | required=True,
45 | stream=False,
46 | callback_handler=None
47 | ):
48 | FunctionInitValidator(
49 | user_id=user_id,
50 | session_id=str(session_id),
51 | input_validator=input_validator,
52 | output_validator=output_validator,
53 | processed_output_validator=processed_output_validator,
54 | run_evaluators=run_evaluators,
55 | processed_output_evaluators=processed_output_evaluators,
56 | stream=stream,
57 | )
58 | self.callback_handler = callback_handler
59 | self.user_id = user_id
60 | self.session_id = str(session_id)
61 | self.observation = None # For logging
62 |
63 | if name is None:
64 | self.name = self.__class__.__name__
65 | else:
66 | self.name = name
67 |
68 | self.input_validator = input_validator
69 | self.output_validator = output_validator
70 | self.processed_output_validator = processed_output_validator
71 | self.required = required
72 | self.state = None
73 | # Need the init the above to run the transition
74 | self.transition(States.INIT)
75 | self.input = None
76 | self.output = None
77 | self.processed_output = None
78 | self.current_observation = None
79 | self.run_evaluators = run_evaluators
80 | self.processed_output_evaluators = processed_output_evaluators
81 | for evaluator in self.processed_output_evaluators:
82 | evaluator.prefix = 'proc:'
83 |
84 | self.cache = {}
85 | self.generation = None
86 | self.trace = None
87 | self.stream = stream
88 | self.observation = None
89 |
90 | @observation('span')
91 | async def __call__(self, **kwargs):
92 | try:
93 | # Validate input
94 | self.input = kwargs
95 | self.transition(States.INPUT_VALIDATION)
96 | validated_input = self.validate_input(**kwargs)
97 | kwargs.update(validated_input)
98 | # Run
99 | self.transition(States.RUNNING)
100 | self.output = await self.run(**kwargs)
101 |
102 | # Validate output
103 | self.transition(States.OUTPUT_VALIDATION)
104 | self.validate_output(**self.output)
105 |
106 | # Evaluate output
107 | for evaluator in self.run_evaluators:
108 | await evaluator(**{'status': 'success', 'output': self.output}, observation=self.observation)
109 |
110 | # Process output
111 | self.transition(States.PROCESSING_OUTPUT)
112 | self.processed_output = await self.process_output(**self.output)
113 |
114 | # Validate processed output
115 | self.transition(States.PROCESSED_OUTPUT_VALIDATION)
116 |
117 | if self.processed_output_validator:
118 | validated_processed_output = self.validate_processed_output(**self.processed_output)
119 | self.processed_output.update(validated_processed_output)
120 |
121 | # Evaluate processed output
122 | for evaluator in self.processed_output_evaluators:
123 | await evaluator(**{'status': 'success', 'output': self.processed_output}, observation=self.observation)
124 |
125 | final_output = {"status": "success",
126 | "output": self.processed_output}
127 |
128 | await self.close(**final_output)
129 |
130 | # Complete
131 | self.transition(States.COMPLETE)
132 | langfuse_client.flush()
133 | return final_output
134 |
135 | except Exception as e:
136 | output_message = await self.handle_exception(e)
137 | # Raise or return error
138 | return output_message
139 |
140 |
141 | async def handle_exception(self,
142 | e):
143 | detailed_error_msg = traceback.format_exc()
144 | output_message = {"status": "error",
145 | "message": detailed_error_msg}
146 |
147 | # Evaluate output if not already done
148 | if self.state < States.OUTPUT_EVALUATION:
149 | for evaluator in self.run_evaluators:
150 | await evaluator(**output_message, observation=self.observation)
151 |
152 | if self.state < States.PROCESSED_OUTPUT_EVALUATION:
153 | for evaluator in self.processed_output_evaluators:
154 | await evaluator(**output_message, observation=self.observation)
155 |
156 |
157 |
158 | self.transition(States.FAILED, msg=detailed_error_msg)
159 | langfuse_client.flush()
160 |
161 | return output_message
162 |
163 | def transition(self, new_state: States, msg: Optional[str] = None):
164 | if new_state not in ALLOWED_TRANSITIONS[self.state]:
165 | raise InvalidStateTransition(
166 | self, f"Invalid state transition from {self.state.name} to {new_state.name}"
167 | )
168 | log_level = "error" if new_state == States.FAILED else "info"
169 | if new_state.name in tinyllm_config['LOGS']['LOG_STATES']:
170 | if log_level == 'error':
171 | self.log(
172 | f"transition from {self.state.name} to: {new_state.name}" + (f" ({msg})" if msg is not None else ""),
173 | level=log_level,
174 | )
175 | else:
176 | self.log(
177 | f"transition to: {new_state.name}" + (f" ({msg})" if msg is not None else ""),
178 | level=log_level,
179 | )
180 |
181 | self.state = new_state
182 |
183 | @property
184 | def log_prefix(self):
185 | base_url = "https://us.cloud.langfuse.com/project/{project_id}/traces/{trace_id}"
186 | if getattr(self, 'trace', None) is not None:
187 | trace_id = self.trace.id
188 | url = base_url.format(project_id=tinyllm_config['LANGFUSE']['PROJECT_ID'], trace_id=trace_id)
189 | return f"[{self.trace.id}][{self.name}]({url})"
190 | else:
191 | return f"[{self.name}]"
192 |
193 | def log(self, message, level="info"):
194 | if tinyllm_config['LOGS']['LOGGING']:
195 | if level == "error":
196 | tinyllm_logger.error(self.log_prefix+' '+message)
197 | else:
198 | tinyllm_logger.info(self.log_prefix+' '+message)
199 |
200 | def validate_input(self, **kwargs):
201 | return self.input_validator(**kwargs).model_dump()
202 |
203 | def validate_output(self, **kwargs):
204 | return self.output_validator(**kwargs).model_dump()
205 |
206 | def validate_processed_output(self, **kwargs):
207 | return self.processed_output_validator(**kwargs).model_dump()
208 |
209 | async def run(self, **kwargs) -> Any:
210 | return kwargs
211 |
212 | async def process_output(self, **kwargs):
213 | return kwargs
214 |
215 | async def close(self,
216 | **kwargs):
217 | pass
218 |
--------------------------------------------------------------------------------
/tinyllm/function_stream.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any, Optional
3 |
4 | from tinyllm.function import Function
5 | from tinyllm import langfuse_client, tinyllm_config
6 | from tinyllm.state import States
7 | from tinyllm.tracing.langfuse_context import observation
8 | from tinyllm.validator import Validator
9 |
10 |
11 | class DefaultFunctionStreamOutputValidator(Validator):
12 | streaming_status: str
13 | type: str # assistant_response, tool
14 | last_completion_delta: Optional[dict]
15 | completion: Any
16 |
17 |
18 | class FunctionStream(Function):
19 |
20 | def __init__(self,
21 | **kwargs):
22 | super().__init__(output_validator=DefaultFunctionStreamOutputValidator,
23 | **kwargs)
24 |
25 | @abstractmethod
26 | async def run(self,
27 | **kwargs):
28 | yield None
29 |
30 | @observation('span', stream=True)
31 | async def __call__(self, **kwargs):
32 | try:
33 | self.input = kwargs
34 |
35 | # Validate input
36 | self.transition(States.INPUT_VALIDATION)
37 | validated_input = self.validate_input(**kwargs)
38 | kwargs.update(validated_input)
39 | # Run
40 | self.transition(States.RUNNING)
41 | async for message in self.run(**validated_input):
42 | # Output validation
43 | if 'status' in message.keys():
44 | if message['status'] == 'success':
45 | message = message['output']
46 | else:
47 | raise Exception(message['message'])
48 |
49 | self.transition(States.OUTPUT_VALIDATION)
50 | self.validate_output(**message)
51 |
52 | yield {"status": "success",
53 | "output": message}
54 |
55 | self.output = message
56 |
57 | # Process output
58 | self.transition(States.PROCESSING_OUTPUT)
59 | self.processed_output = await self.process_output(**self.output)
60 |
61 | # Validate processed output
62 | if self.processed_output_validator:
63 | self.validate_processed_output(**self.processed_output)
64 |
65 | # Evaluate processed output
66 | for evaluator in self.processed_output_evaluators:
67 | await evaluator(**self.processed_output, observation=self.observation)
68 |
69 | self.transition(States.CLOSING)
70 | await self.close(**{"status": "success",
71 | "output": self.processed_output})
72 |
73 | # Complete
74 | self.transition(States.COMPLETE)
75 | langfuse_client.flush()
76 |
77 | except Exception as e:
78 | output_message = await self.handle_exception(e)
79 | # Raise or return error
80 | yield output_message
81 |
--------------------------------------------------------------------------------
/tinyllm/llms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/llms/__init__.py
--------------------------------------------------------------------------------
/tinyllm/llms/constants.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/llms/constants.py
--------------------------------------------------------------------------------
/tinyllm/llms/lite_llm.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Any
2 |
3 | import openai
4 | import litellm
5 | from litellm import acompletion
6 | from openai import OpenAIError
7 | from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type
8 |
9 | from tinyllm.function import Function
10 | from tinyllm.tracing.langfuse_context import observation
11 | from tinyllm.util.helpers import *
12 | from tinyllm.util.message import Content, Message
13 | from tinyllm.validator import Validator
14 |
15 | litellm.set_verbose = False
16 |
17 | model_parameters = [
18 | "messages",
19 | "model",
20 | "frequency_penalty",
21 | "logit_bias",
22 | "logprobs",
23 | "top_logprobs",
24 | "max_tokens",
25 | "n",
26 | "presence_penalty",
27 | "response_format",
28 | "seed",
29 | "stop",
30 | "stream",
31 | "temperature",
32 | "top_p"
33 | ]
34 |
35 |
36 | DEFAULT_LLM_MODEL = 'gpt-4o-mini'
37 |
38 |
39 | OPENAI_TOKEN_LIMITS = {
40 | "gpt-3.5-turbo-0125": 16385,
41 | "gpt-3.5-turbo-1106": 16385,
42 | "gpt-3.5-turbo": 4096,
43 | "gpt-3.5-turbo-16k": 16385,
44 | "gpt-3.5-turbo-instruct": 4096,
45 | "gpt-3.5-turbo-0613": 4096,
46 | "gpt-3.5-turbo-16k-0613": 16385,
47 | "gpt-3.5-turbo-0301": 4096,
48 | "text-davinci-003": 4096,
49 | "text-davinci-002": 4096,
50 | "code-davinci-002": 8001,
51 | "gpt-4-1106-preview": 128000,
52 | "gpt-4-0125-preview ": 128000,
53 | "gpt-4-turbo-preview": 128000,
54 | "gpt-4-vision-preview": 128000,
55 | "gpt-4": 8192,
56 | "gpt-4-32k": 32768,
57 | "gpt-4-0613": 8192,
58 | "gpt-4-32k-0613": 32768,
59 | "gpt-4-0314": 8192,
60 | "gpt-4-32k-0314": 32768,
61 |
62 | }
63 |
64 | ANYSCALE_TOKEN_LIMITS = {
65 | "anyscale/Open-Orca/Mistral-7B-OpenOrca": 8192,
66 | "anyscale/meta-llama/Llama-2-70b-chat-hf": 4096,
67 | }
68 |
69 | AZURE_TOKEN_LIMITS = {
70 | "azure/gpt41106": 128000,
71 | "azure/gpt-4o-mini": 16385,
72 | "azure/gpt4o0513": 128000,
73 | }
74 |
75 | LLM_TOKEN_LIMITS = {**OPENAI_TOKEN_LIMITS, **ANYSCALE_TOKEN_LIMITS, **AZURE_TOKEN_LIMITS}
76 |
77 |
78 |
79 | DEFAULT_CONTEXT_FALLBACK_DICT = {
80 | "gpt-3.5-turbo-0125": "gpt-4-turbo-preview",
81 | "gpt-4-1106-preview": "gpt-4-1106-preview",
82 | "gpt-3.5-turbo": "gpt-3.5-turbo-16k",
83 | "gpt-3.5-turbo-1106": "gpt-3.5-turbo-16k",
84 | "azure/gpt41106": "azure/gpt41106",
85 | "azure/gpt-4o-mini": "azure/gpt41106",
86 | "anyscale/Open-Orca/Mistral-7B-OpenOrca": "gpt-3.5-turbo-16k",
87 | "anyscale/meta-llama/Llama-2-70b-chat-hf": "gpt-3.5-turbo-16k",
88 | }
89 |
90 |
91 | class LiteLLMChatInitValidator(Validator):
92 | system_role: str
93 | answer_format_prompt: Optional[str]
94 |
95 |
96 | class LiteLLMChatInputValidator(Validator):
97 | messages: List[Union[Dict, Message]]
98 | model: Optional[str] = 'gpt-3.5-turbo'
99 | temperature: Optional[float] = 0
100 | max_tokens: Optional[int] = 850
101 | n: Optional[int] = 1
102 | stream: Optional[bool] = False
103 | context_window_fallback_dict: Optional[Dict] = DEFAULT_CONTEXT_FALLBACK_DICT
104 |
105 |
106 | class LiteLLMChatOutputValidator(Validator):
107 | type: str
108 | message: dict
109 | response: Any
110 |
111 |
112 | class LiteLLM(Function):
113 | def __init__(self, **kwargs):
114 | super().__init__(input_validator=LiteLLMChatInputValidator,
115 | **kwargs)
116 | self.generation = None
117 |
118 | def _validate_tool_args(self, **kwargs):
119 | tools_args = {}
120 | if kwargs.get('tools', None) is not None:
121 | tools_args = {
122 | 'tools': kwargs['tools'],
123 | 'tool_choice': kwargs.get('tool_choice', 'auto')
124 | }
125 | return tools_args
126 |
127 | def _parse_mesages(self, messages):
128 | if isinstance(messages[0], Message):
129 | messages = [message.to_dict() for message in messages]
130 | return messages
131 |
132 | @observation(observation_type='generation', input_mapping={'input': 'messages'},
133 | output_mapping={'output': 'response'})
134 | @retry(
135 | stop=stop_after_attempt(3),
136 | wait=wait_random_exponential(min=1, max=10),
137 | retry=retry_if_exception_type((OpenAIError, openai.InternalServerError))
138 | )
139 | async def run(self, **kwargs):
140 | kwargs['messages'] = self._parse_mesages(kwargs['messages'])
141 | tools_args = self._validate_tool_args(**kwargs)
142 | completion_kwargs = {arg: kwargs[arg] for arg in kwargs if arg in model_parameters}
143 | completion_kwargs.update(tools_args)
144 | api_result = await acompletion(
145 | **completion_kwargs,
146 | )
147 | model_dump = api_result.model_dump()
148 | msg_type = 'tool' if model_dump['choices'][0]['finish_reason'] == 'tool_calls' else 'completion'
149 | message = model_dump['choices'][0]['message']
150 | return {
151 | "type": msg_type,
152 | "message": message,
153 | "response": model_dump,
154 | "completion": message['content'],
155 | }
156 |
--------------------------------------------------------------------------------
/tinyllm/llms/lite_llm_stream.py:
--------------------------------------------------------------------------------
1 | import openai
2 | from litellm import acompletion
3 | from openai import OpenAIError
4 | from tenacity import stop_after_attempt, wait_random_exponential, retry_if_exception_type, retry
5 |
6 | from tinyllm.llms.lite_llm import LiteLLM, DEFAULT_CONTEXT_FALLBACK_DICT, DEFAULT_LLM_MODEL, model_parameters
7 | from tinyllm.function_stream import FunctionStream
8 | from tinyllm.tracing.langfuse_context import observation
9 | from tinyllm.util.helpers import get_openai_message
10 |
11 |
12 | class LiteLLMStream(LiteLLM, FunctionStream):
13 |
14 | @retry(
15 | stop=stop_after_attempt(3),
16 | wait=wait_random_exponential(min=1, max=10),
17 | retry=retry_if_exception_type((OpenAIError, openai.InternalServerError))
18 | )
19 | @observation(observation_type='generation', stream=True)
20 | async def run(self, **kwargs):
21 | kwargs['messages'] = self._parse_mesages(kwargs['messages'])
22 | tools_args = self._validate_tool_args(**kwargs)
23 | completion_kwargs = {arg: kwargs[arg] for arg in kwargs if arg in model_parameters}
24 | completion_kwargs.update(tools_args)
25 | completion_kwargs['stream'] = True
26 |
27 | response = await acompletion(
28 | **completion_kwargs,
29 | )
30 |
31 | # We need to track 2 things: the response delta and the function_call
32 | function_call = {
33 | "name": None,
34 | "arguments": ""
35 | }
36 | completion = ""
37 | last_completion_delta = None
38 | finish_delta = None
39 |
40 | # OpenAI function call works as follows: function name available at delta.tool_calls[0].function.
41 | # It returns a dict where: 'name' is returned only in the first chunk
42 | # tool argument tokens are sent in chunks after so need to keep track of them
43 |
44 | async for chunk in response:
45 | chunk_dict = chunk.dict()
46 | status = self.get_streaming_status(chunk_dict)
47 | chunk_role = self.get_chunk_type(chunk_dict)
48 | delta = chunk_dict['choices'][0]['delta']
49 |
50 | # When using tools:
51 | # We need the last response delta as it contains the full function message
52 | # The finish message does not contain any delta
53 |
54 | # If streaming , we need to store chunks of the completion/function call
55 | if status == "streaming":
56 | if chunk_role == "assistant":
57 | if delta['content']:
58 | completion += delta['content']
59 | last_completion_delta = delta
60 | elif chunk_role == "tool":
61 | if function_call['name'] is None:
62 | function_call['name'] = delta['tool_calls'][0]['function']['name']
63 | if last_completion_delta is None:
64 | last_completion_delta = delta
65 |
66 | completion = function_call
67 | if delta['tool_calls'][0]['function']['arguments']:
68 | function_call['arguments'] += delta['tool_calls'][0]['function']['arguments']
69 |
70 | elif status == "finished-streaming":
71 | finish_delta = delta
72 |
73 | yield {
74 | "streaming_status": status,
75 | "type": chunk_role,
76 | "last_completion_delta": last_completion_delta,
77 | "finish_delta": finish_delta,
78 | "completion": completion,
79 | "message": get_openai_message(role=chunk_role,
80 | content=completion),
81 | "last_chunk": chunk_dict,
82 | }
83 |
84 | def get_chunk_type(self,
85 | chunk):
86 | delta = chunk['choices'][0]['delta']
87 |
88 | if delta.get('tool_calls', None) is not None or chunk['choices'][0]['finish_reason'] == 'tool_calls':
89 | return "tool"
90 |
91 | return "assistant"
92 |
93 | def get_streaming_status(self,
94 | chunk):
95 | if chunk['choices'][0]['finish_reason'] in ['stop', 'tool_calls']:
96 | return "finished-streaming"
97 | else:
98 | return "streaming"
99 |
--------------------------------------------------------------------------------
/tinyllm/llms/tiny_function.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import traceback
3 | from textwrap import dedent
4 |
5 | from pydantic import BaseModel, create_model
6 | from typing import Type
7 |
8 | from tenacity import retry, stop_after_attempt, retry_if_exception_type, wait_fixed
9 |
10 | from tinyllm.agent.agent import Agent
11 | from tinyllm.exceptions import MissingBlockException, LLMJsonValidationError
12 | from tinyllm.tracing.langfuse_context import observation
13 | from tinyllm.util.parse_util import *
14 |
15 |
16 | def create_pydantic_model_from_dict(data: Dict[str, Any]) -> BaseModel:
17 | fields = {key: (type(value), ...) for key, value in data.items()}
18 | JSONOutput = create_model('JSONOutput', **fields)
19 | model_instance = JSONOutput(**data)
20 | return model_instance
21 |
22 |
23 | def model_to_string(model) -> str:
24 | fields = model.__fields__
25 | field_defs = []
26 | for field_name, field in fields.items():
27 | field_type = str(field.annotation).replace('typing.', '')
28 | description = field.description
29 | description = f" | Description: {description}" if description else ""
30 | field_defs.append(
31 | f" {field_name}: {field_type}" + description)
32 | model_prompt = "Model:\n" + "\n".join(field_defs) if field_defs else ""
33 | return model_prompt
34 |
35 |
36 | def get_system_role(func,
37 | output_model) -> str:
38 | system_tag = extract_html(func.__doc__.strip(), tag='system')
39 | system_prompt = dedent(system_tag[0]) + '\n' + dedent("""
40 | OUTPUT FORMAT
41 | Your output must be in JSON
42 |
43 | {data_model}
44 |
45 | """)
46 |
47 |
48 | pydantic_model = model_to_string(output_model) if output_model else None
49 |
50 | data_model = dedent(f"""DATA MODEL
51 | Your output must respect the following pydantic model:
52 | {pydantic_model}""") if pydantic_model else ""
53 |
54 | final_prompt = system_prompt.format(pydantic_model=pydantic_model,
55 | data_model=data_model)
56 |
57 | return final_prompt
58 |
59 |
60 |
61 | class JsonParsingException(Exception):
62 | pass
63 |
64 |
65 | def tiny_function(output_model: Type[BaseModel] = None,
66 | example_manager=None,
67 | model_kwargs={}):
68 | def decorator(func):
69 | @functools.wraps(func)
70 | @retry(
71 | reraise=True,
72 | stop=stop_after_attempt(3),
73 | wait=wait_fixed(1),
74 | retry=retry_if_exception_type((MissingBlockException, LLMJsonValidationError))
75 | )
76 | async def wrapper(*args, **kwargs):
77 |
78 | @observation(observation_type='span', name=func.__name__)
79 | async def traced_call(func, *args, **kwargs):
80 | system_role = get_system_role(func=func,
81 | output_model=output_model)
82 |
83 | prompt = extract_html(func.__doc__.strip(), tag='prompt')
84 | if len(prompt) == 0:
85 | assert 'content' in kwargs, "tinyllm_function requires content kwarg"
86 | agent_input_content = kwargs['content']
87 | else:
88 | prompt = prompt[0]
89 | agent_input_content = prompt.format(**kwargs)
90 |
91 | agent = Agent(
92 | name=func.__name__,
93 | system_role=system_role,
94 | example_manager=example_manager
95 | )
96 | model_kwargs['response_format'] = {"type": "json_object"}
97 |
98 | result = await agent(content=agent_input_content,
99 | **model_kwargs)
100 | if result['status'] == 'success':
101 | msg_content = result['output']['response']['choices'][0]['message']['content']
102 | try:
103 | parsed_output = json.loads(msg_content)
104 | if output_model is None:
105 | function_output_model = create_pydantic_model_from_dict(parsed_output)
106 | else:
107 | try:
108 | function_output_model = output_model(**parsed_output)
109 | except:
110 | raise LLMJsonValidationError(f"Output does not match the expected model: {output_model}")
111 |
112 | return {
113 | 'status': 'success',
114 | 'output': function_output_model
115 | }
116 |
117 | except (ValueError, json.JSONDecodeError) as e:
118 | return {"message": f"Parsing error : {traceback.format_exc()}",
119 | 'status': 'error'}
120 | else:
121 | return {
122 | 'status': 'error',
123 | "message": "Agent failed", "details": result
124 | }
125 |
126 | response = await traced_call(func, *args, **kwargs)
127 | return response
128 |
129 | return wrapper
130 |
131 | return decorator
132 |
--------------------------------------------------------------------------------
/tinyllm/memory/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/memory/__init__.py
--------------------------------------------------------------------------------
/tinyllm/memory/memory.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Union
3 |
4 | from tinyllm.function import Function
5 | from tinyllm.util.helpers import count_tokens
6 | from tinyllm.util.message import Message
7 | from tinyllm.validator import Validator
8 |
9 |
10 | class MemoryOutputValidator(Validator):
11 | memories: list
12 |
13 |
14 | class MemoryInputValidator(Validator):
15 | message: Message
16 |
17 |
18 | class Memory(Function):
19 | def __init__(self,
20 | **kwargs):
21 | super().__init__(
22 | output_validator=MemoryOutputValidator,
23 | **kwargs
24 | )
25 | self.memories = None
26 |
27 | async def run(self, **kwargs):
28 | self.memories.append(kwargs['message'])
29 | return {'memories': self.memories}
30 |
31 | @property
32 | def size(self):
33 | return count_tokens(self.memories)
34 |
35 | @abstractmethod
36 | async def load_memories(self):
37 | pass
38 |
39 | @abstractmethod
40 | async def get_memories(self):
41 | pass
42 |
43 |
44 | class BufferMemoryInitValidator(Validator):
45 | buffer_size: int
46 |
47 |
48 | class BufferMemory(Memory):
49 |
50 | def __init__(self,
51 | buffer_size=10,
52 | **kwargs):
53 | BufferMemoryInitValidator(buffer_size=buffer_size)
54 | super().__init__(**kwargs)
55 | self.buffer_size = buffer_size
56 | self.memories = []
57 |
58 | async def get_memories(self):
59 |
60 | memories_to_return = []
61 | msg_count = 0
62 | # Make sure we keep complete tool calls msgs
63 | for memory in self.memories[::-1]:
64 | memories_to_return.append(memory)
65 | if 'tool_calls' in memory.to_dict() or memory.role == 'tool':
66 | continue
67 | else:
68 | msg_count += 1
69 |
70 | if msg_count == self.buffer_size:
71 | break
72 |
73 | return memories_to_return[::-1]
74 |
75 |
76 |
77 | class CharacterBufferMemory(BufferMemory):
78 |
79 | async def load_memories(self):
80 | pass
81 |
82 |
83 | async def get_memories(self):
84 |
85 | memories_to_return = []
86 | size_count = 0
87 | # Make sure we keep complete tool calls msgs
88 | for memory in self.memories[::-1]:
89 | size_count += len(str(memory.to_dict())) * 0.9
90 | if size_count >= self.buffer_size:
91 | break
92 | memories_to_return.append(memory)
93 | if 'tool_calls' in memory.to_dict() or memory.role == 'tool':
94 | continue
95 | else:
96 | size_count += len(str(memory.to_dict()))*0.9
97 |
98 | return memories_to_return[::-1]
99 |
--------------------------------------------------------------------------------
/tinyllm/prompt_manager.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | import datetime as dt
3 | from enum import Enum
4 | from typing import Callable
5 |
6 | from tinyllm.examples.example_manager import ExampleManager
7 | from tinyllm.llms.lite_llm import DEFAULT_LLM_MODEL, LLM_TOKEN_LIMITS, DEFAULT_CONTEXT_FALLBACK_DICT
8 | from tinyllm.memory.memory import Memory
9 | from tinyllm.util.helpers import get_openai_message, count_tokens
10 | from tinyllm.util.message import SystemMessage, UserMessage, AssistantMessage
11 |
12 |
13 | class MaxTokensStrategy(Enum):
14 | MAX = 'max_context'
15 | EXPECTED_RATIO = 'max_user_context_examples'
16 |
17 |
18 | class PromptManager:
19 | """
20 | This class is responsible for formatting the prompt for the LLM and managing:
21 | - model (to avoid exceeding the token limit)
22 | - max_tokens (based on the expected completion size)
23 | """
24 |
25 | def __init__(self,
26 | system_role: str,
27 | example_manager: ExampleManager=ExampleManager(),
28 | memory: Memory = None,
29 | update_system_content: Callable = lambda x: x,
30 | initial_user_message_text: str = None,
31 | is_time_aware: bool = True, ):
32 | self.system_role = system_role
33 | self.example_manager = example_manager
34 | self.memory = memory
35 | self.initial_user_message_text = initial_user_message_text.strip() if initial_user_message_text is not None else None
36 | self.update_system_content = update_system_content
37 | self.is_time_aware = is_time_aware
38 |
39 | async def format_messages(self, messages):
40 |
41 | current_time = '\n\n\n'
44 |
45 | system_content = self.update_system_content(self.system_role) + '\n\n' + current_time
46 | system_msg = SystemMessage(system_content)
47 | memories = [] if self.memory is None else await self.memory.get_memories()
48 | examples = []
49 |
50 | if self.example_manager is not None:
51 | for example in self.example_manager.constant_examples:
52 | examples.append(example.user_message)
53 | examples.append(example.assistant_message)
54 |
55 | for message in messages:
56 | if self.example_manager and self.example_manager.example_selector and message['role'] == 'user':
57 | best_examples = await self.example_manager.example_selector(input=message['content'])
58 | for good_example in best_examples['output']['best_examples']:
59 | examples.append(UserMessage(good_example['user']))
60 | examples.append(AssistantMessage(good_example['assistant']))
61 |
62 | answer_format_msg = [
63 | UserMessage(self.initial_user_message_text)] if self.initial_user_message_text is not None else []
64 |
65 | messages = [system_msg] + memories + examples + answer_format_msg + messages
66 | return messages
67 |
68 | async def prepare_llm_request(self,
69 | messages,
70 | json_model=None,
71 | **kwargs):
72 |
73 | messages = await self.format_messages(messages)
74 |
75 | if kwargs['max_tokens_strategy']:
76 | max_tokens, model = self.get_run_config(messages=messages, **kwargs)
77 | kwargs['max_tokens'] = max_tokens
78 | kwargs['model'] = model
79 | else:
80 | kwargs['model'] = kwargs.get('model', DEFAULT_LLM_MODEL)
81 | kwargs['max_tokens'] = kwargs.get('max_tokens', 800)
82 |
83 | kwargs['messages'] = messages
84 | if json_model:
85 | kwargs['response_format'] = json_model
86 | return kwargs
87 |
88 | async def add_memory(self,
89 | message):
90 | if self.memory is not None:
91 | await self.memory(message=message)
92 |
93 | @property
94 | async def size(self):
95 | messages = await self.prepare_llm_request(message=get_openai_message(role='user', content=''))
96 | return count_tokens(messages)
97 |
98 | def get_run_config(self, messages, **kwargs):
99 |
100 | user_msg = messages[-1]
101 | user_msg_size = count_tokens(user_msg)
102 | all_msg_size = count_tokens(messages)
103 | model = kwargs.get('model', DEFAULT_LLM_MODEL)
104 | model_token_limit = LLM_TOKEN_LIMITS[model]
105 |
106 | if 'max_tokens' in kwargs:
107 | max_tokens = kwargs['max_tokens']
108 | elif kwargs['max_tokens_strategy'] == MaxTokensStrategy.MAX:
109 | leftover_to_use = model_token_limit - all_msg_size
110 | max_tokens = min(kwargs.get('allowed_max_tokens', 4096), leftover_to_use)
111 | elif kwargs['max_tokens_strategy'] == MaxTokensStrategy.EXPECTED_RATIO:
112 | max_tokens = min(max(800, user_msg_size * kwargs['expected_io_ratio']), 4096)
113 | expected_total_size = all_msg_size + max_tokens
114 | if expected_total_size / model_token_limit > 1:
115 | model = DEFAULT_CONTEXT_FALLBACK_DICT[kwargs['model']]
116 |
117 | return int(max_tokens), model
118 |
--------------------------------------------------------------------------------
/tinyllm/rag/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/rag/__init__.py
--------------------------------------------------------------------------------
/tinyllm/rag/document/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/rag/document/__init__.py
--------------------------------------------------------------------------------
/tinyllm/rag/document/document.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | from tinyllm.util.helpers import count_tokens
4 | from tinyllm.util.prompt_util import stringify_dict
5 |
6 |
7 | class DocumentTypes(Enum):
8 | TEXT = 'text'
9 | DICTIONARY = 'dictionary'
10 | TABLE = 'table'
11 | IMAGE = 'image'
12 |
13 |
14 | class Document:
15 |
16 | def __init__(self,
17 | content,
18 | metadata: dict={},
19 | embeddings = None,
20 | type=DocumentTypes.TEXT,
21 | header='[doc]',
22 | include_keys=['content']):
23 | self.content = content
24 | self.metadata = metadata
25 | self.type = type
26 | self.header = header
27 | self.include_keys = include_keys
28 | self.embeddings = embeddings
29 |
30 | @property
31 | def size(self):
32 | content = self.to_string()
33 | return count_tokens(content)
34 |
35 | def to_string(self,
36 | **kwargs):
37 | full_dict = self.metadata.copy()
38 | full_dict.update({'content': self.content})
39 | return stringify_dict(header=kwargs.get('header', self.header),
40 | dict=full_dict,
41 | include_keys=kwargs.get('include_keys', self.include_keys))
42 |
43 |
44 | class ImageDocument(Document):
45 |
46 | def __init__(self, url, **kwargs):
47 | super().__init__(**kwargs)
48 | self.url = url
49 |
--------------------------------------------------------------------------------
/tinyllm/rag/document/store.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from tinyllm.rag.document.document import Document
4 |
5 |
6 | def remove_duplicate_dicts(list_of_lists):
7 | # Flatten the list of lists to easily identify duplicates across all lists
8 | flattened_list = [item for sublist in list_of_lists for item in sublist]
9 | # Remove duplicates while preserving order
10 | seen = set()
11 | unique_flattened_list = []
12 | for item in flattened_list:
13 | # Dictionaries are not hashable, so we use their string representation to keep track of duplicates
14 | item_str = str(item)
15 | if item_str not in seen:
16 | seen.add(item_str)
17 | unique_flattened_list.append(item)
18 |
19 | # Reconstruct the list of lists with unique dictionaries
20 | unique_list_of_lists = []
21 | for original_list in list_of_lists:
22 | new_list = []
23 | for item in original_list:
24 | if item in unique_flattened_list:
25 | new_list.append(item)
26 | # Remove the item from the unique list so it won't appear again
27 | unique_flattened_list.remove(item)
28 | unique_list_of_lists.append(new_list)
29 |
30 | return unique_list_of_lists
31 |
32 |
33 | class DocumentStore:
34 | def __init__(self):
35 | self.store = {}
36 |
37 | def add_docs(self,
38 | docs: List[Document],
39 | name: str):
40 | if name in self.store:
41 | self.store[name] += docs
42 | else:
43 | self.store[name] = docs
44 |
45 | def fit_store(self,
46 | context_size,
47 | weights: Optional[List[float]] = None) -> List[Document]:
48 |
49 | # If weights are not provided, distribute docs evenly
50 | docs_lists = list(self.store.values())
51 |
52 | if not weights:
53 | weights = [1 / len(docs_lists)] * len(docs_lists)
54 | else:
55 | assert len(weights) == len(
56 | docs_lists), "Length of weights must be equal to the number of document store keys."
57 |
58 | # Normalize weights to ensure they sum up to 1
59 | total_weight = sum(weights)
60 | normalized_weights = [weight / total_weight for weight in weights]
61 |
62 | # Calculate token size for each source based on weights
63 | token_sizes = [int(weight * context_size) for weight in normalized_weights]
64 |
65 | i = 0
66 | section_names = list(self.store.keys())
67 | for doc_list, token_size in zip(docs_lists, token_sizes):
68 | section_docs = []
69 | # For each doc source, count how many docs can fit into its token size
70 | current_tokens = 0
71 | for doc in doc_list:
72 | doc_tokens = doc.size
73 | if current_tokens + doc_tokens <= token_size:
74 | section_docs.append(doc)
75 | current_tokens += doc_tokens
76 | else:
77 | break
78 |
79 | self.store[section_names[i]] = section_docs
80 | i += 1
81 |
82 | def to_string(self,
83 | doc_header: str = '[doc]',
84 | include_keys: List[str] = ['content', 'metadata'],
85 | start_string: str = '-----SUPPORTING DOCS-----',
86 | end_string: str = '-----END OF SUPPORTING DOCS-----',
87 | context_size: int = None,
88 | weights: [List[float]] = None,
89 | ):
90 | # Fit the multiple sources of docs based on weights
91 | self.fit_store(context_size,
92 | weights)
93 |
94 | # Convert to appropriate format
95 | fitted_context_string = ''
96 | for section_name, docs in self.store.items():
97 | fitted_context_string += section_name + '\n'
98 | fitted_context_string += f'/n '.join([doc.to_string(include_keys=include_keys,
99 | header=doc_header) for doc in docs])
100 |
101 | # Format the final context
102 | formatted_context = start_string + "\n" + fitted_context_string + "\n" + end_string
103 | return formatted_context
104 |
--------------------------------------------------------------------------------
/tinyllm/rag/loaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/rag/loaders/__init__.py
--------------------------------------------------------------------------------
/tinyllm/rag/loaders/image_loader.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/rag/loaders/image_loader.py
--------------------------------------------------------------------------------
/tinyllm/rag/loaders/loaders.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pprint
3 | from typing import List, Union
4 |
5 | import pandas as pd
6 |
7 | from smartpy.cloud.do.storage import DigitalOcean
8 | from smartpy.utility import os_util
9 | from tinyllm import tinyllm_config
10 | from tinyllm.llms.tiny_function import tiny_function
11 | from tinyllm.rag.document.document import Document, ImageDocument
12 | from tinyllm.util.message import UserMessage, Text, Image, Content
13 |
14 |
15 | class Loader:
16 |
17 | def __init__(self, file_path):
18 | self.file_path = file_path
19 |
20 | @classmethod
21 | def load(self) -> Document:
22 | pass
23 |
24 |
25 | class ExcelLoader(Loader):
26 |
27 | def __init__(self, file_path, sheets=None):
28 | super().__init__(file_path)
29 | self.excel_file = pd.ExcelFile(file_path, engine='openpyxl')
30 | self.sheets = sheets or list(self.excel_file.sheet_names)
31 |
32 | def get_content(self) -> str:
33 | file_content = ""
34 | for sheet in self.sheets:
35 | title = sheet.upper()
36 | file_content += f"## {title}\n"
37 | file_content += pd.read_excel(self.file_path, sheet_name=sheet, engine='openpyxl').to_markdown()
38 |
39 | def get_screenshots(self) -> List[str]:
40 | pass
41 |
42 |
43 | @tiny_function(model_kwargs={'model': 'gpt-4-vision-preview'})
44 | async def parse_image(content: Union[Content, List[Content]]):
45 | """
46 |
47 | ROLE:
48 | You are an Image document parser. You will be provided an image and/or text content from a single document. Your goal is to extract
49 | structured data (sections, fields, descriptions) from this document using the provided img/text.
50 |
51 | OUTPUT FORMAT:
52 | Your output should be a JSON object that properly structures the extracted data. Make sure the section
53 | and field names are semantically meaningful.
54 |
55 | """
56 | pass
57 |
58 |
59 | class ImageStorageSources:
60 | DO = 'digital ocean'
61 |
62 |
63 | class ImageLoader(Loader):
64 |
65 | def __init__(self,
66 | file_path,
67 | url: str = None,
68 | content: str = None,
69 | storage_source=None):
70 | super().__init__(file_path)
71 | self.url = url
72 | self.img_local_path = file_path
73 | self.storage_source = storage_source
74 | self.content = content
75 |
76 | def store_image(self):
77 | if self.url:
78 | return self.url
79 |
80 | if self.storage_source == ImageStorageSources.DO:
81 | do = DigitalOcean(
82 | region_name="nyc3",
83 | endpoint_url=tinyllm_config['CLOUD_PROVIDERS']['DO']['ENDPOINT'],
84 | key_id=tinyllm_config['CLOUD_PROVIDERS']['DO']['KEY'],
85 | secret_access_key=tinyllm_config['CLOUD_PROVIDERS']['DO']['SECRET']
86 | )
87 | self.url = do.upload_file(
88 | project_name=tinyllm_config['CLOUD_PROVIDERS']['DO']['PROJECT_NAME'],
89 | space_name="tinyllm",
90 | file_src=self.img_local_path,
91 | is_public=True
92 | )
93 |
94 | return self.url
95 |
96 | async def parse_with_ai(self):
97 | content = [
98 | Text("Use the provided img and its content to extract the relevant structured data and sections"),
99 | Image(self.url)
100 | ]
101 | parsing_output = await parse_image(content=content)
102 | return parsing_output['output']
103 |
104 | async def async_load(self, parse=False) -> Document:
105 | self.store_image()
106 | content = None
107 | if parse:
108 | content = str(self.parse_with_ai())
109 | image_doc = ImageDocument(
110 | content=content,
111 | url=self.url,
112 | metadata={})
113 | return image_doc
114 |
115 |
116 | class PDFFormLoader(Loader):
117 |
118 | def __init__(self, file_path):
119 | super().__init__(file_path)
120 | from PyPDF2 import PdfReader
121 | self.pdf_reader = PdfReader(open(file_path, "rb"))
122 |
123 | async def async_load(self,
124 | images=False) -> str:
125 | # Parse form dict
126 | form_content = self.pdf_reader.get_form_text_fields()
127 | form_content = {str(k): (float(v) if v.isdigit() else v) for k, v in form_content.items() if
128 | v is not None and len(v) > 0}
129 | form_content = pprint.pformat(form_content)
130 | img_docs = []
131 | if images:
132 | screenshot_paths = self.get_screenshots(self.file_path)
133 | image_loaders = [ImageLoader(file_path=image_path,
134 | content=form_content,
135 | storage_source=ImageStorageSources.DO) for image_path in
136 | screenshot_paths]
137 | img_docs = [await img_loader.async_load() for img_loader in image_loaders]
138 |
139 | pdf_form_doc = Document(content=form_content,
140 | metadata={'urls': [img_doc.url for img_doc in img_docs]})
141 |
142 | return [pdf_form_doc]+img_docs
143 |
144 | def get_screenshots(self, pdf_path, dpi=500):
145 | from pdf2image import convert_from_path
146 |
147 | images_path = os_util.getTempDir('tinyllm/files/')
148 | base_name = '-'.join(os_util.getBaseName(pdf_path).split('.'))
149 | pages = convert_from_path(pdf_path, dpi)
150 | image_paths = []
151 | for count, page in enumerate(pages):
152 | file_name = base_name + f'_page_{count}.jpg'
153 | img_path = os.path.join(images_path, file_name)
154 | page.update(img_path, 'JPEG')
155 | image_paths.append(img_path)
156 | return image_paths
157 |
158 |
159 | async def main():
160 | loader = PDFFormLoader(
161 | file_path='/Users/othmanezoheir/PycharmProjects/zoheir-consulting/lendmarq-ai/docs/Loan Application.pdf')
162 | doc = await loader.async_load(images=True)
163 |
164 |
165 | if __name__ == '__main__':
166 | import asyncio
167 |
168 | asyncio.run(main())
169 |
--------------------------------------------------------------------------------
/tinyllm/rag/rerank.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict
2 |
3 | from smartpy.utility.ai_util import get_top_n_diverse_texts
4 | from tinyllm.rag.document.document import Document
5 |
6 |
7 | class ReRanker:
8 |
9 | def __init__(self,
10 | docs = [],
11 | scores = []):
12 | self.docs = docs
13 | self.scores = scores
14 |
15 | def add_doc(self,
16 | doc: Document,
17 | scores: float):
18 | self.docs.append(doc)
19 | self.scores.append(scores)
20 |
21 | def rerank(self,
22 | top_n: int) -> List[Document]:
23 | # Normalize scores
24 | # Average scores for each doc
25 | # Sort by score
26 | # Implement MMR based
27 | texts = [doc.content for doc in self.docs]
28 | embeddings = [doc.embeddings for doc in self.docs]
29 | top_n_texts = get_top_n_diverse_texts(texts=texts,
30 | embeddings=embeddings,
31 | top_n=top_n)
32 | top_n_docs = [doc for doc in self.docs if doc.content in top_n_texts]
33 | return top_n_docs
--------------------------------------------------------------------------------
/tinyllm/rag/vector_store.py:
--------------------------------------------------------------------------------
1 | import tinyllm
2 |
3 | from sqlalchemy import text
4 | from sqlalchemy import Column, Integer, String, UniqueConstraint
5 | from sqlalchemy.ext.declarative import declarative_base
6 | from sqlalchemy.dialects.postgresql import insert
7 | from pgvector.sqlalchemy import Vector
8 | from sqlalchemy.dialects import postgresql
9 |
10 | from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
11 | from sqlalchemy.orm import sessionmaker
12 |
13 | from tinyllm.function import Function
14 | from tinyllm.rag.document.document import Document, DocumentTypes
15 | from tinyllm.tracing.langfuse_context import observation
16 | from sqlalchemy import text as sql_text
17 |
18 | Base = declarative_base()
19 |
20 |
21 | def get_database_uri():
22 | postgres_config = tinyllm.tinyllm_config['POSTGRES']
23 | user = postgres_config['USERNAME']
24 | password = postgres_config['PASSWORD']
25 | host = postgres_config['HOST']
26 | port = postgres_config['PORT']
27 | name = postgres_config['NAME']
28 | return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
29 |
30 |
31 | class Embeddings(Base):
32 | __tablename__ = "embeddings"
33 | __table_args__ = (UniqueConstraint('text', 'collection_name', name='uq_text_collection_name'),)
34 |
35 | id = Column(Integer, primary_key=True)
36 | created_at = Column(postgresql.TIMESTAMP, server_default=text('now()'), nullable=False)
37 |
38 | collection_name = Column(String, nullable=False)
39 | embedding = Column(Vector(dim=384))
40 | text = Column(String)
41 | emetadata = Column(postgresql.JSON)
42 |
43 |
44 | class VectorStore(Function):
45 |
46 | def __init__(self,
47 | embedding_function):
48 | self.name = 'PGVectorStore'
49 | self._engine = create_async_engine(get_database_uri())
50 | self._Session = sessionmaker(
51 | bind=self._engine,
52 | expire_on_commit=False,
53 | class_=AsyncSession,
54 | )
55 | self.embedding_function = embedding_function
56 | # self.create_tables()
57 |
58 | def _build_metadata_filters(self, metadata_filters):
59 | filter_clauses = []
60 | for key, value in metadata_filters.items():
61 | if isinstance(value, list):
62 | str_values = [str(v) for v in value] # Convert all values to string
63 | filter_by_metadata = Embeddings.emetadata[key].astext.in_(str_values)
64 | filter_clauses.append(filter_by_metadata)
65 | elif isinstance(value, dict) and "in" in map(str.lower, value.keys()):
66 | value_case_insensitive = {k.lower(): v for k, v in value.items()}
67 | filter_by_metadata = Embeddings.emetadata[key].astext.in_(value_case_insensitive["in"])
68 | filter_clauses.append(filter_by_metadata)
69 | else:
70 | filter_by_metadata = Embeddings.emetadata[key].astext == str(value)
71 | filter_clauses.append(filter_by_metadata)
72 | return filter_clauses
73 |
74 | async def create_tables(self):
75 | async with self._engine.begin() as conn:
76 | await conn.execute(text("ALTER DATABASE defaultdb SET SEARCH_PATH TO postgres_schema;"))
77 | await conn.execute(text("""CREATE EXTENSION vector"""))
78 | async with self._engine.begin() as conn:
79 | await conn.run_sync(lambda conn_sync: Base.metadata.create_all(conn_sync))
80 |
81 |
82 | async def add_texts(self, texts, collection_name, metadatas=None):
83 | if metadatas is None:
84 | metadatas = [None] * len(texts)
85 |
86 | embeddings = await self.embedding_function(texts)
87 | async with self._Session() as session:
88 | async with session.begin():
89 | for text, embedding, metadata in zip(texts, embeddings, metadatas):
90 | stmt = insert(Embeddings).values(
91 | created_at=sql_text('now()'),
92 | text=text,
93 | embedding=embedding,
94 | emetadata=metadata,
95 | collection_name=collection_name
96 | ).on_conflict_do_nothing()
97 |
98 | await session.execute(stmt)
99 | await session.commit()
100 |
101 | return {
102 | "documents": [
103 | {
104 | "document": Document(content=text,
105 | type=DocumentTypes.TEXT,
106 | metadata=metadata),
107 | "embedding": embedding
108 |
109 | } for text, embedding, metadata in zip(texts, embeddings, metadatas)
110 | ]
111 | }
112 |
113 | @observation(observation_type='span')
114 | async def similarity_search(self, query, k, collection_filters, metadata_filters=None):
115 | query_embedding = await self.embedding_function(query)
116 | query_embedding = query_embedding[0]
117 | async with self._Session() as session: # Use an asynchronous session
118 | async with session.begin():
119 | # Initialize the base SQL query
120 | base_query = 'SELECT text, emetadata, collection_name, embedding <-> :embedding AS distance FROM embeddings'
121 |
122 | # Initialize the WHERE clauses and parameters
123 | where_clauses = []
124 | params = {'embedding': str(list(query_embedding)), 'k': k}
125 |
126 | # Apply collection filters if they exist
127 | if collection_filters:
128 | where_clauses.append('collection_name = ANY(:collection_filters)')
129 | params['collection_filters'] = collection_filters
130 |
131 | # Apply metadata filters if they exist
132 | if metadata_filters:
133 | for key, values in metadata_filters.items():
134 | where_clauses.append(f"emetadata->>'{key}' = ANY(:{key})")
135 | params[key] = values
136 |
137 | # Construct the final query
138 | if where_clauses:
139 | final_query = f"{base_query} WHERE {' AND '.join(where_clauses)} ORDER BY distance ASC LIMIT :k"
140 | else:
141 | final_query = f"{base_query} ORDER BY distance ASC LIMIT :k"
142 |
143 | # Execute the query
144 | stmt = text(final_query)
145 | result = await session.execute(stmt, params)
146 | rows = result.all()
147 |
148 | docs = [
149 | {
150 | 'document': Document(content=row.text,
151 | type=ContentTypes.TEXT,
152 | metadata=row.emetadata),
153 | 'distance': row.distance,
154 | } for row in rows
155 | ]
156 |
157 | return {
158 | "documents": docs,
159 | }
160 |
--------------------------------------------------------------------------------
/tinyllm/state.py:
--------------------------------------------------------------------------------
1 | from enum import IntEnum
2 |
3 |
4 | class States(IntEnum):
5 | INIT = 1
6 | INPUT_VALIDATION = 2
7 | RUNNING = 3
8 | OUTPUT_VALIDATION = 4
9 | OUTPUT_EVALUATION = 5
10 | PROCESSING_OUTPUT = 6
11 | PROCESSED_OUTPUT_VALIDATION = 7
12 | PROCESSED_OUTPUT_EVALUATION = 8
13 | CLOSING = 9
14 | COMPLETE = 10
15 | FAILED = 11
16 |
17 |
18 | TERMINAL_STATES = [States.COMPLETE, States.FAILED]
19 |
20 | ALLOWED_TRANSITIONS = {
21 | None: [States.INIT],
22 |
23 | States.INIT: [States.INPUT_VALIDATION, States.FAILED],
24 | States.INPUT_VALIDATION: [States.RUNNING, States.FAILED],
25 |
26 | States.RUNNING: [States.OUTPUT_VALIDATION, States.FAILED],
27 |
28 | States.OUTPUT_VALIDATION: [States.COMPLETE, States.OUTPUT_VALIDATION, States.PROCESSING_OUTPUT,
29 | States.OUTPUT_EVALUATION, States.FAILED],
30 | # Can transition to itself in the case of a streaming function
31 | States.OUTPUT_EVALUATION: [States.COMPLETE, States.PROCESSING_OUTPUT, States.PROCESSING_OUTPUT, States.FAILED],
32 |
33 | States.PROCESSING_OUTPUT: [
34 | States.PROCESSED_OUTPUT_VALIDATION,
35 | States.CLOSING,
36 | States.FAILED,
37 | States.COMPLETE
38 | ],
39 | States.PROCESSED_OUTPUT_VALIDATION: [
40 | States.PROCESSED_OUTPUT_EVALUATION,
41 | States.COMPLETE,
42 | States.FAILED,
43 | ],
44 | States.PROCESSED_OUTPUT_EVALUATION: [States.CLOSING, States.COMPLETE, States.FAILED],
45 | States.CLOSING: [States.COMPLETE, States.FAILED],
46 |
47 | States.COMPLETE: [States.INPUT_VALIDATION],
48 | States.FAILED: [States.INPUT_VALIDATION]
49 | }
50 |
--------------------------------------------------------------------------------
/tinyllm/test_vector_store.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import unittest
3 |
4 | from sqlalchemy import delete
5 |
6 | from tinyllm.tests.base import AsyncioTestCase
7 | from tinyllm.rag.vector_store import VectorStore, Embeddings
8 |
9 |
10 | class TestVectorStore(AsyncioTestCase):
11 |
12 | def setUp(self):
13 | super().setUp()
14 |
15 | async def embedding_function(text):
16 | return [[1]*384] #
17 | import asyncio
18 | # Environment Variables for DB
19 | self.vector_store = VectorStore(embedding_function=embedding_function)
20 | asyncio.run(self.vector_store.create_tables())
21 | self.test_texts = ["Hello, world!", "Hi there!", "How are you?"]
22 | self.collection_name = 'test_collection'
23 | self.metadatas = [{"type": "test"}] * len(self.test_texts)
24 |
25 | def test_add_texts(self):
26 |
27 | # Adding test data
28 | self.loop.run_until_complete(self.vector_store.add_texts(self.test_texts, self.collection_name, self.metadatas))
29 | query = "Hello, World"
30 | k = 1
31 | collection_filters = [self.collection_name]
32 | metadata_filters = {"type": ["test"]}
33 | results = self.loop.run_until_complete(self.vector_store.similarity_search(query, k, collection_filters, metadata_filters))
34 | docs = results['documents']
35 | self.assertTrue(len(docs) <= k)
36 | self.assertTrue(all(r['document'].metadata['type'] == 'test' for r in docs))
37 |
38 |
39 | def tearDown(self):
40 |
41 | async def clear_dbb():
42 | async with self.vector_store._Session() as session: # Use an asynchronous session
43 | await session.begin()
44 | await session.execute(
45 | delete(Embeddings).where(Embeddings.collection_name == self.collection_name)
46 | )
47 | await session.commit()
48 |
49 | self.loop.run_until_complete(clear_dbb())
50 |
51 | super().tearDown()
52 |
53 | if __name__ == '__main__':
54 | unittest.main()
55 |
--------------------------------------------------------------------------------
/tinyllm/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/tests/__init__.py
--------------------------------------------------------------------------------
/tinyllm/tests/base.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import sys
4 | import unittest
5 |
6 | from tinyllm import langfuse_client, tinyllm_config
7 |
8 |
9 | class AsyncioTestCase(unittest.TestCase):
10 | def setUp(self):
11 | # Tests are run in live mode
12 | self.loop = asyncio.new_event_loop()
13 | asyncio.set_event_loop(self.loop)
14 |
15 | def tearDown(self):
16 | self.loop.close()
17 | asyncio.set_event_loop(None)
18 | langfuse_client.flush()
19 |
20 |
--------------------------------------------------------------------------------
/tinyllm/tests/test_agent.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from typing import Any, Optional
3 |
4 | from pydantic import BaseModel, Field
5 |
6 | from tinyllm.agent.agent import Agent
7 | from tinyllm.agent.tool import tinyllm_toolkit
8 | from tinyllm.eval.evaluator import Evaluator
9 | from tinyllm.llms.lite_llm import LiteLLM
10 | from tinyllm.tests.base import AsyncioTestCase
11 | from tinyllm.util.message import Text, UserMessage
12 |
13 |
14 | class AnswerCorrectnessEvaluator(Evaluator):
15 |
16 | async def run(self, **kwargs):
17 | completion = kwargs['output']['response']['choices'][0]['message']['content']
18 | evals = {
19 | "evals": {
20 | "correct_answer": 1 if 'january 1st' in completion.lower() else 0
21 | },
22 | "metadata": {}
23 | }
24 |
25 | return evals
26 |
27 |
28 | # Define the test class
29 |
30 | class TestAgent(AsyncioTestCase):
31 |
32 | def test_json_output(self):
33 | class Person(BaseModel):
34 | name: str = Field(..., description='Name of the person')
35 | age: int = Field(..., description='Age of the person')
36 | note: Optional[Any]
37 |
38 | class RiskScoreOutput(BaseModel):
39 | risk_score: float = Field(..., description='Confidence level of the trade idea between 0 and 100')
40 | person: Person
41 |
42 | tiny_agent = Agent(
43 | name='Test: Agent JSON output',
44 | system_role="You are a Credit Risk Analyst. Respond with a risk score based on the provided customer data",
45 | output_model=RiskScoreOutput
46 | )
47 | # Run the asynchronous test
48 | result = self.loop.run_until_complete(tiny_agent(content="Johny Vargas, 29yo, the customer has missed 99% of his bill payments in the last year"))
49 | self.assertTrue(result['status'] == 'success')
50 | self.assertTrue(result['output']['response'].get('risk_score') is not None)
51 |
52 |
53 | def test_wiki_tool(self):
54 | tiny_agent = Agent(
55 | name='Test: Agent Wiki Tool',
56 | llm=LiteLLM(),
57 | toolkit=tinyllm_toolkit(),
58 | user_id='test_user',
59 | session_id='test_session',
60 | )
61 | # Run the asynchronous test
62 | result = self.loop.run_until_complete(tiny_agent(content="What does wiki say about Morocco"))
63 | self.assertTrue(result['status'] == 'success')
64 |
65 |
66 | def test_multi_tool(self):
67 | tiny_agent = Agent(
68 | name="Test: Agent Multi Tool",
69 | toolkit=tinyllm_toolkit(),
70 | )
71 | # Run the asynchronous test
72 | query = """Plan then execute this task for me: I need to multiply the population of Morocco by the population of
73 | Senegal, then square that number by Elon Musk's age"""
74 | result = self.loop.run_until_complete(tiny_agent(content=query,
75 | model='gpt-4')) # Parallel call is not handled yet
76 | self.assertTrue(result['status'] == 'success')
77 |
78 |
79 |
80 | # This allows the test to be run standalone
81 | if __name__ == '__main__':
82 | unittest.main()
83 |
--------------------------------------------------------------------------------
/tinyllm/tests/test_agent_stream.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from tinyllm.agent.tool.tool import Tool
4 | from tinyllm.tests.base import AsyncioTestCase
5 | from tinyllm.agent.agent_stream import AgentStream
6 | from tinyllm.agent.tool import Toolkit, tinyllm_toolkit
7 | from tinyllm.eval.evaluator import Evaluator
8 | from tinyllm.memory.memory import BufferMemory
9 | from tinyllm.util.helpers import get_openai_message
10 |
11 |
12 | class AnswerCorrectnessEvaluator(Evaluator):
13 | async def run(self, **kwargs):
14 | completion = kwargs['output']['output']['completion']
15 |
16 | evals = {
17 | "evals": {
18 | "correct_answer": 1 if 'january 1st' in completion.lower() else 0
19 | },
20 | "metadata": {}
21 | }
22 | return evals
23 |
24 |
25 | def get_user_property(asked_property):
26 | if asked_property == "name":
27 | return "Elias"
28 | elif asked_property == "birthday":
29 | return "January 1st"
30 |
31 |
32 | tools = [
33 | Tool(
34 | name="get_user_property",
35 | description="This is the tool you use retrieve ANY information about the user. Use it to answer questions about his birthday and any personal info",
36 | python_lambda=get_user_property,
37 | parameters={
38 | "type": "object",
39 | "properties": {
40 | "asked_property": {
41 | "type": "string",
42 | "enum": ["birthday", "name"],
43 | "description": "The specific property the user asked about",
44 | },
45 | },
46 | "required": ["asked_property"],
47 | },
48 | )
49 | ]
50 | toolkit = Toolkit(
51 | name='Toolkit',
52 | tools=tools
53 | )
54 |
55 |
56 |
57 | # Define the test class
58 | class TestStreamingAgent(AsyncioTestCase):
59 |
60 | def test_tool_call(self):
61 |
62 | tiny_agent = AgentStream(
63 | name="Test: Agent Stream tools",
64 | toolkit=tinyllm_toolkit(),
65 | user_id='test_user',
66 | session_id='test_session',
67 | run_evaluators=[
68 | AnswerCorrectnessEvaluator(
69 | name="Eval: correct user info",
70 | ),
71 | ],
72 | )
73 |
74 | async def async_test():
75 | msgs = []
76 | async for message in tiny_agent(content="What is the 5th Fibonacci number?"):
77 | msgs.append(message)
78 | return msgs
79 |
80 | # Run the asynchronous test
81 | result = self.loop.run_until_complete(async_test())
82 | # Verify the last message in the list
83 | self.assertEqual(result[-1]['status'], 'success', "The last message status should be 'success'")
84 |
85 |
86 | # This allows the test to be run standalone
87 | if __name__ == '__main__':
88 | unittest.main()
89 |
--------------------------------------------------------------------------------
/tinyllm/tests/test_document_store.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from tinyllm.rag.document.document import Document
4 | from tinyllm.rag.document.store import DocumentStore
5 |
6 |
7 | class TestDocsStore(unittest.TestCase):
8 |
9 | def setUp(self):
10 | # Define the initial parameters for DocsContextBuilder
11 | self.docs = [
12 | {"content": "First document text.",
13 | "metadata": {}},
14 | {"content": "Second document text, which is slightly longer.",
15 | "metadata": {}},
16 | {"content": "Third document text.",
17 | "metadata": {}}
18 | ]
19 | self.docs = [Document(**doc) for doc in self.docs]
20 | self.document_store = DocumentStore()
21 | self.document_store.add_docs(
22 | name='test_section',
23 | docs=self.docs
24 | )
25 | self.docs_2 = [
26 | {"content": "2 First document text.",
27 | "metadata": {}},
28 | {"content": "2 Second document text, which is slightly longer.",
29 | "metadata": {}},
30 | {"content": "2 Third document text.",
31 | "metadata": {}}
32 | ]
33 | self.docs_2 = [Document(**doc) for doc in self.docs_2]
34 | self.document_store.add_docs(
35 | name='test_section_2',
36 | docs=self.docs_2
37 | )
38 |
39 | def test_get_context(self):
40 | start_string = "SUPPORTING DOCS"
41 | end_string = "END OF SUPPORTING DOCS"
42 |
43 | # We limit the context size to the size of the first doc source in the store, + 1 out of 3 docs from the second source
44 | doc1_size = sum([i.size for i in self.docs])
45 | context_available = doc1_size + self.docs_2[0].size
46 |
47 | # Use the DocsContextBuilder to get the final context
48 | final_context = self.document_store.to_string(
49 | start_string=start_string,
50 | end_string=end_string,
51 | context_size=context_available,
52 | weights=[0.6, 0.4]
53 | )
54 | # Assert the presence of the start and end strings in the final context
55 | self.assertTrue(start_string in final_context)
56 | self.assertTrue(end_string in final_context)
57 | # Assert the presence of document texts in the final context
58 | for doc in self.docs[:2]:
59 | self.assertTrue(doc.to_string() in final_context)
60 | for doc in self.docs_2[:1]:
61 | self.assertTrue(doc.to_string() in final_context)
62 |
63 |
64 | if __name__ == '__main__':
65 | unittest.main()
66 |
--------------------------------------------------------------------------------
/tinyllm/tests/test_evaluators.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from tinyllm.eval.evaluator import Evaluator
4 | from tinyllm.function import Function
5 | from tinyllm.tracing.langfuse_context import observation
6 | from tinyllm.util.helpers import get_openai_message
7 | from tinyllm.tests.base import AsyncioTestCase
8 |
9 |
10 | class SuccessFullRunEvaluator(Evaluator):
11 | async def run(self, **kwargs):
12 | print('')
13 | return {
14 | "evals": {
15 | "successful_score": 1,
16 | },
17 | "metadata": {}
18 | }
19 |
20 |
21 | class TestEvaluators(AsyncioTestCase):
22 |
23 |
24 | def test_evaluator(self):
25 |
26 |
27 | litellm_chat = Function(name='Test: LiteLLMChat evaluation',
28 | run_evaluators=[SuccessFullRunEvaluator()],
29 | processed_output_evaluators=[SuccessFullRunEvaluator()])
30 | message = get_openai_message(role='user',
31 | content="Hi")
32 | result = self.loop.run_until_complete(litellm_chat(messages=[message]))
33 | self.assertEqual(result['status'], 'success')
34 |
35 |
36 | def test_evaluator_decorator(self):
37 |
38 | @observation('span', evaluators=[SuccessFullRunEvaluator()])
39 | async def run_func(something=None):
40 | return {
41 | 'status': 'success',
42 | }
43 |
44 | result = self.loop.run_until_complete(run_func(something='something'))
45 |
46 | if __name__ == '__main__':
47 | unittest.main()
48 |
--------------------------------------------------------------------------------
/tinyllm/tests/test_example_selector.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from tinyllm.examples.example_selector import ExampleSelector
4 | from tinyllm.tests.base import AsyncioTestCase
5 |
6 |
7 | async def embedding_function(text):
8 | return [[1] * 384] #
9 |
10 |
11 |
12 | class TestExampleSelector(AsyncioTestCase):
13 |
14 | def setUp(self):
15 | super().setUp()
16 | self.example_texts = [
17 | {
18 | "user": "Example question",
19 | "assistant": "Example answer",
20 | },
21 | {
22 | "user": "Another example question",
23 | "assistant": "Another example answer"
24 | }
25 | ]
26 |
27 |
28 | self.example_selector = ExampleSelector(
29 | name="Test local example selector",
30 | examples=self.example_texts,
31 | embedding_function=embedding_function,
32 | )
33 |
34 | self.loop.run_until_complete(self.example_selector.embed_examples())
35 |
36 |
37 | def test_selector(self):
38 | query = "Find a relevant example"
39 | results = self.loop.run_until_complete(self.example_selector(input=query,
40 | k=1))
41 | self.assertTrue(len(results['output']['best_examples']) == 1)
42 |
43 |
44 |
45 | if __name__ == '__main__':
46 | unittest.main()
47 |
--------------------------------------------------------------------------------
/tinyllm/tests/test_function.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from tinyllm.tests.base import AsyncioTestCase
4 | from tinyllm.function import Function
5 | from tinyllm.exceptions import InvalidStateTransition
6 | from tinyllm.validator import Validator
7 | from tinyllm.state import States
8 |
9 |
10 | class InputValidator(Validator):
11 | value: float
12 |
13 | class OutputValidator(Validator):
14 | value: float
15 |
16 | class AddOneOperator(Function):
17 | def __init__(self, **kwargs):
18 | super().__init__(input_validator=InputValidator, output_validator=OutputValidator, **kwargs)
19 |
20 | async def run(self, **kwargs):
21 | value = kwargs["value"]
22 | self.log("Adding one to {}".format(value))
23 | result = value + 1
24 | self.log("Done adding")
25 | return {"value": result}
26 |
27 |
28 | class TestFunction(AsyncioTestCase):
29 |
30 | def test_add_one(self):
31 | operator = AddOneOperator(name="AddOneTest")
32 | result = self.loop.run_until_complete(operator(value=5.0))
33 | self.assertIsNotNone(result)
34 | self.assertEqual(result['output']["value"], 6.0)
35 |
36 | def test_invalid_state_transition(self):
37 | operator = AddOneOperator(name="AddOneTest")
38 | with self.assertRaises(InvalidStateTransition):
39 | operator.transition(States.COMPLETE)
40 |
41 | def test_invalid_input(self):
42 | operator = AddOneOperator()
43 | self.loop.run_until_complete(operator(value="wrong input"))
44 | assert operator.state == States.FAILED
45 |
46 |
47 |
48 | if __name__ == '__main__':
49 | unittest.main()
50 |
--------------------------------------------------------------------------------
/tinyllm/tests/test_litellm.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from tinyllm.eval.evaluator import Evaluator
4 | from tinyllm.llms.lite_llm import LiteLLM
5 | from tinyllm.llms.lite_llm_stream import LiteLLMStream
6 | from tinyllm.util.helpers import get_openai_message
7 | from tinyllm.tests.base import AsyncioTestCase
8 | from tinyllm.util.message import UserMessage, Text
9 |
10 |
11 | class TestlitellmChat(AsyncioTestCase):
12 |
13 | def setUp(self):
14 | super().setUp()
15 |
16 | def test_litellm_chat(self):
17 | message = UserMessage("Hi")
18 | litellm_chat = LiteLLM(name='Test: LiteLLMChat')
19 | result = self.loop.run_until_complete(litellm_chat(messages=[message]))
20 | self.assertEqual(result['status'], 'success')
21 |
22 | def test_litellm_chat_evaluator(self):
23 | class SuccessFullRunEvaluator(Evaluator):
24 | async def run(self, **kwargs):
25 | return {
26 | "evals": {
27 | "successful_score": 1,
28 | },
29 | "metadata": {}
30 | }
31 |
32 | litellm_chat = LiteLLM(name='Test: LiteLLMChat evaluation',
33 | run_evaluators=[SuccessFullRunEvaluator()],
34 | processed_output_evaluators=[SuccessFullRunEvaluator()])
35 | message = UserMessage('hi')
36 | result = self.loop.run_until_complete(litellm_chat(messages=[message]))
37 | self.assertEqual(result['status'], 'success')
38 |
39 |
40 |
41 | if __name__ == '__main__':
42 | unittest.main()
43 |
--------------------------------------------------------------------------------
/tinyllm/tests/test_litellm_stream.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from tinyllm.llms.lite_llm_stream import LiteLLMStream
4 | from tinyllm.tests.base import AsyncioTestCase
5 | from tinyllm.util.message import UserMessage
6 |
7 |
8 | class TestlitellmChat(AsyncioTestCase):
9 |
10 | def setUp(self):
11 | super().setUp()
12 |
13 | def test_litellm_chat_stream(self):
14 | litellmstream_chat = LiteLLMStream(name='Test: LiteLLM Stream')
15 |
16 | async def get_stream():
17 | message = UserMessage('Hi')
18 | msgs = []
19 | async for msg in litellmstream_chat(messages=[message]):
20 | i = 0
21 | msgs.append(msg)
22 | return msgs
23 |
24 | result = self.loop.run_until_complete(get_stream())
25 | deltas = []
26 | for res in result:
27 | print(res['output']['streaming_status'])
28 | if res['output']['streaming_status'] != 'finished-streaming':
29 | deltas.append(res['output']['last_completion_delta']['content'])
30 |
31 | self.assertEqual(res['status'], 'success')
32 | final_string = ''.join(deltas)
33 | self.assertTrue(final_string[-1] != final_string[-2], "The last Delta has been returned twice")
34 |
35 |
36 |
37 | if __name__ == '__main__':
38 | unittest.main()
39 |
--------------------------------------------------------------------------------
/tinyllm/tests/test_memory.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from tinyllm.memory.memory import BufferMemory
4 | from tinyllm.state import States
5 | from tinyllm.tests.base import AsyncioTestCase
6 |
7 |
8 | class TestMemory(AsyncioTestCase):
9 |
10 | def test_memory(self):
11 | memory = BufferMemory(name="Memory test")
12 |
13 | # User message
14 | msg = {
15 | 'content': 'Hi agent, how are you?',
16 | 'role': 'user'
17 | }
18 |
19 | result = self.loop.run_until_complete(memory(message=msg))
20 |
21 | self.assertEqual(memory.state, States.COMPLETE)
22 | self.assertEqual(len(memory.memories), 1)
23 | self.assertEqual(memory.memories[0], {'role': 'user', 'content': 'Hi agent, how are you?'})
24 |
25 |
26 | if __name__ == '__main__':
27 | unittest.main()
28 |
--------------------------------------------------------------------------------
/tinyllm/tests/test_rag.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from tinyllm.agent.tool.tool import Tool
4 | from tinyllm.rag.document.document import Document
5 | from tinyllm.tests.base import AsyncioTestCase
6 | from tinyllm.agent.tool import Toolkit, tinyllm_toolkit
7 | from tinyllm.eval.evaluator import Evaluator
8 | from tinyllm.memory.memory import BufferMemory
9 | from tinyllm.util.helpers import get_openai_message
10 |
11 |
12 | # Define the test class
13 | class TestRAG(AsyncioTestCase):
14 |
15 | def test_doc(self):
16 | doc = Document(content='Hello World')
17 | doc.to_string(header="[doc]",
18 | include_keys=['content', 'metadata'])
19 | # self.assertEqual(result[-1]['status'], 'success', "The last message status should be 'success'")
20 |
21 |
22 | # This allows the test to be run standalone
23 | if __name__ == '__main__':
24 | unittest.main()
25 |
--------------------------------------------------------------------------------
/tinyllm/tests/test_tiny_function.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, Field
2 |
3 | from tinyllm.llms.tiny_function import tiny_function
4 | from tinyllm.tests.base import AsyncioTestCase
5 | from tinyllm.util.message import Text
6 |
7 | # Mock response to simulate Agent's response
8 | mock_success_response = {
9 | "status": "success",
10 | "output": {
11 | "response": {
12 | "choices": [{
13 | "message": {
14 | "content": "{\"name\": \"Elon\", \"age\": 50, \"occupation\": \"CEO\"}"
15 | }
16 | }]
17 | }
18 | }
19 | }
20 |
21 | mock_fail_response = {"status": "error", "response": {"message": "Failed to process request"}}
22 |
23 |
24 | class TestTinyFunctionDecorator(AsyncioTestCase):
25 |
26 | def test_tiny_function_success(self):
27 | class CharacterInfo(BaseModel):
28 | name: str = Field(..., description="Name")
29 | age: int = Field(..., description="Age")
30 | occupation: str = Field(..., description="occupation")
31 |
32 | @tiny_function(output_model=CharacterInfo)
33 | async def get_character_info(doc1, doc2):
34 | """
35 |
36 | Extract character information from the provided documents
37 |
38 |
39 |
40 | {doc1}
41 | {doc2}
42 |
43 | """
44 | pass
45 |
46 | # Test the decorated function
47 | content = "Elon Musk is a 50 years old CEO"
48 | result = self.loop.run_until_complete(get_character_info(doc1=content, doc2=content))
49 |
50 | # Assertions
51 | self.assertIsInstance(result['output'], dict)
52 | self.assertTrue("Elon" in result['output']['name'])
53 | self.assertTrue(result['output']['age'], 50)
54 | self.assertTrue("CEO" in result['output']['occupation'])
55 |
56 | def test_no_model(self):
57 | @tiny_function()
58 | async def get_character_info(content: str):
59 | """
60 |
61 | Extract character information from the content
62 |
63 | """
64 | pass
65 |
66 | # Test the decorated function
67 | content = "Elon Musk is a 50 years old CEO"
68 | result = self.loop.run_until_complete(get_character_info(content=content))
69 | self.assertEqual(result['status'], 'success')
70 |
71 | def test_not_enough_tokens(self):
72 | class CharacterInfo(BaseModel):
73 | name: str = Field(..., description="Name")
74 | age: int = Field(..., description="Age")
75 | occupation: str = Field(..., description="occupation")
76 |
77 | @tiny_function(model_kwargs={'max_tokens': 80}, output_model=CharacterInfo)
78 | async def get_character_info(content: str):
79 | """
80 |
81 | Extract character information from the content
82 |
83 | """
84 | pass
85 |
86 | # Test the decorated function
87 | content = "Elon Musk is a 50 years old CEO"
88 | result = self.loop.run_until_complete(get_character_info(content=content))
89 | print(result)
90 | self.assertEqual(result['status'], 'success')
91 |
--------------------------------------------------------------------------------
/tinyllm/tests/test_tracing.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from tinyllm.eval.evaluator import Evaluator
4 | from tinyllm.function import Function
5 | from tinyllm.tracing.langfuse_context import observation
6 | from tinyllm.tests.base import AsyncioTestCase
7 |
8 |
9 | class TestlitellmChat(AsyncioTestCase):
10 |
11 | def setUp(self):
12 | super().setUp()
13 |
14 | def test_function_tracing(self):
15 | class SuccessFullRunEvaluator(Evaluator):
16 | async def run(self, **kwargs):
17 | print('Running evaluator')
18 | return {
19 | "evals": {
20 | "successful_score": 1,
21 | },
22 | "metadata": {}
23 | }
24 |
25 | class TestFunction(Function):
26 |
27 | @observation(observation_type='span')
28 | async def run(self, **kwargs):
29 | return {
30 | "result": 1
31 | }
32 |
33 | @observation(observation_type='span')
34 | async def process_output(self, **kwargs):
35 | result = 10 + kwargs['result']
36 | return {
37 | "result": result,
38 | }
39 |
40 | test_func = TestFunction(name='Test: tracing')
41 | message = {
42 | 'role': 'user',
43 | 'content': "Hi"
44 | }
45 | result = self.loop.run_until_complete(test_func(message=message))
46 | self.assertEqual(result['status'], 'success')
47 |
48 |
49 | if __name__ == '__main__':
50 | unittest.main()
51 |
--------------------------------------------------------------------------------
/tinyllm/tracing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/tracing/__init__.py
--------------------------------------------------------------------------------
/tinyllm/tracing/helpers.py:
--------------------------------------------------------------------------------
1 | import datetime as dt
2 | import inspect
3 | import traceback
4 | from functools import wraps
5 |
6 | import langfuse
7 | import numpy as np
8 |
9 | from tinyllm import langfuse_client
10 | from tinyllm.constants import LLM_PRICING
11 | from tinyllm.util.helpers import count_tokens, num_tokens_from_string
12 | from tinyllm.util.message import Message
13 |
14 | model_parameters = [
15 | "model",
16 | "frequency_penalty",
17 | "logit_bias",
18 | "logprobs",
19 | "top_logprobs",
20 | "max_tokens",
21 | "n",
22 | "presence_penalty",
23 | "response_format",
24 | "seed",
25 | "stop",
26 | "stream",
27 | "temperature",
28 | "top_p"
29 | ]
30 |
31 |
32 | ## I want you to implement an ObservationWrapper class that implements all of the above functions as class methods
33 |
34 | class ObservationUtil:
35 |
36 | @classmethod
37 | def handle_exception(cls, obs, e):
38 | if 'end' in dir(obs):
39 | obs.end(level='ERROR', status_message=str(traceback.format_exc()))
40 | elif 'update' in dir(obs):
41 | obs.update(level='ERROR', status_message=str(traceback.format_exc()))
42 |
43 | @classmethod
44 | def prepare_observation_input(cls, input_mapping, function_input):
45 | if not input_mapping:
46 | # stringify values
47 | function_input = cls.keep_accepted_types(function_input)
48 | return {'input': function_input}
49 |
50 | return {langfuse_kwarg: function_input[function_kwarg] for langfuse_kwarg, function_kwarg in
51 | input_mapping.items()}
52 |
53 | @classmethod
54 | def keep_accepted_types(self, d):
55 | acceptable_types = (str, dict, list, tuple, int, float, np.ndarray)
56 |
57 | def is_acceptable(v):
58 | if isinstance(v, acceptable_types):
59 | if isinstance(v, list):
60 | return all(isinstance(item, acceptable_types) for item in v)
61 | return True
62 | return False
63 |
64 | def clean(value):
65 | if isinstance(value, dict):
66 | return {k: clean(v) for k, v in value.items() if is_acceptable(v)}
67 | elif isinstance(value, list):
68 | return [clean(item) for item in value if is_acceptable(item)]
69 | else:
70 | return value
71 |
72 | return clean(d)
73 |
74 | @classmethod
75 | def end_observation(cls, obs, function_input, function_output, output_mapping, observation_type, function_kwargs):
76 | if type(obs) == langfuse.client.StatefulTraceClient:
77 | return
78 |
79 | mapped_output = {}
80 | function_output = cls.keep_accepted_types(function_output)
81 | if not output_mapping:
82 | mapped_output = {'output': function_output}
83 | else:
84 | for langfuse_kwarg, function_kwarg in output_mapping.items():
85 | mapped_output[langfuse_kwarg] = function_output.get(function_kwarg, None)
86 |
87 | if observation_type == 'generation':
88 |
89 | prompt_tokens = count_tokens(function_input)
90 |
91 | if function_output['type'] == 'tool' or function_output.get('message',{}).get('tool_calls',[]) != []:
92 | completion_tokens = count_tokens(function_output['message']['tool_calls'])
93 | else:
94 | completion_tokens = count_tokens(function_output['message']['content'])
95 |
96 | total_tokens = prompt_tokens + completion_tokens
97 | usage_info = {
98 | 'input': prompt_tokens,
99 | 'output': completion_tokens,
100 | 'total': total_tokens,
101 | "unit": "TOKENS",
102 | }
103 |
104 | model = function_kwargs.get('model', None)
105 | pricing = LLM_PRICING.get(model, None)
106 | cost = {}
107 | if pricing:
108 | input_cost = pricing['input'] * (prompt_tokens / 1000)
109 | output_cost = pricing['output'] * (completion_tokens / 1000)
110 | cost = {
111 | "input_cost": input_cost,
112 | "output_cost": output_cost,
113 | "total_cost": input_cost + output_cost
114 | }
115 | usage_info.update(cost)
116 | model_params = {k: v for k, v in function_kwargs.items() if
117 | k in model_parameters and k not in ['messages']}
118 | if 'response_format' in model_params:
119 | model_params['response_format'] = str(model_params['response_format'])
120 |
121 | obs.end(
122 | end_time=dt.datetime.now(),
123 | model=function_kwargs.get('model', None),
124 | model_parameters=model_params,
125 | usage=usage_info,
126 | **mapped_output)
127 | elif observation_type == 'span':
128 | obs.end(**mapped_output)
129 |
130 | @classmethod
131 | async def perform_evaluations(cls, observation, result, evaluators):
132 | if evaluators:
133 | for evaluator in evaluators:
134 | result['observation'] = observation
135 | await evaluator(**result)
136 | result.pop('observation')
137 |
138 | @classmethod
139 | def conditional_args(cls, observation_type, input_mapping=None, output_mapping=None):
140 | if observation_type == 'generation':
141 | if input_mapping is None:
142 | input_mapping = {'input': 'messages'}
143 | if output_mapping is None:
144 | output_mapping = {'output': 'message'}
145 | return input_mapping, output_mapping
146 |
147 | @classmethod
148 | def get_obs_name(cls, *args, func):
149 | name = None
150 | # Decorated method
151 | if len(args) > 0:
152 | if hasattr(args[0], 'name'):
153 | name = args[0].name + ('.' + func.__name__ if func.__name__ not in ['wrapper', '__call__'] else '')
154 | else:
155 | name = args[0].__class__.__name__ + '.' + func.__name__
156 |
157 | # Decorated function
158 | else:
159 | if hasattr(func, '__qualname__'):
160 | if len(func.__qualname__.split('.')) > 1 and '' not in func.__qualname__.split('.'):
161 | name = '.'.join(func.__qualname__.split('.')[-2::])
162 | else:
163 | name = func.__name__
164 | return name
165 |
166 | @classmethod
167 | def get_current_obs(cls,
168 | *args,
169 | parent_observation,
170 | observation_type,
171 | name,
172 | observation_input):
173 |
174 | if parent_observation is None:
175 | # This is the root function, create a new trace
176 | optional_args = {}
177 | for arg in ['user_id', 'session_id']:
178 | if len(args) > 0:
179 | if getattr(args[0], arg, None):
180 | optional_args[arg] = getattr(args[0], arg)
181 |
182 | observation = langfuse_client.trace(name=name,
183 | **optional_args,
184 | **observation_input)
185 | # Pass the trace to the Function
186 | if len(args) > 0:
187 | args[0].observation = observation
188 | args[0].trace = observation
189 | observation_method = getattr(observation, observation_type)
190 | if observation_type == 'generation':
191 | observation_input['input'] = [i.to_dict() for i in observation_input['input']]
192 |
193 | observation = observation_method(name=name, **observation_input)
194 | else:
195 | # Create child observations based on the type
196 | observation_method = getattr(parent_observation, observation_type)
197 | if observation_type == 'generation':
198 | observation_input['input'] = [i.to_dict() for i in observation_input['input']]
199 |
200 | observation = observation_method(name=name, **observation_input)
201 | # Pass the parent trace to this function
202 | args[0].observation = parent_observation
203 |
204 | # Pass the generation
205 | if len(args) > 0:
206 | if hasattr(args[0], 'generation'):
207 | if observation_type == 'generation':
208 | args[0].generation = observation
209 |
210 | return observation
211 |
--------------------------------------------------------------------------------
/tinyllm/tracing/langfuse_context.py:
--------------------------------------------------------------------------------
1 | from contextvars import ContextVar
2 |
3 | from pydantic import BaseModel
4 |
5 | from smartpy.utility.py_util import stringify_values_recursively
6 | from tinyllm.tracing.helpers import *
7 |
8 | current_observation_context = ContextVar('current_observation_context', default=None)
9 |
10 |
11 | class ObservationDecoratorFactory:
12 |
13 | @classmethod
14 | def get_streaming_decorator(self,
15 | observation_type,
16 | input_mapping=None,
17 | output_mapping=None,
18 | evaluators=None):
19 | def decorator(func):
20 | @wraps(func)
21 | async def wrapper(*args, **function_input):
22 | parent_observation = current_observation_context.get()
23 |
24 | name = ObservationUtil.get_obs_name(*args, func=func)
25 |
26 | # Prepare the input for the observation
27 | observation_input = ObservationUtil.prepare_observation_input(input_mapping, function_input)
28 |
29 | # Get the current observation
30 | observation = ObservationUtil.get_current_obs(*args,
31 | parent_observation=parent_observation,
32 | observation_type=observation_type,
33 | name=name,
34 | observation_input=observation_input)
35 | # Pass the observation to the class (so it can evaluate it)
36 | if len(args) > 0:
37 | args[0].observation = observation
38 | # Set the current observation in the context for child functions to access
39 | token = current_observation_context.set(observation)
40 | result = {}
41 | try:
42 | async for result in func(*args, **function_input):
43 | yield result
44 | if type(result) != dict: result = {'result': result}
45 | await ObservationUtil.perform_evaluations(observation, result, evaluators)
46 | except Exception as e:
47 | ObservationUtil.handle_exception(observation, e)
48 | finally:
49 | current_observation_context.reset(token)
50 | ObservationUtil.end_observation(observation, observation_input, result, output_mapping,
51 | observation_type, function_input)
52 |
53 | return wrapper
54 |
55 | return decorator
56 |
57 | @classmethod
58 | def get_decorator(self,
59 | observation_type,
60 | name=None,
61 | input_mapping=None,
62 | output_mapping=None,
63 | evaluators=None):
64 | def decorator(func):
65 |
66 | @wraps(func)
67 | async def wrapper(*args, **function_input):
68 | parent_observation = current_observation_context.get()
69 | if name is None:
70 | obs_name = ObservationUtil.get_obs_name(*args, func=func)
71 | else:
72 | obs_name = name
73 | observation_input = ObservationUtil.prepare_observation_input(input_mapping, function_input)
74 | observation = ObservationUtil.get_current_obs(*args,
75 | parent_observation=parent_observation,
76 | observation_type=observation_type,
77 | name=obs_name,
78 | observation_input=observation_input)
79 | token = current_observation_context.set(observation)
80 | result = {}
81 | if len(args) > 0:
82 | args[0].observation = observation
83 | try:
84 | result = await func(*args, **function_input)
85 | if type(result) != dict: result = {'result': result}
86 | # convert pydantic models to dict
87 | for key, value in result.items():
88 | if isinstance(value, BaseModel):
89 | result[key] = value.model_dump()
90 |
91 | await ObservationUtil.perform_evaluations(observation, result, evaluators)
92 | return result
93 | except Exception as e:
94 | ObservationUtil.handle_exception(observation, e)
95 | raise e
96 | finally:
97 | current_observation_context.reset(token)
98 | ObservationUtil.end_observation(observation, observation_input, result, output_mapping,
99 | observation_type, function_input)
100 | langfuse_client.flush()
101 |
102 | return wrapper
103 |
104 | return decorator
105 |
106 |
107 | def observation(observation_type='span',name=None, input_mapping=None, output_mapping=None, evaluators=None, stream=False):
108 | input_mapping, output_mapping = ObservationUtil.conditional_args(observation_type,
109 | input_mapping,
110 | output_mapping)
111 | if stream:
112 | return ObservationDecoratorFactory.get_streaming_decorator(observation_type, input_mapping, output_mapping,
113 | evaluators)
114 | else:
115 | return ObservationDecoratorFactory.get_decorator(observation_type, name, input_mapping, output_mapping, evaluators)
116 |
--------------------------------------------------------------------------------
/tinyllm/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/util/__init__.py
--------------------------------------------------------------------------------
/tinyllm/util/ai_util.py:
--------------------------------------------------------------------------------
1 | from textwrap import dedent
2 |
3 | import numpy as np
4 | import openai
5 |
6 | def get_cosine_similarity(a, b):
7 | a = np.array(a)
8 | b = np.array(b)
9 | return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
10 |
11 | def get_top_n_similar_vectors_index(input_vector, vectors, k=5, similarity_threshold=0.5):
12 | similarities = [get_cosine_similarity(input_vector, vector) for vector in vectors]
13 | top_similarities_indices = np.argsort(similarities)[-k:][::-1]
14 | return [int(index) for index in top_similarities_indices if similarities[index] > similarity_threshold]
15 |
16 | def get_openai_embedding(text, model="text-embedding-ada-002"):
17 | embedding = openai.Embedding.create(input=[text], model=model)['data'][0]['embedding']
18 | return embedding
19 |
20 | def generate_raw_ngrams(text, n):
21 | tokens = text.split()
22 | if len(tokens) < n:
23 | return []
24 | # Use a list comprehension to generate the n-grams
25 | ngrams = [tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)]
26 |
27 | return ngrams
28 |
29 |
--------------------------------------------------------------------------------
/tinyllm/util/db_util.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zozoheir/tinyllm/82d1973ee34bd614c0d71e874f83cc07bcbbc544/tinyllm/util/db_util.py
--------------------------------------------------------------------------------
/tinyllm/util/helpers.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List, Dict
2 |
3 | import tiktoken
4 |
5 | from tinyllm.util.message import Message
6 | from tinyllm.util.prompt_util import stringify_dict
7 |
8 | OPENAI_MODELS_CONTEXT_SIZES = {
9 | "gpt-4": 8192,
10 | "gpt-4-0314": 8192,
11 | "gpt-4-0613": 8192,
12 | "gpt-4-32k": 32768,
13 | "gpt-4-32k-0314": 32768,
14 | "gpt-4-32k-0613": 32768,
15 | "gpt-3.5-turbo": 4096,
16 | "gpt-3.5-turbo-0301": 4096,
17 | "gpt-3.5-turbo-0613": 4096,
18 | "gpt-3.5-turbo-16k": 16385,
19 | "gpt-3.5-turbo-16k-0613": 16385,
20 | "text-ada-001": 2049,
21 | "ada": 2049,
22 | "text-babbage-001": 2040,
23 | "babbage": 2049,
24 | "text-curie-001": 2049,
25 | "curie": 2049,
26 | "davinci": 2049,
27 | "text-davinci-003": 4097,
28 | "text-davinci-002": 4097,
29 | "code-davinci-002": 8001,
30 | "code-davinci-001": 8001,
31 | "code-cushman-002": 2048,
32 | "code-cushman-001": 2048,
33 | }
34 |
35 | def get_openai_message(role,
36 | content: Union[List, str],
37 | **kwargs):
38 | if role not in ['user', 'system', 'function','tool', 'assistant']:
39 | raise ValueError(f"Invalid role {role}.")
40 |
41 | msg = {'role': role,
42 | 'content': content}
43 | msg.update(kwargs)
44 | return msg
45 |
46 |
47 | def num_tokens_from_string(string: str, encoding_name: str = 'cl100k_base') -> int:
48 | """Returns the number of tokens in a text string."""
49 | encoding = tiktoken.get_encoding(encoding_name)
50 | num_tokens = len(encoding.encode(string))
51 | return num_tokens
52 |
53 |
54 | def count_openai_messages_tokens(messages, model="gpt-3.5-turbo"):
55 | """Returns the number of tokens used by a list of messages."""
56 | try:
57 | encoding = tiktoken.encoding_for_model(model)
58 | except KeyError:
59 | encoding = tiktoken.get_encoding("cl100k_base")
60 | if model in ["gpt-4", "gpt-3.5-turbo", "text-embedding-ada-002"]:
61 | num_tokens = 0
62 | for message in messages:
63 | num_tokens += 4 # every message follows {role/name}\n{content}\n
64 | for key, value in message.items():
65 | num_tokens += len(encoding.encode(value))
66 | if key == "name": # if there's a name, the role is omitted
67 | num_tokens += -1 # role is always required and always 1 token
68 | num_tokens += 2 # every reply is primed with assistant
69 | return int(num_tokens)
70 | else:
71 | raise NotImplementedError("openai_num_tokens_from_messages() is not implemented for this model.")
72 |
73 |
74 | def count_tokens(input: Union[List[Dict], Dict, str],
75 | **kwargs):
76 | if isinstance(input, list):
77 | if len(input) == 0:
78 | return 0
79 | if isinstance(input[0], str):
80 | return sum([num_tokens_from_string(string) for string in input])
81 | elif isinstance(input[0], dict):
82 | return sum([count_tokens(input_dict) for input_dict in input])
83 | elif isinstance(input[0], Message):
84 | return sum([count_tokens(msg.to_dict()) for msg in input])
85 |
86 | return sum([count_tokens(input_dict, **kwargs) for input_dict in input])
87 |
88 | elif isinstance(input, str):
89 | return num_tokens_from_string(input)
90 | elif isinstance(input, dict):
91 | dict_string = stringify_dict(header=kwargs.get('header', '[doc]'),
92 | dict=input,
93 | include_keys=kwargs.get('include_keys', []))
94 | return num_tokens_from_string(dict_string)
95 | elif isinstance(input, Message):
96 | return count_tokens(input.to_dict())
97 |
98 | else:
99 | raise NotImplementedError("count_tokens() is not implemented for this input type.")
100 |
101 |
--------------------------------------------------------------------------------
/tinyllm/util/message.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Dict
2 |
3 | from pydantic import BaseModel
4 |
5 | from tinyllm.validator import Validator
6 |
7 |
8 | # Base class for content types
9 | class Content:
10 | def dict(self) -> Dict:
11 | return self.__dict__
12 |
13 |
14 | class Text(Content):
15 | def __init__(self, text: str):
16 | self.type = "text"
17 | self.text = text
18 |
19 |
20 | class Image(Content):
21 | def __init__(self, url: str):
22 | self.type = "image_url"
23 | self.image_url = {"url": url}
24 |
25 |
26 |
27 | class MessageInput(Validator):
28 | role: str
29 | content: Union[List[Content], str]
30 |
31 |
32 | class Message:
33 |
34 | def __init__(self,
35 | role: str,
36 | content: Union[List[Content], str]):
37 | MessageInput(role=role, content=content)
38 | self.role = role
39 | self.content = content
40 | self.raw_content = content
41 | if type(content) == str:
42 | self.content = [Text(content)]
43 |
44 | def to_dict(self) -> Dict:
45 | return {"role": self.role,
46 | "content": self.raw_content.strip() if type(self.raw_content) == str else [c.dict() for c in self.content]}
47 |
48 |
49 | class UserMessage(Message):
50 | def __init__(self, content: List[Content]):
51 | super().__init__("user", content)
52 |
53 |
54 | class SystemMessage(Message):
55 | def __init__(self, content: List[Content]):
56 | super().__init__("system", content)
57 |
58 |
59 | class FunctionMessage(Message):
60 | def __init__(self, content: List[Content]):
61 | super().__init__("function", content)
62 |
63 |
64 | class ToolMessage(Message):
65 | def __init__(self,
66 | content: Union[List[Content], str],
67 | name: str = None,
68 | tool_calls: List[Dict] = None,
69 | tool_call_id: str = None):
70 | self.name = name
71 | self.tool_calls = tool_calls
72 | self.tool_call_id = tool_call_id
73 | super().__init__("tool", content)
74 |
75 | def to_dict(self) -> Dict:
76 | message = super().to_dict()
77 | if self.tool_call_id:
78 | message['tool_call_id'] = self.tool_call_id
79 | if self.tool_calls:
80 | message['tool_calls'] = self.tool_calls
81 | if self.name:
82 | message['name'] = self.name
83 | return message
84 |
85 |
86 | class AssistantMessage(Message):
87 | def __init__(self,
88 | content,
89 | tool_calls: List[Dict] = None):
90 | self.tool_calls = tool_calls
91 | super().__init__("assistant", content)
92 |
93 | def to_dict(self) -> Dict:
94 | message = super().to_dict()
95 | if self.tool_calls:
96 | message['tool_calls'] = self.tool_calls
97 | return message
98 |
--------------------------------------------------------------------------------
/tinyllm/util/os_util.py:
--------------------------------------------------------------------------------
1 | import psutil
2 | import getpass
3 | import os
4 | import pickle
5 | import platform
6 | import shutil
7 | import stat
8 | import sys
9 | import zipfile
10 | from pathlib import Path
11 | from typing import Union
12 |
13 | import json
14 |
15 |
16 | def loadJson(json_path=None):
17 | with open(json_path) as json_file:
18 | dic = json.load(json_file)
19 | return dic
20 |
21 |
22 | def saveJson(dic=None, save_to_path=None):
23 | ensureDir(save_to_path)
24 | with open(save_to_path, 'w') as f:
25 | json.dump(dic, f)
26 |
27 |
28 | def ensureDir(dir_path):
29 | if isFilePath(dir_path):
30 | dir_path = getParentDir(dir_path)
31 | if isDirPath(dir_path) and not dirExists(dir_path):
32 | os.makedirs(dir_path)
33 |
34 |
35 | def joinPaths(paths: list):
36 | # Remove '/' at the beginning of all paths
37 | paths = [path[1:] if path.startswith('/') and i != 0 else path for i, path in enumerate(paths)]
38 | return os.path.join(*paths)
39 |
40 |
41 | def getBaseName(path):
42 | return os.path.basename(path)
43 |
44 |
45 | def getCurrentDirPath():
46 | return os.getcwd()
47 |
48 |
49 | def getUserHomePath():
50 | return str(Path.home())
51 |
52 |
53 | def getUsername():
54 | return getpass.getuser()
55 |
56 |
57 | def getPythonExecutablePath():
58 | return sys.executable
59 |
60 |
61 | def getOS():
62 | return platform.system()
63 |
64 |
65 | def dirExists(path):
66 | return os.path.isdir(path)
67 |
68 |
69 | def fileExists(path):
70 | return os.path.isfile(path)
71 |
72 |
73 | def isFilePath(path):
74 | split_path = str(path).split('/')
75 | return '.' in split_path[len(split_path) - 1]
76 |
77 |
78 | def isDirPath(path):
79 | split_path = str(path).split('/')
80 | return '.' not in split_path[len(split_path) - 1]
81 |
82 |
83 | def getPythonVersion():
84 | return sys.version[:5]
85 |
86 |
87 | def pathExists(path):
88 | return sum([fileExists(path), dirExists(path)]) > 0
89 |
90 |
91 | def on_rm_error(func, path, exc_info):
92 | # sizer_path contains the sizer_path of the sizer_path that couldn't be removed
93 | # let's just assume that it's read-only and unlink it.
94 | os.chmod(path, stat.S_IWRITE)
95 | os.unlink(path)
96 |
97 |
98 | def remove(path):
99 | if fileExists(path):
100 | os.remove(path)
101 | elif dirExists(path):
102 | shutil.rmtree(path, onerror=on_rm_error)
103 |
104 |
105 |
106 | def runCommand(command):
107 | result = os.system(command)
108 | if result != 0:
109 | raise Exception(f'Command failed : "{command}"')
110 | else:
111 | return result
112 |
113 |
114 | def writeFile(path, lines):
115 | with open(path, 'w') as fOut:
116 | for line in lines:
117 | fOut.write(line)
118 | fOut.write("\n")
119 |
120 |
121 | def copyFromTo(from_path, to_path, overwrite=True, ignore_files: list = None):
122 | if fileExists(from_path):
123 | if overwrite:
124 | remove(to_path)
125 | ensureDir(getParentDir(to_path))
126 | shutil.copyfile(from_path, to_path)
127 | elif dirExists(from_path):
128 | if overwrite:
129 | remove(to_path)
130 | shutil.copytree(from_path, to_path)
131 |
132 |
133 | def getParentDir(path):
134 | return os.path.dirname(path)
135 |
136 |
137 | def listDir(path,
138 | formats='', recursive=False):
139 | if isinstance(formats, str):
140 | formats = [formats]
141 | if recursive is True:
142 | listOfFile = os.listdir(path)
143 | allFiles = list()
144 | for entry in listOfFile:
145 | fullPath = os.path.join(path, entry)
146 | if os.path.isdir(fullPath):
147 | allFiles = allFiles + listDir(fullPath, recursive=recursive, formats=formats)
148 | else:
149 | for format in formats:
150 | if fullPath.endswith(format):
151 | allFiles.append(fullPath)
152 | break
153 | return allFiles
154 | else:
155 | return [os.path.join(path, i) for i in os.listdir(path) if any(i.endswith(format) for format in formats)]
156 |
157 | def walkDir(paths: Union[str, list], extension="", ignore=[]) -> list:
158 | if isinstance(paths, str):
159 | paths = [paths]
160 | files = []
161 | for dir_path in paths:
162 | for current_dir_path, current_subdirs, current_files in os.walk(dir_path):
163 | for aFile in current_files:
164 | if aFile.endswith(extension):
165 | txt_file_path = str(os.path.join(current_dir_path, aFile))
166 | if not any(word in txt_file_path for word in ignore):
167 | files.append(txt_file_path)
168 | return list(files)
169 |
170 |
171 | def copyDir(src, dst):
172 | src = Path(src)
173 | dst = Path(dst)
174 | ensureDir(dst)
175 | for item in os.listdir(src):
176 | s = src / item
177 | d = dst / item
178 | if dirExists(s):
179 | copyDir(s, d)
180 | else:
181 | shutil.copy2(str(s), str(d))
182 |
183 |
184 | def zipFiles(src: list, dst: str, arcname=None):
185 | zip_ = zipfile.ZipFile(dst, 'w')
186 | for i in range(len(src)):
187 | if arcname is None:
188 | zip_.write(src[i], os.path.basename(src[i]), compress_type=zipfile.ZIP_DEFLATED)
189 | else:
190 | zip_.write(src[i], arcname[i], compress_type=zipfile.ZIP_DEFLATED)
191 | zip_.close()
192 |
193 |
194 | def zipDir(path, save_to):
195 | zip_ = zipfile.ZipFile(save_to, 'w')
196 | for root, dirs, files in os.walk(path):
197 | for file in files:
198 | zip_.write(os.path.join(root, file),
199 | os.path.relpath(os.path.join(root, file),
200 | os.path.join(path, '../..')))
201 | zip_.close()
202 |
203 |
204 | def recursiveOverwrite(src, dest, ignore=None):
205 | if dirExists(src):
206 | if not dirExists(dest):
207 | ensureDir(dest)
208 | files = listDir(src)
209 | if ignore is not None:
210 | ignored = ignore(src, files)
211 | else:
212 | ignored = set()
213 | for f in files:
214 | if f not in ignored:
215 | recursiveOverwrite(joinPaths([src, f]),
216 | joinPaths([dest, f]),
217 | ignore)
218 | else:
219 | shutil.copyfile(src, dest)
220 |
221 |
222 |
223 | def loadPickle(file_path):
224 | with open(file_path, 'rb') as handle:
225 | return pickle.load(handle)
226 |
227 |
228 | def savePickle(object, file_path):
229 | with open(file_path, 'openpyxl_sizer_workbook') as handle:
230 | pickle.dump(object, handle, protocol=pickle.HIGHEST_PROTOCOL)
231 |
232 |
233 | def getComputerStats():
234 | gb_divider = (1024.0 ** 3)
235 | stats = {
236 | "cpu_percent": psutil.cpu_percent(),
237 | "ram_available": psutil.virtual_memory().available/gb_divider,
238 | "ram_total": psutil.virtual_memory().total/gb_divider,
239 | "ram_percent": psutil.virtual_memory().available * 100 / psutil.virtual_memory().total
240 | }
241 | return stats
242 |
243 |
244 | def getTempDir(folder_name, ensure_dir=True):
245 | path = joinPaths([getUserHomePath(),"tmp",folder_name])
246 | if ensure_dir: ensureDir(path)
247 | return path
248 |
--------------------------------------------------------------------------------
/tinyllm/util/parse_util.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from typing import List, Union, Dict, Any
4 |
5 | import re
6 | from typing import List
7 |
8 | def extract_blocks(text: str, language: str= 'json') -> List[Union[Dict[str, Any], str]]:
9 | pattern = rf'```{language}\s*(.*?)\s*```'
10 | matches = re.findall(pattern, text.strip(), re.DOTALL)
11 | extracted_blocks = [json.loads(match) if language == 'json' else match for match in matches]
12 | return extracted_blocks
13 |
14 | def extract_html(text: str, tag='prompt') -> List[str]:
15 | pattern = fr'<{tag}>(.*?){tag}>'
16 | matches = re.findall(pattern, text, re.DOTALL)
17 | return matches
18 |
--------------------------------------------------------------------------------
/tinyllm/util/prompt_util.py:
--------------------------------------------------------------------------------
1 | import random
2 | import re
3 | from typing import List, Dict, Any, Optional
4 |
5 | from fuzzywuzzy import fuzz
6 | from tinyllm.util import os_util
7 |
8 | from typing import List, Type, get_args, get_origin
9 | from pydantic import BaseModel, Field
10 |
11 | import inspect
12 |
13 | def extract_function_signature(function):
14 | sig = inspect.signature(function)
15 | return {k: str(v.annotation) for k, v in sig.parameters.items() if v.default != inspect.Parameter.empty}
16 |
17 | def extract_models(model: Type[BaseModel], models_list: List[Type[BaseModel]]) -> None:
18 | if model not in models_list:
19 | models_list.append(model)
20 | fields = model.__fields__
21 | for field_name, field in fields.items():
22 | field_type = field.annotation
23 | origin = get_origin(field_type)
24 | args = get_args(field_type)
25 |
26 | if isinstance(field_type, type) and issubclass(field_type, BaseModel):
27 | extract_models(field_type, models_list)
28 | elif origin is not None:
29 | for arg in args:
30 | if isinstance(arg, type) and issubclass(arg, BaseModel):
31 | extract_models(arg, models_list)
32 |
33 |
34 | def model_to_string(model: Type[BaseModel], indent=4) -> str:
35 | fields = model.__fields__
36 | field_defs = []
37 | for field_name, field in fields.items():
38 | field_type = field.annotation
39 |
40 | origin = get_origin(field_type)
41 | origin = str(origin).replace("","").capitalize()
42 | if origin:
43 | field_type_name = f"{str(origin)}[{str(field_type).split('.')[-1].replace(']', '')}]"
44 | else:
45 | field_type_name = field_type.__name__
46 |
47 | description = field.description
48 | description = f" | description: {description}" if description else ""
49 | field_defs.append(f"{' ' * indent}- {field_name}: {field_type_name}{description}")
50 |
51 | model_str = f"{model.__name__}:\n" + "\n".join(field_defs)
52 | return model_str
53 |
54 |
55 | def pydantic_model_to_string(root_model: Type[BaseModel]) -> str:
56 | models_list = []
57 | extract_models(root_model, models_list)
58 | models_strings = []
59 | for model in models_list[1:]:
60 | model_str = model_to_string(model)
61 | models_strings.append(model_str)
62 | original_model_string = f"{model_to_string(models_list[0])}\n\n WHERE \n\n"
63 | recurive_models = " AND \n\n".join(models_strings)
64 | return original_model_string + recurive_models
65 |
66 |
67 |
68 | def stringify_string_list(paragraphs: List[str],
69 | separator="\n") -> str:
70 | """
71 | Concatenates a list of strings with newline separator.
72 |
73 | :param paragraphs: A list of strings to concatenate.
74 | :return: A string concatenated with newline separator.
75 | """
76 | return separator.join(paragraphs)
77 |
78 |
79 | def stringify_key_value(key: str, value: Any) -> str:
80 | """
81 | Formats a string based on a key-value pair.
82 |
83 | :param key: The key of the pair.
84 | :param value: The value of the pair.
85 | :return: A formatted string.
86 | """
87 | return f"- {key}: {value}"
88 |
89 |
90 | def stringify_dict(header: str,
91 | dict: Dict[str, Any],
92 | include_keys: Optional[List[str]] = []) -> str:
93 | """
94 | Formats a dictionary into a string with a specific format.
95 |
96 | :param dict: A dictionary to format.
97 | :param include_keys: A list of keys to include. Default is None, which includes all keys.
98 | :return: A formatted string.
99 | """
100 | all_strings = []
101 | # if there are included ids, make sure the dict is filtered and follows the same order
102 | dict = {k: dict[k] for k in include_keys} if include_keys else dict
103 | for key, value in dict.items():
104 | # Include the key only if include_keys is None (include all keys) or the key is in include_keys
105 | if value is None:
106 | value = ""
107 | if key in ['created_at', 'updated_at', 'timestamp']:
108 | value = str(value).split('+')[0] if '+' in str(value) else str(value)
109 | generated_string = stringify_key_value(key, str(value).split('+')[0])
110 | all_strings.append(generated_string)
111 |
112 | dict_string_representation = stringify_string_list(all_strings, separator="\n")
113 | return header + "\n" + dict_string_representation
114 |
115 |
116 | def stringify_dict_list(header: str,
117 | dict_list: List[Dict[str, Any]],
118 | include_keys: Optional[List[str]] = None) -> str:
119 | """
120 | Formats a list of dictionaries into a string with a specific format.
121 |
122 | :param dict_list: A list of dictionaries to format.
123 | :param include_keys: A list of keys to include. Default is None, which includes all keys.
124 | :return: A formatted string.
125 | """
126 | all_strings = []
127 | for dict in dict_list:
128 | dict_string_representation = stringify_dict("", dict, include_keys)
129 | all_strings.append(dict_string_representation)
130 |
131 | dict_list_string_representation = stringify_string_list(all_strings, separator="\n\n")
132 | return header + "\n" + dict_list_string_representation
133 |
134 |
135 | def remove_imports(code: str) -> str:
136 | lines = code.split('\n')
137 | lines = [line for line in lines if not line.lstrip().startswith(('import', 'from'))]
138 | return '\n'.join(lines)
139 |
140 |
141 | def extract_markdown_python(text: str):
142 | if '```python' not in text:
143 | return text
144 | pattern = r"```python(.*?)```"
145 | python_codes = re.findall(pattern, text, re.DOTALL)
146 | return "\n".join(python_codes)
147 |
148 |
149 | def get_files_content(file_list: list,
150 | formats: list):
151 | code_context = []
152 | for file_name in file_list:
153 | if os_util.isDirPath(file_name):
154 | for file_path in os_util.listDir(file_name, recursive=True, formats=formats):
155 | try:
156 | with open(file_path, 'r') as file:
157 | content = file.read()
158 | code_context.append(f'\n \nFILE: This is the content of the file {file_name}:\n \n {content}\n')
159 | code_context.append(f'\n------------------------\n')
160 | except FileNotFoundError:
161 | print(f'File {file_name} not found in the directory')
162 |
163 | else:
164 | try:
165 | with open(file_name, 'r') as file:
166 | content = file.read()
167 | code_context.append(f'\n \nFILE: This is the content of the file {file_name}:\n \n {content}\n')
168 | code_context.append(f'\n------------------------\n')
169 | except FileNotFoundError:
170 | print(f'File {file_name} not found in the directory')
171 |
172 | final_prompt = '\n'.join(code_context)
173 | return final_prompt
174 |
175 |
176 | def shuffle_with_freeze(input_list, freeze):
177 | not_frozen_dict = {i: input_list[i] for i in range(len(input_list)) if i not in freeze}
178 | not_frozen_indices = list(not_frozen_dict.keys())
179 | random.shuffle(not_frozen_indices)
180 | shuffled_dict = {i: not_frozen_dict[not_frozen_indices[i]] for i in range(len(not_frozen_indices))}
181 | output_list = [shuffled_dict.get(i) if i in shuffled_dict else input_list[i] for i in range(len(input_list))]
182 | return output_list
183 |
184 |
185 | def remove_duplicate_lines(input_string: str) -> str:
186 | lines = input_string.split('\n')
187 | seen_lines = set()
188 | unique_lines = []
189 | for line in lines:
190 | trimmed_line = line.strip() # Removing leading and trailing whitespaces
191 | if trimmed_line and trimmed_line not in seen_lines:
192 | seen_lines.add(trimmed_line)
193 | unique_lines.append(trimmed_line)
194 | return '\n'.join(unique_lines)
195 |
196 |
197 | def find_closest_match_char_by_char(source, target):
198 | max_ratio = 0
199 | best_match = (0, 0)
200 | n = len(source)
201 |
202 | for start in range(n):
203 | for end in range(start, n):
204 | substring = source[start:end + 1]
205 | ratio = fuzz.token_set_ratio(substring, target)
206 | if ratio > max_ratio:
207 | max_ratio = ratio
208 | best_match = (start, end)
209 |
210 | return best_match
211 |
212 |
213 | def get_smallest_chunk(source, matches):
214 | # Sort matches by start index
215 | matches.sort(key=lambda x: x[0])
216 |
217 | min_chunk = (0, len(source))
218 | for i in range(len(matches)):
219 | for j in range(i + 1, len(matches)):
220 | if matches[j][0] > matches[i][1]: # Ensuring the second element starts after the first
221 | chunk_size = matches[j][1] - matches[i][0]
222 | if chunk_size < (min_chunk[1] - min_chunk[0]):
223 | min_chunk = (matches[i][0], matches[j][1])
224 | break # No need to check further as we are looking for smallest chunk
225 |
226 | return min_chunk
227 |
228 |
229 | def preprocess_text(text):
230 | # Convert to lower case and remove special characters
231 | return re.sub(r'[^a-zA-Z0-9\s]', '', text.lower())
232 |
233 | def blockify(text, title=None):
234 | title = title.upper() if title else None
235 | if title:
236 | return f"<{title}>\n{text}\n{title}>\n\n"
237 | else:
238 | return text
239 |
240 |
241 | INSTRUCTIONS_BOOSTS = [
242 | 'You will be given $500 tip if you follow the instructions',
243 | 'This is important for my career',
244 | ]
245 |
--------------------------------------------------------------------------------
/tinyllm/validator.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Type
2 |
3 | from pydantic import BaseModel, ValidationError
4 |
5 |
6 | class Validator(BaseModel):
7 |
8 | def __init__(self, **data: Any):
9 | if not data:
10 | raise ValidationError("At least one argument is required")
11 | super().__init__(**data)
12 |
13 | class Config:
14 | extra = 'allow'
15 | arbitrary_types_allowed = True
16 |
--------------------------------------------------------------------------------