├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── labeler.yml ├── pull_request_template.md └── workflows │ ├── labeler.yml │ ├── latest-changes.yml.off │ ├── main.yml │ ├── mkdocs_ci.yml.off │ ├── publish-to-pypi.yml │ ├── stale.yml │ └── welcome.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── assets ├── chatbot-demo.mp4 ├── llama-inference-api-min.png ├── llm-inference-llama2_chatbot.png └── llm-inference-min.png ├── docs ├── CHANGELOG.md ├── CNAME ├── index.md ├── overrides │ └── main.html └── requirements.txt ├── examples ├── chatbot │ ├── README.md │ ├── chatbot-tutorial.ipynb │ ├── chatbot.ipynb │ ├── discord_bot.py │ ├── gradio_demo.py │ └── llama_bot_ui.py └── inference-demo.ipynb ├── mkdocs.yml ├── pyproject.toml ├── requirements ├── dev.txt └── requirements.txt ├── setup.cfg ├── setup.py ├── src ├── llm_chain │ ├── __init__.py │ ├── conversation_chain.py │ ├── llm.py │ ├── templates.py │ └── ui │ │ ├── __init__.py │ │ └── main.py └── llm_inference │ ├── __init__.py │ ├── download.py │ ├── model.py │ ├── serve.py │ └── token_manipulation.py └── tests ├── __init__.py ├── __main__.py └── llm_chain └── test_chain.py /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Lines starting with '#' are comments. 2 | # Each line is a file pattern followed by one or more owners. 3 | 4 | # More details are here: https://help.github.com/articles/about-codeowners/ 5 | 6 | # The '*' pattern is global owners. 7 | 8 | # Order is important. The last matching pattern has the most precedence. 9 | # The folders are ordered as follows: 10 | 11 | # In each subsection folders are ordered first by depth, then alphabetically. 12 | # This should make it easy to add new rules without breaking existing ones. 13 | 14 | # Global rule: 15 | * @aniketmaurya 16 | 17 | # tests 18 | /tests/** @aniketmaurya 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | #### Bug description 12 | 13 | 14 | #### Expected result 15 | 16 | 17 | #### Actual result 18 | 19 | 20 | #### Steps to reproduce 21 | 22 | 23 | 1. 24 | 2. 25 | 3. 26 | #### Context 27 | 28 | 29 | 30 | #### Your Environment 31 | 32 | 33 | * Version used: 34 | * Operating System and version: 35 | * Link to your fork: 36 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🚀 Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | #### Is your feature request related to a problem? Please describe. 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | #### Describe the solution you'd like 14 | A clear and concise description of what you want to happen. 15 | 16 | #### Describe alternatives you've considered 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | #### Additional context 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/labeler.yml: -------------------------------------------------------------------------------- 1 | # Add 'docs' to any changes within 'docs' folder or any subfolders 2 | documentation: 3 | - docs/**/* 4 | 5 | example: 6 | - examples/**/* 7 | 8 | test: 9 | - tests/**/* 10 | 11 | CI: 12 | - .github/**/* 13 | - "*.yaml" 14 | - "*.yml" 15 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | #### Changes 2 | 3 | 4 | 5 | 6 | Fixes # (issue) 7 | 8 | 9 | #### Type of change 10 | 11 | - [ ] 📚 Documentation Update 12 | - [ ] 🧪 Tests Cases 13 | - [ ] 🐞 Bug fix (non-breaking change which fixes an issue) 14 | - [ ] 🔬 New feature (non-breaking change which adds functionality) 15 | - [ ] 🚨 Breaking change (fix or feature that would cause existing functionality to not work as expected) 16 | - [ ] 📝 This change requires a documentation update 17 | 18 | 19 | #### Checklist 20 | 21 | - [ ] My code follows the style guidelines of this project 22 | - [ ] I have performed a self-review of my own code 23 | - [ ] I have commented my code, particularly in hard-to-understand areas 24 | - [ ] I have made corresponding changes to the documentation 25 | - [ ] My changes generate no new warnings 26 | - [ ] Did you update CHANGELOG in case of a major change? 27 | -------------------------------------------------------------------------------- /.github/workflows/labeler.yml: -------------------------------------------------------------------------------- 1 | name: "Pull Request Labeler" 2 | on: 3 | - pull_request_target 4 | 5 | jobs: 6 | triage: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/labeler@v4 10 | with: 11 | repo-token: "${{ secrets.GITHUB_TOKEN }}" 12 | -------------------------------------------------------------------------------- /.github/workflows/latest-changes.yml.off: -------------------------------------------------------------------------------- 1 | name: Latest Changes 2 | 3 | on: 4 | pull_request_target: 5 | branches: 6 | - main 7 | types: 8 | - closed 9 | # For manually triggering it 10 | workflow_dispatch: 11 | inputs: 12 | number: 13 | description: PR number 14 | required: true 15 | 16 | jobs: 17 | latest-changes: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v2 21 | with: 22 | token: ${{ secrets.ACTIONS_TOKEN }} 23 | - uses: docker://tiangolo/latest-changes:0.0.3 24 | with: 25 | token: ${{ secrets.GITHUB_TOKEN }} 26 | latest_changes_file: docs/CHANGELOG.md 27 | latest_changes_header: '## 0.0.3\n' 28 | debug_logs: true 29 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | on: 3 | push: 4 | branches: [ main ] 5 | pull_request: 6 | branches: [ main ] 7 | 8 | 9 | jobs: 10 | pytest: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ ubuntu-latest, macos-latest ] 15 | python-version: [3.8, 3.9, "3.10"] 16 | include: 17 | - os: ubuntu-latest 18 | path: ~/.cache/pip 19 | - os: macos-latest 20 | path: ~/Library/Caches/pip 21 | env: 22 | OS: ${{ matrix.os }} 23 | PYTHON: '3.10' 24 | 25 | 26 | steps: 27 | - uses: actions/checkout@v2 28 | with: 29 | fetch-depth: 0 # Shallow clones should be disabled for a better relevancy of analysis 30 | 31 | - name: Set up Python ${{ matrix.python-version }} 32 | uses: actions/setup-python@v2 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | 36 | - name: Cache pip 37 | uses: actions/cache@v2 38 | with: 39 | path: ${{ matrix.path }} 40 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 41 | restore-keys: | 42 | ${{ runner.os }}-pip- 43 | ${{ runner.os }}- 44 | 45 | - name: Install dependencies 46 | env: 47 | TORCH_URL: https://download.pytorch.org/whl/cpu/torch_stable.html 48 | run: | 49 | python --version 50 | pip --version 51 | python -m pip install --upgrade pip coverage pytest 52 | pip install --index-url https://download.pytorch.org/whl/nightly/cpu --pre 'torch>=2.1.0dev' 53 | pip install lit_gpt@git+https://github.com/aniketmaurya/install-lit-gpt.git@install 54 | pip install . --quiet 55 | pip list 56 | shell: bash 57 | 58 | - name: Prepare Test 59 | run: | 60 | python tests # download test data 61 | 62 | - name: Run Test with Coverage 63 | run: | 64 | coverage erase 65 | coverage run -m pytest 66 | 67 | - name: Generate Coverage Report 68 | run: | 69 | coverage report -m -i 70 | coverage xml -i 71 | 72 | - name: Upload Coverage to Codecov 73 | if: runner.os != 'macOS' 74 | uses: codecov/codecov-action@v1 75 | with: 76 | token: ${{ secrets.CODECOV_TOKEN }} 77 | file: ./coverage.xml 78 | flags: unittests 79 | env_vars: OS,PYTHON 80 | name: codecov-umbrella 81 | fail_ci_if_error: false 82 | -------------------------------------------------------------------------------- /.github/workflows/mkdocs_ci.yml.off: -------------------------------------------------------------------------------- 1 | name: MkDocs 2 | on: 3 | push: 4 | branches: 5 | - master 6 | - main 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions/setup-python@v2 13 | with: 14 | python-version: 3.x 15 | - run: pip install -r docs/requirements.txt 16 | - run: mkdocs gh-deploy --force 17 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI 2 | 3 | on: push 4 | 5 | jobs: 6 | build-n-publish: 7 | name: Build and publish Python 🐍 distributions 📦 to PyPI 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@master 12 | - name: Set up Python 3.9 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: 3.9 16 | 17 | - name: Install pypa/build 18 | run: >- 19 | python -m 20 | pip install 21 | build 22 | --user 23 | - name: Build a binary wheel and a source tarball 24 | run: >- 25 | make clean && make build 26 | 27 | - name: Publish distribution 📦 to PyPI 28 | if: startsWith(github.ref, 'refs/tags') 29 | uses: pypa/gh-action-pypi-publish@master 30 | with: 31 | password: ${{ secrets.PYPI_API_TOKEN }} 32 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: Mark stale issues and pull requests 2 | 3 | on: 4 | schedule: 5 | - cron: "30 1 * * *" 6 | 7 | jobs: 8 | stale: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/stale@v3 14 | with: 15 | repo-token: ${{ secrets.GITHUB_TOKEN }} 16 | stale-issue-message: 'Stale issue message' 17 | stale-pr-message: 'Stale pull request message' 18 | stale-issue-label: 'no-issue-activity' 19 | stale-pr-label: 'no-pr-activity' 20 | -------------------------------------------------------------------------------- /.github/workflows/welcome.yml: -------------------------------------------------------------------------------- 1 | name: Greet New Contributors 2 | 3 | on: [pull_request_target, issues] 4 | 5 | jobs: 6 | greeting: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/first-interaction@v1 10 | with: 11 | repo-token: ${{ secrets.GITHUB_TOKEN }} 12 | issue-message: "👋 @${{github.actor}}! Thank you for opening your first issue in this repo. We are so happy that you have decided to contribute and value your contribution. Please read these materials before proceeding: [Contributing Guide](https://github.com/gradsflow/gradsflow/blob/master/CONTRIBUTING.md) and [Code of Conduct](https://github.com/gradsflow/gradsflow/blob/master/CODE_OF_CONDUCT.md)." 13 | pr-message: "👋 @${{github.actor}}! Thank you for opening your first pull request in this repo. We are so happy that you have decided to contribute and value your contribution. Please read these materials before proceeding: [Contributing Guide](https://github.com/gradsflow/gradsflow/blob/master/CONTRIBUTING.md) and [Code of Conduct](https://github.com/gradsflow/gradsflow/blob/master/CODE_OF_CONDUCT.md)." 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | #.idea/ 153 | *.pth 154 | *.ckpt 155 | *.tokenizer 156 | *.model 157 | .DS_Store 158 | checkpoints/ 159 | .vscode/ 160 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | default_language_version: 4 | python: python3 5 | 6 | ci: 7 | autofix_prs: true 8 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 9 | autoupdate_schedule: quarterly 10 | # submodules: true 11 | 12 | repos: 13 | - repo: https://github.com/pre-commit/pre-commit-hooks 14 | rev: v4.5.0 15 | hooks: 16 | - id: end-of-file-fixer 17 | - id: trailing-whitespace 18 | - id: check-yaml 19 | - id: check-docstring-first 20 | - id: check-toml 21 | - id: check-case-conflict 22 | - id: detect-private-key 23 | 24 | - repo: https://github.com/psf/black 25 | rev: 24.3.0 26 | hooks: 27 | - id: black 28 | name: "Black: The uncompromising Python code formatter" 29 | 30 | - repo: https://github.com/PyCQA/isort 31 | rev: 5.13.2 32 | hooks: 33 | - id: isort 34 | name: "Sort Imports" 35 | args: [ "--profile black" ] 36 | 37 | - repo: https://github.com/codespell-project/codespell 38 | rev: v2.2.6 39 | hooks: 40 | - id: codespell 41 | args: 42 | - --ignore-words-list 43 | - "ans,hist" 44 | - --skip 45 | - "*.bib,*.ipynb" 46 | 47 | - repo: https://github.com/asottile/pyupgrade 48 | rev: v3.15.2 49 | hooks: 50 | - id: pyupgrade 51 | args: [ --py39-plus ] 52 | 53 | - repo: https://github.com/PyCQA/bandit 54 | rev: 1.7.8 55 | hooks: 56 | - id: bandit 57 | language_version: python3 58 | exclude: tests/ 59 | args: 60 | - -s 61 | - "B404,B602,B603,B607,B101" 62 | 63 | - repo: https://github.com/kynan/nbstripout 64 | rev: 0.7.1 65 | hooks: 66 | - id: nbstripout 67 | args: [ "max-size 100k" ] 68 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | hello@domain.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guidelines 2 | 3 | 👍🎉 First off, thanks for taking the time to contribute! 🎉👍 4 | 5 | The following is a set of guidelines for contributing to Python-Project-Template and its packages, which are hosted in the Python-Project-Template Organization on GitHub. These are mostly guidelines, not rules. Use your best judgment, and feel free to propose changes to this document in a pull request. 6 | 7 | We welcome any kind of contribution to our software, from simple comment or question to a full fledged [pull request](https://help.github.com/articles/about-pull-requests/). Please read and follow our [Code of Conduct](CODE_OF_CONDUCT.md). 8 | 9 | A contribution can be one of the following cases: 10 | 11 | 1. you have a question; 12 | 1. you think you may have found a bug (including unexpected behavior); 13 | 1. you want to make some kind of change to the code base (e.g. to fix a bug, to add a new feature, to update documentation); 14 | 1. you want to make a new release of the code base. 15 | 16 | The sections below outline the steps in each case. 17 | 18 | ## You have a question 19 | 20 | 1. use the search functionality [here](https://github.com/aniketmaurya/python-project-template/issues) to see if someone already filed the same issue; 21 | 2. if your issue search did not yield any relevant results, make a new issue; 22 | 3. apply the "Question" label; apply other labels when relevant. 23 | 4. You can join our Slack group as well. 24 | 25 | ## You think you may have found a bug 26 | 27 | 1. use the search functionality [here](https://github.com/aniketmaurya/python-project-template/issues) to see if someone already filed the same issue; 28 | 1. if your issue search did not yield any relevant results, make a new issue, making sure to provide enough information to the rest of the community to understand the cause and context of the problem. Depending on the issue, you may want to include: 29 | - the [SHA hashcode](https://help.github.com/articles/autolinked-references-and-urls/#commit-shas) of the commit that is causing your problem; 30 | - some identifying information (name and version number) for dependencies you're using; 31 | - information about the operating system; 32 | 1. apply relevant labels to the newly created issue. 33 | 34 | ## You want to make some kind of change to the code base 35 | 36 | 1. (**important**) announce your plan to the rest of the community *before you start working*. This announcement should be in the form of a (new) issue; 37 | 1. (**important**) wait until some kind of consensus is reached about your idea being a good idea; 38 | 1. if needed, fork the repository to your own Github profile and create your own feature branch off of the latest master commit. While working on your feature branch, make sure to stay up to date with the master branch by pulling in changes, possibly from the 'upstream' repository (follow the instructions [here](https://help.github.com/articles/configuring-a-remote-for-a-fork/) and [here](https://help.github.com/articles/syncing-a-fork/)); 39 | 1. make sure the existing tests still work by running ``pytest``; 40 | 1. add your own tests (if necessary); 41 | 1. update or expand the documentation; 42 | 1. update the `docs/CHANGELOG.md` file with change; 43 | 1. push your feature branch to (your fork of) the https://github.com/aniketmaurya/python-project-template repository on GitHub; 44 | 1. create the pull request, e.g. following the instructions [here](https://help.github.com/articles/creating-a-pull-request/). 45 | 46 | In case you feel like you've made a valuable contribution, but you don't know how to write or run tests for it, or how to generate the documentation: don't let this discourage you from making the pull request; we can help you! Just go ahead and submit the pull request, but keep in mind that you might be asked to append additional commits to your pull request. 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Aniket Maurya 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Manifest syntax https://docs.python.org/2/distutils/sourcedist.html 2 | graft wheelhouse 3 | 4 | recursive-exclude __pycache__ *.py[cod] *.orig 5 | 6 | # Include the README and CHANGELOG 7 | include *.md 8 | recursive-include assets *.png 9 | 10 | exclude app.py 11 | exclude .lightning 12 | exclude .lightningignore 13 | 14 | # Include the license file 15 | include LICENSE 16 | 17 | # Exclude build configs 18 | exclude *.sh 19 | exclude *.toml 20 | exclude *.svg 21 | exclude *.yml 22 | exclude *.yaml 23 | 24 | # exclude tests from package 25 | recursive-exclude tests * 26 | recursive-exclude site * 27 | exclude tests 28 | 29 | # Exclude the documentation files 30 | recursive-exclude docs * 31 | exclude docs 32 | 33 | # Include the Requirements 34 | include requirements/requirements.txt 35 | recursive-include requirements/ *.txt 36 | 37 | # Exclude Makefile 38 | exclude Makefile 39 | 40 | prune .git 41 | prune .github 42 | prune scripts 43 | prune temp* 44 | prune test* 45 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | build-docs: 2 | cp README.md docs/index.md 3 | 4 | docsserve: 5 | mkdocs serve --dirtyreload --livereload 6 | 7 | test: 8 | python tests/__init__.py 9 | pytest 10 | 11 | coverage: ## Run tests with coverage 12 | coverage erase 13 | coverage run -m pytest 14 | coverage report -m 15 | coverage xml 16 | 17 | clean: 18 | rm -rf dist 19 | find . -type f -name "*.DS_Store" -ls -delete 20 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 21 | find . | grep -E ".pytest_cache" | xargs rm -rf 22 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 23 | rm -f .coverage 24 | 25 | style: 26 | black . 27 | isort --profile black . 28 | 29 | push: 30 | git push && git push --tags 31 | 32 | build: 33 | python -m build 34 | 35 | publish-test: 36 | $(style clean build) 37 | twine upload -r testpypi dist/* 38 | 39 | publish-prod: 40 | $(style clean build) 41 | twine upload dist/* 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Large Language Model (LLM) Inference API and Chatbot 🦙 2 | 3 | ![project banner](https://github.com/aniketmaurya/llm-inference/raw/main/assets/llm-inference-min.png) 4 | 5 | Inference API for LLMs like LLaMA and Falcon powered by Lit-GPT from [Lightning AI](https://lightning.ai) 6 | 7 | ``` 8 | pip install llm-inference 9 | ``` 10 | 11 | ### Install from main branch 12 | ```bash 13 | pip install git+https://github.com/aniketmaurya/llm-inference.git@main 14 | 15 | # You need to manually install [Lit-GPT](https://github.com/Lightning-AI/lit-gpt) and setup the model weights to use this project. 16 | pip install lit_gpt@git+https://github.com/aniketmaurya/install-lit-gpt.git@install 17 | ``` 18 | 19 | ## For Inference 20 | 21 | ```python 22 | from llm_inference import LLMInference, prepare_weights 23 | 24 | path = prepare_weights("EleutherAI/pythia-70m") 25 | model = LLMInference(checkpoint_dir=path) 26 | 27 | print(model("New York is located in")) 28 | ``` 29 | 30 | 31 | ## How to use the Chatbot 32 | 33 | ![chatbot image](./assets/llm-inference-llama2_chatbot.png) 34 | 35 | 36 | ```python 37 | from llm_chain import LitGPTConversationChain, LitGPTLLM 38 | from llm_inference import prepare_weights 39 | 40 | path = str(prepare_weights("meta-llama/Llama-2-7b-chat-hf")) 41 | llm = LitGPTLLM(checkpoint_dir=path, quantize="bnb.nf4") # 7GB GPU memory 42 | bot = LitGPTConversationChain.from_llm(llm=llm, prompt=llama2_prompt_template) 43 | 44 | print(bot.send("hi, what is the capital of France?")) 45 | ``` 46 | 47 | ## Launch Chatbot App 48 | 49 | 52 | 53 | **1. Download weights** 54 | ```py 55 | from llm_inference import prepare_weights 56 | path = prepare_weights("meta-llama/Llama-2-7b-chat-hf") 57 | ``` 58 | 59 | **2. Launch Gradio App** 60 | 61 | ``` 62 | python examples/chatbot/gradio_demo.py 63 | ``` 64 | -------------------------------------------------------------------------------- /assets/chatbot-demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aniketmaurya/llm-inference/5bb323c4cce70dcbe81cf794aaa0a66b87fe3083/assets/chatbot-demo.mp4 -------------------------------------------------------------------------------- /assets/llama-inference-api-min.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aniketmaurya/llm-inference/5bb323c4cce70dcbe81cf794aaa0a66b87fe3083/assets/llama-inference-api-min.png -------------------------------------------------------------------------------- /assets/llm-inference-llama2_chatbot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aniketmaurya/llm-inference/5bb323c4cce70dcbe81cf794aaa0a66b87fe3083/assets/llm-inference-llama2_chatbot.png -------------------------------------------------------------------------------- /assets/llm-inference-min.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aniketmaurya/llm-inference/5bb323c4cce70dcbe81cf794aaa0a66b87fe3083/assets/llm-inference-min.png -------------------------------------------------------------------------------- /docs/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Release Notes 2 | 3 | ## 0.0.1 4 | * Setup repo 5 | -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | CUSTOM_DOMAIN.com 2 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aniketmaurya/llm-inference/5bb323c4cce70dcbe81cf794aaa0a66b87fe3083/docs/index.md -------------------------------------------------------------------------------- /docs/overrides/main.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block extrahead %} 4 | {% set title = config.site_name %} 5 | {% if page and page.meta and page.meta.title %} 6 | {% set title = title ~ " - " ~ page.meta.title %} 7 | {% elif page and page.title and not page.is_homepage %} 8 | {% set title = title ~ " - " ~ page.title | striptags %} 9 | {% endif %} 10 | 11 | 12 | {{ title }} 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | {% endblock %} 30 | 31 | {% block outdated %} 32 | You're not viewing the latest version. 33 | 34 | Click here to go to latest. 35 | 36 | {% endblock %} 37 | 38 | 39 | {% set extracopyright %} 40 | Copyright (c) 2021 Aniket Maurya 41 | {% endset %} 42 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs==1.2.2 2 | mkdocs-material==7.2.4 3 | mkdocs-material-extensions==1.0.1 4 | mkdocs-git-revision-date-localized-plugin==0.9.2 5 | mkdocs-macros-plugin==0.6.0 6 | mkdocs-autorefs==0.2.1 7 | mkdocstrings==0.15.2 8 | tags-macros-plugin @ git+https://github.com/jldiaz/mkdocs-plugin-tags.git@d26e2f124e4f3471639d426459e281080988fe7a 9 | mkdocs-jupyter 10 | mkdocs-meta-descriptions-plugin 11 | jupyter_contrib_nbextensions 12 | -------------------------------------------------------------------------------- /examples/chatbot/README.md: -------------------------------------------------------------------------------- 1 | # ChatServer 2 | 3 | ChatBot System built with LangChain and Lightning AI 4 | 5 | ## How to run 6 | 7 | ```bash 8 | git clone https://github.com/aniketmaurya/chatbot-server.git 9 | cd chatbot-server 10 | 11 | pip install -e . 12 | lightning run app app.py 13 | ``` 14 | 15 | Please initiate a conversation with the chatbot. 16 | 17 | ![](./assets/chatserver-min.png) 18 | -------------------------------------------------------------------------------- /examples/chatbot/chatbot-tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# LLM Generation at high level\n", 8 | "\n", 9 | "```python\n", 10 | "query = \"Capital of\"\n", 11 | "\n", 12 | "output = \"\"\n", 13 | "for i in range(MAX_GENERATED_TOKENS):\n", 14 | " output = output + LLM(output)\n", 15 | "```" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "0. output = Capital of\n", 23 | "1. output = Capital of France\n", 24 | "1. output = Capital of France is\n", 25 | "1. output = Capital of France is Paris" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "```\n", 54 | "query = A chat between a curious user and an artificial intelligence assistant.\n", 55 | " The assistant gives helpful, detailed, and polite answers to the user's questions.\n", 56 | " USER: My name is Aniket\n", 57 | " ASSISTANT:\n", 58 | "```" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "longchat_template = \"\"\"A chat between a curious user and an artificial intelligence assistant.\n", 82 | "The assistant gives helpful, detailed, and polite answers to the user's questions.\n", 83 | "USER: {input}\n", 84 | "ASSISTANT:\"\"\"" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "print(longchat_template.format(input=\"My name is Aniket\"))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "print(longchat_template.format(input=\"What is the capital of France?\"))" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "from llm_inference import LLMInference, prepare_weights\n", 147 | "from rich import print" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "# path = str(prepare_weights(\"meta-llama/Llama-2-7b-chat-hf\"))\n", 157 | "# model = LLMInference(checkpoint_dir=path, quantize=\"bnb.nf4\")" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "path = str(prepare_weights(\"lmsys/longchat-7b-16k\"))\n", 167 | "model = LLMInference(checkpoint_dir=path, quantize=\"bnb.nf4\")" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "longchat_template = \"\"\"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n", 177 | "USER: {input}\n", 178 | "ASSISTANT:\"\"\"" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "query = longchat_template.format(input=\"What is the capital of France?\")\n", 209 | "output = model.chat(query)\n", 210 | "print(output)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "output = model.chat(longchat_template.format(input=\"My name is Aniket\"))\n", 241 | "print(output)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "output = model.chat(longchat_template.format(input=\"Write a poem on Lightning AI\"))\n", 251 | "print(output) " 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "metadata": {}, 299 | "source": [ 300 | "## Memory" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "output = model.chat(longchat_template.format(input=\"My name is Aniket?\"))\n", 310 | "print(output)\n", 311 | "\n", 312 | "\n", 313 | "output = model.chat(longchat_template.format(input=\"What is my name?\"))\n", 314 | "print(output)" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "metadata": {}, 349 | "outputs": [], 350 | "source": [] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": null, 355 | "metadata": {}, 356 | "outputs": [], 357 | "source": [] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": null, 376 | "metadata": {}, 377 | "outputs": [], 378 | "source": [] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "metadata": {}, 384 | "outputs": [], 385 | "source": [ 386 | "longchat_template = \"\"\"A chat between a curious user and an artificial intelligence assistant.\n", 387 | "The assistant gives helpful, detailed, and polite answers to the user's questions.\n", 388 | "Context:\n", 389 | "User: My name is Aniket\n", 390 | "Assistant: Hi, Aniket how are you?\n", 391 | "\n", 392 | "USER: {input}\n", 393 | "ASSISTANT:\"\"\"" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "metadata": {}, 400 | "outputs": [], 401 | "source": [] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": null, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": null, 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": null, 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [ 423 | "longchat_template = \"\"\"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n", 424 | "Context:\n", 425 | "USER: My name is Aniket!\n", 426 | "ASSISTANT: How can I help you Aniket?\n", 427 | "USER: {input}\n", 428 | "ASSISTANT:\"\"\"\n", 429 | "\n", 430 | "output = model.chat(longchat_template.format(input=\"What is my name?\"))\n", 431 | "print(output)" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": null, 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "metadata": {}, 452 | "outputs": [], 453 | "source": [] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": null, 458 | "metadata": {}, 459 | "outputs": [], 460 | "source": [] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": null, 465 | "metadata": {}, 466 | "outputs": [], 467 | "source": [] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": null, 472 | "metadata": {}, 473 | "outputs": [], 474 | "source": [] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": null, 479 | "metadata": {}, 480 | "outputs": [], 481 | "source": [] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": null, 486 | "metadata": {}, 487 | "outputs": [], 488 | "source": [] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": null, 493 | "metadata": {}, 494 | "outputs": [], 495 | "source": [ 496 | "longchat_template = \"\"\"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n", 497 | "Context:\n", 498 | "{history}\n", 499 | "USER: {input}\n", 500 | "ASSISTANT:\"\"\"\n", 501 | "\n", 502 | "history =\"USER: Hi, I am Aniket!\\nAssistant: How can I help you Aniket?\"\n", 503 | "\n", 504 | "query = longchat_template.format(input=\"What is my name?\", history=history)\n", 505 | "output = model.chat(query)\n", 506 | "print(output)" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": null, 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "metadata": {}, 520 | "outputs": [], 521 | "source": [] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": null, 533 | "metadata": {}, 534 | "outputs": [], 535 | "source": [] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": null, 540 | "metadata": {}, 541 | "outputs": [], 542 | "source": [] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": null, 547 | "metadata": {}, 548 | "outputs": [], 549 | "source": [] 550 | }, 551 | { 552 | "cell_type": "markdown", 553 | "metadata": {}, 554 | "source": [ 555 | "[PromptTemplate doc](https://python.langchain.com/docs/modules/model_io/prompts/prompt_templates/)" 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": null, 561 | "metadata": {}, 562 | "outputs": [], 563 | "source": [ 564 | "from langchain.prompts import PromptTemplate\n", 565 | "\n", 566 | "longchat_template = \"\"\"A chat between a curious user and an artificial intelligence assistant.\n", 567 | "The assistant gives helpful, detailed, and polite answers to the user's questions.\n", 568 | "Context:\n", 569 | "{history}\n", 570 | "USER: {input}\n", 571 | "ASSISTANT:\"\"\"\n", 572 | "\n", 573 | "longchat_prompt_template = PromptTemplate(\n", 574 | " input_variables=[\"input\", \"history\"], template=longchat_template\n", 575 | ")" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": null, 581 | "metadata": {}, 582 | "outputs": [], 583 | "source": [] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "execution_count": null, 588 | "metadata": {}, 589 | "outputs": [], 590 | "source": [] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": null, 595 | "metadata": {}, 596 | "outputs": [], 597 | "source": [ 598 | "print(longchat_prompt_template.format(\n", 599 | " input = \"What is my name?\",\n", 600 | " history =\"USER: Hi, I am Aniket!\\nAssistant: How can I help you Aniket?\"\n", 601 | "))" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": null, 607 | "metadata": {}, 608 | "outputs": [], 609 | "source": [] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": null, 614 | "metadata": {}, 615 | "outputs": [], 616 | "source": [ 617 | "from langchain.chains import ConversationChain\n", 618 | "from langchain.memory import ConversationBufferWindowMemory\n", 619 | "\n", 620 | "from llm_chain import LitGPTLLM\n", 621 | "\n", 622 | "\n", 623 | "llm = LitGPTLLM(model=model)\n", 624 | "\n", 625 | "\n", 626 | "conversation = ConversationChain(\n", 627 | " llm=llm,\n", 628 | " prompt=longchat_prompt_template,\n", 629 | " verbose=False,\n", 630 | " memory=ConversationBufferWindowMemory(ai_prefix=\"Assistant\", human_prefix=\"User\", k=2),\n", 631 | ")" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": null, 637 | "metadata": {}, 638 | "outputs": [], 639 | "source": [ 640 | "conversation(\"hi, I am Aniket\")[\"response\"]" 641 | ] 642 | }, 643 | { 644 | "cell_type": "code", 645 | "execution_count": null, 646 | "metadata": {}, 647 | "outputs": [], 648 | "source": [ 649 | "conversation(\"What is my name?\")[\"response\"]" 650 | ] 651 | }, 652 | { 653 | "cell_type": "code", 654 | "execution_count": null, 655 | "metadata": {}, 656 | "outputs": [], 657 | "source": [ 658 | "conversation(\"What is the timezone of London?\")[\"response\"]" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": null, 664 | "metadata": {}, 665 | "outputs": [], 666 | "source": [ 667 | "print(conversation.memory.chat_memory)" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": null, 673 | "metadata": {}, 674 | "outputs": [], 675 | "source": [] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": null, 680 | "metadata": {}, 681 | "outputs": [], 682 | "source": [] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "execution_count": null, 687 | "metadata": {}, 688 | "outputs": [], 689 | "source": [] 690 | }, 691 | { 692 | "cell_type": "code", 693 | "execution_count": null, 694 | "metadata": {}, 695 | "outputs": [], 696 | "source": [] 697 | }, 698 | { 699 | "cell_type": "code", 700 | "execution_count": null, 701 | "metadata": {}, 702 | "outputs": [], 703 | "source": [] 704 | }, 705 | { 706 | "cell_type": "code", 707 | "execution_count": null, 708 | "metadata": {}, 709 | "outputs": [], 710 | "source": [ 711 | "from langchain.memory import ConversationBufferMemory\n", 712 | "\n", 713 | "conversation = ConversationChain(\n", 714 | " llm=llm,\n", 715 | " prompt=longchat_prompt_template,\n", 716 | " verbose=False,\n", 717 | " memory=ConversationBufferMemory(ai_prefix=\"Assistant\", human_prefix=\"User\"),\n", 718 | ")" 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": null, 724 | "metadata": {}, 725 | "outputs": [], 726 | "source": [ 727 | "output = conversation(\n", 728 | " \"PyTorch Lightning is an open-source library developed by Lightning AI team.\"\n", 729 | ")[\"response\"]\n", 730 | "print(output)" 731 | ] 732 | }, 733 | { 734 | "cell_type": "code", 735 | "execution_count": null, 736 | "metadata": {}, 737 | "outputs": [], 738 | "source": [ 739 | "output = conversation(\n", 740 | " \"who developed PyTorch Lightning? just give me the name of the team or person and nothing else.\"\n", 741 | ")[\"response\"]\n", 742 | "print(output)" 743 | ] 744 | }, 745 | { 746 | "cell_type": "code", 747 | "execution_count": null, 748 | "metadata": {}, 749 | "outputs": [], 750 | "source": [] 751 | }, 752 | { 753 | "cell_type": "markdown", 754 | "metadata": {}, 755 | "source": [ 756 | "* https://twitter.com/yanndubs/status/1681644889145237504?s=20" 757 | ] 758 | }, 759 | { 760 | "cell_type": "code", 761 | "execution_count": null, 762 | "metadata": {}, 763 | "outputs": [], 764 | "source": [] 765 | } 766 | ], 767 | "metadata": { 768 | "kernelspec": { 769 | "display_name": "Python 3 (ipykernel)", 770 | "language": "python", 771 | "name": "python3" 772 | }, 773 | "language_info": { 774 | "codemirror_mode": { 775 | "name": "ipython", 776 | "version": 3 777 | }, 778 | "file_extension": ".py", 779 | "mimetype": "text/x-python", 780 | "name": "python", 781 | "nbconvert_exporter": "python", 782 | "pygments_lexer": "ipython3", 783 | "version": "3.10.12" 784 | } 785 | }, 786 | "nbformat": 4, 787 | "nbformat_minor": 4 788 | } 789 | -------------------------------------------------------------------------------- /examples/chatbot/chatbot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from llm_chain import LitGPTConversationChain, LitGPTLLM\n", 10 | "from llm_inference import prepare_weights\n", 11 | "from rich import print" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "bot = LitGPTConversationChain.from_llm(llm=\"dummy\", verbose=True)\n", 21 | "print(bot(\"Hi, I am Aniket!\"))" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# path = str(prepare_weights(\"EleutherAI/pythia-70m\"))\n", 31 | "# llm = LitGPTLLM(checkpoint_dir=path, accelerator=\"cpu\")" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "path = str(prepare_weights(\"lmsys/longchat-13b-16k\"))\n", 41 | "llm = LitGPTLLM(checkpoint_dir=path, quantize=\"bnb.nf4\")" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "from llm_chain.templates import longchat_prompt_template\n", 51 | "\n", 52 | "print(longchat_prompt_template)\n", 53 | "\n", 54 | "bot = LitGPTConversationChain.from_llm(\n", 55 | " llm=llm, prompt=longchat_prompt_template, verbose=True\n", 56 | ")" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "print(bot.send(\"Hi, I am Adam\"))" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "print(bot.send(\"What is the timezone of London?\"))" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "bot.send(\"my monitor screen size is 32 inch.\")" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "bot.send(\"What did I tell you about my monitor screen size?\")" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "bot.clear() # clears the chatbot memory\n", 102 | "bot.send(\"What did I tell you about my monitor screen size?\")" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "Managing the memory" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "from langchain.memory import ConversationBufferMemory\n", 119 | "from langchain.chains import ConversationChain" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "# Here it is by default set to \"AI\"\n", 129 | "conversation = ConversationChain(\n", 130 | " llm=llm,\n", 131 | " prompt=longchat_prompt_template,\n", 132 | " verbose=False,\n", 133 | " memory=ConversationBufferMemory(ai_prefix=\"Assistant\", human_prefix=\"User\"),\n", 134 | ")" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "output = conversation(\n", 144 | " \"PyTorch Lightning is an open-source library developed by Lightning AI team.\"\n", 145 | ")[\"response\"]\n", 146 | "print(output)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "output = conversation(\n", 156 | " \"who developed PyTorch Lightning? just give me the name of the team or person and nothing else.\"\n", 157 | ")[\"response\"]\n", 158 | "print(output)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [] 174 | } 175 | ], 176 | "metadata": { 177 | "kernelspec": { 178 | "display_name": "am", 179 | "language": "python", 180 | "name": "python3" 181 | }, 182 | "language_info": { 183 | "codemirror_mode": { 184 | "name": "ipython", 185 | "version": 3 186 | }, 187 | "file_extension": ".py", 188 | "mimetype": "text/x-python", 189 | "name": "python", 190 | "nbconvert_exporter": "python", 191 | "pygments_lexer": "ipython3", 192 | "version": "3.11.3" 193 | } 194 | }, 195 | "nbformat": 4, 196 | "nbformat_minor": 2 197 | } 198 | -------------------------------------------------------------------------------- /examples/chatbot/discord_bot.py: -------------------------------------------------------------------------------- 1 | # pip install discord.py 2 | # Learn more here - https://github.com/aniketmaurya/docs-QnA-discord-bot/tree/main 3 | import os 4 | 5 | import discord 6 | from dotenv import load_dotenv 7 | 8 | from llm_chain import LitGPTConversationChain, LitGPTLLM 9 | from llm_chain.templates import longchat_prompt_template 10 | from llm_inference import prepare_weights 11 | 12 | load_dotenv() 13 | 14 | # path = prepare_weights("lmsys/longchat-7b-16k") 15 | path = "checkpoints/lmsys/longchat-13b-16k" 16 | llm = LitGPTLLM(checkpoint_dir=path, quantize="bnb.nf4") 17 | llm("warm up!") 18 | TOKEN = os.environ.get("DISCORD_BOT_TOKEN") 19 | 20 | 21 | class MyClient(discord.Client): 22 | BOT_INSTANCE = {} 23 | 24 | def chat(self, user_id, query): 25 | if user_id in self.BOT_INSTANCE: 26 | return self.BOT_INSTANCE[user_id].send(query) 27 | 28 | self.BOT_INSTANCE[user_id] = LitGPTConversationChain.from_llm( 29 | llm=llm, prompt=longchat_prompt_template 30 | ) 31 | return self.BOT_INSTANCE[user_id].send(query) 32 | 33 | bot = LitGPTConversationChain.from_llm(llm=llm, prompt=longchat_prompt_template) 34 | 35 | async def on_ready(self): 36 | print(f"Logged on as {self.user}!") 37 | 38 | async def on_message(self, message): 39 | if message.author.id == self.user.id: 40 | return 41 | print(f"Message from {message.author}: {message.content}") 42 | 43 | if message.content.startswith("!help"): 44 | query = message.content.replace("!help", "") 45 | result = self.bot.send(query) 46 | await message.reply(result, mention_author=True) 47 | 48 | 49 | intents = discord.Intents.default() 50 | intents.message_content = True 51 | 52 | client = MyClient(intents=intents) 53 | client.run(TOKEN) 54 | -------------------------------------------------------------------------------- /examples/chatbot/gradio_demo.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from llm_chain import LitGPTConversationChain, LitGPTLLM 4 | from llm_chain.templates import llama2_prompt_template 5 | from llm_inference import prepare_weights 6 | 7 | path = str(prepare_weights("meta-llama/Llama-2-7b-chat-hf")) 8 | llm = LitGPTLLM(checkpoint_dir=path, quantize="bnb.nf4") 9 | llm("warmup") 10 | bot = LitGPTConversationChain.from_llm(llm=llm, prompt=llama2_prompt_template) 11 | 12 | 13 | with gr.Blocks() as demo: 14 | chatbot = gr.Chatbot() 15 | msg = gr.Textbox() 16 | clear = gr.ClearButton([msg, chatbot]) 17 | clear.click(fn=bot.clear) 18 | 19 | def respond(message, chat_history): 20 | bot_message = bot.send(message) 21 | chat_history.append((f"👤 {message}", f"{bot_message}")) 22 | return "", chat_history 23 | 24 | msg.submit(respond, [msg, chatbot], [msg, chatbot]) 25 | 26 | if __name__ == "__main__": 27 | demo.launch() 28 | -------------------------------------------------------------------------------- /examples/chatbot/llama_bot_ui.py: -------------------------------------------------------------------------------- 1 | # please use the gradio demo instead - https://github.com/aniketmaurya/llm-inference/blob/main/examples/chatbot/gradio_demo.py 2 | # This script has to be updated to the latest version 3 | 4 | import lightning as L 5 | import lightning.app.frontend as frontend 6 | 7 | from llm_chain.ui import ui_render_fn 8 | from llm_inference.serve import PromptRequest, Response, ServeLLaMA 9 | 10 | checkpoint_path = "weights/state_dict.pth" 11 | tokenizer_path = "weights/tokenizer.model" 12 | 13 | 14 | class ChatBotApp(L.LightningFlow): 15 | def __init__(self): 16 | super().__init__() 17 | self.llm_serve = ServeLLaMA( 18 | input_type=PromptRequest, 19 | output_type=Response, 20 | checkpoint_path=checkpoint_path, 21 | tokenizer_path=tokenizer_path, 22 | ) 23 | self.llm_url = "" 24 | 25 | def run(self): 26 | self.llm_serve.run() 27 | if self.llm_serve.url: 28 | print("url is ready:", self.llm_serve.url) 29 | self.llm_url = self.llm_serve.url 30 | 31 | def configure_layout(self): 32 | return frontend.StreamlitFrontend(render_fn=ui_render_fn) 33 | 34 | 35 | if __name__ == "__main__": 36 | app = L.LightningApp(ChatBotApp()) 37 | -------------------------------------------------------------------------------- /examples/inference-demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from llm_inference import LLMInference, prepare_weights\n", 10 | "from rich import print" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "path = prepare_weights(\"EleutherAI/pythia-70m\")" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "model = LLMInference(checkpoint_dir=path, precision=32, accelerator=\"cpu\")" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "print(model(\"New York is located in\", temperature=1))" 38 | ] 39 | } 40 | ], 41 | "metadata": { 42 | "kernelspec": { 43 | "display_name": "pl", 44 | "language": "python", 45 | "name": "python3" 46 | }, 47 | "language_info": { 48 | "codemirror_mode": { 49 | "name": "ipython", 50 | "version": 3 51 | }, 52 | "file_extension": ".py", 53 | "mimetype": "text/x-python", 54 | "name": "python", 55 | "nbconvert_exporter": "python", 56 | "pygments_lexer": "ipython3", 57 | "version": "3.10.11" 58 | } 59 | }, 60 | "nbformat": 4, 61 | "nbformat_minor": 2 62 | } 63 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: TEMPLATE_NAME 2 | site_description: TEMPLATE_DESCRIPTION 3 | site_author: AUTHOR NAME 4 | copyright: 'Copyright © 2021 AUTHOR NAME' 5 | 6 | banner_url: https://IMAGE_URL 7 | repo_url: https://github.com/aniketmaurya/python-project-template/ 8 | repo_name: aniketmaurya/python-project-template 9 | 10 | theme: 11 | name: material 12 | custom_dir: docs/overrides 13 | palette: 14 | - scheme: default 15 | primary: black 16 | accent: deep orange 17 | toggle: 18 | icon: material/lightbulb-outline 19 | name: Switch to dark mode 20 | 21 | - scheme: slate 22 | primary: black 23 | accent: deep orange 24 | toggle: 25 | icon: material/lightbulb 26 | name: Switch to light mode 27 | 28 | logo: https://IMAGE_URL 29 | favicon: https://IMAGE_URL 30 | features: 31 | - search.suggest 32 | - search.highlight 33 | 34 | # Necessary for search to work properly 35 | include_search_page: false 36 | search_index_only: true 37 | 38 | markdown_extensions: 39 | - meta 40 | - pymdownx.highlight 41 | - pymdownx.superfences 42 | - pymdownx.details 43 | - pymdownx.superfences 44 | - admonition 45 | - pymdownx.emoji: 46 | emoji_index: "!!python/name:materialx.emoji.twemoji" 47 | emoji_generator: "!!python/name:materialx.emoji.to_svg" 48 | - toc: 49 | permalink: true 50 | 51 | plugins: 52 | - git-revision-date-localized 53 | - search 54 | - autorefs 55 | - mkdocs-jupyter 56 | - mkdocstrings: 57 | default_handler: python 58 | handlers: 59 | python: 60 | rendering: 61 | show_source: false 62 | 63 | extra: 64 | homepage: https://TEMPLATE_URL 65 | 66 | nav: 67 | - Introduction: 'index.md' 68 | - Release Notes: 'CHANGELOG.md' 69 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools >= 40.9.0", 4 | ] 5 | build-backend = "setuptools.build_meta" 6 | 7 | [tool.isort] 8 | profile = "black" 9 | -------------------------------------------------------------------------------- /requirements/dev.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pre-commit 3 | black 4 | isort 5 | build 6 | twine 7 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | python-dotenv 2 | huggingface_hub 3 | langchain>=0.0.94 4 | openai>=0.26.5 5 | streamlit>=1.19.0 6 | streamlit-chat>=0.0.2.1 7 | transformers>=4.26.1 8 | fastapi>=0.88.0 9 | gradio 10 | bitsandbytes>=0.40.0 11 | sentencepiece 12 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = llm_inference 3 | version = attr: llm_inference.__version__ 4 | author = Aniket Maurya 5 | author_email = theaniketmaurya@gmail.com 6 | description = Large Language Models Inference API and Applications 7 | description-file = README.md 8 | long_description = file: README.md, LICENSE.md 9 | long_description_content_type = text/markdown 10 | url = https://github.com/aniketmaurya/llm-inference 11 | license = Apache License 2.0 12 | keywords = LLM, LLaMA, GPT, Falcon 13 | 14 | [options] 15 | python_requires = >=3.8 16 | package_dir = 17 | llm_inference = src/llm_inference 18 | llm_chain = src/llm_chain 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | def get_requirements(file): 5 | with open(file) as f: 6 | required = f.read().splitlines() 7 | return required 8 | 9 | 10 | required = get_requirements("requirements/requirements.txt") 11 | dev_required = get_requirements("requirements/dev.txt") 12 | extras = {"dev": dev_required} 13 | 14 | setup(install_requires=required, extras_require=extras) 15 | -------------------------------------------------------------------------------- /src/llm_chain/__init__.py: -------------------------------------------------------------------------------- 1 | from .conversation_chain import LitGPTConversationChain 2 | from .llm import LitGPTLLM 3 | -------------------------------------------------------------------------------- /src/llm_chain/conversation_chain.py: -------------------------------------------------------------------------------- 1 | """Wrapper around Lightning App.""" 2 | 3 | import logging 4 | from collections import deque 5 | from typing import Optional, Union 6 | 7 | from langchain.chains import ConversationChain 8 | from langchain.chains.conversation.memory import ConversationBufferWindowMemory 9 | from langchain.prompts import PromptTemplate 10 | 11 | from .llm import DummyLLM, LitGPTLLM, ServerLLM 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class LitGPTConversationChain(ConversationChain): 17 | @staticmethod 18 | def from_llm( 19 | llm: Union[str, LitGPTLLM], 20 | prompt: Optional[PromptTemplate] = None, 21 | memory: Optional[None] = None, 22 | input_key="input", 23 | output_key="response", 24 | verbose=False, 25 | url: Optional[str] = None, 26 | ): 27 | if llm == "dummy": 28 | llm = DummyLLM() 29 | 30 | if llm == "server": 31 | llm = ServerLLM(url=url) 32 | 33 | if not memory: 34 | memory = ConversationBufferWindowMemory( 35 | llm=llm, 36 | k=5, 37 | output_key=output_key, 38 | input_key=input_key, 39 | ai_prefix="Assistant", 40 | human_prefix="User", 41 | ) 42 | chain = LitGPTConversationChain( 43 | llm=llm, 44 | verbose=verbose, 45 | memory=memory, 46 | output_key=output_key, 47 | input_key=input_key, 48 | ) 49 | if prompt: 50 | chain.prompt = prompt 51 | return chain 52 | 53 | @staticmethod 54 | def from_lit_gpt( 55 | checkpoint_dir: str, 56 | precision: str = "bf16-mixed", 57 | quantize: Optional[str] = None, 58 | accelerator: str = "auto", 59 | input_key="input", 60 | output_key="response", 61 | verbose=False, 62 | ): 63 | llm = LitGPTLLM( 64 | checkpoint_dir=checkpoint_dir, 65 | precision=precision, 66 | quantize=quantize, 67 | accelerator=accelerator, 68 | ) 69 | return LitGPTConversationChain.from_llm( 70 | llm=llm, 71 | input_key=input_key, 72 | output_key=output_key, 73 | verbose=verbose, 74 | ) 75 | 76 | def send(self, prompt: str, **kwargs): 77 | return self(prompt)["response"] 78 | 79 | @property 80 | def history(self): 81 | return self.memory.buffer 82 | 83 | def clear(self): 84 | self.memory.clear() 85 | 86 | 87 | def build_server_chain( 88 | url: str, input_key: str = "input", output_key: str = "response" 89 | ) -> ConversationChain: 90 | """Logic for loading the chain you want to use should go here.""" 91 | 92 | logger.info(f"Initializing ServerLLM using url: {url}") 93 | 94 | llm = ServerLLM(url=url) 95 | 96 | memory = ConversationSummaryBufferMemory( 97 | llm=llm, output_key=output_key, input_key=input_key 98 | ) 99 | chain = ConversationChain( 100 | llm=llm, verbose=True, memory=memory, output_key=output_key, input_key=input_key 101 | ) 102 | logger.info("Created Conversational Chain") 103 | return chain 104 | -------------------------------------------------------------------------------- /src/llm_chain/llm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, List, Optional 3 | 4 | import requests 5 | from langchain.llms.base import LLM 6 | from pydantic import BaseModel 7 | 8 | from llm_inference import LLMInference 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class DummyLLM(LLM, BaseModel): 14 | def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str: 15 | return f"Hi, I am a helpful chatbot!" 16 | 17 | @property 18 | def _llm_type(self) -> str: 19 | """Return type of llm.""" 20 | return "Dummy LLM" 21 | 22 | 23 | class LitGPTLLM(LLM, BaseModel): 24 | checkpoint_dir: str = "" 25 | model: Any = None 26 | quantize: Optional[str] = None 27 | accelerator: Optional[str] = "auto" 28 | model_configs: dict = {} 29 | 30 | def _call( 31 | self, 32 | prompt: str, 33 | stop: Optional[list] = None, 34 | temperature=1e-5, 35 | **kwargs: Any, 36 | ) -> str: 37 | if not self.model: 38 | print("Loading model for first time...") 39 | self.model = LLMInference( 40 | checkpoint_dir=self.checkpoint_dir, 41 | quantize=self.quantize, 42 | accelerator=self.accelerator, 43 | **self.model_configs, 44 | ) 45 | 46 | return self.model.chat(prompt, temperature=temperature, **kwargs) 47 | 48 | @property 49 | def _llm_type(self) -> str: 50 | """Return type of llm.""" 51 | return "Lit-GPT LLM" 52 | 53 | 54 | class ServerLLM(LLM, BaseModel): 55 | url: str = "" 56 | TIMEOUT: float = 60.0 57 | 58 | def _call(self, prompt: str, stop: Optional[list] = None) -> str: 59 | """Run the LLM on the given prompt and input.""" 60 | if self.url == "": 61 | raise Exception("Server URL not set!") 62 | 63 | headers = { 64 | "accept": "application/json", 65 | "Content-Type": "application/json", 66 | } 67 | assert isinstance(prompt, str) 68 | json_data = {"prompt": prompt} 69 | response = requests.post( 70 | url=self.url + "/predict", 71 | headers=headers, 72 | json=json_data, 73 | timeout=self.TIMEOUT, 74 | ) 75 | logger.error(response.raise_for_status()) 76 | return response.json()["result"] 77 | 78 | @property 79 | def _llm_type(self) -> str: 80 | """Return type of llm.""" 81 | return "Server LLM" 82 | -------------------------------------------------------------------------------- /src/llm_chain/templates.py: -------------------------------------------------------------------------------- 1 | from langchain.prompts import PromptTemplate 2 | 3 | chatgpt_template = """Assistant is a large language model trained by OpenAI. 4 | 5 | Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. 6 | 7 | Assistant is constantly learning and improving, and its capabilities are constantly evolving. It is able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. Additionally, Assistant is able to generate its own text based on the input it receives, allowing it to engage in discussions and provide explanations and descriptions on a wide range of topics. 8 | 9 | Overall, Assistant is a powerful tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether you need help with a specific question or just want to have a conversation about a particular topic, Assistant is here to assist. 10 | 11 | {history} 12 | Human: {human_input} 13 | Assistant:""" 14 | 15 | chatgpt_prompt_template = PromptTemplate( 16 | input_variables=["human_input", "history"], template=chatgpt_template 17 | ) 18 | 19 | 20 | question_template = """Question: {question} 21 | 22 | Answer:""" 23 | 24 | 25 | longchat_template = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. 26 | Context: {history} 27 | USER: {input} ASSISTANT: 28 | """ 29 | 30 | longchat_prompt_template = PromptTemplate( 31 | input_variables=["input", "history"], template=longchat_template 32 | ) 33 | 34 | 35 | # llama2 template 36 | b_inst, e_inst = "[INST]", "[/INST]" 37 | b_sys, e_sys = "<>\n", "\n<>\n\n" 38 | llama2_template = ( 39 | f"{b_inst} {b_sys}You are a helpful, respectful and honest assistant. Always answer as helpfully as" 40 | " possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist," 41 | " toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and" 42 | " positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why" 43 | " instead of answering something not correct. If you don't know the answer to a question, please don't" 44 | f" share false information.{{history}} {e_sys} {{input}} {e_inst} " 45 | ) 46 | 47 | llama2_prompt_template = PromptTemplate( 48 | input_variables=["input", "history"], template=llama2_template 49 | ) 50 | -------------------------------------------------------------------------------- /src/llm_chain/ui/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import run as ui_render_fn 2 | -------------------------------------------------------------------------------- /src/llm_chain/ui/main.py: -------------------------------------------------------------------------------- 1 | """Python file to serve as the frontend""" 2 | 3 | import logging 4 | 5 | import rich 6 | import streamlit as st 7 | from streamlit_chat import message 8 | 9 | from llm_chain import ServerChatBot 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def run(lightning_app_state): 15 | if not lightning_app_state.llm_url: 16 | st.info("Waiting for server to get ready...") 17 | return 18 | 19 | print("lightning_app_state", lightning_app_state) 20 | 21 | if "model" not in st.session_state: 22 | # build unique conversational chain per session state 23 | bot = ServerChatBot(lightning_app_state.llm_url) 24 | st.session_state["model"] = bot 25 | logger.info("loaded model into state session") 26 | 27 | else: 28 | bot = st.session_state["model"] 29 | 30 | # From here down is all the StreamLit UI. 31 | st.set_page_config(page_title="LLaMA Demo", page_icon=":robot:") 32 | st.header("LLM Demo") 33 | 34 | if "generated" not in st.session_state: 35 | st.session_state["generated"] = [] 36 | 37 | if "past" not in st.session_state: 38 | st.session_state["past"] = [] 39 | 40 | def get_text(): 41 | input_text = st.text_input("You: ", "Hello, how are you?", key="input") 42 | return input_text 43 | 44 | user_input = get_text() 45 | 46 | if user_input: 47 | rich.print("user input:", user_input) 48 | output = bot.predict(input=user_input) 49 | rich.print("buffer:", bot.memory.buffer) 50 | 51 | st.session_state.past.append(user_input) 52 | st.session_state.generated.append(output) 53 | 54 | if st.session_state["generated"]: 55 | for i in range(len(st.session_state["generated"]) - 1, -1, -1): 56 | message(st.session_state["generated"][i], key=str(i)) 57 | message(st.session_state["past"][i], is_user=True, key=str(i) + "_user") 58 | -------------------------------------------------------------------------------- /src/llm_inference/__init__.py: -------------------------------------------------------------------------------- 1 | """Inference API for LLaMA""" 2 | 3 | from .download import prepare_weights 4 | from .model import LLMInference 5 | 6 | # from .serve import ServeLitGPT 7 | 8 | __version__ = "0.0.7" 9 | -------------------------------------------------------------------------------- /src/llm_inference/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from lit_gpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint 5 | from lit_gpt.scripts.download import download_from_hub 6 | 7 | 8 | def prepare_weights( 9 | repo_id: str, 10 | ): 11 | local_dir = Path(f"checkpoints/{repo_id}") 12 | if local_dir.exists(): 13 | print(f"weights already exists at {local_dir}") 14 | return local_dir 15 | download_from_hub(repo_id=repo_id) 16 | convert_hf_checkpoint(checkpoint_dir=local_dir) 17 | return local_dir 18 | -------------------------------------------------------------------------------- /src/llm_inference/model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import time 5 | import warnings 6 | from functools import partial 7 | from pathlib import Path 8 | from typing import Any, Literal, Optional, Union 9 | 10 | import lightning as L 11 | import torch 12 | from dotenv import load_dotenv 13 | from lightning.fabric.strategies import FSDPStrategy 14 | from lit_gpt import GPT, Config, Tokenizer 15 | from lit_gpt.adapter_v2 import add_adapter_v2_parameters_to_linear_layers 16 | from lit_gpt.model import Block 17 | from lit_gpt.utils import check_valid_checkpoint_dir, lazy_load, quantization 18 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 19 | 20 | from .token_manipulation import get_stop_tokens 21 | 22 | load_dotenv() 23 | 24 | WEIGHTS_PATH = os.environ.get("WEIGHTS") 25 | 26 | 27 | def generate_prompt(example): 28 | """Generates a standardized message to prompt the model with an instruction, optional input and a 29 | 'response' field.""" 30 | 31 | if example["input"]: 32 | return ( 33 | "Below is an instruction that describes a task, paired with an input that provides further context. " 34 | "Write a response that appropriately completes the request.\n\n" 35 | f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" 36 | ) 37 | return ( 38 | "Below is an instruction that describes a task. " 39 | "Write a response that appropriately completes the request.\n\n" 40 | f"### Instruction:\n{example['instruction']}\n\n### Response:" 41 | ) 42 | 43 | 44 | @torch.inference_mode() 45 | def _generate( 46 | model: torch.nn.Module, 47 | idx: torch.Tensor, 48 | max_returned_tokens: int, 49 | max_seq_length: int, 50 | *, 51 | temperature: float = 1.0, 52 | top_k: Optional[int] = None, 53 | eos_id: Optional[int] = None, 54 | ) -> torch.Tensor: 55 | """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 56 | 57 | The implementation of this function is modified from A. Karpathy's nanoGPT. 58 | 59 | Args: 60 | model: The model to use. 61 | idx: Tensor of shape (T) with indices of the prompt sequence. 62 | max_returned_tokens: The maximum number of tokens to return (given plus generated). 63 | max_seq_length: The maximum sequence length allowed. Should be less or equal than the block size. 64 | temperature: Scales the predicted logits by 1 / temperature. 65 | top_k: If specified, only sample among the tokens with the k highest probabilities. 66 | eos_id: If specified, stop generating any more token once the token is triggered. 67 | """ 68 | T = idx.size(0) 69 | assert max_returned_tokens > T 70 | device, dtype = idx.device, idx.dtype 71 | # create an empty tensor of the expected final shape and fill in the current tokens 72 | empty = torch.empty(max_returned_tokens, dtype=dtype, device=device) 73 | empty[:T] = idx 74 | idx = empty 75 | input_pos = torch.arange(0, T, device=device) 76 | 77 | if idx.device.type == "xla": 78 | import torch_xla.core.xla_model as xm 79 | 80 | xm.mark_step() 81 | 82 | # generate up to a fixed number of tokens 83 | for _ in range(max_returned_tokens - T): 84 | x = idx.index_select(0, input_pos).view(1, -1) 85 | 86 | # forward 87 | logits = model(x, max_seq_length, input_pos) 88 | logits = logits[0, -1] / temperature 89 | 90 | # optionally crop the logits to only the top k options 91 | if top_k is not None: 92 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 93 | logits = torch.where(logits < v[[-1]], -float("Inf"), logits) 94 | 95 | probs = torch.nn.functional.softmax(logits, dim=-1) 96 | idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) 97 | 98 | # advance 99 | input_pos = input_pos[-1:] + 1 100 | 101 | if idx.device.type == "xla": 102 | xm.mark_step() 103 | 104 | # concatenate the new generation 105 | idx = idx.index_copy(0, input_pos, idx_next) 106 | 107 | # if token is triggered, return the output (stop generation) 108 | if idx_next == eos_id: 109 | return idx[:input_pos] # include the EOS token 110 | 111 | return idx 112 | 113 | 114 | class LLMInference: 115 | def __init__( 116 | self, 117 | checkpoint_dir: Path = Path(f"checkpoints/tiiuae/falcon-7b"), 118 | quantize: Literal["llm.int8", "gptq.int4"] = None, 119 | accelerator: str = "auto", 120 | strategy: str = "auto", 121 | devices: int = 1, 122 | precision: str = "bf16-true", 123 | adapter_path: Optional[Path] = None, 124 | ) -> None: 125 | self.quantize = quantize 126 | 127 | checkpoint_dir = Path(checkpoint_dir) 128 | 129 | if strategy == "fsdp": 130 | auto_wrap_policy = partial( 131 | transformer_auto_wrap_policy, transformer_layer_cls={Block} 132 | ) 133 | strategy = FSDPStrategy( 134 | auto_wrap_policy=auto_wrap_policy, cpu_offload=False 135 | ) 136 | self.fabric = fabric = L.Fabric( 137 | devices=devices, 138 | precision=precision, 139 | strategy=strategy, 140 | accelerator=accelerator, 141 | ) 142 | fabric.launch() 143 | 144 | check_valid_checkpoint_dir(checkpoint_dir) 145 | 146 | with open(checkpoint_dir / "lit_config.json") as fp: 147 | self.config = config = Config(**json.load(fp)) 148 | 149 | if quantize is not None and devices > 1: 150 | raise NotImplementedError 151 | if quantize == "gptq.int4": 152 | model_file = "lit_model_gptq.4bit.pth" 153 | if not (checkpoint_dir / model_file).is_file(): 154 | raise ValueError("Please run `python quantize/gptq.py` first") 155 | else: 156 | model_file = "lit_model.pth" 157 | checkpoint_path = checkpoint_dir / model_file 158 | 159 | if adapter_path: 160 | model = self.load_adapter_model( 161 | checkpoint_path=checkpoint_path, adapter_path=adapter_path 162 | ) 163 | 164 | else: 165 | model = self.load_model(checkpoint_path=checkpoint_path) 166 | 167 | model.eval() 168 | self.model = fabric.setup_module(model) 169 | self.tokenizer = Tokenizer(checkpoint_dir) 170 | 171 | def __call__( 172 | self, 173 | prompt: str, 174 | max_new_tokens: int = 100, 175 | top_k: int = 200, 176 | temperature: float = 0.1, 177 | eos_id=None, 178 | ) -> str: 179 | tokenizer = self.tokenizer 180 | model = self.model 181 | fabric = self.fabric 182 | 183 | encoded = tokenizer.encode(prompt, device=fabric.device) 184 | prompt_length = encoded.size(0) 185 | max_returned_tokens = prompt_length + max_new_tokens 186 | 187 | t0 = time.perf_counter() 188 | y = _generate( 189 | model, 190 | encoded, 191 | max_returned_tokens, 192 | max_seq_length=max_returned_tokens, 193 | temperature=temperature, 194 | top_k=top_k, 195 | eos_id=eos_id, 196 | ) 197 | t = time.perf_counter() - t0 198 | 199 | model.reset_cache() 200 | output = tokenizer.decode(y[prompt_length:]) 201 | tokens_generated = y.size(0) - prompt_length 202 | fabric.print( 203 | f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", 204 | file=sys.stderr, 205 | ) 206 | if fabric.device.type == "cuda": 207 | fabric.print( 208 | f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", 209 | file=sys.stderr, 210 | ) 211 | 212 | return output 213 | 214 | def instruction_predict( 215 | self, 216 | prompt: str, 217 | max_new_tokens: int = 100, 218 | top_k: int = 200, 219 | temperature: float = 0.1, 220 | ) -> str: 221 | sample = {"instruction": prompt, "input": input} 222 | prompt = generate_prompt(sample) 223 | output = self.__call__( 224 | prompt=prompt, 225 | max_new_tokens=max_new_tokens, 226 | top_k=top_k, 227 | temperature=temperature, 228 | eos_id=self.tokenizer.eos_id, 229 | ) 230 | output = output.split("### Response:")[1].strip() 231 | return output 232 | 233 | def chat( 234 | self, 235 | prompt: str, 236 | max_new_tokens: int = 100, 237 | top_k: int = 200, 238 | temperature: float = 0.1, 239 | eos_id=None, 240 | ) -> str: 241 | tokenizer = self.tokenizer 242 | model = self.model 243 | fabric = self.fabric 244 | 245 | encoded = tokenizer.encode(prompt, device=fabric.device) 246 | prompt_length = encoded.size(0) 247 | max_returned_tokens = model.config.block_size 248 | 249 | t0 = time.perf_counter() 250 | y = _generate( 251 | model, 252 | encoded, 253 | max_returned_tokens, 254 | max_seq_length=max_returned_tokens, 255 | temperature=temperature, 256 | top_k=top_k, 257 | eos_id=self.tokenizer.eos_id, 258 | ) 259 | t = time.perf_counter() - t0 260 | 261 | model.reset_cache() 262 | output = tokenizer.decode(y[prompt_length:]) 263 | tokens_generated = y.size(0) - prompt_length 264 | fabric.print( 265 | f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", 266 | file=sys.stderr, 267 | ) 268 | if fabric.device.type == "cuda": 269 | fabric.print( 270 | f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", 271 | file=sys.stderr, 272 | ) 273 | 274 | return output 275 | 276 | def eval(self): 277 | self.model.eval() 278 | 279 | def load_model(self, checkpoint_path: Union[Path, str]): 280 | fabric = self.fabric 281 | quantize = self.quantize 282 | config = self.config 283 | 284 | fabric.print( 285 | f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", 286 | file=sys.stderr, 287 | ) 288 | t0 = time.time() 289 | with fabric.init_module(empty_init=True), quantization(quantize): 290 | model = GPT(config) 291 | fabric.print( 292 | f"Time to instantiate model: {time.time() - t0:.02f} seconds.", 293 | file=sys.stderr, 294 | ) 295 | 296 | t0 = time.time() 297 | with lazy_load(checkpoint_path) as checkpoint: 298 | model.load_state_dict( 299 | checkpoint.get("model", checkpoint), strict=quantize is None 300 | ) 301 | fabric.print( 302 | f"Time to load the model weights: {time.time() - t0:.02f} seconds.", 303 | file=sys.stderr, 304 | ) 305 | return model 306 | 307 | def load_lora_model(self, checkpoint_path, lora_path: str): 308 | return self.load_model(checkpoint_path) 309 | 310 | def load_adapter_model(self, checkpoint_path, adapter_path: str): 311 | fabric = self.fabric 312 | quantize = self.quantize 313 | config = self.config 314 | 315 | fabric.print( 316 | f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", 317 | file=sys.stderr, 318 | ) 319 | t0 = time.time() 320 | with fabric.init_module(empty_init=True), quantization(quantize): 321 | model = GPT(config) 322 | add_adapter_v2_parameters_to_linear_layers(model) 323 | fabric.print( 324 | f"Time to instantiate model: {time.time() - t0:.02f} seconds.", 325 | file=sys.stderr, 326 | ) 327 | 328 | t0 = time.time() 329 | with lazy_load(checkpoint_path) as checkpoint, lazy_load( 330 | adapter_path 331 | ) as adapter_checkpoint: 332 | checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint)) 333 | model.load_state_dict(checkpoint, strict=quantize is None) 334 | fabric.print( 335 | f"Time to load the model weights: {time.time() - t0:.02f} seconds.", 336 | file=sys.stderr, 337 | ) 338 | return model 339 | -------------------------------------------------------------------------------- /src/llm_inference/serve.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import lightning as L 4 | from lightning.app.components import PythonServer 5 | from pydantic import BaseModel 6 | 7 | from llm_inference.model import LLMInference 8 | 9 | 10 | class PromptRequest(BaseModel): 11 | prompt: str 12 | 13 | 14 | class Response(BaseModel): 15 | result: str 16 | 17 | 18 | class ServeLitGPT(PythonServer): 19 | def __init__( 20 | self, 21 | input_type, 22 | output_type, 23 | checkpoint_dir: str = None, 24 | ): 25 | super().__init__(input_type, output_type) 26 | self.checkpoint_dir = checkpoint_dir 27 | 28 | def setup(self, *args: Any, **kwargs: Any) -> None: 29 | self._model = LLMInference( 30 | checkpoint_dir=self.checkpoint_dir, 31 | ) 32 | 33 | def predict(self, request: PromptRequest) -> Any: 34 | result = self._model.chat(request.prompt) 35 | return Response(result=result) 36 | 37 | 38 | if __name__ == "__main__": 39 | component = ServeLitGPT( 40 | input_type=PromptRequest, 41 | output_type=Response, 42 | checkpoint_dir="examples/chatbot/checkpoints/lmsys/longchat-7b-16k/", 43 | ) 44 | app = L.LightningApp(component) 45 | -------------------------------------------------------------------------------- /src/llm_inference/token_manipulation.py: -------------------------------------------------------------------------------- 1 | def get_stop_tokens(tokenizer): 2 | stop_tokens = ([tokenizer.eos_id],) 3 | return stop_tokens 4 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aniketmaurya/llm-inference/5bb323c4cce70dcbe81cf794aaa0a66b87fe3083/tests/__init__.py -------------------------------------------------------------------------------- /tests/__main__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aniketmaurya/llm-inference/5bb323c4cce70dcbe81cf794aaa0a66b87fe3083/tests/__main__.py -------------------------------------------------------------------------------- /tests/llm_chain/test_chain.py: -------------------------------------------------------------------------------- 1 | from llm_chain.conversation_chain import LitGPTConversationChain 2 | 3 | 4 | def test_dummybot(): 5 | bot = LitGPTConversationChain.from_llm("dummy") 6 | prompt = "Hello, I am testing you!" 7 | response = bot.send(prompt) 8 | assert isinstance(response, str) 9 | assert response == "Hi, I am a helpful chatbot!" 10 | --------------------------------------------------------------------------------