├── .coveragerc ├── .github ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml └── workflows │ ├── publish_pypi.yml │ └── test.yml ├── .gitignore ├── .pydocstyle ├── AUTHORS.md ├── CHANGELOG.md ├── LICENSE ├── README.md ├── SECURITY.md ├── autopep8.bat ├── autopep8.sh ├── codecov.yml ├── dev-requirements.txt ├── memor ├── __init__.py ├── errors.py ├── functions.py ├── keywords.py ├── params.py ├── prompt.py ├── response.py ├── session.py ├── template.py └── tokens_estimator.py ├── otherfiles ├── RELEASE.md ├── donation.png ├── meta.yaml ├── requirements-splitter.py └── version_check.py ├── requirements.txt ├── setup.py └── tests ├── test_prompt.py ├── test_prompt_template.py ├── test_response.py ├── test_session.py └── test_token_estimators.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | omit = 4 | */memor/__main__.py 5 | [report] 6 | # Regexes for lines to exclude from consideration 7 | exclude_lines = 8 | pragma: no cover 9 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a professional setting 37 | 38 | ## Enforcement Responsibilities 39 | 40 | Community leaders are responsible for clarifying and enforcing our standards of 41 | acceptable behavior and will take appropriate and fair corrective action in 42 | response to any behavior that they deem inappropriate, threatening, offensive, 43 | or harmful. 44 | 45 | Community leaders have the right and responsibility to remove, edit, or reject 46 | comments, commits, code, wiki edits, issues, and other contributions that are 47 | not aligned to this Code of Conduct, and will communicate reasons for moderation 48 | decisions when appropriate. 49 | 50 | ## Scope 51 | This Code of Conduct applies both within project spaces and in public spaces 52 | when an individual is representing the project or its community. 53 | Examples of representing our community include using an official e-mail address, 54 | posting via an official social media account, or acting as an appointed 55 | representative at an online or offline event. 56 | 57 | ## Enforcement 58 | 59 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 60 | reported to the community leaders responsible for enforcement at 61 | memor@openscilab.com. 62 | All complaints will be reviewed and investigated promptly and fairly. 63 | 64 | All community leaders are obligated to respect the privacy and security of the 65 | reporter of any incident. 66 | 67 | ## Enforcement Guidelines 68 | 69 | Community leaders will follow these Community Impact Guidelines in determining 70 | the consequences for any action they deem in violation of this Code of Conduct: 71 | 72 | ### 1. Correction 73 | 74 | **Community Impact**: Use of inappropriate language or other behavior deemed 75 | unprofessional or unwelcome in the community. 76 | 77 | **Consequence**: A private, written warning from community leaders, providing 78 | clarity around the nature of the violation and an explanation of why the 79 | behavior was inappropriate. A public apology may be requested. 80 | 81 | ### 2. Warning 82 | 83 | **Community Impact**: A violation through a single incident or series of 84 | actions. 85 | 86 | **Consequence**: A warning with consequences for continued behavior. No 87 | interaction with the people involved, including unsolicited interaction with 88 | those enforcing the Code of Conduct, for a specified period of time. This 89 | includes avoiding interactions in community spaces as well as external channels 90 | like social media. Violating these terms may lead to a temporary or permanent 91 | ban. 92 | 93 | ### 3. Temporary Ban 94 | 95 | **Community Impact**: A serious violation of community standards, including 96 | sustained inappropriate behavior. 97 | 98 | **Consequence**: A temporary ban from any sort of interaction or public 99 | communication with the community for a specified period of time. No public or 100 | private interaction with the people involved, including unsolicited interaction 101 | with those enforcing the Code of Conduct, is allowed during this period. 102 | Violating these terms may lead to a permanent ban. 103 | 104 | ### 4. Permanent Ban 105 | 106 | **Community Impact**: Demonstrating a pattern of violation of community 107 | standards, including sustained inappropriate behavior, harassment of an 108 | individual, or aggression toward or disparagement of classes of individuals. 109 | 110 | **Consequence**: A permanent ban from any sort of public interaction within the 111 | community. 112 | 113 | ## Attribution 114 | 115 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 116 | version 2.1, available at 117 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 118 | 119 | Community Impact Guidelines were inspired by 120 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 121 | 122 | For answers to common questions about this code of conduct, see the FAQ at 123 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 124 | [https://www.contributor-covenant.org/translations][translations]. 125 | 126 | [homepage]: https://www.contributor-covenant.org 127 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 128 | [Mozilla CoC]: https://github.com/mozilla/diversity 129 | [FAQ]: https://www.contributor-covenant.org/faq 130 | [translations]: https://www.contributor-covenant.org/translations 131 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution 2 | 3 | Changes and improvements are more than welcome! ❤️ Feel free to fork and open a pull request. 4 | 5 | 6 | Please consider the following : 7 | 8 | 9 | 1. Fork it! 10 | 2. Create your feature branch (under `dev` branch) 11 | 3. Add your functions/methods to proper files 12 | 4. Add standard `docstring` to your functions/methods 13 | 5. Add tests for your functions/methods (`unittest` testcases in `tests` folder) 14 | 6. Pass all CI tests 15 | 7. Update `CHANGELOG.md` 16 | - Describe changes under `[Unreleased]` section 17 | 8. Submit a pull request into `dev` (please complete the pull request template) 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: File a bug report 3 | title: "[Bug]: " 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for your time to fill out this bug report! 9 | - type: input 10 | id: contact 11 | attributes: 12 | label: Contact details 13 | description: How can we get in touch with you if we need more info? 14 | placeholder: ex. email@example.com 15 | validations: 16 | required: false 17 | - type: textarea 18 | id: what-happened 19 | attributes: 20 | label: What happened? 21 | description: Provide a clear and concise description of what the bug is. 22 | placeholder: > 23 | Tell us a description of the bug. 24 | validations: 25 | required: true 26 | - type: textarea 27 | id: step-to-reproduce 28 | attributes: 29 | label: Steps to reproduce 30 | description: Provide details of how to reproduce the bug. 31 | placeholder: > 32 | ex. 1. Go to '...' 33 | validations: 34 | required: true 35 | - type: textarea 36 | id: expected-behavior 37 | attributes: 38 | label: Expected behavior 39 | description: What did you expect to happen? 40 | placeholder: > 41 | ex. I expected '...' to happen 42 | validations: 43 | required: true 44 | - type: textarea 45 | id: actual-behavior 46 | attributes: 47 | label: Actual behavior 48 | description: What did actually happen? 49 | placeholder: > 50 | ex. Instead '...' happened 51 | validations: 52 | required: true 53 | - type: dropdown 54 | id: operating-system 55 | attributes: 56 | label: Operating system 57 | description: Which operating system are you using? 58 | options: 59 | - Windows 60 | - macOS 61 | - Linux 62 | default: 0 63 | validations: 64 | required: true 65 | - type: dropdown 66 | id: python-version 67 | attributes: 68 | label: Python version 69 | description: Which version of Python are you using? 70 | options: 71 | - Python 3.13 72 | - Python 3.12 73 | - Python 3.11 74 | - Python 3.10 75 | - Python 3.9 76 | - Python 3.8 77 | - Python 3.7 78 | - Python 3.6 79 | default: 1 80 | validations: 81 | required: true 82 | - type: dropdown 83 | id: memor-version 84 | attributes: 85 | label: Memor version 86 | description: Which version of Memor are you using? 87 | options: 88 | - Memor 0.6 89 | - Memor 0.5 90 | - Memor 0.4 91 | - Memor 0.3 92 | - Memor 0.2 93 | - Memor 0.1 94 | default: 0 95 | validations: 96 | required: true 97 | - type: textarea 98 | id: logs 99 | attributes: 100 | label: Relevant log output 101 | description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. 102 | render: shell 103 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Discord 4 | url: https://discord.gg/cZxGwZ6utB 5 | about: Ask questions and discuss with other Memor community members 6 | - name: Website 7 | url: https://openscilab.com/ 8 | about: Check out our website for more information 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Suggest a feature for this project 3 | title: "[Feature]: " 4 | body: 5 | - type: textarea 6 | id: description 7 | attributes: 8 | label: Describe the feature you want to add 9 | placeholder: > 10 | I'd like to be able to [...] 11 | validations: 12 | required: true 13 | - type: textarea 14 | id: possible-solution 15 | attributes: 16 | label: Describe your proposed solution 17 | placeholder: > 18 | I think this could be done by [...] 19 | validations: 20 | required: false 21 | - type: textarea 22 | id: alternatives 23 | attributes: 24 | label: Describe alternatives you've considered, if relevant 25 | placeholder: > 26 | Another way to do this would be [...] 27 | validations: 28 | required: false 29 | - type: textarea 30 | id: additional-context 31 | attributes: 32 | label: Additional context 33 | placeholder: > 34 | Add any other context or screenshots about the feature request here. 35 | validations: 36 | required: false 37 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | #### Reference Issues/PRs 2 | 3 | #### What does this implement/fix? Explain your changes. 4 | 5 | #### Any other comments? 6 | 7 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: weekly 7 | time: "01:30" 8 | open-pull-requests-limit: 10 9 | target-branch: dev 10 | assignees: 11 | - "sadrasabouri" 12 | - "sepandhaghighi" 13 | -------------------------------------------------------------------------------- /.github/workflows/publish_pypi.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | push: 8 | # Sequence of patterns matched against refs/tags 9 | tags: 10 | - '*' # Push events to matching v*, i.e. v1.0, v20.15.10 11 | 12 | jobs: 13 | deploy: 14 | 15 | runs-on: ubuntu-22.04 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: '3.x' 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install setuptools wheel twine 27 | - name: Build and publish 28 | env: 29 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 30 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 31 | run: | 32 | python setup.py sdist bdist_wheel 33 | twine upload dist/*.tar.gz 34 | twine upload dist/*.whl -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: 9 | - main 10 | - dev 11 | 12 | pull_request: 13 | branches: 14 | - dev 15 | - main 16 | 17 | env: 18 | TEST_PYTHON_VERSION: 3.9 19 | TEST_OS: 'ubuntu-22.04' 20 | 21 | jobs: 22 | build: 23 | runs-on: ${{ matrix.os }} 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | os: [ubuntu-22.04, windows-2022, macOS-13] 28 | python-version: [3.7, 3.8, 3.9, 3.10.5, 3.11.0, 3.12.0, 3.13.0] 29 | steps: 30 | - uses: actions/checkout@v2 31 | - name: Set up Python ${{ matrix.python-version }} 32 | uses: actions/setup-python@v2 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | - name: Installation 36 | run: | 37 | python -m pip install --upgrade pip 38 | pip install . 39 | - name: Test requirements installation 40 | run: | 41 | python otherfiles/requirements-splitter.py 42 | pip install --upgrade --upgrade-strategy=only-if-needed -r test-requirements.txt 43 | - name: Test with pytest 44 | run: | 45 | python -m pytest . --cov=memor --cov-report=term 46 | - name: Upload coverage to Codecov 47 | uses: codecov/codecov-action@v4 48 | with: 49 | fail_ci_if_error: true 50 | token: ${{ secrets.CODECOV_TOKEN }} 51 | if: matrix.python-version == env.TEST_PYTHON_VERSION && matrix.os == env.TEST_OS 52 | - name: Vulture, Bandit and Pydocstyle tests 53 | run: | 54 | python -m vulture memor/ otherfiles/ setup.py --min-confidence 65 --exclude=__init__.py --sort-by-size 55 | python -m bandit -r memor -s B311 56 | python -m pydocstyle -v 57 | if: matrix.python-version == env.TEST_PYTHON_VERSION 58 | - name: Version check 59 | run: | 60 | python otherfiles/version_check.py 61 | if: matrix.python-version == env.TEST_PYTHON_VERSION 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *,cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # dotenv 81 | .env 82 | 83 | # virtualenv 84 | .venv/ 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | 91 | # Rope project settings 92 | .ropeproject 93 | ### Example user template template 94 | ### Example user template 95 | 96 | # IntelliJ project files 97 | .idea 98 | *.iml 99 | out 100 | gen 101 | 102 | # Outputs 103 | 104 | prompt_test1.json 105 | prompt_test2.json 106 | prompt_test3.json 107 | response_test1.json 108 | response_test2.json 109 | response_test3.json 110 | session_test1.json 111 | template_test1.json 112 | template_test2.json 113 | -------------------------------------------------------------------------------- /.pydocstyle: -------------------------------------------------------------------------------- 1 | [pydocstyle] 2 | match_dir = ^(?!(tests|build)).* 3 | match = .*\.py 4 | 5 | -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | # Core Developers 2 | ---------- 3 | - Sepand Haghighi - Open Science Laboratory ([Github](https://github.com/sepandhaghighi)) ** 4 | - Sadra Sabouri - Open Science Laboratory ([Github](https://github.com/sadrasabouri)) ** 5 | 6 | ** **Maintainer** 7 | 8 | # Other Contributors 9 | ---------- 10 | 11 | 12 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) 5 | and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). 6 | 7 | ## [Unreleased] 8 | ## [0.6] - 2025-05-05 9 | ### Added 10 | - `Response` class `id` property 11 | - `Prompt` class `id` property 12 | - `Response` class `regenerate_id` method 13 | - `Prompt` class `regenerate_id` method 14 | - `Session` class `render_counter` method 15 | - `Session` class `remove_message_by_index` and `remove_message_by_id` methods 16 | - `Session` class `get_message_by_index`, `get_message_by_id` and `get_message` methods 17 | - `LLMModel` enum 18 | - `AI_STUDIO` render format 19 | ### Changed 20 | - Test system modified 21 | - Modification handling centralized via `_mark_modified` method 22 | - `Session` class `remove_message` method modified 23 | ## [0.5] - 2025-04-16 24 | ### Added 25 | - `Session` class `check_render` method 26 | - `Session` class `clear_messages` method 27 | - `Prompt` class `check_render` method 28 | - `Session` class `estimate_tokens` method 29 | - `Prompt` class `estimate_tokens` method 30 | - `Response` class `estimate_tokens` method 31 | - `universal_tokens_estimator` function 32 | - `openai_tokens_estimator_gpt_3_5` function 33 | - `openai_tokens_estimator_gpt_4` function 34 | ### Changed 35 | - `init_check` parameter added to `Prompt` class 36 | - `init_check` parameter added to `Session` class 37 | - Test system modified 38 | - `Python 3.6` support dropped 39 | - `README.md` updated 40 | ## [0.4] - 2025-03-17 41 | ### Added 42 | - `Session` class `__contains__` method 43 | - `Session` class `__getitem__` method 44 | - `Session` class `mask_message` method 45 | - `Session` class `unmask_message` method 46 | - `Session` class `masks` attribute 47 | - `Response` class `__len__` method 48 | - `Prompt` class `__len__` method 49 | ### Changed 50 | - `inference_time` parameter added to `Response` class 51 | - `README.md` updated 52 | - Test system modified 53 | - Python typing features added to all modules 54 | - `Prompt` class default values updated 55 | - `Response` class default values updated 56 | ## [0.3] - 2025-03-08 57 | ### Added 58 | - `Session` class `__len__` method 59 | - `Session` class `__iter__` method 60 | - `Session` class `__add__` and `__radd__` methods 61 | ### Changed 62 | - `tokens` parameter added to `Prompt` class 63 | - `tokens` parameter added to `Response` class 64 | - `tokens` parameter added to preset templates 65 | - `Prompt` class modified 66 | - `Response` class modified 67 | - `PromptTemplate` class modified 68 | ## [0.2] - 2025-03-01 69 | ### Added 70 | - `Session` class 71 | ### Changed 72 | - `Prompt` class modified 73 | - `Response` class modified 74 | - `PromptTemplate` class modified 75 | - `README.md` updated 76 | - Test system modified 77 | ## [0.1] - 2025-02-12 78 | ### Added 79 | - `Prompt` class 80 | - `Response` class 81 | - `PromptTemplate` class 82 | - `PresetPromptTemplate` class 83 | 84 | 85 | [Unreleased]: https://github.com/openscilab/memor/compare/v0.6...dev 86 | [0.6]: https://github.com/openscilab/memor/compare/v0.5...v0.6 87 | [0.5]: https://github.com/openscilab/memor/compare/v0.4...v0.5 88 | [0.4]: https://github.com/openscilab/memor/compare/v0.3...v0.4 89 | [0.3]: https://github.com/openscilab/memor/compare/v0.2...v0.3 90 | [0.2]: https://github.com/openscilab/memor/compare/v0.1...v0.2 91 | [0.1]: https://github.com/openscilab/memor/compare/6594313...v0.1 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 OpenSciLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Memor: A Python Library for Managing and Transferring Conversational Memory Across LLMs

3 |
4 | 5 | PyPI version 6 | built with Python3 7 | GitHub repo size 8 | Discord Channel 9 |
10 | 11 | ---------- 12 | 13 | 14 | ## Overview 15 |

16 | Memor is a library designed to help users manage the memory of their interactions with Large Language Models (LLMs). 17 | It enables users to seamlessly access and utilize the history of their conversations when prompting LLMs. 18 | That would create a more personalized and context-aware experience. 19 | Memor stands out by allowing users to transfer conversational history across different LLMs, eliminating cold starts where models don't have information about user and their preferences. 20 | Users can select specific parts of past interactions with one LLM and share them with another. 21 | By bridging the gap between isolated LLM instances, Memor revolutionizes the way users interact with AI by making transitions between models smoother. 22 | 23 |

24 | 25 | 26 | 27 | 32 | 33 | 34 | 35 | 40 | 41 |
PyPI Counter 28 | 29 | 30 | 31 |
Github Stars 36 | 37 | 38 | 39 |
42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 53 | 56 | 57 |
Branchmaindev
CI 51 | 52 | 54 | 55 |
58 | 59 | 60 | 61 | 62 | 63 | 64 |
Code QualityCodeFactor
65 | 66 | 67 | ## Installation 68 | 69 | ### PyPI 70 | - Check [Python Packaging User Guide](https://packaging.python.org/installing/) 71 | - Run `pip install memor==0.6` 72 | ### Source code 73 | - Download [Version 0.6](https://github.com/openscilab/memor/archive/v0.6.zip) or [Latest Source](https://github.com/openscilab/memor/archive/dev.zip) 74 | - Run `pip install .` 75 | 76 | ## Usage 77 | Define your prompt and the response(s) to that; Memor will wrap it into a object with a templated representation. 78 | You can create a session by combining multiple prompts and responses, gradually building it up: 79 | 80 | ```pycon 81 | >>> from memor import Session, Prompt, Response, Role 82 | >>> from memor import PresetPromptTemplate, RenderFormat, LLMModel 83 | >>> response = Response(message="I am fine.", model=LLMModel.GPT_4, role=Role.ASSISTANT, temperature=0.9, score=0.9) 84 | >>> prompt = Prompt(message="Hello, how are you?", 85 | responses=[response], 86 | role=Role.USER, 87 | template=PresetPromptTemplate.INSTRUCTION1.PROMPT_RESPONSE_STANDARD) 88 | >>> system_prompt = Prompt(message="You are a friendly and informative AI assistant designed to answer questions on a wide range of topics.", 89 | role=Role.SYSTEM) 90 | >>> session = Session(messages=[system_prompt, prompt]) 91 | >>> session.render(RenderFormat.OPENAI) 92 | ``` 93 | 94 | The rendered output will be a list of messages formatted for compatibility with the OpenAI API. 95 | 96 | ```json 97 | [{"content": "You are a friendly and informative AI assistant designed to answer questions on a wide range of topics.", "role": "system"}, 98 | {"content": "I'm providing you with a history of a previous conversation. Please consider this context when responding to my new question.\n" 99 | "Prompt: Hello, how are you?\n" 100 | "Response: I am fine.", 101 | "role": "user"}] 102 | ``` 103 | 104 | ### Prompt Templates 105 | 106 | #### Preset Templates 107 | 108 | Memor provides a variety of pre-defined prompt templates to control how prompts and responses are rendered. Each template is prefixed by an optional instruction string and includes variations for different formatting styles. Following are different variants of parameters: 109 | 110 | | **Instruction Name** | **Description** | 111 | |---------------|----------| 112 | | `INSTRUCTION1` | "I'm providing you with a history of a previous conversation. Please consider this context when responding to my new question." | 113 | | `INSTRUCTION2` | "Here is the context from a prior conversation. Please learn from this information and use it to provide a thoughtful and context-aware response to my next questions." | 114 | | `INSTRUCTION3` | "I am sharing a record of a previous discussion. Use this information to provide a consistent and relevant answer to my next query." | 115 | 116 | | **Template Title** | **Description** | 117 | |--------------|----------| 118 | | `PROMPT` | Only includes the prompt message. | 119 | | `RESPONSE` | Only includes the response message. | 120 | | `RESPONSE0` to `RESPONSE3` | Include specific responses from a list of multiple responses. | 121 | | `PROMPT_WITH_LABEL` | Prompt with a "Prompt: " prefix. | 122 | | `RESPONSE_WITH_LABEL` | Response with a "Response: " prefix. | 123 | | `RESPONSE0_WITH_LABEL` to `RESPONSE3_WITH_LABEL` | Labeled response for the i-th response. | 124 | | `PROMPT_RESPONSE_STANDARD` | Includes both labeled prompt and response on a single line. | 125 | | `PROMPT_RESPONSE_FULL` | A detailed multi-line representation including role, date, model, etc. | 126 | 127 | You can access them like this: 128 | 129 | ```pycon 130 | >>> from memor import PresetPromptTemplate 131 | >>> template = PresetPromptTemplate.INSTRUCTION1.PROMPT_RESPONSE_STANDARD 132 | ``` 133 | 134 | #### Custom Templates 135 | 136 | You can define custom templates for your prompts using the `PromptTemplate` class. This class provides two key parameters that control its functionality: 137 | 138 | + `content`: A string that defines the template structure, following Python string formatting conventions. You can include dynamic fields using placeholders like `{field_name}`, which will be automatically populated using attributes from the prompt object. Some common examples of auto-filled fields are shown below: 139 | 140 | | **Prompt Object Attribute** | **Placeholder Syntax** | **Description** | 141 | |--------------------------------------|------------------------------------|----------------------------------------------| 142 | | `prompt.message` | `{prompt[message]}` | The main prompt message | 143 | | `prompt.selected_response` | `{prompt[response]}` | The selected response for the prompt | 144 | | `prompt.date_modified` | `{prompt[date_modified]}` | Timestamp of the last modification | 145 | | `prompt.responses[2].message` | `{responses[2][message]}` | Message from the response at index 2 | 146 | | `prompt.responses[0].inference_time` | `{responses[0][inference_time]}` | Inference time for the response at index 0 | 147 | 148 | 149 | + `custom_map`: In addition to the attributes listed above, you can define and insert custom placeholders (e.g., `{field_name}`) and provide their values through a dictionary. When rendering the template, each placeholder will be replaced with its corresponding value from `custom_map`. 150 | 151 | 152 | Suppose you want to prepend an instruction to every prompt message. You can define and use a template as follows: 153 | 154 | ```pycon 155 | >>> template = PromptTemplate(content="{instruction}, {prompt[message]}", custom_map={"instruction": "Hi"}) 156 | >>> prompt = Prompt(message="How are you?", template=template) 157 | >>> prompt.render() 158 | Hi, How are you? 159 | ``` 160 | 161 | By using this dynamic structure, you can create flexible and sophisticated prompt templates with Memor. You can design specific schemas for your conversational or instructional formats when interacting with LLM. 162 | 163 | ## Issues & bug reports 164 | 165 | Just fill an issue and describe it. We'll check it ASAP! or send an email to [memor@openscilab.com](mailto:memor@openscilab.com "memor@openscilab.com"). 166 | 167 | - Please complete the issue template 168 | 169 | You can also join our discord server 170 | 171 | 172 | Discord Channel 173 | 174 | 175 | ## Show your support 176 | 177 | 178 | ### Star this repo 179 | 180 | Give a ⭐️ if this project helped you! 181 | 182 | ### Donate to our project 183 | If you do like our project and we hope that you do, can you please support us? Our project is not and is never going to be working for profit. We need the money just so we can continue doing what we do ;-) . 184 | 185 | Memor Donation -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security policy 2 | 3 | ## Supported versions 4 | 5 | | Version | Supported | 6 | | ------------- | ------------------ | 7 | | 0.6 | :white_check_mark: | 8 | | < 0.6 | :x: | 9 | 10 | ## Reporting a vulnerability 11 | 12 | Please report security vulnerabilities by email to [memor@openscilab.com](mailto:memor@openscilab.com "memor@openscilab.com"). 13 | 14 | If the security vulnerability is accepted, a dedicated bugfix release will be issued as soon as possible (depending on the complexity of the fix). -------------------------------------------------------------------------------- /autopep8.bat: -------------------------------------------------------------------------------- 1 | python -m autopep8 memor --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose --ignore=E721 2 | python -m autopep8 otherfiles --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose --ignore=E721 3 | python -m autopep8 tests --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose --ignore=E721 4 | python -m autopep8 setup.py --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose 5 | -------------------------------------------------------------------------------- /autopep8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python -m autopep8 memor --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose --ignore=E721 3 | python -m autopep8 otherfiles --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose --ignore=E721 4 | python -m autopep8 tests --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose --ignore=E721 5 | python -m autopep8 setup.py --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose 6 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: yes 3 | 4 | coverage: 5 | precision: 2 6 | round: up 7 | range: "70...100" 8 | status: 9 | patch: 10 | default: 11 | enabled: no 12 | project: 13 | default: 14 | threshold: 1% 15 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools>=40.8.0 2 | vulture>=1.0 3 | bandit>=1.5.1 4 | pydocstyle>=3.0.0 5 | pytest>=4.3.1 6 | pytest-cov>=2.6.1 -------------------------------------------------------------------------------- /memor/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Memor modules.""" 3 | from .params import MEMOR_VERSION, RenderFormat, LLMModel 4 | from .tokens_estimator import TokensEstimator 5 | from .template import PromptTemplate, PresetPromptTemplate 6 | from .prompt import Prompt, Role 7 | from .response import Response 8 | from .session import Session 9 | from .errors import MemorRenderError, MemorValidationError 10 | 11 | __version__ = MEMOR_VERSION 12 | -------------------------------------------------------------------------------- /memor/errors.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Memor errors.""" 3 | 4 | 5 | class MemorValidationError(ValueError): 6 | """Base class for validation errors in Memor.""" 7 | 8 | pass 9 | 10 | 11 | class MemorRenderError(Exception): 12 | """Base class for render error in Memor.""" 13 | 14 | pass 15 | -------------------------------------------------------------------------------- /memor/functions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Memor functions.""" 3 | from typing import Any, Type 4 | import os 5 | import datetime 6 | import uuid 7 | from .params import INVALID_DATETIME_MESSAGE 8 | from .params import INVALID_PATH_MESSAGE, INVALID_STR_VALUE_MESSAGE 9 | from .params import INVALID_PROB_VALUE_MESSAGE 10 | from .params import INVALID_POSFLOAT_VALUE_MESSAGE 11 | from .params import INVALID_POSINT_VALUE_MESSAGE 12 | from .params import INVALID_CUSTOM_MAP_MESSAGE 13 | from .params import INVALID_BOOL_VALUE_MESSAGE 14 | from .params import INVALID_LIST_OF_X_MESSAGE 15 | from .params import INVALID_ID_MESSAGE 16 | from .errors import MemorValidationError 17 | 18 | 19 | def generate_message_id() -> str: 20 | """Generate message ID.""" 21 | return str(uuid.uuid4()) 22 | 23 | 24 | def _validate_message_id(message_id: str) -> bool: 25 | """ 26 | Validate message ID. 27 | 28 | :param message_id: message ID 29 | """ 30 | try: 31 | _ = uuid.UUID(message_id, version=4) 32 | except ValueError: 33 | raise MemorValidationError(INVALID_ID_MESSAGE) 34 | return True 35 | 36 | 37 | def get_time_utc() -> datetime.datetime: 38 | """ 39 | Get time in UTC format. 40 | 41 | :return: UTC format time as a datetime object 42 | """ 43 | return datetime.datetime.now(datetime.timezone.utc) 44 | 45 | 46 | def _validate_string(value: Any, parameter_name: str) -> bool: 47 | """ 48 | Validate string. 49 | 50 | :param value: value 51 | :param parameter_name: parameter name 52 | """ 53 | if not isinstance(value, str): 54 | raise MemorValidationError(INVALID_STR_VALUE_MESSAGE.format(parameter_name=parameter_name)) 55 | return True 56 | 57 | 58 | def _validate_bool(value: Any, parameter_name: str) -> bool: 59 | """ 60 | Validate boolean. 61 | 62 | :param value: value 63 | :param parameter_name: parameter name 64 | """ 65 | if not isinstance(value, bool): 66 | raise MemorValidationError(INVALID_BOOL_VALUE_MESSAGE.format(parameter_name=parameter_name)) 67 | return True 68 | 69 | 70 | def _can_convert_to_string(value: Any) -> bool: 71 | """ 72 | Check if value can be converted to string. 73 | 74 | :param value: value 75 | """ 76 | try: 77 | str(value) 78 | except Exception: 79 | return False 80 | return True 81 | 82 | 83 | def _validate_pos_int(value: Any, parameter_name: str) -> bool: 84 | """ 85 | Validate positive integer. 86 | 87 | :param value: value 88 | :param parameter_name: parameter name 89 | """ 90 | if not isinstance(value, int) or value < 0: 91 | raise MemorValidationError(INVALID_POSINT_VALUE_MESSAGE.format(parameter_name=parameter_name)) 92 | return True 93 | 94 | 95 | def _validate_pos_float(value: Any, parameter_name: str) -> bool: 96 | """ 97 | Validate positive float. 98 | 99 | :param value: value 100 | :param parameter_name: parameter name 101 | """ 102 | if not isinstance(value, float) or value < 0: 103 | raise MemorValidationError(INVALID_POSFLOAT_VALUE_MESSAGE.format(parameter_name=parameter_name)) 104 | return True 105 | 106 | 107 | def _validate_probability(value: Any, parameter_name: str) -> bool: 108 | """ 109 | Validate probability (a float between 0 and 1). 110 | 111 | :param value: value 112 | :param parameter_name: parameter name 113 | """ 114 | if not isinstance(value, float) or value < 0 or value > 1: 115 | raise MemorValidationError(INVALID_PROB_VALUE_MESSAGE.format(parameter_name=parameter_name)) 116 | return True 117 | 118 | 119 | def _validate_list_of(value: Any, parameter_name: str, type_: Type, type_name: str) -> bool: 120 | """ 121 | Validate list of values. 122 | 123 | :param value: value 124 | :param parameter_name: parameter name 125 | :param type_: type 126 | :param type_name: type name 127 | """ 128 | if not isinstance(value, list): 129 | raise MemorValidationError(INVALID_LIST_OF_X_MESSAGE.format(parameter_name=parameter_name, type_name=type_name)) 130 | 131 | if not all(isinstance(x, type_) for x in value): 132 | raise MemorValidationError(INVALID_LIST_OF_X_MESSAGE.format(parameter_name=parameter_name, type_name=type_name)) 133 | return True 134 | 135 | 136 | def _validate_date_time(date_time: Any, parameter_name: str) -> bool: 137 | """ 138 | Validate date time. 139 | 140 | :param date_time: date time 141 | :param parameter_name: parameter name 142 | """ 143 | if not isinstance(date_time, datetime.datetime) or date_time.tzinfo is None: 144 | raise MemorValidationError(INVALID_DATETIME_MESSAGE.format(parameter_name=parameter_name)) 145 | return True 146 | 147 | 148 | def _validate_path(path: Any) -> bool: 149 | """ 150 | Validate path property. 151 | 152 | :param path: path 153 | """ 154 | if not isinstance(path, str) or not os.path.exists(path): 155 | raise FileNotFoundError(INVALID_PATH_MESSAGE.format(path=path)) 156 | return True 157 | 158 | 159 | def _validate_custom_map(custom_map: Any) -> bool: 160 | """ 161 | Validate custom map a dictionary with keys and values that can be converted to strings. 162 | 163 | :param custom_map: custom map 164 | """ 165 | if not isinstance(custom_map, dict): 166 | raise MemorValidationError(INVALID_CUSTOM_MAP_MESSAGE) 167 | if not all(_can_convert_to_string(k) and _can_convert_to_string(v) for k, v in custom_map.items()): 168 | raise MemorValidationError(INVALID_CUSTOM_MAP_MESSAGE) 169 | return True 170 | -------------------------------------------------------------------------------- /memor/keywords.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Tokens estimator keywords.""" 3 | 4 | COMMON_PREFIXES = {"un", "re", "in", "dis", "pre", "mis", "non", "over", "under", "sub", "trans"} 5 | 6 | COMMON_SUFFIXES = {"ing", "ed", "ly", "es", "s", "ment", "able", "ness", "tion", "ive", "ous"} 7 | 8 | PYTHON_KEYWORDS = {"if", "else", "elif", "while", "for", "def", "return", "import", "from", "class", 9 | "try", "except", "finally", "with", "as", "break", "continue", "pass", "lambda", 10 | "True", "False", "None", "and", "or", "not", "in", "is", "global", "nonlocal"} 11 | 12 | JAVASCRIPT_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break", 13 | "continue", "function", "return", "var", "let", "const", "class", "extends", 14 | "super", "import", "export", "try", "catch", "finally", "throw", "new", 15 | "delete", "typeof", "instanceof", "in", "void", "yield", "this", "async", 16 | "await", "static", "get", "set", "true", "false", "null", "undefined"} 17 | 18 | JAVA_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break", 19 | "continue", "return", "void", "int", "float", "double", "char", "long", "short", 20 | "boolean", "byte", "class", "interface", "extends", "implements", "new", "import", 21 | "package", "public", "private", "protected", "static", "final", "abstract", 22 | "try", "catch", "finally", "throw", "throws", "synchronized", "volatile", "transient", 23 | "native", "strictfp", "assert", "instanceof", "super", "this", "true", "false", "null"} 24 | 25 | 26 | C_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break", "continue", 27 | "return", "void", "char", "int", "float", "double", "short", "long", "signed", 28 | "unsigned", "struct", "union", "typedef", "enum", "const", "volatile", "extern", 29 | "register", "static", "auto", "sizeof", "goto"} 30 | 31 | 32 | CPP_KEYWORDS = C_KEYWORDS | {"new", "delete", "class", 33 | "public", "private", "protected", "namespace", "using", "template", "friend", 34 | "virtual", "inline", "operator", "explicit", "this", "true", "false", "nullptr"} 35 | 36 | 37 | CSHARP_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break", 38 | "continue", "return", "void", "int", "float", "double", "char", "long", "short", 39 | "bool", "byte", "class", "interface", "struct", "new", "namespace", "using", 40 | "public", "private", "protected", "static", "readonly", "const", "try", "catch", 41 | "finally", "throw", "async", "await", "true", "false", "null"} 42 | 43 | 44 | GO_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "break", "continue", "return", 45 | "func", "var", "const", "type", "struct", "interface", "map", "chan", "package", 46 | "import", "defer", "go", "select", "range", "fallthrough", "goto"} 47 | 48 | 49 | RUST_KEYWORDS = {"if", "else", "match", "loop", "for", "while", "break", "continue", "return", 50 | "fn", "let", "const", "static", "struct", "enum", "trait", "impl", "mod", 51 | "use", "crate", "super", "self", "as", "type", "where", "pub", "unsafe", 52 | "dyn", "move", "async", "await", "true", "false"} 53 | 54 | 55 | SWIFT_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "repeat", "break", 56 | "continue", "return", "func", "var", "let", "class", "struct", "enum", "protocol", 57 | "import", "defer", "as", "is", "try", "catch", "throw", "throws", "inout", 58 | "guard", "self", "super", "true", "false", "nil"} 59 | 60 | 61 | KOTLIN_KEYWORDS = {"if", "else", "when", "for", "while", "do", "break", "continue", "return", 62 | "fun", "val", "var", "class", "object", "interface", "enum", "sealed", 63 | "import", "package", "as", "is", "in", "try", "catch", "finally", "throw", 64 | "super", "this", "by", "constructor", "init", "companion", "override", 65 | "abstract", "final", "open", "private", "protected", "public", "internal", 66 | "inline", "suspend", "operator", "true", "false", "null"} 67 | 68 | TYPESCRIPT_KEYWORDS = JAVASCRIPT_KEYWORDS | {"interface", "type", "namespace", "declare"} 69 | 70 | 71 | PHP_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break", 72 | "continue", "return", "function", "class", "public", "private", "protected", 73 | "extends", "implements", "namespace", "use", "new", "static", "global", 74 | "const", "var", "echo", "print", "try", "catch", "finally", "throw", "true", "false", "null"} 75 | 76 | RUBY_KEYWORDS = {"if", "else", "elsif", "unless", "case", "when", "for", "while", "do", "break", 77 | "continue", "return", "def", "class", "module", "end", "begin", "rescue", "ensure", 78 | "yield", "super", "self", "alias", "true", "false", "nil"} 79 | 80 | SQL_KEYWORDS = {"SELECT", "INSERT", "UPDATE", "DELETE", "FROM", "WHERE", "JOIN", "INNER", "LEFT", 81 | "RIGHT", "FULL", "ON", "GROUP BY", "HAVING", "ORDER BY", "LIMIT", "OFFSET", "AS", 82 | "AND", "OR", "NOT", "NULL", "TRUE", "FALSE"} 83 | 84 | BASH_KEYWORDS = {"if", "else", "fi", "then", "elif", "case", "esac", "for", "while", "do", "done", 85 | "break", "continue", "return", "function", "export", "readonly", "local", "declare", 86 | "eval", "trap", "exec", "true", "false"} 87 | 88 | MATLAB_KEYWORDS = {"if", "else", "elseif", "end", "for", "while", "break", "continue", "return", 89 | "function", "global", "persistent", "switch", "case", "otherwise", "try", "catch", 90 | "true", "false"} 91 | 92 | R_KEYWORDS = {"if", "else", "repeat", "while", "for", "break", "next", "return", "function", 93 | "TRUE", "FALSE", "NULL", "Inf", "NaN", "NA"} 94 | 95 | 96 | PERL_KEYWORDS = {"if", "else", "elsif", "unless", "while", "for", "foreach", "do", "last", "next", 97 | "redo", "goto", "return", "sub", "package", "use", "require", "my", "local", "our", 98 | "state", "BEGIN", "END", "true", "false"} 99 | 100 | LUA_KEYWORDS = {"if", "else", "elseif", "then", "for", "while", "repeat", "until", "break", "return", 101 | "function", "end", "local", "do", "true", "false", "nil"} 102 | 103 | SCALA_KEYWORDS = {"if", "else", "match", "case", "for", "while", "do", "yield", "return", 104 | "def", "val", "var", "lazy", "class", "object", "trait", "extends", 105 | "with", "import", "package", "new", "this", "super", "implicit", 106 | "override", "abstract", "final", "sealed", "private", "protected", 107 | "public", "try", "catch", "finally", "throw", "true", "false", "null"} 108 | 109 | DART_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break", 110 | "continue", "return", "var", "final", "const", "dynamic", "void", 111 | "int", "double", "bool", "String", "class", "interface", "extends", 112 | "implements", "mixin", "import", "library", "part", "typedef", 113 | "this", "super", "as", "is", "new", "try", "catch", "finally", "throw", 114 | "async", "await", "true", "false", "null"} 115 | 116 | JULIA_KEYWORDS = {"if", "else", "elseif", "for", "while", "break", "continue", "return", 117 | "function", "macro", "module", "import", "using", "export", "struct", 118 | "mutable", "const", "begin", "end", "do", "try", "catch", "finally", 119 | "true", "false", "nothing"} 120 | 121 | HASKELL_KEYWORDS = {"if", "then", "else", "case", "of", "let", "in", "where", "do", "module", 122 | "import", "class", "instance", "data", "type", "newtype", "deriving", 123 | "default", "foreign", "safe", "unsafe", "qualified", "true", "false"} 124 | 125 | COBOL_KEYWORDS = {"ACCEPT", "ADD", "CALL", "CANCEL", "CLOSE", "COMPUTE", "CONTINUE", "DELETE", 126 | "DISPLAY", "DIVIDE", "EVALUATE", "EXIT", "GOBACK", "GO", "IF", "INITIALIZE", 127 | "INSPECT", "MERGE", "MOVE", "MULTIPLY", "OPEN", "PERFORM", "READ", "RETURN", 128 | "REWRITE", "SEARCH", "SET", "SORT", "START", "STOP", "STRING", "SUBTRACT", 129 | "UNSTRING", "WRITE", "END-IF", "END-PERFORM"} 130 | 131 | OBJECTIVEC_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break", 132 | "continue", "return", "void", "int", "float", "double", "char", "long", "short", 133 | "signed", "unsigned", "class", "interface", "protocol", "implementation", 134 | "try", "catch", "finally", "throw", "import", "self", "super", "atomic", 135 | "nonatomic", "strong", "weak", "retain", "copy", "assign", "true", "false", "nil"} 136 | 137 | FSHARP_KEYWORDS = {"if", "then", "else", "match", "with", "for", "while", "do", "done", "let", 138 | "rec", "in", "try", "finally", "raise", "exception", "function", "return", 139 | "type", "mutable", "namespace", "module", "open", "abstract", "override", 140 | "inherit", "base", "new", "true", "false", "null"} 141 | 142 | LISP_KEYWORDS = {"defun", "setq", "let", "lambda", "if", "cond", "loop", "dolist", "dotimes", 143 | "progn", "return", "function", "defmacro", "quote", "eval", "apply", "car", 144 | "cdr", "cons", "list", "mapcar", "format", "read", "print", "load", "t", "nil"} 145 | 146 | PROLOG_KEYWORDS = {"if", "else", "end", "fail", "true", "false", "not", "repeat", "is", 147 | "assert", "retract", "call", "findall", "bagof", "setof", "atom", 148 | "integer", "float", "char_code", "compound", "number", "var"} 149 | 150 | ADA_KEYWORDS = {"if", "then", "else", "elsif", "case", "when", "for", "while", "loop", "exit", 151 | "return", "procedure", "function", "package", "use", "is", "begin", "end", 152 | "record", "type", "constant", "exception", "raise", "declare", "private", 153 | "null", "true", "false"} 154 | 155 | DELPHI_KEYWORDS = {"if", "then", "else", "case", "of", "for", "while", "repeat", "until", "break", 156 | "continue", "begin", "end", "procedure", "function", "var", "const", "type", 157 | "class", "record", "interface", "implementation", "unit", "uses", "inherited", 158 | "try", "except", "finally", "raise", "private", "public", "protected", "published", 159 | "true", "false", "nil"} 160 | 161 | VB_KEYWORDS = {"If", "Then", "Else", "ElseIf", "End", "For", "Each", "While", "Do", "Loop", 162 | "Select", "Case", "Try", "Catch", "Finally", "Throw", "Return", "Function", 163 | "Sub", "Class", "Module", "Namespace", "Imports", "Inherits", "Implements", 164 | "Public", "Private", "Protected", "Friend", "Shared", "Static", "Dim", "Const", 165 | "New", "Me", "MyBase", "MyClass", "Not", "And", "Or", "True", "False", "Nothing"} 166 | 167 | HTML_KEYWORDS = {"html", "head", "title", "meta", "link", "style", "script", "body", "div", "span", 168 | "h1", "h2", "h3", "h4", "h5", "h6", "p", "a", "img", "ul", "ol", "li", "table", 169 | "tr", "td", "th", "thead", "tbody", "tfoot", "form", "input", "button", "label", 170 | "select", "option", "textarea", "fieldset", "legend", "iframe", "nav", "section", 171 | "article", "aside", "header", "footer", "main", "blockquote", "cite", "code", 172 | "pre", "em", "strong", "b", "i", "u", "small", "br", "hr"} 173 | 174 | CSS_KEYWORDS = {"color", "background", "border", "margin", "padding", "width", "height", "font-size", 175 | "font-family", "text-align", "display", "position", "top", "bottom", "left", "right", 176 | "z-index", "visibility", "opacity", "overflow", "cursor", "flex", "grid", "align-items", 177 | "justify-content", "box-shadow", "text-shadow", "animation", "transition", "transform", 178 | "clip-path", "content", "filter", "outline", "max-width", "min-width", "max-height", 179 | "min-height", "letter-spacing", "line-height", "white-space", "word-break"} 180 | 181 | 182 | PROGRAMMING_LANGUAGES = { 183 | "Python": PYTHON_KEYWORDS, 184 | "JavaScript": JAVASCRIPT_KEYWORDS, 185 | "Java": JAVA_KEYWORDS, 186 | "C": C_KEYWORDS, 187 | "C++": CPP_KEYWORDS, 188 | "C#": CSHARP_KEYWORDS, 189 | "Go": GO_KEYWORDS, 190 | "Rust": RUST_KEYWORDS, 191 | "Swift": SWIFT_KEYWORDS, 192 | "Kotlin": KOTLIN_KEYWORDS, 193 | "TypeScript": TYPESCRIPT_KEYWORDS, 194 | "PHP": PHP_KEYWORDS, 195 | "Ruby": RUBY_KEYWORDS, 196 | "SQL": SQL_KEYWORDS, 197 | "Bash": BASH_KEYWORDS, 198 | "MATLAB": MATLAB_KEYWORDS, 199 | "R": R_KEYWORDS, 200 | "Perl": PERL_KEYWORDS, 201 | "Lua": LUA_KEYWORDS, 202 | "Scala": SCALA_KEYWORDS, 203 | "Dart": DART_KEYWORDS, 204 | "Julia": JULIA_KEYWORDS, 205 | "Haskell": HASKELL_KEYWORDS, 206 | "COBOL": COBOL_KEYWORDS, 207 | "Objective-C": OBJECTIVEC_KEYWORDS, 208 | "F#": FSHARP_KEYWORDS, 209 | "Lisp": LISP_KEYWORDS, 210 | "Prolog": PROLOG_KEYWORDS, 211 | "Ada": ADA_KEYWORDS, 212 | "Delphi": DELPHI_KEYWORDS, 213 | "Visual Basic": VB_KEYWORDS, 214 | "HTML": HTML_KEYWORDS, 215 | "CSS": CSS_KEYWORDS} 216 | 217 | PROGRAMMING_LANGUAGES_KEYWORDS = set() 218 | for language in PROGRAMMING_LANGUAGES: 219 | PROGRAMMING_LANGUAGES_KEYWORDS = PROGRAMMING_LANGUAGES_KEYWORDS | PROGRAMMING_LANGUAGES[language] 220 | -------------------------------------------------------------------------------- /memor/params.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Memor parameters and constants.""" 3 | from enum import Enum 4 | MEMOR_VERSION = "0.6" 5 | 6 | DATE_TIME_FORMAT = "%Y-%m-%d %H:%M:%S %z" 7 | 8 | INVALID_PATH_MESSAGE = "Invalid path: must be a string and refer to an existing location. Given path: {path}" 9 | INVALID_STR_VALUE_MESSAGE = "Invalid value. `{parameter_name}` must be a string." 10 | INVALID_BOOL_VALUE_MESSAGE = "Invalid value. `{parameter_name}` must be a boolean." 11 | INVALID_POSFLOAT_VALUE_MESSAGE = "Invalid value. `{parameter_name}` must be a positive float." 12 | INVALID_POSINT_VALUE_MESSAGE = "Invalid value. `{parameter_name}` must be a positive integer." 13 | INVALID_PROB_VALUE_MESSAGE = "Invalid value. `{parameter_name}` must be a value between 0 and 1." 14 | INVALID_LIST_OF_X_MESSAGE = "Invalid value. `{parameter_name}` must be a list of {type_name}." 15 | INVALID_INT_OR_STR_MESSAGE = "Invalid value. `{parameter_name}` must be an integer or a string." 16 | INVALID_INT_OR_STR_SLICE_MESSAGE = "Invalid value. `{parameter_name}` must be an integer, string or a slice." 17 | INVALID_DATETIME_MESSAGE = "Invalid value. `{parameter_name}` must be a datetime object that includes timezone information." 18 | INVALID_TEMPLATE_MESSAGE = "Invalid template. It must be an instance of `PromptTemplate` or `PresetPromptTemplate`." 19 | INVALID_RESPONSE_MESSAGE = "Invalid response. It must be an instance of `Response`." 20 | INVALID_MESSAGE = "Invalid message. It must be an instance of `Prompt` or `Response`." 21 | INVALID_MESSAGE_STATUS_LEN_MESSAGE = "Invalid message status length. It must be equal to the number of messages." 22 | INVALID_CUSTOM_MAP_MESSAGE = "Invalid custom map: it must be a dictionary with keys and values that can be converted to strings." 23 | INVALID_ROLE_MESSAGE = "Invalid role. It must be an instance of Role enum." 24 | INVALID_ID_MESSAGE = "Invalid message ID. It must be a valid UUIDv4." 25 | INVALID_MODEL_MESSAGE = "Invalid model. It must be an instance of LLMModel enum or a string." 26 | INVALID_TEMPLATE_STRUCTURE_MESSAGE = "Invalid template structure. It should be a JSON object with proper fields." 27 | INVALID_PROMPT_STRUCTURE_MESSAGE = "Invalid prompt structure. It should be a JSON object with proper fields." 28 | INVALID_RESPONSE_STRUCTURE_MESSAGE = "Invalid response structure. It should be a JSON object with proper fields." 29 | INVALID_RENDER_FORMAT_MESSAGE = "Invalid render format. It must be an instance of RenderFormat enum." 30 | PROMPT_RENDER_ERROR_MESSAGE = "Prompt template and properties are incompatible." 31 | UNSUPPORTED_OPERAND_ERROR_MESSAGE = "Unsupported operand type(s) for {operator}: `{operand1}` and `{operand2}`" 32 | DATA_SAVE_SUCCESS_MESSAGE = "Everything seems good." 33 | 34 | 35 | class Role(Enum): 36 | """Role enum.""" 37 | 38 | SYSTEM = "system" 39 | USER = "user" 40 | ASSISTANT = "assistant" 41 | DEFAULT = USER 42 | 43 | 44 | class RenderFormat(Enum): 45 | """Render format.""" 46 | 47 | STRING = "STRING" 48 | OPENAI = "OPENAI" 49 | AI_STUDIO = "AI STUDIO" 50 | DICTIONARY = "DICTIONARY" 51 | ITEMS = "ITEMS" 52 | DEFAULT = STRING 53 | 54 | 55 | class LLMModel(Enum): 56 | """LLM model enum.""" 57 | 58 | GPT_O1 = "gpt-o1" 59 | GPT_O1_MINI = "gpt-o1-mini" 60 | GPT_4O = "gpt-4o" 61 | GPT_4O_MINI = "gpt-4o-mini" 62 | GPT_4_TURBO = "gpt-4-turbo" 63 | GPT_4 = "gpt-4" 64 | GPT_4_VISION = "gpt-4-vision" 65 | GPT_3_5_TURBO = "gpt-3.5-turbo" 66 | DAVINCI = "davinci" 67 | BABBAGE = "babbage" 68 | 69 | CLAUDE_3_5_SONNET = "claude-3.5-sonnet" 70 | CLAUDE_3_OPUS = "claude-3-opus" 71 | CLAUDE_3_HAIKU = "claude-3-haiku" 72 | CLAUDE_2 = "claude-2" 73 | CLAUDE_INSTANT = "claude-instant" 74 | 75 | LLAMA3_70B = "llama3-70b" 76 | LLAMA3_8B = "llama3-8b" 77 | LLAMA_GUARD_3_8B = "llama-guard-3-8b" 78 | 79 | MISTRAL_7B = "mistral-7b" 80 | MIXTRAL_8X7B = "mixtral-8x7b" 81 | MIXTRAL_8X22B = "mixtral-8x22b" 82 | MISTRAL_NEMO = "mistral-nemo" 83 | MISTRAL_TINY = "mistral-tiny" 84 | MISTRAL_SMALL = "mistral-small" 85 | MISTRAL_MEDIUM = "mistral-medium" 86 | MISTRAL_LARGE = "mistral-large" 87 | CODESTRAL = "codestral" 88 | PIXTRAL = "pixtral-12b" 89 | 90 | GEMMA_7B = "gemma-7b" 91 | GEMMA2_9B = "gemma2-9b" 92 | GEMINI_1_PRO = "gemini-1-pro" 93 | GEMINI_1_ULTRA = "gemini-1-ultra" 94 | GEMINI_1_5_PRO = "gemini-1.5-pro" 95 | GEMINI_1_5_ULTRA = "gemini-1.5-ultra" 96 | GEMINI_1_5_FLASH = "gemini-1.5-flash" 97 | GEMINI_2_FLASH = "gemini-2-flash" 98 | GEMINI_2_PRO = "gemini-2-pro" 99 | 100 | DEEPSEEK_V3 = "deepseek-v3" 101 | DEEPSEEK_R1 = "deepseek-r1" 102 | DEEPSEEK_CODER = "deepseek-coder" 103 | 104 | PHI_2 = "phi-2" 105 | PHI_4 = "phi-4" 106 | 107 | QWEN_1_8B = "qwen-1.8b" 108 | QWEN_7B = "qwen-7b" 109 | QWEN_14B = "qwen-14b" 110 | QWEN_72B = "qwen-72b" 111 | 112 | YI_6B = "yi-6b" 113 | YI_9B = "yi-9b" 114 | YI_34B = "yi-34b" 115 | 116 | DEFAULT = "unknown" 117 | -------------------------------------------------------------------------------- /memor/prompt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Prompt class.""" 3 | from typing import List, Dict, Union, Tuple, Any 4 | import datetime 5 | import json 6 | from .params import MEMOR_VERSION 7 | from .params import DATE_TIME_FORMAT 8 | from .params import RenderFormat, DATA_SAVE_SUCCESS_MESSAGE 9 | from .params import Role 10 | from .tokens_estimator import TokensEstimator 11 | from .params import INVALID_PROMPT_STRUCTURE_MESSAGE, INVALID_TEMPLATE_MESSAGE 12 | from .params import INVALID_ROLE_MESSAGE, INVALID_RESPONSE_MESSAGE 13 | from .params import PROMPT_RENDER_ERROR_MESSAGE 14 | from .params import INVALID_RENDER_FORMAT_MESSAGE 15 | from .errors import MemorValidationError, MemorRenderError 16 | from .functions import get_time_utc, generate_message_id 17 | from .functions import _validate_string, _validate_pos_int, _validate_list_of 18 | from .functions import _validate_path, _validate_message_id 19 | from .template import PromptTemplate, PresetPromptTemplate 20 | from .template import _BasicPresetPromptTemplate, _Instruction1PresetPromptTemplate, _Instruction2PresetPromptTemplate, _Instruction3PresetPromptTemplate 21 | from .response import Response 22 | 23 | 24 | class Prompt: 25 | """ 26 | Prompt class. 27 | 28 | >>> from memor import Prompt, Role, Response 29 | >>> responses = [Response(message="I am fine."), Response(message="I am not fine."), Response(message="I am okay.")] 30 | >>> prompt = Prompt(message="Hello, how are you?", responses=responses) 31 | >>> prompt.message 32 | 'Hello, how are you?' 33 | >>> prompt.responses[1].message 34 | 'I am not fine.' 35 | """ 36 | 37 | def __init__( 38 | self, 39 | message: str = "", 40 | responses: List[Response] = [], 41 | role: Role = Role.DEFAULT, 42 | tokens: int = None, 43 | template: Union[PresetPromptTemplate, PromptTemplate] = PresetPromptTemplate.DEFAULT, 44 | file_path: str = None, 45 | init_check: bool = True) -> None: 46 | """ 47 | Prompt object initiator. 48 | 49 | :param message: prompt message 50 | :param responses: prompt responses 51 | :param role: prompt role 52 | :param tokens: tokens 53 | :param template: prompt template 54 | :param file_path: prompt file path 55 | :param init_check: initial check flag 56 | """ 57 | self._message = "" 58 | self._tokens = None 59 | self._role = Role.DEFAULT 60 | self._template = PresetPromptTemplate.DEFAULT.value 61 | self._responses = [] 62 | self._date_created = get_time_utc() 63 | self._mark_modified() 64 | self._memor_version = MEMOR_VERSION 65 | self._selected_response_index = 0 66 | self._selected_response = None 67 | self._id = None 68 | if file_path: 69 | self.load(file_path) 70 | else: 71 | if message: 72 | self.update_message(message) 73 | if role: 74 | self.update_role(role) 75 | if tokens: 76 | self.update_tokens(tokens) 77 | if responses: 78 | self.update_responses(responses) 79 | if template: 80 | self.update_template(template) 81 | self.select_response(index=self._selected_response_index) 82 | self._id = generate_message_id() 83 | _validate_message_id(self._id) 84 | if init_check: 85 | _ = self.render() 86 | 87 | def _mark_modified(self) -> None: 88 | """Mark modification.""" 89 | self._date_modified = get_time_utc() 90 | 91 | def __eq__(self, other_prompt: "Prompt") -> bool: 92 | """ 93 | Check prompts equality. 94 | 95 | :param other_prompt: another prompt 96 | """ 97 | if isinstance(other_prompt, Prompt): 98 | return self._message == other_prompt._message and self._responses == other_prompt._responses and \ 99 | self._role == other_prompt._role and self._template == other_prompt._template and \ 100 | self._tokens == other_prompt._tokens 101 | return False 102 | 103 | def __str__(self) -> str: 104 | """Return string representation of Prompt.""" 105 | return self.render(render_format=RenderFormat.STRING) 106 | 107 | def __repr__(self) -> str: 108 | """Return string representation of Prompt.""" 109 | return "Prompt(message={message})".format(message=self._message) 110 | 111 | def __len__(self) -> int: 112 | """Return the length of the Prompt object.""" 113 | try: 114 | return len(self.render(render_format=RenderFormat.STRING)) 115 | except Exception: 116 | return 0 117 | 118 | def __copy__(self) -> "Prompt": 119 | """ 120 | Return a copy of the Prompt object. 121 | 122 | :return: a copy of Prompt object 123 | """ 124 | _class = self.__class__ 125 | result = _class.__new__(_class) 126 | result.__dict__.update(self.__dict__) 127 | result.regenerate_id() 128 | return result 129 | 130 | def copy(self) -> "Prompt": 131 | """ 132 | Return a copy of the Prompt object. 133 | 134 | :return: a copy of Prompt object 135 | """ 136 | return self.__copy__() 137 | 138 | def add_response(self, response: Response, index: int = None) -> None: 139 | """ 140 | Add a response to the prompt object. 141 | 142 | :param response: response 143 | :param index: index 144 | """ 145 | if not isinstance(response, Response): 146 | raise MemorValidationError(INVALID_RESPONSE_MESSAGE) 147 | if index is None: 148 | self._responses.append(response) 149 | else: 150 | self._responses.insert(index, response) 151 | self._mark_modified() 152 | 153 | def remove_response(self, index: int) -> None: 154 | """ 155 | Remove a response from the prompt object. 156 | 157 | :param index: index 158 | """ 159 | self._responses.pop(index) 160 | self._mark_modified() 161 | 162 | def select_response(self, index: int) -> None: 163 | """ 164 | Select a response as selected response. 165 | 166 | :param index: index 167 | """ 168 | if len(self._responses) > 0: 169 | self._selected_response_index = index 170 | self._selected_response = self._responses[index] 171 | self._mark_modified() 172 | 173 | def update_responses(self, responses: List[Response]) -> None: 174 | """ 175 | Update the prompt responses. 176 | 177 | :param responses: responses 178 | """ 179 | _validate_list_of(responses, "responses", Response, "`Response`") 180 | self._responses = responses 181 | self._mark_modified() 182 | 183 | def update_message(self, message: str) -> None: 184 | """ 185 | Update the prompt message. 186 | 187 | :param message: message 188 | """ 189 | _validate_string(message, "message") 190 | self._message = message 191 | self._mark_modified() 192 | 193 | def update_role(self, role: Role) -> None: 194 | """ 195 | Update the prompt role. 196 | 197 | :param role: role 198 | """ 199 | if not isinstance(role, Role): 200 | raise MemorValidationError(INVALID_ROLE_MESSAGE) 201 | self._role = role 202 | self._mark_modified() 203 | 204 | def update_tokens(self, tokens: int) -> None: 205 | """ 206 | Update the tokens. 207 | 208 | :param tokens: tokens 209 | """ 210 | _validate_pos_int(tokens, "tokens") 211 | self._tokens = tokens 212 | self._mark_modified() 213 | 214 | def update_template(self, template: PromptTemplate) -> None: 215 | """ 216 | Update the prompt template. 217 | 218 | :param template: template 219 | """ 220 | if not isinstance( 221 | template, 222 | (PromptTemplate, 223 | _BasicPresetPromptTemplate, 224 | _Instruction1PresetPromptTemplate, 225 | _Instruction2PresetPromptTemplate, 226 | _Instruction3PresetPromptTemplate)): 227 | raise MemorValidationError(INVALID_TEMPLATE_MESSAGE) 228 | if isinstance(template, PromptTemplate): 229 | self._template = template 230 | if isinstance( 231 | template, 232 | (_BasicPresetPromptTemplate, 233 | _Instruction1PresetPromptTemplate, 234 | _Instruction2PresetPromptTemplate, 235 | _Instruction3PresetPromptTemplate)): 236 | self._template = template.value 237 | self._mark_modified() 238 | 239 | def save(self, file_path: str, save_template: bool = True) -> Dict[str, Any]: 240 | """ 241 | Save method. 242 | 243 | :param file_path: prompt file path 244 | :param save_template: save template flag 245 | """ 246 | result = {"status": True, "message": DATA_SAVE_SUCCESS_MESSAGE} 247 | try: 248 | with open(file_path, "w") as file: 249 | data = self.to_json(save_template=save_template) 250 | json.dump(data, file) 251 | except Exception as e: 252 | result["status"] = False 253 | result["message"] = str(e) 254 | return result 255 | 256 | def load(self, file_path: str) -> None: 257 | """ 258 | Load method. 259 | 260 | :param file_path: prompt file path 261 | """ 262 | _validate_path(file_path) 263 | with open(file_path, "r") as file: 264 | self.from_json(file.read()) 265 | 266 | def from_json(self, json_object: Union[str, Dict[str, Any]]) -> None: 267 | """ 268 | Load attributes from the JSON object. 269 | 270 | :param json_object: JSON object 271 | """ 272 | try: 273 | if isinstance(json_object, str): 274 | loaded_obj = json.loads(json_object) 275 | else: 276 | loaded_obj = json_object.copy() 277 | self._message = loaded_obj["message"] 278 | self._tokens = loaded_obj.get("tokens", None) 279 | self._id = loaded_obj.get("id", generate_message_id()) 280 | responses = [] 281 | for response in loaded_obj["responses"]: 282 | response_obj = Response() 283 | response_obj.from_json(response) 284 | responses.append(response_obj) 285 | self._responses = responses 286 | self._role = Role(loaded_obj["role"]) 287 | self._template = PresetPromptTemplate.DEFAULT.value 288 | if "template" in loaded_obj: 289 | template_obj = PromptTemplate() 290 | template_obj.from_json(loaded_obj["template"]) 291 | self._template = template_obj 292 | self._memor_version = loaded_obj["memor_version"] 293 | self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT) 294 | self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT) 295 | self._selected_response_index = loaded_obj["selected_response_index"] 296 | self.select_response(index=self._selected_response_index) 297 | except Exception: 298 | raise MemorValidationError(INVALID_PROMPT_STRUCTURE_MESSAGE) 299 | 300 | def to_json(self, save_template: bool = True) -> Dict[str, Any]: 301 | """ 302 | Convert the prompt to a JSON object. 303 | 304 | :param save_template: save template flag 305 | """ 306 | data = self.to_dict(save_template=save_template).copy() 307 | for index, response in enumerate(data["responses"]): 308 | data["responses"][index] = response.to_json() 309 | if "template" in data: 310 | data["template"] = data["template"].to_json() 311 | data["role"] = data["role"].value 312 | data["date_created"] = datetime.datetime.strftime(data["date_created"], DATE_TIME_FORMAT) 313 | data["date_modified"] = datetime.datetime.strftime(data["date_modified"], DATE_TIME_FORMAT) 314 | return data 315 | 316 | def to_dict(self, save_template: bool = True) -> Dict[str, Any]: 317 | """ 318 | Convert the prompt to a dictionary. 319 | 320 | :param save_template: save template flag 321 | """ 322 | data = { 323 | "type": "Prompt", 324 | "message": self._message, 325 | "responses": self._responses.copy(), 326 | "selected_response_index": self._selected_response_index, 327 | "tokens": self._tokens, 328 | "role": self._role, 329 | "id": self._id, 330 | "template": self._template, 331 | "memor_version": MEMOR_VERSION, 332 | "date_created": self._date_created, 333 | "date_modified": self._date_modified, 334 | } 335 | if not save_template: 336 | del data["template"] 337 | return data 338 | 339 | def regenerate_id(self) -> None: 340 | """Regenerate ID.""" 341 | new_id = self._id 342 | while new_id == self.id: 343 | new_id = generate_message_id() 344 | self._id = new_id 345 | 346 | @property 347 | def message(self) -> str: 348 | """Get the prompt message.""" 349 | return self._message 350 | 351 | @property 352 | def responses(self) -> List[Response]: 353 | """Get the prompt responses.""" 354 | return self._responses 355 | 356 | @property 357 | def role(self) -> Role: 358 | """Get the prompt role.""" 359 | return self._role 360 | 361 | @property 362 | def tokens(self) -> int: 363 | """Get the prompt tokens.""" 364 | return self._tokens 365 | 366 | @property 367 | def date_created(self) -> datetime.datetime: 368 | """Get the prompt creation date.""" 369 | return self._date_created 370 | 371 | @property 372 | def date_modified(self) -> datetime.datetime: 373 | """Get the prompt object modification date.""" 374 | return self._date_modified 375 | 376 | @property 377 | def template(self) -> PromptTemplate: 378 | """Get the prompt template.""" 379 | return self._template 380 | 381 | @property 382 | def id(self) -> str: 383 | """Get the prompt ID.""" 384 | return self._id 385 | 386 | @property 387 | def selected_response(self) -> Response: 388 | """Get the prompt selected response.""" 389 | return self._selected_response 390 | 391 | def render(self, render_format: RenderFormat = RenderFormat.DEFAULT) -> Union[str, 392 | Dict[str, Any], 393 | List[Tuple[str, Any]]]: 394 | """ 395 | Render method. 396 | 397 | :param render_format: render format 398 | """ 399 | if not isinstance(render_format, RenderFormat): 400 | raise MemorValidationError(INVALID_RENDER_FORMAT_MESSAGE) 401 | try: 402 | format_kwargs = {"prompt": self.to_json(save_template=False)} 403 | if isinstance(self._selected_response, Response): 404 | format_kwargs.update({"response": self._selected_response.to_json()}) 405 | responses_dicts = [] 406 | for _, response in enumerate(self._responses): 407 | responses_dicts.append(response.to_json()) 408 | format_kwargs.update({"responses": responses_dicts}) 409 | custom_map = self._template._custom_map 410 | if custom_map is not None: 411 | format_kwargs.update(custom_map) 412 | content = self._template._content.format(**format_kwargs) 413 | prompt_dict = self.to_dict() 414 | prompt_dict["content"] = content 415 | if render_format == RenderFormat.OPENAI: 416 | return {"role": self._role.value, "content": content} 417 | if render_format == RenderFormat.AI_STUDIO: 418 | return {"role": self._role.value, "parts": [{"text": content}]} 419 | if render_format == RenderFormat.STRING: 420 | return content 421 | if render_format == RenderFormat.DICTIONARY: 422 | return prompt_dict 423 | if render_format == RenderFormat.ITEMS: 424 | return list(prompt_dict.items()) 425 | except Exception: 426 | raise MemorRenderError(PROMPT_RENDER_ERROR_MESSAGE) 427 | 428 | def check_render(self) -> bool: 429 | """Check render.""" 430 | try: 431 | _ = self.render() 432 | return True 433 | except Exception: 434 | return False 435 | 436 | def estimate_tokens(self, method: TokensEstimator = TokensEstimator.DEFAULT) -> int: 437 | """ 438 | Estimate the number of tokens in the prompt message. 439 | 440 | :param method: token estimator method 441 | """ 442 | return method(self.render(render_format=RenderFormat.STRING)) 443 | -------------------------------------------------------------------------------- /memor/response.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Response class.""" 3 | from typing import List, Dict, Union, Tuple, Any 4 | import datetime 5 | import json 6 | from .params import MEMOR_VERSION 7 | from .params import DATE_TIME_FORMAT 8 | from .params import DATA_SAVE_SUCCESS_MESSAGE 9 | from .params import INVALID_RESPONSE_STRUCTURE_MESSAGE 10 | from .params import INVALID_ROLE_MESSAGE, INVALID_RENDER_FORMAT_MESSAGE, INVALID_MODEL_MESSAGE 11 | from .params import Role, RenderFormat, LLMModel 12 | from .tokens_estimator import TokensEstimator 13 | from .errors import MemorValidationError 14 | from .functions import get_time_utc, generate_message_id 15 | from .functions import _validate_string, _validate_pos_float, _validate_pos_int, _validate_message_id 16 | from .functions import _validate_date_time, _validate_probability, _validate_path 17 | 18 | 19 | class Response: 20 | """ 21 | Response class. 22 | 23 | >>> from memor import Response, Role 24 | >>> response = Response(message="Hello!", score=0.9, role=Role.ASSISTANT, temperature=0.5, model=LLMModel.GPT_4) 25 | >>> response.message 26 | 'Hello!' 27 | """ 28 | 29 | def __init__( 30 | self, 31 | message: str = "", 32 | score: float = None, 33 | role: Role = Role.ASSISTANT, 34 | temperature: float = None, 35 | tokens: int = None, 36 | inference_time: float = None, 37 | model: Union[LLMModel, str] = LLMModel.DEFAULT, 38 | date: datetime.datetime = get_time_utc(), 39 | file_path: str = None) -> None: 40 | """ 41 | Response object initiator. 42 | 43 | :param message: response message 44 | :param score: response score 45 | :param role: response role 46 | :param temperature: temperature 47 | :param tokens: tokens 48 | :param inference_time: inference time 49 | :param model: agent model 50 | :param date: response date 51 | :param file_path: response file path 52 | """ 53 | self._message = "" 54 | self._score = None 55 | self._role = Role.ASSISTANT 56 | self._temperature = None 57 | self._tokens = None 58 | self._inference_time = None 59 | self._model = LLMModel.DEFAULT.value 60 | self._date_created = get_time_utc() 61 | self._mark_modified() 62 | self._memor_version = MEMOR_VERSION 63 | self._id = None 64 | if file_path: 65 | self.load(file_path) 66 | else: 67 | if message: 68 | self.update_message(message) 69 | if score: 70 | self.update_score(score) 71 | if role: 72 | self.update_role(role) 73 | if model: 74 | self.update_model(model) 75 | if temperature: 76 | self.update_temperature(temperature) 77 | if tokens: 78 | self.update_tokens(tokens) 79 | if inference_time: 80 | self.update_inference_time(inference_time) 81 | if date: 82 | _validate_date_time(date, "date") 83 | self._date_created = date 84 | self._id = generate_message_id() 85 | _validate_message_id(self._id) 86 | 87 | def _mark_modified(self) -> None: 88 | """Mark modification.""" 89 | self._date_modified = get_time_utc() 90 | 91 | def __eq__(self, other_response: "Response") -> bool: 92 | """ 93 | Check responses equality. 94 | 95 | :param other_response: another response 96 | """ 97 | if isinstance(other_response, Response): 98 | return self._message == other_response._message and self._score == other_response._score and self._role == other_response._role and self._temperature == other_response._temperature and \ 99 | self._model == other_response._model and self._tokens == other_response._tokens and self._inference_time == other_response._inference_time 100 | return False 101 | 102 | def __str__(self) -> str: 103 | """Return string representation of Response.""" 104 | return self.render(render_format=RenderFormat.STRING) 105 | 106 | def __repr__(self) -> str: 107 | """Return string representation of Response.""" 108 | return "Response(message={message})".format(message=self._message) 109 | 110 | def __len__(self) -> int: 111 | """Return the length of the Response object.""" 112 | return len(self.render(render_format=RenderFormat.STRING)) 113 | 114 | def __copy__(self) -> "Response": 115 | """Return a copy of the Response object.""" 116 | _class = self.__class__ 117 | result = _class.__new__(_class) 118 | result.__dict__.update(self.__dict__) 119 | result.regenerate_id() 120 | return result 121 | 122 | def copy(self) -> "Response": 123 | """Return a copy of the Response object.""" 124 | return self.__copy__() 125 | 126 | def update_message(self, message: str) -> None: 127 | """ 128 | Update the response message. 129 | 130 | :param message: message 131 | """ 132 | _validate_string(message, "message") 133 | self._message = message 134 | self._mark_modified() 135 | 136 | def update_score(self, score: float) -> None: 137 | """ 138 | Update the response score. 139 | 140 | :param score: score 141 | """ 142 | _validate_probability(score, "score") 143 | self._score = score 144 | self._mark_modified() 145 | 146 | def update_role(self, role: Role) -> None: 147 | """ 148 | Update the response role. 149 | 150 | :param role: role 151 | """ 152 | if not isinstance(role, Role): 153 | raise MemorValidationError(INVALID_ROLE_MESSAGE) 154 | self._role = role 155 | self._mark_modified() 156 | 157 | def update_temperature(self, temperature: float) -> None: 158 | """ 159 | Update the temperature. 160 | 161 | :param temperature: temperature 162 | """ 163 | _validate_pos_float(temperature, "temperature") 164 | self._temperature = temperature 165 | self._mark_modified() 166 | 167 | def update_tokens(self, tokens: int) -> None: 168 | """ 169 | Update the tokens. 170 | 171 | :param tokens: tokens 172 | """ 173 | _validate_pos_int(tokens, "tokens") 174 | self._tokens = tokens 175 | self._mark_modified() 176 | 177 | def update_inference_time(self, inference_time: float) -> None: 178 | """ 179 | Update inference time. 180 | 181 | :param inference_time: inference time 182 | """ 183 | _validate_pos_float(inference_time, "inference_time") 184 | self._inference_time = inference_time 185 | self._mark_modified() 186 | 187 | def update_model(self, model: Union[LLMModel, str]) -> None: 188 | """ 189 | Update the agent model. 190 | 191 | :param model: model 192 | """ 193 | if isinstance(model, str): 194 | self._model = model 195 | elif isinstance(model, LLMModel): 196 | self._model = model.value 197 | else: 198 | raise MemorValidationError(INVALID_MODEL_MESSAGE) 199 | self._mark_modified() 200 | 201 | def save(self, file_path: str) -> Dict[str, Any]: 202 | """ 203 | Save method. 204 | 205 | :param file_path: response file path 206 | """ 207 | result = {"status": True, "message": DATA_SAVE_SUCCESS_MESSAGE} 208 | try: 209 | with open(file_path, "w") as file: 210 | json.dump(self.to_json(), file) 211 | except Exception as e: 212 | result["status"] = False 213 | result["message"] = str(e) 214 | return result 215 | 216 | def load(self, file_path: str) -> None: 217 | """ 218 | Load method. 219 | 220 | :param file_path: response file path 221 | """ 222 | _validate_path(file_path) 223 | with open(file_path, "r") as file: 224 | self.from_json(file.read()) 225 | 226 | def from_json(self, json_object: Union[str, Dict[str, Any]]) -> None: 227 | """ 228 | Load attributes from the JSON object. 229 | 230 | :param json_object: JSON object 231 | """ 232 | try: 233 | if isinstance(json_object, str): 234 | loaded_obj = json.loads(json_object) 235 | else: 236 | loaded_obj = json_object.copy() 237 | self._message = loaded_obj["message"] 238 | self._score = loaded_obj["score"] 239 | self._temperature = loaded_obj["temperature"] 240 | self._tokens = loaded_obj.get("tokens", None) 241 | self._inference_time = loaded_obj.get("inference_time", None) 242 | self._model = loaded_obj["model"] 243 | self._role = Role(loaded_obj["role"]) 244 | self._memor_version = loaded_obj["memor_version"] 245 | self._id = loaded_obj.get("id", generate_message_id()) 246 | self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT) 247 | self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT) 248 | except Exception: 249 | raise MemorValidationError(INVALID_RESPONSE_STRUCTURE_MESSAGE) 250 | 251 | def to_json(self) -> Dict[str, Any]: 252 | """Convert the response to a JSON object.""" 253 | data = self.to_dict().copy() 254 | data["date_created"] = datetime.datetime.strftime(data["date_created"], DATE_TIME_FORMAT) 255 | data["date_modified"] = datetime.datetime.strftime(data["date_modified"], DATE_TIME_FORMAT) 256 | data["role"] = data["role"].value 257 | return data 258 | 259 | def to_dict(self) -> Dict[str, Any]: 260 | """Convert the response to a dictionary.""" 261 | return { 262 | "type": "Response", 263 | "message": self._message, 264 | "score": self._score, 265 | "temperature": self._temperature, 266 | "tokens": self._tokens, 267 | "inference_time": self._inference_time, 268 | "role": self._role, 269 | "model": self._model, 270 | "id": self._id, 271 | "memor_version": MEMOR_VERSION, 272 | "date_created": self._date_created, 273 | "date_modified": self._date_modified, 274 | } 275 | 276 | def render(self, 277 | render_format: RenderFormat = RenderFormat.DEFAULT) -> Union[str, 278 | Dict[str, Any], 279 | List[Tuple[str, Any]]]: 280 | """ 281 | Render the response. 282 | 283 | :param render_format: render format 284 | """ 285 | if not isinstance(render_format, RenderFormat): 286 | raise MemorValidationError(INVALID_RENDER_FORMAT_MESSAGE) 287 | if render_format == RenderFormat.STRING: 288 | return self._message 289 | elif render_format == RenderFormat.OPENAI: 290 | return {"role": self._role.value, 291 | "content": self._message} 292 | elif render_format == RenderFormat.AI_STUDIO: 293 | return {"role": self._role.value, 294 | "parts": [{"text": self._message}]} 295 | elif render_format == RenderFormat.DICTIONARY: 296 | return self.to_dict() 297 | elif render_format == RenderFormat.ITEMS: 298 | return self.to_dict().items() 299 | return self._message 300 | 301 | def estimate_tokens(self, method: TokensEstimator = TokensEstimator.DEFAULT) -> int: 302 | """ 303 | Estimate the number of tokens in the response message. 304 | 305 | :param method: token estimator method 306 | """ 307 | return method(self.render(render_format=RenderFormat.STRING)) 308 | 309 | def regenerate_id(self) -> None: 310 | """Regenerate ID.""" 311 | new_id = self._id 312 | while new_id == self.id: 313 | new_id = generate_message_id() 314 | self._id = new_id 315 | 316 | @property 317 | def message(self) -> str: 318 | """Get the response message.""" 319 | return self._message 320 | 321 | @property 322 | def score(self) -> float: 323 | """Get the response score.""" 324 | return self._score 325 | 326 | @property 327 | def temperature(self) -> float: 328 | """Get the temperature.""" 329 | return self._temperature 330 | 331 | @property 332 | def tokens(self) -> int: 333 | """Get the tokens.""" 334 | return self._tokens 335 | 336 | @property 337 | def inference_time(self) -> float: 338 | """Get inference time.""" 339 | return self._inference_time 340 | 341 | @property 342 | def role(self) -> Role: 343 | """Get the response role.""" 344 | return self._role 345 | 346 | @property 347 | def model(self) -> str: 348 | """Get the agent model.""" 349 | return self._model 350 | 351 | @property 352 | def id(self) -> str: 353 | """Get the response ID.""" 354 | return self._id 355 | 356 | @property 357 | def date_created(self) -> datetime.datetime: 358 | """Get the response creation date.""" 359 | return self._date_created 360 | 361 | @property 362 | def date_modified(self) -> datetime.datetime: 363 | """Get the response object modification date.""" 364 | return self._date_modified 365 | -------------------------------------------------------------------------------- /memor/session.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Session class.""" 3 | from typing import List, Dict, Tuple, Any, Union, Generator 4 | import datetime 5 | import json 6 | from .params import MEMOR_VERSION 7 | from .params import DATE_TIME_FORMAT, DATA_SAVE_SUCCESS_MESSAGE 8 | from .params import INVALID_MESSAGE 9 | from .params import INVALID_MESSAGE_STATUS_LEN_MESSAGE, INVALID_RENDER_FORMAT_MESSAGE 10 | from .params import INVALID_INT_OR_STR_MESSAGE, INVALID_INT_OR_STR_SLICE_MESSAGE 11 | from .params import UNSUPPORTED_OPERAND_ERROR_MESSAGE 12 | from .params import RenderFormat 13 | from .tokens_estimator import TokensEstimator 14 | from .prompt import Prompt 15 | from .response import Response 16 | from .errors import MemorValidationError 17 | from .functions import get_time_utc 18 | from .functions import _validate_bool, _validate_path 19 | from .functions import _validate_list_of, _validate_string 20 | 21 | 22 | class Session: 23 | """Session class.""" 24 | 25 | def __init__( 26 | self, 27 | title: str = None, 28 | messages: List[Union[Prompt, Response]] = [], 29 | file_path: str = None, 30 | init_check: bool = True) -> None: 31 | """ 32 | Session object initiator. 33 | 34 | :param title: title 35 | :param messages: messages 36 | :param file_path: file path 37 | :param init_check: initial check flag 38 | """ 39 | self._title = None 40 | self._render_counter = 0 41 | self._messages = [] 42 | self._messages_status = [] 43 | self._date_created = get_time_utc() 44 | self._mark_modified() 45 | self._memor_version = MEMOR_VERSION 46 | if file_path: 47 | self.load(file_path) 48 | else: 49 | if title: 50 | self.update_title(title) 51 | if messages: 52 | self.update_messages(messages) 53 | if init_check: 54 | _ = self.render(enable_counter=False) 55 | 56 | def _mark_modified(self) -> None: 57 | """Mark modification.""" 58 | self._date_modified = get_time_utc() 59 | 60 | def __eq__(self, other_session: "Session") -> bool: 61 | """ 62 | Check sessions equality. 63 | 64 | :param other_session: other session 65 | """ 66 | if isinstance(other_session, Session): 67 | return self._title == other_session._title and self._messages == other_session._messages 68 | return False 69 | 70 | def __str__(self) -> str: 71 | """Return string representation of Session.""" 72 | return self.render(render_format=RenderFormat.STRING, enable_counter=False) 73 | 74 | def __repr__(self) -> str: 75 | """Return string representation of Session.""" 76 | return "Session(title={title})".format(title=self._title) 77 | 78 | def __len__(self) -> int: 79 | """Return the length of the Session object.""" 80 | return len(self._messages) 81 | 82 | def __iter__(self) -> Generator[Union[Prompt, Response], None, None]: 83 | """Iterate through the Session object.""" 84 | yield from self._messages 85 | 86 | def __add__(self, other_object: Union["Session", Response, Prompt]) -> "Session": 87 | """ 88 | Addition method. 89 | 90 | :param other_object: other object 91 | """ 92 | if isinstance(other_object, (Response, Prompt)): 93 | new_messages = self._messages + [other_object] 94 | return Session(title=self.title, messages=new_messages) 95 | if isinstance(other_object, Session): 96 | new_messages = self._messages + other_object._messages 97 | return Session(messages=new_messages) 98 | raise TypeError( 99 | UNSUPPORTED_OPERAND_ERROR_MESSAGE.format( 100 | operator="+", 101 | operand1="Session", 102 | operand2=type(other_object).__name__)) 103 | 104 | def __radd__(self, other_object: Union["Session", Response, Prompt]) -> "Session": 105 | """ 106 | Reverse addition method. 107 | 108 | :param other_object: other object 109 | """ 110 | if isinstance(other_object, (Response, Prompt)): 111 | new_messages = [other_object] + self._messages 112 | return Session(title=self.title, messages=new_messages) 113 | raise TypeError( 114 | UNSUPPORTED_OPERAND_ERROR_MESSAGE.format( 115 | operator="+", 116 | operand1="Session", 117 | operand2=type(other_object).__name__)) 118 | 119 | def __contains__(self, message: Union[Prompt, Response]) -> bool: 120 | """ 121 | Check if the Session contains the given message. 122 | 123 | :param message: message 124 | """ 125 | return message in self._messages 126 | 127 | def __getitem__(self, identifier: Union[int, slice, str]) -> Union[Prompt, Response]: 128 | """ 129 | Get a message from the session object. 130 | 131 | :param identifier: message identifier (index/slice or id) 132 | """ 133 | return self.get_message(identifier=identifier) 134 | 135 | def __copy__(self) -> "Session": 136 | """Return a copy of the Session object.""" 137 | _class = self.__class__ 138 | result = _class.__new__(_class) 139 | result.__dict__.update(self.__dict__) 140 | return result 141 | 142 | def copy(self) -> "Session": 143 | """Return a copy of the Session object.""" 144 | return self.__copy__() 145 | 146 | def add_message(self, 147 | message: Union[Prompt, Response], 148 | status: bool = True, 149 | index: int = None) -> None: 150 | """ 151 | Add a message to the session object. 152 | 153 | :param message: message 154 | :param status: status 155 | :param index: index 156 | """ 157 | if not isinstance(message, (Prompt, Response)): 158 | raise MemorValidationError(INVALID_MESSAGE) 159 | _validate_bool(status, "status") 160 | if index is None: 161 | self._messages.append(message) 162 | self._messages_status.append(status) 163 | else: 164 | self._messages.insert(index, message) 165 | self._messages_status.insert(index, status) 166 | self._mark_modified() 167 | 168 | def get_message_by_index(self, index: Union[int, slice]) -> Union[Prompt, Response]: 169 | """ 170 | Get a message from the session object by index/slice. 171 | 172 | :param index: index 173 | """ 174 | return self._messages[index] 175 | 176 | def get_message_by_id(self, message_id: str) -> Union[Prompt, Response]: 177 | """ 178 | Get a message from the session object by message id. 179 | 180 | :param message_id: message id 181 | """ 182 | for index, message in enumerate(self._messages): 183 | if message.id == message_id: 184 | return self.get_message_by_index(index=index) 185 | 186 | def get_message(self, identifier: Union[int, slice, str]) -> Union[Prompt, Response]: 187 | """ 188 | Get a message from the session object. 189 | 190 | :param identifier: message identifier (index/slice or id) 191 | """ 192 | if isinstance(identifier, (int, slice)): 193 | return self.get_message_by_index(index=identifier) 194 | elif isinstance(identifier, str): 195 | return self.get_message_by_id(message_id=identifier) 196 | else: 197 | raise MemorValidationError(INVALID_INT_OR_STR_SLICE_MESSAGE.format(parameter_name="identifier")) 198 | 199 | def remove_message_by_index(self, index: int) -> None: 200 | """ 201 | Remove a message from the session object by index. 202 | 203 | :param index: index 204 | """ 205 | self._messages.pop(index) 206 | self._messages_status.pop(index) 207 | self._mark_modified() 208 | 209 | def remove_message_by_id(self, message_id: str) -> None: 210 | """ 211 | Remove a message from the session object by message id. 212 | 213 | :param message_id: message id 214 | """ 215 | for index, message in enumerate(self._messages): 216 | if message.id == message_id: 217 | self.remove_message_by_index(index=index) 218 | break 219 | 220 | def remove_message(self, identifier: Union[int, str]) -> None: 221 | """ 222 | Remove a message from the session object. 223 | 224 | :param identifier: message identifier (index or id) 225 | """ 226 | if isinstance(identifier, int): 227 | self.remove_message_by_index(index=identifier) 228 | elif isinstance(identifier, str): 229 | self.remove_message_by_id(message_id=identifier) 230 | else: 231 | raise MemorValidationError(INVALID_INT_OR_STR_MESSAGE.format(parameter_name="identifier")) 232 | 233 | def clear_messages(self) -> None: 234 | """Remove all messages.""" 235 | self._messages = [] 236 | self._messages_status = [] 237 | self._mark_modified() 238 | 239 | def enable_message(self, index: int) -> None: 240 | """ 241 | Enable a message. 242 | 243 | :param index: index 244 | """ 245 | self._messages_status[index] = True 246 | 247 | def disable_message(self, index: int) -> None: 248 | """ 249 | Disable a message. 250 | 251 | :param index: index 252 | """ 253 | self._messages_status[index] = False 254 | 255 | def mask_message(self, index: int) -> None: 256 | """ 257 | Mask a message. 258 | 259 | :param index: index 260 | """ 261 | self.disable_message(index) 262 | 263 | def unmask_message(self, index: int) -> None: 264 | """ 265 | Unmask a message. 266 | 267 | :param index: index 268 | """ 269 | self.enable_message(index) 270 | 271 | def update_title(self, title: str) -> None: 272 | """ 273 | Update the session title. 274 | 275 | :param title: title 276 | """ 277 | _validate_string(title, "title") 278 | self._title = title 279 | self._mark_modified() 280 | 281 | def update_messages(self, 282 | messages: List[Union[Prompt, Response]], 283 | status: List[bool] = None) -> None: 284 | """ 285 | Update the session messages. 286 | 287 | :param messages: messages 288 | :param status: status 289 | """ 290 | _validate_list_of(messages, "messages", (Prompt, Response), "`Prompt` or `Response`") 291 | self._messages = messages 292 | if status: 293 | self.update_messages_status(status) 294 | else: 295 | self.update_messages_status(len(messages) * [True]) 296 | self._mark_modified() 297 | 298 | def update_messages_status(self, status: List[bool]) -> None: 299 | """ 300 | Update the session messages status. 301 | 302 | :param status: status 303 | """ 304 | _validate_list_of(status, "status", bool, "booleans") 305 | if len(status) != len(self._messages): 306 | raise MemorValidationError(INVALID_MESSAGE_STATUS_LEN_MESSAGE) 307 | self._messages_status = status 308 | 309 | def save(self, file_path: str) -> Dict[str, Any]: 310 | """ 311 | Save method. 312 | 313 | :param file_path: session file path 314 | """ 315 | result = {"status": True, "message": DATA_SAVE_SUCCESS_MESSAGE} 316 | try: 317 | with open(file_path, "w") as file: 318 | data = self.to_json() 319 | json.dump(data, file) 320 | except Exception as e: 321 | result["status"] = False 322 | result["message"] = str(e) 323 | return result 324 | 325 | def load(self, file_path: str) -> None: 326 | """ 327 | Load method. 328 | 329 | :param file_path: session file path 330 | """ 331 | _validate_path(file_path) 332 | with open(file_path, "r") as file: 333 | self.from_json(file.read()) 334 | 335 | def from_json(self, json_object: Union[str, Dict[str, Any]]) -> None: 336 | """ 337 | Load attributes from the JSON object. 338 | 339 | :param json_object: JSON object 340 | """ 341 | if isinstance(json_object, str): 342 | loaded_obj = json.loads(json_object) 343 | else: 344 | loaded_obj = json_object.copy() 345 | self._title = loaded_obj["title"] 346 | self._render_counter = loaded_obj.get("render_counter", 0) 347 | self._messages_status = loaded_obj["messages_status"] 348 | messages = [] 349 | for message in loaded_obj["messages"]: 350 | if message["type"] == "Prompt": 351 | message_obj = Prompt() 352 | elif message["type"] == "Response": 353 | message_obj = Response() 354 | message_obj.from_json(message) 355 | messages.append(message_obj) 356 | self._messages = messages 357 | self._memor_version = loaded_obj["memor_version"] 358 | self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT) 359 | self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT) 360 | 361 | def to_json(self) -> Dict[str, Any]: 362 | """Convert the session to a JSON object.""" 363 | data = self.to_dict().copy() 364 | for index, message in enumerate(data["messages"]): 365 | data["messages"][index] = message.to_json() 366 | data["date_created"] = datetime.datetime.strftime(data["date_created"], DATE_TIME_FORMAT) 367 | data["date_modified"] = datetime.datetime.strftime(data["date_modified"], DATE_TIME_FORMAT) 368 | return data 369 | 370 | def to_dict(self) -> Dict[str, Any]: 371 | """ 372 | Convert the session to a dictionary. 373 | 374 | :return: dict 375 | """ 376 | data = { 377 | "type": "Session", 378 | "title": self._title, 379 | "render_counter": self._render_counter, 380 | "messages": self._messages.copy(), 381 | "messages_status": self._messages_status.copy(), 382 | "memor_version": MEMOR_VERSION, 383 | "date_created": self._date_created, 384 | "date_modified": self._date_modified, 385 | } 386 | return data 387 | 388 | def render(self, render_format: RenderFormat = RenderFormat.DEFAULT, 389 | enable_counter: bool = True) -> Union[str, Dict[str, Any], List[Tuple[str, Any]]]: 390 | """ 391 | Render method. 392 | 393 | :param render_format: render format 394 | :param enable_counter: render counter flag 395 | """ 396 | if not isinstance(render_format, RenderFormat): 397 | raise MemorValidationError(INVALID_RENDER_FORMAT_MESSAGE) 398 | result = None 399 | if render_format in [RenderFormat.OPENAI, RenderFormat.AI_STUDIO]: 400 | result = [] 401 | for message in self._messages: 402 | if isinstance(message, Session): 403 | result.extend(message.render(render_format=render_format)) 404 | else: 405 | result.append(message.render(render_format=render_format)) 406 | else: 407 | content = "" 408 | session_dict = self.to_dict() 409 | for message in self._messages: 410 | content += message.render(render_format=RenderFormat.STRING) + "\n" 411 | session_dict["content"] = content 412 | if render_format == RenderFormat.STRING: 413 | result = content 414 | if render_format == RenderFormat.DICTIONARY: 415 | result = session_dict 416 | if render_format == RenderFormat.ITEMS: 417 | result = list(session_dict.items()) 418 | if enable_counter: 419 | self._render_counter += 1 420 | self._mark_modified() 421 | return result 422 | 423 | def check_render(self) -> bool: 424 | """Check render.""" 425 | try: 426 | _ = self.render(enable_counter=False) 427 | return True 428 | except Exception: 429 | return False 430 | 431 | def estimate_tokens(self, method: TokensEstimator = TokensEstimator.DEFAULT) -> int: 432 | """ 433 | Estimate the number of tokens in the session. 434 | 435 | :param method: token estimator method 436 | """ 437 | return method(self.render(render_format=RenderFormat.STRING, enable_counter=False)) 438 | 439 | @property 440 | def date_created(self) -> datetime.datetime: 441 | """Get the session creation date.""" 442 | return self._date_created 443 | 444 | @property 445 | def date_modified(self) -> datetime.datetime: 446 | """Get the session object modification date.""" 447 | return self._date_modified 448 | 449 | @property 450 | def title(self) -> str: 451 | """Get the session title.""" 452 | return self._title 453 | 454 | @property 455 | def render_counter(self) -> int: 456 | """Get the render counter.""" 457 | return self._render_counter 458 | 459 | @property 460 | def messages(self) -> List[Union[Prompt, Response]]: 461 | """Get the session messages.""" 462 | return self._messages 463 | 464 | @property 465 | def messages_status(self) -> List[bool]: 466 | """Get the session messages status.""" 467 | return self._messages_status 468 | 469 | @property 470 | def masks(self) -> List[bool]: 471 | """Get the session masks.""" 472 | return [not x for x in self._messages_status] 473 | -------------------------------------------------------------------------------- /memor/template.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Template class.""" 3 | from typing import Dict, Any, Union 4 | import json 5 | import datetime 6 | from enum import Enum 7 | from .params import DATE_TIME_FORMAT 8 | from .params import DATA_SAVE_SUCCESS_MESSAGE 9 | from .params import INVALID_TEMPLATE_STRUCTURE_MESSAGE 10 | from .params import MEMOR_VERSION 11 | from .errors import MemorValidationError 12 | from .functions import get_time_utc 13 | from .functions import _validate_path, _validate_custom_map 14 | from .functions import _validate_string 15 | 16 | 17 | class PromptTemplate: 18 | r""" 19 | Prompt template. 20 | 21 | >>> template = PromptTemplate(content="Take a deep breath\n{prompt_message}!", title="Greeting") 22 | >>> template.title 23 | 'Greeting' 24 | """ 25 | 26 | def __init__( 27 | self, 28 | content: str = None, 29 | file_path: str = None, 30 | title: str = None, 31 | custom_map: Dict[str, str] = None) -> None: 32 | """ 33 | Prompt template object initiator. 34 | 35 | :param content: template content 36 | :param file_path: template file path 37 | :param title: template title 38 | :param custom_map: custom map 39 | """ 40 | self._content = None 41 | self._title = None 42 | self._date_created = get_time_utc() 43 | self._mark_modified() 44 | self._memor_version = MEMOR_VERSION 45 | self._custom_map = None 46 | if file_path: 47 | self.load(file_path) 48 | else: 49 | if title: 50 | self.update_title(title) 51 | if content: 52 | self.update_content(content) 53 | if custom_map: 54 | self.update_map(custom_map) 55 | 56 | def _mark_modified(self) -> None: 57 | """Mark modification.""" 58 | self._date_modified = get_time_utc() 59 | 60 | def __eq__(self, other_template: "PromptTemplate") -> bool: 61 | """ 62 | Check templates equality. 63 | 64 | :param other_template: another template 65 | """ 66 | if isinstance(other_template, PromptTemplate): 67 | return self._content == other_template._content and self._title == other_template._title and self._custom_map == other_template._custom_map 68 | return False 69 | 70 | def __str__(self) -> str: 71 | """Return string representation of PromptTemplate.""" 72 | return self._content 73 | 74 | def __repr__(self) -> str: 75 | """Return string representation of PromptTemplate.""" 76 | return "PromptTemplate(content={content})".format(content=self._content) 77 | 78 | def __copy__(self) -> "PromptTemplate": 79 | """Return a copy of the PromptTemplate object.""" 80 | _class = self.__class__ 81 | result = _class.__new__(_class) 82 | result.__dict__.update(self.__dict__) 83 | return result 84 | 85 | def copy(self) -> "PromptTemplate": 86 | """Return a copy of the PromptTemplate object.""" 87 | return self.__copy__() 88 | 89 | def update_title(self, title: str) -> None: 90 | """ 91 | Update title. 92 | 93 | :param title: title 94 | """ 95 | _validate_string(title, "title") 96 | self._title = title 97 | self._mark_modified() 98 | 99 | def update_content(self, content: str) -> None: 100 | """ 101 | Update content. 102 | 103 | :param content: content 104 | """ 105 | _validate_string(content, "content") 106 | self._content = content 107 | self._mark_modified() 108 | 109 | def update_map(self, custom_map: Dict[str, str]) -> None: 110 | """ 111 | Update custom map. 112 | 113 | :param custom_map: custom map 114 | """ 115 | _validate_custom_map(custom_map) 116 | self._custom_map = custom_map 117 | self._mark_modified() 118 | 119 | def save(self, file_path: str) -> Dict[str, Any]: 120 | """ 121 | Save method. 122 | 123 | :param file_path: template file path 124 | """ 125 | result = {"status": True, "message": DATA_SAVE_SUCCESS_MESSAGE} 126 | try: 127 | with open(file_path, "w") as file: 128 | json.dump(self.to_json(), file) 129 | except Exception as e: 130 | result["status"] = False 131 | result["message"] = str(e) 132 | return result 133 | 134 | def load(self, file_path: str) -> None: 135 | """ 136 | Load method. 137 | 138 | :param file_path: template file path 139 | """ 140 | _validate_path(file_path) 141 | with open(file_path, "r") as file: 142 | self.from_json(file.read()) 143 | 144 | def from_json(self, json_object: Union[str, Dict[str, Any]]) -> None: 145 | """ 146 | Load attributes from the JSON object. 147 | 148 | :param json_object: JSON object 149 | """ 150 | try: 151 | if isinstance(json_object, str): 152 | loaded_obj = json.loads(json_object) 153 | else: 154 | loaded_obj = json_object.copy() 155 | self._content = loaded_obj["content"] 156 | self._title = loaded_obj["title"] 157 | self._memor_version = loaded_obj["memor_version"] 158 | self._custom_map = loaded_obj["custom_map"] 159 | self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT) 160 | self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT) 161 | except Exception: 162 | raise MemorValidationError(INVALID_TEMPLATE_STRUCTURE_MESSAGE) 163 | 164 | def to_json(self) -> Dict[str, Any]: 165 | """Convert PromptTemplate to json.""" 166 | data = self.to_dict().copy() 167 | data["date_created"] = datetime.datetime.strftime(data["date_created"], DATE_TIME_FORMAT) 168 | data["date_modified"] = datetime.datetime.strftime(data["date_modified"], DATE_TIME_FORMAT) 169 | return data 170 | 171 | def to_dict(self) -> Dict[str, Any]: 172 | """Convert PromptTemplate to dict.""" 173 | return { 174 | "title": self._title, 175 | "content": self._content, 176 | "memor_version": MEMOR_VERSION, 177 | "custom_map": self._custom_map.copy(), 178 | "date_created": self._date_created, 179 | "date_modified": self._date_modified, 180 | } 181 | 182 | @property 183 | def content(self) -> str: 184 | """Get the PromptTemplate content.""" 185 | return self._content 186 | 187 | @property 188 | def title(self) -> str: 189 | """Get the PromptTemplate title.""" 190 | return self._title 191 | 192 | @property 193 | def date_created(self) -> datetime.datetime: 194 | """Get the PromptTemplate creation date.""" 195 | return self._date_created 196 | 197 | @property 198 | def date_modified(self) -> datetime.datetime: 199 | """Get the PromptTemplate modification date.""" 200 | return self._date_modified 201 | 202 | @property 203 | def custom_map(self) -> Dict[str, str]: 204 | """Get the PromptTemplate custom map.""" 205 | return self._custom_map 206 | 207 | 208 | PROMPT_INSTRUCTION1 = "I'm providing you with a history of a previous conversation. Please consider this context when responding to my new question.\n" 209 | PROMPT_INSTRUCTION2 = "Here is the context from a prior conversation. Please learn from this information and use it to provide a thoughtful and context-aware response to my next questions.\n" 210 | PROMPT_INSTRUCTION3 = "I am sharing a record of a previous discussion. Use this information to provide a consistent and relevant answer to my next query.\n" 211 | 212 | BASIC_PROMPT_CONTENT = "{instruction}{prompt[message]}" 213 | BASIC_RESPONSE_CONTENT = "{instruction}{response[message]}" 214 | BASIC_RESPONSE0_CONTENT = "{instruction}{responses[0][message]}" 215 | BASIC_RESPONSE1_CONTENT = "{instruction}{responses[1][message]}" 216 | BASIC_RESPONSE2_CONTENT = "{instruction}{responses[2][message]}" 217 | BASIC_RESPONSE3_CONTENT = "{instruction}{responses[3][message]}" 218 | BASIC_PROMPT_CONTENT_LABEL = "{instruction}Prompt: {prompt[message]}" 219 | BASIC_RESPONSE_CONTENT_LABEL = "{instruction}Response: {response[message]}" 220 | BASIC_RESPONSE0_CONTENT_LABEL = "{instruction}Response: {responses[0][message]}" 221 | BASIC_RESPONSE1_CONTENT_LABEL = "{instruction}Response: {responses[1][message]}" 222 | BASIC_RESPONSE2_CONTENT_LABEL = "{instruction}Response: {responses[2][message]}" 223 | BASIC_RESPONSE3_CONTENT_LABEL = "{instruction}Response: {responses[3][message]}" 224 | BASIC_PROMPT_RESPONSE_STANDARD_CONTENT = "{instruction}Prompt: {prompt[message]}\nResponse: {response[message]}" 225 | BASIC_PROMPT_RESPONSE_FULL_CONTENT = """{instruction} 226 | Prompt: 227 | Message: {prompt[message]} 228 | Role: {prompt[role]} 229 | Tokens: {prompt[tokens]} 230 | Date: {prompt[date]} 231 | Response: 232 | Message: {response[message]} 233 | Role: {response[role]} 234 | Temperature: {response[temperature]} 235 | Model: {response[model]} 236 | Score: {response[score]} 237 | Tokens: {response[tokens]} 238 | Inference Time: {response[inference_time]} 239 | Date: {response[date]}""" 240 | 241 | 242 | class _BasicPresetPromptTemplate(Enum): 243 | """Preset basic-prompt templates.""" 244 | 245 | PROMPT = PromptTemplate(content=BASIC_PROMPT_CONTENT, title="Basic/Prompt", custom_map={"instruction": ""}) 246 | RESPONSE = PromptTemplate( 247 | content=BASIC_RESPONSE_CONTENT, 248 | title="Basic/Response", 249 | custom_map={ 250 | "instruction": ""}) 251 | RESPONSE0 = PromptTemplate( 252 | content=BASIC_RESPONSE0_CONTENT, 253 | title="Basic/Response0", 254 | custom_map={ 255 | "instruction": ""}) 256 | RESPONSE1 = PromptTemplate( 257 | content=BASIC_RESPONSE1_CONTENT, 258 | title="Basic/Response1", 259 | custom_map={ 260 | "instruction": ""}) 261 | RESPONSE2 = PromptTemplate( 262 | content=BASIC_RESPONSE2_CONTENT, 263 | title="Basic/Response2", 264 | custom_map={ 265 | "instruction": ""}) 266 | RESPONSE3 = PromptTemplate( 267 | content=BASIC_RESPONSE3_CONTENT, 268 | title="Basic/Response3", 269 | custom_map={ 270 | "instruction": ""}) 271 | PROMPT_WITH_LABEL = PromptTemplate( 272 | content=BASIC_PROMPT_CONTENT_LABEL, 273 | title="Basic/Prompt With Label", 274 | custom_map={ 275 | "instruction": ""}) 276 | RESPONSE_WITH_LABEL = PromptTemplate( 277 | content=BASIC_RESPONSE_CONTENT_LABEL, 278 | title="Basic/Response With Label", 279 | custom_map={ 280 | "instruction": ""}) 281 | RESPONSE0_WITH_LABEL = PromptTemplate( 282 | content=BASIC_RESPONSE0_CONTENT_LABEL, 283 | title="Basic/Response0 With Label", 284 | custom_map={ 285 | "instruction": ""}) 286 | RESPONSE1_WITH_LABEL = PromptTemplate( 287 | content=BASIC_RESPONSE1_CONTENT_LABEL, 288 | title="Basic/Response1 With Label", 289 | custom_map={ 290 | "instruction": ""}) 291 | RESPONSE2_WITH_LABEL = PromptTemplate( 292 | content=BASIC_RESPONSE2_CONTENT_LABEL, 293 | title="Basic/Response2 With Label", 294 | custom_map={ 295 | "instruction": ""}) 296 | RESPONSE3_WITH_LABEL = PromptTemplate( 297 | content=BASIC_RESPONSE3_CONTENT_LABEL, 298 | title="Basic/Response3 With Label", 299 | custom_map={ 300 | "instruction": ""}) 301 | PROMPT_RESPONSE_STANDARD = PromptTemplate( 302 | content=BASIC_PROMPT_RESPONSE_STANDARD_CONTENT, 303 | title="Basic/Prompt-Response Standard", 304 | custom_map={ 305 | "instruction": ""}) 306 | PROMPT_RESPONSE_FULL = PromptTemplate( 307 | content=BASIC_PROMPT_RESPONSE_FULL_CONTENT, 308 | title="Basic/Prompt-Response Full", 309 | custom_map={ 310 | "instruction": ""}) 311 | 312 | 313 | class _Instruction1PresetPromptTemplate(Enum): 314 | """Preset instruction1-prompt templates.""" 315 | 316 | PROMPT = PromptTemplate( 317 | content=BASIC_PROMPT_CONTENT, 318 | title="Instruction1/Prompt", 319 | custom_map={ 320 | "instruction": PROMPT_INSTRUCTION1}) 321 | RESPONSE = PromptTemplate( 322 | content=BASIC_RESPONSE_CONTENT, 323 | title="Instruction1/Response", 324 | custom_map={ 325 | "instruction": PROMPT_INSTRUCTION1}) 326 | RESPONSE0 = PromptTemplate( 327 | content=BASIC_RESPONSE0_CONTENT, 328 | title="Instruction1/Response0", 329 | custom_map={ 330 | "instruction": PROMPT_INSTRUCTION1}) 331 | RESPONSE1 = PromptTemplate( 332 | content=BASIC_RESPONSE1_CONTENT, 333 | title="Instruction1/Response1", 334 | custom_map={ 335 | "instruction": PROMPT_INSTRUCTION1}) 336 | RESPONSE2 = PromptTemplate( 337 | content=BASIC_RESPONSE2_CONTENT, 338 | title="Instruction1/Response2", 339 | custom_map={ 340 | "instruction": PROMPT_INSTRUCTION1}) 341 | RESPONSE3 = PromptTemplate( 342 | content=BASIC_RESPONSE3_CONTENT, 343 | title="Instruction1/Response3", 344 | custom_map={ 345 | "instruction": PROMPT_INSTRUCTION1}) 346 | PROMPT_WITH_LABEL = PromptTemplate( 347 | content=BASIC_PROMPT_CONTENT_LABEL, 348 | title="Instruction1/Prompt With Label", 349 | custom_map={ 350 | "instruction": PROMPT_INSTRUCTION1}) 351 | RESPONSE_WITH_LABEL = PromptTemplate( 352 | content=BASIC_RESPONSE_CONTENT_LABEL, 353 | title="Instruction1/Response With Label", 354 | custom_map={ 355 | "instruction": PROMPT_INSTRUCTION1}) 356 | RESPONSE0_WITH_LABEL = PromptTemplate( 357 | content=BASIC_RESPONSE0_CONTENT_LABEL, 358 | title="Instruction1/Response0 With Label", 359 | custom_map={ 360 | "instruction": PROMPT_INSTRUCTION1}) 361 | RESPONSE1_WITH_LABEL = PromptTemplate( 362 | content=BASIC_RESPONSE1_CONTENT_LABEL, 363 | title="Instruction1/Response1 With Label", 364 | custom_map={ 365 | "instruction": PROMPT_INSTRUCTION1}) 366 | RESPONSE2_WITH_LABEL = PromptTemplate( 367 | content=BASIC_RESPONSE2_CONTENT_LABEL, 368 | title="Instruction1/Response2 With Label", 369 | custom_map={ 370 | "instruction": PROMPT_INSTRUCTION1}) 371 | RESPONSE3_WITH_LABEL = PromptTemplate( 372 | content=BASIC_RESPONSE3_CONTENT_LABEL, 373 | title="Instruction1/Response3 With Label", 374 | custom_map={ 375 | "instruction": PROMPT_INSTRUCTION1}) 376 | PROMPT_RESPONSE_STANDARD = PromptTemplate( 377 | content=BASIC_PROMPT_RESPONSE_STANDARD_CONTENT, 378 | title="Instruction1/Prompt-Response Standard", 379 | custom_map={ 380 | "instruction": PROMPT_INSTRUCTION1}) 381 | PROMPT_RESPONSE_FULL = PromptTemplate( 382 | content=BASIC_PROMPT_RESPONSE_FULL_CONTENT, 383 | title="Instruction1/Prompt-Response Full", 384 | custom_map={ 385 | "instruction": PROMPT_INSTRUCTION1}) 386 | 387 | 388 | class _Instruction2PresetPromptTemplate(Enum): 389 | """Preset instruction2-prompt templates.""" 390 | 391 | PROMPT = PromptTemplate( 392 | content=BASIC_PROMPT_CONTENT, 393 | title="Instruction2/Prompt", 394 | custom_map={ 395 | "instruction": PROMPT_INSTRUCTION2}) 396 | RESPONSE = PromptTemplate( 397 | content=BASIC_RESPONSE_CONTENT, 398 | title="Instruction2/Response", 399 | custom_map={ 400 | "instruction": PROMPT_INSTRUCTION2}) 401 | RESPONSE0 = PromptTemplate( 402 | content=BASIC_RESPONSE0_CONTENT, 403 | title="Instruction2/Response0", 404 | custom_map={ 405 | "instruction": PROMPT_INSTRUCTION2}) 406 | RESPONSE1 = PromptTemplate( 407 | content=BASIC_RESPONSE1_CONTENT, 408 | title="Instruction2/Response1", 409 | custom_map={ 410 | "instruction": PROMPT_INSTRUCTION2}) 411 | RESPONSE2 = PromptTemplate( 412 | content=BASIC_RESPONSE2_CONTENT, 413 | title="Instruction2/Response2", 414 | custom_map={ 415 | "instruction": PROMPT_INSTRUCTION2}) 416 | RESPONSE3 = PromptTemplate( 417 | content=BASIC_RESPONSE3_CONTENT, 418 | title="Instruction2/Response3", 419 | custom_map={ 420 | "instruction": PROMPT_INSTRUCTION2}) 421 | PROMPT_WITH_LABEL = PromptTemplate( 422 | content=BASIC_PROMPT_CONTENT_LABEL, 423 | title="Instruction2/Prompt With Label", 424 | custom_map={ 425 | "instruction": PROMPT_INSTRUCTION2}) 426 | RESPONSE_WITH_LABEL = PromptTemplate( 427 | content=BASIC_RESPONSE_CONTENT_LABEL, 428 | title="Instruction2/Response With Label", 429 | custom_map={ 430 | "instruction": PROMPT_INSTRUCTION2}) 431 | RESPONSE0_WITH_LABEL = PromptTemplate( 432 | content=BASIC_RESPONSE0_CONTENT_LABEL, 433 | title="Instruction2/Response0 With Label", 434 | custom_map={ 435 | "instruction": PROMPT_INSTRUCTION2}) 436 | RESPONSE1_WITH_LABEL = PromptTemplate( 437 | content=BASIC_RESPONSE1_CONTENT_LABEL, 438 | title="Instruction2/Response1 With Label", 439 | custom_map={ 440 | "instruction": PROMPT_INSTRUCTION2}) 441 | RESPONSE2_WITH_LABEL = PromptTemplate( 442 | content=BASIC_RESPONSE2_CONTENT_LABEL, 443 | title="Instruction2/Response2 With Label", 444 | custom_map={ 445 | "instruction": PROMPT_INSTRUCTION2}) 446 | RESPONSE3_WITH_LABEL = PromptTemplate( 447 | content=BASIC_RESPONSE3_CONTENT_LABEL, 448 | title="Instruction2/Response3 With Label", 449 | custom_map={ 450 | "instruction": PROMPT_INSTRUCTION2}) 451 | PROMPT_RESPONSE_STANDARD = PromptTemplate( 452 | content=BASIC_PROMPT_RESPONSE_STANDARD_CONTENT, 453 | title="Instruction2/Prompt-Response Standard", 454 | custom_map={ 455 | "instruction": PROMPT_INSTRUCTION2}) 456 | PROMPT_RESPONSE_FULL = PromptTemplate( 457 | content=BASIC_PROMPT_RESPONSE_FULL_CONTENT, 458 | title="Instruction2/Prompt-Response Full", 459 | custom_map={ 460 | "instruction": PROMPT_INSTRUCTION2}) 461 | 462 | 463 | class _Instruction3PresetPromptTemplate(Enum): 464 | """Preset instruction3-prompt templates.""" 465 | 466 | PROMPT = PromptTemplate( 467 | content=BASIC_PROMPT_CONTENT, 468 | title="Instruction3/Prompt", 469 | custom_map={ 470 | "instruction": PROMPT_INSTRUCTION3}) 471 | RESPONSE = PromptTemplate( 472 | content=BASIC_RESPONSE_CONTENT, 473 | title="Instruction3/Response", 474 | custom_map={ 475 | "instruction": PROMPT_INSTRUCTION3}) 476 | RESPONSE0 = PromptTemplate( 477 | content=BASIC_RESPONSE0_CONTENT, 478 | title="Instruction3/Response0", 479 | custom_map={ 480 | "instruction": PROMPT_INSTRUCTION3}) 481 | RESPONSE1 = PromptTemplate( 482 | content=BASIC_RESPONSE1_CONTENT, 483 | title="Instruction3/Response1", 484 | custom_map={ 485 | "instruction": PROMPT_INSTRUCTION3}) 486 | RESPONSE2 = PromptTemplate( 487 | content=BASIC_RESPONSE2_CONTENT, 488 | title="Instruction3/Response2", 489 | custom_map={ 490 | "instruction": PROMPT_INSTRUCTION3}) 491 | RESPONSE3 = PromptTemplate( 492 | content=BASIC_RESPONSE3_CONTENT, 493 | title="Instruction3/Response3", 494 | custom_map={ 495 | "instruction": PROMPT_INSTRUCTION3}) 496 | PROMPT_WITH_LABEL = PromptTemplate( 497 | content=BASIC_PROMPT_CONTENT_LABEL, 498 | title="Instruction3/Prompt With Label", 499 | custom_map={ 500 | "instruction": PROMPT_INSTRUCTION3}) 501 | RESPONSE_WITH_LABEL = PromptTemplate( 502 | content=BASIC_RESPONSE_CONTENT_LABEL, 503 | title="Instruction3/Response With Label", 504 | custom_map={ 505 | "instruction": PROMPT_INSTRUCTION3}) 506 | RESPONSE0_WITH_LABEL = PromptTemplate( 507 | content=BASIC_RESPONSE0_CONTENT_LABEL, 508 | title="Instruction3/Response0 With Label", 509 | custom_map={ 510 | "instruction": PROMPT_INSTRUCTION3}) 511 | RESPONSE1_WITH_LABEL = PromptTemplate( 512 | content=BASIC_RESPONSE1_CONTENT_LABEL, 513 | title="Instruction3/Response1 With Label", 514 | custom_map={ 515 | "instruction": PROMPT_INSTRUCTION3}) 516 | RESPONSE2_WITH_LABEL = PromptTemplate( 517 | content=BASIC_RESPONSE2_CONTENT_LABEL, 518 | title="Instruction3/Response2 With Label", 519 | custom_map={ 520 | "instruction": PROMPT_INSTRUCTION3}) 521 | RESPONSE3_WITH_LABEL = PromptTemplate( 522 | content=BASIC_RESPONSE3_CONTENT_LABEL, 523 | title="Instruction3/Response3 With Label", 524 | custom_map={ 525 | "instruction": PROMPT_INSTRUCTION3}) 526 | PROMPT_RESPONSE_STANDARD = PromptTemplate( 527 | content=BASIC_PROMPT_RESPONSE_STANDARD_CONTENT, 528 | title="Instruction3/Prompt-Response Standard", 529 | custom_map={ 530 | "instruction": PROMPT_INSTRUCTION3}) 531 | PROMPT_RESPONSE_FULL = PromptTemplate( 532 | content=BASIC_PROMPT_RESPONSE_FULL_CONTENT, 533 | title="Instruction3/Prompt-Response Full", 534 | custom_map={ 535 | "instruction": PROMPT_INSTRUCTION3}) 536 | 537 | 538 | class PresetPromptTemplate: 539 | """Preset prompt templates.""" 540 | 541 | BASIC = _BasicPresetPromptTemplate 542 | INSTRUCTION1 = _Instruction1PresetPromptTemplate 543 | INSTRUCTION2 = _Instruction2PresetPromptTemplate 544 | INSTRUCTION3 = _Instruction3PresetPromptTemplate 545 | DEFAULT = BASIC.PROMPT 546 | -------------------------------------------------------------------------------- /memor/tokens_estimator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Tokens estimator functions.""" 3 | 4 | import re 5 | from enum import Enum 6 | from typing import Set, List 7 | from .keywords import PROGRAMMING_LANGUAGES_KEYWORDS 8 | from .keywords import COMMON_PREFIXES, COMMON_SUFFIXES 9 | 10 | 11 | def _is_code_snippet(message: str) -> bool: 12 | """ 13 | Check if the message is a code snippet based on common coding symbols. 14 | 15 | :param message: The input message to check. 16 | :return: Boolean indicating if the message is a code snippet. 17 | """ 18 | return bool(re.search(r"[=<>+\-*/{}();]", message)) 19 | 20 | 21 | def _preprocess_message(message: str, is_code: bool) -> str: 22 | """ 23 | Preprocess message by replacing contractions in non-code text. 24 | 25 | :param message: The input message to preprocess. 26 | :param is_code: Boolean indicating if the message is a code. 27 | :return: Preprocessed message. 28 | """ 29 | if not is_code: 30 | return re.sub(r"(?<=\w)'(?=\w)", " ", message) 31 | return message 32 | 33 | 34 | def _tokenize_message(message: str) -> List[str]: 35 | """ 36 | Tokenize the message based on words, symbols, and numbers. 37 | 38 | :param message: The input message to tokenize. 39 | :return: List of tokens. 40 | """ 41 | return re.findall(r"[A-Za-z_][A-Za-z0-9_]*|[+\-*/=<>(){}[\],.:;]|\"[^\"]*\"|'[^']*'|\d+|\S", message) 42 | 43 | 44 | def _count_code_tokens(token: str, common_keywords: Set[str]) -> int: 45 | """ 46 | Count tokens in code snippets considering different token types. 47 | 48 | :param token: The token to count. 49 | :param common_keywords: Set of common keywords in programming languages. 50 | :return: Count of tokens. 51 | """ 52 | if token in common_keywords or re.match(r"[+\-*/=<>(){}[\],.:;]", token): 53 | return 1 54 | if token.isdigit(): 55 | return max(1, len(token) // 4) 56 | if token.startswith(("'", '"')) and token.endswith(("'", '"')): 57 | return max(1, len(token) // 6) 58 | if "_" in token: 59 | return len(token.split("_")) 60 | if re.search(r"[A-Z]", token): 61 | return len(re.findall(r"[A-Z][a-z]*", token)) 62 | return 1 63 | 64 | 65 | def _count_text_tokens(token: str, prefixes: Set[str], suffixes: Set[str]) -> int: 66 | """ 67 | Count tokens in text based on prefixes, suffixes, and subwords. 68 | 69 | :param token: The token to count. 70 | :param prefixes: Set of common prefixes. 71 | :param suffixes: Set of common suffixes. 72 | :return: Token count. 73 | """ 74 | if len(token) == 1 and not token.isalnum(): 75 | return 1 76 | if token.isdigit(): 77 | return max(1, len(token) // 4) 78 | prefix_count = sum(token.startswith(p) for p in prefixes if len(token) > len(p) + 3) 79 | suffix_count = sum(token.endswith(s) for s in suffixes if len(token) > len(s) + 3) 80 | parts = re.findall(r"[aeiou]+|[^aeiou]+", token) 81 | subword_count = max(1, len(parts) // 2) 82 | 83 | return prefix_count + suffix_count + subword_count 84 | 85 | 86 | def universal_tokens_estimator(message: str) -> int: 87 | """ 88 | Estimate the number of tokens in a given text or code snippet. 89 | 90 | :param message: The input text or code snippet to estimate tokens for. 91 | :return: Estimated number of tokens. 92 | """ 93 | is_code = _is_code_snippet(message) 94 | message = _preprocess_message(message, is_code) 95 | tokens = _tokenize_message(message) 96 | 97 | return sum( 98 | _count_code_tokens( 99 | token, 100 | PROGRAMMING_LANGUAGES_KEYWORDS) if is_code else _count_text_tokens( 101 | token, 102 | COMMON_PREFIXES, 103 | COMMON_SUFFIXES) for token in tokens) 104 | 105 | 106 | def _openai_tokens_estimator(text: str) -> int: 107 | """ 108 | Estimate the number of tokens in a given text for OpenAI's models. 109 | 110 | :param text: The input text to estimate tokens for. 111 | :return: Estimated number of tokens. 112 | """ 113 | char_count = len(text) 114 | token_estimate = char_count / 4 115 | 116 | space_count = text.count(" ") 117 | punctuation_count = sum(1 for char in text if char in ",.?!;:") 118 | token_estimate += (space_count + punctuation_count) * 0.5 119 | 120 | if any(keyword in text for keyword in PROGRAMMING_LANGUAGES_KEYWORDS): 121 | token_estimate *= 1.1 122 | 123 | newline_count = text.count("\n") 124 | token_estimate += newline_count * 0.8 125 | 126 | long_word_penalty = sum(len(word) / 10 for word in text.split() if len(word) > 15) 127 | token_estimate += long_word_penalty 128 | 129 | if "http" in text: 130 | token_estimate *= 1.1 131 | 132 | rare_char_count = sum(1 for char in text if ord(char) > 10000) 133 | token_estimate += rare_char_count * 0.8 134 | 135 | return token_estimate 136 | 137 | 138 | def openai_tokens_estimator_gpt_3_5(text: str) -> int: 139 | """ 140 | Estimate the number of tokens in a given text for OpenAI's GPT-3.5 Turbo model. 141 | 142 | :param text: The input text to estimate tokens for. 143 | :return: Estimated number of tokens. 144 | """ 145 | token_estimate = _openai_tokens_estimator(text) 146 | return int(max(1, token_estimate)) 147 | 148 | 149 | def openai_tokens_estimator_gpt_4(text: str) -> int: 150 | """ 151 | Estimate the number of tokens in a given text for OpenAI's GPT-4 model. 152 | 153 | :param text: The input text to estimate tokens for. 154 | :return: Estimated number of tokens. 155 | """ 156 | token_estimate = _openai_tokens_estimator(text) 157 | token_estimate *= 1.05 # Adjusting for GPT-4's tokenization 158 | return int(max(1, token_estimate)) 159 | 160 | 161 | class TokensEstimator(Enum): 162 | """Token estimator enum.""" 163 | 164 | UNIVERSAL = universal_tokens_estimator 165 | OPENAI_GPT_3_5 = openai_tokens_estimator_gpt_3_5 166 | OPENAI_GPT_4 = openai_tokens_estimator_gpt_4 167 | DEFAULT = UNIVERSAL 168 | -------------------------------------------------------------------------------- /otherfiles/RELEASE.md: -------------------------------------------------------------------------------- 1 | # Memor Release Instructions 2 | 3 | **Last Update: 2024-12-27** 4 | 5 | 1. Create the `release` branch under `dev` 6 | 2. Update all version tags 7 | 1. `setup.py` 8 | 2. `README.md` 9 | 3. `otherfiles/version_check.py` 10 | 4. `otherfiles/meta.yaml` 11 | 5. `memor/params.py` 12 | 3. Update `CHANGELOG.md` 13 | 1. Add a new header under `Unreleased` section (Example: `## [0.1] - 2022-08-17`) 14 | 2. Add a new compare link to the end of the file (Example: `[0.2]: https://github.com/openscilab/memor/compare/v0.1...v0.2`) 15 | 3. Update `dev` compare link (Example: `[Unreleased]: https://github.com/openscilab/memor/compare/v0.2...dev`) 16 | 4. Update `.github/ISSUE_TEMPLATE/bug_report.yml` 17 | 1. Add new version tag to `Memor version` dropbox options 18 | 5. Create a PR from `release` to `dev` 19 | 1. Title: `Version x.x` (Example: `Version 0.1`) 20 | 2. Tag all related issues 21 | 3. Labels: `release` 22 | 4. Set milestone 23 | 5. Wait for all CI pass 24 | 6. Need review (**2** reviewers) 25 | 7. Squash and merge 26 | 8. Delete `release` branch 27 | 6. Merge `dev` branch into `main` 28 | 1. `git checkout main` 29 | 2. `git merge dev` 30 | 3. `git push origin main` 31 | 4. Wait for all CI pass 32 | 7. Create a new release 33 | 1. Target branch: `main` 34 | 2. Tag: `vx.x` (Example: `v0.1`) 35 | 3. Title: `Version x.x` (Example: `Version 0.1`) 36 | 4. Copy changelogs 37 | 5. Tag all related issues 38 | 8. Bump!! 39 | 9. Close this version issues 40 | 10. Close milestone -------------------------------------------------------------------------------- /otherfiles/donation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openscilab/memor/325e17b8ffc5eb62a9bb4044df63021b9968b94b/otherfiles/donation.png -------------------------------------------------------------------------------- /otherfiles/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set name = "memor" %} 2 | {% set version = "0.6" %} 3 | 4 | package: 5 | name: {{ name|lower }} 6 | version: {{ version }} 7 | source: 8 | git_url: https://github.com/openscilab/memor 9 | git_rev: v{{ version }} 10 | build: 11 | noarch: python 12 | number: 0 13 | script: {{ PYTHON }} -m pip install . -vv 14 | requirements: 15 | host: 16 | - pip 17 | - setuptools 18 | - python >=3.7 19 | run: 20 | - python >=3.7 21 | about: 22 | home: https://github.com/openscilab/memor 23 | license: MIT 24 | license_family: MIT 25 | summary: Memor: A Python Library for Managing and Transferring Conversational Memory Across LLMs 26 | description: | 27 | Memor is a library designed to help users manage the memory of their interactions with Large Language Models (LLMs). It enables users to seamlessly access and utilize the history of their conversations when prompting LLMs. That would create a more personalized and context-aware experience. Memor stands out by allowing users to transfer conversational history across different LLMs, eliminating cold starts where models don\'t have information about user and their preferences. Users can select specific parts of past interactions with one LLM and share them with another.By bridging the gap between isolated LLM instances, Memor revolutionizes the way users interact with AI by making transitions between models smoother. 28 | 29 | Website: https://openscilab.com 30 | 31 | Repo: https://github.com/openscilab/memor 32 | extra: 33 | recipe-maintainers: 34 | - sepandhaghighi 35 | - sadrasabouri 36 | -------------------------------------------------------------------------------- /otherfiles/requirements-splitter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Requirements splitter.""" 3 | 4 | test_req = "" 5 | 6 | with open('dev-requirements.txt', 'r') as f: 7 | for line in f: 8 | if '==' not in line: 9 | test_req += line 10 | 11 | with open('test-requirements.txt', 'w') as f: 12 | f.write(test_req) 13 | -------------------------------------------------------------------------------- /otherfiles/version_check.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Version-check script.""" 3 | import os 4 | import sys 5 | import codecs 6 | Failed = 0 7 | MEMOR_VERSION = "0.6" 8 | 9 | 10 | SETUP_ITEMS = [ 11 | "version='{0}'", 12 | 'https://github.com/openscilab/memor/tarball/v{0}'] 13 | README_ITEMS = [ 14 | "[Version {0}](https://github.com/openscilab/memor/archive/v{0}.zip)", 15 | "pip install memor=={0}"] 16 | CHANGELOG_ITEMS = [ 17 | "## [{0}]", 18 | "https://github.com/openscilab/memor/compare/v{0}...dev", 19 | "[{0}]:"] 20 | PARAMS_ITEMS = ['MEMOR_VERSION = "{0}"'] 21 | META_ITEMS = ['% set version = "{0}" %'] 22 | ISSUE_TEMPLATE_ITEMS = ["- Memor {0}"] 23 | SECURITY_ITEMS = ["| {0} | :white_check_mark: |", "| < {0} | :x: |"] 24 | 25 | FILES = { 26 | os.path.join("otherfiles", "meta.yaml"): META_ITEMS, 27 | "setup.py": SETUP_ITEMS, 28 | "README.md": README_ITEMS, 29 | "CHANGELOG.md": CHANGELOG_ITEMS, 30 | "SECURITY.md": SECURITY_ITEMS, 31 | os.path.join("memor", "params.py"): PARAMS_ITEMS, 32 | os.path.join(".github", "ISSUE_TEMPLATE", "bug_report.yml"): ISSUE_TEMPLATE_ITEMS, 33 | } 34 | 35 | TEST_NUMBER = len(FILES) 36 | 37 | 38 | def print_result(failed: bool = False) -> None: 39 | """ 40 | Print final result. 41 | 42 | :param failed: failed flag 43 | """ 44 | message = "Version tag tests " 45 | if not failed: 46 | print("\n" + message + "passed!") 47 | else: 48 | print("\n" + message + "failed!") 49 | print("Passed : " + str(TEST_NUMBER - Failed) + "/" + str(TEST_NUMBER)) 50 | 51 | 52 | if __name__ == "__main__": 53 | for file_name in FILES: 54 | try: 55 | file_content = codecs.open( 56 | file_name, "r", "utf-8", 'ignore').read() 57 | for test_item in FILES[file_name]: 58 | if file_content.find(test_item.format(MEMOR_VERSION)) == -1: 59 | print("Incorrect version tag in " + file_name) 60 | Failed += 1 61 | break 62 | except Exception as e: 63 | Failed += 1 64 | print("Error in " + file_name + "\n" + "Message : " + str(e)) 65 | 66 | if Failed == 0: 67 | print_result(False) 68 | sys.exit(0) 69 | else: 70 | print_result(True) 71 | sys.exit(1) 72 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openscilab/memor/325e17b8ffc5eb62a9bb4044df63021b9968b94b/requirements.txt -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Setup module.""" 3 | try: 4 | from setuptools import setup 5 | except ImportError: 6 | from distutils.core import setup 7 | 8 | 9 | def get_requires() -> list: 10 | """Read requirements.txt.""" 11 | requirements = open("requirements.txt", "r").read() 12 | return list(filter(lambda x: x != "", requirements.split())) 13 | 14 | 15 | def read_description() -> str: 16 | """Read README.md and CHANGELOG.md.""" 17 | try: 18 | with open("README.md") as r: 19 | description = "\n" 20 | description += r.read() 21 | with open("CHANGELOG.md") as c: 22 | description += "\n" 23 | description += c.read() 24 | return description 25 | except Exception: 26 | return '''Memor is a library designed to help users manage the memory of their interactions with Large Language Models (LLMs). 27 | It enables users to seamlessly access and utilize the history of their conversations when prompting LLMs. 28 | That would create a more personalized and context-aware experience. 29 | Memor stands out by allowing users to transfer conversational history across different LLMs, eliminating cold starts where models don\'t have information about user and their preferences. 30 | Users can select specific parts of past interactions with one LLM and share them with another.By bridging the gap between isolated LLM instances, Memor revolutionizes the way users interact with AI by making transitions between models smoother.''' 31 | 32 | 33 | setup( 34 | name='memor', 35 | packages=[ 36 | 'memor', ], 37 | version='0.6', 38 | description='Memor: A Python Library for Managing and Transferring Conversational Memory Across LLMs', 39 | long_description=read_description(), 40 | long_description_content_type='text/markdown', 41 | author='Memor Development Team', 42 | author_email='memor@openscilab.com', 43 | url='https://github.com/openscilab/memor', 44 | download_url='https://github.com/openscilab/memor/tarball/v0.6', 45 | keywords="llm memory management conversational history ai agent", 46 | project_urls={ 47 | 'Source': 'https://github.com/openscilab/memor', 48 | }, 49 | install_requires=get_requires(), 50 | python_requires='>=3.7', 51 | classifiers=[ 52 | 'Development Status :: 3 - Alpha', 53 | 'Natural Language :: English', 54 | 'License :: OSI Approved :: MIT License', 55 | 'Operating System :: OS Independent', 56 | 'Programming Language :: Python :: 3.7', 57 | 'Programming Language :: Python :: 3.8', 58 | 'Programming Language :: Python :: 3.9', 59 | 'Programming Language :: Python :: 3.10', 60 | 'Programming Language :: Python :: 3.11', 61 | 'Programming Language :: Python :: 3.12', 62 | 'Programming Language :: Python :: 3.13', 63 | 'Intended Audience :: Developers', 64 | 'Intended Audience :: Education', 65 | 'Intended Audience :: End Users/Desktop', 66 | 'Intended Audience :: Manufacturing', 67 | 'Intended Audience :: Science/Research', 68 | 'Topic :: Education', 69 | 'Topic :: Scientific/Engineering', 70 | 'Topic :: Scientific/Engineering :: Information Analysis', 71 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 72 | ], 73 | license='MIT', 74 | ) 75 | -------------------------------------------------------------------------------- /tests/test_prompt.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import uuid 3 | import copy 4 | import pytest 5 | from memor import Prompt, Response, Role, LLMModel 6 | from memor import PresetPromptTemplate, PromptTemplate 7 | from memor import RenderFormat, MemorValidationError, MemorRenderError 8 | from memor import TokensEstimator 9 | 10 | TEST_CASE_NAME = "Prompt tests" 11 | 12 | 13 | def test_message1(): 14 | prompt = Prompt(message="Hello, how are you?") 15 | assert prompt.message == "Hello, how are you?" 16 | 17 | 18 | def test_message2(): 19 | prompt = Prompt(message="Hello, how are you?") 20 | prompt.update_message("What's Up?") 21 | assert prompt.message == "What's Up?" 22 | 23 | 24 | def test_message3(): 25 | prompt = Prompt(message="Hello, how are you?") 26 | with pytest.raises(MemorValidationError, match=r"Invalid value. `message` must be a string."): 27 | prompt.update_message(22) 28 | 29 | 30 | def test_tokens1(): 31 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 32 | assert prompt.tokens is None 33 | 34 | 35 | def test_tokens2(): 36 | prompt = Prompt(message="Hello, how are you?", role=Role.USER, tokens=4) 37 | assert prompt.tokens == 4 38 | 39 | 40 | def test_tokens3(): 41 | prompt = Prompt(message="Hello, how are you?", role=Role.USER, tokens=4) 42 | prompt.update_tokens(7) 43 | assert prompt.tokens == 7 44 | 45 | 46 | def test_tokens4(): 47 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 48 | with pytest.raises(MemorValidationError, match=r"Invalid value. `tokens` must be a positive integer."): 49 | prompt.update_tokens("4") 50 | 51 | 52 | def test_estimated_tokens1(): 53 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 54 | assert prompt.estimate_tokens(TokensEstimator.UNIVERSAL) == 7 55 | 56 | 57 | def test_estimated_tokens2(): 58 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 59 | assert prompt.estimate_tokens(TokensEstimator.OPENAI_GPT_3_5) == 7 60 | 61 | 62 | def test_estimated_tokens3(): 63 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 64 | assert prompt.estimate_tokens(TokensEstimator.OPENAI_GPT_4) == 8 65 | 66 | 67 | def test_role1(): 68 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 69 | assert prompt.role == Role.USER 70 | 71 | 72 | def test_role2(): 73 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 74 | prompt.update_role(Role.SYSTEM) 75 | assert prompt.role == Role.SYSTEM 76 | 77 | 78 | def test_role3(): 79 | prompt = Prompt(message="Hello, how are you?", role=None) 80 | assert prompt.role == Role.USER 81 | 82 | 83 | def test_role4(): 84 | prompt = Prompt(message="Hello, how are you?", role=None) 85 | with pytest.raises(MemorValidationError, match=r"Invalid role. It must be an instance of Role enum."): 86 | prompt.update_role(2) 87 | 88 | 89 | def test_id1(): 90 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 91 | assert uuid.UUID(prompt.id, version=4) == uuid.UUID(prompt._id, version=4) 92 | 93 | 94 | def test_id2(): 95 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 96 | prompt._id = "123" 97 | _ = prompt.save("prompt_test3.json") 98 | with pytest.raises(MemorValidationError, match=r"Invalid message ID. It must be a valid UUIDv4."): 99 | _ = Prompt(file_path="prompt_test3.json") 100 | 101 | 102 | def test_responses1(): 103 | message = "Hello, how are you?" 104 | response = Response(message="I am fine.") 105 | prompt = Prompt(message=message, responses=[response]) 106 | assert prompt.responses[0].message == "I am fine." 107 | 108 | 109 | def test_responses2(): 110 | message = "Hello, how are you?" 111 | response0 = Response(message="I am fine.") 112 | response1 = Response(message="Good!") 113 | prompt = Prompt(message=message, responses=[response0, response1]) 114 | assert prompt.responses[0].message == "I am fine." and prompt.responses[1].message == "Good!" 115 | 116 | 117 | def test_responses3(): 118 | message = "Hello, how are you?" 119 | response0 = Response(message="I am fine.") 120 | response1 = Response(message="Good!") 121 | prompt = Prompt(message=message) 122 | prompt.update_responses([response0, response1]) 123 | assert prompt.responses[0].message == "I am fine." and prompt.responses[1].message == "Good!" 124 | 125 | 126 | def test_responses4(): 127 | message = "Hello, how are you?" 128 | prompt = Prompt(message=message) 129 | with pytest.raises(MemorValidationError, match=r"Invalid value. `responses` must be a list of `Response`."): 130 | prompt.update_responses({"I am fine.", "Good!"}) 131 | 132 | 133 | def test_responses5(): 134 | message = "Hello, how are you?" 135 | response0 = Response(message="I am fine.") 136 | prompt = Prompt(message=message) 137 | with pytest.raises(MemorValidationError, match=r"Invalid value. `responses` must be a list of `Response`."): 138 | prompt.update_responses([response0, "Good!"]) 139 | 140 | 141 | def test_add_response1(): 142 | message = "Hello, how are you?" 143 | response0 = Response(message="I am fine.") 144 | prompt = Prompt(message=message, responses=[response0]) 145 | response1 = Response(message="Great!") 146 | prompt.add_response(response1) 147 | assert prompt.responses[0] == response0 and prompt.responses[1] == response1 148 | 149 | 150 | def test_add_response2(): 151 | message = "Hello, how are you?" 152 | response0 = Response(message="I am fine.") 153 | prompt = Prompt(message=message, responses=[response0]) 154 | response1 = Response(message="Great!") 155 | prompt.add_response(response1, index=0) 156 | assert prompt.responses[0] == response1 and prompt.responses[1] == response0 157 | 158 | 159 | def test_add_response3(): 160 | message = "Hello, how are you?" 161 | response0 = Response(message="I am fine.") 162 | prompt = Prompt(message=message, responses=[response0]) 163 | with pytest.raises(MemorValidationError, match=r"Invalid response. It must be an instance of `Response`."): 164 | prompt.add_response(1) 165 | 166 | 167 | def test_remove_response(): 168 | message = "Hello, how are you?" 169 | response0 = Response(message="I am fine.") 170 | response1 = Response(message="Great!") 171 | prompt = Prompt(message=message, responses=[response0, response1]) 172 | prompt.remove_response(0) 173 | assert response0 not in prompt.responses 174 | 175 | 176 | def test_select_response(): 177 | message = "Hello, how are you?" 178 | response0 = Response(message="I am fine.") 179 | prompt = Prompt(message=message, responses=[response0]) 180 | response1 = Response(message="Great!") 181 | prompt.add_response(response1) 182 | prompt.select_response(index=1) 183 | assert prompt.selected_response == response1 184 | 185 | 186 | def test_template1(): 187 | message = "Hello, how are you?" 188 | prompt = Prompt(message=message, template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD, init_check=False) 189 | assert prompt.template == PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD.value 190 | 191 | 192 | def test_template2(): 193 | message = "Hello, how are you?" 194 | prompt = Prompt(message=message, template=PresetPromptTemplate.BASIC.RESPONSE, init_check=False) 195 | prompt.update_template(PresetPromptTemplate.INSTRUCTION1.PROMPT) 196 | assert prompt.template.content == PresetPromptTemplate.INSTRUCTION1.PROMPT.value.content 197 | 198 | 199 | def test_template3(): 200 | message = "Hello, how are you?" 201 | template = PromptTemplate(content="{message}-{response}") 202 | prompt = Prompt(message=message, template=template, init_check=False) 203 | assert prompt.template.content == "{message}-{response}" 204 | 205 | 206 | def test_template4(): 207 | message = "Hello, how are you?" 208 | prompt = Prompt(message=message, template=None) 209 | assert prompt.template == PresetPromptTemplate.DEFAULT.value 210 | 211 | 212 | def test_template5(): 213 | message = "Hello, how are you?" 214 | prompt = Prompt(message=message, template=PresetPromptTemplate.BASIC.RESPONSE, init_check=False) 215 | with pytest.raises(MemorValidationError, match=r"Invalid template. It must be an instance of `PromptTemplate` or `PresetPromptTemplate`."): 216 | prompt.update_template("{prompt_message}") 217 | 218 | 219 | def test_copy1(): 220 | message = "Hello, how are you?" 221 | response = Response(message="I am fine.") 222 | prompt1 = Prompt(message=message, responses=[response], role=Role.USER, 223 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 224 | prompt2 = copy.copy(prompt1) 225 | assert id(prompt1) != id(prompt2) and prompt1.id != prompt2.id 226 | 227 | 228 | def test_copy2(): 229 | message = "Hello, how are you?" 230 | response = Response(message="I am fine.") 231 | prompt1 = Prompt(message=message, responses=[response], role=Role.USER, 232 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 233 | prompt2 = prompt1.copy() 234 | assert id(prompt1) != id(prompt2) and prompt1.id != prompt2.id 235 | 236 | 237 | def test_str(): 238 | message = "Hello, how are you?" 239 | response = Response(message="I am fine.") 240 | prompt = Prompt(message=message, responses=[response], role=Role.USER, 241 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 242 | assert str(prompt) == prompt.render(render_format=RenderFormat.STRING) 243 | 244 | 245 | def test_repr(): 246 | message = "Hello, how are you?" 247 | response = Response(message="I am fine.") 248 | prompt = Prompt(message=message, responses=[response], role=Role.USER, 249 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 250 | assert repr(prompt) == "Prompt(message={message})".format(message=prompt.message) 251 | 252 | 253 | def test_json1(): 254 | message = "Hello, how are you?" 255 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 256 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 257 | prompt1 = Prompt( 258 | message=message, 259 | responses=[ 260 | response1, 261 | response2], 262 | role=Role.USER, 263 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 264 | prompt1_json = prompt1.to_json() 265 | prompt2 = Prompt() 266 | prompt2.from_json(prompt1_json) 267 | assert prompt1 == prompt2 268 | 269 | 270 | def test_json2(): 271 | message = "Hello, how are you?" 272 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 273 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 274 | prompt1 = Prompt( 275 | message=message, 276 | responses=[ 277 | response1, 278 | response2], 279 | role=Role.USER, 280 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 281 | prompt1_json = prompt1.to_json(save_template=False) 282 | prompt2 = Prompt() 283 | prompt2.from_json(prompt1_json) 284 | assert prompt1 != prompt2 and prompt1.template == PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD.value and prompt2.template == PresetPromptTemplate.DEFAULT.value 285 | 286 | 287 | def test_json3(): 288 | prompt = Prompt() 289 | with pytest.raises(MemorValidationError, match=r"Invalid prompt structure. It should be a JSON object with proper fields."): 290 | prompt.from_json("{}") 291 | 292 | 293 | def test_save1(): 294 | message = "Hello, how are you?" 295 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 296 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 297 | prompt = Prompt( 298 | message=message, 299 | responses=[ 300 | response1, 301 | response2], 302 | role=Role.USER, 303 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 304 | result = prompt.save("f:/") 305 | assert result["status"] == False 306 | 307 | 308 | def test_save2(): 309 | message = "Hello, how are you?" 310 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 311 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 312 | prompt1 = Prompt( 313 | message=message, 314 | responses=[ 315 | response1, 316 | response2], 317 | role=Role.USER, 318 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 319 | result = prompt1.save("prompt_test1.json") 320 | prompt2 = Prompt(file_path="prompt_test1.json") 321 | assert result["status"] and prompt1 == prompt2 322 | 323 | 324 | def test_load1(): 325 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: 22"): 326 | _ = Prompt(file_path=22) 327 | 328 | 329 | def test_load2(): 330 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: prompt_test10.json"): 331 | _ = Prompt(file_path="prompt_test10.json") 332 | 333 | 334 | def test_save3(): 335 | message = "Hello, how are you?" 336 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 337 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 338 | prompt1 = Prompt( 339 | message=message, 340 | responses=[ 341 | response1, 342 | response2], 343 | role=Role.USER, 344 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 345 | result = prompt1.save("prompt_test2.json", save_template=False) 346 | prompt2 = Prompt(file_path="prompt_test2.json") 347 | assert result["status"] and prompt1 != prompt2 and prompt1.template == PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD.value and prompt2.template == PresetPromptTemplate.DEFAULT.value 348 | 349 | 350 | def test_render1(): 351 | message = "Hello, how are you?" 352 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 353 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 354 | prompt = Prompt( 355 | message=message, 356 | responses=[ 357 | response1, 358 | response2], 359 | role=Role.USER, 360 | template=PresetPromptTemplate.BASIC.PROMPT) 361 | assert prompt.render() == "Hello, how are you?" 362 | 363 | 364 | def test_render2(): 365 | message = "Hello, how are you?" 366 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 367 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 368 | prompt = Prompt( 369 | message=message, 370 | responses=[ 371 | response1, 372 | response2], 373 | role=Role.USER, 374 | template=PresetPromptTemplate.BASIC.PROMPT) 375 | assert prompt.render(RenderFormat.OPENAI) == {"role": "user", "content": "Hello, how are you?"} 376 | 377 | 378 | def test_render3(): 379 | message = "Hello, how are you?" 380 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 381 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 382 | prompt = Prompt( 383 | message=message, 384 | responses=[ 385 | response1, 386 | response2], 387 | role=Role.USER, 388 | template=PresetPromptTemplate.BASIC.PROMPT) 389 | assert prompt.render(RenderFormat.DICTIONARY)["content"] == "Hello, how are you?" 390 | 391 | 392 | def test_render4(): 393 | message = "Hello, how are you?" 394 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 395 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 396 | prompt = Prompt( 397 | message=message, 398 | responses=[ 399 | response1, 400 | response2], 401 | role=Role.USER, 402 | template=PresetPromptTemplate.BASIC.PROMPT) 403 | assert ("content", "Hello, how are you?") in prompt.render(RenderFormat.ITEMS) 404 | 405 | 406 | def test_render5(): 407 | message = "How are you?" 408 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 409 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 410 | template = PromptTemplate(content="{instruction}, {prompt[message]}", custom_map={"instruction": "Hi"}) 411 | prompt = Prompt( 412 | message=message, 413 | responses=[ 414 | response1, 415 | response2], 416 | role=Role.USER, 417 | template=template) 418 | assert prompt.render(RenderFormat.OPENAI) == {"role": "user", "content": "Hi, How are you?"} 419 | 420 | 421 | def test_render6(): 422 | message = "Hello, how are you?" 423 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 424 | template = PromptTemplate(content="{response[2][message]}") 425 | prompt = Prompt( 426 | message=message, 427 | responses=[response], 428 | role=Role.USER, 429 | template=template, 430 | init_check=False) 431 | with pytest.raises(MemorRenderError, match=r"Prompt template and properties are incompatible."): 432 | prompt.render() 433 | 434 | 435 | def test_render7(): 436 | message = "How are you?" 437 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 438 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 439 | template = PromptTemplate(content="{instruction}, {prompt[message]}", custom_map={"instruction": "Hi"}) 440 | prompt = Prompt( 441 | message=message, 442 | responses=[ 443 | response1, 444 | response2], 445 | role=Role.USER, 446 | template=template) 447 | assert prompt.render(RenderFormat.AI_STUDIO) == {'role': 'user', 'parts': [{'text': 'Hi, How are you?'}]} 448 | 449 | 450 | def test_init_check(): 451 | message = "Hello, how are you?" 452 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 453 | template = PromptTemplate(content="{response[2][message]}") 454 | with pytest.raises(MemorRenderError, match=r"Prompt template and properties are incompatible."): 455 | _ = Prompt(message=message, responses=[response], role=Role.USER, template=template) 456 | 457 | 458 | def test_check_render1(): 459 | message = "Hello, how are you?" 460 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 461 | template = PromptTemplate(content="{response[2][message]}") 462 | prompt = Prompt( 463 | message=message, 464 | responses=[response], 465 | role=Role.USER, 466 | template=template, 467 | init_check=False) 468 | assert not prompt.check_render() 469 | 470 | 471 | def test_check_render2(): 472 | message = "How are you?" 473 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 474 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 475 | template = PromptTemplate(content="{instruction}, {prompt[message]}", custom_map={"instruction": "Hi"}) 476 | prompt = Prompt( 477 | message=message, 478 | responses=[ 479 | response1, 480 | response2], 481 | role=Role.USER, 482 | template=template) 483 | assert prompt.check_render() 484 | 485 | 486 | def test_equality1(): 487 | message = "Hello, how are you?" 488 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 489 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 490 | prompt1 = Prompt( 491 | message=message, 492 | responses=[ 493 | response1, 494 | response2], 495 | role=Role.USER, 496 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 497 | prompt2 = prompt1.copy() 498 | assert prompt1 == prompt2 499 | 500 | 501 | def test_equality2(): 502 | message = "Hello, how are you?" 503 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 504 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 505 | prompt1 = Prompt(message=message, responses=[response1], role=Role.USER, 506 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 507 | prompt2 = Prompt(message=message, responses=[response2], role=Role.USER, 508 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 509 | assert prompt1 != prompt2 510 | 511 | 512 | def test_equality3(): 513 | message = "Hello, how are you?" 514 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 515 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 516 | prompt1 = Prompt( 517 | message=message, 518 | responses=[ 519 | response1, 520 | response2], 521 | role=Role.USER, 522 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 523 | prompt2 = Prompt( 524 | message=message, 525 | responses=[ 526 | response1, 527 | response2], 528 | role=Role.USER, 529 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 530 | assert prompt1 == prompt2 531 | 532 | 533 | def test_equality4(): 534 | message = "Hello, how are you?" 535 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 536 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 537 | prompt = Prompt( 538 | message=message, 539 | responses=[ 540 | response1, 541 | response2], 542 | role=Role.USER, 543 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 544 | assert prompt != 2 545 | 546 | 547 | def test_length1(): 548 | prompt = Prompt(message="Hello, how are you?") 549 | assert len(prompt) == 19 550 | 551 | 552 | def test_length2(): 553 | message = "Hello, how are you?" 554 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 555 | template = PromptTemplate(content="{response[2][message]}") 556 | prompt = Prompt( 557 | message=message, 558 | responses=[response], 559 | role=Role.USER, 560 | template=template, 561 | init_check=False) 562 | assert len(prompt) == 0 563 | 564 | 565 | def test_length3(): 566 | prompt = Prompt() 567 | assert len(prompt) == 0 568 | 569 | 570 | def test_date_modified(): 571 | message = "Hello, how are you?" 572 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 573 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 574 | prompt = Prompt( 575 | message=message, 576 | responses=[ 577 | response1, 578 | response2], 579 | role=Role.USER, 580 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 581 | assert isinstance(prompt.date_modified, datetime.datetime) 582 | 583 | 584 | def test_date_created(): 585 | message = "Hello, how are you?" 586 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 587 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 588 | prompt = Prompt( 589 | message=message, 590 | responses=[ 591 | response1, 592 | response2], 593 | role=Role.USER, 594 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD) 595 | assert isinstance(prompt.date_created, datetime.datetime) 596 | -------------------------------------------------------------------------------- /tests/test_prompt_template.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import copy 4 | import pytest 5 | from memor import PromptTemplate, MemorValidationError 6 | 7 | TEST_CASE_NAME = "PromptTemplate tests" 8 | 9 | 10 | def test_title1(): 11 | template = PromptTemplate( 12 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 13 | custom_map={ 14 | "language": "Python"}) 15 | assert template.title is None 16 | 17 | 18 | def test_title2(): 19 | template = PromptTemplate( 20 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 21 | custom_map={ 22 | "language": "Python"}) 23 | template.update_title("template1") 24 | assert template.title == "template1" 25 | 26 | 27 | def test_title3(): 28 | template = PromptTemplate( 29 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 30 | custom_map={ 31 | "language": "Python"}, 32 | title=None) 33 | assert template.title is None 34 | 35 | 36 | def test_title4(): 37 | template = PromptTemplate( 38 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 39 | custom_map={ 40 | "language": "Python"}, 41 | title=None) 42 | with pytest.raises(MemorValidationError, match=r"Invalid value. `title` must be a string."): 43 | template.update_title(25) 44 | 45 | 46 | def test_content1(): 47 | template = PromptTemplate( 48 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 49 | custom_map={ 50 | "language": "Python"}) 51 | assert template.content == "Act as a {language} developer and respond to this question:\n{prompt_message}" 52 | 53 | 54 | def test_content2(): 55 | template = PromptTemplate( 56 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 57 | custom_map={ 58 | "language": "Python"}) 59 | template.update_content(content="Act as a {language} developer and respond to this query:\n{prompt_message}") 60 | assert template.content == "Act as a {language} developer and respond to this query:\n{prompt_message}" 61 | 62 | 63 | def test_content3(): 64 | template = PromptTemplate( 65 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 66 | custom_map={ 67 | "language": "Python"}) 68 | with pytest.raises(MemorValidationError, match=r"Invalid value. `content` must be a string."): 69 | template.update_content(content=22) 70 | 71 | 72 | def test_custom_map1(): 73 | template = PromptTemplate( 74 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 75 | custom_map={ 76 | "language": "Python"}) 77 | assert template.custom_map == {"language": "Python"} 78 | 79 | 80 | def test_custom_map2(): 81 | template = PromptTemplate( 82 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 83 | custom_map={ 84 | "language": "Python"}) 85 | template.update_map({"language": "C++"}) 86 | assert template.custom_map == {"language": "C++"} 87 | 88 | 89 | def test_custom_map3(): 90 | template = PromptTemplate( 91 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 92 | custom_map={ 93 | "language": "Python"}) 94 | with pytest.raises(MemorValidationError, match=r"Invalid custom map: it must be a dictionary with keys and values that can be converted to strings."): 95 | template.update_map(["C++"]) 96 | 97 | 98 | def test_date_modified(): 99 | template = PromptTemplate( 100 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 101 | custom_map={ 102 | "language": "Python"}) 103 | assert isinstance(template.date_modified, datetime.datetime) 104 | 105 | 106 | def test_date_created(): 107 | template = PromptTemplate( 108 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 109 | custom_map={ 110 | "language": "Python"}) 111 | assert isinstance(template.date_created, datetime.datetime) 112 | 113 | 114 | def test_json1(): 115 | template1 = PromptTemplate( 116 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 117 | custom_map={ 118 | "language": "Python"}) 119 | template1_json = template1.to_json() 120 | template2 = PromptTemplate() 121 | template2.from_json(template1_json) 122 | assert template1 == template2 123 | 124 | 125 | def test_json2(): 126 | template = PromptTemplate() 127 | with pytest.raises(MemorValidationError, match=r"Invalid template structure. It should be a JSON object with proper fields."): 128 | template.from_json("{}") 129 | 130 | 131 | def test_save1(): 132 | template = PromptTemplate( 133 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 134 | custom_map={ 135 | "language": "Python"}) 136 | result = template.save("template_test1.json") 137 | with open("template_test1.json", "r") as file: 138 | saved_template = json.loads(file.read()) 139 | assert result["status"] and template.to_json() == saved_template 140 | 141 | 142 | def test_save2(): 143 | template = PromptTemplate( 144 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 145 | custom_map={ 146 | "language": "Python"}) 147 | result = template.save("f:/") 148 | assert result["status"] == False 149 | 150 | 151 | def test_load1(): 152 | template1 = PromptTemplate( 153 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 154 | custom_map={ 155 | "language": "Python"}) 156 | result = template1.save("template_test2.json") 157 | template2 = PromptTemplate(file_path="template_test2.json") 158 | assert result["status"] and template1 == template2 159 | 160 | 161 | def test_load2(): 162 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: 22"): 163 | _ = PromptTemplate(file_path=22) 164 | 165 | 166 | def test_load3(): 167 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: template_test10.json"): 168 | _ = PromptTemplate(file_path="template_test10.json") 169 | 170 | 171 | def test_copy1(): 172 | template1 = PromptTemplate( 173 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 174 | custom_map={ 175 | "language": "Python"}) 176 | template2 = copy.copy(template1) 177 | assert id(template1) != id(template2) 178 | 179 | 180 | def test_copy2(): 181 | template1 = PromptTemplate( 182 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 183 | custom_map={ 184 | "language": "Python"}) 185 | template2 = template1.copy() 186 | assert id(template1) != id(template2) 187 | 188 | 189 | def test_str(): 190 | template = PromptTemplate( 191 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 192 | custom_map={ 193 | "language": "Python"}) 194 | assert str(template) == template.content 195 | 196 | 197 | def test_repr(): 198 | template = PromptTemplate( 199 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 200 | custom_map={ 201 | "language": "Python"}) 202 | assert repr(template) == "PromptTemplate(content={content})".format(content=template.content) 203 | 204 | 205 | def test_equality1(): 206 | template1 = PromptTemplate( 207 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 208 | custom_map={ 209 | "language": "Python"}) 210 | template2 = template1.copy() 211 | assert template1 == template2 212 | 213 | 214 | def test_equality2(): 215 | template1 = PromptTemplate( 216 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 217 | custom_map={ 218 | "language": "Python"}, 219 | title="template1") 220 | template2 = PromptTemplate( 221 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 222 | custom_map={ 223 | "language": "Python"}, 224 | title="template2") 225 | assert template1 != template2 226 | 227 | 228 | def test_equality3(): 229 | template1 = PromptTemplate( 230 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 231 | custom_map={ 232 | "language": "Python"}, 233 | title="template1") 234 | template2 = PromptTemplate( 235 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 236 | custom_map={ 237 | "language": "Python"}, 238 | title="template1") 239 | assert template1 == template2 240 | 241 | 242 | def test_equality4(): 243 | template = PromptTemplate( 244 | content="Act as a {language} developer and respond to this question:\n{prompt_message}", 245 | custom_map={ 246 | "language": "Python"}, 247 | title="template1") 248 | assert template != 2 249 | -------------------------------------------------------------------------------- /tests/test_response.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import uuid 3 | import json 4 | import copy 5 | import pytest 6 | from memor import Response, Role, LLMModel, MemorValidationError 7 | from memor import RenderFormat 8 | from memor import TokensEstimator 9 | 10 | TEST_CASE_NAME = "Response tests" 11 | 12 | 13 | def test_message1(): 14 | response = Response(message="I am fine.") 15 | assert response.message == "I am fine." 16 | 17 | 18 | def test_message2(): 19 | response = Response(message="I am fine.") 20 | response.update_message("OK!") 21 | assert response.message == "OK!" 22 | 23 | 24 | def test_message3(): 25 | response = Response(message="I am fine.") 26 | with pytest.raises(MemorValidationError, match=r"Invalid value. `message` must be a string."): 27 | response.update_message(22) 28 | 29 | 30 | def test_tokens1(): 31 | response = Response(message="I am fine.") 32 | assert response.tokens is None 33 | 34 | 35 | def test_tokens2(): 36 | response = Response(message="I am fine.", tokens=4) 37 | assert response.tokens == 4 38 | 39 | 40 | def test_tokens3(): 41 | response = Response(message="I am fine.", tokens=4) 42 | response.update_tokens(6) 43 | assert response.tokens == 6 44 | 45 | 46 | def test_estimated_tokens1(): 47 | response = Response(message="I am fine.") 48 | assert response.estimate_tokens(TokensEstimator.UNIVERSAL) == 5 49 | 50 | 51 | def test_estimated_tokens2(): 52 | response = Response(message="I am fine.") 53 | assert response.estimate_tokens(TokensEstimator.OPENAI_GPT_3_5) == 4 54 | 55 | 56 | def test_estimated_tokens3(): 57 | response = Response(message="I am fine.") 58 | assert response.estimate_tokens(TokensEstimator.OPENAI_GPT_4) == 4 59 | 60 | 61 | def test_tokens4(): 62 | response = Response(message="I am fine.", tokens=4) 63 | with pytest.raises(MemorValidationError, match=r"Invalid value. `tokens` must be a positive integer."): 64 | response.update_tokens(-2) 65 | 66 | 67 | def test_inference_time1(): 68 | response = Response(message="I am fine.") 69 | assert response.inference_time is None 70 | 71 | 72 | def test_inference_time2(): 73 | response = Response(message="I am fine.", inference_time=8.2) 74 | assert response.inference_time == 8.2 75 | 76 | 77 | def test_inference_time3(): 78 | response = Response(message="I am fine.", inference_time=8.2) 79 | response.update_inference_time(9.5) 80 | assert response.inference_time == 9.5 81 | 82 | 83 | def test_inference_time4(): 84 | response = Response(message="I am fine.", inference_time=8.2) 85 | with pytest.raises(MemorValidationError, match=r"Invalid value. `inference_time` must be a positive float."): 86 | response.update_inference_time(-5) 87 | 88 | 89 | def test_score1(): 90 | response = Response(message="I am fine.", score=0.9) 91 | assert response.score == 0.9 92 | 93 | 94 | def test_score2(): 95 | response = Response(message="I am fine.", score=0.9) 96 | response.update_score(0.5) 97 | assert response.score == 0.5 98 | 99 | 100 | def test_score3(): 101 | response = Response(message="I am fine.", score=0.9) 102 | with pytest.raises(MemorValidationError, match=r"Invalid value. `score` must be a value between 0 and 1."): 103 | response.update_score(-2) 104 | 105 | 106 | def test_role1(): 107 | response = Response(message="I am fine.", role=Role.ASSISTANT) 108 | assert response.role == Role.ASSISTANT 109 | 110 | 111 | def test_role2(): 112 | response = Response(message="I am fine.", role=Role.ASSISTANT) 113 | response.update_role(Role.USER) 114 | assert response.role == Role.USER 115 | 116 | 117 | def test_role3(): 118 | response = Response(message="I am fine.", role=None) 119 | assert response.role == Role.ASSISTANT 120 | 121 | 122 | def test_role4(): 123 | response = Response(message="I am fine.", role=Role.ASSISTANT) 124 | with pytest.raises(MemorValidationError, match=r"Invalid role. It must be an instance of Role enum."): 125 | response.update_role(2) 126 | 127 | 128 | def test_temperature1(): 129 | response = Response(message="I am fine.", temperature=0.2) 130 | assert response.temperature == 0.2 131 | 132 | 133 | def test_temperature2(): 134 | response = Response(message="I am fine.", temperature=0.2) 135 | response.update_temperature(0.7) 136 | assert response.temperature == 0.7 137 | 138 | 139 | def test_temperature3(): 140 | response = Response(message="I am fine.", temperature=0.2) 141 | with pytest.raises(MemorValidationError, match=r"Invalid value. `temperature` must be a positive float."): 142 | response.update_temperature(-22) 143 | 144 | 145 | def test_model1(): 146 | response = Response(message="I am fine.", model=LLMModel.GPT_4) 147 | assert response.model == LLMModel.GPT_4.value 148 | 149 | 150 | def test_model2(): 151 | response = Response(message="I am fine.", model=LLMModel.GPT_4) 152 | response.update_model(LLMModel.GPT_4O) 153 | assert response.model == LLMModel.GPT_4O.value 154 | 155 | 156 | def test_model3(): 157 | response = Response(message="I am fine.", model=LLMModel.GPT_4) 158 | response.update_model("my-trained-llm-instruct") 159 | assert response.model == "my-trained-llm-instruct" 160 | 161 | 162 | def test_model4(): 163 | response = Response(message="I am fine.", model=LLMModel.GPT_4) 164 | with pytest.raises(MemorValidationError, match=r"Invalid model. It must be an instance of LLMModel enum."): 165 | response.update_model(4) 166 | 167 | 168 | def test_id1(): 169 | response = Response(message="I am fine.", model=LLMModel.GPT_4) 170 | assert uuid.UUID(response.id, version=4) == uuid.UUID(response._id, version=4) 171 | 172 | 173 | def test_id2(): 174 | response = Response(message="I am fine.", model=LLMModel.GPT_4) 175 | response._id = "123" 176 | _ = response.save("response_test3.json") 177 | with pytest.raises(MemorValidationError, match=r"Invalid message ID. It must be a valid UUIDv4."): 178 | _ = Response(file_path="response_test3.json") 179 | 180 | 181 | def test_date1(): 182 | date_time_utc = datetime.datetime.now(datetime.timezone.utc) 183 | response = Response(message="I am fine.", date=date_time_utc) 184 | assert response.date_created == date_time_utc 185 | 186 | 187 | def test_date2(): 188 | response = Response(message="I am fine.", date=None) 189 | assert isinstance(response.date_created, datetime.datetime) 190 | 191 | 192 | def test_date3(): 193 | with pytest.raises(MemorValidationError, match=r"Invalid value. `date` must be a datetime object that includes timezone information."): 194 | _ = Response(message="I am fine.", date="2/25/2025") 195 | 196 | 197 | def test_date4(): 198 | with pytest.raises(MemorValidationError, match=r"Invalid value. `date` must be a datetime object that includes timezone information."): 199 | _ = Response(message="I am fine.", date=datetime.datetime.now()) 200 | 201 | 202 | def test_json1(): 203 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 204 | response1_json = response1.to_json() 205 | response2 = Response() 206 | response2.from_json(response1_json) 207 | assert response1 == response2 208 | 209 | 210 | def test_json2(): 211 | response = Response() 212 | with pytest.raises(MemorValidationError, match=r"Invalid response structure. It should be a JSON object with proper fields."): 213 | response.from_json("{}") 214 | 215 | 216 | def test_save1(): 217 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 218 | result = response.save("response_test1.json") 219 | with open("response_test1.json", "r") as file: 220 | saved_response = json.loads(file.read()) 221 | assert result["status"] and response.to_json() == saved_response 222 | 223 | 224 | def test_save2(): 225 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 226 | result = response.save("f:/") 227 | assert result["status"] == False 228 | 229 | 230 | def test_load1(): 231 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 232 | result = response1.save("response_test2.json") 233 | response2 = Response(file_path="response_test2.json") 234 | assert result["status"] and response1 == response2 235 | 236 | 237 | def test_load2(): 238 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: 2"): 239 | response = Response(file_path=2) 240 | 241 | 242 | def test_load3(): 243 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: response_test10.json"): 244 | response = Response(file_path="response_test10.json") 245 | 246 | 247 | def test_copy1(): 248 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 249 | response2 = copy.copy(response1) 250 | assert id(response1) != id(response2) and response1.id != response2.id 251 | 252 | 253 | def test_copy2(): 254 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 255 | response2 = response1.copy() 256 | assert id(response1) != id(response2) and response1.id != response2.id 257 | 258 | 259 | def test_str(): 260 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 261 | assert str(response) == response.message 262 | 263 | 264 | def test_repr(): 265 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 266 | assert repr(response) == "Response(message={message})".format(message=response.message) 267 | 268 | 269 | def test_render1(): 270 | response = Response(message="I am fine.") 271 | assert response.render() == "I am fine." 272 | 273 | 274 | def test_render2(): 275 | response = Response(message="I am fine.") 276 | assert response.render(RenderFormat.OPENAI) == {"role": "assistant", "content": "I am fine."} 277 | 278 | 279 | def test_render3(): 280 | response = Response(message="I am fine.") 281 | assert response.render(RenderFormat.DICTIONARY) == response.to_dict() 282 | 283 | 284 | def test_render4(): 285 | response = Response(message="I am fine.") 286 | assert response.render(RenderFormat.ITEMS) == response.to_dict().items() 287 | 288 | 289 | def test_render5(): 290 | response = Response(message="I am fine.") 291 | with pytest.raises(MemorValidationError, match=r"Invalid render format. It must be an instance of RenderFormat enum."): 292 | response.render("OPENAI") 293 | 294 | 295 | def test_render6(): 296 | response = Response(message="I am fine.") 297 | assert response.render(RenderFormat.AI_STUDIO) == {'role': 'assistant', 'parts': [{'text': 'I am fine.'}]} 298 | 299 | 300 | def test_equality1(): 301 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 302 | response2 = response1.copy() 303 | assert response1 == response2 304 | 305 | 306 | def test_equality2(): 307 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 308 | response2 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.6) 309 | assert response1 != response2 310 | 311 | 312 | def test_equality3(): 313 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 314 | response2 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 315 | assert response1 == response2 316 | 317 | 318 | def test_equality4(): 319 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 320 | assert response != 2 321 | 322 | 323 | def test_length1(): 324 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 325 | assert len(response) == 10 326 | 327 | 328 | def test_length2(): 329 | response = Response() 330 | assert len(response) == 0 331 | 332 | 333 | def test_date_modified(): 334 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 335 | assert isinstance(response.date_modified, datetime.datetime) 336 | 337 | 338 | def test_date_created(): 339 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8) 340 | assert isinstance(response.date_created, datetime.datetime) 341 | -------------------------------------------------------------------------------- /tests/test_session.py: -------------------------------------------------------------------------------- 1 | import re 2 | import datetime 3 | import copy 4 | import pytest 5 | from memor import Session, Prompt, Response, Role 6 | from memor import PromptTemplate 7 | from memor import RenderFormat 8 | from memor import MemorRenderError, MemorValidationError 9 | from memor import TokensEstimator 10 | 11 | TEST_CASE_NAME = "Session tests" 12 | 13 | 14 | def test_title1(): 15 | session = Session(title="session1") 16 | assert session.title == "session1" 17 | 18 | 19 | def test_title2(): 20 | session = Session(title="session1") 21 | session.update_title("session2") 22 | assert session.title == "session2" 23 | 24 | 25 | def test_title3(): 26 | session = Session(title="session1") 27 | with pytest.raises(MemorValidationError, match=r"Invalid value. `title` must be a string."): 28 | session.update_title(2) 29 | 30 | 31 | def test_messages1(): 32 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 33 | response = Response(message="I am fine.") 34 | session = Session(messages=[prompt, response]) 35 | assert session.messages == [prompt, response] 36 | 37 | 38 | def test_messages2(): 39 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 40 | response = Response(message="I am fine.") 41 | session = Session(messages=[prompt, response]) 42 | session.update_messages([prompt, response, prompt, response]) 43 | assert session.messages == [prompt, response, prompt, response] 44 | 45 | 46 | def test_messages3(): 47 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 48 | session = Session(messages=[prompt]) 49 | with pytest.raises(MemorValidationError, match=r"Invalid value. `messages` must be a list of `Prompt` or `Response`."): 50 | session.update_messages([prompt, "I am fine."]) 51 | 52 | 53 | def test_messages4(): 54 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 55 | session = Session(messages=[prompt]) 56 | with pytest.raises(MemorValidationError, match=r"Invalid value. `messages` must be a list of `Prompt` or `Response`."): 57 | session.update_messages("I am fine.") 58 | 59 | 60 | def test_messages_status1(): 61 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 62 | response = Response(message="I am fine.") 63 | session = Session(messages=[prompt, response]) 64 | assert session.messages_status == [True, True] 65 | 66 | 67 | def test_messages_status2(): 68 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 69 | response = Response(message="I am fine.") 70 | session = Session(messages=[prompt, response]) 71 | session.update_messages_status([False, True]) 72 | assert session.messages_status == [False, True] 73 | 74 | 75 | def test_messages_status3(): 76 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 77 | response = Response(message="I am fine.") 78 | session = Session(messages=[prompt, response]) 79 | with pytest.raises(MemorValidationError, match=r"Invalid value. `status` must be a list of booleans."): 80 | session.update_messages_status(["False", True]) 81 | 82 | 83 | def test_messages_status4(): 84 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 85 | response = Response(message="I am fine.") 86 | session = Session(messages=[prompt, response]) 87 | with pytest.raises(MemorValidationError, match=r"Invalid message status length. It must be equal to the number of messages."): 88 | session.update_messages_status([False, True, True]) 89 | 90 | 91 | def test_enable_message(): 92 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 93 | response = Response(message="I am fine.") 94 | session = Session(messages=[prompt, response]) 95 | session.update_messages_status([False, False]) 96 | session.enable_message(0) 97 | assert session.messages_status == [True, False] 98 | 99 | 100 | def test_disable_message(): 101 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 102 | response = Response(message="I am fine.") 103 | session = Session(messages=[prompt, response]) 104 | session.update_messages_status([True, True]) 105 | session.disable_message(0) 106 | assert session.messages_status == [False, True] 107 | 108 | 109 | def test_mask_message(): 110 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 111 | response = Response(message="I am fine.") 112 | session = Session(messages=[prompt, response]) 113 | session.mask_message(0) 114 | assert session.messages_status == [False, True] 115 | 116 | 117 | def test_unmask_message(): 118 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 119 | response = Response(message="I am fine.") 120 | session = Session(messages=[prompt, response]) 121 | session.update_messages_status([False, False]) 122 | session.unmask_message(0) 123 | assert session.messages_status == [True, False] 124 | 125 | 126 | def test_masks(): 127 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 128 | response = Response(message="I am fine.") 129 | session = Session(messages=[prompt, response]) 130 | session.update_messages_status([False, True]) 131 | assert session.masks == [True, False] 132 | 133 | 134 | def test_add_message1(): 135 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 136 | response = Response(message="I am fine.") 137 | session = Session(messages=[prompt, response]) 138 | session.add_message(Response("Good!")) 139 | assert session.messages[2] == Response("Good!") 140 | 141 | 142 | def test_add_message2(): 143 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 144 | response = Response(message="I am fine.") 145 | session = Session(messages=[prompt, response]) 146 | session.add_message(message=Response("Good!"), status=False, index=0) 147 | assert session.messages[0] == Response("Good!") and session.messages_status[0] == False 148 | 149 | 150 | def test_add_message3(): 151 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 152 | response = Response(message="I am fine.") 153 | session = Session(messages=[prompt, response]) 154 | with pytest.raises(MemorValidationError, match=r"Invalid message. It must be an instance of `Prompt` or `Response`."): 155 | session.add_message(message="Good!", status=False, index=0) 156 | 157 | 158 | def test_add_message4(): 159 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 160 | response = Response(message="I am fine.") 161 | session = Session(messages=[prompt, response]) 162 | with pytest.raises(MemorValidationError, match=r"Invalid value. `status` must be a boolean."): 163 | session.add_message(message=prompt, status="False", index=0) 164 | 165 | 166 | def test_remove_message1(): 167 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 168 | response = Response(message="I am fine.") 169 | session = Session(messages=[prompt, response]) 170 | session.remove_message(1) 171 | assert session.messages == [prompt] and session.messages_status == [True] 172 | 173 | 174 | def test_remove_message2(): 175 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 176 | response = Response(message="I am fine.") 177 | session = Session(messages=[prompt, response]) 178 | session.remove_message_by_index(1) 179 | assert session.messages == [prompt] and session.messages_status == [True] 180 | 181 | 182 | def test_remove_message3(): 183 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 184 | response = Response(message="I am fine.") 185 | session = Session(messages=[prompt, response]) 186 | session.remove_message_by_id(response.id) 187 | assert session.messages == [prompt] and session.messages_status == [True] 188 | 189 | 190 | def test_remove_message4(): 191 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 192 | response = Response(message="I am fine.") 193 | session = Session(messages=[prompt, response]) 194 | session.remove_message(response.id) 195 | assert session.messages == [prompt] and session.messages_status == [True] 196 | 197 | 198 | def test_remove_message5(): 199 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 200 | response = Response(message="I am fine.") 201 | session = Session(messages=[prompt, response]) 202 | with pytest.raises(MemorValidationError, match=r"Invalid value. `identifier` must be an integer or a string."): 203 | session.remove_message(3.5) 204 | 205 | 206 | def test_clear_messages(): 207 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 208 | response = Response(message="I am fine.") 209 | session = Session(messages=[prompt, response]) 210 | assert len(session) == 2 211 | session.clear_messages() 212 | assert len(session) == 0 213 | 214 | 215 | def test_copy1(): 216 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 217 | response = Response(message="I am fine.") 218 | session1 = Session(messages=[prompt, response], title="session") 219 | session2 = copy.copy(session1) 220 | assert id(session1) != id(session2) 221 | 222 | 223 | def test_copy2(): 224 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 225 | response = Response(message="I am fine.") 226 | session1 = Session(messages=[prompt, response], title="session") 227 | session2 = session1.copy() 228 | assert id(session1) != id(session2) 229 | 230 | 231 | def test_str(): 232 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 233 | response = Response(message="I am fine.") 234 | session = Session(messages=[prompt, response], title="session1") 235 | assert str(session) == session.render(render_format=RenderFormat.STRING) 236 | 237 | 238 | def test_repr(): 239 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 240 | response = Response(message="I am fine.") 241 | session = Session(messages=[prompt, response], title="session1") 242 | assert repr(session) == "Session(title={title})".format(title=session.title) 243 | 244 | 245 | def test_json(): 246 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 247 | response = Response(message="I am fine.") 248 | session1 = Session(messages=[prompt, response], title="session1") 249 | session1_json = session1.to_json() 250 | session2 = Session() 251 | session2.from_json(session1_json) 252 | assert session1 == session2 253 | 254 | 255 | def test_save1(): 256 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 257 | response = Response(message="I am fine.") 258 | session = Session(messages=[prompt, response], title="session1") 259 | result = session.save("f:/") 260 | assert result["status"] == False 261 | 262 | 263 | def test_save2(): 264 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 265 | response = Response(message="I am fine.") 266 | session1 = Session(messages=[prompt, response], title="session1") 267 | _ = session1.render() 268 | result = session1.save("session_test1.json") 269 | session2 = Session(file_path="session_test1.json") 270 | assert result["status"] and session1 == session2 and session2.render_counter == 1 271 | 272 | 273 | def test_load1(): 274 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: 22"): 275 | _ = Session(file_path=22) 276 | 277 | 278 | def test_load2(): 279 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: session_test10.json"): 280 | _ = Session(file_path="session_test10.json") 281 | 282 | 283 | def test_render1(): 284 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 285 | response = Response(message="I am fine.") 286 | session = Session(messages=[prompt, response], title="session1") 287 | assert session.render() == "Hello, how are you?\nI am fine.\n" 288 | 289 | 290 | def test_render2(): 291 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 292 | response = Response(message="I am fine.") 293 | session = Session(messages=[prompt, response], title="session1") 294 | assert session.render(RenderFormat.OPENAI) == [{"role": "user", "content": "Hello, how are you?"}, { 295 | "role": "assistant", "content": "I am fine."}] 296 | 297 | 298 | def test_render3(): 299 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 300 | response = Response(message="I am fine.") 301 | session = Session(messages=[prompt, response], title="session1") 302 | assert session.render(RenderFormat.DICTIONARY)["content"] == "Hello, how are you?\nI am fine.\n" 303 | 304 | 305 | def test_render4(): 306 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 307 | response = Response(message="I am fine.") 308 | session = Session(messages=[prompt, response], title="session1") 309 | assert ("content", "Hello, how are you?\nI am fine.\n") in session.render(RenderFormat.ITEMS) 310 | 311 | 312 | def test_render5(): 313 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 314 | response = Response(message="I am fine.") 315 | session = Session(messages=[prompt, response], title="session1") 316 | with pytest.raises(MemorValidationError, match=r"Invalid render format. It must be an instance of RenderFormat enum."): 317 | session.render("OPENAI") 318 | 319 | 320 | def test_render6(): 321 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 322 | response = Response(message="I am fine.") 323 | session = Session(messages=[prompt, response], title="session1") 324 | assert session.render(RenderFormat.AI_STUDIO) == [{'role': 'user', 'parts': [{'text': 'Hello, how are you?'}]}, { 325 | 'role': 'assistant', 'parts': [{'text': 'I am fine.'}]}] 326 | 327 | 328 | def test_check_render1(): 329 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 330 | response = Response(message="I am fine.") 331 | session = Session(messages=[prompt, response], title="session1") 332 | assert session.check_render() 333 | 334 | 335 | def test_check_render2(): 336 | template = PromptTemplate(content="{response[2][message]}") 337 | prompt = Prompt(message="Hello, how are you?", role=Role.USER, template=template, init_check=False) 338 | response = Response(message="I am fine.") 339 | session = Session(messages=[prompt, response], title="session1", init_check=False) 340 | assert not session.check_render() 341 | 342 | 343 | def test_render_counter1(): 344 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 345 | response = Response(message="I am fine.") 346 | session = Session(messages=[prompt, response], title="session1") 347 | assert session.render_counter == 0 348 | for _ in range(10): 349 | __ = session.render() 350 | assert session.render_counter == 10 351 | 352 | 353 | def test_render_counter2(): 354 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 355 | response = Response(message="I am fine.") 356 | session = Session(messages=[prompt, response], title="session1") 357 | assert session.render_counter == 0 358 | for _ in range(10): 359 | __ = session.render() 360 | for _ in range(2): 361 | __ = session.render(enable_counter=False) 362 | assert session.render_counter == 10 363 | 364 | 365 | def test_render_counter3(): 366 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 367 | response = Response(message="I am fine.") 368 | session = Session(messages=[prompt, response], title="session1", init_check=True) 369 | _ = str(session) 370 | _ = session.check_render() 371 | _ = session.estimate_tokens() 372 | assert session.render_counter == 0 373 | 374 | 375 | def test_init_check(): 376 | template = PromptTemplate(content="{response[2][message]}") 377 | prompt = Prompt(message="Hello, how are you?", role=Role.USER, template=template, init_check=False) 378 | response = Response(message="I am fine.") 379 | with pytest.raises(MemorRenderError, match=r"Prompt template and properties are incompatible."): 380 | _ = Session(messages=[prompt, response], title="session1") 381 | 382 | 383 | def test_equality1(): 384 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 385 | response = Response(message="I am fine.") 386 | session1 = Session(messages=[prompt, response], title="session1") 387 | session2 = session1.copy() 388 | assert session1 == session2 389 | 390 | 391 | def test_equality2(): 392 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 393 | response = Response(message="I am fine.") 394 | session1 = Session(messages=[prompt, response], title="session1") 395 | session2 = Session(messages=[prompt, response], title="session2") 396 | assert session1 != session2 397 | 398 | 399 | def test_equality3(): 400 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 401 | response = Response(message="I am fine.") 402 | session1 = Session(messages=[prompt, response], title="session1") 403 | session2 = Session(messages=[prompt, response], title="session1") 404 | assert session1 == session2 405 | 406 | 407 | def test_equality4(): 408 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 409 | response = Response(message="I am fine.") 410 | session = Session(messages=[prompt, response], title="session1") 411 | assert session != 2 412 | 413 | 414 | def test_date_modified(): 415 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 416 | response = Response(message="I am fine.") 417 | session = Session(messages=[prompt, response], title="session1") 418 | assert isinstance(session.date_modified, datetime.datetime) 419 | 420 | 421 | def test_date_created(): 422 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 423 | response = Response(message="I am fine.") 424 | session = Session(messages=[prompt, response], title="session1") 425 | assert isinstance(session.date_created, datetime.datetime) 426 | 427 | 428 | def test_length(): 429 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 430 | response = Response(message="I am fine.") 431 | session = Session(messages=[prompt, response], title="session1") 432 | assert len(session) == len(session.messages) and len(session) == 2 433 | 434 | 435 | def test_iter(): 436 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 437 | response = Response(message="I am fine.") 438 | session = Session(messages=[prompt, response, prompt, response], title="session1") 439 | messages = [] 440 | for message in session: 441 | messages.append(message) 442 | assert session.messages == messages 443 | 444 | 445 | def test_addition1(): 446 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 447 | response = Response(message="I am fine.") 448 | session1 = Session(messages=[prompt, response, prompt, response], title="session1") 449 | session2 = Session(messages=[prompt, prompt, response, response], title="session2") 450 | session3 = session1 + session2 451 | assert session3.title is None and session3.messages == session1.messages + session2.messages 452 | 453 | 454 | def test_addition2(): 455 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 456 | response = Response(message="I am fine.") 457 | session1 = Session(messages=[prompt, response, prompt, response], title="session1") 458 | session2 = Session(messages=[prompt, prompt, response, response], title="session2") 459 | session3 = session2 + session1 460 | assert session3.title is None and session3.messages == session2.messages + session1.messages 461 | 462 | 463 | def test_addition3(): 464 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 465 | response = Response(message="I am fine.") 466 | session1 = Session(messages=[prompt, response, prompt, response], title="session1") 467 | session2 = session1 + response 468 | assert session2.title == "session1" and session2.messages == session1.messages + [response] 469 | 470 | 471 | def test_addition4(): 472 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 473 | response = Response(message="I am fine.") 474 | session1 = Session(messages=[prompt, response, prompt, response], title="session1") 475 | session2 = session1 + prompt 476 | assert session2.title == "session1" and session2.messages == session1.messages + [prompt] 477 | 478 | 479 | def test_addition5(): 480 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 481 | response = Response(message="I am fine.") 482 | session1 = Session(messages=[prompt, response, prompt, response], title="session1") 483 | session2 = response + session1 484 | assert session2.title == "session1" and session2.messages == [response] + session1.messages 485 | 486 | 487 | def test_addition6(): 488 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 489 | response = Response(message="I am fine.") 490 | session1 = Session(messages=[prompt, response, prompt, response], title="session1") 491 | session2 = prompt + session1 492 | assert session2.title == "session1" and session2.messages == [prompt] + session1.messages 493 | 494 | 495 | def test_addition7(): 496 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 497 | response = Response(message="I am fine.") 498 | session1 = Session(messages=[prompt, response, prompt, response], title="session1") 499 | with pytest.raises(TypeError, match=re.escape(r"Unsupported operand type(s) for +: `Session` and `int`")): 500 | _ = session1 + 2 501 | 502 | 503 | def test_addition8(): 504 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 505 | response = Response(message="I am fine.") 506 | session1 = Session(messages=[prompt, response, prompt, response], title="session1") 507 | with pytest.raises(TypeError, match=re.escape(r"Unsupported operand type(s) for +: `Session` and `int`")): 508 | _ = 2 + session1 509 | 510 | 511 | def test_contains1(): 512 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 513 | response = Response(message="I am fine.") 514 | session = Session(messages=[prompt, response], title="session") 515 | assert prompt in session and response in session 516 | 517 | 518 | def test_contains2(): 519 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 520 | response1 = Response(message="I am fine.") 521 | response2 = Response(message="Good!") 522 | session = Session(messages=[prompt, response1], title="session") 523 | assert response2 not in session 524 | 525 | 526 | def test_contains3(): 527 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 528 | response = Response(message="I am fine.") 529 | session = Session(messages=[prompt, response], title="session") 530 | assert "I am fine." not in session 531 | 532 | 533 | def test_getitem1(): 534 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 535 | response = Response(message="I am fine.") 536 | session = Session(messages=[prompt, response], title="session") 537 | assert session[0] == session.messages[0] and session[1] == session.messages[1] 538 | 539 | 540 | def test_getitem2(): 541 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 542 | response = Response(message="I am fine.") 543 | session = Session(messages=[prompt, response, response, response], title="session") 544 | assert session[:] == session.messages 545 | 546 | 547 | def test_getitem3(): 548 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 549 | response = Response(message="I am fine.") 550 | session = Session(messages=[prompt, response], title="session") 551 | assert session[0] == session.get_message_by_index(0) and session[1] == session.get_message_by_index(1) 552 | 553 | 554 | def test_getitem4(): 555 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 556 | response = Response(message="I am fine.") 557 | session = Session(messages=[prompt, response], title="session") 558 | assert session[0] == session.get_message(0) and session[1] == session.get_message(1) 559 | 560 | 561 | def test_getitem5(): 562 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 563 | response = Response(message="I am fine.") 564 | session = Session(messages=[prompt, response], title="session") 565 | assert session[0] == session.get_message_by_id(prompt.id) and session[1] == session.get_message_by_id(response.id) 566 | 567 | 568 | def test_getitem6(): 569 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 570 | response = Response(message="I am fine.") 571 | session = Session(messages=[prompt, response], title="session") 572 | assert session[0] == session.get_message(prompt.id) and session[1] == session.get_message(response.id) 573 | 574 | 575 | def test_getitem7(): 576 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 577 | response = Response(message="I am fine.") 578 | session = Session(messages=[prompt, response], title="session") 579 | with pytest.raises(MemorValidationError, match=r"Invalid value. `identifier` must be an integer, string or a slice."): 580 | _ = session[3.5] 581 | 582 | 583 | def test_getitem8(): 584 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 585 | response = Response(message="I am fine.") 586 | session = Session(messages=[prompt, response], title="session") 587 | with pytest.raises(MemorValidationError, match=r"Invalid value. `identifier` must be an integer, string or a slice."): 588 | _ = session.get_message(3.5) 589 | 590 | 591 | def test_estimated_tokens1(): 592 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 593 | response = Response(message="I am fine.") 594 | session = Session(messages=[prompt, response], title="session") 595 | assert session.estimate_tokens(TokensEstimator.UNIVERSAL) == 12 596 | 597 | 598 | def test_estimated_tokens2(): 599 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 600 | response = Response(message="I am fine.") 601 | session = Session(messages=[prompt, response], title="session") 602 | assert session.estimate_tokens(TokensEstimator.OPENAI_GPT_3_5) == 14 603 | 604 | 605 | def test_estimated_tokens3(): 606 | prompt = Prompt(message="Hello, how are you?", role=Role.USER) 607 | response = Response(message="I am fine.") 608 | session = Session(messages=[prompt, response], title="session") 609 | assert session.estimate_tokens(TokensEstimator.OPENAI_GPT_4) == 15 610 | -------------------------------------------------------------------------------- /tests/test_token_estimators.py: -------------------------------------------------------------------------------- 1 | from memor.tokens_estimator import openai_tokens_estimator_gpt_3_5, openai_tokens_estimator_gpt_4, universal_tokens_estimator 2 | 3 | TEST_CASE_NAME = "Token Estimators tests" 4 | 5 | 6 | def test_universal_tokens_estimator_with_contractions(): 7 | message = "I'm going to the park." 8 | assert universal_tokens_estimator(message) == 7 9 | message = "They'll be here soon." 10 | assert universal_tokens_estimator(message) == 7 11 | 12 | 13 | def test_universal_tokens_estimator_with_code_snippets(): 14 | message = "def foo(): return 42" 15 | assert universal_tokens_estimator(message) == 7 16 | message = "if x == 10:" 17 | assert universal_tokens_estimator(message) == 6 18 | 19 | 20 | def test_universal_tokens_estimator_with_loops(): 21 | message = "for i in range(10):" 22 | assert universal_tokens_estimator(message) == 8 23 | message = "while True:" 24 | assert universal_tokens_estimator(message) == 4 25 | 26 | 27 | def test_universal_tokens_estimator_with_long_sentences(): 28 | message = "Understanding natural language processing is fun!" 29 | assert universal_tokens_estimator(message) == 17 30 | message = "Tokenization involves splitting text into meaningful units." 31 | assert universal_tokens_estimator(message) == 24 32 | 33 | 34 | def test_universal_tokens_estimator_with_variable_names(): 35 | message = "some_variable_name = 100" 36 | assert universal_tokens_estimator(message) == 5 37 | message = "another_long_var_name = 'test'" 38 | assert universal_tokens_estimator(message) == 6 39 | 40 | 41 | def test_universal_tokens_estimator_with_function_definitions(): 42 | message = "The function `def add(x, y): return x + y` adds two numbers." 43 | assert universal_tokens_estimator(message) == 20 44 | message = "Use `for i in range(5):` to loop." 45 | assert universal_tokens_estimator(message) == 14 46 | 47 | 48 | def test_universal_tokens_estimator_with_numbers(): 49 | message = "The year 2023 was great!" 50 | assert universal_tokens_estimator(message) == 6 51 | message = "42 is the answer to everything." 52 | assert universal_tokens_estimator(message) == 11 53 | 54 | 55 | def test_universal_tokens_estimator_with_print_statements(): 56 | message = "print('Hello, world!')" 57 | assert universal_tokens_estimator(message) == 5 58 | message = "name = \"Alice\"" 59 | assert universal_tokens_estimator(message) == 3 60 | 61 | 62 | def test_openai_tokens_estimator_with_function_definition(): 63 | message = "def add(a, b): return a + b" 64 | assert openai_tokens_estimator_gpt_3_5(message) == 11 65 | 66 | 67 | def test_openai_tokens_estimator_with_url(): 68 | message = "Visit https://openai.com for more info." 69 | assert openai_tokens_estimator_gpt_3_5(message) == 18 70 | 71 | 72 | def test_openai_tokens_estimator_with_long_words(): 73 | message = "This is a verylongwordwithoutspaces and should be counted properly." 74 | assert openai_tokens_estimator_gpt_3_5(message) == 25 75 | 76 | 77 | def test_openai_tokens_estimator_with_newlines(): 78 | message = "Line1\nLine2\nLine3\n" 79 | assert openai_tokens_estimator_gpt_3_5(message) == 7 80 | 81 | 82 | def test_openai_tokens_estimator_with_gpt4_model(): 83 | message = "This is a test sentence that should be counted properly even with GPT-4. I am making it longer to test the model." 84 | assert openai_tokens_estimator_gpt_4(message) == 45 85 | --------------------------------------------------------------------------------