├── .coveragerc
├── .github
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── ISSUE_TEMPLATE
│ ├── bug_report.yml
│ ├── config.yml
│ └── feature_request.yml
├── PULL_REQUEST_TEMPLATE.md
├── dependabot.yml
└── workflows
│ ├── publish_pypi.yml
│ └── test.yml
├── .gitignore
├── .pydocstyle
├── AUTHORS.md
├── CHANGELOG.md
├── LICENSE
├── README.md
├── SECURITY.md
├── autopep8.bat
├── autopep8.sh
├── codecov.yml
├── dev-requirements.txt
├── memor
├── __init__.py
├── errors.py
├── functions.py
├── keywords.py
├── params.py
├── prompt.py
├── response.py
├── session.py
├── template.py
└── tokens_estimator.py
├── otherfiles
├── RELEASE.md
├── donation.png
├── meta.yaml
├── requirements-splitter.py
└── version_check.py
├── requirements.txt
├── setup.py
└── tests
├── test_prompt.py
├── test_prompt_template.py
├── test_response.py
├── test_session.py
└── test_token_estimators.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | branch = True
3 | omit =
4 | */memor/__main__.py
5 | [report]
6 | # Regexes for lines to exclude from consideration
7 | exclude_lines =
8 | pragma: no cover
9 |
--------------------------------------------------------------------------------
/.github/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, caste, color, religion, or sexual
10 | identity and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the overall
26 | community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or advances of
31 | any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email address,
35 | without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a professional setting
37 |
38 | ## Enforcement Responsibilities
39 |
40 | Community leaders are responsible for clarifying and enforcing our standards of
41 | acceptable behavior and will take appropriate and fair corrective action in
42 | response to any behavior that they deem inappropriate, threatening, offensive,
43 | or harmful.
44 |
45 | Community leaders have the right and responsibility to remove, edit, or reject
46 | comments, commits, code, wiki edits, issues, and other contributions that are
47 | not aligned to this Code of Conduct, and will communicate reasons for moderation
48 | decisions when appropriate.
49 |
50 | ## Scope
51 | This Code of Conduct applies both within project spaces and in public spaces
52 | when an individual is representing the project or its community.
53 | Examples of representing our community include using an official e-mail address,
54 | posting via an official social media account, or acting as an appointed
55 | representative at an online or offline event.
56 |
57 | ## Enforcement
58 |
59 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
60 | reported to the community leaders responsible for enforcement at
61 | memor@openscilab.com.
62 | All complaints will be reviewed and investigated promptly and fairly.
63 |
64 | All community leaders are obligated to respect the privacy and security of the
65 | reporter of any incident.
66 |
67 | ## Enforcement Guidelines
68 |
69 | Community leaders will follow these Community Impact Guidelines in determining
70 | the consequences for any action they deem in violation of this Code of Conduct:
71 |
72 | ### 1. Correction
73 |
74 | **Community Impact**: Use of inappropriate language or other behavior deemed
75 | unprofessional or unwelcome in the community.
76 |
77 | **Consequence**: A private, written warning from community leaders, providing
78 | clarity around the nature of the violation and an explanation of why the
79 | behavior was inappropriate. A public apology may be requested.
80 |
81 | ### 2. Warning
82 |
83 | **Community Impact**: A violation through a single incident or series of
84 | actions.
85 |
86 | **Consequence**: A warning with consequences for continued behavior. No
87 | interaction with the people involved, including unsolicited interaction with
88 | those enforcing the Code of Conduct, for a specified period of time. This
89 | includes avoiding interactions in community spaces as well as external channels
90 | like social media. Violating these terms may lead to a temporary or permanent
91 | ban.
92 |
93 | ### 3. Temporary Ban
94 |
95 | **Community Impact**: A serious violation of community standards, including
96 | sustained inappropriate behavior.
97 |
98 | **Consequence**: A temporary ban from any sort of interaction or public
99 | communication with the community for a specified period of time. No public or
100 | private interaction with the people involved, including unsolicited interaction
101 | with those enforcing the Code of Conduct, is allowed during this period.
102 | Violating these terms may lead to a permanent ban.
103 |
104 | ### 4. Permanent Ban
105 |
106 | **Community Impact**: Demonstrating a pattern of violation of community
107 | standards, including sustained inappropriate behavior, harassment of an
108 | individual, or aggression toward or disparagement of classes of individuals.
109 |
110 | **Consequence**: A permanent ban from any sort of public interaction within the
111 | community.
112 |
113 | ## Attribution
114 |
115 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
116 | version 2.1, available at
117 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
118 |
119 | Community Impact Guidelines were inspired by
120 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC].
121 |
122 | For answers to common questions about this code of conduct, see the FAQ at
123 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
124 | [https://www.contributor-covenant.org/translations][translations].
125 |
126 | [homepage]: https://www.contributor-covenant.org
127 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
128 | [Mozilla CoC]: https://github.com/mozilla/diversity
129 | [FAQ]: https://www.contributor-covenant.org/faq
130 | [translations]: https://www.contributor-covenant.org/translations
131 |
--------------------------------------------------------------------------------
/.github/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution
2 |
3 | Changes and improvements are more than welcome! ❤️ Feel free to fork and open a pull request.
4 |
5 |
6 | Please consider the following :
7 |
8 |
9 | 1. Fork it!
10 | 2. Create your feature branch (under `dev` branch)
11 | 3. Add your functions/methods to proper files
12 | 4. Add standard `docstring` to your functions/methods
13 | 5. Add tests for your functions/methods (`unittest` testcases in `tests` folder)
14 | 6. Pass all CI tests
15 | 7. Update `CHANGELOG.md`
16 | - Describe changes under `[Unreleased]` section
17 | 8. Submit a pull request into `dev` (please complete the pull request template)
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: File a bug report
3 | title: "[Bug]: "
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: |
8 | Thanks for your time to fill out this bug report!
9 | - type: input
10 | id: contact
11 | attributes:
12 | label: Contact details
13 | description: How can we get in touch with you if we need more info?
14 | placeholder: ex. email@example.com
15 | validations:
16 | required: false
17 | - type: textarea
18 | id: what-happened
19 | attributes:
20 | label: What happened?
21 | description: Provide a clear and concise description of what the bug is.
22 | placeholder: >
23 | Tell us a description of the bug.
24 | validations:
25 | required: true
26 | - type: textarea
27 | id: step-to-reproduce
28 | attributes:
29 | label: Steps to reproduce
30 | description: Provide details of how to reproduce the bug.
31 | placeholder: >
32 | ex. 1. Go to '...'
33 | validations:
34 | required: true
35 | - type: textarea
36 | id: expected-behavior
37 | attributes:
38 | label: Expected behavior
39 | description: What did you expect to happen?
40 | placeholder: >
41 | ex. I expected '...' to happen
42 | validations:
43 | required: true
44 | - type: textarea
45 | id: actual-behavior
46 | attributes:
47 | label: Actual behavior
48 | description: What did actually happen?
49 | placeholder: >
50 | ex. Instead '...' happened
51 | validations:
52 | required: true
53 | - type: dropdown
54 | id: operating-system
55 | attributes:
56 | label: Operating system
57 | description: Which operating system are you using?
58 | options:
59 | - Windows
60 | - macOS
61 | - Linux
62 | default: 0
63 | validations:
64 | required: true
65 | - type: dropdown
66 | id: python-version
67 | attributes:
68 | label: Python version
69 | description: Which version of Python are you using?
70 | options:
71 | - Python 3.13
72 | - Python 3.12
73 | - Python 3.11
74 | - Python 3.10
75 | - Python 3.9
76 | - Python 3.8
77 | - Python 3.7
78 | - Python 3.6
79 | default: 1
80 | validations:
81 | required: true
82 | - type: dropdown
83 | id: memor-version
84 | attributes:
85 | label: Memor version
86 | description: Which version of Memor are you using?
87 | options:
88 | - Memor 0.6
89 | - Memor 0.5
90 | - Memor 0.4
91 | - Memor 0.3
92 | - Memor 0.2
93 | - Memor 0.1
94 | default: 0
95 | validations:
96 | required: true
97 | - type: textarea
98 | id: logs
99 | attributes:
100 | label: Relevant log output
101 | description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
102 | render: shell
103 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: Discord
4 | url: https://discord.gg/cZxGwZ6utB
5 | about: Ask questions and discuss with other Memor community members
6 | - name: Website
7 | url: https://openscilab.com/
8 | about: Check out our website for more information
9 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: Feature Request
2 | description: Suggest a feature for this project
3 | title: "[Feature]: "
4 | body:
5 | - type: textarea
6 | id: description
7 | attributes:
8 | label: Describe the feature you want to add
9 | placeholder: >
10 | I'd like to be able to [...]
11 | validations:
12 | required: true
13 | - type: textarea
14 | id: possible-solution
15 | attributes:
16 | label: Describe your proposed solution
17 | placeholder: >
18 | I think this could be done by [...]
19 | validations:
20 | required: false
21 | - type: textarea
22 | id: alternatives
23 | attributes:
24 | label: Describe alternatives you've considered, if relevant
25 | placeholder: >
26 | Another way to do this would be [...]
27 | validations:
28 | required: false
29 | - type: textarea
30 | id: additional-context
31 | attributes:
32 | label: Additional context
33 | placeholder: >
34 | Add any other context or screenshots about the feature request here.
35 | validations:
36 | required: false
37 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | #### Reference Issues/PRs
2 |
3 | #### What does this implement/fix? Explain your changes.
4 |
5 | #### Any other comments?
6 |
7 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: pip
4 | directory: "/"
5 | schedule:
6 | interval: weekly
7 | time: "01:30"
8 | open-pull-requests-limit: 10
9 | target-branch: dev
10 | assignees:
11 | - "sadrasabouri"
12 | - "sepandhaghighi"
13 |
--------------------------------------------------------------------------------
/.github/workflows/publish_pypi.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | push:
8 | # Sequence of patterns matched against refs/tags
9 | tags:
10 | - '*' # Push events to matching v*, i.e. v1.0, v20.15.10
11 |
12 | jobs:
13 | deploy:
14 |
15 | runs-on: ubuntu-22.04
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python
20 | uses: actions/setup-python@v1
21 | with:
22 | python-version: '3.x'
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | pip install setuptools wheel twine
27 | - name: Build and publish
28 | env:
29 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
30 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
31 | run: |
32 | python setup.py sdist bdist_wheel
33 | twine upload dist/*.tar.gz
34 | twine upload dist/*.whl
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: CI
5 |
6 | on:
7 | push:
8 | branches:
9 | - main
10 | - dev
11 |
12 | pull_request:
13 | branches:
14 | - dev
15 | - main
16 |
17 | env:
18 | TEST_PYTHON_VERSION: 3.9
19 | TEST_OS: 'ubuntu-22.04'
20 |
21 | jobs:
22 | build:
23 | runs-on: ${{ matrix.os }}
24 | strategy:
25 | fail-fast: false
26 | matrix:
27 | os: [ubuntu-22.04, windows-2022, macOS-13]
28 | python-version: [3.7, 3.8, 3.9, 3.10.5, 3.11.0, 3.12.0, 3.13.0]
29 | steps:
30 | - uses: actions/checkout@v2
31 | - name: Set up Python ${{ matrix.python-version }}
32 | uses: actions/setup-python@v2
33 | with:
34 | python-version: ${{ matrix.python-version }}
35 | - name: Installation
36 | run: |
37 | python -m pip install --upgrade pip
38 | pip install .
39 | - name: Test requirements installation
40 | run: |
41 | python otherfiles/requirements-splitter.py
42 | pip install --upgrade --upgrade-strategy=only-if-needed -r test-requirements.txt
43 | - name: Test with pytest
44 | run: |
45 | python -m pytest . --cov=memor --cov-report=term
46 | - name: Upload coverage to Codecov
47 | uses: codecov/codecov-action@v4
48 | with:
49 | fail_ci_if_error: true
50 | token: ${{ secrets.CODECOV_TOKEN }}
51 | if: matrix.python-version == env.TEST_PYTHON_VERSION && matrix.os == env.TEST_OS
52 | - name: Vulture, Bandit and Pydocstyle tests
53 | run: |
54 | python -m vulture memor/ otherfiles/ setup.py --min-confidence 65 --exclude=__init__.py --sort-by-size
55 | python -m bandit -r memor -s B311
56 | python -m pydocstyle -v
57 | if: matrix.python-version == env.TEST_PYTHON_VERSION
58 | - name: Version check
59 | run: |
60 | python otherfiles/version_check.py
61 | if: matrix.python-version == env.TEST_PYTHON_VERSION
62 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Python template
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | env/
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *,cover
48 | .hypothesis/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 |
58 | # Flask stuff:
59 | instance/
60 | .webassets-cache
61 |
62 | # Scrapy stuff:
63 | .scrapy
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 |
68 | # PyBuilder
69 | target/
70 |
71 | # Jupyter Notebook
72 | .ipynb_checkpoints
73 |
74 | # pyenv
75 | .python-version
76 |
77 | # celery beat schedule file
78 | celerybeat-schedule
79 |
80 | # dotenv
81 | .env
82 |
83 | # virtualenv
84 | .venv/
85 | venv/
86 | ENV/
87 |
88 | # Spyder project settings
89 | .spyderproject
90 |
91 | # Rope project settings
92 | .ropeproject
93 | ### Example user template template
94 | ### Example user template
95 |
96 | # IntelliJ project files
97 | .idea
98 | *.iml
99 | out
100 | gen
101 |
102 | # Outputs
103 |
104 | prompt_test1.json
105 | prompt_test2.json
106 | prompt_test3.json
107 | response_test1.json
108 | response_test2.json
109 | response_test3.json
110 | session_test1.json
111 | template_test1.json
112 | template_test2.json
113 |
--------------------------------------------------------------------------------
/.pydocstyle:
--------------------------------------------------------------------------------
1 | [pydocstyle]
2 | match_dir = ^(?!(tests|build)).*
3 | match = .*\.py
4 |
5 |
--------------------------------------------------------------------------------
/AUTHORS.md:
--------------------------------------------------------------------------------
1 | # Core Developers
2 | ----------
3 | - Sepand Haghighi - Open Science Laboratory ([Github](https://github.com/sepandhaghighi)) **
4 | - Sadra Sabouri - Open Science Laboratory ([Github](https://github.com/sadrasabouri)) **
5 |
6 | ** **Maintainer**
7 |
8 | # Other Contributors
9 | ----------
10 |
11 |
12 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 | All notable changes to this project will be documented in this file.
3 |
4 | The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
5 | and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
6 |
7 | ## [Unreleased]
8 | ## [0.6] - 2025-05-05
9 | ### Added
10 | - `Response` class `id` property
11 | - `Prompt` class `id` property
12 | - `Response` class `regenerate_id` method
13 | - `Prompt` class `regenerate_id` method
14 | - `Session` class `render_counter` method
15 | - `Session` class `remove_message_by_index` and `remove_message_by_id` methods
16 | - `Session` class `get_message_by_index`, `get_message_by_id` and `get_message` methods
17 | - `LLMModel` enum
18 | - `AI_STUDIO` render format
19 | ### Changed
20 | - Test system modified
21 | - Modification handling centralized via `_mark_modified` method
22 | - `Session` class `remove_message` method modified
23 | ## [0.5] - 2025-04-16
24 | ### Added
25 | - `Session` class `check_render` method
26 | - `Session` class `clear_messages` method
27 | - `Prompt` class `check_render` method
28 | - `Session` class `estimate_tokens` method
29 | - `Prompt` class `estimate_tokens` method
30 | - `Response` class `estimate_tokens` method
31 | - `universal_tokens_estimator` function
32 | - `openai_tokens_estimator_gpt_3_5` function
33 | - `openai_tokens_estimator_gpt_4` function
34 | ### Changed
35 | - `init_check` parameter added to `Prompt` class
36 | - `init_check` parameter added to `Session` class
37 | - Test system modified
38 | - `Python 3.6` support dropped
39 | - `README.md` updated
40 | ## [0.4] - 2025-03-17
41 | ### Added
42 | - `Session` class `__contains__` method
43 | - `Session` class `__getitem__` method
44 | - `Session` class `mask_message` method
45 | - `Session` class `unmask_message` method
46 | - `Session` class `masks` attribute
47 | - `Response` class `__len__` method
48 | - `Prompt` class `__len__` method
49 | ### Changed
50 | - `inference_time` parameter added to `Response` class
51 | - `README.md` updated
52 | - Test system modified
53 | - Python typing features added to all modules
54 | - `Prompt` class default values updated
55 | - `Response` class default values updated
56 | ## [0.3] - 2025-03-08
57 | ### Added
58 | - `Session` class `__len__` method
59 | - `Session` class `__iter__` method
60 | - `Session` class `__add__` and `__radd__` methods
61 | ### Changed
62 | - `tokens` parameter added to `Prompt` class
63 | - `tokens` parameter added to `Response` class
64 | - `tokens` parameter added to preset templates
65 | - `Prompt` class modified
66 | - `Response` class modified
67 | - `PromptTemplate` class modified
68 | ## [0.2] - 2025-03-01
69 | ### Added
70 | - `Session` class
71 | ### Changed
72 | - `Prompt` class modified
73 | - `Response` class modified
74 | - `PromptTemplate` class modified
75 | - `README.md` updated
76 | - Test system modified
77 | ## [0.1] - 2025-02-12
78 | ### Added
79 | - `Prompt` class
80 | - `Response` class
81 | - `PromptTemplate` class
82 | - `PresetPromptTemplate` class
83 |
84 |
85 | [Unreleased]: https://github.com/openscilab/memor/compare/v0.6...dev
86 | [0.6]: https://github.com/openscilab/memor/compare/v0.5...v0.6
87 | [0.5]: https://github.com/openscilab/memor/compare/v0.4...v0.5
88 | [0.4]: https://github.com/openscilab/memor/compare/v0.3...v0.4
89 | [0.3]: https://github.com/openscilab/memor/compare/v0.2...v0.3
90 | [0.2]: https://github.com/openscilab/memor/compare/v0.1...v0.2
91 | [0.1]: https://github.com/openscilab/memor/compare/6594313...v0.1
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 OpenSciLab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Memor: A Python Library for Managing and Transferring Conversational Memory Across LLMs
3 |
4 |

5 |

6 |

7 |

8 |

9 |
10 |
11 | ----------
12 |
13 |
14 | ## Overview
15 |
16 | Memor is a library designed to help users manage the memory of their interactions with Large Language Models (LLMs).
17 | It enables users to seamlessly access and utilize the history of their conversations when prompting LLMs.
18 | That would create a more personalized and context-aware experience.
19 | Memor stands out by allowing users to transfer conversational history across different LLMs, eliminating cold starts where models don't have information about user and their preferences.
20 | Users can select specific parts of past interactions with one LLM and share them with another.
21 | By bridging the gap between isolated LLM instances, Memor revolutionizes the way users interact with AI by making transitions between models smoother.
22 |
23 |
24 |
25 |
26 | PyPI Counter |
27 |
28 |
29 |
30 |
31 | |
32 |
33 |
34 | Github Stars |
35 |
36 |
37 |
38 |
39 | |
40 |
41 |
42 |
43 |
44 | Branch |
45 | main |
46 | dev |
47 |
48 |
49 | CI |
50 |
51 |
52 | |
53 |
54 |
55 | |
56 |
57 |
58 |
59 |
60 | Code Quality |
61 |  |
62 |  |
63 |
64 |
65 |
66 |
67 | ## Installation
68 |
69 | ### PyPI
70 | - Check [Python Packaging User Guide](https://packaging.python.org/installing/)
71 | - Run `pip install memor==0.6`
72 | ### Source code
73 | - Download [Version 0.6](https://github.com/openscilab/memor/archive/v0.6.zip) or [Latest Source](https://github.com/openscilab/memor/archive/dev.zip)
74 | - Run `pip install .`
75 |
76 | ## Usage
77 | Define your prompt and the response(s) to that; Memor will wrap it into a object with a templated representation.
78 | You can create a session by combining multiple prompts and responses, gradually building it up:
79 |
80 | ```pycon
81 | >>> from memor import Session, Prompt, Response, Role
82 | >>> from memor import PresetPromptTemplate, RenderFormat, LLMModel
83 | >>> response = Response(message="I am fine.", model=LLMModel.GPT_4, role=Role.ASSISTANT, temperature=0.9, score=0.9)
84 | >>> prompt = Prompt(message="Hello, how are you?",
85 | responses=[response],
86 | role=Role.USER,
87 | template=PresetPromptTemplate.INSTRUCTION1.PROMPT_RESPONSE_STANDARD)
88 | >>> system_prompt = Prompt(message="You are a friendly and informative AI assistant designed to answer questions on a wide range of topics.",
89 | role=Role.SYSTEM)
90 | >>> session = Session(messages=[system_prompt, prompt])
91 | >>> session.render(RenderFormat.OPENAI)
92 | ```
93 |
94 | The rendered output will be a list of messages formatted for compatibility with the OpenAI API.
95 |
96 | ```json
97 | [{"content": "You are a friendly and informative AI assistant designed to answer questions on a wide range of topics.", "role": "system"},
98 | {"content": "I'm providing you with a history of a previous conversation. Please consider this context when responding to my new question.\n"
99 | "Prompt: Hello, how are you?\n"
100 | "Response: I am fine.",
101 | "role": "user"}]
102 | ```
103 |
104 | ### Prompt Templates
105 |
106 | #### Preset Templates
107 |
108 | Memor provides a variety of pre-defined prompt templates to control how prompts and responses are rendered. Each template is prefixed by an optional instruction string and includes variations for different formatting styles. Following are different variants of parameters:
109 |
110 | | **Instruction Name** | **Description** |
111 | |---------------|----------|
112 | | `INSTRUCTION1` | "I'm providing you with a history of a previous conversation. Please consider this context when responding to my new question." |
113 | | `INSTRUCTION2` | "Here is the context from a prior conversation. Please learn from this information and use it to provide a thoughtful and context-aware response to my next questions." |
114 | | `INSTRUCTION3` | "I am sharing a record of a previous discussion. Use this information to provide a consistent and relevant answer to my next query." |
115 |
116 | | **Template Title** | **Description** |
117 | |--------------|----------|
118 | | `PROMPT` | Only includes the prompt message. |
119 | | `RESPONSE` | Only includes the response message. |
120 | | `RESPONSE0` to `RESPONSE3` | Include specific responses from a list of multiple responses. |
121 | | `PROMPT_WITH_LABEL` | Prompt with a "Prompt: " prefix. |
122 | | `RESPONSE_WITH_LABEL` | Response with a "Response: " prefix. |
123 | | `RESPONSE0_WITH_LABEL` to `RESPONSE3_WITH_LABEL` | Labeled response for the i-th response. |
124 | | `PROMPT_RESPONSE_STANDARD` | Includes both labeled prompt and response on a single line. |
125 | | `PROMPT_RESPONSE_FULL` | A detailed multi-line representation including role, date, model, etc. |
126 |
127 | You can access them like this:
128 |
129 | ```pycon
130 | >>> from memor import PresetPromptTemplate
131 | >>> template = PresetPromptTemplate.INSTRUCTION1.PROMPT_RESPONSE_STANDARD
132 | ```
133 |
134 | #### Custom Templates
135 |
136 | You can define custom templates for your prompts using the `PromptTemplate` class. This class provides two key parameters that control its functionality:
137 |
138 | + `content`: A string that defines the template structure, following Python string formatting conventions. You can include dynamic fields using placeholders like `{field_name}`, which will be automatically populated using attributes from the prompt object. Some common examples of auto-filled fields are shown below:
139 |
140 | | **Prompt Object Attribute** | **Placeholder Syntax** | **Description** |
141 | |--------------------------------------|------------------------------------|----------------------------------------------|
142 | | `prompt.message` | `{prompt[message]}` | The main prompt message |
143 | | `prompt.selected_response` | `{prompt[response]}` | The selected response for the prompt |
144 | | `prompt.date_modified` | `{prompt[date_modified]}` | Timestamp of the last modification |
145 | | `prompt.responses[2].message` | `{responses[2][message]}` | Message from the response at index 2 |
146 | | `prompt.responses[0].inference_time` | `{responses[0][inference_time]}` | Inference time for the response at index 0 |
147 |
148 |
149 | + `custom_map`: In addition to the attributes listed above, you can define and insert custom placeholders (e.g., `{field_name}`) and provide their values through a dictionary. When rendering the template, each placeholder will be replaced with its corresponding value from `custom_map`.
150 |
151 |
152 | Suppose you want to prepend an instruction to every prompt message. You can define and use a template as follows:
153 |
154 | ```pycon
155 | >>> template = PromptTemplate(content="{instruction}, {prompt[message]}", custom_map={"instruction": "Hi"})
156 | >>> prompt = Prompt(message="How are you?", template=template)
157 | >>> prompt.render()
158 | Hi, How are you?
159 | ```
160 |
161 | By using this dynamic structure, you can create flexible and sophisticated prompt templates with Memor. You can design specific schemas for your conversational or instructional formats when interacting with LLM.
162 |
163 | ## Issues & bug reports
164 |
165 | Just fill an issue and describe it. We'll check it ASAP! or send an email to [memor@openscilab.com](mailto:memor@openscilab.com "memor@openscilab.com").
166 |
167 | - Please complete the issue template
168 |
169 | You can also join our discord server
170 |
171 |
172 |
173 |
174 |
175 | ## Show your support
176 |
177 |
178 | ### Star this repo
179 |
180 | Give a ⭐️ if this project helped you!
181 |
182 | ### Donate to our project
183 | If you do like our project and we hope that you do, can you please support us? Our project is not and is never going to be working for profit. We need the money just so we can continue doing what we do ;-) .
184 |
185 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Security policy
2 |
3 | ## Supported versions
4 |
5 | | Version | Supported |
6 | | ------------- | ------------------ |
7 | | 0.6 | :white_check_mark: |
8 | | < 0.6 | :x: |
9 |
10 | ## Reporting a vulnerability
11 |
12 | Please report security vulnerabilities by email to [memor@openscilab.com](mailto:memor@openscilab.com "memor@openscilab.com").
13 |
14 | If the security vulnerability is accepted, a dedicated bugfix release will be issued as soon as possible (depending on the complexity of the fix).
--------------------------------------------------------------------------------
/autopep8.bat:
--------------------------------------------------------------------------------
1 | python -m autopep8 memor --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose --ignore=E721
2 | python -m autopep8 otherfiles --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose --ignore=E721
3 | python -m autopep8 tests --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose --ignore=E721
4 | python -m autopep8 setup.py --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose
5 |
--------------------------------------------------------------------------------
/autopep8.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | python -m autopep8 memor --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose --ignore=E721
3 | python -m autopep8 otherfiles --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose --ignore=E721
4 | python -m autopep8 tests --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose --ignore=E721
5 | python -m autopep8 setup.py --recursive --aggressive --aggressive --in-place --pep8-passes 2000 --max-line-length 120 --verbose
6 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | codecov:
2 | require_ci_to_pass: yes
3 |
4 | coverage:
5 | precision: 2
6 | round: up
7 | range: "70...100"
8 | status:
9 | patch:
10 | default:
11 | enabled: no
12 | project:
13 | default:
14 | threshold: 1%
15 |
--------------------------------------------------------------------------------
/dev-requirements.txt:
--------------------------------------------------------------------------------
1 | setuptools>=40.8.0
2 | vulture>=1.0
3 | bandit>=1.5.1
4 | pydocstyle>=3.0.0
5 | pytest>=4.3.1
6 | pytest-cov>=2.6.1
--------------------------------------------------------------------------------
/memor/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Memor modules."""
3 | from .params import MEMOR_VERSION, RenderFormat, LLMModel
4 | from .tokens_estimator import TokensEstimator
5 | from .template import PromptTemplate, PresetPromptTemplate
6 | from .prompt import Prompt, Role
7 | from .response import Response
8 | from .session import Session
9 | from .errors import MemorRenderError, MemorValidationError
10 |
11 | __version__ = MEMOR_VERSION
12 |
--------------------------------------------------------------------------------
/memor/errors.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Memor errors."""
3 |
4 |
5 | class MemorValidationError(ValueError):
6 | """Base class for validation errors in Memor."""
7 |
8 | pass
9 |
10 |
11 | class MemorRenderError(Exception):
12 | """Base class for render error in Memor."""
13 |
14 | pass
15 |
--------------------------------------------------------------------------------
/memor/functions.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Memor functions."""
3 | from typing import Any, Type
4 | import os
5 | import datetime
6 | import uuid
7 | from .params import INVALID_DATETIME_MESSAGE
8 | from .params import INVALID_PATH_MESSAGE, INVALID_STR_VALUE_MESSAGE
9 | from .params import INVALID_PROB_VALUE_MESSAGE
10 | from .params import INVALID_POSFLOAT_VALUE_MESSAGE
11 | from .params import INVALID_POSINT_VALUE_MESSAGE
12 | from .params import INVALID_CUSTOM_MAP_MESSAGE
13 | from .params import INVALID_BOOL_VALUE_MESSAGE
14 | from .params import INVALID_LIST_OF_X_MESSAGE
15 | from .params import INVALID_ID_MESSAGE
16 | from .errors import MemorValidationError
17 |
18 |
19 | def generate_message_id() -> str:
20 | """Generate message ID."""
21 | return str(uuid.uuid4())
22 |
23 |
24 | def _validate_message_id(message_id: str) -> bool:
25 | """
26 | Validate message ID.
27 |
28 | :param message_id: message ID
29 | """
30 | try:
31 | _ = uuid.UUID(message_id, version=4)
32 | except ValueError:
33 | raise MemorValidationError(INVALID_ID_MESSAGE)
34 | return True
35 |
36 |
37 | def get_time_utc() -> datetime.datetime:
38 | """
39 | Get time in UTC format.
40 |
41 | :return: UTC format time as a datetime object
42 | """
43 | return datetime.datetime.now(datetime.timezone.utc)
44 |
45 |
46 | def _validate_string(value: Any, parameter_name: str) -> bool:
47 | """
48 | Validate string.
49 |
50 | :param value: value
51 | :param parameter_name: parameter name
52 | """
53 | if not isinstance(value, str):
54 | raise MemorValidationError(INVALID_STR_VALUE_MESSAGE.format(parameter_name=parameter_name))
55 | return True
56 |
57 |
58 | def _validate_bool(value: Any, parameter_name: str) -> bool:
59 | """
60 | Validate boolean.
61 |
62 | :param value: value
63 | :param parameter_name: parameter name
64 | """
65 | if not isinstance(value, bool):
66 | raise MemorValidationError(INVALID_BOOL_VALUE_MESSAGE.format(parameter_name=parameter_name))
67 | return True
68 |
69 |
70 | def _can_convert_to_string(value: Any) -> bool:
71 | """
72 | Check if value can be converted to string.
73 |
74 | :param value: value
75 | """
76 | try:
77 | str(value)
78 | except Exception:
79 | return False
80 | return True
81 |
82 |
83 | def _validate_pos_int(value: Any, parameter_name: str) -> bool:
84 | """
85 | Validate positive integer.
86 |
87 | :param value: value
88 | :param parameter_name: parameter name
89 | """
90 | if not isinstance(value, int) or value < 0:
91 | raise MemorValidationError(INVALID_POSINT_VALUE_MESSAGE.format(parameter_name=parameter_name))
92 | return True
93 |
94 |
95 | def _validate_pos_float(value: Any, parameter_name: str) -> bool:
96 | """
97 | Validate positive float.
98 |
99 | :param value: value
100 | :param parameter_name: parameter name
101 | """
102 | if not isinstance(value, float) or value < 0:
103 | raise MemorValidationError(INVALID_POSFLOAT_VALUE_MESSAGE.format(parameter_name=parameter_name))
104 | return True
105 |
106 |
107 | def _validate_probability(value: Any, parameter_name: str) -> bool:
108 | """
109 | Validate probability (a float between 0 and 1).
110 |
111 | :param value: value
112 | :param parameter_name: parameter name
113 | """
114 | if not isinstance(value, float) or value < 0 or value > 1:
115 | raise MemorValidationError(INVALID_PROB_VALUE_MESSAGE.format(parameter_name=parameter_name))
116 | return True
117 |
118 |
119 | def _validate_list_of(value: Any, parameter_name: str, type_: Type, type_name: str) -> bool:
120 | """
121 | Validate list of values.
122 |
123 | :param value: value
124 | :param parameter_name: parameter name
125 | :param type_: type
126 | :param type_name: type name
127 | """
128 | if not isinstance(value, list):
129 | raise MemorValidationError(INVALID_LIST_OF_X_MESSAGE.format(parameter_name=parameter_name, type_name=type_name))
130 |
131 | if not all(isinstance(x, type_) for x in value):
132 | raise MemorValidationError(INVALID_LIST_OF_X_MESSAGE.format(parameter_name=parameter_name, type_name=type_name))
133 | return True
134 |
135 |
136 | def _validate_date_time(date_time: Any, parameter_name: str) -> bool:
137 | """
138 | Validate date time.
139 |
140 | :param date_time: date time
141 | :param parameter_name: parameter name
142 | """
143 | if not isinstance(date_time, datetime.datetime) or date_time.tzinfo is None:
144 | raise MemorValidationError(INVALID_DATETIME_MESSAGE.format(parameter_name=parameter_name))
145 | return True
146 |
147 |
148 | def _validate_path(path: Any) -> bool:
149 | """
150 | Validate path property.
151 |
152 | :param path: path
153 | """
154 | if not isinstance(path, str) or not os.path.exists(path):
155 | raise FileNotFoundError(INVALID_PATH_MESSAGE.format(path=path))
156 | return True
157 |
158 |
159 | def _validate_custom_map(custom_map: Any) -> bool:
160 | """
161 | Validate custom map a dictionary with keys and values that can be converted to strings.
162 |
163 | :param custom_map: custom map
164 | """
165 | if not isinstance(custom_map, dict):
166 | raise MemorValidationError(INVALID_CUSTOM_MAP_MESSAGE)
167 | if not all(_can_convert_to_string(k) and _can_convert_to_string(v) for k, v in custom_map.items()):
168 | raise MemorValidationError(INVALID_CUSTOM_MAP_MESSAGE)
169 | return True
170 |
--------------------------------------------------------------------------------
/memor/keywords.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Tokens estimator keywords."""
3 |
4 | COMMON_PREFIXES = {"un", "re", "in", "dis", "pre", "mis", "non", "over", "under", "sub", "trans"}
5 |
6 | COMMON_SUFFIXES = {"ing", "ed", "ly", "es", "s", "ment", "able", "ness", "tion", "ive", "ous"}
7 |
8 | PYTHON_KEYWORDS = {"if", "else", "elif", "while", "for", "def", "return", "import", "from", "class",
9 | "try", "except", "finally", "with", "as", "break", "continue", "pass", "lambda",
10 | "True", "False", "None", "and", "or", "not", "in", "is", "global", "nonlocal"}
11 |
12 | JAVASCRIPT_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break",
13 | "continue", "function", "return", "var", "let", "const", "class", "extends",
14 | "super", "import", "export", "try", "catch", "finally", "throw", "new",
15 | "delete", "typeof", "instanceof", "in", "void", "yield", "this", "async",
16 | "await", "static", "get", "set", "true", "false", "null", "undefined"}
17 |
18 | JAVA_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break",
19 | "continue", "return", "void", "int", "float", "double", "char", "long", "short",
20 | "boolean", "byte", "class", "interface", "extends", "implements", "new", "import",
21 | "package", "public", "private", "protected", "static", "final", "abstract",
22 | "try", "catch", "finally", "throw", "throws", "synchronized", "volatile", "transient",
23 | "native", "strictfp", "assert", "instanceof", "super", "this", "true", "false", "null"}
24 |
25 |
26 | C_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break", "continue",
27 | "return", "void", "char", "int", "float", "double", "short", "long", "signed",
28 | "unsigned", "struct", "union", "typedef", "enum", "const", "volatile", "extern",
29 | "register", "static", "auto", "sizeof", "goto"}
30 |
31 |
32 | CPP_KEYWORDS = C_KEYWORDS | {"new", "delete", "class",
33 | "public", "private", "protected", "namespace", "using", "template", "friend",
34 | "virtual", "inline", "operator", "explicit", "this", "true", "false", "nullptr"}
35 |
36 |
37 | CSHARP_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break",
38 | "continue", "return", "void", "int", "float", "double", "char", "long", "short",
39 | "bool", "byte", "class", "interface", "struct", "new", "namespace", "using",
40 | "public", "private", "protected", "static", "readonly", "const", "try", "catch",
41 | "finally", "throw", "async", "await", "true", "false", "null"}
42 |
43 |
44 | GO_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "break", "continue", "return",
45 | "func", "var", "const", "type", "struct", "interface", "map", "chan", "package",
46 | "import", "defer", "go", "select", "range", "fallthrough", "goto"}
47 |
48 |
49 | RUST_KEYWORDS = {"if", "else", "match", "loop", "for", "while", "break", "continue", "return",
50 | "fn", "let", "const", "static", "struct", "enum", "trait", "impl", "mod",
51 | "use", "crate", "super", "self", "as", "type", "where", "pub", "unsafe",
52 | "dyn", "move", "async", "await", "true", "false"}
53 |
54 |
55 | SWIFT_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "repeat", "break",
56 | "continue", "return", "func", "var", "let", "class", "struct", "enum", "protocol",
57 | "import", "defer", "as", "is", "try", "catch", "throw", "throws", "inout",
58 | "guard", "self", "super", "true", "false", "nil"}
59 |
60 |
61 | KOTLIN_KEYWORDS = {"if", "else", "when", "for", "while", "do", "break", "continue", "return",
62 | "fun", "val", "var", "class", "object", "interface", "enum", "sealed",
63 | "import", "package", "as", "is", "in", "try", "catch", "finally", "throw",
64 | "super", "this", "by", "constructor", "init", "companion", "override",
65 | "abstract", "final", "open", "private", "protected", "public", "internal",
66 | "inline", "suspend", "operator", "true", "false", "null"}
67 |
68 | TYPESCRIPT_KEYWORDS = JAVASCRIPT_KEYWORDS | {"interface", "type", "namespace", "declare"}
69 |
70 |
71 | PHP_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break",
72 | "continue", "return", "function", "class", "public", "private", "protected",
73 | "extends", "implements", "namespace", "use", "new", "static", "global",
74 | "const", "var", "echo", "print", "try", "catch", "finally", "throw", "true", "false", "null"}
75 |
76 | RUBY_KEYWORDS = {"if", "else", "elsif", "unless", "case", "when", "for", "while", "do", "break",
77 | "continue", "return", "def", "class", "module", "end", "begin", "rescue", "ensure",
78 | "yield", "super", "self", "alias", "true", "false", "nil"}
79 |
80 | SQL_KEYWORDS = {"SELECT", "INSERT", "UPDATE", "DELETE", "FROM", "WHERE", "JOIN", "INNER", "LEFT",
81 | "RIGHT", "FULL", "ON", "GROUP BY", "HAVING", "ORDER BY", "LIMIT", "OFFSET", "AS",
82 | "AND", "OR", "NOT", "NULL", "TRUE", "FALSE"}
83 |
84 | BASH_KEYWORDS = {"if", "else", "fi", "then", "elif", "case", "esac", "for", "while", "do", "done",
85 | "break", "continue", "return", "function", "export", "readonly", "local", "declare",
86 | "eval", "trap", "exec", "true", "false"}
87 |
88 | MATLAB_KEYWORDS = {"if", "else", "elseif", "end", "for", "while", "break", "continue", "return",
89 | "function", "global", "persistent", "switch", "case", "otherwise", "try", "catch",
90 | "true", "false"}
91 |
92 | R_KEYWORDS = {"if", "else", "repeat", "while", "for", "break", "next", "return", "function",
93 | "TRUE", "FALSE", "NULL", "Inf", "NaN", "NA"}
94 |
95 |
96 | PERL_KEYWORDS = {"if", "else", "elsif", "unless", "while", "for", "foreach", "do", "last", "next",
97 | "redo", "goto", "return", "sub", "package", "use", "require", "my", "local", "our",
98 | "state", "BEGIN", "END", "true", "false"}
99 |
100 | LUA_KEYWORDS = {"if", "else", "elseif", "then", "for", "while", "repeat", "until", "break", "return",
101 | "function", "end", "local", "do", "true", "false", "nil"}
102 |
103 | SCALA_KEYWORDS = {"if", "else", "match", "case", "for", "while", "do", "yield", "return",
104 | "def", "val", "var", "lazy", "class", "object", "trait", "extends",
105 | "with", "import", "package", "new", "this", "super", "implicit",
106 | "override", "abstract", "final", "sealed", "private", "protected",
107 | "public", "try", "catch", "finally", "throw", "true", "false", "null"}
108 |
109 | DART_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break",
110 | "continue", "return", "var", "final", "const", "dynamic", "void",
111 | "int", "double", "bool", "String", "class", "interface", "extends",
112 | "implements", "mixin", "import", "library", "part", "typedef",
113 | "this", "super", "as", "is", "new", "try", "catch", "finally", "throw",
114 | "async", "await", "true", "false", "null"}
115 |
116 | JULIA_KEYWORDS = {"if", "else", "elseif", "for", "while", "break", "continue", "return",
117 | "function", "macro", "module", "import", "using", "export", "struct",
118 | "mutable", "const", "begin", "end", "do", "try", "catch", "finally",
119 | "true", "false", "nothing"}
120 |
121 | HASKELL_KEYWORDS = {"if", "then", "else", "case", "of", "let", "in", "where", "do", "module",
122 | "import", "class", "instance", "data", "type", "newtype", "deriving",
123 | "default", "foreign", "safe", "unsafe", "qualified", "true", "false"}
124 |
125 | COBOL_KEYWORDS = {"ACCEPT", "ADD", "CALL", "CANCEL", "CLOSE", "COMPUTE", "CONTINUE", "DELETE",
126 | "DISPLAY", "DIVIDE", "EVALUATE", "EXIT", "GOBACK", "GO", "IF", "INITIALIZE",
127 | "INSPECT", "MERGE", "MOVE", "MULTIPLY", "OPEN", "PERFORM", "READ", "RETURN",
128 | "REWRITE", "SEARCH", "SET", "SORT", "START", "STOP", "STRING", "SUBTRACT",
129 | "UNSTRING", "WRITE", "END-IF", "END-PERFORM"}
130 |
131 | OBJECTIVEC_KEYWORDS = {"if", "else", "switch", "case", "default", "for", "while", "do", "break",
132 | "continue", "return", "void", "int", "float", "double", "char", "long", "short",
133 | "signed", "unsigned", "class", "interface", "protocol", "implementation",
134 | "try", "catch", "finally", "throw", "import", "self", "super", "atomic",
135 | "nonatomic", "strong", "weak", "retain", "copy", "assign", "true", "false", "nil"}
136 |
137 | FSHARP_KEYWORDS = {"if", "then", "else", "match", "with", "for", "while", "do", "done", "let",
138 | "rec", "in", "try", "finally", "raise", "exception", "function", "return",
139 | "type", "mutable", "namespace", "module", "open", "abstract", "override",
140 | "inherit", "base", "new", "true", "false", "null"}
141 |
142 | LISP_KEYWORDS = {"defun", "setq", "let", "lambda", "if", "cond", "loop", "dolist", "dotimes",
143 | "progn", "return", "function", "defmacro", "quote", "eval", "apply", "car",
144 | "cdr", "cons", "list", "mapcar", "format", "read", "print", "load", "t", "nil"}
145 |
146 | PROLOG_KEYWORDS = {"if", "else", "end", "fail", "true", "false", "not", "repeat", "is",
147 | "assert", "retract", "call", "findall", "bagof", "setof", "atom",
148 | "integer", "float", "char_code", "compound", "number", "var"}
149 |
150 | ADA_KEYWORDS = {"if", "then", "else", "elsif", "case", "when", "for", "while", "loop", "exit",
151 | "return", "procedure", "function", "package", "use", "is", "begin", "end",
152 | "record", "type", "constant", "exception", "raise", "declare", "private",
153 | "null", "true", "false"}
154 |
155 | DELPHI_KEYWORDS = {"if", "then", "else", "case", "of", "for", "while", "repeat", "until", "break",
156 | "continue", "begin", "end", "procedure", "function", "var", "const", "type",
157 | "class", "record", "interface", "implementation", "unit", "uses", "inherited",
158 | "try", "except", "finally", "raise", "private", "public", "protected", "published",
159 | "true", "false", "nil"}
160 |
161 | VB_KEYWORDS = {"If", "Then", "Else", "ElseIf", "End", "For", "Each", "While", "Do", "Loop",
162 | "Select", "Case", "Try", "Catch", "Finally", "Throw", "Return", "Function",
163 | "Sub", "Class", "Module", "Namespace", "Imports", "Inherits", "Implements",
164 | "Public", "Private", "Protected", "Friend", "Shared", "Static", "Dim", "Const",
165 | "New", "Me", "MyBase", "MyClass", "Not", "And", "Or", "True", "False", "Nothing"}
166 |
167 | HTML_KEYWORDS = {"html", "head", "title", "meta", "link", "style", "script", "body", "div", "span",
168 | "h1", "h2", "h3", "h4", "h5", "h6", "p", "a", "img", "ul", "ol", "li", "table",
169 | "tr", "td", "th", "thead", "tbody", "tfoot", "form", "input", "button", "label",
170 | "select", "option", "textarea", "fieldset", "legend", "iframe", "nav", "section",
171 | "article", "aside", "header", "footer", "main", "blockquote", "cite", "code",
172 | "pre", "em", "strong", "b", "i", "u", "small", "br", "hr"}
173 |
174 | CSS_KEYWORDS = {"color", "background", "border", "margin", "padding", "width", "height", "font-size",
175 | "font-family", "text-align", "display", "position", "top", "bottom", "left", "right",
176 | "z-index", "visibility", "opacity", "overflow", "cursor", "flex", "grid", "align-items",
177 | "justify-content", "box-shadow", "text-shadow", "animation", "transition", "transform",
178 | "clip-path", "content", "filter", "outline", "max-width", "min-width", "max-height",
179 | "min-height", "letter-spacing", "line-height", "white-space", "word-break"}
180 |
181 |
182 | PROGRAMMING_LANGUAGES = {
183 | "Python": PYTHON_KEYWORDS,
184 | "JavaScript": JAVASCRIPT_KEYWORDS,
185 | "Java": JAVA_KEYWORDS,
186 | "C": C_KEYWORDS,
187 | "C++": CPP_KEYWORDS,
188 | "C#": CSHARP_KEYWORDS,
189 | "Go": GO_KEYWORDS,
190 | "Rust": RUST_KEYWORDS,
191 | "Swift": SWIFT_KEYWORDS,
192 | "Kotlin": KOTLIN_KEYWORDS,
193 | "TypeScript": TYPESCRIPT_KEYWORDS,
194 | "PHP": PHP_KEYWORDS,
195 | "Ruby": RUBY_KEYWORDS,
196 | "SQL": SQL_KEYWORDS,
197 | "Bash": BASH_KEYWORDS,
198 | "MATLAB": MATLAB_KEYWORDS,
199 | "R": R_KEYWORDS,
200 | "Perl": PERL_KEYWORDS,
201 | "Lua": LUA_KEYWORDS,
202 | "Scala": SCALA_KEYWORDS,
203 | "Dart": DART_KEYWORDS,
204 | "Julia": JULIA_KEYWORDS,
205 | "Haskell": HASKELL_KEYWORDS,
206 | "COBOL": COBOL_KEYWORDS,
207 | "Objective-C": OBJECTIVEC_KEYWORDS,
208 | "F#": FSHARP_KEYWORDS,
209 | "Lisp": LISP_KEYWORDS,
210 | "Prolog": PROLOG_KEYWORDS,
211 | "Ada": ADA_KEYWORDS,
212 | "Delphi": DELPHI_KEYWORDS,
213 | "Visual Basic": VB_KEYWORDS,
214 | "HTML": HTML_KEYWORDS,
215 | "CSS": CSS_KEYWORDS}
216 |
217 | PROGRAMMING_LANGUAGES_KEYWORDS = set()
218 | for language in PROGRAMMING_LANGUAGES:
219 | PROGRAMMING_LANGUAGES_KEYWORDS = PROGRAMMING_LANGUAGES_KEYWORDS | PROGRAMMING_LANGUAGES[language]
220 |
--------------------------------------------------------------------------------
/memor/params.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Memor parameters and constants."""
3 | from enum import Enum
4 | MEMOR_VERSION = "0.6"
5 |
6 | DATE_TIME_FORMAT = "%Y-%m-%d %H:%M:%S %z"
7 |
8 | INVALID_PATH_MESSAGE = "Invalid path: must be a string and refer to an existing location. Given path: {path}"
9 | INVALID_STR_VALUE_MESSAGE = "Invalid value. `{parameter_name}` must be a string."
10 | INVALID_BOOL_VALUE_MESSAGE = "Invalid value. `{parameter_name}` must be a boolean."
11 | INVALID_POSFLOAT_VALUE_MESSAGE = "Invalid value. `{parameter_name}` must be a positive float."
12 | INVALID_POSINT_VALUE_MESSAGE = "Invalid value. `{parameter_name}` must be a positive integer."
13 | INVALID_PROB_VALUE_MESSAGE = "Invalid value. `{parameter_name}` must be a value between 0 and 1."
14 | INVALID_LIST_OF_X_MESSAGE = "Invalid value. `{parameter_name}` must be a list of {type_name}."
15 | INVALID_INT_OR_STR_MESSAGE = "Invalid value. `{parameter_name}` must be an integer or a string."
16 | INVALID_INT_OR_STR_SLICE_MESSAGE = "Invalid value. `{parameter_name}` must be an integer, string or a slice."
17 | INVALID_DATETIME_MESSAGE = "Invalid value. `{parameter_name}` must be a datetime object that includes timezone information."
18 | INVALID_TEMPLATE_MESSAGE = "Invalid template. It must be an instance of `PromptTemplate` or `PresetPromptTemplate`."
19 | INVALID_RESPONSE_MESSAGE = "Invalid response. It must be an instance of `Response`."
20 | INVALID_MESSAGE = "Invalid message. It must be an instance of `Prompt` or `Response`."
21 | INVALID_MESSAGE_STATUS_LEN_MESSAGE = "Invalid message status length. It must be equal to the number of messages."
22 | INVALID_CUSTOM_MAP_MESSAGE = "Invalid custom map: it must be a dictionary with keys and values that can be converted to strings."
23 | INVALID_ROLE_MESSAGE = "Invalid role. It must be an instance of Role enum."
24 | INVALID_ID_MESSAGE = "Invalid message ID. It must be a valid UUIDv4."
25 | INVALID_MODEL_MESSAGE = "Invalid model. It must be an instance of LLMModel enum or a string."
26 | INVALID_TEMPLATE_STRUCTURE_MESSAGE = "Invalid template structure. It should be a JSON object with proper fields."
27 | INVALID_PROMPT_STRUCTURE_MESSAGE = "Invalid prompt structure. It should be a JSON object with proper fields."
28 | INVALID_RESPONSE_STRUCTURE_MESSAGE = "Invalid response structure. It should be a JSON object with proper fields."
29 | INVALID_RENDER_FORMAT_MESSAGE = "Invalid render format. It must be an instance of RenderFormat enum."
30 | PROMPT_RENDER_ERROR_MESSAGE = "Prompt template and properties are incompatible."
31 | UNSUPPORTED_OPERAND_ERROR_MESSAGE = "Unsupported operand type(s) for {operator}: `{operand1}` and `{operand2}`"
32 | DATA_SAVE_SUCCESS_MESSAGE = "Everything seems good."
33 |
34 |
35 | class Role(Enum):
36 | """Role enum."""
37 |
38 | SYSTEM = "system"
39 | USER = "user"
40 | ASSISTANT = "assistant"
41 | DEFAULT = USER
42 |
43 |
44 | class RenderFormat(Enum):
45 | """Render format."""
46 |
47 | STRING = "STRING"
48 | OPENAI = "OPENAI"
49 | AI_STUDIO = "AI STUDIO"
50 | DICTIONARY = "DICTIONARY"
51 | ITEMS = "ITEMS"
52 | DEFAULT = STRING
53 |
54 |
55 | class LLMModel(Enum):
56 | """LLM model enum."""
57 |
58 | GPT_O1 = "gpt-o1"
59 | GPT_O1_MINI = "gpt-o1-mini"
60 | GPT_4O = "gpt-4o"
61 | GPT_4O_MINI = "gpt-4o-mini"
62 | GPT_4_TURBO = "gpt-4-turbo"
63 | GPT_4 = "gpt-4"
64 | GPT_4_VISION = "gpt-4-vision"
65 | GPT_3_5_TURBO = "gpt-3.5-turbo"
66 | DAVINCI = "davinci"
67 | BABBAGE = "babbage"
68 |
69 | CLAUDE_3_5_SONNET = "claude-3.5-sonnet"
70 | CLAUDE_3_OPUS = "claude-3-opus"
71 | CLAUDE_3_HAIKU = "claude-3-haiku"
72 | CLAUDE_2 = "claude-2"
73 | CLAUDE_INSTANT = "claude-instant"
74 |
75 | LLAMA3_70B = "llama3-70b"
76 | LLAMA3_8B = "llama3-8b"
77 | LLAMA_GUARD_3_8B = "llama-guard-3-8b"
78 |
79 | MISTRAL_7B = "mistral-7b"
80 | MIXTRAL_8X7B = "mixtral-8x7b"
81 | MIXTRAL_8X22B = "mixtral-8x22b"
82 | MISTRAL_NEMO = "mistral-nemo"
83 | MISTRAL_TINY = "mistral-tiny"
84 | MISTRAL_SMALL = "mistral-small"
85 | MISTRAL_MEDIUM = "mistral-medium"
86 | MISTRAL_LARGE = "mistral-large"
87 | CODESTRAL = "codestral"
88 | PIXTRAL = "pixtral-12b"
89 |
90 | GEMMA_7B = "gemma-7b"
91 | GEMMA2_9B = "gemma2-9b"
92 | GEMINI_1_PRO = "gemini-1-pro"
93 | GEMINI_1_ULTRA = "gemini-1-ultra"
94 | GEMINI_1_5_PRO = "gemini-1.5-pro"
95 | GEMINI_1_5_ULTRA = "gemini-1.5-ultra"
96 | GEMINI_1_5_FLASH = "gemini-1.5-flash"
97 | GEMINI_2_FLASH = "gemini-2-flash"
98 | GEMINI_2_PRO = "gemini-2-pro"
99 |
100 | DEEPSEEK_V3 = "deepseek-v3"
101 | DEEPSEEK_R1 = "deepseek-r1"
102 | DEEPSEEK_CODER = "deepseek-coder"
103 |
104 | PHI_2 = "phi-2"
105 | PHI_4 = "phi-4"
106 |
107 | QWEN_1_8B = "qwen-1.8b"
108 | QWEN_7B = "qwen-7b"
109 | QWEN_14B = "qwen-14b"
110 | QWEN_72B = "qwen-72b"
111 |
112 | YI_6B = "yi-6b"
113 | YI_9B = "yi-9b"
114 | YI_34B = "yi-34b"
115 |
116 | DEFAULT = "unknown"
117 |
--------------------------------------------------------------------------------
/memor/prompt.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Prompt class."""
3 | from typing import List, Dict, Union, Tuple, Any
4 | import datetime
5 | import json
6 | from .params import MEMOR_VERSION
7 | from .params import DATE_TIME_FORMAT
8 | from .params import RenderFormat, DATA_SAVE_SUCCESS_MESSAGE
9 | from .params import Role
10 | from .tokens_estimator import TokensEstimator
11 | from .params import INVALID_PROMPT_STRUCTURE_MESSAGE, INVALID_TEMPLATE_MESSAGE
12 | from .params import INVALID_ROLE_MESSAGE, INVALID_RESPONSE_MESSAGE
13 | from .params import PROMPT_RENDER_ERROR_MESSAGE
14 | from .params import INVALID_RENDER_FORMAT_MESSAGE
15 | from .errors import MemorValidationError, MemorRenderError
16 | from .functions import get_time_utc, generate_message_id
17 | from .functions import _validate_string, _validate_pos_int, _validate_list_of
18 | from .functions import _validate_path, _validate_message_id
19 | from .template import PromptTemplate, PresetPromptTemplate
20 | from .template import _BasicPresetPromptTemplate, _Instruction1PresetPromptTemplate, _Instruction2PresetPromptTemplate, _Instruction3PresetPromptTemplate
21 | from .response import Response
22 |
23 |
24 | class Prompt:
25 | """
26 | Prompt class.
27 |
28 | >>> from memor import Prompt, Role, Response
29 | >>> responses = [Response(message="I am fine."), Response(message="I am not fine."), Response(message="I am okay.")]
30 | >>> prompt = Prompt(message="Hello, how are you?", responses=responses)
31 | >>> prompt.message
32 | 'Hello, how are you?'
33 | >>> prompt.responses[1].message
34 | 'I am not fine.'
35 | """
36 |
37 | def __init__(
38 | self,
39 | message: str = "",
40 | responses: List[Response] = [],
41 | role: Role = Role.DEFAULT,
42 | tokens: int = None,
43 | template: Union[PresetPromptTemplate, PromptTemplate] = PresetPromptTemplate.DEFAULT,
44 | file_path: str = None,
45 | init_check: bool = True) -> None:
46 | """
47 | Prompt object initiator.
48 |
49 | :param message: prompt message
50 | :param responses: prompt responses
51 | :param role: prompt role
52 | :param tokens: tokens
53 | :param template: prompt template
54 | :param file_path: prompt file path
55 | :param init_check: initial check flag
56 | """
57 | self._message = ""
58 | self._tokens = None
59 | self._role = Role.DEFAULT
60 | self._template = PresetPromptTemplate.DEFAULT.value
61 | self._responses = []
62 | self._date_created = get_time_utc()
63 | self._mark_modified()
64 | self._memor_version = MEMOR_VERSION
65 | self._selected_response_index = 0
66 | self._selected_response = None
67 | self._id = None
68 | if file_path:
69 | self.load(file_path)
70 | else:
71 | if message:
72 | self.update_message(message)
73 | if role:
74 | self.update_role(role)
75 | if tokens:
76 | self.update_tokens(tokens)
77 | if responses:
78 | self.update_responses(responses)
79 | if template:
80 | self.update_template(template)
81 | self.select_response(index=self._selected_response_index)
82 | self._id = generate_message_id()
83 | _validate_message_id(self._id)
84 | if init_check:
85 | _ = self.render()
86 |
87 | def _mark_modified(self) -> None:
88 | """Mark modification."""
89 | self._date_modified = get_time_utc()
90 |
91 | def __eq__(self, other_prompt: "Prompt") -> bool:
92 | """
93 | Check prompts equality.
94 |
95 | :param other_prompt: another prompt
96 | """
97 | if isinstance(other_prompt, Prompt):
98 | return self._message == other_prompt._message and self._responses == other_prompt._responses and \
99 | self._role == other_prompt._role and self._template == other_prompt._template and \
100 | self._tokens == other_prompt._tokens
101 | return False
102 |
103 | def __str__(self) -> str:
104 | """Return string representation of Prompt."""
105 | return self.render(render_format=RenderFormat.STRING)
106 |
107 | def __repr__(self) -> str:
108 | """Return string representation of Prompt."""
109 | return "Prompt(message={message})".format(message=self._message)
110 |
111 | def __len__(self) -> int:
112 | """Return the length of the Prompt object."""
113 | try:
114 | return len(self.render(render_format=RenderFormat.STRING))
115 | except Exception:
116 | return 0
117 |
118 | def __copy__(self) -> "Prompt":
119 | """
120 | Return a copy of the Prompt object.
121 |
122 | :return: a copy of Prompt object
123 | """
124 | _class = self.__class__
125 | result = _class.__new__(_class)
126 | result.__dict__.update(self.__dict__)
127 | result.regenerate_id()
128 | return result
129 |
130 | def copy(self) -> "Prompt":
131 | """
132 | Return a copy of the Prompt object.
133 |
134 | :return: a copy of Prompt object
135 | """
136 | return self.__copy__()
137 |
138 | def add_response(self, response: Response, index: int = None) -> None:
139 | """
140 | Add a response to the prompt object.
141 |
142 | :param response: response
143 | :param index: index
144 | """
145 | if not isinstance(response, Response):
146 | raise MemorValidationError(INVALID_RESPONSE_MESSAGE)
147 | if index is None:
148 | self._responses.append(response)
149 | else:
150 | self._responses.insert(index, response)
151 | self._mark_modified()
152 |
153 | def remove_response(self, index: int) -> None:
154 | """
155 | Remove a response from the prompt object.
156 |
157 | :param index: index
158 | """
159 | self._responses.pop(index)
160 | self._mark_modified()
161 |
162 | def select_response(self, index: int) -> None:
163 | """
164 | Select a response as selected response.
165 |
166 | :param index: index
167 | """
168 | if len(self._responses) > 0:
169 | self._selected_response_index = index
170 | self._selected_response = self._responses[index]
171 | self._mark_modified()
172 |
173 | def update_responses(self, responses: List[Response]) -> None:
174 | """
175 | Update the prompt responses.
176 |
177 | :param responses: responses
178 | """
179 | _validate_list_of(responses, "responses", Response, "`Response`")
180 | self._responses = responses
181 | self._mark_modified()
182 |
183 | def update_message(self, message: str) -> None:
184 | """
185 | Update the prompt message.
186 |
187 | :param message: message
188 | """
189 | _validate_string(message, "message")
190 | self._message = message
191 | self._mark_modified()
192 |
193 | def update_role(self, role: Role) -> None:
194 | """
195 | Update the prompt role.
196 |
197 | :param role: role
198 | """
199 | if not isinstance(role, Role):
200 | raise MemorValidationError(INVALID_ROLE_MESSAGE)
201 | self._role = role
202 | self._mark_modified()
203 |
204 | def update_tokens(self, tokens: int) -> None:
205 | """
206 | Update the tokens.
207 |
208 | :param tokens: tokens
209 | """
210 | _validate_pos_int(tokens, "tokens")
211 | self._tokens = tokens
212 | self._mark_modified()
213 |
214 | def update_template(self, template: PromptTemplate) -> None:
215 | """
216 | Update the prompt template.
217 |
218 | :param template: template
219 | """
220 | if not isinstance(
221 | template,
222 | (PromptTemplate,
223 | _BasicPresetPromptTemplate,
224 | _Instruction1PresetPromptTemplate,
225 | _Instruction2PresetPromptTemplate,
226 | _Instruction3PresetPromptTemplate)):
227 | raise MemorValidationError(INVALID_TEMPLATE_MESSAGE)
228 | if isinstance(template, PromptTemplate):
229 | self._template = template
230 | if isinstance(
231 | template,
232 | (_BasicPresetPromptTemplate,
233 | _Instruction1PresetPromptTemplate,
234 | _Instruction2PresetPromptTemplate,
235 | _Instruction3PresetPromptTemplate)):
236 | self._template = template.value
237 | self._mark_modified()
238 |
239 | def save(self, file_path: str, save_template: bool = True) -> Dict[str, Any]:
240 | """
241 | Save method.
242 |
243 | :param file_path: prompt file path
244 | :param save_template: save template flag
245 | """
246 | result = {"status": True, "message": DATA_SAVE_SUCCESS_MESSAGE}
247 | try:
248 | with open(file_path, "w") as file:
249 | data = self.to_json(save_template=save_template)
250 | json.dump(data, file)
251 | except Exception as e:
252 | result["status"] = False
253 | result["message"] = str(e)
254 | return result
255 |
256 | def load(self, file_path: str) -> None:
257 | """
258 | Load method.
259 |
260 | :param file_path: prompt file path
261 | """
262 | _validate_path(file_path)
263 | with open(file_path, "r") as file:
264 | self.from_json(file.read())
265 |
266 | def from_json(self, json_object: Union[str, Dict[str, Any]]) -> None:
267 | """
268 | Load attributes from the JSON object.
269 |
270 | :param json_object: JSON object
271 | """
272 | try:
273 | if isinstance(json_object, str):
274 | loaded_obj = json.loads(json_object)
275 | else:
276 | loaded_obj = json_object.copy()
277 | self._message = loaded_obj["message"]
278 | self._tokens = loaded_obj.get("tokens", None)
279 | self._id = loaded_obj.get("id", generate_message_id())
280 | responses = []
281 | for response in loaded_obj["responses"]:
282 | response_obj = Response()
283 | response_obj.from_json(response)
284 | responses.append(response_obj)
285 | self._responses = responses
286 | self._role = Role(loaded_obj["role"])
287 | self._template = PresetPromptTemplate.DEFAULT.value
288 | if "template" in loaded_obj:
289 | template_obj = PromptTemplate()
290 | template_obj.from_json(loaded_obj["template"])
291 | self._template = template_obj
292 | self._memor_version = loaded_obj["memor_version"]
293 | self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT)
294 | self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT)
295 | self._selected_response_index = loaded_obj["selected_response_index"]
296 | self.select_response(index=self._selected_response_index)
297 | except Exception:
298 | raise MemorValidationError(INVALID_PROMPT_STRUCTURE_MESSAGE)
299 |
300 | def to_json(self, save_template: bool = True) -> Dict[str, Any]:
301 | """
302 | Convert the prompt to a JSON object.
303 |
304 | :param save_template: save template flag
305 | """
306 | data = self.to_dict(save_template=save_template).copy()
307 | for index, response in enumerate(data["responses"]):
308 | data["responses"][index] = response.to_json()
309 | if "template" in data:
310 | data["template"] = data["template"].to_json()
311 | data["role"] = data["role"].value
312 | data["date_created"] = datetime.datetime.strftime(data["date_created"], DATE_TIME_FORMAT)
313 | data["date_modified"] = datetime.datetime.strftime(data["date_modified"], DATE_TIME_FORMAT)
314 | return data
315 |
316 | def to_dict(self, save_template: bool = True) -> Dict[str, Any]:
317 | """
318 | Convert the prompt to a dictionary.
319 |
320 | :param save_template: save template flag
321 | """
322 | data = {
323 | "type": "Prompt",
324 | "message": self._message,
325 | "responses": self._responses.copy(),
326 | "selected_response_index": self._selected_response_index,
327 | "tokens": self._tokens,
328 | "role": self._role,
329 | "id": self._id,
330 | "template": self._template,
331 | "memor_version": MEMOR_VERSION,
332 | "date_created": self._date_created,
333 | "date_modified": self._date_modified,
334 | }
335 | if not save_template:
336 | del data["template"]
337 | return data
338 |
339 | def regenerate_id(self) -> None:
340 | """Regenerate ID."""
341 | new_id = self._id
342 | while new_id == self.id:
343 | new_id = generate_message_id()
344 | self._id = new_id
345 |
346 | @property
347 | def message(self) -> str:
348 | """Get the prompt message."""
349 | return self._message
350 |
351 | @property
352 | def responses(self) -> List[Response]:
353 | """Get the prompt responses."""
354 | return self._responses
355 |
356 | @property
357 | def role(self) -> Role:
358 | """Get the prompt role."""
359 | return self._role
360 |
361 | @property
362 | def tokens(self) -> int:
363 | """Get the prompt tokens."""
364 | return self._tokens
365 |
366 | @property
367 | def date_created(self) -> datetime.datetime:
368 | """Get the prompt creation date."""
369 | return self._date_created
370 |
371 | @property
372 | def date_modified(self) -> datetime.datetime:
373 | """Get the prompt object modification date."""
374 | return self._date_modified
375 |
376 | @property
377 | def template(self) -> PromptTemplate:
378 | """Get the prompt template."""
379 | return self._template
380 |
381 | @property
382 | def id(self) -> str:
383 | """Get the prompt ID."""
384 | return self._id
385 |
386 | @property
387 | def selected_response(self) -> Response:
388 | """Get the prompt selected response."""
389 | return self._selected_response
390 |
391 | def render(self, render_format: RenderFormat = RenderFormat.DEFAULT) -> Union[str,
392 | Dict[str, Any],
393 | List[Tuple[str, Any]]]:
394 | """
395 | Render method.
396 |
397 | :param render_format: render format
398 | """
399 | if not isinstance(render_format, RenderFormat):
400 | raise MemorValidationError(INVALID_RENDER_FORMAT_MESSAGE)
401 | try:
402 | format_kwargs = {"prompt": self.to_json(save_template=False)}
403 | if isinstance(self._selected_response, Response):
404 | format_kwargs.update({"response": self._selected_response.to_json()})
405 | responses_dicts = []
406 | for _, response in enumerate(self._responses):
407 | responses_dicts.append(response.to_json())
408 | format_kwargs.update({"responses": responses_dicts})
409 | custom_map = self._template._custom_map
410 | if custom_map is not None:
411 | format_kwargs.update(custom_map)
412 | content = self._template._content.format(**format_kwargs)
413 | prompt_dict = self.to_dict()
414 | prompt_dict["content"] = content
415 | if render_format == RenderFormat.OPENAI:
416 | return {"role": self._role.value, "content": content}
417 | if render_format == RenderFormat.AI_STUDIO:
418 | return {"role": self._role.value, "parts": [{"text": content}]}
419 | if render_format == RenderFormat.STRING:
420 | return content
421 | if render_format == RenderFormat.DICTIONARY:
422 | return prompt_dict
423 | if render_format == RenderFormat.ITEMS:
424 | return list(prompt_dict.items())
425 | except Exception:
426 | raise MemorRenderError(PROMPT_RENDER_ERROR_MESSAGE)
427 |
428 | def check_render(self) -> bool:
429 | """Check render."""
430 | try:
431 | _ = self.render()
432 | return True
433 | except Exception:
434 | return False
435 |
436 | def estimate_tokens(self, method: TokensEstimator = TokensEstimator.DEFAULT) -> int:
437 | """
438 | Estimate the number of tokens in the prompt message.
439 |
440 | :param method: token estimator method
441 | """
442 | return method(self.render(render_format=RenderFormat.STRING))
443 |
--------------------------------------------------------------------------------
/memor/response.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Response class."""
3 | from typing import List, Dict, Union, Tuple, Any
4 | import datetime
5 | import json
6 | from .params import MEMOR_VERSION
7 | from .params import DATE_TIME_FORMAT
8 | from .params import DATA_SAVE_SUCCESS_MESSAGE
9 | from .params import INVALID_RESPONSE_STRUCTURE_MESSAGE
10 | from .params import INVALID_ROLE_MESSAGE, INVALID_RENDER_FORMAT_MESSAGE, INVALID_MODEL_MESSAGE
11 | from .params import Role, RenderFormat, LLMModel
12 | from .tokens_estimator import TokensEstimator
13 | from .errors import MemorValidationError
14 | from .functions import get_time_utc, generate_message_id
15 | from .functions import _validate_string, _validate_pos_float, _validate_pos_int, _validate_message_id
16 | from .functions import _validate_date_time, _validate_probability, _validate_path
17 |
18 |
19 | class Response:
20 | """
21 | Response class.
22 |
23 | >>> from memor import Response, Role
24 | >>> response = Response(message="Hello!", score=0.9, role=Role.ASSISTANT, temperature=0.5, model=LLMModel.GPT_4)
25 | >>> response.message
26 | 'Hello!'
27 | """
28 |
29 | def __init__(
30 | self,
31 | message: str = "",
32 | score: float = None,
33 | role: Role = Role.ASSISTANT,
34 | temperature: float = None,
35 | tokens: int = None,
36 | inference_time: float = None,
37 | model: Union[LLMModel, str] = LLMModel.DEFAULT,
38 | date: datetime.datetime = get_time_utc(),
39 | file_path: str = None) -> None:
40 | """
41 | Response object initiator.
42 |
43 | :param message: response message
44 | :param score: response score
45 | :param role: response role
46 | :param temperature: temperature
47 | :param tokens: tokens
48 | :param inference_time: inference time
49 | :param model: agent model
50 | :param date: response date
51 | :param file_path: response file path
52 | """
53 | self._message = ""
54 | self._score = None
55 | self._role = Role.ASSISTANT
56 | self._temperature = None
57 | self._tokens = None
58 | self._inference_time = None
59 | self._model = LLMModel.DEFAULT.value
60 | self._date_created = get_time_utc()
61 | self._mark_modified()
62 | self._memor_version = MEMOR_VERSION
63 | self._id = None
64 | if file_path:
65 | self.load(file_path)
66 | else:
67 | if message:
68 | self.update_message(message)
69 | if score:
70 | self.update_score(score)
71 | if role:
72 | self.update_role(role)
73 | if model:
74 | self.update_model(model)
75 | if temperature:
76 | self.update_temperature(temperature)
77 | if tokens:
78 | self.update_tokens(tokens)
79 | if inference_time:
80 | self.update_inference_time(inference_time)
81 | if date:
82 | _validate_date_time(date, "date")
83 | self._date_created = date
84 | self._id = generate_message_id()
85 | _validate_message_id(self._id)
86 |
87 | def _mark_modified(self) -> None:
88 | """Mark modification."""
89 | self._date_modified = get_time_utc()
90 |
91 | def __eq__(self, other_response: "Response") -> bool:
92 | """
93 | Check responses equality.
94 |
95 | :param other_response: another response
96 | """
97 | if isinstance(other_response, Response):
98 | return self._message == other_response._message and self._score == other_response._score and self._role == other_response._role and self._temperature == other_response._temperature and \
99 | self._model == other_response._model and self._tokens == other_response._tokens and self._inference_time == other_response._inference_time
100 | return False
101 |
102 | def __str__(self) -> str:
103 | """Return string representation of Response."""
104 | return self.render(render_format=RenderFormat.STRING)
105 |
106 | def __repr__(self) -> str:
107 | """Return string representation of Response."""
108 | return "Response(message={message})".format(message=self._message)
109 |
110 | def __len__(self) -> int:
111 | """Return the length of the Response object."""
112 | return len(self.render(render_format=RenderFormat.STRING))
113 |
114 | def __copy__(self) -> "Response":
115 | """Return a copy of the Response object."""
116 | _class = self.__class__
117 | result = _class.__new__(_class)
118 | result.__dict__.update(self.__dict__)
119 | result.regenerate_id()
120 | return result
121 |
122 | def copy(self) -> "Response":
123 | """Return a copy of the Response object."""
124 | return self.__copy__()
125 |
126 | def update_message(self, message: str) -> None:
127 | """
128 | Update the response message.
129 |
130 | :param message: message
131 | """
132 | _validate_string(message, "message")
133 | self._message = message
134 | self._mark_modified()
135 |
136 | def update_score(self, score: float) -> None:
137 | """
138 | Update the response score.
139 |
140 | :param score: score
141 | """
142 | _validate_probability(score, "score")
143 | self._score = score
144 | self._mark_modified()
145 |
146 | def update_role(self, role: Role) -> None:
147 | """
148 | Update the response role.
149 |
150 | :param role: role
151 | """
152 | if not isinstance(role, Role):
153 | raise MemorValidationError(INVALID_ROLE_MESSAGE)
154 | self._role = role
155 | self._mark_modified()
156 |
157 | def update_temperature(self, temperature: float) -> None:
158 | """
159 | Update the temperature.
160 |
161 | :param temperature: temperature
162 | """
163 | _validate_pos_float(temperature, "temperature")
164 | self._temperature = temperature
165 | self._mark_modified()
166 |
167 | def update_tokens(self, tokens: int) -> None:
168 | """
169 | Update the tokens.
170 |
171 | :param tokens: tokens
172 | """
173 | _validate_pos_int(tokens, "tokens")
174 | self._tokens = tokens
175 | self._mark_modified()
176 |
177 | def update_inference_time(self, inference_time: float) -> None:
178 | """
179 | Update inference time.
180 |
181 | :param inference_time: inference time
182 | """
183 | _validate_pos_float(inference_time, "inference_time")
184 | self._inference_time = inference_time
185 | self._mark_modified()
186 |
187 | def update_model(self, model: Union[LLMModel, str]) -> None:
188 | """
189 | Update the agent model.
190 |
191 | :param model: model
192 | """
193 | if isinstance(model, str):
194 | self._model = model
195 | elif isinstance(model, LLMModel):
196 | self._model = model.value
197 | else:
198 | raise MemorValidationError(INVALID_MODEL_MESSAGE)
199 | self._mark_modified()
200 |
201 | def save(self, file_path: str) -> Dict[str, Any]:
202 | """
203 | Save method.
204 |
205 | :param file_path: response file path
206 | """
207 | result = {"status": True, "message": DATA_SAVE_SUCCESS_MESSAGE}
208 | try:
209 | with open(file_path, "w") as file:
210 | json.dump(self.to_json(), file)
211 | except Exception as e:
212 | result["status"] = False
213 | result["message"] = str(e)
214 | return result
215 |
216 | def load(self, file_path: str) -> None:
217 | """
218 | Load method.
219 |
220 | :param file_path: response file path
221 | """
222 | _validate_path(file_path)
223 | with open(file_path, "r") as file:
224 | self.from_json(file.read())
225 |
226 | def from_json(self, json_object: Union[str, Dict[str, Any]]) -> None:
227 | """
228 | Load attributes from the JSON object.
229 |
230 | :param json_object: JSON object
231 | """
232 | try:
233 | if isinstance(json_object, str):
234 | loaded_obj = json.loads(json_object)
235 | else:
236 | loaded_obj = json_object.copy()
237 | self._message = loaded_obj["message"]
238 | self._score = loaded_obj["score"]
239 | self._temperature = loaded_obj["temperature"]
240 | self._tokens = loaded_obj.get("tokens", None)
241 | self._inference_time = loaded_obj.get("inference_time", None)
242 | self._model = loaded_obj["model"]
243 | self._role = Role(loaded_obj["role"])
244 | self._memor_version = loaded_obj["memor_version"]
245 | self._id = loaded_obj.get("id", generate_message_id())
246 | self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT)
247 | self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT)
248 | except Exception:
249 | raise MemorValidationError(INVALID_RESPONSE_STRUCTURE_MESSAGE)
250 |
251 | def to_json(self) -> Dict[str, Any]:
252 | """Convert the response to a JSON object."""
253 | data = self.to_dict().copy()
254 | data["date_created"] = datetime.datetime.strftime(data["date_created"], DATE_TIME_FORMAT)
255 | data["date_modified"] = datetime.datetime.strftime(data["date_modified"], DATE_TIME_FORMAT)
256 | data["role"] = data["role"].value
257 | return data
258 |
259 | def to_dict(self) -> Dict[str, Any]:
260 | """Convert the response to a dictionary."""
261 | return {
262 | "type": "Response",
263 | "message": self._message,
264 | "score": self._score,
265 | "temperature": self._temperature,
266 | "tokens": self._tokens,
267 | "inference_time": self._inference_time,
268 | "role": self._role,
269 | "model": self._model,
270 | "id": self._id,
271 | "memor_version": MEMOR_VERSION,
272 | "date_created": self._date_created,
273 | "date_modified": self._date_modified,
274 | }
275 |
276 | def render(self,
277 | render_format: RenderFormat = RenderFormat.DEFAULT) -> Union[str,
278 | Dict[str, Any],
279 | List[Tuple[str, Any]]]:
280 | """
281 | Render the response.
282 |
283 | :param render_format: render format
284 | """
285 | if not isinstance(render_format, RenderFormat):
286 | raise MemorValidationError(INVALID_RENDER_FORMAT_MESSAGE)
287 | if render_format == RenderFormat.STRING:
288 | return self._message
289 | elif render_format == RenderFormat.OPENAI:
290 | return {"role": self._role.value,
291 | "content": self._message}
292 | elif render_format == RenderFormat.AI_STUDIO:
293 | return {"role": self._role.value,
294 | "parts": [{"text": self._message}]}
295 | elif render_format == RenderFormat.DICTIONARY:
296 | return self.to_dict()
297 | elif render_format == RenderFormat.ITEMS:
298 | return self.to_dict().items()
299 | return self._message
300 |
301 | def estimate_tokens(self, method: TokensEstimator = TokensEstimator.DEFAULT) -> int:
302 | """
303 | Estimate the number of tokens in the response message.
304 |
305 | :param method: token estimator method
306 | """
307 | return method(self.render(render_format=RenderFormat.STRING))
308 |
309 | def regenerate_id(self) -> None:
310 | """Regenerate ID."""
311 | new_id = self._id
312 | while new_id == self.id:
313 | new_id = generate_message_id()
314 | self._id = new_id
315 |
316 | @property
317 | def message(self) -> str:
318 | """Get the response message."""
319 | return self._message
320 |
321 | @property
322 | def score(self) -> float:
323 | """Get the response score."""
324 | return self._score
325 |
326 | @property
327 | def temperature(self) -> float:
328 | """Get the temperature."""
329 | return self._temperature
330 |
331 | @property
332 | def tokens(self) -> int:
333 | """Get the tokens."""
334 | return self._tokens
335 |
336 | @property
337 | def inference_time(self) -> float:
338 | """Get inference time."""
339 | return self._inference_time
340 |
341 | @property
342 | def role(self) -> Role:
343 | """Get the response role."""
344 | return self._role
345 |
346 | @property
347 | def model(self) -> str:
348 | """Get the agent model."""
349 | return self._model
350 |
351 | @property
352 | def id(self) -> str:
353 | """Get the response ID."""
354 | return self._id
355 |
356 | @property
357 | def date_created(self) -> datetime.datetime:
358 | """Get the response creation date."""
359 | return self._date_created
360 |
361 | @property
362 | def date_modified(self) -> datetime.datetime:
363 | """Get the response object modification date."""
364 | return self._date_modified
365 |
--------------------------------------------------------------------------------
/memor/session.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Session class."""
3 | from typing import List, Dict, Tuple, Any, Union, Generator
4 | import datetime
5 | import json
6 | from .params import MEMOR_VERSION
7 | from .params import DATE_TIME_FORMAT, DATA_SAVE_SUCCESS_MESSAGE
8 | from .params import INVALID_MESSAGE
9 | from .params import INVALID_MESSAGE_STATUS_LEN_MESSAGE, INVALID_RENDER_FORMAT_MESSAGE
10 | from .params import INVALID_INT_OR_STR_MESSAGE, INVALID_INT_OR_STR_SLICE_MESSAGE
11 | from .params import UNSUPPORTED_OPERAND_ERROR_MESSAGE
12 | from .params import RenderFormat
13 | from .tokens_estimator import TokensEstimator
14 | from .prompt import Prompt
15 | from .response import Response
16 | from .errors import MemorValidationError
17 | from .functions import get_time_utc
18 | from .functions import _validate_bool, _validate_path
19 | from .functions import _validate_list_of, _validate_string
20 |
21 |
22 | class Session:
23 | """Session class."""
24 |
25 | def __init__(
26 | self,
27 | title: str = None,
28 | messages: List[Union[Prompt, Response]] = [],
29 | file_path: str = None,
30 | init_check: bool = True) -> None:
31 | """
32 | Session object initiator.
33 |
34 | :param title: title
35 | :param messages: messages
36 | :param file_path: file path
37 | :param init_check: initial check flag
38 | """
39 | self._title = None
40 | self._render_counter = 0
41 | self._messages = []
42 | self._messages_status = []
43 | self._date_created = get_time_utc()
44 | self._mark_modified()
45 | self._memor_version = MEMOR_VERSION
46 | if file_path:
47 | self.load(file_path)
48 | else:
49 | if title:
50 | self.update_title(title)
51 | if messages:
52 | self.update_messages(messages)
53 | if init_check:
54 | _ = self.render(enable_counter=False)
55 |
56 | def _mark_modified(self) -> None:
57 | """Mark modification."""
58 | self._date_modified = get_time_utc()
59 |
60 | def __eq__(self, other_session: "Session") -> bool:
61 | """
62 | Check sessions equality.
63 |
64 | :param other_session: other session
65 | """
66 | if isinstance(other_session, Session):
67 | return self._title == other_session._title and self._messages == other_session._messages
68 | return False
69 |
70 | def __str__(self) -> str:
71 | """Return string representation of Session."""
72 | return self.render(render_format=RenderFormat.STRING, enable_counter=False)
73 |
74 | def __repr__(self) -> str:
75 | """Return string representation of Session."""
76 | return "Session(title={title})".format(title=self._title)
77 |
78 | def __len__(self) -> int:
79 | """Return the length of the Session object."""
80 | return len(self._messages)
81 |
82 | def __iter__(self) -> Generator[Union[Prompt, Response], None, None]:
83 | """Iterate through the Session object."""
84 | yield from self._messages
85 |
86 | def __add__(self, other_object: Union["Session", Response, Prompt]) -> "Session":
87 | """
88 | Addition method.
89 |
90 | :param other_object: other object
91 | """
92 | if isinstance(other_object, (Response, Prompt)):
93 | new_messages = self._messages + [other_object]
94 | return Session(title=self.title, messages=new_messages)
95 | if isinstance(other_object, Session):
96 | new_messages = self._messages + other_object._messages
97 | return Session(messages=new_messages)
98 | raise TypeError(
99 | UNSUPPORTED_OPERAND_ERROR_MESSAGE.format(
100 | operator="+",
101 | operand1="Session",
102 | operand2=type(other_object).__name__))
103 |
104 | def __radd__(self, other_object: Union["Session", Response, Prompt]) -> "Session":
105 | """
106 | Reverse addition method.
107 |
108 | :param other_object: other object
109 | """
110 | if isinstance(other_object, (Response, Prompt)):
111 | new_messages = [other_object] + self._messages
112 | return Session(title=self.title, messages=new_messages)
113 | raise TypeError(
114 | UNSUPPORTED_OPERAND_ERROR_MESSAGE.format(
115 | operator="+",
116 | operand1="Session",
117 | operand2=type(other_object).__name__))
118 |
119 | def __contains__(self, message: Union[Prompt, Response]) -> bool:
120 | """
121 | Check if the Session contains the given message.
122 |
123 | :param message: message
124 | """
125 | return message in self._messages
126 |
127 | def __getitem__(self, identifier: Union[int, slice, str]) -> Union[Prompt, Response]:
128 | """
129 | Get a message from the session object.
130 |
131 | :param identifier: message identifier (index/slice or id)
132 | """
133 | return self.get_message(identifier=identifier)
134 |
135 | def __copy__(self) -> "Session":
136 | """Return a copy of the Session object."""
137 | _class = self.__class__
138 | result = _class.__new__(_class)
139 | result.__dict__.update(self.__dict__)
140 | return result
141 |
142 | def copy(self) -> "Session":
143 | """Return a copy of the Session object."""
144 | return self.__copy__()
145 |
146 | def add_message(self,
147 | message: Union[Prompt, Response],
148 | status: bool = True,
149 | index: int = None) -> None:
150 | """
151 | Add a message to the session object.
152 |
153 | :param message: message
154 | :param status: status
155 | :param index: index
156 | """
157 | if not isinstance(message, (Prompt, Response)):
158 | raise MemorValidationError(INVALID_MESSAGE)
159 | _validate_bool(status, "status")
160 | if index is None:
161 | self._messages.append(message)
162 | self._messages_status.append(status)
163 | else:
164 | self._messages.insert(index, message)
165 | self._messages_status.insert(index, status)
166 | self._mark_modified()
167 |
168 | def get_message_by_index(self, index: Union[int, slice]) -> Union[Prompt, Response]:
169 | """
170 | Get a message from the session object by index/slice.
171 |
172 | :param index: index
173 | """
174 | return self._messages[index]
175 |
176 | def get_message_by_id(self, message_id: str) -> Union[Prompt, Response]:
177 | """
178 | Get a message from the session object by message id.
179 |
180 | :param message_id: message id
181 | """
182 | for index, message in enumerate(self._messages):
183 | if message.id == message_id:
184 | return self.get_message_by_index(index=index)
185 |
186 | def get_message(self, identifier: Union[int, slice, str]) -> Union[Prompt, Response]:
187 | """
188 | Get a message from the session object.
189 |
190 | :param identifier: message identifier (index/slice or id)
191 | """
192 | if isinstance(identifier, (int, slice)):
193 | return self.get_message_by_index(index=identifier)
194 | elif isinstance(identifier, str):
195 | return self.get_message_by_id(message_id=identifier)
196 | else:
197 | raise MemorValidationError(INVALID_INT_OR_STR_SLICE_MESSAGE.format(parameter_name="identifier"))
198 |
199 | def remove_message_by_index(self, index: int) -> None:
200 | """
201 | Remove a message from the session object by index.
202 |
203 | :param index: index
204 | """
205 | self._messages.pop(index)
206 | self._messages_status.pop(index)
207 | self._mark_modified()
208 |
209 | def remove_message_by_id(self, message_id: str) -> None:
210 | """
211 | Remove a message from the session object by message id.
212 |
213 | :param message_id: message id
214 | """
215 | for index, message in enumerate(self._messages):
216 | if message.id == message_id:
217 | self.remove_message_by_index(index=index)
218 | break
219 |
220 | def remove_message(self, identifier: Union[int, str]) -> None:
221 | """
222 | Remove a message from the session object.
223 |
224 | :param identifier: message identifier (index or id)
225 | """
226 | if isinstance(identifier, int):
227 | self.remove_message_by_index(index=identifier)
228 | elif isinstance(identifier, str):
229 | self.remove_message_by_id(message_id=identifier)
230 | else:
231 | raise MemorValidationError(INVALID_INT_OR_STR_MESSAGE.format(parameter_name="identifier"))
232 |
233 | def clear_messages(self) -> None:
234 | """Remove all messages."""
235 | self._messages = []
236 | self._messages_status = []
237 | self._mark_modified()
238 |
239 | def enable_message(self, index: int) -> None:
240 | """
241 | Enable a message.
242 |
243 | :param index: index
244 | """
245 | self._messages_status[index] = True
246 |
247 | def disable_message(self, index: int) -> None:
248 | """
249 | Disable a message.
250 |
251 | :param index: index
252 | """
253 | self._messages_status[index] = False
254 |
255 | def mask_message(self, index: int) -> None:
256 | """
257 | Mask a message.
258 |
259 | :param index: index
260 | """
261 | self.disable_message(index)
262 |
263 | def unmask_message(self, index: int) -> None:
264 | """
265 | Unmask a message.
266 |
267 | :param index: index
268 | """
269 | self.enable_message(index)
270 |
271 | def update_title(self, title: str) -> None:
272 | """
273 | Update the session title.
274 |
275 | :param title: title
276 | """
277 | _validate_string(title, "title")
278 | self._title = title
279 | self._mark_modified()
280 |
281 | def update_messages(self,
282 | messages: List[Union[Prompt, Response]],
283 | status: List[bool] = None) -> None:
284 | """
285 | Update the session messages.
286 |
287 | :param messages: messages
288 | :param status: status
289 | """
290 | _validate_list_of(messages, "messages", (Prompt, Response), "`Prompt` or `Response`")
291 | self._messages = messages
292 | if status:
293 | self.update_messages_status(status)
294 | else:
295 | self.update_messages_status(len(messages) * [True])
296 | self._mark_modified()
297 |
298 | def update_messages_status(self, status: List[bool]) -> None:
299 | """
300 | Update the session messages status.
301 |
302 | :param status: status
303 | """
304 | _validate_list_of(status, "status", bool, "booleans")
305 | if len(status) != len(self._messages):
306 | raise MemorValidationError(INVALID_MESSAGE_STATUS_LEN_MESSAGE)
307 | self._messages_status = status
308 |
309 | def save(self, file_path: str) -> Dict[str, Any]:
310 | """
311 | Save method.
312 |
313 | :param file_path: session file path
314 | """
315 | result = {"status": True, "message": DATA_SAVE_SUCCESS_MESSAGE}
316 | try:
317 | with open(file_path, "w") as file:
318 | data = self.to_json()
319 | json.dump(data, file)
320 | except Exception as e:
321 | result["status"] = False
322 | result["message"] = str(e)
323 | return result
324 |
325 | def load(self, file_path: str) -> None:
326 | """
327 | Load method.
328 |
329 | :param file_path: session file path
330 | """
331 | _validate_path(file_path)
332 | with open(file_path, "r") as file:
333 | self.from_json(file.read())
334 |
335 | def from_json(self, json_object: Union[str, Dict[str, Any]]) -> None:
336 | """
337 | Load attributes from the JSON object.
338 |
339 | :param json_object: JSON object
340 | """
341 | if isinstance(json_object, str):
342 | loaded_obj = json.loads(json_object)
343 | else:
344 | loaded_obj = json_object.copy()
345 | self._title = loaded_obj["title"]
346 | self._render_counter = loaded_obj.get("render_counter", 0)
347 | self._messages_status = loaded_obj["messages_status"]
348 | messages = []
349 | for message in loaded_obj["messages"]:
350 | if message["type"] == "Prompt":
351 | message_obj = Prompt()
352 | elif message["type"] == "Response":
353 | message_obj = Response()
354 | message_obj.from_json(message)
355 | messages.append(message_obj)
356 | self._messages = messages
357 | self._memor_version = loaded_obj["memor_version"]
358 | self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT)
359 | self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT)
360 |
361 | def to_json(self) -> Dict[str, Any]:
362 | """Convert the session to a JSON object."""
363 | data = self.to_dict().copy()
364 | for index, message in enumerate(data["messages"]):
365 | data["messages"][index] = message.to_json()
366 | data["date_created"] = datetime.datetime.strftime(data["date_created"], DATE_TIME_FORMAT)
367 | data["date_modified"] = datetime.datetime.strftime(data["date_modified"], DATE_TIME_FORMAT)
368 | return data
369 |
370 | def to_dict(self) -> Dict[str, Any]:
371 | """
372 | Convert the session to a dictionary.
373 |
374 | :return: dict
375 | """
376 | data = {
377 | "type": "Session",
378 | "title": self._title,
379 | "render_counter": self._render_counter,
380 | "messages": self._messages.copy(),
381 | "messages_status": self._messages_status.copy(),
382 | "memor_version": MEMOR_VERSION,
383 | "date_created": self._date_created,
384 | "date_modified": self._date_modified,
385 | }
386 | return data
387 |
388 | def render(self, render_format: RenderFormat = RenderFormat.DEFAULT,
389 | enable_counter: bool = True) -> Union[str, Dict[str, Any], List[Tuple[str, Any]]]:
390 | """
391 | Render method.
392 |
393 | :param render_format: render format
394 | :param enable_counter: render counter flag
395 | """
396 | if not isinstance(render_format, RenderFormat):
397 | raise MemorValidationError(INVALID_RENDER_FORMAT_MESSAGE)
398 | result = None
399 | if render_format in [RenderFormat.OPENAI, RenderFormat.AI_STUDIO]:
400 | result = []
401 | for message in self._messages:
402 | if isinstance(message, Session):
403 | result.extend(message.render(render_format=render_format))
404 | else:
405 | result.append(message.render(render_format=render_format))
406 | else:
407 | content = ""
408 | session_dict = self.to_dict()
409 | for message in self._messages:
410 | content += message.render(render_format=RenderFormat.STRING) + "\n"
411 | session_dict["content"] = content
412 | if render_format == RenderFormat.STRING:
413 | result = content
414 | if render_format == RenderFormat.DICTIONARY:
415 | result = session_dict
416 | if render_format == RenderFormat.ITEMS:
417 | result = list(session_dict.items())
418 | if enable_counter:
419 | self._render_counter += 1
420 | self._mark_modified()
421 | return result
422 |
423 | def check_render(self) -> bool:
424 | """Check render."""
425 | try:
426 | _ = self.render(enable_counter=False)
427 | return True
428 | except Exception:
429 | return False
430 |
431 | def estimate_tokens(self, method: TokensEstimator = TokensEstimator.DEFAULT) -> int:
432 | """
433 | Estimate the number of tokens in the session.
434 |
435 | :param method: token estimator method
436 | """
437 | return method(self.render(render_format=RenderFormat.STRING, enable_counter=False))
438 |
439 | @property
440 | def date_created(self) -> datetime.datetime:
441 | """Get the session creation date."""
442 | return self._date_created
443 |
444 | @property
445 | def date_modified(self) -> datetime.datetime:
446 | """Get the session object modification date."""
447 | return self._date_modified
448 |
449 | @property
450 | def title(self) -> str:
451 | """Get the session title."""
452 | return self._title
453 |
454 | @property
455 | def render_counter(self) -> int:
456 | """Get the render counter."""
457 | return self._render_counter
458 |
459 | @property
460 | def messages(self) -> List[Union[Prompt, Response]]:
461 | """Get the session messages."""
462 | return self._messages
463 |
464 | @property
465 | def messages_status(self) -> List[bool]:
466 | """Get the session messages status."""
467 | return self._messages_status
468 |
469 | @property
470 | def masks(self) -> List[bool]:
471 | """Get the session masks."""
472 | return [not x for x in self._messages_status]
473 |
--------------------------------------------------------------------------------
/memor/template.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Template class."""
3 | from typing import Dict, Any, Union
4 | import json
5 | import datetime
6 | from enum import Enum
7 | from .params import DATE_TIME_FORMAT
8 | from .params import DATA_SAVE_SUCCESS_MESSAGE
9 | from .params import INVALID_TEMPLATE_STRUCTURE_MESSAGE
10 | from .params import MEMOR_VERSION
11 | from .errors import MemorValidationError
12 | from .functions import get_time_utc
13 | from .functions import _validate_path, _validate_custom_map
14 | from .functions import _validate_string
15 |
16 |
17 | class PromptTemplate:
18 | r"""
19 | Prompt template.
20 |
21 | >>> template = PromptTemplate(content="Take a deep breath\n{prompt_message}!", title="Greeting")
22 | >>> template.title
23 | 'Greeting'
24 | """
25 |
26 | def __init__(
27 | self,
28 | content: str = None,
29 | file_path: str = None,
30 | title: str = None,
31 | custom_map: Dict[str, str] = None) -> None:
32 | """
33 | Prompt template object initiator.
34 |
35 | :param content: template content
36 | :param file_path: template file path
37 | :param title: template title
38 | :param custom_map: custom map
39 | """
40 | self._content = None
41 | self._title = None
42 | self._date_created = get_time_utc()
43 | self._mark_modified()
44 | self._memor_version = MEMOR_VERSION
45 | self._custom_map = None
46 | if file_path:
47 | self.load(file_path)
48 | else:
49 | if title:
50 | self.update_title(title)
51 | if content:
52 | self.update_content(content)
53 | if custom_map:
54 | self.update_map(custom_map)
55 |
56 | def _mark_modified(self) -> None:
57 | """Mark modification."""
58 | self._date_modified = get_time_utc()
59 |
60 | def __eq__(self, other_template: "PromptTemplate") -> bool:
61 | """
62 | Check templates equality.
63 |
64 | :param other_template: another template
65 | """
66 | if isinstance(other_template, PromptTemplate):
67 | return self._content == other_template._content and self._title == other_template._title and self._custom_map == other_template._custom_map
68 | return False
69 |
70 | def __str__(self) -> str:
71 | """Return string representation of PromptTemplate."""
72 | return self._content
73 |
74 | def __repr__(self) -> str:
75 | """Return string representation of PromptTemplate."""
76 | return "PromptTemplate(content={content})".format(content=self._content)
77 |
78 | def __copy__(self) -> "PromptTemplate":
79 | """Return a copy of the PromptTemplate object."""
80 | _class = self.__class__
81 | result = _class.__new__(_class)
82 | result.__dict__.update(self.__dict__)
83 | return result
84 |
85 | def copy(self) -> "PromptTemplate":
86 | """Return a copy of the PromptTemplate object."""
87 | return self.__copy__()
88 |
89 | def update_title(self, title: str) -> None:
90 | """
91 | Update title.
92 |
93 | :param title: title
94 | """
95 | _validate_string(title, "title")
96 | self._title = title
97 | self._mark_modified()
98 |
99 | def update_content(self, content: str) -> None:
100 | """
101 | Update content.
102 |
103 | :param content: content
104 | """
105 | _validate_string(content, "content")
106 | self._content = content
107 | self._mark_modified()
108 |
109 | def update_map(self, custom_map: Dict[str, str]) -> None:
110 | """
111 | Update custom map.
112 |
113 | :param custom_map: custom map
114 | """
115 | _validate_custom_map(custom_map)
116 | self._custom_map = custom_map
117 | self._mark_modified()
118 |
119 | def save(self, file_path: str) -> Dict[str, Any]:
120 | """
121 | Save method.
122 |
123 | :param file_path: template file path
124 | """
125 | result = {"status": True, "message": DATA_SAVE_SUCCESS_MESSAGE}
126 | try:
127 | with open(file_path, "w") as file:
128 | json.dump(self.to_json(), file)
129 | except Exception as e:
130 | result["status"] = False
131 | result["message"] = str(e)
132 | return result
133 |
134 | def load(self, file_path: str) -> None:
135 | """
136 | Load method.
137 |
138 | :param file_path: template file path
139 | """
140 | _validate_path(file_path)
141 | with open(file_path, "r") as file:
142 | self.from_json(file.read())
143 |
144 | def from_json(self, json_object: Union[str, Dict[str, Any]]) -> None:
145 | """
146 | Load attributes from the JSON object.
147 |
148 | :param json_object: JSON object
149 | """
150 | try:
151 | if isinstance(json_object, str):
152 | loaded_obj = json.loads(json_object)
153 | else:
154 | loaded_obj = json_object.copy()
155 | self._content = loaded_obj["content"]
156 | self._title = loaded_obj["title"]
157 | self._memor_version = loaded_obj["memor_version"]
158 | self._custom_map = loaded_obj["custom_map"]
159 | self._date_created = datetime.datetime.strptime(loaded_obj["date_created"], DATE_TIME_FORMAT)
160 | self._date_modified = datetime.datetime.strptime(loaded_obj["date_modified"], DATE_TIME_FORMAT)
161 | except Exception:
162 | raise MemorValidationError(INVALID_TEMPLATE_STRUCTURE_MESSAGE)
163 |
164 | def to_json(self) -> Dict[str, Any]:
165 | """Convert PromptTemplate to json."""
166 | data = self.to_dict().copy()
167 | data["date_created"] = datetime.datetime.strftime(data["date_created"], DATE_TIME_FORMAT)
168 | data["date_modified"] = datetime.datetime.strftime(data["date_modified"], DATE_TIME_FORMAT)
169 | return data
170 |
171 | def to_dict(self) -> Dict[str, Any]:
172 | """Convert PromptTemplate to dict."""
173 | return {
174 | "title": self._title,
175 | "content": self._content,
176 | "memor_version": MEMOR_VERSION,
177 | "custom_map": self._custom_map.copy(),
178 | "date_created": self._date_created,
179 | "date_modified": self._date_modified,
180 | }
181 |
182 | @property
183 | def content(self) -> str:
184 | """Get the PromptTemplate content."""
185 | return self._content
186 |
187 | @property
188 | def title(self) -> str:
189 | """Get the PromptTemplate title."""
190 | return self._title
191 |
192 | @property
193 | def date_created(self) -> datetime.datetime:
194 | """Get the PromptTemplate creation date."""
195 | return self._date_created
196 |
197 | @property
198 | def date_modified(self) -> datetime.datetime:
199 | """Get the PromptTemplate modification date."""
200 | return self._date_modified
201 |
202 | @property
203 | def custom_map(self) -> Dict[str, str]:
204 | """Get the PromptTemplate custom map."""
205 | return self._custom_map
206 |
207 |
208 | PROMPT_INSTRUCTION1 = "I'm providing you with a history of a previous conversation. Please consider this context when responding to my new question.\n"
209 | PROMPT_INSTRUCTION2 = "Here is the context from a prior conversation. Please learn from this information and use it to provide a thoughtful and context-aware response to my next questions.\n"
210 | PROMPT_INSTRUCTION3 = "I am sharing a record of a previous discussion. Use this information to provide a consistent and relevant answer to my next query.\n"
211 |
212 | BASIC_PROMPT_CONTENT = "{instruction}{prompt[message]}"
213 | BASIC_RESPONSE_CONTENT = "{instruction}{response[message]}"
214 | BASIC_RESPONSE0_CONTENT = "{instruction}{responses[0][message]}"
215 | BASIC_RESPONSE1_CONTENT = "{instruction}{responses[1][message]}"
216 | BASIC_RESPONSE2_CONTENT = "{instruction}{responses[2][message]}"
217 | BASIC_RESPONSE3_CONTENT = "{instruction}{responses[3][message]}"
218 | BASIC_PROMPT_CONTENT_LABEL = "{instruction}Prompt: {prompt[message]}"
219 | BASIC_RESPONSE_CONTENT_LABEL = "{instruction}Response: {response[message]}"
220 | BASIC_RESPONSE0_CONTENT_LABEL = "{instruction}Response: {responses[0][message]}"
221 | BASIC_RESPONSE1_CONTENT_LABEL = "{instruction}Response: {responses[1][message]}"
222 | BASIC_RESPONSE2_CONTENT_LABEL = "{instruction}Response: {responses[2][message]}"
223 | BASIC_RESPONSE3_CONTENT_LABEL = "{instruction}Response: {responses[3][message]}"
224 | BASIC_PROMPT_RESPONSE_STANDARD_CONTENT = "{instruction}Prompt: {prompt[message]}\nResponse: {response[message]}"
225 | BASIC_PROMPT_RESPONSE_FULL_CONTENT = """{instruction}
226 | Prompt:
227 | Message: {prompt[message]}
228 | Role: {prompt[role]}
229 | Tokens: {prompt[tokens]}
230 | Date: {prompt[date]}
231 | Response:
232 | Message: {response[message]}
233 | Role: {response[role]}
234 | Temperature: {response[temperature]}
235 | Model: {response[model]}
236 | Score: {response[score]}
237 | Tokens: {response[tokens]}
238 | Inference Time: {response[inference_time]}
239 | Date: {response[date]}"""
240 |
241 |
242 | class _BasicPresetPromptTemplate(Enum):
243 | """Preset basic-prompt templates."""
244 |
245 | PROMPT = PromptTemplate(content=BASIC_PROMPT_CONTENT, title="Basic/Prompt", custom_map={"instruction": ""})
246 | RESPONSE = PromptTemplate(
247 | content=BASIC_RESPONSE_CONTENT,
248 | title="Basic/Response",
249 | custom_map={
250 | "instruction": ""})
251 | RESPONSE0 = PromptTemplate(
252 | content=BASIC_RESPONSE0_CONTENT,
253 | title="Basic/Response0",
254 | custom_map={
255 | "instruction": ""})
256 | RESPONSE1 = PromptTemplate(
257 | content=BASIC_RESPONSE1_CONTENT,
258 | title="Basic/Response1",
259 | custom_map={
260 | "instruction": ""})
261 | RESPONSE2 = PromptTemplate(
262 | content=BASIC_RESPONSE2_CONTENT,
263 | title="Basic/Response2",
264 | custom_map={
265 | "instruction": ""})
266 | RESPONSE3 = PromptTemplate(
267 | content=BASIC_RESPONSE3_CONTENT,
268 | title="Basic/Response3",
269 | custom_map={
270 | "instruction": ""})
271 | PROMPT_WITH_LABEL = PromptTemplate(
272 | content=BASIC_PROMPT_CONTENT_LABEL,
273 | title="Basic/Prompt With Label",
274 | custom_map={
275 | "instruction": ""})
276 | RESPONSE_WITH_LABEL = PromptTemplate(
277 | content=BASIC_RESPONSE_CONTENT_LABEL,
278 | title="Basic/Response With Label",
279 | custom_map={
280 | "instruction": ""})
281 | RESPONSE0_WITH_LABEL = PromptTemplate(
282 | content=BASIC_RESPONSE0_CONTENT_LABEL,
283 | title="Basic/Response0 With Label",
284 | custom_map={
285 | "instruction": ""})
286 | RESPONSE1_WITH_LABEL = PromptTemplate(
287 | content=BASIC_RESPONSE1_CONTENT_LABEL,
288 | title="Basic/Response1 With Label",
289 | custom_map={
290 | "instruction": ""})
291 | RESPONSE2_WITH_LABEL = PromptTemplate(
292 | content=BASIC_RESPONSE2_CONTENT_LABEL,
293 | title="Basic/Response2 With Label",
294 | custom_map={
295 | "instruction": ""})
296 | RESPONSE3_WITH_LABEL = PromptTemplate(
297 | content=BASIC_RESPONSE3_CONTENT_LABEL,
298 | title="Basic/Response3 With Label",
299 | custom_map={
300 | "instruction": ""})
301 | PROMPT_RESPONSE_STANDARD = PromptTemplate(
302 | content=BASIC_PROMPT_RESPONSE_STANDARD_CONTENT,
303 | title="Basic/Prompt-Response Standard",
304 | custom_map={
305 | "instruction": ""})
306 | PROMPT_RESPONSE_FULL = PromptTemplate(
307 | content=BASIC_PROMPT_RESPONSE_FULL_CONTENT,
308 | title="Basic/Prompt-Response Full",
309 | custom_map={
310 | "instruction": ""})
311 |
312 |
313 | class _Instruction1PresetPromptTemplate(Enum):
314 | """Preset instruction1-prompt templates."""
315 |
316 | PROMPT = PromptTemplate(
317 | content=BASIC_PROMPT_CONTENT,
318 | title="Instruction1/Prompt",
319 | custom_map={
320 | "instruction": PROMPT_INSTRUCTION1})
321 | RESPONSE = PromptTemplate(
322 | content=BASIC_RESPONSE_CONTENT,
323 | title="Instruction1/Response",
324 | custom_map={
325 | "instruction": PROMPT_INSTRUCTION1})
326 | RESPONSE0 = PromptTemplate(
327 | content=BASIC_RESPONSE0_CONTENT,
328 | title="Instruction1/Response0",
329 | custom_map={
330 | "instruction": PROMPT_INSTRUCTION1})
331 | RESPONSE1 = PromptTemplate(
332 | content=BASIC_RESPONSE1_CONTENT,
333 | title="Instruction1/Response1",
334 | custom_map={
335 | "instruction": PROMPT_INSTRUCTION1})
336 | RESPONSE2 = PromptTemplate(
337 | content=BASIC_RESPONSE2_CONTENT,
338 | title="Instruction1/Response2",
339 | custom_map={
340 | "instruction": PROMPT_INSTRUCTION1})
341 | RESPONSE3 = PromptTemplate(
342 | content=BASIC_RESPONSE3_CONTENT,
343 | title="Instruction1/Response3",
344 | custom_map={
345 | "instruction": PROMPT_INSTRUCTION1})
346 | PROMPT_WITH_LABEL = PromptTemplate(
347 | content=BASIC_PROMPT_CONTENT_LABEL,
348 | title="Instruction1/Prompt With Label",
349 | custom_map={
350 | "instruction": PROMPT_INSTRUCTION1})
351 | RESPONSE_WITH_LABEL = PromptTemplate(
352 | content=BASIC_RESPONSE_CONTENT_LABEL,
353 | title="Instruction1/Response With Label",
354 | custom_map={
355 | "instruction": PROMPT_INSTRUCTION1})
356 | RESPONSE0_WITH_LABEL = PromptTemplate(
357 | content=BASIC_RESPONSE0_CONTENT_LABEL,
358 | title="Instruction1/Response0 With Label",
359 | custom_map={
360 | "instruction": PROMPT_INSTRUCTION1})
361 | RESPONSE1_WITH_LABEL = PromptTemplate(
362 | content=BASIC_RESPONSE1_CONTENT_LABEL,
363 | title="Instruction1/Response1 With Label",
364 | custom_map={
365 | "instruction": PROMPT_INSTRUCTION1})
366 | RESPONSE2_WITH_LABEL = PromptTemplate(
367 | content=BASIC_RESPONSE2_CONTENT_LABEL,
368 | title="Instruction1/Response2 With Label",
369 | custom_map={
370 | "instruction": PROMPT_INSTRUCTION1})
371 | RESPONSE3_WITH_LABEL = PromptTemplate(
372 | content=BASIC_RESPONSE3_CONTENT_LABEL,
373 | title="Instruction1/Response3 With Label",
374 | custom_map={
375 | "instruction": PROMPT_INSTRUCTION1})
376 | PROMPT_RESPONSE_STANDARD = PromptTemplate(
377 | content=BASIC_PROMPT_RESPONSE_STANDARD_CONTENT,
378 | title="Instruction1/Prompt-Response Standard",
379 | custom_map={
380 | "instruction": PROMPT_INSTRUCTION1})
381 | PROMPT_RESPONSE_FULL = PromptTemplate(
382 | content=BASIC_PROMPT_RESPONSE_FULL_CONTENT,
383 | title="Instruction1/Prompt-Response Full",
384 | custom_map={
385 | "instruction": PROMPT_INSTRUCTION1})
386 |
387 |
388 | class _Instruction2PresetPromptTemplate(Enum):
389 | """Preset instruction2-prompt templates."""
390 |
391 | PROMPT = PromptTemplate(
392 | content=BASIC_PROMPT_CONTENT,
393 | title="Instruction2/Prompt",
394 | custom_map={
395 | "instruction": PROMPT_INSTRUCTION2})
396 | RESPONSE = PromptTemplate(
397 | content=BASIC_RESPONSE_CONTENT,
398 | title="Instruction2/Response",
399 | custom_map={
400 | "instruction": PROMPT_INSTRUCTION2})
401 | RESPONSE0 = PromptTemplate(
402 | content=BASIC_RESPONSE0_CONTENT,
403 | title="Instruction2/Response0",
404 | custom_map={
405 | "instruction": PROMPT_INSTRUCTION2})
406 | RESPONSE1 = PromptTemplate(
407 | content=BASIC_RESPONSE1_CONTENT,
408 | title="Instruction2/Response1",
409 | custom_map={
410 | "instruction": PROMPT_INSTRUCTION2})
411 | RESPONSE2 = PromptTemplate(
412 | content=BASIC_RESPONSE2_CONTENT,
413 | title="Instruction2/Response2",
414 | custom_map={
415 | "instruction": PROMPT_INSTRUCTION2})
416 | RESPONSE3 = PromptTemplate(
417 | content=BASIC_RESPONSE3_CONTENT,
418 | title="Instruction2/Response3",
419 | custom_map={
420 | "instruction": PROMPT_INSTRUCTION2})
421 | PROMPT_WITH_LABEL = PromptTemplate(
422 | content=BASIC_PROMPT_CONTENT_LABEL,
423 | title="Instruction2/Prompt With Label",
424 | custom_map={
425 | "instruction": PROMPT_INSTRUCTION2})
426 | RESPONSE_WITH_LABEL = PromptTemplate(
427 | content=BASIC_RESPONSE_CONTENT_LABEL,
428 | title="Instruction2/Response With Label",
429 | custom_map={
430 | "instruction": PROMPT_INSTRUCTION2})
431 | RESPONSE0_WITH_LABEL = PromptTemplate(
432 | content=BASIC_RESPONSE0_CONTENT_LABEL,
433 | title="Instruction2/Response0 With Label",
434 | custom_map={
435 | "instruction": PROMPT_INSTRUCTION2})
436 | RESPONSE1_WITH_LABEL = PromptTemplate(
437 | content=BASIC_RESPONSE1_CONTENT_LABEL,
438 | title="Instruction2/Response1 With Label",
439 | custom_map={
440 | "instruction": PROMPT_INSTRUCTION2})
441 | RESPONSE2_WITH_LABEL = PromptTemplate(
442 | content=BASIC_RESPONSE2_CONTENT_LABEL,
443 | title="Instruction2/Response2 With Label",
444 | custom_map={
445 | "instruction": PROMPT_INSTRUCTION2})
446 | RESPONSE3_WITH_LABEL = PromptTemplate(
447 | content=BASIC_RESPONSE3_CONTENT_LABEL,
448 | title="Instruction2/Response3 With Label",
449 | custom_map={
450 | "instruction": PROMPT_INSTRUCTION2})
451 | PROMPT_RESPONSE_STANDARD = PromptTemplate(
452 | content=BASIC_PROMPT_RESPONSE_STANDARD_CONTENT,
453 | title="Instruction2/Prompt-Response Standard",
454 | custom_map={
455 | "instruction": PROMPT_INSTRUCTION2})
456 | PROMPT_RESPONSE_FULL = PromptTemplate(
457 | content=BASIC_PROMPT_RESPONSE_FULL_CONTENT,
458 | title="Instruction2/Prompt-Response Full",
459 | custom_map={
460 | "instruction": PROMPT_INSTRUCTION2})
461 |
462 |
463 | class _Instruction3PresetPromptTemplate(Enum):
464 | """Preset instruction3-prompt templates."""
465 |
466 | PROMPT = PromptTemplate(
467 | content=BASIC_PROMPT_CONTENT,
468 | title="Instruction3/Prompt",
469 | custom_map={
470 | "instruction": PROMPT_INSTRUCTION3})
471 | RESPONSE = PromptTemplate(
472 | content=BASIC_RESPONSE_CONTENT,
473 | title="Instruction3/Response",
474 | custom_map={
475 | "instruction": PROMPT_INSTRUCTION3})
476 | RESPONSE0 = PromptTemplate(
477 | content=BASIC_RESPONSE0_CONTENT,
478 | title="Instruction3/Response0",
479 | custom_map={
480 | "instruction": PROMPT_INSTRUCTION3})
481 | RESPONSE1 = PromptTemplate(
482 | content=BASIC_RESPONSE1_CONTENT,
483 | title="Instruction3/Response1",
484 | custom_map={
485 | "instruction": PROMPT_INSTRUCTION3})
486 | RESPONSE2 = PromptTemplate(
487 | content=BASIC_RESPONSE2_CONTENT,
488 | title="Instruction3/Response2",
489 | custom_map={
490 | "instruction": PROMPT_INSTRUCTION3})
491 | RESPONSE3 = PromptTemplate(
492 | content=BASIC_RESPONSE3_CONTENT,
493 | title="Instruction3/Response3",
494 | custom_map={
495 | "instruction": PROMPT_INSTRUCTION3})
496 | PROMPT_WITH_LABEL = PromptTemplate(
497 | content=BASIC_PROMPT_CONTENT_LABEL,
498 | title="Instruction3/Prompt With Label",
499 | custom_map={
500 | "instruction": PROMPT_INSTRUCTION3})
501 | RESPONSE_WITH_LABEL = PromptTemplate(
502 | content=BASIC_RESPONSE_CONTENT_LABEL,
503 | title="Instruction3/Response With Label",
504 | custom_map={
505 | "instruction": PROMPT_INSTRUCTION3})
506 | RESPONSE0_WITH_LABEL = PromptTemplate(
507 | content=BASIC_RESPONSE0_CONTENT_LABEL,
508 | title="Instruction3/Response0 With Label",
509 | custom_map={
510 | "instruction": PROMPT_INSTRUCTION3})
511 | RESPONSE1_WITH_LABEL = PromptTemplate(
512 | content=BASIC_RESPONSE1_CONTENT_LABEL,
513 | title="Instruction3/Response1 With Label",
514 | custom_map={
515 | "instruction": PROMPT_INSTRUCTION3})
516 | RESPONSE2_WITH_LABEL = PromptTemplate(
517 | content=BASIC_RESPONSE2_CONTENT_LABEL,
518 | title="Instruction3/Response2 With Label",
519 | custom_map={
520 | "instruction": PROMPT_INSTRUCTION3})
521 | RESPONSE3_WITH_LABEL = PromptTemplate(
522 | content=BASIC_RESPONSE3_CONTENT_LABEL,
523 | title="Instruction3/Response3 With Label",
524 | custom_map={
525 | "instruction": PROMPT_INSTRUCTION3})
526 | PROMPT_RESPONSE_STANDARD = PromptTemplate(
527 | content=BASIC_PROMPT_RESPONSE_STANDARD_CONTENT,
528 | title="Instruction3/Prompt-Response Standard",
529 | custom_map={
530 | "instruction": PROMPT_INSTRUCTION3})
531 | PROMPT_RESPONSE_FULL = PromptTemplate(
532 | content=BASIC_PROMPT_RESPONSE_FULL_CONTENT,
533 | title="Instruction3/Prompt-Response Full",
534 | custom_map={
535 | "instruction": PROMPT_INSTRUCTION3})
536 |
537 |
538 | class PresetPromptTemplate:
539 | """Preset prompt templates."""
540 |
541 | BASIC = _BasicPresetPromptTemplate
542 | INSTRUCTION1 = _Instruction1PresetPromptTemplate
543 | INSTRUCTION2 = _Instruction2PresetPromptTemplate
544 | INSTRUCTION3 = _Instruction3PresetPromptTemplate
545 | DEFAULT = BASIC.PROMPT
546 |
--------------------------------------------------------------------------------
/memor/tokens_estimator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Tokens estimator functions."""
3 |
4 | import re
5 | from enum import Enum
6 | from typing import Set, List
7 | from .keywords import PROGRAMMING_LANGUAGES_KEYWORDS
8 | from .keywords import COMMON_PREFIXES, COMMON_SUFFIXES
9 |
10 |
11 | def _is_code_snippet(message: str) -> bool:
12 | """
13 | Check if the message is a code snippet based on common coding symbols.
14 |
15 | :param message: The input message to check.
16 | :return: Boolean indicating if the message is a code snippet.
17 | """
18 | return bool(re.search(r"[=<>+\-*/{}();]", message))
19 |
20 |
21 | def _preprocess_message(message: str, is_code: bool) -> str:
22 | """
23 | Preprocess message by replacing contractions in non-code text.
24 |
25 | :param message: The input message to preprocess.
26 | :param is_code: Boolean indicating if the message is a code.
27 | :return: Preprocessed message.
28 | """
29 | if not is_code:
30 | return re.sub(r"(?<=\w)'(?=\w)", " ", message)
31 | return message
32 |
33 |
34 | def _tokenize_message(message: str) -> List[str]:
35 | """
36 | Tokenize the message based on words, symbols, and numbers.
37 |
38 | :param message: The input message to tokenize.
39 | :return: List of tokens.
40 | """
41 | return re.findall(r"[A-Za-z_][A-Za-z0-9_]*|[+\-*/=<>(){}[\],.:;]|\"[^\"]*\"|'[^']*'|\d+|\S", message)
42 |
43 |
44 | def _count_code_tokens(token: str, common_keywords: Set[str]) -> int:
45 | """
46 | Count tokens in code snippets considering different token types.
47 |
48 | :param token: The token to count.
49 | :param common_keywords: Set of common keywords in programming languages.
50 | :return: Count of tokens.
51 | """
52 | if token in common_keywords or re.match(r"[+\-*/=<>(){}[\],.:;]", token):
53 | return 1
54 | if token.isdigit():
55 | return max(1, len(token) // 4)
56 | if token.startswith(("'", '"')) and token.endswith(("'", '"')):
57 | return max(1, len(token) // 6)
58 | if "_" in token:
59 | return len(token.split("_"))
60 | if re.search(r"[A-Z]", token):
61 | return len(re.findall(r"[A-Z][a-z]*", token))
62 | return 1
63 |
64 |
65 | def _count_text_tokens(token: str, prefixes: Set[str], suffixes: Set[str]) -> int:
66 | """
67 | Count tokens in text based on prefixes, suffixes, and subwords.
68 |
69 | :param token: The token to count.
70 | :param prefixes: Set of common prefixes.
71 | :param suffixes: Set of common suffixes.
72 | :return: Token count.
73 | """
74 | if len(token) == 1 and not token.isalnum():
75 | return 1
76 | if token.isdigit():
77 | return max(1, len(token) // 4)
78 | prefix_count = sum(token.startswith(p) for p in prefixes if len(token) > len(p) + 3)
79 | suffix_count = sum(token.endswith(s) for s in suffixes if len(token) > len(s) + 3)
80 | parts = re.findall(r"[aeiou]+|[^aeiou]+", token)
81 | subword_count = max(1, len(parts) // 2)
82 |
83 | return prefix_count + suffix_count + subword_count
84 |
85 |
86 | def universal_tokens_estimator(message: str) -> int:
87 | """
88 | Estimate the number of tokens in a given text or code snippet.
89 |
90 | :param message: The input text or code snippet to estimate tokens for.
91 | :return: Estimated number of tokens.
92 | """
93 | is_code = _is_code_snippet(message)
94 | message = _preprocess_message(message, is_code)
95 | tokens = _tokenize_message(message)
96 |
97 | return sum(
98 | _count_code_tokens(
99 | token,
100 | PROGRAMMING_LANGUAGES_KEYWORDS) if is_code else _count_text_tokens(
101 | token,
102 | COMMON_PREFIXES,
103 | COMMON_SUFFIXES) for token in tokens)
104 |
105 |
106 | def _openai_tokens_estimator(text: str) -> int:
107 | """
108 | Estimate the number of tokens in a given text for OpenAI's models.
109 |
110 | :param text: The input text to estimate tokens for.
111 | :return: Estimated number of tokens.
112 | """
113 | char_count = len(text)
114 | token_estimate = char_count / 4
115 |
116 | space_count = text.count(" ")
117 | punctuation_count = sum(1 for char in text if char in ",.?!;:")
118 | token_estimate += (space_count + punctuation_count) * 0.5
119 |
120 | if any(keyword in text for keyword in PROGRAMMING_LANGUAGES_KEYWORDS):
121 | token_estimate *= 1.1
122 |
123 | newline_count = text.count("\n")
124 | token_estimate += newline_count * 0.8
125 |
126 | long_word_penalty = sum(len(word) / 10 for word in text.split() if len(word) > 15)
127 | token_estimate += long_word_penalty
128 |
129 | if "http" in text:
130 | token_estimate *= 1.1
131 |
132 | rare_char_count = sum(1 for char in text if ord(char) > 10000)
133 | token_estimate += rare_char_count * 0.8
134 |
135 | return token_estimate
136 |
137 |
138 | def openai_tokens_estimator_gpt_3_5(text: str) -> int:
139 | """
140 | Estimate the number of tokens in a given text for OpenAI's GPT-3.5 Turbo model.
141 |
142 | :param text: The input text to estimate tokens for.
143 | :return: Estimated number of tokens.
144 | """
145 | token_estimate = _openai_tokens_estimator(text)
146 | return int(max(1, token_estimate))
147 |
148 |
149 | def openai_tokens_estimator_gpt_4(text: str) -> int:
150 | """
151 | Estimate the number of tokens in a given text for OpenAI's GPT-4 model.
152 |
153 | :param text: The input text to estimate tokens for.
154 | :return: Estimated number of tokens.
155 | """
156 | token_estimate = _openai_tokens_estimator(text)
157 | token_estimate *= 1.05 # Adjusting for GPT-4's tokenization
158 | return int(max(1, token_estimate))
159 |
160 |
161 | class TokensEstimator(Enum):
162 | """Token estimator enum."""
163 |
164 | UNIVERSAL = universal_tokens_estimator
165 | OPENAI_GPT_3_5 = openai_tokens_estimator_gpt_3_5
166 | OPENAI_GPT_4 = openai_tokens_estimator_gpt_4
167 | DEFAULT = UNIVERSAL
168 |
--------------------------------------------------------------------------------
/otherfiles/RELEASE.md:
--------------------------------------------------------------------------------
1 | # Memor Release Instructions
2 |
3 | **Last Update: 2024-12-27**
4 |
5 | 1. Create the `release` branch under `dev`
6 | 2. Update all version tags
7 | 1. `setup.py`
8 | 2. `README.md`
9 | 3. `otherfiles/version_check.py`
10 | 4. `otherfiles/meta.yaml`
11 | 5. `memor/params.py`
12 | 3. Update `CHANGELOG.md`
13 | 1. Add a new header under `Unreleased` section (Example: `## [0.1] - 2022-08-17`)
14 | 2. Add a new compare link to the end of the file (Example: `[0.2]: https://github.com/openscilab/memor/compare/v0.1...v0.2`)
15 | 3. Update `dev` compare link (Example: `[Unreleased]: https://github.com/openscilab/memor/compare/v0.2...dev`)
16 | 4. Update `.github/ISSUE_TEMPLATE/bug_report.yml`
17 | 1. Add new version tag to `Memor version` dropbox options
18 | 5. Create a PR from `release` to `dev`
19 | 1. Title: `Version x.x` (Example: `Version 0.1`)
20 | 2. Tag all related issues
21 | 3. Labels: `release`
22 | 4. Set milestone
23 | 5. Wait for all CI pass
24 | 6. Need review (**2** reviewers)
25 | 7. Squash and merge
26 | 8. Delete `release` branch
27 | 6. Merge `dev` branch into `main`
28 | 1. `git checkout main`
29 | 2. `git merge dev`
30 | 3. `git push origin main`
31 | 4. Wait for all CI pass
32 | 7. Create a new release
33 | 1. Target branch: `main`
34 | 2. Tag: `vx.x` (Example: `v0.1`)
35 | 3. Title: `Version x.x` (Example: `Version 0.1`)
36 | 4. Copy changelogs
37 | 5. Tag all related issues
38 | 8. Bump!!
39 | 9. Close this version issues
40 | 10. Close milestone
--------------------------------------------------------------------------------
/otherfiles/donation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openscilab/memor/325e17b8ffc5eb62a9bb4044df63021b9968b94b/otherfiles/donation.png
--------------------------------------------------------------------------------
/otherfiles/meta.yaml:
--------------------------------------------------------------------------------
1 | {% set name = "memor" %}
2 | {% set version = "0.6" %}
3 |
4 | package:
5 | name: {{ name|lower }}
6 | version: {{ version }}
7 | source:
8 | git_url: https://github.com/openscilab/memor
9 | git_rev: v{{ version }}
10 | build:
11 | noarch: python
12 | number: 0
13 | script: {{ PYTHON }} -m pip install . -vv
14 | requirements:
15 | host:
16 | - pip
17 | - setuptools
18 | - python >=3.7
19 | run:
20 | - python >=3.7
21 | about:
22 | home: https://github.com/openscilab/memor
23 | license: MIT
24 | license_family: MIT
25 | summary: Memor: A Python Library for Managing and Transferring Conversational Memory Across LLMs
26 | description: |
27 | Memor is a library designed to help users manage the memory of their interactions with Large Language Models (LLMs). It enables users to seamlessly access and utilize the history of their conversations when prompting LLMs. That would create a more personalized and context-aware experience. Memor stands out by allowing users to transfer conversational history across different LLMs, eliminating cold starts where models don\'t have information about user and their preferences. Users can select specific parts of past interactions with one LLM and share them with another.By bridging the gap between isolated LLM instances, Memor revolutionizes the way users interact with AI by making transitions between models smoother.
28 |
29 | Website: https://openscilab.com
30 |
31 | Repo: https://github.com/openscilab/memor
32 | extra:
33 | recipe-maintainers:
34 | - sepandhaghighi
35 | - sadrasabouri
36 |
--------------------------------------------------------------------------------
/otherfiles/requirements-splitter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Requirements splitter."""
3 |
4 | test_req = ""
5 |
6 | with open('dev-requirements.txt', 'r') as f:
7 | for line in f:
8 | if '==' not in line:
9 | test_req += line
10 |
11 | with open('test-requirements.txt', 'w') as f:
12 | f.write(test_req)
13 |
--------------------------------------------------------------------------------
/otherfiles/version_check.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Version-check script."""
3 | import os
4 | import sys
5 | import codecs
6 | Failed = 0
7 | MEMOR_VERSION = "0.6"
8 |
9 |
10 | SETUP_ITEMS = [
11 | "version='{0}'",
12 | 'https://github.com/openscilab/memor/tarball/v{0}']
13 | README_ITEMS = [
14 | "[Version {0}](https://github.com/openscilab/memor/archive/v{0}.zip)",
15 | "pip install memor=={0}"]
16 | CHANGELOG_ITEMS = [
17 | "## [{0}]",
18 | "https://github.com/openscilab/memor/compare/v{0}...dev",
19 | "[{0}]:"]
20 | PARAMS_ITEMS = ['MEMOR_VERSION = "{0}"']
21 | META_ITEMS = ['% set version = "{0}" %']
22 | ISSUE_TEMPLATE_ITEMS = ["- Memor {0}"]
23 | SECURITY_ITEMS = ["| {0} | :white_check_mark: |", "| < {0} | :x: |"]
24 |
25 | FILES = {
26 | os.path.join("otherfiles", "meta.yaml"): META_ITEMS,
27 | "setup.py": SETUP_ITEMS,
28 | "README.md": README_ITEMS,
29 | "CHANGELOG.md": CHANGELOG_ITEMS,
30 | "SECURITY.md": SECURITY_ITEMS,
31 | os.path.join("memor", "params.py"): PARAMS_ITEMS,
32 | os.path.join(".github", "ISSUE_TEMPLATE", "bug_report.yml"): ISSUE_TEMPLATE_ITEMS,
33 | }
34 |
35 | TEST_NUMBER = len(FILES)
36 |
37 |
38 | def print_result(failed: bool = False) -> None:
39 | """
40 | Print final result.
41 |
42 | :param failed: failed flag
43 | """
44 | message = "Version tag tests "
45 | if not failed:
46 | print("\n" + message + "passed!")
47 | else:
48 | print("\n" + message + "failed!")
49 | print("Passed : " + str(TEST_NUMBER - Failed) + "/" + str(TEST_NUMBER))
50 |
51 |
52 | if __name__ == "__main__":
53 | for file_name in FILES:
54 | try:
55 | file_content = codecs.open(
56 | file_name, "r", "utf-8", 'ignore').read()
57 | for test_item in FILES[file_name]:
58 | if file_content.find(test_item.format(MEMOR_VERSION)) == -1:
59 | print("Incorrect version tag in " + file_name)
60 | Failed += 1
61 | break
62 | except Exception as e:
63 | Failed += 1
64 | print("Error in " + file_name + "\n" + "Message : " + str(e))
65 |
66 | if Failed == 0:
67 | print_result(False)
68 | sys.exit(0)
69 | else:
70 | print_result(True)
71 | sys.exit(1)
72 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openscilab/memor/325e17b8ffc5eb62a9bb4044df63021b9968b94b/requirements.txt
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Setup module."""
3 | try:
4 | from setuptools import setup
5 | except ImportError:
6 | from distutils.core import setup
7 |
8 |
9 | def get_requires() -> list:
10 | """Read requirements.txt."""
11 | requirements = open("requirements.txt", "r").read()
12 | return list(filter(lambda x: x != "", requirements.split()))
13 |
14 |
15 | def read_description() -> str:
16 | """Read README.md and CHANGELOG.md."""
17 | try:
18 | with open("README.md") as r:
19 | description = "\n"
20 | description += r.read()
21 | with open("CHANGELOG.md") as c:
22 | description += "\n"
23 | description += c.read()
24 | return description
25 | except Exception:
26 | return '''Memor is a library designed to help users manage the memory of their interactions with Large Language Models (LLMs).
27 | It enables users to seamlessly access and utilize the history of their conversations when prompting LLMs.
28 | That would create a more personalized and context-aware experience.
29 | Memor stands out by allowing users to transfer conversational history across different LLMs, eliminating cold starts where models don\'t have information about user and their preferences.
30 | Users can select specific parts of past interactions with one LLM and share them with another.By bridging the gap between isolated LLM instances, Memor revolutionizes the way users interact with AI by making transitions between models smoother.'''
31 |
32 |
33 | setup(
34 | name='memor',
35 | packages=[
36 | 'memor', ],
37 | version='0.6',
38 | description='Memor: A Python Library for Managing and Transferring Conversational Memory Across LLMs',
39 | long_description=read_description(),
40 | long_description_content_type='text/markdown',
41 | author='Memor Development Team',
42 | author_email='memor@openscilab.com',
43 | url='https://github.com/openscilab/memor',
44 | download_url='https://github.com/openscilab/memor/tarball/v0.6',
45 | keywords="llm memory management conversational history ai agent",
46 | project_urls={
47 | 'Source': 'https://github.com/openscilab/memor',
48 | },
49 | install_requires=get_requires(),
50 | python_requires='>=3.7',
51 | classifiers=[
52 | 'Development Status :: 3 - Alpha',
53 | 'Natural Language :: English',
54 | 'License :: OSI Approved :: MIT License',
55 | 'Operating System :: OS Independent',
56 | 'Programming Language :: Python :: 3.7',
57 | 'Programming Language :: Python :: 3.8',
58 | 'Programming Language :: Python :: 3.9',
59 | 'Programming Language :: Python :: 3.10',
60 | 'Programming Language :: Python :: 3.11',
61 | 'Programming Language :: Python :: 3.12',
62 | 'Programming Language :: Python :: 3.13',
63 | 'Intended Audience :: Developers',
64 | 'Intended Audience :: Education',
65 | 'Intended Audience :: End Users/Desktop',
66 | 'Intended Audience :: Manufacturing',
67 | 'Intended Audience :: Science/Research',
68 | 'Topic :: Education',
69 | 'Topic :: Scientific/Engineering',
70 | 'Topic :: Scientific/Engineering :: Information Analysis',
71 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
72 | ],
73 | license='MIT',
74 | )
75 |
--------------------------------------------------------------------------------
/tests/test_prompt.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import uuid
3 | import copy
4 | import pytest
5 | from memor import Prompt, Response, Role, LLMModel
6 | from memor import PresetPromptTemplate, PromptTemplate
7 | from memor import RenderFormat, MemorValidationError, MemorRenderError
8 | from memor import TokensEstimator
9 |
10 | TEST_CASE_NAME = "Prompt tests"
11 |
12 |
13 | def test_message1():
14 | prompt = Prompt(message="Hello, how are you?")
15 | assert prompt.message == "Hello, how are you?"
16 |
17 |
18 | def test_message2():
19 | prompt = Prompt(message="Hello, how are you?")
20 | prompt.update_message("What's Up?")
21 | assert prompt.message == "What's Up?"
22 |
23 |
24 | def test_message3():
25 | prompt = Prompt(message="Hello, how are you?")
26 | with pytest.raises(MemorValidationError, match=r"Invalid value. `message` must be a string."):
27 | prompt.update_message(22)
28 |
29 |
30 | def test_tokens1():
31 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
32 | assert prompt.tokens is None
33 |
34 |
35 | def test_tokens2():
36 | prompt = Prompt(message="Hello, how are you?", role=Role.USER, tokens=4)
37 | assert prompt.tokens == 4
38 |
39 |
40 | def test_tokens3():
41 | prompt = Prompt(message="Hello, how are you?", role=Role.USER, tokens=4)
42 | prompt.update_tokens(7)
43 | assert prompt.tokens == 7
44 |
45 |
46 | def test_tokens4():
47 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
48 | with pytest.raises(MemorValidationError, match=r"Invalid value. `tokens` must be a positive integer."):
49 | prompt.update_tokens("4")
50 |
51 |
52 | def test_estimated_tokens1():
53 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
54 | assert prompt.estimate_tokens(TokensEstimator.UNIVERSAL) == 7
55 |
56 |
57 | def test_estimated_tokens2():
58 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
59 | assert prompt.estimate_tokens(TokensEstimator.OPENAI_GPT_3_5) == 7
60 |
61 |
62 | def test_estimated_tokens3():
63 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
64 | assert prompt.estimate_tokens(TokensEstimator.OPENAI_GPT_4) == 8
65 |
66 |
67 | def test_role1():
68 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
69 | assert prompt.role == Role.USER
70 |
71 |
72 | def test_role2():
73 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
74 | prompt.update_role(Role.SYSTEM)
75 | assert prompt.role == Role.SYSTEM
76 |
77 |
78 | def test_role3():
79 | prompt = Prompt(message="Hello, how are you?", role=None)
80 | assert prompt.role == Role.USER
81 |
82 |
83 | def test_role4():
84 | prompt = Prompt(message="Hello, how are you?", role=None)
85 | with pytest.raises(MemorValidationError, match=r"Invalid role. It must be an instance of Role enum."):
86 | prompt.update_role(2)
87 |
88 |
89 | def test_id1():
90 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
91 | assert uuid.UUID(prompt.id, version=4) == uuid.UUID(prompt._id, version=4)
92 |
93 |
94 | def test_id2():
95 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
96 | prompt._id = "123"
97 | _ = prompt.save("prompt_test3.json")
98 | with pytest.raises(MemorValidationError, match=r"Invalid message ID. It must be a valid UUIDv4."):
99 | _ = Prompt(file_path="prompt_test3.json")
100 |
101 |
102 | def test_responses1():
103 | message = "Hello, how are you?"
104 | response = Response(message="I am fine.")
105 | prompt = Prompt(message=message, responses=[response])
106 | assert prompt.responses[0].message == "I am fine."
107 |
108 |
109 | def test_responses2():
110 | message = "Hello, how are you?"
111 | response0 = Response(message="I am fine.")
112 | response1 = Response(message="Good!")
113 | prompt = Prompt(message=message, responses=[response0, response1])
114 | assert prompt.responses[0].message == "I am fine." and prompt.responses[1].message == "Good!"
115 |
116 |
117 | def test_responses3():
118 | message = "Hello, how are you?"
119 | response0 = Response(message="I am fine.")
120 | response1 = Response(message="Good!")
121 | prompt = Prompt(message=message)
122 | prompt.update_responses([response0, response1])
123 | assert prompt.responses[0].message == "I am fine." and prompt.responses[1].message == "Good!"
124 |
125 |
126 | def test_responses4():
127 | message = "Hello, how are you?"
128 | prompt = Prompt(message=message)
129 | with pytest.raises(MemorValidationError, match=r"Invalid value. `responses` must be a list of `Response`."):
130 | prompt.update_responses({"I am fine.", "Good!"})
131 |
132 |
133 | def test_responses5():
134 | message = "Hello, how are you?"
135 | response0 = Response(message="I am fine.")
136 | prompt = Prompt(message=message)
137 | with pytest.raises(MemorValidationError, match=r"Invalid value. `responses` must be a list of `Response`."):
138 | prompt.update_responses([response0, "Good!"])
139 |
140 |
141 | def test_add_response1():
142 | message = "Hello, how are you?"
143 | response0 = Response(message="I am fine.")
144 | prompt = Prompt(message=message, responses=[response0])
145 | response1 = Response(message="Great!")
146 | prompt.add_response(response1)
147 | assert prompt.responses[0] == response0 and prompt.responses[1] == response1
148 |
149 |
150 | def test_add_response2():
151 | message = "Hello, how are you?"
152 | response0 = Response(message="I am fine.")
153 | prompt = Prompt(message=message, responses=[response0])
154 | response1 = Response(message="Great!")
155 | prompt.add_response(response1, index=0)
156 | assert prompt.responses[0] == response1 and prompt.responses[1] == response0
157 |
158 |
159 | def test_add_response3():
160 | message = "Hello, how are you?"
161 | response0 = Response(message="I am fine.")
162 | prompt = Prompt(message=message, responses=[response0])
163 | with pytest.raises(MemorValidationError, match=r"Invalid response. It must be an instance of `Response`."):
164 | prompt.add_response(1)
165 |
166 |
167 | def test_remove_response():
168 | message = "Hello, how are you?"
169 | response0 = Response(message="I am fine.")
170 | response1 = Response(message="Great!")
171 | prompt = Prompt(message=message, responses=[response0, response1])
172 | prompt.remove_response(0)
173 | assert response0 not in prompt.responses
174 |
175 |
176 | def test_select_response():
177 | message = "Hello, how are you?"
178 | response0 = Response(message="I am fine.")
179 | prompt = Prompt(message=message, responses=[response0])
180 | response1 = Response(message="Great!")
181 | prompt.add_response(response1)
182 | prompt.select_response(index=1)
183 | assert prompt.selected_response == response1
184 |
185 |
186 | def test_template1():
187 | message = "Hello, how are you?"
188 | prompt = Prompt(message=message, template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD, init_check=False)
189 | assert prompt.template == PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD.value
190 |
191 |
192 | def test_template2():
193 | message = "Hello, how are you?"
194 | prompt = Prompt(message=message, template=PresetPromptTemplate.BASIC.RESPONSE, init_check=False)
195 | prompt.update_template(PresetPromptTemplate.INSTRUCTION1.PROMPT)
196 | assert prompt.template.content == PresetPromptTemplate.INSTRUCTION1.PROMPT.value.content
197 |
198 |
199 | def test_template3():
200 | message = "Hello, how are you?"
201 | template = PromptTemplate(content="{message}-{response}")
202 | prompt = Prompt(message=message, template=template, init_check=False)
203 | assert prompt.template.content == "{message}-{response}"
204 |
205 |
206 | def test_template4():
207 | message = "Hello, how are you?"
208 | prompt = Prompt(message=message, template=None)
209 | assert prompt.template == PresetPromptTemplate.DEFAULT.value
210 |
211 |
212 | def test_template5():
213 | message = "Hello, how are you?"
214 | prompt = Prompt(message=message, template=PresetPromptTemplate.BASIC.RESPONSE, init_check=False)
215 | with pytest.raises(MemorValidationError, match=r"Invalid template. It must be an instance of `PromptTemplate` or `PresetPromptTemplate`."):
216 | prompt.update_template("{prompt_message}")
217 |
218 |
219 | def test_copy1():
220 | message = "Hello, how are you?"
221 | response = Response(message="I am fine.")
222 | prompt1 = Prompt(message=message, responses=[response], role=Role.USER,
223 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
224 | prompt2 = copy.copy(prompt1)
225 | assert id(prompt1) != id(prompt2) and prompt1.id != prompt2.id
226 |
227 |
228 | def test_copy2():
229 | message = "Hello, how are you?"
230 | response = Response(message="I am fine.")
231 | prompt1 = Prompt(message=message, responses=[response], role=Role.USER,
232 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
233 | prompt2 = prompt1.copy()
234 | assert id(prompt1) != id(prompt2) and prompt1.id != prompt2.id
235 |
236 |
237 | def test_str():
238 | message = "Hello, how are you?"
239 | response = Response(message="I am fine.")
240 | prompt = Prompt(message=message, responses=[response], role=Role.USER,
241 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
242 | assert str(prompt) == prompt.render(render_format=RenderFormat.STRING)
243 |
244 |
245 | def test_repr():
246 | message = "Hello, how are you?"
247 | response = Response(message="I am fine.")
248 | prompt = Prompt(message=message, responses=[response], role=Role.USER,
249 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
250 | assert repr(prompt) == "Prompt(message={message})".format(message=prompt.message)
251 |
252 |
253 | def test_json1():
254 | message = "Hello, how are you?"
255 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
256 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
257 | prompt1 = Prompt(
258 | message=message,
259 | responses=[
260 | response1,
261 | response2],
262 | role=Role.USER,
263 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
264 | prompt1_json = prompt1.to_json()
265 | prompt2 = Prompt()
266 | prompt2.from_json(prompt1_json)
267 | assert prompt1 == prompt2
268 |
269 |
270 | def test_json2():
271 | message = "Hello, how are you?"
272 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
273 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
274 | prompt1 = Prompt(
275 | message=message,
276 | responses=[
277 | response1,
278 | response2],
279 | role=Role.USER,
280 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
281 | prompt1_json = prompt1.to_json(save_template=False)
282 | prompt2 = Prompt()
283 | prompt2.from_json(prompt1_json)
284 | assert prompt1 != prompt2 and prompt1.template == PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD.value and prompt2.template == PresetPromptTemplate.DEFAULT.value
285 |
286 |
287 | def test_json3():
288 | prompt = Prompt()
289 | with pytest.raises(MemorValidationError, match=r"Invalid prompt structure. It should be a JSON object with proper fields."):
290 | prompt.from_json("{}")
291 |
292 |
293 | def test_save1():
294 | message = "Hello, how are you?"
295 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
296 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
297 | prompt = Prompt(
298 | message=message,
299 | responses=[
300 | response1,
301 | response2],
302 | role=Role.USER,
303 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
304 | result = prompt.save("f:/")
305 | assert result["status"] == False
306 |
307 |
308 | def test_save2():
309 | message = "Hello, how are you?"
310 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
311 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
312 | prompt1 = Prompt(
313 | message=message,
314 | responses=[
315 | response1,
316 | response2],
317 | role=Role.USER,
318 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
319 | result = prompt1.save("prompt_test1.json")
320 | prompt2 = Prompt(file_path="prompt_test1.json")
321 | assert result["status"] and prompt1 == prompt2
322 |
323 |
324 | def test_load1():
325 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: 22"):
326 | _ = Prompt(file_path=22)
327 |
328 |
329 | def test_load2():
330 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: prompt_test10.json"):
331 | _ = Prompt(file_path="prompt_test10.json")
332 |
333 |
334 | def test_save3():
335 | message = "Hello, how are you?"
336 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
337 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
338 | prompt1 = Prompt(
339 | message=message,
340 | responses=[
341 | response1,
342 | response2],
343 | role=Role.USER,
344 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
345 | result = prompt1.save("prompt_test2.json", save_template=False)
346 | prompt2 = Prompt(file_path="prompt_test2.json")
347 | assert result["status"] and prompt1 != prompt2 and prompt1.template == PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD.value and prompt2.template == PresetPromptTemplate.DEFAULT.value
348 |
349 |
350 | def test_render1():
351 | message = "Hello, how are you?"
352 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
353 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
354 | prompt = Prompt(
355 | message=message,
356 | responses=[
357 | response1,
358 | response2],
359 | role=Role.USER,
360 | template=PresetPromptTemplate.BASIC.PROMPT)
361 | assert prompt.render() == "Hello, how are you?"
362 |
363 |
364 | def test_render2():
365 | message = "Hello, how are you?"
366 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
367 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
368 | prompt = Prompt(
369 | message=message,
370 | responses=[
371 | response1,
372 | response2],
373 | role=Role.USER,
374 | template=PresetPromptTemplate.BASIC.PROMPT)
375 | assert prompt.render(RenderFormat.OPENAI) == {"role": "user", "content": "Hello, how are you?"}
376 |
377 |
378 | def test_render3():
379 | message = "Hello, how are you?"
380 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
381 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
382 | prompt = Prompt(
383 | message=message,
384 | responses=[
385 | response1,
386 | response2],
387 | role=Role.USER,
388 | template=PresetPromptTemplate.BASIC.PROMPT)
389 | assert prompt.render(RenderFormat.DICTIONARY)["content"] == "Hello, how are you?"
390 |
391 |
392 | def test_render4():
393 | message = "Hello, how are you?"
394 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
395 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
396 | prompt = Prompt(
397 | message=message,
398 | responses=[
399 | response1,
400 | response2],
401 | role=Role.USER,
402 | template=PresetPromptTemplate.BASIC.PROMPT)
403 | assert ("content", "Hello, how are you?") in prompt.render(RenderFormat.ITEMS)
404 |
405 |
406 | def test_render5():
407 | message = "How are you?"
408 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
409 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
410 | template = PromptTemplate(content="{instruction}, {prompt[message]}", custom_map={"instruction": "Hi"})
411 | prompt = Prompt(
412 | message=message,
413 | responses=[
414 | response1,
415 | response2],
416 | role=Role.USER,
417 | template=template)
418 | assert prompt.render(RenderFormat.OPENAI) == {"role": "user", "content": "Hi, How are you?"}
419 |
420 |
421 | def test_render6():
422 | message = "Hello, how are you?"
423 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
424 | template = PromptTemplate(content="{response[2][message]}")
425 | prompt = Prompt(
426 | message=message,
427 | responses=[response],
428 | role=Role.USER,
429 | template=template,
430 | init_check=False)
431 | with pytest.raises(MemorRenderError, match=r"Prompt template and properties are incompatible."):
432 | prompt.render()
433 |
434 |
435 | def test_render7():
436 | message = "How are you?"
437 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
438 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
439 | template = PromptTemplate(content="{instruction}, {prompt[message]}", custom_map={"instruction": "Hi"})
440 | prompt = Prompt(
441 | message=message,
442 | responses=[
443 | response1,
444 | response2],
445 | role=Role.USER,
446 | template=template)
447 | assert prompt.render(RenderFormat.AI_STUDIO) == {'role': 'user', 'parts': [{'text': 'Hi, How are you?'}]}
448 |
449 |
450 | def test_init_check():
451 | message = "Hello, how are you?"
452 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
453 | template = PromptTemplate(content="{response[2][message]}")
454 | with pytest.raises(MemorRenderError, match=r"Prompt template and properties are incompatible."):
455 | _ = Prompt(message=message, responses=[response], role=Role.USER, template=template)
456 |
457 |
458 | def test_check_render1():
459 | message = "Hello, how are you?"
460 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
461 | template = PromptTemplate(content="{response[2][message]}")
462 | prompt = Prompt(
463 | message=message,
464 | responses=[response],
465 | role=Role.USER,
466 | template=template,
467 | init_check=False)
468 | assert not prompt.check_render()
469 |
470 |
471 | def test_check_render2():
472 | message = "How are you?"
473 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
474 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
475 | template = PromptTemplate(content="{instruction}, {prompt[message]}", custom_map={"instruction": "Hi"})
476 | prompt = Prompt(
477 | message=message,
478 | responses=[
479 | response1,
480 | response2],
481 | role=Role.USER,
482 | template=template)
483 | assert prompt.check_render()
484 |
485 |
486 | def test_equality1():
487 | message = "Hello, how are you?"
488 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
489 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
490 | prompt1 = Prompt(
491 | message=message,
492 | responses=[
493 | response1,
494 | response2],
495 | role=Role.USER,
496 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
497 | prompt2 = prompt1.copy()
498 | assert prompt1 == prompt2
499 |
500 |
501 | def test_equality2():
502 | message = "Hello, how are you?"
503 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
504 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
505 | prompt1 = Prompt(message=message, responses=[response1], role=Role.USER,
506 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
507 | prompt2 = Prompt(message=message, responses=[response2], role=Role.USER,
508 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
509 | assert prompt1 != prompt2
510 |
511 |
512 | def test_equality3():
513 | message = "Hello, how are you?"
514 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
515 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
516 | prompt1 = Prompt(
517 | message=message,
518 | responses=[
519 | response1,
520 | response2],
521 | role=Role.USER,
522 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
523 | prompt2 = Prompt(
524 | message=message,
525 | responses=[
526 | response1,
527 | response2],
528 | role=Role.USER,
529 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
530 | assert prompt1 == prompt2
531 |
532 |
533 | def test_equality4():
534 | message = "Hello, how are you?"
535 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
536 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
537 | prompt = Prompt(
538 | message=message,
539 | responses=[
540 | response1,
541 | response2],
542 | role=Role.USER,
543 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
544 | assert prompt != 2
545 |
546 |
547 | def test_length1():
548 | prompt = Prompt(message="Hello, how are you?")
549 | assert len(prompt) == 19
550 |
551 |
552 | def test_length2():
553 | message = "Hello, how are you?"
554 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
555 | template = PromptTemplate(content="{response[2][message]}")
556 | prompt = Prompt(
557 | message=message,
558 | responses=[response],
559 | role=Role.USER,
560 | template=template,
561 | init_check=False)
562 | assert len(prompt) == 0
563 |
564 |
565 | def test_length3():
566 | prompt = Prompt()
567 | assert len(prompt) == 0
568 |
569 |
570 | def test_date_modified():
571 | message = "Hello, how are you?"
572 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
573 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
574 | prompt = Prompt(
575 | message=message,
576 | responses=[
577 | response1,
578 | response2],
579 | role=Role.USER,
580 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
581 | assert isinstance(prompt.date_modified, datetime.datetime)
582 |
583 |
584 | def test_date_created():
585 | message = "Hello, how are you?"
586 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
587 | response2 = Response(message="Thanks!", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
588 | prompt = Prompt(
589 | message=message,
590 | responses=[
591 | response1,
592 | response2],
593 | role=Role.USER,
594 | template=PresetPromptTemplate.BASIC.PROMPT_RESPONSE_STANDARD)
595 | assert isinstance(prompt.date_created, datetime.datetime)
596 |
--------------------------------------------------------------------------------
/tests/test_prompt_template.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 | import copy
4 | import pytest
5 | from memor import PromptTemplate, MemorValidationError
6 |
7 | TEST_CASE_NAME = "PromptTemplate tests"
8 |
9 |
10 | def test_title1():
11 | template = PromptTemplate(
12 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
13 | custom_map={
14 | "language": "Python"})
15 | assert template.title is None
16 |
17 |
18 | def test_title2():
19 | template = PromptTemplate(
20 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
21 | custom_map={
22 | "language": "Python"})
23 | template.update_title("template1")
24 | assert template.title == "template1"
25 |
26 |
27 | def test_title3():
28 | template = PromptTemplate(
29 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
30 | custom_map={
31 | "language": "Python"},
32 | title=None)
33 | assert template.title is None
34 |
35 |
36 | def test_title4():
37 | template = PromptTemplate(
38 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
39 | custom_map={
40 | "language": "Python"},
41 | title=None)
42 | with pytest.raises(MemorValidationError, match=r"Invalid value. `title` must be a string."):
43 | template.update_title(25)
44 |
45 |
46 | def test_content1():
47 | template = PromptTemplate(
48 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
49 | custom_map={
50 | "language": "Python"})
51 | assert template.content == "Act as a {language} developer and respond to this question:\n{prompt_message}"
52 |
53 |
54 | def test_content2():
55 | template = PromptTemplate(
56 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
57 | custom_map={
58 | "language": "Python"})
59 | template.update_content(content="Act as a {language} developer and respond to this query:\n{prompt_message}")
60 | assert template.content == "Act as a {language} developer and respond to this query:\n{prompt_message}"
61 |
62 |
63 | def test_content3():
64 | template = PromptTemplate(
65 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
66 | custom_map={
67 | "language": "Python"})
68 | with pytest.raises(MemorValidationError, match=r"Invalid value. `content` must be a string."):
69 | template.update_content(content=22)
70 |
71 |
72 | def test_custom_map1():
73 | template = PromptTemplate(
74 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
75 | custom_map={
76 | "language": "Python"})
77 | assert template.custom_map == {"language": "Python"}
78 |
79 |
80 | def test_custom_map2():
81 | template = PromptTemplate(
82 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
83 | custom_map={
84 | "language": "Python"})
85 | template.update_map({"language": "C++"})
86 | assert template.custom_map == {"language": "C++"}
87 |
88 |
89 | def test_custom_map3():
90 | template = PromptTemplate(
91 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
92 | custom_map={
93 | "language": "Python"})
94 | with pytest.raises(MemorValidationError, match=r"Invalid custom map: it must be a dictionary with keys and values that can be converted to strings."):
95 | template.update_map(["C++"])
96 |
97 |
98 | def test_date_modified():
99 | template = PromptTemplate(
100 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
101 | custom_map={
102 | "language": "Python"})
103 | assert isinstance(template.date_modified, datetime.datetime)
104 |
105 |
106 | def test_date_created():
107 | template = PromptTemplate(
108 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
109 | custom_map={
110 | "language": "Python"})
111 | assert isinstance(template.date_created, datetime.datetime)
112 |
113 |
114 | def test_json1():
115 | template1 = PromptTemplate(
116 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
117 | custom_map={
118 | "language": "Python"})
119 | template1_json = template1.to_json()
120 | template2 = PromptTemplate()
121 | template2.from_json(template1_json)
122 | assert template1 == template2
123 |
124 |
125 | def test_json2():
126 | template = PromptTemplate()
127 | with pytest.raises(MemorValidationError, match=r"Invalid template structure. It should be a JSON object with proper fields."):
128 | template.from_json("{}")
129 |
130 |
131 | def test_save1():
132 | template = PromptTemplate(
133 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
134 | custom_map={
135 | "language": "Python"})
136 | result = template.save("template_test1.json")
137 | with open("template_test1.json", "r") as file:
138 | saved_template = json.loads(file.read())
139 | assert result["status"] and template.to_json() == saved_template
140 |
141 |
142 | def test_save2():
143 | template = PromptTemplate(
144 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
145 | custom_map={
146 | "language": "Python"})
147 | result = template.save("f:/")
148 | assert result["status"] == False
149 |
150 |
151 | def test_load1():
152 | template1 = PromptTemplate(
153 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
154 | custom_map={
155 | "language": "Python"})
156 | result = template1.save("template_test2.json")
157 | template2 = PromptTemplate(file_path="template_test2.json")
158 | assert result["status"] and template1 == template2
159 |
160 |
161 | def test_load2():
162 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: 22"):
163 | _ = PromptTemplate(file_path=22)
164 |
165 |
166 | def test_load3():
167 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: template_test10.json"):
168 | _ = PromptTemplate(file_path="template_test10.json")
169 |
170 |
171 | def test_copy1():
172 | template1 = PromptTemplate(
173 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
174 | custom_map={
175 | "language": "Python"})
176 | template2 = copy.copy(template1)
177 | assert id(template1) != id(template2)
178 |
179 |
180 | def test_copy2():
181 | template1 = PromptTemplate(
182 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
183 | custom_map={
184 | "language": "Python"})
185 | template2 = template1.copy()
186 | assert id(template1) != id(template2)
187 |
188 |
189 | def test_str():
190 | template = PromptTemplate(
191 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
192 | custom_map={
193 | "language": "Python"})
194 | assert str(template) == template.content
195 |
196 |
197 | def test_repr():
198 | template = PromptTemplate(
199 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
200 | custom_map={
201 | "language": "Python"})
202 | assert repr(template) == "PromptTemplate(content={content})".format(content=template.content)
203 |
204 |
205 | def test_equality1():
206 | template1 = PromptTemplate(
207 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
208 | custom_map={
209 | "language": "Python"})
210 | template2 = template1.copy()
211 | assert template1 == template2
212 |
213 |
214 | def test_equality2():
215 | template1 = PromptTemplate(
216 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
217 | custom_map={
218 | "language": "Python"},
219 | title="template1")
220 | template2 = PromptTemplate(
221 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
222 | custom_map={
223 | "language": "Python"},
224 | title="template2")
225 | assert template1 != template2
226 |
227 |
228 | def test_equality3():
229 | template1 = PromptTemplate(
230 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
231 | custom_map={
232 | "language": "Python"},
233 | title="template1")
234 | template2 = PromptTemplate(
235 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
236 | custom_map={
237 | "language": "Python"},
238 | title="template1")
239 | assert template1 == template2
240 |
241 |
242 | def test_equality4():
243 | template = PromptTemplate(
244 | content="Act as a {language} developer and respond to this question:\n{prompt_message}",
245 | custom_map={
246 | "language": "Python"},
247 | title="template1")
248 | assert template != 2
249 |
--------------------------------------------------------------------------------
/tests/test_response.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import uuid
3 | import json
4 | import copy
5 | import pytest
6 | from memor import Response, Role, LLMModel, MemorValidationError
7 | from memor import RenderFormat
8 | from memor import TokensEstimator
9 |
10 | TEST_CASE_NAME = "Response tests"
11 |
12 |
13 | def test_message1():
14 | response = Response(message="I am fine.")
15 | assert response.message == "I am fine."
16 |
17 |
18 | def test_message2():
19 | response = Response(message="I am fine.")
20 | response.update_message("OK!")
21 | assert response.message == "OK!"
22 |
23 |
24 | def test_message3():
25 | response = Response(message="I am fine.")
26 | with pytest.raises(MemorValidationError, match=r"Invalid value. `message` must be a string."):
27 | response.update_message(22)
28 |
29 |
30 | def test_tokens1():
31 | response = Response(message="I am fine.")
32 | assert response.tokens is None
33 |
34 |
35 | def test_tokens2():
36 | response = Response(message="I am fine.", tokens=4)
37 | assert response.tokens == 4
38 |
39 |
40 | def test_tokens3():
41 | response = Response(message="I am fine.", tokens=4)
42 | response.update_tokens(6)
43 | assert response.tokens == 6
44 |
45 |
46 | def test_estimated_tokens1():
47 | response = Response(message="I am fine.")
48 | assert response.estimate_tokens(TokensEstimator.UNIVERSAL) == 5
49 |
50 |
51 | def test_estimated_tokens2():
52 | response = Response(message="I am fine.")
53 | assert response.estimate_tokens(TokensEstimator.OPENAI_GPT_3_5) == 4
54 |
55 |
56 | def test_estimated_tokens3():
57 | response = Response(message="I am fine.")
58 | assert response.estimate_tokens(TokensEstimator.OPENAI_GPT_4) == 4
59 |
60 |
61 | def test_tokens4():
62 | response = Response(message="I am fine.", tokens=4)
63 | with pytest.raises(MemorValidationError, match=r"Invalid value. `tokens` must be a positive integer."):
64 | response.update_tokens(-2)
65 |
66 |
67 | def test_inference_time1():
68 | response = Response(message="I am fine.")
69 | assert response.inference_time is None
70 |
71 |
72 | def test_inference_time2():
73 | response = Response(message="I am fine.", inference_time=8.2)
74 | assert response.inference_time == 8.2
75 |
76 |
77 | def test_inference_time3():
78 | response = Response(message="I am fine.", inference_time=8.2)
79 | response.update_inference_time(9.5)
80 | assert response.inference_time == 9.5
81 |
82 |
83 | def test_inference_time4():
84 | response = Response(message="I am fine.", inference_time=8.2)
85 | with pytest.raises(MemorValidationError, match=r"Invalid value. `inference_time` must be a positive float."):
86 | response.update_inference_time(-5)
87 |
88 |
89 | def test_score1():
90 | response = Response(message="I am fine.", score=0.9)
91 | assert response.score == 0.9
92 |
93 |
94 | def test_score2():
95 | response = Response(message="I am fine.", score=0.9)
96 | response.update_score(0.5)
97 | assert response.score == 0.5
98 |
99 |
100 | def test_score3():
101 | response = Response(message="I am fine.", score=0.9)
102 | with pytest.raises(MemorValidationError, match=r"Invalid value. `score` must be a value between 0 and 1."):
103 | response.update_score(-2)
104 |
105 |
106 | def test_role1():
107 | response = Response(message="I am fine.", role=Role.ASSISTANT)
108 | assert response.role == Role.ASSISTANT
109 |
110 |
111 | def test_role2():
112 | response = Response(message="I am fine.", role=Role.ASSISTANT)
113 | response.update_role(Role.USER)
114 | assert response.role == Role.USER
115 |
116 |
117 | def test_role3():
118 | response = Response(message="I am fine.", role=None)
119 | assert response.role == Role.ASSISTANT
120 |
121 |
122 | def test_role4():
123 | response = Response(message="I am fine.", role=Role.ASSISTANT)
124 | with pytest.raises(MemorValidationError, match=r"Invalid role. It must be an instance of Role enum."):
125 | response.update_role(2)
126 |
127 |
128 | def test_temperature1():
129 | response = Response(message="I am fine.", temperature=0.2)
130 | assert response.temperature == 0.2
131 |
132 |
133 | def test_temperature2():
134 | response = Response(message="I am fine.", temperature=0.2)
135 | response.update_temperature(0.7)
136 | assert response.temperature == 0.7
137 |
138 |
139 | def test_temperature3():
140 | response = Response(message="I am fine.", temperature=0.2)
141 | with pytest.raises(MemorValidationError, match=r"Invalid value. `temperature` must be a positive float."):
142 | response.update_temperature(-22)
143 |
144 |
145 | def test_model1():
146 | response = Response(message="I am fine.", model=LLMModel.GPT_4)
147 | assert response.model == LLMModel.GPT_4.value
148 |
149 |
150 | def test_model2():
151 | response = Response(message="I am fine.", model=LLMModel.GPT_4)
152 | response.update_model(LLMModel.GPT_4O)
153 | assert response.model == LLMModel.GPT_4O.value
154 |
155 |
156 | def test_model3():
157 | response = Response(message="I am fine.", model=LLMModel.GPT_4)
158 | response.update_model("my-trained-llm-instruct")
159 | assert response.model == "my-trained-llm-instruct"
160 |
161 |
162 | def test_model4():
163 | response = Response(message="I am fine.", model=LLMModel.GPT_4)
164 | with pytest.raises(MemorValidationError, match=r"Invalid model. It must be an instance of LLMModel enum."):
165 | response.update_model(4)
166 |
167 |
168 | def test_id1():
169 | response = Response(message="I am fine.", model=LLMModel.GPT_4)
170 | assert uuid.UUID(response.id, version=4) == uuid.UUID(response._id, version=4)
171 |
172 |
173 | def test_id2():
174 | response = Response(message="I am fine.", model=LLMModel.GPT_4)
175 | response._id = "123"
176 | _ = response.save("response_test3.json")
177 | with pytest.raises(MemorValidationError, match=r"Invalid message ID. It must be a valid UUIDv4."):
178 | _ = Response(file_path="response_test3.json")
179 |
180 |
181 | def test_date1():
182 | date_time_utc = datetime.datetime.now(datetime.timezone.utc)
183 | response = Response(message="I am fine.", date=date_time_utc)
184 | assert response.date_created == date_time_utc
185 |
186 |
187 | def test_date2():
188 | response = Response(message="I am fine.", date=None)
189 | assert isinstance(response.date_created, datetime.datetime)
190 |
191 |
192 | def test_date3():
193 | with pytest.raises(MemorValidationError, match=r"Invalid value. `date` must be a datetime object that includes timezone information."):
194 | _ = Response(message="I am fine.", date="2/25/2025")
195 |
196 |
197 | def test_date4():
198 | with pytest.raises(MemorValidationError, match=r"Invalid value. `date` must be a datetime object that includes timezone information."):
199 | _ = Response(message="I am fine.", date=datetime.datetime.now())
200 |
201 |
202 | def test_json1():
203 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
204 | response1_json = response1.to_json()
205 | response2 = Response()
206 | response2.from_json(response1_json)
207 | assert response1 == response2
208 |
209 |
210 | def test_json2():
211 | response = Response()
212 | with pytest.raises(MemorValidationError, match=r"Invalid response structure. It should be a JSON object with proper fields."):
213 | response.from_json("{}")
214 |
215 |
216 | def test_save1():
217 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
218 | result = response.save("response_test1.json")
219 | with open("response_test1.json", "r") as file:
220 | saved_response = json.loads(file.read())
221 | assert result["status"] and response.to_json() == saved_response
222 |
223 |
224 | def test_save2():
225 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
226 | result = response.save("f:/")
227 | assert result["status"] == False
228 |
229 |
230 | def test_load1():
231 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
232 | result = response1.save("response_test2.json")
233 | response2 = Response(file_path="response_test2.json")
234 | assert result["status"] and response1 == response2
235 |
236 |
237 | def test_load2():
238 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: 2"):
239 | response = Response(file_path=2)
240 |
241 |
242 | def test_load3():
243 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: response_test10.json"):
244 | response = Response(file_path="response_test10.json")
245 |
246 |
247 | def test_copy1():
248 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
249 | response2 = copy.copy(response1)
250 | assert id(response1) != id(response2) and response1.id != response2.id
251 |
252 |
253 | def test_copy2():
254 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
255 | response2 = response1.copy()
256 | assert id(response1) != id(response2) and response1.id != response2.id
257 |
258 |
259 | def test_str():
260 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
261 | assert str(response) == response.message
262 |
263 |
264 | def test_repr():
265 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
266 | assert repr(response) == "Response(message={message})".format(message=response.message)
267 |
268 |
269 | def test_render1():
270 | response = Response(message="I am fine.")
271 | assert response.render() == "I am fine."
272 |
273 |
274 | def test_render2():
275 | response = Response(message="I am fine.")
276 | assert response.render(RenderFormat.OPENAI) == {"role": "assistant", "content": "I am fine."}
277 |
278 |
279 | def test_render3():
280 | response = Response(message="I am fine.")
281 | assert response.render(RenderFormat.DICTIONARY) == response.to_dict()
282 |
283 |
284 | def test_render4():
285 | response = Response(message="I am fine.")
286 | assert response.render(RenderFormat.ITEMS) == response.to_dict().items()
287 |
288 |
289 | def test_render5():
290 | response = Response(message="I am fine.")
291 | with pytest.raises(MemorValidationError, match=r"Invalid render format. It must be an instance of RenderFormat enum."):
292 | response.render("OPENAI")
293 |
294 |
295 | def test_render6():
296 | response = Response(message="I am fine.")
297 | assert response.render(RenderFormat.AI_STUDIO) == {'role': 'assistant', 'parts': [{'text': 'I am fine.'}]}
298 |
299 |
300 | def test_equality1():
301 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
302 | response2 = response1.copy()
303 | assert response1 == response2
304 |
305 |
306 | def test_equality2():
307 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
308 | response2 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.6)
309 | assert response1 != response2
310 |
311 |
312 | def test_equality3():
313 | response1 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
314 | response2 = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
315 | assert response1 == response2
316 |
317 |
318 | def test_equality4():
319 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
320 | assert response != 2
321 |
322 |
323 | def test_length1():
324 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
325 | assert len(response) == 10
326 |
327 |
328 | def test_length2():
329 | response = Response()
330 | assert len(response) == 0
331 |
332 |
333 | def test_date_modified():
334 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
335 | assert isinstance(response.date_modified, datetime.datetime)
336 |
337 |
338 | def test_date_created():
339 | response = Response(message="I am fine.", model=LLMModel.GPT_4, temperature=0.5, role=Role.USER, score=0.8)
340 | assert isinstance(response.date_created, datetime.datetime)
341 |
--------------------------------------------------------------------------------
/tests/test_session.py:
--------------------------------------------------------------------------------
1 | import re
2 | import datetime
3 | import copy
4 | import pytest
5 | from memor import Session, Prompt, Response, Role
6 | from memor import PromptTemplate
7 | from memor import RenderFormat
8 | from memor import MemorRenderError, MemorValidationError
9 | from memor import TokensEstimator
10 |
11 | TEST_CASE_NAME = "Session tests"
12 |
13 |
14 | def test_title1():
15 | session = Session(title="session1")
16 | assert session.title == "session1"
17 |
18 |
19 | def test_title2():
20 | session = Session(title="session1")
21 | session.update_title("session2")
22 | assert session.title == "session2"
23 |
24 |
25 | def test_title3():
26 | session = Session(title="session1")
27 | with pytest.raises(MemorValidationError, match=r"Invalid value. `title` must be a string."):
28 | session.update_title(2)
29 |
30 |
31 | def test_messages1():
32 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
33 | response = Response(message="I am fine.")
34 | session = Session(messages=[prompt, response])
35 | assert session.messages == [prompt, response]
36 |
37 |
38 | def test_messages2():
39 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
40 | response = Response(message="I am fine.")
41 | session = Session(messages=[prompt, response])
42 | session.update_messages([prompt, response, prompt, response])
43 | assert session.messages == [prompt, response, prompt, response]
44 |
45 |
46 | def test_messages3():
47 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
48 | session = Session(messages=[prompt])
49 | with pytest.raises(MemorValidationError, match=r"Invalid value. `messages` must be a list of `Prompt` or `Response`."):
50 | session.update_messages([prompt, "I am fine."])
51 |
52 |
53 | def test_messages4():
54 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
55 | session = Session(messages=[prompt])
56 | with pytest.raises(MemorValidationError, match=r"Invalid value. `messages` must be a list of `Prompt` or `Response`."):
57 | session.update_messages("I am fine.")
58 |
59 |
60 | def test_messages_status1():
61 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
62 | response = Response(message="I am fine.")
63 | session = Session(messages=[prompt, response])
64 | assert session.messages_status == [True, True]
65 |
66 |
67 | def test_messages_status2():
68 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
69 | response = Response(message="I am fine.")
70 | session = Session(messages=[prompt, response])
71 | session.update_messages_status([False, True])
72 | assert session.messages_status == [False, True]
73 |
74 |
75 | def test_messages_status3():
76 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
77 | response = Response(message="I am fine.")
78 | session = Session(messages=[prompt, response])
79 | with pytest.raises(MemorValidationError, match=r"Invalid value. `status` must be a list of booleans."):
80 | session.update_messages_status(["False", True])
81 |
82 |
83 | def test_messages_status4():
84 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
85 | response = Response(message="I am fine.")
86 | session = Session(messages=[prompt, response])
87 | with pytest.raises(MemorValidationError, match=r"Invalid message status length. It must be equal to the number of messages."):
88 | session.update_messages_status([False, True, True])
89 |
90 |
91 | def test_enable_message():
92 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
93 | response = Response(message="I am fine.")
94 | session = Session(messages=[prompt, response])
95 | session.update_messages_status([False, False])
96 | session.enable_message(0)
97 | assert session.messages_status == [True, False]
98 |
99 |
100 | def test_disable_message():
101 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
102 | response = Response(message="I am fine.")
103 | session = Session(messages=[prompt, response])
104 | session.update_messages_status([True, True])
105 | session.disable_message(0)
106 | assert session.messages_status == [False, True]
107 |
108 |
109 | def test_mask_message():
110 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
111 | response = Response(message="I am fine.")
112 | session = Session(messages=[prompt, response])
113 | session.mask_message(0)
114 | assert session.messages_status == [False, True]
115 |
116 |
117 | def test_unmask_message():
118 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
119 | response = Response(message="I am fine.")
120 | session = Session(messages=[prompt, response])
121 | session.update_messages_status([False, False])
122 | session.unmask_message(0)
123 | assert session.messages_status == [True, False]
124 |
125 |
126 | def test_masks():
127 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
128 | response = Response(message="I am fine.")
129 | session = Session(messages=[prompt, response])
130 | session.update_messages_status([False, True])
131 | assert session.masks == [True, False]
132 |
133 |
134 | def test_add_message1():
135 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
136 | response = Response(message="I am fine.")
137 | session = Session(messages=[prompt, response])
138 | session.add_message(Response("Good!"))
139 | assert session.messages[2] == Response("Good!")
140 |
141 |
142 | def test_add_message2():
143 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
144 | response = Response(message="I am fine.")
145 | session = Session(messages=[prompt, response])
146 | session.add_message(message=Response("Good!"), status=False, index=0)
147 | assert session.messages[0] == Response("Good!") and session.messages_status[0] == False
148 |
149 |
150 | def test_add_message3():
151 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
152 | response = Response(message="I am fine.")
153 | session = Session(messages=[prompt, response])
154 | with pytest.raises(MemorValidationError, match=r"Invalid message. It must be an instance of `Prompt` or `Response`."):
155 | session.add_message(message="Good!", status=False, index=0)
156 |
157 |
158 | def test_add_message4():
159 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
160 | response = Response(message="I am fine.")
161 | session = Session(messages=[prompt, response])
162 | with pytest.raises(MemorValidationError, match=r"Invalid value. `status` must be a boolean."):
163 | session.add_message(message=prompt, status="False", index=0)
164 |
165 |
166 | def test_remove_message1():
167 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
168 | response = Response(message="I am fine.")
169 | session = Session(messages=[prompt, response])
170 | session.remove_message(1)
171 | assert session.messages == [prompt] and session.messages_status == [True]
172 |
173 |
174 | def test_remove_message2():
175 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
176 | response = Response(message="I am fine.")
177 | session = Session(messages=[prompt, response])
178 | session.remove_message_by_index(1)
179 | assert session.messages == [prompt] and session.messages_status == [True]
180 |
181 |
182 | def test_remove_message3():
183 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
184 | response = Response(message="I am fine.")
185 | session = Session(messages=[prompt, response])
186 | session.remove_message_by_id(response.id)
187 | assert session.messages == [prompt] and session.messages_status == [True]
188 |
189 |
190 | def test_remove_message4():
191 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
192 | response = Response(message="I am fine.")
193 | session = Session(messages=[prompt, response])
194 | session.remove_message(response.id)
195 | assert session.messages == [prompt] and session.messages_status == [True]
196 |
197 |
198 | def test_remove_message5():
199 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
200 | response = Response(message="I am fine.")
201 | session = Session(messages=[prompt, response])
202 | with pytest.raises(MemorValidationError, match=r"Invalid value. `identifier` must be an integer or a string."):
203 | session.remove_message(3.5)
204 |
205 |
206 | def test_clear_messages():
207 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
208 | response = Response(message="I am fine.")
209 | session = Session(messages=[prompt, response])
210 | assert len(session) == 2
211 | session.clear_messages()
212 | assert len(session) == 0
213 |
214 |
215 | def test_copy1():
216 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
217 | response = Response(message="I am fine.")
218 | session1 = Session(messages=[prompt, response], title="session")
219 | session2 = copy.copy(session1)
220 | assert id(session1) != id(session2)
221 |
222 |
223 | def test_copy2():
224 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
225 | response = Response(message="I am fine.")
226 | session1 = Session(messages=[prompt, response], title="session")
227 | session2 = session1.copy()
228 | assert id(session1) != id(session2)
229 |
230 |
231 | def test_str():
232 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
233 | response = Response(message="I am fine.")
234 | session = Session(messages=[prompt, response], title="session1")
235 | assert str(session) == session.render(render_format=RenderFormat.STRING)
236 |
237 |
238 | def test_repr():
239 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
240 | response = Response(message="I am fine.")
241 | session = Session(messages=[prompt, response], title="session1")
242 | assert repr(session) == "Session(title={title})".format(title=session.title)
243 |
244 |
245 | def test_json():
246 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
247 | response = Response(message="I am fine.")
248 | session1 = Session(messages=[prompt, response], title="session1")
249 | session1_json = session1.to_json()
250 | session2 = Session()
251 | session2.from_json(session1_json)
252 | assert session1 == session2
253 |
254 |
255 | def test_save1():
256 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
257 | response = Response(message="I am fine.")
258 | session = Session(messages=[prompt, response], title="session1")
259 | result = session.save("f:/")
260 | assert result["status"] == False
261 |
262 |
263 | def test_save2():
264 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
265 | response = Response(message="I am fine.")
266 | session1 = Session(messages=[prompt, response], title="session1")
267 | _ = session1.render()
268 | result = session1.save("session_test1.json")
269 | session2 = Session(file_path="session_test1.json")
270 | assert result["status"] and session1 == session2 and session2.render_counter == 1
271 |
272 |
273 | def test_load1():
274 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: 22"):
275 | _ = Session(file_path=22)
276 |
277 |
278 | def test_load2():
279 | with pytest.raises(FileNotFoundError, match=r"Invalid path: must be a string and refer to an existing location. Given path: session_test10.json"):
280 | _ = Session(file_path="session_test10.json")
281 |
282 |
283 | def test_render1():
284 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
285 | response = Response(message="I am fine.")
286 | session = Session(messages=[prompt, response], title="session1")
287 | assert session.render() == "Hello, how are you?\nI am fine.\n"
288 |
289 |
290 | def test_render2():
291 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
292 | response = Response(message="I am fine.")
293 | session = Session(messages=[prompt, response], title="session1")
294 | assert session.render(RenderFormat.OPENAI) == [{"role": "user", "content": "Hello, how are you?"}, {
295 | "role": "assistant", "content": "I am fine."}]
296 |
297 |
298 | def test_render3():
299 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
300 | response = Response(message="I am fine.")
301 | session = Session(messages=[prompt, response], title="session1")
302 | assert session.render(RenderFormat.DICTIONARY)["content"] == "Hello, how are you?\nI am fine.\n"
303 |
304 |
305 | def test_render4():
306 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
307 | response = Response(message="I am fine.")
308 | session = Session(messages=[prompt, response], title="session1")
309 | assert ("content", "Hello, how are you?\nI am fine.\n") in session.render(RenderFormat.ITEMS)
310 |
311 |
312 | def test_render5():
313 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
314 | response = Response(message="I am fine.")
315 | session = Session(messages=[prompt, response], title="session1")
316 | with pytest.raises(MemorValidationError, match=r"Invalid render format. It must be an instance of RenderFormat enum."):
317 | session.render("OPENAI")
318 |
319 |
320 | def test_render6():
321 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
322 | response = Response(message="I am fine.")
323 | session = Session(messages=[prompt, response], title="session1")
324 | assert session.render(RenderFormat.AI_STUDIO) == [{'role': 'user', 'parts': [{'text': 'Hello, how are you?'}]}, {
325 | 'role': 'assistant', 'parts': [{'text': 'I am fine.'}]}]
326 |
327 |
328 | def test_check_render1():
329 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
330 | response = Response(message="I am fine.")
331 | session = Session(messages=[prompt, response], title="session1")
332 | assert session.check_render()
333 |
334 |
335 | def test_check_render2():
336 | template = PromptTemplate(content="{response[2][message]}")
337 | prompt = Prompt(message="Hello, how are you?", role=Role.USER, template=template, init_check=False)
338 | response = Response(message="I am fine.")
339 | session = Session(messages=[prompt, response], title="session1", init_check=False)
340 | assert not session.check_render()
341 |
342 |
343 | def test_render_counter1():
344 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
345 | response = Response(message="I am fine.")
346 | session = Session(messages=[prompt, response], title="session1")
347 | assert session.render_counter == 0
348 | for _ in range(10):
349 | __ = session.render()
350 | assert session.render_counter == 10
351 |
352 |
353 | def test_render_counter2():
354 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
355 | response = Response(message="I am fine.")
356 | session = Session(messages=[prompt, response], title="session1")
357 | assert session.render_counter == 0
358 | for _ in range(10):
359 | __ = session.render()
360 | for _ in range(2):
361 | __ = session.render(enable_counter=False)
362 | assert session.render_counter == 10
363 |
364 |
365 | def test_render_counter3():
366 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
367 | response = Response(message="I am fine.")
368 | session = Session(messages=[prompt, response], title="session1", init_check=True)
369 | _ = str(session)
370 | _ = session.check_render()
371 | _ = session.estimate_tokens()
372 | assert session.render_counter == 0
373 |
374 |
375 | def test_init_check():
376 | template = PromptTemplate(content="{response[2][message]}")
377 | prompt = Prompt(message="Hello, how are you?", role=Role.USER, template=template, init_check=False)
378 | response = Response(message="I am fine.")
379 | with pytest.raises(MemorRenderError, match=r"Prompt template and properties are incompatible."):
380 | _ = Session(messages=[prompt, response], title="session1")
381 |
382 |
383 | def test_equality1():
384 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
385 | response = Response(message="I am fine.")
386 | session1 = Session(messages=[prompt, response], title="session1")
387 | session2 = session1.copy()
388 | assert session1 == session2
389 |
390 |
391 | def test_equality2():
392 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
393 | response = Response(message="I am fine.")
394 | session1 = Session(messages=[prompt, response], title="session1")
395 | session2 = Session(messages=[prompt, response], title="session2")
396 | assert session1 != session2
397 |
398 |
399 | def test_equality3():
400 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
401 | response = Response(message="I am fine.")
402 | session1 = Session(messages=[prompt, response], title="session1")
403 | session2 = Session(messages=[prompt, response], title="session1")
404 | assert session1 == session2
405 |
406 |
407 | def test_equality4():
408 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
409 | response = Response(message="I am fine.")
410 | session = Session(messages=[prompt, response], title="session1")
411 | assert session != 2
412 |
413 |
414 | def test_date_modified():
415 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
416 | response = Response(message="I am fine.")
417 | session = Session(messages=[prompt, response], title="session1")
418 | assert isinstance(session.date_modified, datetime.datetime)
419 |
420 |
421 | def test_date_created():
422 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
423 | response = Response(message="I am fine.")
424 | session = Session(messages=[prompt, response], title="session1")
425 | assert isinstance(session.date_created, datetime.datetime)
426 |
427 |
428 | def test_length():
429 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
430 | response = Response(message="I am fine.")
431 | session = Session(messages=[prompt, response], title="session1")
432 | assert len(session) == len(session.messages) and len(session) == 2
433 |
434 |
435 | def test_iter():
436 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
437 | response = Response(message="I am fine.")
438 | session = Session(messages=[prompt, response, prompt, response], title="session1")
439 | messages = []
440 | for message in session:
441 | messages.append(message)
442 | assert session.messages == messages
443 |
444 |
445 | def test_addition1():
446 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
447 | response = Response(message="I am fine.")
448 | session1 = Session(messages=[prompt, response, prompt, response], title="session1")
449 | session2 = Session(messages=[prompt, prompt, response, response], title="session2")
450 | session3 = session1 + session2
451 | assert session3.title is None and session3.messages == session1.messages + session2.messages
452 |
453 |
454 | def test_addition2():
455 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
456 | response = Response(message="I am fine.")
457 | session1 = Session(messages=[prompt, response, prompt, response], title="session1")
458 | session2 = Session(messages=[prompt, prompt, response, response], title="session2")
459 | session3 = session2 + session1
460 | assert session3.title is None and session3.messages == session2.messages + session1.messages
461 |
462 |
463 | def test_addition3():
464 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
465 | response = Response(message="I am fine.")
466 | session1 = Session(messages=[prompt, response, prompt, response], title="session1")
467 | session2 = session1 + response
468 | assert session2.title == "session1" and session2.messages == session1.messages + [response]
469 |
470 |
471 | def test_addition4():
472 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
473 | response = Response(message="I am fine.")
474 | session1 = Session(messages=[prompt, response, prompt, response], title="session1")
475 | session2 = session1 + prompt
476 | assert session2.title == "session1" and session2.messages == session1.messages + [prompt]
477 |
478 |
479 | def test_addition5():
480 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
481 | response = Response(message="I am fine.")
482 | session1 = Session(messages=[prompt, response, prompt, response], title="session1")
483 | session2 = response + session1
484 | assert session2.title == "session1" and session2.messages == [response] + session1.messages
485 |
486 |
487 | def test_addition6():
488 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
489 | response = Response(message="I am fine.")
490 | session1 = Session(messages=[prompt, response, prompt, response], title="session1")
491 | session2 = prompt + session1
492 | assert session2.title == "session1" and session2.messages == [prompt] + session1.messages
493 |
494 |
495 | def test_addition7():
496 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
497 | response = Response(message="I am fine.")
498 | session1 = Session(messages=[prompt, response, prompt, response], title="session1")
499 | with pytest.raises(TypeError, match=re.escape(r"Unsupported operand type(s) for +: `Session` and `int`")):
500 | _ = session1 + 2
501 |
502 |
503 | def test_addition8():
504 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
505 | response = Response(message="I am fine.")
506 | session1 = Session(messages=[prompt, response, prompt, response], title="session1")
507 | with pytest.raises(TypeError, match=re.escape(r"Unsupported operand type(s) for +: `Session` and `int`")):
508 | _ = 2 + session1
509 |
510 |
511 | def test_contains1():
512 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
513 | response = Response(message="I am fine.")
514 | session = Session(messages=[prompt, response], title="session")
515 | assert prompt in session and response in session
516 |
517 |
518 | def test_contains2():
519 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
520 | response1 = Response(message="I am fine.")
521 | response2 = Response(message="Good!")
522 | session = Session(messages=[prompt, response1], title="session")
523 | assert response2 not in session
524 |
525 |
526 | def test_contains3():
527 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
528 | response = Response(message="I am fine.")
529 | session = Session(messages=[prompt, response], title="session")
530 | assert "I am fine." not in session
531 |
532 |
533 | def test_getitem1():
534 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
535 | response = Response(message="I am fine.")
536 | session = Session(messages=[prompt, response], title="session")
537 | assert session[0] == session.messages[0] and session[1] == session.messages[1]
538 |
539 |
540 | def test_getitem2():
541 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
542 | response = Response(message="I am fine.")
543 | session = Session(messages=[prompt, response, response, response], title="session")
544 | assert session[:] == session.messages
545 |
546 |
547 | def test_getitem3():
548 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
549 | response = Response(message="I am fine.")
550 | session = Session(messages=[prompt, response], title="session")
551 | assert session[0] == session.get_message_by_index(0) and session[1] == session.get_message_by_index(1)
552 |
553 |
554 | def test_getitem4():
555 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
556 | response = Response(message="I am fine.")
557 | session = Session(messages=[prompt, response], title="session")
558 | assert session[0] == session.get_message(0) and session[1] == session.get_message(1)
559 |
560 |
561 | def test_getitem5():
562 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
563 | response = Response(message="I am fine.")
564 | session = Session(messages=[prompt, response], title="session")
565 | assert session[0] == session.get_message_by_id(prompt.id) and session[1] == session.get_message_by_id(response.id)
566 |
567 |
568 | def test_getitem6():
569 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
570 | response = Response(message="I am fine.")
571 | session = Session(messages=[prompt, response], title="session")
572 | assert session[0] == session.get_message(prompt.id) and session[1] == session.get_message(response.id)
573 |
574 |
575 | def test_getitem7():
576 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
577 | response = Response(message="I am fine.")
578 | session = Session(messages=[prompt, response], title="session")
579 | with pytest.raises(MemorValidationError, match=r"Invalid value. `identifier` must be an integer, string or a slice."):
580 | _ = session[3.5]
581 |
582 |
583 | def test_getitem8():
584 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
585 | response = Response(message="I am fine.")
586 | session = Session(messages=[prompt, response], title="session")
587 | with pytest.raises(MemorValidationError, match=r"Invalid value. `identifier` must be an integer, string or a slice."):
588 | _ = session.get_message(3.5)
589 |
590 |
591 | def test_estimated_tokens1():
592 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
593 | response = Response(message="I am fine.")
594 | session = Session(messages=[prompt, response], title="session")
595 | assert session.estimate_tokens(TokensEstimator.UNIVERSAL) == 12
596 |
597 |
598 | def test_estimated_tokens2():
599 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
600 | response = Response(message="I am fine.")
601 | session = Session(messages=[prompt, response], title="session")
602 | assert session.estimate_tokens(TokensEstimator.OPENAI_GPT_3_5) == 14
603 |
604 |
605 | def test_estimated_tokens3():
606 | prompt = Prompt(message="Hello, how are you?", role=Role.USER)
607 | response = Response(message="I am fine.")
608 | session = Session(messages=[prompt, response], title="session")
609 | assert session.estimate_tokens(TokensEstimator.OPENAI_GPT_4) == 15
610 |
--------------------------------------------------------------------------------
/tests/test_token_estimators.py:
--------------------------------------------------------------------------------
1 | from memor.tokens_estimator import openai_tokens_estimator_gpt_3_5, openai_tokens_estimator_gpt_4, universal_tokens_estimator
2 |
3 | TEST_CASE_NAME = "Token Estimators tests"
4 |
5 |
6 | def test_universal_tokens_estimator_with_contractions():
7 | message = "I'm going to the park."
8 | assert universal_tokens_estimator(message) == 7
9 | message = "They'll be here soon."
10 | assert universal_tokens_estimator(message) == 7
11 |
12 |
13 | def test_universal_tokens_estimator_with_code_snippets():
14 | message = "def foo(): return 42"
15 | assert universal_tokens_estimator(message) == 7
16 | message = "if x == 10:"
17 | assert universal_tokens_estimator(message) == 6
18 |
19 |
20 | def test_universal_tokens_estimator_with_loops():
21 | message = "for i in range(10):"
22 | assert universal_tokens_estimator(message) == 8
23 | message = "while True:"
24 | assert universal_tokens_estimator(message) == 4
25 |
26 |
27 | def test_universal_tokens_estimator_with_long_sentences():
28 | message = "Understanding natural language processing is fun!"
29 | assert universal_tokens_estimator(message) == 17
30 | message = "Tokenization involves splitting text into meaningful units."
31 | assert universal_tokens_estimator(message) == 24
32 |
33 |
34 | def test_universal_tokens_estimator_with_variable_names():
35 | message = "some_variable_name = 100"
36 | assert universal_tokens_estimator(message) == 5
37 | message = "another_long_var_name = 'test'"
38 | assert universal_tokens_estimator(message) == 6
39 |
40 |
41 | def test_universal_tokens_estimator_with_function_definitions():
42 | message = "The function `def add(x, y): return x + y` adds two numbers."
43 | assert universal_tokens_estimator(message) == 20
44 | message = "Use `for i in range(5):` to loop."
45 | assert universal_tokens_estimator(message) == 14
46 |
47 |
48 | def test_universal_tokens_estimator_with_numbers():
49 | message = "The year 2023 was great!"
50 | assert universal_tokens_estimator(message) == 6
51 | message = "42 is the answer to everything."
52 | assert universal_tokens_estimator(message) == 11
53 |
54 |
55 | def test_universal_tokens_estimator_with_print_statements():
56 | message = "print('Hello, world!')"
57 | assert universal_tokens_estimator(message) == 5
58 | message = "name = \"Alice\""
59 | assert universal_tokens_estimator(message) == 3
60 |
61 |
62 | def test_openai_tokens_estimator_with_function_definition():
63 | message = "def add(a, b): return a + b"
64 | assert openai_tokens_estimator_gpt_3_5(message) == 11
65 |
66 |
67 | def test_openai_tokens_estimator_with_url():
68 | message = "Visit https://openai.com for more info."
69 | assert openai_tokens_estimator_gpt_3_5(message) == 18
70 |
71 |
72 | def test_openai_tokens_estimator_with_long_words():
73 | message = "This is a verylongwordwithoutspaces and should be counted properly."
74 | assert openai_tokens_estimator_gpt_3_5(message) == 25
75 |
76 |
77 | def test_openai_tokens_estimator_with_newlines():
78 | message = "Line1\nLine2\nLine3\n"
79 | assert openai_tokens_estimator_gpt_3_5(message) == 7
80 |
81 |
82 | def test_openai_tokens_estimator_with_gpt4_model():
83 | message = "This is a test sentence that should be counted properly even with GPT-4. I am making it longer to test the model."
84 | assert openai_tokens_estimator_gpt_4(message) == 45
85 |
--------------------------------------------------------------------------------