├── .env ├── .github ├── actions │ └── poetry_setup │ │ └── action.yml └── workflows │ ├── _integration_test.yml │ ├── _release.yml │ ├── _test.yml │ └── _test_release.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.en.md ├── README.md ├── docs └── img │ ├── MetaGLM.png │ ├── all_tools.png │ └── demo.mp4 ├── langchain_glm ├── __init__.py ├── agent_toolkits │ ├── __init__.py │ └── all_tools │ │ ├── __init__.py │ │ ├── code_interpreter_tool.py │ │ ├── drawing_tool.py │ │ ├── registry.py │ │ ├── struct_type.py │ │ ├── tool.py │ │ └── web_browser_tool.py ├── agents │ ├── __init__.py │ ├── all_tools_agent.py │ ├── all_tools_bind │ │ └── base.py │ ├── format_scratchpad │ │ └── all_tools.py │ ├── output_parsers │ │ ├── __init__.py │ │ ├── _utils.py │ │ ├── base.py │ │ ├── code_interpreter.py │ │ ├── drawing_tool.py │ │ ├── function.py │ │ ├── tools.py │ │ ├── web_browser.py │ │ └── zhipuai_all_tools.py │ └── zhipuai_all_tools │ │ ├── __init__.py │ │ ├── base.py │ │ └── schema.py ├── callbacks │ ├── __init__.py │ └── agent_callback_handler.py ├── chat_models │ ├── __init__.py │ ├── all_tools_message.py │ └── base.py ├── embeddings │ ├── __init__.py │ └── base.py └── utils │ ├── __init__.py │ └── history.py ├── poetry.toml ├── pyproject.toml ├── scripts ├── add_encoding_declaration.py ├── check_imports.py ├── check_pydantic.sh └── lint_imports.sh └── tests ├── assistant ├── chatchat_icon_blue_square_v2.png ├── client.py ├── dialogue.py ├── server │ └── server.py ├── start_chat.py ├── utils.py └── webui.py ├── conftest.py ├── integration_tests ├── all_tools │ └── test_alltools.py ├── demo │ ├── test_demo_1_openai.py │ ├── test_demo_1_zhipuai.py │ ├── test_demo_2_openai.py │ └── test_demo_2_zhipuai.py ├── embeddings │ └── test_embeddings.py ├── prompt │ └── test_hub.py └── tools │ └── tool_use.py └── unit_tests ├── output_parsers └── test_message_tool_paser_chunk.py ├── test_code_interpreter.py ├── tittoken.py └── tools_bind └── test_tools_bind.py /.env: -------------------------------------------------------------------------------- 1 | PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring 2 | PYTHONIOENCODING=utf-8 -------------------------------------------------------------------------------- /.github/actions/poetry_setup/action.yml: -------------------------------------------------------------------------------- 1 | # An action for setting up poetry install with caching. 2 | # Using a custom action since the default action does not 3 | # take poetry install groups into account. 4 | # Action code from: 5 | # https://github.com/actions/setup-python/issues/505#issuecomment-1273013236 6 | name: poetry-install-with-caching 7 | description: Poetry install with support for caching of dependency groups. 8 | 9 | inputs: 10 | python-version: 11 | description: Python version, supporting MAJOR.MINOR only 12 | required: true 13 | 14 | poetry-version: 15 | description: Poetry version 16 | required: true 17 | 18 | cache-key: 19 | description: Cache key to use for manual handling of caching 20 | required: true 21 | 22 | working-directory: 23 | description: Directory whose poetry.lock file should be cached 24 | required: true 25 | 26 | runs: 27 | using: composite 28 | steps: 29 | - uses: actions/setup-python@v5 30 | name: Setup python ${{ inputs.python-version }} 31 | id: setup-python 32 | with: 33 | python-version: ${{ inputs.python-version }} 34 | 35 | - uses: actions/cache@v4 36 | id: cache-bin-poetry 37 | name: Cache Poetry binary - Python ${{ inputs.python-version }} 38 | env: 39 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "1" 40 | with: 41 | path: | 42 | /opt/pipx/venvs/poetry 43 | # This step caches the poetry installation, so make sure it's keyed on the poetry version as well. 44 | key: bin-poetry-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-${{ inputs.poetry-version }} 45 | 46 | - name: Refresh shell hashtable and fixup softlinks 47 | if: steps.cache-bin-poetry.outputs.cache-hit == 'true' 48 | shell: bash 49 | env: 50 | POETRY_VERSION: ${{ inputs.poetry-version }} 51 | PYTHON_VERSION: ${{ inputs.python-version }} 52 | run: | 53 | set -eux 54 | 55 | # Refresh the shell hashtable, to ensure correct `which` output. 56 | hash -r 57 | 58 | # `actions/cache@v3` doesn't always seem able to correctly unpack softlinks. 59 | # Delete and recreate the softlinks pipx expects to have. 60 | rm /opt/pipx/venvs/poetry/bin/python 61 | cd /opt/pipx/venvs/poetry/bin 62 | ln -s "$(which "python$PYTHON_VERSION")" python 63 | chmod +x python 64 | cd /opt/pipx_bin/ 65 | ln -s /opt/pipx/venvs/poetry/bin/poetry poetry 66 | chmod +x poetry 67 | 68 | # Ensure everything got set up correctly. 69 | /opt/pipx/venvs/poetry/bin/python --version 70 | /opt/pipx_bin/poetry --version 71 | 72 | - name: Install poetry 73 | if: steps.cache-bin-poetry.outputs.cache-hit != 'true' 74 | shell: bash 75 | env: 76 | POETRY_VERSION: ${{ inputs.poetry-version }} 77 | PYTHON_VERSION: ${{ inputs.python-version }} 78 | # Install poetry using the python version installed by setup-python step. 79 | run: pipx install "poetry==$POETRY_VERSION" --python '${{ steps.setup-python.outputs.python-path }}' --verbose 80 | 81 | - name: Restore pip and poetry cached dependencies 82 | uses: actions/cache@v4 83 | env: 84 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "4" 85 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 86 | with: 87 | path: | 88 | ~/.cache/pip 89 | ~/.cache/pypoetry/virtualenvs 90 | ~/.cache/pypoetry/cache 91 | ~/.cache/pypoetry/artifacts 92 | ${{ env.WORKDIR }}/.venv 93 | key: py-deps-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-poetry-${{ inputs.poetry-version }}-${{ inputs.cache-key }}-${{ hashFiles(format('{0}/**/poetry.lock', env.WORKDIR)) }} -------------------------------------------------------------------------------- /.github/workflows/_integration_test.yml: -------------------------------------------------------------------------------- 1 | name: integration_test 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | default: '.' 10 | description: "From which folder this pipeline executes" 11 | 12 | env: 13 | POETRY_VERSION: "1.7.1" 14 | 15 | jobs: 16 | build: 17 | if: github.ref == 'refs/heads/main' 18 | 19 | environment: Scheduled testing publish 20 | outputs: 21 | pkg-name: ${{ steps.check-version.outputs.pkg-name }} 22 | version: ${{ steps.check-version.outputs.version }} 23 | runs-on: ${{ matrix.os }} 24 | strategy: 25 | matrix: 26 | os: [ubuntu-latest, windows-latest, macos-latest] 27 | python-version: ["3.8", "3.9", "3.10", "3.11"] 28 | 29 | name: "make integration_test #${{ matrix.os }} Python ${{ matrix.python-version }}" 30 | steps: 31 | - uses: actions/checkout@v4 32 | 33 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 34 | uses: "./.github/actions/poetry_setup" 35 | with: 36 | python-version: ${{ matrix.python-version }} 37 | poetry-version: ${{ env.POETRY_VERSION }} 38 | working-directory: ${{ inputs.working-directory }} 39 | cache-key: core 40 | 41 | 42 | - name: Import test dependencies 43 | run: poetry install --with test 44 | working-directory: ${{ inputs.working-directory }} 45 | 46 | - name: Run integration tests 47 | shell: bash 48 | env: 49 | PYTHONIOENCODING: "utf-8" 50 | ZHIPUAI_API_KEY: ${{ secrets.ZHIPUAI_API_KEY }} 51 | ZHIPUAI_BASE_URL: ${{ secrets.ZHIPUAI_BASE_URL }} 52 | run: | 53 | make integration_tests 54 | 55 | - name: Ensure the tests did not create any additional files 56 | shell: bash 57 | run: | 58 | set -eu 59 | 60 | STATUS="$(git status)" 61 | echo "$STATUS" 62 | 63 | # grep will exit non-zero if the target message isn't found, 64 | # and `set -e` above will cause the step to fail. 65 | echo "$STATUS" | grep 'nothing to commit, working tree clean' -------------------------------------------------------------------------------- /.github/workflows/_release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | run-name: Release ${{ inputs.working-directory }} by @${{ github.actor }} 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | workflow_dispatch: 11 | inputs: 12 | working-directory: 13 | required: true 14 | type: string 15 | default: '.' 16 | description: "From which folder this pipeline executes" 17 | env: 18 | PYTHON_VERSION: "3.9" 19 | POETRY_VERSION: "1.7.1" 20 | 21 | jobs: 22 | build: 23 | if: github.ref == 'refs/heads/main' 24 | environment: Scheduled testing publish 25 | runs-on: ubuntu-latest 26 | 27 | outputs: 28 | pkg-name: ${{ steps.check-version.outputs.pkg-name }} 29 | version: ${{ steps.check-version.outputs.version }} 30 | 31 | steps: 32 | - uses: actions/checkout@v4 33 | 34 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 35 | uses: "./.github/actions/poetry_setup" 36 | with: 37 | python-version: ${{ env.PYTHON_VERSION }} 38 | poetry-version: ${{ env.POETRY_VERSION }} 39 | working-directory: ${{ inputs.working-directory }} 40 | cache-key: release 41 | 42 | # We want to keep this build stage *separate* from the release stage, 43 | # so that there's no sharing of permissions between them. 44 | # The release stage has trusted publishing and GitHub repo contents write access, 45 | # and we want to keep the scope of that access limited just to the release job. 46 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency) 47 | # could get access to our GitHub or PyPI credentials. 48 | # 49 | # Per the trusted publishing GitHub Action: 50 | # > It is strongly advised to separate jobs for building [...] 51 | # > from the publish job. 52 | # https://github.com/pypa/gh-action-pypi-publish#non-goals 53 | - name: Build project for distribution 54 | run: poetry build 55 | working-directory: ${{ inputs.working-directory }} 56 | 57 | - name: Upload build 58 | uses: actions/upload-artifact@v4 59 | with: 60 | name: dist 61 | path: ${{ inputs.working-directory }}/dist/ 62 | 63 | - name: Check Version 64 | id: check-version 65 | shell: bash 66 | working-directory: ${{ inputs.working-directory }} 67 | run: | 68 | echo pkg-name="$(poetry version | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT 69 | echo version="$(poetry version --short)" >> $GITHUB_OUTPUT 70 | 71 | test-pypi-publish: 72 | needs: 73 | - build 74 | uses: 75 | ./.github/workflows/_test_release.yml 76 | with: 77 | working-directory: ${{ inputs.working-directory }} 78 | secrets: inherit 79 | 80 | pre-release-checks: 81 | needs: 82 | - build 83 | - test-pypi-publish 84 | environment: Scheduled testing publish 85 | runs-on: ubuntu-latest 86 | steps: 87 | - uses: actions/checkout@v4 88 | 89 | # We explicitly *don't* set up caching here. This ensures our tests are 90 | # maximally sensitive to catching breakage. 91 | # 92 | # For example, here's a way that caching can cause a falsely-passing test: 93 | # - Make the langchain package manifest no longer list a dependency package 94 | # as a requirement. This means it won't be installed by `pip install`, 95 | # and attempting to use it would cause a crash. 96 | # - That dependency used to be required, so it may have been cached. 97 | # When restoring the venv packages from cache, that dependency gets included. 98 | # - Tests pass, because the dependency is present even though it wasn't specified. 99 | # - The package is published, and it breaks on the missing dependency when 100 | # used in the real world. 101 | 102 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 103 | uses: "./.github/actions/poetry_setup" 104 | with: 105 | python-version: ${{ env.PYTHON_VERSION }} 106 | poetry-version: ${{ env.POETRY_VERSION }} 107 | working-directory: ${{ inputs.working-directory }} 108 | 109 | - name: Import published package 110 | shell: bash 111 | working-directory: ${{ inputs.working-directory }} 112 | env: 113 | PKG_NAME: ${{ needs.build.outputs.pkg-name }} 114 | VERSION: ${{ needs.build.outputs.version }} 115 | # Here we use: 116 | # - The default regular PyPI index as the *primary* index, meaning 117 | # that it takes priority (https://pypi.org/simple) 118 | # - The test PyPI index as an extra index, so that any dependencies that 119 | # are not found on test PyPI can be resolved and installed anyway. 120 | # (https://test.pypi.org/simple). This will include the PKG_NAME==VERSION 121 | # package because VERSION will not have been uploaded to regular PyPI yet. 122 | # - attempt install again after 5 seconds if it fails because there is 123 | # sometimes a delay in availability on test pypi 124 | run: | 125 | poetry run pip install \ 126 | --extra-index-url https://test.pypi.org/simple/ \ 127 | "$PKG_NAME==$VERSION" || \ 128 | ( \ 129 | sleep 5 && \ 130 | poetry run pip install \ 131 | --extra-index-url https://test.pypi.org/simple/ \ 132 | "$PKG_NAME==$VERSION" \ 133 | ) 134 | 135 | # Replace all dashes in the package name with underscores, 136 | # since that's how Python imports packages with dashes in the name. 137 | IMPORT_NAME="$(echo "$PKG_NAME" | sed s/-/_/g)" 138 | 139 | poetry run python -c "import $IMPORT_NAME; print(dir($IMPORT_NAME))" 140 | 141 | - name: Import test dependencies 142 | run: poetry install --with lint,test -v 143 | working-directory: ${{ inputs.working-directory }} 144 | 145 | # Overwrite the local version of the package with the test PyPI version. 146 | - name: Import published package (again) 147 | working-directory: ${{ inputs.working-directory }} 148 | shell: bash 149 | env: 150 | PKG_NAME: ${{ needs.build.outputs.pkg-name }} 151 | VERSION: ${{ needs.build.outputs.version }} 152 | run: | 153 | poetry run pip install \ 154 | --extra-index-url https://test.pypi.org/simple/ \ 155 | "$PKG_NAME==$VERSION" 156 | 157 | - name: Run unit tests 158 | run: make tests 159 | env: 160 | PYTHONIOENCODING: "utf-8" 161 | ZHIPUAI_API_KEY: ${{ secrets.ZHIPUAI_API_KEY }} 162 | ZHIPUAI_BASE_URL: ${{ secrets.ZHIPUAI_BASE_URL }} 163 | working-directory: ${{ inputs.working-directory }} 164 | 165 | - name: Run integration tests 166 | env: 167 | PYTHONIOENCODING: "utf-8" 168 | ZHIPUAI_API_KEY: ${{ secrets.ZHIPUAI_API_KEY }} 169 | ZHIPUAI_BASE_URL: ${{ secrets.ZHIPUAI_BASE_URL }} 170 | run: make integration_tests 171 | working-directory: ${{ inputs.working-directory }} 172 | 173 | publish: 174 | needs: 175 | - build 176 | - test-pypi-publish 177 | - pre-release-checks 178 | environment: Scheduled testing publish 179 | runs-on: ubuntu-latest 180 | 181 | defaults: 182 | run: 183 | working-directory: ${{ inputs.working-directory }} 184 | 185 | steps: 186 | - uses: actions/checkout@v4 187 | 188 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 189 | uses: "./.github/actions/poetry_setup" 190 | with: 191 | python-version: ${{ env.PYTHON_VERSION }} 192 | poetry-version: ${{ env.POETRY_VERSION }} 193 | working-directory: ${{ inputs.working-directory }} 194 | cache-key: release 195 | 196 | - uses: actions/download-artifact@v4 197 | with: 198 | name: dist 199 | path: ${{ inputs.working-directory }}/dist/ 200 | 201 | - name: Publish package distributions to PyPI 202 | uses: pypa/gh-action-pypi-publish@release/v1 203 | 204 | with: 205 | packages-dir: ${{ inputs.working-directory }}/dist/ 206 | verbose: true 207 | print-hash: true 208 | user: __token__ 209 | password: ${{ secrets.PYPI_API_TOKEN }} 210 | # We overwrite any existing distributions with the same name and version. 211 | # This is *only for CI use* and is *extremely dangerous* otherwise! 212 | # https://github.com/pypa/gh-action-pypi-publish#tolerating-release-package-file-duplicates 213 | skip-existing: true 214 | 215 | mark-release: 216 | needs: 217 | - build 218 | - test-pypi-publish 219 | - pre-release-checks 220 | - publish 221 | runs-on: ubuntu-latest 222 | permissions: 223 | # This permission is needed by `ncipollo/release-action` to 224 | # create the GitHub release. 225 | contents: write 226 | id-token: none 227 | 228 | defaults: 229 | run: 230 | working-directory: ${{ inputs.working-directory }} 231 | 232 | steps: 233 | - uses: actions/checkout@v4 234 | 235 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 236 | uses: "./.github/actions/poetry_setup" 237 | with: 238 | python-version: ${{ env.PYTHON_VERSION }} 239 | poetry-version: ${{ env.POETRY_VERSION }} 240 | working-directory: ${{ inputs.working-directory }} 241 | cache-key: release 242 | 243 | - uses: actions/download-artifact@v4 244 | with: 245 | name: dist 246 | path: ${{ inputs.working-directory }}/dist/ 247 | 248 | - name: Create Release 249 | uses: ncipollo/release-action@v1 250 | if: ${{ inputs.working-directory == '.' }} 251 | with: 252 | artifacts: "dist/*" 253 | token: ${{ secrets.GITHUB_TOKEN }} 254 | draft: false 255 | generateReleaseNotes: true 256 | tag: v${{ needs.build.outputs.version }} 257 | commit: main 258 | -------------------------------------------------------------------------------- /.github/workflows/_test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | default: '.' 10 | description: "From which folder this pipeline executes" 11 | 12 | env: 13 | POETRY_VERSION: "1.7.1" 14 | 15 | jobs: 16 | build: 17 | defaults: 18 | run: 19 | working-directory: ${{ inputs.working-directory }} 20 | runs-on: ${{ matrix.os }} 21 | strategy: 22 | matrix: 23 | os: [ubuntu-latest, windows-latest, macos-latest] 24 | python-version: ["3.8", "3.9", "3.10", "3.11"] 25 | 26 | name: "make test #${{ matrix.os }} Python ${{ matrix.python-version }}" 27 | steps: 28 | - uses: actions/checkout@v4 29 | 30 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 31 | uses: "./.github/actions/poetry_setup" 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | poetry-version: ${{ env.POETRY_VERSION }} 35 | working-directory: ${{ inputs.working-directory }} 36 | cache-key: core 37 | 38 | 39 | - name: Import test dependencies 40 | run: poetry install --with test 41 | working-directory: ${{ inputs.working-directory }} 42 | 43 | - name: Run core tests 44 | shell: bash 45 | env: 46 | PYTHONIOENCODING: "utf-8" 47 | run: | 48 | make test 49 | 50 | - name: Ensure the tests did not create any additional files 51 | shell: bash 52 | run: | 53 | set -eu 54 | 55 | STATUS="$(git status)" 56 | echo "$STATUS" 57 | 58 | # grep will exit non-zero if the target message isn't found, 59 | # and `set -e` above will cause the step to fail. 60 | echo "$STATUS" | grep 'nothing to commit, working tree clean' -------------------------------------------------------------------------------- /.github/workflows/_test_release.yml: -------------------------------------------------------------------------------- 1 | name: test-release 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | PYTHON_VERSION: "3.9" 14 | 15 | jobs: 16 | build: 17 | if: github.ref == 'refs/heads/main' 18 | runs-on: ubuntu-latest 19 | 20 | outputs: 21 | pkg-name: ${{ steps.check-version.outputs.pkg-name }} 22 | version: ${{ steps.check-version.outputs.version }} 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | 27 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 28 | uses: "./.github/actions/poetry_setup" 29 | with: 30 | python-version: ${{ env.PYTHON_VERSION }} 31 | poetry-version: ${{ env.POETRY_VERSION }} 32 | working-directory: ${{ inputs.working-directory }} 33 | cache-key: release 34 | 35 | # We want to keep this build stage *separate* from the release stage, 36 | # so that there's no sharing of permissions between them. 37 | # The release stage has trusted publishing and GitHub repo contents write access, 38 | # and we want to keep the scope of that access limited just to the release job. 39 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency) 40 | # could get access to our GitHub or PyPI credentials. 41 | # 42 | # Per the trusted publishing GitHub Action: 43 | # > It is strongly advised to separate jobs for building [...] 44 | # > from the publish job. 45 | # https://github.com/pypa/gh-action-pypi-publish#non-goals 46 | - name: Build project for distribution 47 | run: poetry build 48 | working-directory: ${{ inputs.working-directory }} 49 | 50 | - name: Upload build 51 | uses: actions/upload-artifact@v4 52 | with: 53 | name: test-dist 54 | path: ${{ inputs.working-directory }}/dist/ 55 | 56 | - name: Check Version 57 | id: check-version 58 | shell: bash 59 | working-directory: ${{ inputs.working-directory }} 60 | run: | 61 | echo pkg-name="$(poetry version | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT 62 | echo version="$(poetry version --short)" >> $GITHUB_OUTPUT 63 | 64 | publish: 65 | needs: 66 | - build 67 | runs-on: ubuntu-latest 68 | environment: Scheduled testing publish 69 | # permissions: 70 | # id-token: none # This is required for requesting the JWT 71 | 72 | steps: 73 | - uses: actions/checkout@v4 74 | 75 | - uses: actions/download-artifact@v4 76 | with: 77 | name: test-dist 78 | path: ${{ inputs.working-directory }}/dist/ 79 | 80 | - name: Publish to test PyPI 81 | uses: pypa/gh-action-pypi-publish@release/v1 82 | with: 83 | user: __token__ 84 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 85 | packages-dir: ${{ inputs.working-directory }}/dist/ 86 | verbose: true 87 | print-hash: true 88 | repository-url: https://test.pypi.org/legacy/ 89 | # We overwrite any existing distributions with the same name and version. 90 | # This is *only for CI use* and is *extremely dangerous* otherwise! 91 | # https://github.com/pypa/gh-action-pypi-publish#tolerating-release-package-file-duplicates 92 | skip-existing: true 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | .idea 157 | 158 | poetry.lock 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) LangChain, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests 2 | 3 | # Default target executed when no arguments are given to make. 4 | all: help 5 | 6 | ###################### 7 | # TESTING AND COVERAGE 8 | ###################### 9 | 10 | # Define a variable for the test file path. 11 | TEST_FILE ?= tests/unit_tests/ 12 | 13 | # Run unit tests and generate a coverage report. 14 | coverage: 15 | poetry run pytest --cov \ 16 | --cov-config=.coveragerc \ 17 | --cov-report xml \ 18 | --cov-report term-missing:skip-covered \ 19 | $(TEST_FILE) 20 | 21 | test tests: 22 | poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE) 23 | 24 | extended_tests: 25 | poetry run pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests 26 | 27 | test_watch: 28 | poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket tests/unit_tests 29 | 30 | test_watch_extended: 31 | poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests 32 | 33 | integration_tests: 34 | poetry run pytest tests/integration_tests 35 | 36 | scheduled_tests: 37 | poetry run pytest -m scheduled tests/integration_tests 38 | 39 | 40 | ###################### 41 | # LINTING AND FORMATTING 42 | ###################### 43 | 44 | # Define a variable for Python and notebook files. 45 | PYTHON_FILES=. 46 | MYPY_CACHE=.mypy_cache 47 | lint format: PYTHON_FILES=. 48 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/langchain --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$') 49 | lint_package: PYTHON_FILES=model_providers 50 | lint_tests: PYTHON_FILES=tests 51 | lint_tests: MYPY_CACHE=.mypy_cache_test 52 | 53 | lint lint_diff lint_package lint_tests: 54 | ./scripts/check_pydantic.sh . 55 | ./scripts/lint_imports.sh 56 | poetry run ruff . 57 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff 58 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES) 59 | [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) 60 | 61 | format format_diff: 62 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) 63 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I --fix $(PYTHON_FILES) 64 | 65 | spell_check: 66 | poetry run codespell --toml pyproject.toml 67 | 68 | spell_fix: 69 | poetry run codespell --toml pyproject.toml -w 70 | 71 | ###################### 72 | # HELP 73 | ###################### 74 | 75 | help: 76 | @echo '-- LINTING --' 77 | @echo 'format - run code formatters' 78 | @echo 'lint - run linters' 79 | @echo 'spell_check - run codespell on the project' 80 | @echo 'spell_fix - run codespell on the project and fix the errors' 81 | @echo '-- TESTS --' 82 | @echo 'coverage - run unit tests and generate coverage report' 83 | @echo 'test - run unit tests' 84 | @echo 'tests - run unit tests (alias for "make test")' 85 | @echo 'test TEST_FILE= - run all tests in file' 86 | -------------------------------------------------------------------------------- /README.en.md: -------------------------------------------------------------------------------- 1 | # 🔗 LangChain-GLM 2 | 3 | 4 | ## Project Overview 5 | This project utilizes the foundational components of LangChain to implement a comprehensive framework 6 | that supports intelligent agents and related tasks. The core is built on Zhiyuan AI's latest 7 | GLM-4 All Tools. Through Zhiyuan AI's API interface, it can autonomously understand user intentions, 8 | plan complex instructions, and invoke one or more tools (such as web browsers, Python interpreters, 9 | and text-to-image models) to accomplish intricate tasks. 10 | 11 | ![all_tools.png](docs/img/all_tools.png) 12 | 13 | > Fig. | The overall process of GLM-4 All Tools and custom GLMs (agents). 14 | 15 | ## Project Structure 16 | 17 | | Package Path | Description | 18 | | ------------------------------------------------------------------ | ------------------------------------------------------------- | 19 | | [agent_toolkits](https://github.com/MetaGLM/langchain-zhipuai/tree/main/langchain_glm/agent_toolkits) | Platform tool AdapterAllTool adapter, a platform adapter tool that provides a unified interface for various tools, aiming for seamless integration and execution across different platforms. This tool adapts to specific platform parameters to ensure compatibility and consistent output. | 20 | | [agents](https://github.com/MetaGLM/langchain-zhipuai/tree/main/langchain_glm/agents) | Encapsulates the input, output, agent sessions, tool parameters, and tool execution strategies for the AgentExecutor. | 21 | | [callbacks](https://github.com/MetaGLM/langchain-zhipuai/tree/main/langchain_glm/callbacks) | Abstracts some interactive events during the AgentExecutor process, displaying information through events. | 22 | | [chat_models](https://github.com/MetaGLM/langchain-zhipuai/tree/main/langchain_glm/chat_models) | A wrapper layer for the Zhipu AI SDK, providing integration with LangChain's BaseChatModel and formatting input and output as message bodies. | 23 | | [embeddings](https://github.com/MetaGLM/langchain-zhipuai/tree/main/langchain_glm/embeddings) | A wrapper layer for the Zhipu AI SDK, providing integration with LangChain's Embeddings. | 24 | | [utils](https://github.com/MetaGLM/langchain-zhipuai/tree/main/langchain_glm/utils) | Various session tools. | 25 | 26 | 27 | ## 快速使用 28 | 29 | - Install from the repository 30 | https://github.com/MetaGLM/langchain-glm/releases 31 | - Install directly from the source using pip 32 | ```bash 33 | pip install git+https://github.com/MetaGLM/langchain-glm.git -v 34 | ``` 35 | - Install from PyPI 36 | ```bash 37 | pip install langchain-glm 38 | ``` 39 | 40 | > Before using, please set the environment variable `ZHIPUAI_API_KEY` with the value of your Zhipu AI API Key. 41 | 42 | #### Tool Usage 43 | - Set environment variables 44 | ```python 45 | import getpass 46 | import os 47 | 48 | os.environ["ZHIPUAI_API_KEY"] = getpass.getpass() 49 | 50 | ``` 51 | ```python 52 | from langchain_glm import ChatZhipuAI 53 | llm = ChatZhipuAI(model="glm-4") 54 | ``` 55 | 56 | 57 | - example tools: 58 | ```python 59 | from langchain_core.tools import tool 60 | 61 | @tool 62 | def multiply(first_int: int, second_int: int) -> int: 63 | """Multiply two integers together.""" 64 | return first_int * second_int 65 | 66 | @tool 67 | def add(first_int: int, second_int: int) -> int: 68 | "Add two integers." 69 | return first_int + second_int 70 | 71 | @tool 72 | def exponentiate(base: int, exponent: int) -> int: 73 | "Exponentiate the base to the exponent power." 74 | return base**exponent 75 | ``` 76 | - Build Chain 77 | Bind tools to the language model and invoke: 78 | ```python 79 | from operator import itemgetter 80 | from typing import Dict, List, Union 81 | 82 | from langchain_core.messages import AIMessage 83 | from langchain_core.runnables import ( 84 | Runnable, 85 | RunnableLambda, 86 | RunnableMap, 87 | RunnablePassthrough, 88 | ) 89 | 90 | tools = [multiply, exponentiate, add] 91 | llm_with_tools = llm.bind_tools(tools) 92 | tool_map = {tool.name: tool for tool in tools} 93 | 94 | 95 | def call_tools(msg: AIMessage) -> Runnable: 96 | """Simple sequential tool calling helper.""" 97 | tool_map = {tool.name: tool for tool in tools} 98 | tool_calls = msg.tool_calls.copy() 99 | for tool_call in tool_calls: 100 | tool_call["output"] = tool_map[tool_call["name"]].invoke(tool_call["args"]) 101 | return tool_calls 102 | 103 | 104 | chain = llm_with_tools | call_tools 105 | ``` 106 | 107 | - invoke 108 | ```python 109 | chain.invoke( 110 | "What's 23 times 7" 111 | ) 112 | ``` 113 | 114 | #### Example Code 115 | 116 | - Agent Executor 117 | Our `glm-4-alltools` model provides platform tools. With `ZhipuAIAllToolsRunnable`, you can easily set up an executor to run multiple tools. 118 | 119 | `code_interpreter`: Use `sandbox` to specify the code sandbox environment. 120 | - Default = auto, which automatically uses the sandbox environment to execute code. 121 | - Set `sandbox = none` to disable the sandbox environment. 122 | 123 | `web_browser`: Use `web_browser` to specify the browser tool. 124 | `drawing_tool`: Use `drawing_tool` to specify the drawing tool. 125 | 126 | 127 | ```python 128 | 129 | from langchain_glm.agents.zhipuai_all_tools import ZhipuAIAllToolsRunnable 130 | agent_executor = ZhipuAIAllToolsRunnable.create_agent_executor( 131 | model_name="glm-4-alltools", 132 | tools=[ 133 | {"type": "code_interpreter", "code_interpreter": {"sandbox": "none"}}, 134 | {"type": "web_browser"}, 135 | {"type": "drawing_tool"}, 136 | multiply, exponentiate, add 137 | ], 138 | ) 139 | 140 | ``` 141 | 142 | 143 | - Execute `agent_executor` and Print Results 144 | This section uses an agent to run a shell command and prints the output once available. It checks the result type and prints the relevant information. 145 | The `invoke` method returns an asynchronous iterator that can handle the agent's output. 146 | You can call the `invoke` method multiple times, with each call returning a new iterator. 147 | `ZhipuAIAllToolsRunnable` automatically handles state saving and recovery. Some state information is stored within the instance. 148 | You can access the status of `intermediate_steps` through the `callback` attribute. 149 | 150 | 151 | ```python 152 | from langchain_glm.agents.zhipuai_all_tools.base import ( 153 | AllToolsAction, 154 | AllToolsActionToolEnd, 155 | AllToolsActionToolStart, 156 | AllToolsFinish, 157 | AllToolsLLMStatus 158 | ) 159 | from langchain_glm.callbacks.agent_callback_handler import AgentStatus 160 | 161 | 162 | chat_iterator = agent_executor.invoke( 163 | chat_input="What's 23 times 7, and what's five times 18 and add a million plus a billion and cube thirty-seven" 164 | ) 165 | async for item in chat_iterator: 166 | if isinstance(item, AllToolsAction): 167 | print("AllToolsAction:" + str(item.to_json())) 168 | elif isinstance(item, AllToolsFinish): 169 | print("AllToolsFinish:" + str(item.to_json())) 170 | elif isinstance(item, AllToolsActionToolStart): 171 | print("AllToolsActionToolStart:" + str(item.to_json())) 172 | elif isinstance(item, AllToolsActionToolEnd): 173 | print("AllToolsActionToolEnd:" + str(item.to_json())) 174 | elif isinstance(item, AllToolsLLMStatus): 175 | if item.status == AgentStatus.llm_end: 176 | print("llm_end:" + item.text) 177 | ``` 178 | 179 | ## Integrated Demo 180 | We provide an integrated demo that you can run directly to see the results. 181 | - Install dependencies 182 | ```shell 183 | fastapi = "~0.109.2" 184 | sse_starlette = "~1.8.2" 185 | uvicorn = ">=0.27.0.post1" 186 | # webui 187 | streamlit = "1.34.0" 188 | streamlit-option-menu = "0.3.12" 189 | streamlit-antd-components = "0.3.1" 190 | streamlit-chatbox = "1.1.12.post4" 191 | streamlit-modal = "0.1.0" 192 | streamlit-aggrid = "1.0.5" 193 | streamlit-extras = "0.4.2" 194 | ``` 195 | 196 | - server[server.py](tests/assistant/server/server.py) 197 | ```shell 198 | python tests/assistant/server/server.py 199 | ``` 200 | 201 | - client[start_chat.py](tests/assistant/start_chat.py) 202 | ```shell 203 | python tests/assistant/start_chat.py 204 | ``` 205 | 206 | > show 207 | 208 | 209 | https://github.com/MetaGLM/langchain-zhipuai/assets/16206043/06863f9c-cd03-4a74-b76a-daa315718104 210 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🔗 LangChain-GLM 2 | 3 | 4 | ## 项目介绍 5 | 本项目通过langchain的基础组件,实现了完整的支持智能体和相关任务架构。底层采用智谱AI的最新的 `GLM-4 All Tools`, 通过智谱AI的API接口, 6 | 能够自主理解用户的意图,规划复杂的指令,并能够调用一个或多个工具(例如网络浏览器、Python解释器和文本到图像模型)以完成复杂的任务。 7 | 8 | ![all_tools.png](docs/img/all_tools.png) 9 | 10 | > 图|GLM-4 All Tools 和定制 GLMs(智能体)的整体流程。 11 | 12 | ## 项目结构 13 | 14 | | 包路径 | 说明 | 15 | | ------------------------------------------------------------ | ------------------------------------------------------------ | 16 | | [agent_toolkits](https://github.com/MetaGLM/langchain-zhipuai/tree/main/langchain_glm/agent_toolkits) | 平台工具AdapterAllTool适配器, 是一个用于为各种工具提供统一接口的平台适配器工具,目的是在不同平台上实现无缝集成和执行。该工具通过适配特定的平台参数,确保兼容性和一致的输出。 | 17 | | [agents](https://github.com/MetaGLM/langchain-zhipuai/tree/main/langchain_glm/agents) | 定义AgentExecutor的输入、输出、智能体会话、工具参数、工具执行策略的封装 | 18 | | [callbacks](https://github.com/MetaGLM/langchain-zhipuai/tree/main/langchain_glm/callbacks) | 抽象AgentExecutor过程中的一些交互事件,通过事件展示信息 | 19 | | [chat_models](https://github.com/MetaGLM/langchain-zhipuai/tree/main/langchain_glm/chat_models) | zhipuai sdk的封装层,提供langchain的BaseChatModel集成,格式化输入输出为消息体 | 20 | | [embeddings](https://github.com/MetaGLM/langchain-zhipuai/tree/main/langchain_glm/embeddings) | zhipuai sdk的封装层,提供langchain的Embeddings集成 | 21 | | [utils](https://github.com/MetaGLM/langchain-zhipuai/tree/main/langchain_glm/utils) | 一些会话工具 | 22 | 23 | 24 | ## 快速使用 25 | ### Python版本支持 26 | 正式的 Python (3.8, 3.9, 3.10, 3.11, 3.12) 27 | 28 | - 从 repo 安装 29 | https://github.com/MetaGLM/langchain-glm/releases 30 | - 直接使用pip源码安装 31 | pip install git+https://github.com/MetaGLM/langchain-glm.git -v 32 | - 从pypi安装 33 | pip install langchain-glm 34 | 35 | > 使用前请设置环境变量`ZHIPUAI_API_KEY`,值为智谱AI的API Key。 36 | 37 | 38 | #### 工具使用 39 | - Set environment variables 40 | ```python 41 | import getpass 42 | import os 43 | 44 | os.environ["ZHIPUAI_API_KEY"] = getpass.getpass() 45 | 46 | ``` 47 | ```python 48 | from langchain_glm import ChatZhipuAI 49 | llm = ChatZhipuAI(model="glm-4") 50 | ``` 51 | 52 | 53 | - 定义一些示例工具: 54 | ```python 55 | from langchain_core.tools import tool 56 | 57 | @tool 58 | def multiply(first_int: int, second_int: int) -> int: 59 | """Multiply two integers together.""" 60 | return first_int * second_int 61 | 62 | @tool 63 | def add(first_int: int, second_int: int) -> int: 64 | "Add two integers." 65 | return first_int + second_int 66 | 67 | @tool 68 | def exponentiate(base: int, exponent: int) -> int: 69 | "Exponentiate the base to the exponent power." 70 | return base**exponent 71 | ``` 72 | - 构建chain 73 | 绑定工具到语言模型并调用: 74 | ```python 75 | from operator import itemgetter 76 | from typing import Dict, List, Union 77 | 78 | from langchain_core.messages import AIMessage 79 | from langchain_core.runnables import ( 80 | Runnable, 81 | RunnableLambda, 82 | RunnableMap, 83 | RunnablePassthrough, 84 | ) 85 | 86 | tools = [multiply, exponentiate, add] 87 | llm_with_tools = llm.bind_tools(tools) 88 | tool_map = {tool.name: tool for tool in tools} 89 | 90 | 91 | def call_tools(msg: AIMessage) -> Runnable: 92 | """Simple sequential tool calling helper.""" 93 | tool_map = {tool.name: tool for tool in tools} 94 | tool_calls = msg.tool_calls.copy() 95 | for tool_call in tool_calls: 96 | tool_call["output"] = tool_map[tool_call["name"]].invoke(tool_call["args"]) 97 | return tool_calls 98 | 99 | 100 | chain = llm_with_tools | call_tools 101 | ``` 102 | 103 | - 调用chain 104 | ```python 105 | chain.invoke( 106 | "What's 23 times 7, and what's five times 18 and add a million plus a billion and cube thirty-seven" 107 | ) 108 | ``` 109 | 110 | #### 代码解析使用示例 111 | 112 | 113 | - 创建一个代理执行器 114 | 我们的glm-4-alltools的模型提供了平台工具,通过ZhipuAIAllToolsRunnable,你可以非常方便的设置了一个执行器来运行多个工具。 115 | 116 | code_interpreter:使用`sandbox`指定代码沙盒环境, 117 | 默认 = auto,即自动调用沙盒环境执行代码。 118 | 设置 sandbox = none,不启用沙盒环境。 119 | 120 | web_browser:使用`web_browser`指定浏览器工具。 121 | drawing_tool:使用`drawing_tool`指定绘图工具。 122 | 123 | ```python 124 | 125 | from langchain_glm.agents.zhipuai_all_tools import ZhipuAIAllToolsRunnable 126 | agent_executor = ZhipuAIAllToolsRunnable.create_agent_executor( 127 | model_name="glm-4-alltools", 128 | tools=[ 129 | {"type": "code_interpreter", "code_interpreter": {"sandbox": "none"}}, 130 | {"type": "web_browser"}, 131 | {"type": "drawing_tool"}, 132 | multiply, exponentiate, add 133 | ], 134 | ) 135 | 136 | ``` 137 | 138 | 139 | - 执行agent_executor并打印结果 140 | 这部分使用代理来运行一个Shell命令,并在结果出现时打印出来。它检查结果的类型并打印相关信息。 141 | 这个invoke返回一个异步迭代器,可以用来处理代理的输出。 142 | 你可以多次调用invoke方法,每次调用都会返回一个新的迭代器。 143 | ZhipuAIAllToolsRunnable会自动处理状态保存和恢复,一些状态信息会被保存实例中 144 | 你可以通过callback属性获取intermediate_steps的状态信息。 145 | ```python 146 | from langchain_glm.agents.zhipuai_all_tools.base import ( 147 | AllToolsAction, 148 | AllToolsActionToolEnd, 149 | AllToolsActionToolStart, 150 | AllToolsFinish, 151 | AllToolsLLMStatus 152 | ) 153 | from langchain_glm.callbacks.agent_callback_handler import AgentStatus 154 | 155 | 156 | chat_iterator = agent_executor.invoke( 157 | chat_input="看下本地文件有哪些,告诉我你用的是什么文件,查看当前目录" 158 | ) 159 | async for item in chat_iterator: 160 | if isinstance(item, AllToolsAction): 161 | print("AllToolsAction:" + str(item.to_json())) 162 | elif isinstance(item, AllToolsFinish): 163 | print("AllToolsFinish:" + str(item.to_json())) 164 | elif isinstance(item, AllToolsActionToolStart): 165 | print("AllToolsActionToolStart:" + str(item.to_json())) 166 | elif isinstance(item, AllToolsActionToolEnd): 167 | print("AllToolsActionToolEnd:" + str(item.to_json())) 168 | elif isinstance(item, AllToolsLLMStatus): 169 | if item.status == AgentStatus.llm_end: 170 | print("llm_end:" + item.text) 171 | ``` 172 | 173 | ## 集成demo 174 | 我们提供了一个集成的demo,可以直接运行,查看效果。 175 | - 安装依赖 176 | ```shell 177 | fastapi = "~0.109.2" 178 | sse_starlette = "~1.8.2" 179 | uvicorn = ">=0.27.0.post1" 180 | # webui 181 | streamlit = "1.34.0" 182 | streamlit-option-menu = "0.3.12" 183 | streamlit-antd-components = "0.3.1" 184 | streamlit-chatbox = "1.1.12.post4" 185 | streamlit-modal = "0.1.0" 186 | streamlit-aggrid = "1.0.5" 187 | streamlit-extras = "0.4.2" 188 | ``` 189 | 190 | - 运行后端服务[server.py](tests/assistant/server/server.py) 191 | ```shell 192 | python tests/assistant/server/server.py 193 | ``` 194 | 195 | - 运行前端服务[test_chat.py](tests/assistant/test_chat.py) 196 | ```shell 197 | python tests/assistant/start_chat.py 198 | ``` 199 | 200 | > 展示 201 | 202 | 203 | https://github.com/MetaGLM/langchain-zhipuai/assets/16206043/06863f9c-cd03-4a74-b76a-daa315718104 204 | -------------------------------------------------------------------------------- /docs/img/MetaGLM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MetaGLM/langchain-zhipuai/29efaeacbcb4db7572c2f09e60a4196771eefd24/docs/img/MetaGLM.png -------------------------------------------------------------------------------- /docs/img/all_tools.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MetaGLM/langchain-zhipuai/29efaeacbcb4db7572c2f09e60a4196771eefd24/docs/img/all_tools.png -------------------------------------------------------------------------------- /docs/img/demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MetaGLM/langchain-zhipuai/29efaeacbcb4db7572c2f09e60a4196771eefd24/docs/img/demo.mp4 -------------------------------------------------------------------------------- /langchain_glm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # ruff: noqa: E402 3 | """Main entrypoint into package.""" 4 | from importlib import metadata 5 | 6 | from langchain_glm.agents import ZhipuAIAllToolsRunnable 7 | from langchain_glm.chat_models import ChatZhipuAI 8 | 9 | try: 10 | __version__ = metadata.version(__package__) 11 | except metadata.PackageNotFoundError: 12 | # Case where package metadata is not available. 13 | __version__ = "" 14 | del metadata # optional, avoids polluting the results of dir(__package__) 15 | 16 | 17 | __all__ = [ 18 | "ChatZhipuAI", 19 | "ZhipuAIAllToolsRunnable", 20 | ] 21 | -------------------------------------------------------------------------------- /langchain_glm/agent_toolkits/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from langchain_glm.agent_toolkits.all_tools import AdapterAllTool, BaseToolOutput 3 | 4 | __all__ = ["BaseToolOutput", "AdapterAllTool"] 5 | -------------------------------------------------------------------------------- /langchain_glm/agent_toolkits/all_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from langchain_glm.agent_toolkits.all_tools.tool import ( 3 | AdapterAllTool, 4 | BaseToolOutput, 5 | ) 6 | 7 | __all__ = ["BaseToolOutput", "AdapterAllTool"] 8 | -------------------------------------------------------------------------------- /langchain_glm/agent_toolkits/all_tools/code_interpreter_tool.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import logging 4 | from dataclasses import dataclass 5 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union 6 | 7 | from langchain_core.agents import AgentAction 8 | from langchain_core.callbacks import ( 9 | AsyncCallbackManagerForChainRun, 10 | AsyncCallbackManagerForToolRun, 11 | CallbackManagerForToolRun, 12 | ) 13 | 14 | from langchain_glm.agent_toolkits import AdapterAllTool 15 | from langchain_glm.agent_toolkits.all_tools.tool import ( 16 | AllToolExecutor, 17 | BaseToolOutput, 18 | ) 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class CodeInterpreterToolOutput(BaseToolOutput): 24 | platform_params: Dict[str, Any] 25 | tool: str 26 | code_input: str 27 | code_output: Dict[str, Any] 28 | 29 | def __init__( 30 | self, 31 | tool: str, 32 | code_input: str, 33 | code_output: Dict[str, Any], 34 | platform_params: Dict[str, Any], 35 | **extras: Any, 36 | ) -> None: 37 | data = CodeInterpreterToolOutput.paser_data( 38 | tool=tool, code_input=code_input, code_output=code_output 39 | ) 40 | super().__init__(data, "", "", **extras) 41 | self.platform_params = platform_params 42 | self.tool = tool 43 | self.code_input = code_input 44 | self.code_output = code_output 45 | 46 | @staticmethod 47 | def paser_data(tool: str, code_input: str, code_output: Dict[str, Any]) -> str: 48 | return f"""Access:{tool}, Message: {code_input},{code_output}""" 49 | 50 | 51 | @dataclass 52 | class CodeInterpreterAllToolExecutor(AllToolExecutor): 53 | """platform adapter tool for code interpreter tool""" 54 | 55 | name: str 56 | 57 | @staticmethod 58 | def _python_ast_interpreter( 59 | code_input: str, platform_params: Dict[str, Any] = None 60 | ): 61 | """Use Shell to execute system shell commands""" 62 | 63 | try: 64 | from langchain_experimental.tools import PythonAstREPLTool 65 | 66 | tool = PythonAstREPLTool() 67 | out = tool.run(tool_input=code_input) 68 | if str(out) == "": 69 | raise ValueError(f"Tool {tool.name} local sandbox is out empty") 70 | return CodeInterpreterToolOutput( 71 | tool=tool.name, 72 | code_input=code_input, 73 | code_output=out, 74 | platform_params=platform_params, 75 | ) 76 | except ImportError: 77 | raise AttributeError( 78 | "This tool has been moved to langchain experiment. " 79 | "This tool has access to a python REPL. " 80 | "For best practices make sure to sandbox this tool. " 81 | "Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md " 82 | "To keep using this code as is, install langchain experimental and " 83 | "update relevant imports replacing 'langchain' with 'langchain_experimental'" 84 | ) 85 | 86 | def run( 87 | self, 88 | tool: str, 89 | tool_input: str, 90 | log: str, 91 | outputs: List[Union[str, dict]] = None, 92 | run_manager: Optional[CallbackManagerForToolRun] = None, 93 | ) -> CodeInterpreterToolOutput: 94 | if outputs is None or str(outputs).strip() == "": 95 | if "auto" == self.platform_params.get("sandbox", "auto"): 96 | raise ValueError( 97 | f"Tool {self.name} sandbox is auto , but log is None, is server error" 98 | ) 99 | elif "none" == self.platform_params.get("sandbox", "auto"): 100 | logger.warning( 101 | f"Tool {self.name} sandbox is local!!!, this not safe, please use jupyter sandbox it" 102 | ) 103 | return self._python_ast_interpreter( 104 | code_input=tool_input, platform_params=self.platform_params 105 | ) 106 | 107 | return CodeInterpreterToolOutput( 108 | tool=tool, 109 | code_input=tool_input, 110 | code_output=json.dumps(outputs), 111 | platform_params=self.platform_params, 112 | ) 113 | 114 | async def arun( 115 | self, 116 | tool: str, 117 | tool_input: str, 118 | log: str, 119 | outputs: List[Union[str, dict]] = None, 120 | run_manager: Optional[AsyncCallbackManagerForToolRun] = None, 121 | ) -> CodeInterpreterToolOutput: 122 | """Use the tool asynchronously.""" 123 | if outputs is None or str(outputs).strip() == "" or len(outputs) == 0: 124 | if "auto" == self.platform_params.get("sandbox", "auto"): 125 | raise ValueError( 126 | f"Tool {self.name} sandbox is auto , but log is None, is server error" 127 | ) 128 | elif "none" == self.platform_params.get("sandbox", "auto"): 129 | logger.warning( 130 | f"Tool {self.name} sandbox is local!!!, this not safe, please use jupyter sandbox it" 131 | ) 132 | return self._python_ast_interpreter( 133 | code_input=tool_input, platform_params=self.platform_params 134 | ) 135 | 136 | return CodeInterpreterToolOutput( 137 | tool=tool, 138 | code_input=tool_input, 139 | code_output=json.dumps(outputs), 140 | platform_params=self.platform_params, 141 | ) 142 | 143 | 144 | class CodeInterpreterAdapterAllTool(AdapterAllTool[CodeInterpreterAllToolExecutor]): 145 | @classmethod 146 | def get_type(cls) -> str: 147 | return "CodeInterpreterAdapterAllTool" 148 | 149 | def _build_adapter_all_tool( 150 | self, platform_params: Dict[str, Any] 151 | ) -> CodeInterpreterAllToolExecutor: 152 | return CodeInterpreterAllToolExecutor( 153 | name="code_interpreter", platform_params=platform_params 154 | ) 155 | -------------------------------------------------------------------------------- /langchain_glm/agent_toolkits/all_tools/drawing_tool.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | from dataclasses import dataclass 4 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union 5 | 6 | from langchain_core.agents import AgentAction 7 | from langchain_core.callbacks import ( 8 | AsyncCallbackManagerForChainRun, 9 | AsyncCallbackManagerForToolRun, 10 | CallbackManagerForToolRun, 11 | ) 12 | 13 | from langchain_glm.agent_toolkits import AdapterAllTool 14 | from langchain_glm.agent_toolkits.all_tools.struct_type import ( 15 | AdapterAllToolStructType, 16 | ) 17 | from langchain_glm.agent_toolkits.all_tools.tool import ( 18 | AllToolExecutor, 19 | BaseToolOutput, 20 | ) 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class DrawingToolOutput(BaseToolOutput): 26 | platform_params: Dict[str, Any] 27 | 28 | def __init__( 29 | self, 30 | data: Any, 31 | platform_params: Dict[str, Any], 32 | **extras: Any, 33 | ) -> None: 34 | super().__init__(data, "", "", **extras) 35 | self.platform_params = platform_params 36 | 37 | 38 | @dataclass 39 | class DrawingAllToolExecutor(AllToolExecutor): 40 | """platform adapter tool for code interpreter tool""" 41 | 42 | name: str 43 | 44 | def run( 45 | self, 46 | tool: str, 47 | tool_input: str, 48 | log: str, 49 | outputs: List[Union[str, dict]] = None, 50 | run_manager: Optional[CallbackManagerForToolRun] = None, 51 | ) -> DrawingToolOutput: 52 | if outputs is None or str(outputs).strip() == "": 53 | raise ValueError(f"Tool {self.name} is server error") 54 | 55 | return DrawingToolOutput( 56 | data=f"""Access:{tool}, Message: {tool_input},{log}""", 57 | platform_params=self.platform_params, 58 | ) 59 | 60 | async def arun( 61 | self, 62 | tool: str, 63 | tool_input: str, 64 | log: str, 65 | outputs: List[Union[str, dict]] = None, 66 | run_manager: Optional[AsyncCallbackManagerForToolRun] = None, 67 | ) -> DrawingToolOutput: 68 | """Use the tool asynchronously.""" 69 | if outputs is None or str(outputs).strip() == "" or len(outputs) == 0: 70 | raise ValueError(f"Tool {self.name} is server error") 71 | 72 | return DrawingToolOutput( 73 | data=f"""Access:{tool}, Message: {tool_input},{log}""", 74 | platform_params=self.platform_params, 75 | ) 76 | 77 | 78 | class DrawingAdapterAllTool(AdapterAllTool[DrawingAllToolExecutor]): 79 | @classmethod 80 | def get_type(cls) -> str: 81 | return "DrawingAdapterAllTool" 82 | 83 | def _build_adapter_all_tool( 84 | self, platform_params: Dict[str, Any] 85 | ) -> DrawingAllToolExecutor: 86 | return DrawingAllToolExecutor( 87 | name=AdapterAllToolStructType.DRAWING_TOOL, platform_params=platform_params 88 | ) 89 | -------------------------------------------------------------------------------- /langchain_glm/agent_toolkits/all_tools/registry.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Dict, Type 3 | 4 | from langchain_glm.agent_toolkits import AdapterAllTool 5 | from langchain_glm.agent_toolkits.all_tools.code_interpreter_tool import ( 6 | CodeInterpreterAdapterAllTool, 7 | ) 8 | from langchain_glm.agent_toolkits.all_tools.drawing_tool import ( 9 | DrawingAdapterAllTool, 10 | ) 11 | from langchain_glm.agent_toolkits.all_tools.struct_type import ( 12 | AdapterAllToolStructType, 13 | ) 14 | from langchain_glm.agent_toolkits.all_tools.web_browser_tool import ( 15 | WebBrowserAdapterAllTool, 16 | ) 17 | 18 | TOOL_STRUCT_TYPE_TO_TOOL_CLASS: Dict[AdapterAllToolStructType, Type[AdapterAllTool]] = { 19 | AdapterAllToolStructType.CODE_INTERPRETER: CodeInterpreterAdapterAllTool, 20 | AdapterAllToolStructType.DRAWING_TOOL: DrawingAdapterAllTool, 21 | AdapterAllToolStructType.WEB_BROWSER: WebBrowserAdapterAllTool, 22 | } 23 | -------------------------------------------------------------------------------- /langchain_glm/agent_toolkits/all_tools/struct_type.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """IndexStructType class.""" 3 | 4 | from enum import Enum 5 | 6 | 7 | class AdapterAllToolStructType(str, Enum): 8 | """ 9 | 10 | Attributes: 11 | DICT ("dict"): 12 | 13 | """ 14 | 15 | # TODO: refactor so these are properties on the base class 16 | 17 | CODE_INTERPRETER = "code_interpreter" 18 | DRAWING_TOOL = "drawing_tool" 19 | WEB_BROWSER = "web_browser" 20 | -------------------------------------------------------------------------------- /langchain_glm/agent_toolkits/all_tools/tool.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """platform adapter tool """ 3 | 4 | from __future__ import annotations 5 | 6 | import json 7 | import logging 8 | from abc import abstractmethod 9 | from dataclasses import dataclass 10 | from pathlib import Path 11 | from typing import ( 12 | TYPE_CHECKING, 13 | Any, 14 | Dict, 15 | Generic, 16 | Optional, 17 | Tuple, 18 | Type, 19 | TypeVar, 20 | Union, 21 | ) 22 | 23 | from dataclasses_json import DataClassJsonMixin 24 | from langchain_core.agents import AgentAction, AgentFinish, AgentStep 25 | from langchain_core.callbacks import ( 26 | AsyncCallbackManagerForChainRun, 27 | AsyncCallbackManagerForToolRun, 28 | BaseCallbackManager, 29 | CallbackManagerForChainRun, 30 | CallbackManagerForToolRun, 31 | Callbacks, 32 | ) 33 | from langchain_core.tools import BaseTool 34 | 35 | from langchain_glm.agent_toolkits.all_tools.struct_type import ( 36 | AdapterAllToolStructType, 37 | ) 38 | from langchain_glm.agents.output_parsers.code_interpreter import ( 39 | CodeInterpreterAgentAction, 40 | ) 41 | from langchain_glm.agents.output_parsers.drawing_tool import DrawingToolAgentAction 42 | from langchain_glm.agents.output_parsers.web_browser import WebBrowserAgentAction 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | 47 | class BaseToolOutput: 48 | """ 49 | LLM 要求 Tool 的输出为 str,但 Tool 用在别处时希望它正常返回结构化数据。 50 | 只需要将 Tool 返回值用该类封装,能同时满足两者的需要。 51 | 基类简单的将返回值字符串化,或指定 format="json" 将其转为 json。 52 | 用户也可以继承该类定义自己的转换方法。 53 | """ 54 | 55 | def __init__( 56 | self, 57 | data: Any, 58 | format: str = "", 59 | data_alias: str = "", 60 | **extras: Any, 61 | ) -> None: 62 | self.data = data 63 | self.format = format 64 | self.extras = extras 65 | if data_alias: 66 | setattr(self, data_alias, property(lambda obj: obj.data)) 67 | 68 | def __str__(self) -> str: 69 | if self.format == "json": 70 | return json.dumps(self.data, ensure_ascii=False, indent=2) 71 | else: 72 | return str(self.data) 73 | 74 | 75 | @dataclass 76 | class AllToolExecutor(DataClassJsonMixin): 77 | platform_params: Dict[str, Any] 78 | 79 | @abstractmethod 80 | def run(self, *args: Any, **kwargs: Any) -> BaseToolOutput: 81 | pass 82 | 83 | @abstractmethod 84 | async def arun( 85 | self, 86 | *args: Any, 87 | **kwargs: Any, 88 | ) -> BaseToolOutput: 89 | pass 90 | 91 | 92 | E = TypeVar("E", bound=AllToolExecutor) 93 | 94 | 95 | class AdapterAllTool(BaseTool, Generic[E]): 96 | """platform adapter tool for all tools.""" 97 | 98 | name: str 99 | description: str 100 | 101 | platform_params: Dict[str, Any] 102 | """tools params """ 103 | adapter_all_tool: E 104 | 105 | def __init__(self, name: str, platform_params: Dict[str, Any], **data: Any): 106 | super().__init__( 107 | name=name, 108 | description=f"platform adapter tool for {name}", 109 | platform_params=platform_params, 110 | adapter_all_tool=self._build_adapter_all_tool(platform_params), 111 | **data, 112 | ) 113 | 114 | @abstractmethod 115 | def _build_adapter_all_tool(self, platform_params: Dict[str, Any]) -> E: 116 | raise NotImplementedError 117 | 118 | @classmethod 119 | @abstractmethod 120 | def get_type(cls) -> str: 121 | raise NotImplementedError 122 | 123 | def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: 124 | # For backwards compatibility, if run_input is a string, 125 | # pass as a positional argument. 126 | if tool_input is None: 127 | return (), {} 128 | if isinstance(tool_input, str): 129 | return (tool_input,), {} 130 | else: 131 | # for tool defined with `*args` parameters 132 | # the args_schema has a field named `args` 133 | # it should be expanded to actual *args 134 | # e.g.: test_tools 135 | # .test_named_tool_decorator_return_direct 136 | # .search_api 137 | if "args" in tool_input: 138 | args = tool_input["args"] 139 | if args is None: 140 | tool_input.pop("args") 141 | return (), tool_input 142 | elif isinstance(args, tuple): 143 | tool_input.pop("args") 144 | return args, tool_input 145 | return (), tool_input 146 | 147 | def _run( 148 | self, 149 | agent_action: AgentAction, 150 | run_manager: Optional[AsyncCallbackManagerForChainRun] = None, 151 | **tool_run_kwargs: Any, 152 | ) -> Any: 153 | if ( 154 | AdapterAllToolStructType.CODE_INTERPRETER == agent_action.tool 155 | and isinstance(agent_action, CodeInterpreterAgentAction) 156 | ): 157 | return self.adapter_all_tool.run( 158 | **{ 159 | "tool": agent_action.tool, 160 | "tool_input": agent_action.tool_input, 161 | "log": agent_action.log, 162 | "outputs": agent_action.outputs, 163 | }, 164 | **tool_run_kwargs, 165 | ) 166 | elif AdapterAllToolStructType.DRAWING_TOOL == agent_action.tool and isinstance( 167 | agent_action, DrawingToolAgentAction 168 | ): 169 | return self.adapter_all_tool.run( 170 | **{ 171 | "tool": agent_action.tool, 172 | "tool_input": agent_action.tool_input, 173 | "log": agent_action.log, 174 | "outputs": agent_action.outputs, 175 | }, 176 | **tool_run_kwargs, 177 | ) 178 | elif AdapterAllToolStructType.WEB_BROWSER == agent_action.tool and isinstance( 179 | agent_action, WebBrowserAgentAction 180 | ): 181 | return self.adapter_all_tool.run( 182 | **{ 183 | "tool": agent_action.tool, 184 | "tool_input": agent_action.tool_input, 185 | "log": agent_action.log, 186 | "outputs": agent_action.outputs, 187 | }, 188 | **tool_run_kwargs, 189 | ) 190 | else: 191 | raise KeyError() 192 | 193 | async def _arun( 194 | self, 195 | agent_action: AgentAction, 196 | run_manager: Optional[AsyncCallbackManagerForChainRun] = None, 197 | **tool_run_kwargs: Any, 198 | ) -> Any: 199 | if ( 200 | AdapterAllToolStructType.CODE_INTERPRETER == agent_action.tool 201 | and isinstance(agent_action, CodeInterpreterAgentAction) 202 | ): 203 | return await self.adapter_all_tool.arun( 204 | **{ 205 | "tool": agent_action.tool, 206 | "tool_input": agent_action.tool_input, 207 | "log": agent_action.log, 208 | "outputs": agent_action.outputs, 209 | }, 210 | **tool_run_kwargs, 211 | ) 212 | 213 | elif AdapterAllToolStructType.DRAWING_TOOL == agent_action.tool and isinstance( 214 | agent_action, DrawingToolAgentAction 215 | ): 216 | return await self.adapter_all_tool.arun( 217 | **{ 218 | "tool": agent_action.tool, 219 | "tool_input": agent_action.tool_input, 220 | "log": agent_action.log, 221 | "outputs": agent_action.outputs, 222 | }, 223 | **tool_run_kwargs, 224 | ) 225 | elif AdapterAllToolStructType.WEB_BROWSER == agent_action.tool and isinstance( 226 | agent_action, WebBrowserAgentAction 227 | ): 228 | return await self.adapter_all_tool.arun( 229 | **{ 230 | "tool": agent_action.tool, 231 | "tool_input": agent_action.tool_input, 232 | "log": agent_action.log, 233 | "outputs": agent_action.outputs, 234 | }, 235 | **tool_run_kwargs, 236 | ) 237 | else: 238 | raise KeyError() 239 | -------------------------------------------------------------------------------- /langchain_glm/agent_toolkits/all_tools/web_browser_tool.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | from dataclasses import dataclass 4 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union 5 | 6 | from langchain_core.agents import AgentAction 7 | from langchain_core.callbacks import ( 8 | AsyncCallbackManagerForChainRun, 9 | AsyncCallbackManagerForToolRun, 10 | CallbackManagerForToolRun, 11 | ) 12 | 13 | from langchain_glm.agent_toolkits import AdapterAllTool 14 | from langchain_glm.agent_toolkits.all_tools.struct_type import ( 15 | AdapterAllToolStructType, 16 | ) 17 | from langchain_glm.agent_toolkits.all_tools.tool import ( 18 | AllToolExecutor, 19 | BaseToolOutput, 20 | ) 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class WebBrowserToolOutput(BaseToolOutput): 26 | platform_params: Dict[str, Any] 27 | 28 | def __init__( 29 | self, 30 | data: Any, 31 | platform_params: Dict[str, Any], 32 | **extras: Any, 33 | ) -> None: 34 | super().__init__(data, "", "", **extras) 35 | self.platform_params = platform_params 36 | 37 | 38 | @dataclass 39 | class WebBrowserAllToolExecutor(AllToolExecutor): 40 | """platform adapter tool for code interpreter tool""" 41 | 42 | name: str 43 | 44 | def run( 45 | self, 46 | tool: str, 47 | tool_input: str, 48 | log: str, 49 | outputs: List[Union[str, dict]] = None, 50 | run_manager: Optional[CallbackManagerForToolRun] = None, 51 | ) -> WebBrowserToolOutput: 52 | if outputs is None or str(outputs).strip() == "": 53 | raise ValueError(f"Tool {self.name} is server error") 54 | 55 | return WebBrowserToolOutput( 56 | data=f"""Access:{tool}, Message: {tool_input},{log}""", 57 | platform_params=self.platform_params, 58 | ) 59 | 60 | async def arun( 61 | self, 62 | tool: str, 63 | tool_input: str, 64 | log: str, 65 | outputs: List[Union[str, dict]] = None, 66 | run_manager: Optional[AsyncCallbackManagerForToolRun] = None, 67 | ) -> WebBrowserToolOutput: 68 | """Use the tool asynchronously.""" 69 | if outputs is None or str(outputs).strip() == "" or len(outputs) == 0: 70 | raise ValueError(f"Tool {self.name} is server error") 71 | 72 | return WebBrowserToolOutput( 73 | data=f"""Access:{tool}, Message: {tool_input},{log}""", 74 | platform_params=self.platform_params, 75 | ) 76 | 77 | 78 | class WebBrowserAdapterAllTool(AdapterAllTool[WebBrowserAllToolExecutor]): 79 | @classmethod 80 | def get_type(cls) -> str: 81 | return "WebBrowserAdapterAllTool" 82 | 83 | def _build_adapter_all_tool( 84 | self, platform_params: Dict[str, Any] 85 | ) -> WebBrowserAllToolExecutor: 86 | return WebBrowserAllToolExecutor( 87 | name=AdapterAllToolStructType.WEB_BROWSER, platform_params=platform_params 88 | ) 89 | -------------------------------------------------------------------------------- /langchain_glm/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from langchain_glm.agents.zhipuai_all_tools import ZhipuAIAllToolsRunnable 3 | 4 | __all__ = ["ZhipuAIAllToolsRunnable"] 5 | -------------------------------------------------------------------------------- /langchain_glm/agents/all_tools_bind/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Sequence 3 | 4 | from langchain_core.language_models import BaseLanguageModel 5 | from langchain_core.prompts.chat import ChatPromptTemplate 6 | from langchain_core.runnables import Runnable, RunnablePassthrough 7 | from langchain_core.runnables.base import RunnableBindingBase 8 | from langchain_core.tools import BaseTool 9 | from langchain_core.utils.function_calling import convert_to_openai_tool 10 | 11 | from langchain_glm.agents.format_scratchpad.all_tools import ( 12 | format_to_zhipuai_all_tool_messages, 13 | ) 14 | from langchain_glm.agents.output_parsers import ZhipuAiALLToolsAgentOutputParser 15 | 16 | 17 | def create_zhipuai_tools_agent( 18 | prompt: ChatPromptTemplate, 19 | llm_with_all_tools: RunnableBindingBase = None, 20 | ) -> Runnable: 21 | """Create an agent that uses OpenAI tools. 22 | 23 | Args: 24 | prompt: The prompt to use. See Prompt section below for more on the expected 25 | input variables. 26 | llm_with_all_tools: Optional. If provided, this will be used as the LLM with all 27 | tools bound to it. If not provided, the tools will be bound to the LLM 28 | provided. 29 | 30 | Returns: 31 | A Runnable sequence representing an agent. It takes as input all the same input 32 | variables as the prompt passed in does. It returns as output either an 33 | AgentAction or AgentFinish. 34 | 35 | Example: 36 | 37 | .. code-block:: python 38 | 39 | from langchain import hub 40 | from langchain_community.chat_models import ChatOpenAI 41 | from langchain.agents import AgentExecutor 42 | 43 | from langchain_glm.agents.all_tools_bind import create_zhipuai_tools_agent 44 | 45 | prompt = hub.pull("hwchase17/openai-tools-agent") 46 | model = ChatOpenAI() 47 | llm_with_all_tools = model.bind ... 48 | 49 | agent = create_zhipuai_tools_agent(llm_with_all_tools, prompt) 50 | agent_executor = AgentExecutor(agent=agent, tools=tools) 51 | 52 | agent_executor.invoke({"input": "hi"}) 53 | 54 | # Using with chat history 55 | from langchain_core.messages import AIMessage, HumanMessage 56 | agent_executor.invoke( 57 | { 58 | "input": "what's my name?", 59 | "chat_history": [ 60 | HumanMessage(content="hi! my name is bob"), 61 | AIMessage(content="Hello Bob! How can I assist you today?"), 62 | ], 63 | } 64 | ) 65 | 66 | Prompt: 67 | 68 | The agent prompt must have an `agent_scratchpad` key that is a 69 | ``MessagesPlaceholder``. Intermediate agent actions and tool output 70 | messages will be passed in here. 71 | 72 | Here's an example: 73 | 74 | .. code-block:: python 75 | 76 | from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder 77 | 78 | prompt = ChatPromptTemplate.from_messages( 79 | [ 80 | ("system", "You are a helpful assistant"), 81 | MessagesPlaceholder("chat_history", optional=True), 82 | ("human", "{input}"), 83 | MessagesPlaceholder("agent_scratchpad"), 84 | ] 85 | ) 86 | """ 87 | missing_vars = {"agent_scratchpad"}.difference( 88 | prompt.input_variables + list(prompt.partial_variables) 89 | ) 90 | if missing_vars: 91 | raise ValueError(f"Prompt missing required variables: {missing_vars}") 92 | 93 | agent = ( 94 | RunnablePassthrough.assign( 95 | agent_scratchpad=lambda x: format_to_zhipuai_all_tool_messages( 96 | x["intermediate_steps"] 97 | ) 98 | ) 99 | | prompt 100 | | llm_with_all_tools 101 | | ZhipuAiALLToolsAgentOutputParser() 102 | ) 103 | 104 | return agent 105 | -------------------------------------------------------------------------------- /langchain_glm/agents/format_scratchpad/all_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | from typing import List, Sequence, Tuple, Union 4 | 5 | from langchain.agents.output_parsers.tools import ToolAgentAction 6 | from langchain_core.agents import AgentAction 7 | from langchain_core.messages import ( 8 | AIMessage, 9 | BaseMessage, 10 | ToolMessage, 11 | ) 12 | 13 | from langchain_glm.agent_toolkits import BaseToolOutput 14 | from langchain_glm.agent_toolkits.all_tools.code_interpreter_tool import ( 15 | CodeInterpreterToolOutput, 16 | ) 17 | from langchain_glm.agent_toolkits.all_tools.drawing_tool import DrawingToolOutput 18 | from langchain_glm.agent_toolkits.all_tools.web_browser_tool import ( 19 | WebBrowserToolOutput, 20 | ) 21 | from langchain_glm.agents.output_parsers.code_interpreter import ( 22 | CodeInterpreterAgentAction, 23 | ) 24 | from langchain_glm.agents.output_parsers.drawing_tool import DrawingToolAgentAction 25 | from langchain_glm.agents.output_parsers.web_browser import WebBrowserAgentAction 26 | 27 | 28 | def _create_tool_message( 29 | agent_action: ToolAgentAction, observation: Union[str, BaseToolOutput] 30 | ) -> ToolMessage: 31 | """Convert agent action and observation into a function message. 32 | Args: 33 | agent_action: the tool invocation request from the agent 34 | observation: the result of the tool invocation 35 | Returns: 36 | FunctionMessage that corresponds to the original tool invocation 37 | """ 38 | if not isinstance(observation, str): 39 | try: 40 | content = json.dumps(observation, ensure_ascii=False) 41 | except Exception: 42 | content = str(observation) 43 | else: 44 | content = observation 45 | return ToolMessage( 46 | tool_call_id=agent_action.tool_call_id, 47 | content=content, 48 | additional_kwargs={"name": agent_action.tool}, 49 | ) 50 | 51 | 52 | def format_to_zhipuai_all_tool_messages( 53 | intermediate_steps: Sequence[Tuple[AgentAction, BaseToolOutput]], 54 | ) -> List[BaseMessage]: 55 | """Convert (AgentAction, tool output) tuples into FunctionMessages. 56 | 57 | Args: 58 | intermediate_steps: Steps the LLM has taken to date, along with observations 59 | 60 | Returns: 61 | list of messages to send to the LLM for the next prediction 62 | 63 | """ 64 | messages = [] 65 | for agent_action, observation in intermediate_steps: 66 | if isinstance(agent_action, CodeInterpreterAgentAction): 67 | if isinstance(observation, CodeInterpreterToolOutput): 68 | if "auto" == observation.platform_params.get("sandbox", "auto"): 69 | new_messages = [ 70 | AIMessage(content=str(observation.code_input)), 71 | _create_tool_message(agent_action, observation), 72 | ] 73 | 74 | messages.extend( 75 | [new for new in new_messages if new not in messages] 76 | ) 77 | elif "none" == observation.platform_params.get("sandbox", "auto"): 78 | new_messages = [ 79 | AIMessage(content=str(observation.code_input)), 80 | _create_tool_message(agent_action, observation.code_output), 81 | ] 82 | 83 | messages.extend( 84 | [new for new in new_messages if new not in messages] 85 | ) 86 | else: 87 | raise ValueError( 88 | f"Unknown sandbox type: {observation.platform_params.get('sandbox', 'auto')}" 89 | ) 90 | else: 91 | raise ValueError(f"Unknown observation type: {type(observation)}") 92 | 93 | elif isinstance(agent_action, DrawingToolAgentAction): 94 | if isinstance(observation, DrawingToolOutput): 95 | new_messages = [AIMessage(content=str(observation))] 96 | messages.extend([new for new in new_messages if new not in messages]) 97 | else: 98 | raise ValueError(f"Unknown observation type: {type(observation)}") 99 | 100 | elif isinstance(agent_action, WebBrowserAgentAction): 101 | if isinstance(observation, WebBrowserToolOutput): 102 | new_messages = [AIMessage(content=str(observation))] 103 | messages.extend([new for new in new_messages if new not in messages]) 104 | else: 105 | raise ValueError(f"Unknown observation type: {type(observation)}") 106 | 107 | elif isinstance(agent_action, ToolAgentAction): 108 | ai_msgs = AIMessage( 109 | content=f"arguments='{agent_action.tool_input}', name='{agent_action.tool}'" 110 | ) 111 | new_messages = [ai_msgs, _create_tool_message(agent_action, observation)] 112 | messages.extend([new for new in new_messages if new not in messages]) 113 | else: 114 | messages.append(AIMessage(content=agent_action.log)) 115 | return messages 116 | -------------------------------------------------------------------------------- /langchain_glm/agents/output_parsers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Parsing utils to go from string to AgentAction or Agent Finish. 3 | 4 | AgentAction means that an action should be taken. 5 | This contains the name of the tool to use, the input to pass to that tool, 6 | and a `log` variable (which contains a log of the agent's thinking). 7 | 8 | AgentFinish means that a response should be given. 9 | This contains a `return_values` dictionary. This usually contains a 10 | single `output` key, but can be extended to contain more. 11 | This also contains a `log` variable (which contains a log of the agent's thinking). 12 | """ 13 | 14 | from langchain_glm.agents.output_parsers.zhipuai_all_tools import ( 15 | ZhipuAiALLToolsAgentOutputParser, 16 | ) 17 | 18 | __all__ = ["ZhipuAiALLToolsAgentOutputParser"] 19 | -------------------------------------------------------------------------------- /langchain_glm/agents/output_parsers/_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Function to find positions of object() instances 3 | def find_object_positions(log_chunk, obj): 4 | return [i for i, x in enumerate(log_chunk) if x == obj] 5 | 6 | 7 | # Function to concatenate segments based on object positions 8 | def concatenate_segments(log_chunk, positions): 9 | segments = [] 10 | start = 0 11 | for pos in positions: 12 | segments.append("".join(map(str, log_chunk[start:pos]))) 13 | start = pos + 1 14 | return segments 15 | -------------------------------------------------------------------------------- /langchain_glm/agents/output_parsers/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any, Dict, Optional 3 | 4 | from zhipuai.core import BaseModel 5 | 6 | 7 | class AllToolsMessageToolCall(BaseModel): 8 | name: Optional[str] 9 | args: Optional[Dict[str, Any]] 10 | id: Optional[str] 11 | 12 | 13 | class AllToolsMessageToolCallChunk(BaseModel): 14 | name: Optional[str] 15 | args: Optional[Dict[str, Any]] 16 | id: Optional[str] 17 | index: Optional[int] 18 | -------------------------------------------------------------------------------- /langchain_glm/agents/output_parsers/code_interpreter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import logging 4 | from collections import deque 5 | from typing import Any, Deque, Dict, List, Union 6 | 7 | from langchain.agents.output_parsers.tools import ToolAgentAction 8 | from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish 9 | from langchain_core.exceptions import OutputParserException 10 | from langchain_core.messages import ( 11 | AIMessage, 12 | BaseMessage, 13 | ToolCall, 14 | ) 15 | from langchain_core.utils.json import ( 16 | parse_partial_json, 17 | ) 18 | from zhipuai.core import BaseModel 19 | 20 | from langchain_glm.agent_toolkits.all_tools.struct_type import ( 21 | AdapterAllToolStructType, 22 | ) 23 | from langchain_glm.agents.output_parsers._utils import ( 24 | concatenate_segments, 25 | find_object_positions, 26 | ) 27 | from langchain_glm.agents.output_parsers.base import ( 28 | AllToolsMessageToolCall, 29 | AllToolsMessageToolCallChunk, 30 | ) 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | class CodeInterpreterAgentAction(ToolAgentAction): 36 | outputs: List[Union[str, dict]] = None 37 | """Output of the tool call.""" 38 | platform_params: dict = None 39 | 40 | 41 | def _best_effort_parse_code_interpreter_tool_calls( 42 | tool_call_chunks: List[dict], 43 | ) -> List[Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk]]: 44 | code_interpreter_chunk: List[ 45 | Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] 46 | ] = [] 47 | # Best-effort parsing allready parsed tool calls 48 | for code_interpreter in tool_call_chunks: 49 | if AdapterAllToolStructType.CODE_INTERPRETER == code_interpreter["name"]: 50 | if isinstance(code_interpreter["args"], str): 51 | args_ = parse_partial_json(code_interpreter["args"]) 52 | else: 53 | args_ = code_interpreter["args"] 54 | if not isinstance(args_, dict): 55 | raise ValueError("Malformed args.") 56 | 57 | if "outputs" in args_: 58 | code_interpreter_chunk.append( 59 | AllToolsMessageToolCall( 60 | name=code_interpreter["name"], 61 | args=args_, 62 | id=code_interpreter["id"], 63 | ) 64 | ) 65 | else: 66 | code_interpreter_chunk.append( 67 | AllToolsMessageToolCallChunk( 68 | name=code_interpreter["name"], 69 | args=args_, 70 | id=code_interpreter["id"], 71 | index=code_interpreter.get("index"), 72 | ) 73 | ) 74 | 75 | return code_interpreter_chunk 76 | 77 | 78 | def _paser_code_interpreter_chunk_input( 79 | message: BaseMessage, 80 | code_interpreter_chunk: List[ 81 | Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] 82 | ], 83 | ) -> Deque[CodeInterpreterAgentAction]: 84 | try: 85 | input_log_chunk = [] 86 | 87 | outputs: List[List[dict]] = [] 88 | obj = object() 89 | for interpreter_chunk in code_interpreter_chunk: 90 | interpreter_chunk_args = interpreter_chunk.args 91 | 92 | if "input" in interpreter_chunk_args: 93 | input_log_chunk.append(interpreter_chunk_args["input"]) 94 | if "outputs" in interpreter_chunk_args: 95 | input_log_chunk.append(obj) 96 | outputs.append(interpreter_chunk_args["outputs"]) 97 | 98 | if input_log_chunk[-1] is not obj: 99 | input_log_chunk.append(obj) 100 | # segments the list based on these positions, and then concatenates each segment into a string 101 | # Find positions of object() instances 102 | positions = find_object_positions(input_log_chunk, obj) 103 | 104 | # Concatenate segments 105 | result_actions = concatenate_segments(input_log_chunk, positions) 106 | 107 | tool_call_id = ( 108 | code_interpreter_chunk[0].id if code_interpreter_chunk[0].id else "abc" 109 | ) 110 | code_interpreter_action_result_stack: Deque[ 111 | CodeInterpreterAgentAction 112 | ] = deque() 113 | for i, action in enumerate(result_actions): 114 | if len(result_actions) > len(outputs): 115 | outputs.insert(i, []) 116 | 117 | out_logs = [logs["logs"] for logs in outputs[i] if "logs" in logs] 118 | out_str = "\n".join(out_logs) 119 | log = f"{action}\r\n{out_str}" 120 | code_interpreter_action = CodeInterpreterAgentAction( 121 | tool=AdapterAllToolStructType.CODE_INTERPRETER, 122 | tool_input=action, 123 | outputs=outputs[i], 124 | log=log, 125 | message_log=[message], 126 | tool_call_id=tool_call_id, 127 | ) 128 | 129 | code_interpreter_action_result_stack.append(code_interpreter_action) 130 | return code_interpreter_action_result_stack 131 | 132 | except Exception as e: 133 | logger.error(f"Error parsing code_interpreter_chunk: {e}", exc_info=True) 134 | raise OutputParserException( 135 | f"Could not parse tool input: code_interpreter because {e}" 136 | ) 137 | -------------------------------------------------------------------------------- /langchain_glm/agents/output_parsers/drawing_tool.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import logging 4 | from collections import deque 5 | from json import JSONDecodeError 6 | from typing import Any, Deque, Dict, List, Union 7 | 8 | from langchain.agents.output_parsers.tools import ToolAgentAction 9 | from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish 10 | from langchain_core.exceptions import OutputParserException 11 | from langchain_core.messages import ( 12 | AIMessage, 13 | BaseMessage, 14 | ToolCall, 15 | ) 16 | from langchain_core.utils.json import ( 17 | parse_partial_json, 18 | ) 19 | from zhipuai.core import BaseModel 20 | 21 | from langchain_glm.agent_toolkits.all_tools.struct_type import ( 22 | AdapterAllToolStructType, 23 | ) 24 | from langchain_glm.agents.output_parsers._utils import ( 25 | concatenate_segments, 26 | find_object_positions, 27 | ) 28 | from langchain_glm.agents.output_parsers.base import ( 29 | AllToolsMessageToolCall, 30 | AllToolsMessageToolCallChunk, 31 | ) 32 | from langchain_glm.chat_models.all_tools_message import ALLToolsMessageChunk 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | class DrawingToolAgentAction(ToolAgentAction): 38 | outputs: List[Union[str, dict]] = None 39 | """Output of the tool call.""" 40 | platform_params: dict = None 41 | 42 | 43 | def _best_effort_parse_drawing_tool_tool_calls( 44 | tool_call_chunks: List[dict], 45 | ) -> List[Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk]]: 46 | drawing_tool_chunk: List[ 47 | Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] 48 | ] = [] 49 | # Best-effort parsing allready parsed tool calls 50 | for drawing_tool in tool_call_chunks: 51 | if AdapterAllToolStructType.DRAWING_TOOL == drawing_tool["name"]: 52 | if isinstance(drawing_tool["args"], str): 53 | args_ = parse_partial_json(drawing_tool["args"]) 54 | else: 55 | args_ = drawing_tool["args"] 56 | if not isinstance(args_, dict): 57 | raise ValueError("Malformed args.") 58 | 59 | if "outputs" in args_: 60 | drawing_tool_chunk.append( 61 | AllToolsMessageToolCall( 62 | name=drawing_tool["name"], 63 | args=args_, 64 | id=drawing_tool["id"], 65 | ) 66 | ) 67 | else: 68 | drawing_tool_chunk.append( 69 | AllToolsMessageToolCallChunk( 70 | name=drawing_tool["name"], 71 | args=args_, 72 | id=drawing_tool["id"], 73 | index=drawing_tool.get("index"), 74 | ) 75 | ) 76 | 77 | return drawing_tool_chunk 78 | 79 | 80 | def _paser_drawing_tool_chunk_input( 81 | message: BaseMessage, 82 | drawing_tool_chunk: List[ 83 | Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] 84 | ], 85 | ) -> Deque[DrawingToolAgentAction]: 86 | try: 87 | input_log_chunk = [] 88 | 89 | outputs: List[List[dict]] = [] 90 | obj = object() 91 | for interpreter_chunk in drawing_tool_chunk: 92 | interpreter_chunk_args = interpreter_chunk.args 93 | 94 | if "input" in interpreter_chunk_args: 95 | input_log_chunk.append(interpreter_chunk_args["input"]) 96 | if "outputs" in interpreter_chunk_args: 97 | input_log_chunk.append(obj) 98 | outputs.append(interpreter_chunk_args["outputs"]) 99 | 100 | if input_log_chunk[-1] is not obj: 101 | input_log_chunk.append(obj) 102 | # segments the list based on these positions, and then concatenates each segment into a string 103 | # Find positions of object() instances 104 | positions = find_object_positions(input_log_chunk, obj) 105 | 106 | # Concatenate segments 107 | result_actions = concatenate_segments(input_log_chunk, positions) 108 | 109 | tool_call_id = drawing_tool_chunk[0].id if drawing_tool_chunk[0].id else "abc" 110 | drawing_tool_action_result_stack: Deque[DrawingToolAgentAction] = deque() 111 | for i, action in enumerate(result_actions): 112 | if len(result_actions) > len(outputs): 113 | outputs.insert(i, []) 114 | 115 | out_logs = [ 116 | f'' 117 | for logs in outputs[i] 118 | if "image" in logs 119 | ] 120 | 121 | out_str = "\n".join(out_logs) 122 | log = f"{action}\r\n{out_str}" 123 | 124 | drawing_tool_action = DrawingToolAgentAction( 125 | tool=AdapterAllToolStructType.DRAWING_TOOL, 126 | tool_input=action, 127 | outputs=outputs[i], 128 | log=log, 129 | message_log=[message], 130 | tool_call_id=tool_call_id, 131 | ) 132 | drawing_tool_action_result_stack.append(drawing_tool_action) 133 | return drawing_tool_action_result_stack 134 | 135 | except Exception as e: 136 | logger.error(f"Error parsing drawing_tool_chunk: {e}", exc_info=True) 137 | raise OutputParserException( 138 | f"Could not parse tool input: drawing_tool because {e}" 139 | ) 140 | -------------------------------------------------------------------------------- /langchain_glm/agents/output_parsers/function.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import logging 4 | from collections import deque 5 | from typing import Any, Deque, Dict, List, Union 6 | 7 | from langchain.agents.output_parsers.tools import ToolAgentAction 8 | from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish 9 | from langchain_core.exceptions import OutputParserException 10 | from langchain_core.messages import ( 11 | BaseMessage, 12 | ToolCall, 13 | ) 14 | from langchain_core.utils.json import parse_partial_json 15 | 16 | from langchain_glm.agent_toolkits.all_tools.struct_type import AdapterAllToolStructType 17 | from langchain_glm.agents.output_parsers._utils import ( 18 | concatenate_segments, 19 | find_object_positions, 20 | ) 21 | from langchain_glm.agents.output_parsers.base import ( 22 | AllToolsMessageToolCall, 23 | AllToolsMessageToolCallChunk, 24 | ) 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | def _best_effort_parse_function_tool_calls( 30 | tool_call_chunks: List[dict], 31 | ) -> List[Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk]]: 32 | function_chunk: List[ 33 | Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] 34 | ] = [] 35 | # Best-effort parsing allready parsed tool calls 36 | for function in tool_call_chunks: 37 | if function["name"] not in AdapterAllToolStructType.__members__.values(): 38 | if isinstance(function["args"], str): 39 | args_ = parse_partial_json(function["args"]) 40 | else: 41 | args_ = function["args"] 42 | if not isinstance(args_, dict): 43 | raise ValueError("Malformed args.") 44 | 45 | if len(args_.keys()) > 0: 46 | function_chunk.append( 47 | AllToolsMessageToolCall( 48 | name=function["name"], 49 | args=args_, 50 | id=function["id"], 51 | ) 52 | ) 53 | else: 54 | function_chunk.append( 55 | AllToolsMessageToolCallChunk( 56 | name=function["name"], 57 | args=args_, 58 | id=function["id"], 59 | index=function.get("index"), 60 | ) 61 | ) 62 | 63 | return function_chunk 64 | 65 | 66 | def _paser_function_chunk_input( 67 | message: BaseMessage, 68 | function_chunk: List[Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk]], 69 | ) -> Deque[ToolAgentAction]: 70 | try: 71 | function_action_result_stack: Deque[ToolAgentAction] = deque() 72 | for _chunk in function_chunk: 73 | if isinstance(_chunk, AllToolsMessageToolCall): 74 | function_name = _chunk.name 75 | _tool_input = _chunk.args 76 | tool_call_id = _chunk.id if _chunk.id else "abc" 77 | if "__arg1" in _tool_input: 78 | tool_input = _tool_input["__arg1"] 79 | else: 80 | tool_input = _tool_input 81 | 82 | content_msg = ( 83 | f"responded: {message.content}\n" if message.content else "\n" 84 | ) 85 | log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n" 86 | 87 | function_action_result_stack.append( 88 | ToolAgentAction( 89 | tool=function_name, 90 | tool_input=tool_input, 91 | log=log, 92 | message_log=[message], 93 | tool_call_id=tool_call_id, 94 | ) 95 | ) 96 | 97 | return function_action_result_stack 98 | 99 | except Exception as e: 100 | logger.error(f"Error parsing function_chunk: {e}", exc_info=True) 101 | raise OutputParserException(f"Error parsing function_chunk: {e} ") 102 | -------------------------------------------------------------------------------- /langchain_glm/agents/output_parsers/tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import logging 4 | from collections import deque 5 | from json import JSONDecodeError 6 | from typing import Any, Deque, Dict, List, Union 7 | 8 | from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish 9 | from langchain_core.exceptions import OutputParserException 10 | from langchain_core.messages import ( 11 | AIMessage, 12 | BaseMessage, 13 | ToolCall, 14 | ToolCallChunk, 15 | ) 16 | from langchain_core.utils.json import ( 17 | parse_partial_json, 18 | ) 19 | from zhipuai.core import BaseModel 20 | 21 | from langchain_glm.agent_toolkits.all_tools.struct_type import ( 22 | AdapterAllToolStructType, 23 | ) 24 | from langchain_glm.agents.output_parsers.base import ( 25 | AllToolsMessageToolCall, 26 | AllToolsMessageToolCallChunk, 27 | ) 28 | from langchain_glm.agents.output_parsers.code_interpreter import ( 29 | _best_effort_parse_code_interpreter_tool_calls, 30 | _paser_code_interpreter_chunk_input, 31 | ) 32 | from langchain_glm.agents.output_parsers.drawing_tool import ( 33 | _best_effort_parse_drawing_tool_tool_calls, 34 | _paser_drawing_tool_chunk_input, 35 | ) 36 | from langchain_glm.agents.output_parsers.function import ( 37 | _best_effort_parse_function_tool_calls, 38 | _paser_function_chunk_input, 39 | ) 40 | from langchain_glm.agents.output_parsers.web_browser import ( 41 | _best_effort_parse_web_browser_tool_calls, 42 | _paser_web_browser_chunk_input, 43 | ) 44 | from langchain_glm.chat_models.all_tools_message import ALLToolsMessageChunk 45 | 46 | logger = logging.getLogger(__name__) 47 | 48 | 49 | def paser_ai_message_to_tool_calls( 50 | message: BaseMessage, 51 | ): 52 | tool_calls = [] 53 | if message.tool_calls: 54 | tool_calls = message.tool_calls 55 | else: 56 | if not message.additional_kwargs.get("tool_calls"): 57 | return AgentFinish( 58 | return_values={"output": message.content}, log=str(message.content) 59 | ) 60 | # Best-effort parsing allready parsed tool calls 61 | for tool_call in message.additional_kwargs["tool_calls"]: 62 | if "function" == tool_call["type"]: 63 | function = tool_call["function"] 64 | function_name = function["name"] 65 | try: 66 | args = json.loads(function["arguments"] or "{}") 67 | tool_calls.append( 68 | ToolCall( 69 | name=function_name, 70 | args=args, 71 | id=tool_call["id"] if tool_call["id"] else "abc", 72 | ) 73 | ) 74 | except JSONDecodeError: 75 | raise OutputParserException( 76 | f"Could not parse tool input: {function} because " 77 | f"the `arguments` is not valid JSON." 78 | ) 79 | elif tool_call["type"] in AdapterAllToolStructType.__members__.values(): 80 | adapter_tool = tool_call[tool_call["type"]] 81 | 82 | tool_calls.append( 83 | ToolCall( 84 | name=tool_call["type"], 85 | args=adapter_tool if adapter_tool else {}, 86 | id=tool_call["id"] if tool_call["id"] else "abc", 87 | ) 88 | ) 89 | 90 | return tool_calls 91 | 92 | 93 | def parse_ai_message_to_tool_action( 94 | message: BaseMessage, 95 | ) -> Union[List[AgentAction], AgentFinish]: 96 | """Parse an AI message potentially containing tool_calls.""" 97 | if not isinstance(message, AIMessage): 98 | raise TypeError(f"Expected an AI message got {type(message)}") 99 | 100 | # TODO: parse platform tools built-in @langchain_glm.agents.zhipuai_all_tools.base._get_assistants_tool 101 | # type in the future "function" or "code_interpreter" 102 | # for @ToolAgentAction from langchain.agents.output_parsers.tools 103 | # import with langchain.agents.format_scratchpad.tools.format_to_tool_messages 104 | actions: List = [] 105 | tool_calls = paser_ai_message_to_tool_calls(message) 106 | if isinstance(tool_calls, AgentFinish): 107 | return tool_calls 108 | code_interpreter_action_result_stack: deque = deque() 109 | web_browser_action_result_stack: deque = deque() 110 | drawing_tool_result_stack: deque = deque() 111 | function_tool_result_stack: deque = deque() 112 | code_interpreter_chunk: List[ 113 | Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] 114 | ] = [] 115 | if message.tool_calls: 116 | if isinstance(message, ALLToolsMessageChunk): 117 | code_interpreter_chunk = _best_effort_parse_code_interpreter_tool_calls( 118 | message.tool_call_chunks 119 | ) 120 | else: 121 | code_interpreter_chunk = _best_effort_parse_code_interpreter_tool_calls( 122 | tool_calls 123 | ) 124 | 125 | if code_interpreter_chunk and len(code_interpreter_chunk) > 1: 126 | code_interpreter_action_result_stack = _paser_code_interpreter_chunk_input( 127 | message, code_interpreter_chunk 128 | ) 129 | 130 | drawing_tool_chunk: List[ 131 | Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] 132 | ] = [] 133 | if message.tool_calls: 134 | if isinstance(message, ALLToolsMessageChunk): 135 | drawing_tool_chunk = _best_effort_parse_drawing_tool_tool_calls( 136 | message.tool_call_chunks 137 | ) 138 | else: 139 | drawing_tool_chunk = _best_effort_parse_drawing_tool_tool_calls(tool_calls) 140 | 141 | if drawing_tool_chunk and len(drawing_tool_chunk) > 1: 142 | drawing_tool_result_stack = _paser_drawing_tool_chunk_input( 143 | message, drawing_tool_chunk 144 | ) 145 | 146 | web_browser_chunk: List[ 147 | Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] 148 | ] = [] 149 | if message.tool_calls: 150 | if isinstance(message, ALLToolsMessageChunk): 151 | web_browser_chunk = _best_effort_parse_web_browser_tool_calls( 152 | message.tool_call_chunks 153 | ) 154 | else: 155 | web_browser_chunk = _best_effort_parse_web_browser_tool_calls(tool_calls) 156 | 157 | if web_browser_chunk and len(web_browser_chunk) > 1: 158 | web_browser_action_result_stack = _paser_web_browser_chunk_input( 159 | message, web_browser_chunk 160 | ) 161 | 162 | # TODO: parse platform tools built-in @langchain_glm 163 | # delete AdapterAllToolStructType from tool_calls 164 | function_tool_chunk = _best_effort_parse_function_tool_calls(tool_calls) 165 | 166 | function_tool_result_stack = _paser_function_chunk_input( 167 | message, function_tool_chunk 168 | ) 169 | 170 | if isinstance(message, ALLToolsMessageChunk): 171 | call_chunks = _paser_object_positions(message.tool_call_chunks) 172 | 173 | for too_call_name in call_chunks: 174 | if too_call_name == AdapterAllToolStructType.CODE_INTERPRETER: 175 | actions.append(code_interpreter_action_result_stack.popleft()) 176 | elif too_call_name == AdapterAllToolStructType.WEB_BROWSER: 177 | actions.append(web_browser_action_result_stack.popleft()) 178 | elif too_call_name == AdapterAllToolStructType.DRAWING_TOOL: 179 | actions.append(drawing_tool_result_stack.popleft()) 180 | else: 181 | actions.append(function_tool_result_stack.popleft()) 182 | else: 183 | for too_call in tool_calls: 184 | if too_call["name"] not in AdapterAllToolStructType.__members__.values(): 185 | actions.append(function_tool_result_stack.popleft()) 186 | elif too_call["name"] == AdapterAllToolStructType.CODE_INTERPRETER: 187 | actions.append(code_interpreter_action_result_stack.popleft()) 188 | elif too_call["name"] == AdapterAllToolStructType.WEB_BROWSER: 189 | actions.append(web_browser_action_result_stack.popleft()) 190 | elif too_call["name"] == AdapterAllToolStructType.DRAWING_TOOL: 191 | actions.append(drawing_tool_result_stack.popleft()) 192 | 193 | return actions 194 | 195 | 196 | def _paser_object_positions(tool_call_chunks: List[ToolCallChunk]): 197 | call_chunks = [] 198 | last_name = None 199 | if not tool_call_chunks: 200 | return call_chunks 201 | for call_chunk in tool_call_chunks: 202 | if call_chunk["name"] in AdapterAllToolStructType.__members__.values(): 203 | if isinstance(call_chunk["args"], str): 204 | args_ = parse_partial_json(call_chunk["args"]) 205 | else: 206 | args_ = call_chunk["args"] 207 | if not isinstance(args_, dict): 208 | raise ValueError("Malformed args.") 209 | 210 | if "outputs" in args_: 211 | call_chunks.append(call_chunk["name"]) 212 | last_name = call_chunk["name"] 213 | 214 | else: 215 | if call_chunk["name"] != last_name: 216 | call_chunks.append(call_chunk["name"]) 217 | last_name = call_chunk["name"] 218 | 219 | if len(call_chunks) == 0: 220 | call_chunks.append(tool_call_chunks[-1]["name"]) 221 | elif tool_call_chunks[-1]["name"] != call_chunks[-1]: 222 | call_chunks.append(tool_call_chunks[-1]["name"]) 223 | return call_chunks 224 | -------------------------------------------------------------------------------- /langchain_glm/agents/output_parsers/web_browser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import logging 4 | from collections import deque 5 | from json import JSONDecodeError 6 | from typing import Any, Deque, Dict, List, Union 7 | 8 | from langchain.agents.output_parsers.tools import ToolAgentAction 9 | from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish 10 | from langchain_core.exceptions import OutputParserException 11 | from langchain_core.messages import ( 12 | AIMessage, 13 | BaseMessage, 14 | ToolCall, 15 | ) 16 | from langchain_core.utils.json import ( 17 | parse_partial_json, 18 | ) 19 | from zhipuai.core import BaseModel 20 | 21 | from langchain_glm.agent_toolkits.all_tools.struct_type import ( 22 | AdapterAllToolStructType, 23 | ) 24 | from langchain_glm.agents.output_parsers._utils import ( 25 | concatenate_segments, 26 | find_object_positions, 27 | ) 28 | from langchain_glm.agents.output_parsers.base import ( 29 | AllToolsMessageToolCall, 30 | AllToolsMessageToolCallChunk, 31 | ) 32 | from langchain_glm.chat_models.all_tools_message import ALLToolsMessageChunk 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | class WebBrowserAgentAction(ToolAgentAction): 38 | outputs: List[Union[str, dict]] = None 39 | """Output of the tool call.""" 40 | platform_params: dict = None 41 | 42 | 43 | def _best_effort_parse_web_browser_tool_calls( 44 | tool_call_chunks: List[dict], 45 | ) -> List[Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk]]: 46 | web_browser_chunk: List[ 47 | Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] 48 | ] = [] 49 | # Best-effort parsing allready parsed tool calls 50 | for web_browser in tool_call_chunks: 51 | if AdapterAllToolStructType.WEB_BROWSER == web_browser["name"]: 52 | if isinstance(web_browser["args"], str): 53 | args_ = parse_partial_json(web_browser["args"]) 54 | else: 55 | args_ = web_browser["args"] 56 | if not isinstance(args_, dict): 57 | raise ValueError("Malformed args.") 58 | 59 | if "outputs" in args_: 60 | web_browser_chunk.append( 61 | AllToolsMessageToolCall( 62 | name=web_browser["name"], 63 | args=args_, 64 | id=web_browser["id"], 65 | ) 66 | ) 67 | else: 68 | web_browser_chunk.append( 69 | AllToolsMessageToolCallChunk( 70 | name=web_browser["name"], 71 | args=args_, 72 | id=web_browser["id"], 73 | index=web_browser.get("index"), 74 | ) 75 | ) 76 | 77 | return web_browser_chunk 78 | 79 | 80 | def _paser_web_browser_chunk_input( 81 | message: BaseMessage, 82 | web_browser_chunk: List[ 83 | Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] 84 | ], 85 | ) -> Deque[WebBrowserAgentAction]: 86 | try: 87 | input_log_chunk = [] 88 | 89 | outputs: List[List[dict]] = [] 90 | obj = object() 91 | for interpreter_chunk in web_browser_chunk: 92 | interpreter_chunk_args = interpreter_chunk.args 93 | 94 | if "input" in interpreter_chunk_args: 95 | input_log_chunk.append(interpreter_chunk_args["input"]) 96 | if "outputs" in interpreter_chunk_args: 97 | input_log_chunk.append(obj) 98 | outputs.append(interpreter_chunk_args["outputs"]) 99 | 100 | if input_log_chunk[-1] is not obj: 101 | input_log_chunk.append(obj) 102 | # segments the list based on these positions, and then concatenates each segment into a string 103 | # Find positions of object() instances 104 | positions = find_object_positions(input_log_chunk, obj) 105 | 106 | # Concatenate segments 107 | result_actions = concatenate_segments(input_log_chunk, positions) 108 | 109 | tool_call_id = web_browser_chunk[0].id if web_browser_chunk[0].id else "abc" 110 | web_browser_action_result_stack: Deque[WebBrowserAgentAction] = deque() 111 | for i, action in enumerate(result_actions): 112 | if len(result_actions) > len(outputs): 113 | outputs.insert(i, []) 114 | 115 | out_logs = [ 116 | f"title:{logs['title']}\nlink:{logs['link']}\ncontent:{logs['content']}" 117 | for logs in outputs[i] 118 | if "title" in logs 119 | ] 120 | out_str = "\n".join(out_logs) 121 | log = f"{action}\r\n{out_str}" 122 | web_browser_action = WebBrowserAgentAction( 123 | tool=AdapterAllToolStructType.WEB_BROWSER, 124 | tool_input=action, 125 | outputs=outputs[i], 126 | log=log, 127 | message_log=[message], 128 | tool_call_id=tool_call_id, 129 | ) 130 | web_browser_action_result_stack.append(web_browser_action) 131 | return web_browser_action_result_stack 132 | except Exception as e: 133 | logger.error(f"Error parsing web_browser_chunk: {e}", exc_info=True) 134 | raise OutputParserException(f"Could not parse tool input: web_browser {e} ") 135 | -------------------------------------------------------------------------------- /langchain_glm/agents/output_parsers/zhipuai_all_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import List, Union 3 | 4 | from langchain.agents.agent import MultiActionAgentOutputParser 5 | from langchain.agents.output_parsers.tools import ToolAgentAction 6 | from langchain_core.agents import AgentAction, AgentFinish 7 | from langchain_core.messages import BaseMessage 8 | from langchain_core.outputs import ChatGeneration, Generation 9 | 10 | from langchain_glm.agents.output_parsers.code_interpreter import ( 11 | CodeInterpreterAgentAction, 12 | ) 13 | from langchain_glm.agents.output_parsers.drawing_tool import DrawingToolAgentAction 14 | from langchain_glm.agents.output_parsers.tools import ( 15 | parse_ai_message_to_tool_action, 16 | ) 17 | from langchain_glm.agents.output_parsers.web_browser import WebBrowserAgentAction 18 | 19 | ZhipuAiALLToolAgentAction = ToolAgentAction 20 | 21 | 22 | def parse_ai_message_to_zhipuai_all_tool_action( 23 | message: BaseMessage, 24 | ) -> Union[List[AgentAction], AgentFinish]: 25 | """Parse an AI message potentially containing tool_calls.""" 26 | tool_actions = parse_ai_message_to_tool_action(message) 27 | if isinstance(tool_actions, AgentFinish): 28 | return tool_actions 29 | final_actions: List[AgentAction] = [] 30 | for action in tool_actions: 31 | if isinstance(action, CodeInterpreterAgentAction): 32 | final_actions.append(action) 33 | elif isinstance(action, DrawingToolAgentAction): 34 | final_actions.append(action) 35 | elif isinstance(action, WebBrowserAgentAction): 36 | final_actions.append(action) 37 | elif isinstance(action, ToolAgentAction): 38 | final_actions.append( 39 | ZhipuAiALLToolAgentAction( 40 | tool=action.tool, 41 | tool_input=action.tool_input, 42 | log=action.log, 43 | message_log=action.message_log, 44 | tool_call_id=action.tool_call_id, 45 | ) 46 | ) 47 | else: 48 | final_actions.append(action) 49 | return final_actions 50 | 51 | 52 | class ZhipuAiALLToolsAgentOutputParser(MultiActionAgentOutputParser): 53 | """Parses a message into agent actions/finish. 54 | 55 | Is meant to be used with OpenAI models, as it relies on the specific 56 | tool_calls parameter from OpenAI to convey what tools to use. 57 | 58 | If a tool_calls parameter is passed, then that is used to get 59 | the tool names and tool inputs. 60 | 61 | If one is not passed, then the AIMessage is assumed to be the final output. 62 | """ 63 | 64 | @property 65 | def _type(self) -> str: 66 | return "zhipuai-all-tools-agent-output-parser" 67 | 68 | def parse_result( 69 | self, result: List[Generation], *, partial: bool = False 70 | ) -> Union[List[AgentAction], AgentFinish]: 71 | if not isinstance(result[0], ChatGeneration): 72 | raise ValueError("This output parser only works on ChatGeneration output") 73 | message = result[0].message 74 | return parse_ai_message_to_zhipuai_all_tool_action(message) 75 | 76 | def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]: 77 | raise ValueError("Can only parse messages") 78 | -------------------------------------------------------------------------------- /langchain_glm/agents/zhipuai_all_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from langchain_glm.agents.zhipuai_all_tools.base import ( 3 | ZhipuAIAllToolsRunnable, 4 | ) 5 | from langchain_glm.agents.zhipuai_all_tools.schema import ( 6 | AllToolsAction, 7 | AllToolsActionToolEnd, 8 | AllToolsActionToolStart, 9 | AllToolsBaseComponent, 10 | AllToolsFinish, 11 | AllToolsLLMStatus, 12 | MsgType, 13 | ) 14 | 15 | __all__ = [ 16 | "ZhipuAIAllToolsRunnable", 17 | "MsgType", 18 | "AllToolsBaseComponent", 19 | "AllToolsAction", 20 | "AllToolsFinish", 21 | "AllToolsActionToolStart", 22 | "AllToolsActionToolEnd", 23 | "AllToolsLLMStatus", 24 | ] 25 | -------------------------------------------------------------------------------- /langchain_glm/agents/zhipuai_all_tools/schema.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import uuid 4 | from abc import abstractmethod 5 | from enum import Enum, auto 6 | from numbers import Number 7 | from typing import Any, Dict, List, Optional, Union 8 | 9 | from typing_extensions import ClassVar, Self 10 | from zhipuai.core import PYDANTIC_V2, BaseModel, ConfigDict 11 | 12 | 13 | class MsgType: 14 | TEXT = 1 15 | IMAGE = 2 16 | AUDIO = 3 17 | VIDEO = 4 18 | 19 | 20 | class AllToolsBaseComponent(BaseModel): 21 | if PYDANTIC_V2: 22 | model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) 23 | else: 24 | 25 | class Config: 26 | arbitrary_types_allowed = True 27 | 28 | @classmethod 29 | @abstractmethod 30 | def class_name(cls) -> str: 31 | """Get class name.""" 32 | 33 | def to_dict(self, **kwargs: Any) -> Dict[str, Any]: 34 | data = self.dict(**kwargs) 35 | data["class_name"] = self.class_name() 36 | return data 37 | 38 | def to_json(self, **kwargs: Any) -> str: 39 | data = self.to_dict(**kwargs) 40 | return json.dumps(data, ensure_ascii=False) 41 | 42 | # TODO: return type here not supported by current mypy version 43 | @classmethod 44 | def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore 45 | if isinstance(kwargs, dict): 46 | data.update(kwargs) 47 | 48 | data.pop("class_name", None) 49 | return cls(**data) 50 | 51 | @classmethod 52 | def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore 53 | data = json.loads(data_str) 54 | return cls.from_dict(data, **kwargs) 55 | 56 | 57 | class AllToolsAction(AllToolsBaseComponent): 58 | """AgentFinish with run and thread metadata.""" 59 | 60 | run_id: str 61 | status: int # AgentStatus 62 | tool: str 63 | tool_input: Union[str, Dict[str, str], Dict[str, Number]] 64 | log: str 65 | 66 | @classmethod 67 | def class_name(cls) -> str: 68 | return "AllToolsAction" 69 | 70 | 71 | class AllToolsFinish(AllToolsBaseComponent): 72 | """AgentFinish with run and thread metadata.""" 73 | 74 | run_id: str 75 | status: int # AgentStatus 76 | return_values: Dict[str, str] 77 | log: str 78 | 79 | @classmethod 80 | def class_name(cls) -> str: 81 | return "AllToolsFinish" 82 | 83 | 84 | class AllToolsActionToolStart(AllToolsBaseComponent): 85 | """AllToolsAction with run and thread metadata.""" 86 | 87 | run_id: str 88 | status: int # AgentStatus 89 | tool: str 90 | tool_input: Optional[str] = None 91 | 92 | @classmethod 93 | def class_name(cls) -> str: 94 | return "AllToolsActionToolStart" 95 | 96 | 97 | class AllToolsActionToolEnd(AllToolsBaseComponent): 98 | """AllToolsActionToolEnd with run and thread metadata.""" 99 | 100 | run_id: str 101 | 102 | status: int # AgentStatus 103 | tool: str 104 | tool_output: str 105 | 106 | @classmethod 107 | def class_name(cls) -> str: 108 | return "AllToolsActionToolEnd" 109 | 110 | 111 | class AllToolsLLMStatus(AllToolsBaseComponent): 112 | run_id: str 113 | status: int # AgentStatus 114 | text: str 115 | message_type: int = MsgType.TEXT 116 | 117 | @classmethod 118 | def class_name(cls) -> str: 119 | return "AllToolsLLMStatus" 120 | -------------------------------------------------------------------------------- /langchain_glm/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """**Callback handlers** allow listening to events in LangChain. 3 | 4 | **Class hierarchy:** 5 | 6 | .. code-block:: 7 | 8 | BaseCallbackHandler --> CallbackHandler # Example: AimCallbackHandler 9 | """ 10 | from langchain_glm.callbacks.agent_callback_handler import ( 11 | AgentExecutorAsyncIteratorCallbackHandler, 12 | ) 13 | 14 | __all__ = [ 15 | "AgentExecutorAsyncIteratorCallbackHandler", 16 | ] 17 | -------------------------------------------------------------------------------- /langchain_glm/callbacks/agent_callback_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import annotations 3 | 4 | import asyncio 5 | import json 6 | from typing import Any, Dict, List, Optional, Tuple 7 | from uuid import UUID 8 | 9 | from langchain.callbacks import AsyncIteratorCallbackHandler 10 | from langchain.schema import AgentAction, AgentFinish 11 | from langchain_core.outputs import LLMResult 12 | 13 | from langchain_glm.agent_toolkits import BaseToolOutput 14 | from langchain_glm.utils import History 15 | 16 | 17 | def dumps(obj: Dict) -> str: 18 | return json.dumps(obj, ensure_ascii=False) 19 | 20 | 21 | class AgentStatus: 22 | chain_start: int = 0 23 | llm_start: int = 1 24 | llm_new_token: int = 2 25 | llm_end: int = 3 26 | agent_action: int = 4 27 | agent_finish: int = 5 28 | tool_start: int = 6 29 | tool_end: int = 7 30 | error: int = -1 31 | chain_end: int = -999 32 | 33 | 34 | class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): 35 | def __init__(self): 36 | super().__init__() 37 | self.queue = asyncio.Queue() 38 | self.done = asyncio.Event() 39 | self.out = False 40 | self.intermediate_steps: List[Tuple[AgentAction, BaseToolOutput]] = [] 41 | self.outputs: Dict[str, Any] = {} 42 | 43 | async def on_llm_start( 44 | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any 45 | ) -> None: 46 | data = { 47 | "status": AgentStatus.llm_start, 48 | "text": "", 49 | } 50 | self.out = False 51 | self.done.clear() 52 | self.queue.put_nowait(dumps(data)) 53 | 54 | async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: 55 | special_tokens = ["\nAction:", "\nObservation:", "<|observation|>"] 56 | for stoken in special_tokens: 57 | if stoken in token: 58 | before_action = token.split(stoken)[0] 59 | data = { 60 | "status": AgentStatus.llm_new_token, 61 | "text": before_action + "\n", 62 | } 63 | self.queue.put_nowait(dumps(data)) 64 | self.out = False 65 | break 66 | 67 | if token is not None and token != "" and not self.out: 68 | data = { 69 | "run_id": str(kwargs["run_id"]), 70 | "status": AgentStatus.llm_new_token, 71 | "text": token, 72 | } 73 | self.queue.put_nowait(dumps(data)) 74 | 75 | async def on_chat_model_start( 76 | self, 77 | serialized: Dict[str, Any], 78 | messages: List[List], 79 | *, 80 | run_id: UUID, 81 | parent_run_id: Optional[UUID] = None, 82 | tags: Optional[List[str]] = None, 83 | metadata: Optional[Dict[str, Any]] = None, 84 | **kwargs: Any, 85 | ) -> None: 86 | data = { 87 | "run_id": str(run_id), 88 | "status": AgentStatus.llm_start, 89 | "text": "", 90 | } 91 | self.done.clear() 92 | self.queue.put_nowait(dumps(data)) 93 | 94 | async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: 95 | data = { 96 | "run_id": str(kwargs["run_id"]), 97 | "status": AgentStatus.llm_end, 98 | "text": response.generations[0][0].message.content, 99 | } 100 | 101 | self.queue.put_nowait(dumps(data)) 102 | 103 | async def on_llm_error( 104 | self, error: Exception | KeyboardInterrupt, **kwargs: Any 105 | ) -> None: 106 | data = { 107 | "status": AgentStatus.error, 108 | "text": str(error), 109 | } 110 | self.queue.put_nowait(dumps(data)) 111 | 112 | async def on_tool_start( 113 | self, 114 | serialized: Dict[str, Any], 115 | input_str: str, 116 | *, 117 | run_id: UUID, 118 | parent_run_id: Optional[UUID] = None, 119 | tags: Optional[List[str]] = None, 120 | metadata: Optional[Dict[str, Any]] = None, 121 | **kwargs: Any, 122 | ) -> None: 123 | data = { 124 | "run_id": str(run_id), 125 | "status": AgentStatus.tool_start, 126 | "tool": serialized["name"], 127 | "tool_input": input_str, 128 | } 129 | self.done.clear() 130 | self.queue.put_nowait(dumps(data)) 131 | 132 | async def on_tool_end( 133 | self, 134 | output: Any, 135 | *, 136 | run_id: UUID, 137 | parent_run_id: Optional[UUID] = None, 138 | tags: Optional[List[str]] = None, 139 | **kwargs: Any, 140 | ) -> None: 141 | """Run when tool ends running.""" 142 | data = { 143 | "run_id": str(run_id), 144 | "status": AgentStatus.tool_end, 145 | "tool": kwargs["name"], 146 | "tool_output": str(output), 147 | } 148 | self.queue.put_nowait(dumps(data)) 149 | 150 | async def on_tool_error( 151 | self, 152 | error: BaseException, 153 | *, 154 | run_id: UUID, 155 | parent_run_id: Optional[UUID] = None, 156 | tags: Optional[List[str]] = None, 157 | **kwargs: Any, 158 | ) -> None: 159 | """Run when tool errors.""" 160 | data = { 161 | "run_id": str(run_id), 162 | "status": AgentStatus.error, 163 | "tool_output": str(error), 164 | "is_error": True, 165 | } 166 | 167 | self.queue.put_nowait(dumps(data)) 168 | 169 | async def on_agent_action( 170 | self, 171 | action: AgentAction, 172 | *, 173 | run_id: UUID, 174 | parent_run_id: Optional[UUID] = None, 175 | tags: Optional[List[str]] = None, 176 | **kwargs: Any, 177 | ) -> None: 178 | data = { 179 | "run_id": str(run_id), 180 | "status": AgentStatus.agent_action, 181 | "action": { 182 | "tool": action.tool, 183 | "tool_input": action.tool_input, 184 | "log": action.log, 185 | }, 186 | } 187 | self.queue.put_nowait(dumps(data)) 188 | 189 | async def on_agent_finish( 190 | self, 191 | finish: AgentFinish, 192 | *, 193 | run_id: UUID, 194 | parent_run_id: Optional[UUID] = None, 195 | tags: Optional[List[str]] = None, 196 | **kwargs: Any, 197 | ) -> None: 198 | if "Thought:" in finish.return_values["output"]: 199 | finish.return_values["output"] = finish.return_values["output"].replace( 200 | "Thought:", "" 201 | ) 202 | 203 | data = { 204 | "run_id": str(run_id), 205 | "status": AgentStatus.agent_finish, 206 | "finish": { 207 | "return_values": finish.return_values, 208 | "log": finish.log, 209 | }, 210 | } 211 | 212 | self.queue.put_nowait(dumps(data)) 213 | 214 | async def on_chain_start( 215 | self, 216 | serialized: Dict[str, Any], 217 | inputs: Dict[str, Any], 218 | *, 219 | run_id: UUID, 220 | parent_run_id: Optional[UUID] = None, 221 | tags: Optional[List[str]] = None, 222 | metadata: Optional[Dict[str, Any]] = None, 223 | **kwargs: Any, 224 | ) -> None: 225 | """Run when chain starts running.""" 226 | if "agent_scratchpad" in inputs: 227 | del inputs["agent_scratchpad"] 228 | if "chat_history" in inputs: 229 | inputs["chat_history"] = [ 230 | History.from_message(message).to_msg_tuple() 231 | for message in inputs["chat_history"] 232 | ] 233 | data = { 234 | "run_id": str(run_id), 235 | "status": AgentStatus.chain_start, 236 | "inputs": inputs, 237 | "parent_run_id": parent_run_id, 238 | "tags": tags, 239 | "metadata": metadata, 240 | } 241 | 242 | self.done.clear() 243 | self.out = False 244 | self.queue.put_nowait(dumps(data)) 245 | 246 | async def on_chain_error( 247 | self, 248 | error: BaseException, 249 | *, 250 | run_id: UUID, 251 | parent_run_id: Optional[UUID] = None, 252 | tags: Optional[List[str]] = None, 253 | **kwargs: Any, 254 | ) -> None: 255 | """Run when chain errors.""" 256 | data = { 257 | "run_id": str(run_id), 258 | "status": AgentStatus.error, 259 | "error": str(error), 260 | } 261 | self.queue.put_nowait(dumps(data)) 262 | 263 | async def on_chain_end( 264 | self, 265 | outputs: Dict[str, Any], 266 | *, 267 | run_id: UUID, 268 | parent_run_id: UUID | None = None, 269 | tags: List[str] | None = None, 270 | **kwargs: Any, 271 | ) -> None: 272 | if "intermediate_steps" in outputs: 273 | self.intermediate_steps = outputs["intermediate_steps"] 274 | self.outputs = outputs 275 | del outputs["intermediate_steps"] 276 | data = { 277 | "run_id": str(run_id), 278 | "status": AgentStatus.chain_end, 279 | "outputs": outputs, 280 | "parent_run_id": parent_run_id, 281 | "tags": tags, 282 | } 283 | self.queue.put_nowait(dumps(data)) 284 | self.out = True 285 | # self.done.set() 286 | -------------------------------------------------------------------------------- /langchain_glm/chat_models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from langchain_glm.chat_models.base import ChatZhipuAI 3 | 4 | __all__ = [ 5 | "ChatZhipuAI", 6 | ] 7 | -------------------------------------------------------------------------------- /langchain_glm/chat_models/all_tools_message.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | from typing import Any, Dict, List, Literal, Union 4 | 5 | from langchain_core.messages import AIMessage 6 | from langchain_core.messages.base import ( 7 | BaseMessage, 8 | BaseMessageChunk, 9 | merge_content, 10 | ) 11 | from langchain_core.messages.tool import ( 12 | InvalidToolCall, 13 | ToolCall, 14 | ToolCallChunk, 15 | default_tool_chunk_parser, 16 | default_tool_parser, 17 | tool_call_chunk, 18 | ) 19 | from langchain_core.pydantic_v1 import root_validator 20 | from langchain_core.utils._merge import merge_dicts, merge_lists 21 | from langchain_core.utils.json import ( 22 | parse_partial_json, 23 | ) 24 | 25 | 26 | def default_all_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]: 27 | """Best-effort parsing of all tool chunks.""" 28 | tool_call_chunks = [] 29 | for tool_call in raw_tool_calls: 30 | if "function" in tool_call and tool_call["function"] is not None: 31 | function_args = tool_call["function"]["arguments"] 32 | function_name = tool_call["function"]["name"] 33 | elif ( 34 | "code_interpreter" in tool_call 35 | and tool_call["code_interpreter"] is not None 36 | ): 37 | function_args = json.dumps( 38 | tool_call["code_interpreter"], ensure_ascii=False 39 | ) 40 | function_name = "code_interpreter" 41 | elif "drawing_tool" in tool_call and tool_call["drawing_tool"] is not None: 42 | function_args = json.dumps(tool_call["drawing_tool"], ensure_ascii=False) 43 | function_name = "drawing_tool" 44 | elif "web_browser" in tool_call and tool_call["web_browser"] is not None: 45 | function_args = json.dumps(tool_call["web_browser"], ensure_ascii=False) 46 | function_name = "web_browser" 47 | else: 48 | function_args = None 49 | function_name = None 50 | parsed = ToolCallChunk( 51 | name=function_name, 52 | args=function_args, 53 | id=tool_call.get("id"), 54 | index=tool_call.get("index"), 55 | ) 56 | tool_call_chunks.append(parsed) 57 | return tool_call_chunks 58 | 59 | 60 | class ALLToolsMessageChunk(AIMessage, BaseMessageChunk): 61 | """Message chunk from an AI.""" 62 | 63 | # Ignoring mypy re-assignment here since we're overriding the value 64 | # to make sure that the chunk variant can be discriminated from the 65 | # non-chunk variant. 66 | type: Literal["ALLToolsMessageChunk"] = "ALLToolsMessageChunk" # type: ignore[assignment] # noqa: E501 67 | 68 | tool_call_chunks: List[ToolCallChunk] = [] 69 | """If provided, tool call chunks associated with the message.""" 70 | 71 | @classmethod 72 | def get_lc_namespace(cls) -> List[str]: 73 | """Get the namespace of the langchain object.""" 74 | return ["langchain", "schema", "messages"] 75 | 76 | @property 77 | def lc_attributes(self) -> Dict: 78 | """Attrs to be serialized even if they are derived from other init args.""" 79 | return { 80 | "tool_calls": self.tool_calls, 81 | "invalid_tool_calls": self.invalid_tool_calls, 82 | } 83 | 84 | @root_validator(allow_reuse=True) 85 | def _backwards_compat_tool_calls(cls, values: dict) -> dict: 86 | raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls") 87 | tool_calls = ( 88 | values.get("tool_calls") 89 | or values.get("invalid_tool_calls") 90 | or values.get("tool_call_chunks") 91 | ) 92 | if raw_tool_calls and not tool_calls: 93 | try: 94 | if issubclass(cls, BaseMessageChunk): # type: ignore 95 | values["tool_call_chunks"] = default_all_tool_chunk_parser( 96 | raw_tool_calls 97 | ) 98 | else: 99 | tool_calls, invalid_tool_calls = default_tool_parser(raw_tool_calls) 100 | values["tool_calls"] = tool_calls 101 | values["invalid_tool_calls"] = invalid_tool_calls 102 | except Exception as e: 103 | pass 104 | return values 105 | 106 | @root_validator(allow_reuse=True) 107 | def init_tool_calls(cls, values: dict) -> dict: 108 | if not values["tool_call_chunks"]: 109 | if values["tool_calls"]: 110 | values["tool_call_chunks"] = [ 111 | tool_call_chunk( 112 | name=tc["name"], 113 | args=json.dumps(tc["args"]), 114 | id=tc["id"], 115 | index=None, 116 | ) 117 | for tc in values["tool_calls"] 118 | ] 119 | if values["invalid_tool_calls"]: 120 | tool_call_chunks = values.get("tool_call_chunks", []) 121 | tool_call_chunks.extend( 122 | [ 123 | tool_call_chunk( 124 | name=tc["name"], args=tc["args"], id=tc["id"], index=None 125 | ) 126 | for tc in values["invalid_tool_calls"] 127 | ] 128 | ) 129 | values["tool_call_chunks"] = tool_call_chunks 130 | 131 | return values 132 | 133 | tool_calls, invalid_tool_calls = _paser_chunk(values["tool_call_chunks"]) 134 | values["tool_calls"] = tool_calls 135 | values["invalid_tool_calls"] = invalid_tool_calls 136 | return values 137 | 138 | def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore 139 | if isinstance(other, ALLToolsMessageChunk): 140 | if self.example != other.example: 141 | raise ValueError( 142 | "Cannot concatenate ALLToolsMessageChunks with different example values." 143 | ) 144 | 145 | content = merge_content(self.content, other.content) 146 | additional_kwargs = merge_dicts( 147 | self.additional_kwargs, other.additional_kwargs 148 | ) 149 | response_metadata = merge_dicts( 150 | self.response_metadata, other.response_metadata 151 | ) 152 | 153 | # Merge tool call chunks 154 | if self.tool_call_chunks or other.tool_call_chunks: 155 | raw_tool_calls = merge_lists( 156 | self.tool_call_chunks, 157 | other.tool_call_chunks, 158 | ) 159 | if raw_tool_calls: 160 | tool_call_chunks = [ 161 | ToolCallChunk( 162 | name=rtc.get("name"), 163 | args=rtc.get("args"), 164 | index=rtc.get("index"), 165 | id=rtc.get("id"), 166 | ) 167 | for rtc in raw_tool_calls 168 | ] 169 | else: 170 | tool_call_chunks = [] 171 | else: 172 | tool_call_chunks = [] 173 | 174 | return self.__class__( 175 | example=self.example, 176 | content=content, 177 | additional_kwargs=additional_kwargs, 178 | tool_call_chunks=tool_call_chunks, 179 | response_metadata=response_metadata, 180 | id=self.id, 181 | ) 182 | 183 | return super().__add__(other) 184 | 185 | 186 | def _paser_chunk(tool_call_chunks): 187 | tool_calls = [] 188 | invalid_tool_calls = [] 189 | for chunk in tool_call_chunks: 190 | try: 191 | if "code_interpreter" in chunk["name"]: 192 | args_ = parse_partial_json(chunk["args"]) 193 | 194 | if not isinstance(args_, dict): 195 | raise ValueError("Malformed args.") 196 | 197 | if "outputs" in args_: 198 | tool_calls.append( 199 | ToolCall( 200 | name=chunk["name"] or "", 201 | args=args_, 202 | id=chunk["id"], 203 | ) 204 | ) 205 | 206 | else: 207 | invalid_tool_calls.append( 208 | InvalidToolCall( 209 | name=chunk["name"], 210 | args=chunk["args"], 211 | id=chunk["id"], 212 | error=None, 213 | ) 214 | ) 215 | elif "drawing_tool" in chunk["name"]: 216 | args_ = parse_partial_json(chunk["args"]) 217 | 218 | if not isinstance(args_, dict): 219 | raise ValueError("Malformed args.") 220 | 221 | if "outputs" in args_: 222 | tool_calls.append( 223 | ToolCall( 224 | name=chunk["name"] or "", 225 | args=args_, 226 | id=chunk["id"], 227 | ) 228 | ) 229 | 230 | else: 231 | invalid_tool_calls.append( 232 | InvalidToolCall( 233 | name=chunk["name"], 234 | args=chunk["args"], 235 | id=chunk["id"], 236 | error=None, 237 | ) 238 | ) 239 | elif "web_browser" in chunk["name"]: 240 | args_ = parse_partial_json(chunk["args"]) 241 | 242 | if not isinstance(args_, dict): 243 | raise ValueError("Malformed args.") 244 | 245 | if "outputs" in args_: 246 | tool_calls.append( 247 | ToolCall( 248 | name=chunk["name"] or "", 249 | args=args_, 250 | id=chunk["id"], 251 | ) 252 | ) 253 | 254 | else: 255 | invalid_tool_calls.append( 256 | InvalidToolCall( 257 | name=chunk["name"], 258 | args=chunk["args"], 259 | id=chunk["id"], 260 | error=None, 261 | ) 262 | ) 263 | else: 264 | args_ = parse_partial_json(chunk["args"]) 265 | 266 | if isinstance(args_, dict): 267 | temp_args_ = {} 268 | for key, value in args_.items(): 269 | key = key.strip() 270 | temp_args_[key] = value 271 | 272 | tool_calls.append( 273 | ToolCall( 274 | name=chunk["name"] or "", 275 | args=temp_args_, 276 | id=chunk["id"], 277 | ) 278 | ) 279 | else: 280 | raise ValueError("Malformed args.") 281 | except Exception: 282 | invalid_tool_calls.append( 283 | InvalidToolCall( 284 | name=chunk["name"], 285 | args=chunk["args"], 286 | id=chunk["id"], 287 | error=None, 288 | ) 289 | ) 290 | return tool_calls, invalid_tool_calls 291 | -------------------------------------------------------------------------------- /langchain_glm/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __all__ = [ 3 | "ZhipuAIAIEmbeddings", 4 | ] 5 | -------------------------------------------------------------------------------- /langchain_glm/embeddings/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import annotations 3 | 4 | import logging 5 | import os 6 | import warnings 7 | from typing import ( 8 | Any, 9 | Dict, 10 | Iterable, 11 | List, 12 | Literal, 13 | Mapping, 14 | Optional, 15 | Sequence, 16 | Set, 17 | Tuple, 18 | Union, 19 | cast, 20 | ) 21 | 22 | import zhipuai 23 | from langchain_core.embeddings import Embeddings 24 | from langchain_core.pydantic_v1 import ( 25 | BaseModel, 26 | Extra, 27 | Field, 28 | SecretStr, 29 | root_validator, 30 | ) 31 | from langchain_core.utils import ( 32 | convert_to_secret_str, 33 | get_from_dict_or_env, 34 | get_pydantic_field_names, 35 | ) 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | class ZhipuAIEmbeddings(BaseModel, Embeddings): 41 | """ZhipuAI embedding models. 42 | 43 | To use, you should have the 44 | environment variable ``OPENAI_API_KEY`` set with your API key or pass it 45 | as a named parameter to the constructor. 46 | 47 | Example: 48 | .. code-block:: python 49 | 50 | from langchain_glm import ZhipuAIEmbeddings 51 | 52 | zhipuai = ZhipuAIEmbeddings(model=""text_embedding") 53 | 54 | 55 | """ 56 | 57 | client: Any = Field(default=None, exclude=True) #: :meta private: 58 | model: str = "embedding-2" 59 | zhipuai_api_base: Optional[str] = Field(default=None, alias="base_url") 60 | """Base URL path for API requests, leave blank if not using a proxy or service 61 | emulator.""" 62 | zhipuai_proxy: Optional[str] = None 63 | embedding_ctx_length: int = 8191 64 | """The maximum number of tokens to embed at once.""" 65 | zhipuai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") 66 | """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" 67 | 68 | chunk_size: int = 1000 69 | """Maximum number of texts to embed in each batch""" 70 | max_retries: int = 2 71 | """Maximum number of retries to make when generating.""" 72 | request_timeout: Optional[Union[float, Tuple[float, float], Any]] = Field( 73 | default=None, alias="timeout" 74 | ) 75 | """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or 76 | None.""" 77 | headers: Any = None 78 | 79 | show_progress_bar: bool = False 80 | """Whether to show a progress bar when embedding.""" 81 | model_kwargs: Dict[str, Any] = Field(default_factory=dict) 82 | """Holds any model parameters valid for `create` call not explicitly specified.""" 83 | http_client: Union[Any, None] = None 84 | """Optional httpx.Client.""" 85 | 86 | 87 | class Config: 88 | """Configuration for this pydantic object.""" 89 | 90 | extra = Extra.forbid 91 | allow_population_by_field_name = True 92 | 93 | @root_validator(pre=True, allow_reuse=True) 94 | def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: 95 | """Build extra kwargs from additional params that were passed in.""" 96 | all_required_field_names = get_pydantic_field_names(cls) 97 | extra = values.get("model_kwargs", {}) 98 | for field_name in list(values): 99 | if field_name in extra: 100 | raise ValueError(f"Found {field_name} supplied twice.") 101 | if field_name not in all_required_field_names: 102 | warnings.warn( 103 | f"""WARNING! {field_name} is not default parameter. 104 | {field_name} was transferred to model_kwargs. 105 | Please confirm that {field_name} is what you intended.""" 106 | ) 107 | extra[field_name] = values.pop(field_name) 108 | 109 | invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) 110 | if invalid_model_kwargs: 111 | raise ValueError( 112 | f"Parameters {invalid_model_kwargs} should be specified explicitly. " 113 | f"Instead they were passed in as part of `model_kwargs` parameter." 114 | ) 115 | 116 | values["model_kwargs"] = extra 117 | return values 118 | 119 | @root_validator(allow_reuse=True) 120 | def validate_environment(cls, values: Dict) -> Dict: 121 | """Validate that api key and python package exists in environment.""" 122 | zhipuai_api_key = get_from_dict_or_env( 123 | values, "zhipuai_api_key", "ZHIPUAI_API_KEY" 124 | ) 125 | values["zhipuai_api_key"] = ( 126 | convert_to_secret_str(zhipuai_api_key) if zhipuai_api_key else None 127 | ) 128 | values["zhipuai_api_base"] = values["zhipuai_api_base"] or os.getenv( 129 | "OPENAI_API_BASE" 130 | ) 131 | values["zhipuai_api_type"] = get_from_dict_or_env( 132 | values, 133 | "zhipuai_api_type", 134 | "OPENAI_API_TYPE", 135 | default="", 136 | ) 137 | values["zhipuai_proxy"] = get_from_dict_or_env( 138 | values, 139 | "zhipuai_proxy", 140 | "OPENAI_PROXY", 141 | default="", 142 | ) 143 | 144 | client_params = { 145 | "api_key": values["zhipuai_api_key"].get_secret_value() 146 | if values["zhipuai_api_key"] 147 | else None, 148 | "base_url": values["zhipuai_api_base"], 149 | "timeout": values["request_timeout"], 150 | "max_retries": values["max_retries"], 151 | "http_client": values["http_client"], 152 | } 153 | if not values.get("client"): 154 | values["client"] = zhipuai.ZhipuAI(**client_params).embeddings 155 | return values 156 | 157 | @property 158 | def _invocation_params(self) -> Dict[str, Any]: 159 | params: Dict = {"model": self.model, **self.model_kwargs} 160 | return params 161 | 162 | def _get_len_safe_embeddings( 163 | self, texts: List[str], *, chunk_size: Optional[int] = None 164 | ) -> List[List[float]]: 165 | """ 166 | Generate length-safe embeddings for a list of texts. 167 | Args: 168 | texts (List[str]): A list of texts to embed. 169 | chunk_size (Optional[int]): The size of chunks for processing embeddings. 170 | 171 | Returns: 172 | List[List[float]]: A list of embeddings for each input text. 173 | """ 174 | 175 | _chunk_size = chunk_size or self.chunk_size 176 | 177 | if self.show_progress_bar: 178 | try: 179 | from tqdm.auto import tqdm 180 | 181 | _iter: Iterable = tqdm(range(0, len(texts), _chunk_size)) 182 | except ImportError: 183 | _iter = range(0, len(texts), _chunk_size) 184 | else: 185 | _iter = range(0, len(texts), _chunk_size) 186 | 187 | batched_embeddings: List[List[float]] = [] 188 | for i in _iter: 189 | response = self.client.create( 190 | input=texts[i : i + _chunk_size], **self._invocation_params 191 | ) 192 | if not isinstance(response, dict): 193 | response = response.dict() 194 | batched_embeddings.extend(r["embedding"] for r in response["data"]) 195 | 196 | return batched_embeddings 197 | 198 | def embed_documents( 199 | self, texts: List[str], chunk_size: Optional[int] = 0 200 | ) -> List[List[float]]: 201 | """Call out to OpenAI's embedding endpoint for embedding search docs. 202 | 203 | Args: 204 | texts: The list of texts to embed. 205 | chunk_size: The chunk size of embeddings. If None, will use the chunk size 206 | specified by the class. 207 | 208 | Returns: 209 | List of embeddings, one for each text. 210 | """ 211 | return self._get_len_safe_embeddings(texts) 212 | 213 | def embed_query(self, text: str) -> List[float]: 214 | """Call out to OpenAI's embedding endpoint for embedding query text. 215 | 216 | Args: 217 | text: The text to embed. 218 | 219 | Returns: 220 | Embedding for the text. 221 | """ 222 | return self.embed_documents([text])[0] 223 | -------------------------------------------------------------------------------- /langchain_glm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from langchain_glm.utils.history import History 3 | 4 | __all__ = ["History"] 5 | -------------------------------------------------------------------------------- /langchain_glm/utils/history.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | from functools import lru_cache 4 | from typing import Any, Dict, List, Tuple, Union 5 | 6 | from langchain.prompts.chat import ChatMessagePromptTemplate 7 | from langchain_core.messages import ( 8 | AIMessage, 9 | AIMessageChunk, 10 | BaseMessage, 11 | BaseMessageChunk, 12 | ChatMessage, 13 | ChatMessageChunk, 14 | FunctionMessage, 15 | FunctionMessageChunk, 16 | HumanMessage, 17 | HumanMessageChunk, 18 | SystemMessage, 19 | SystemMessageChunk, 20 | ToolMessage, 21 | ToolMessageChunk, 22 | ) 23 | from zhipuai.core import BaseModel 24 | 25 | logger = logging.getLogger() 26 | 27 | 28 | def _convert_message_to_dict(message: BaseMessage) -> dict: 29 | """Convert a LangChain message to a dictionary. 30 | 31 | Args: 32 | message: The LangChain message. 33 | 34 | Returns: 35 | The dictionary. 36 | """ 37 | message_dict: Dict[str, Any] 38 | if isinstance(message, ChatMessage): 39 | message_dict = {"role": message.role, "content": message.content} 40 | elif isinstance(message, HumanMessage): 41 | message_dict = {"role": "user", "content": message.content} 42 | elif isinstance(message, AIMessage): 43 | message_dict = {"role": "assistant", "content": message.content} 44 | if "function_call" in message.additional_kwargs: 45 | message_dict["function_call"] = message.additional_kwargs["function_call"] 46 | # If function call only, content is None not empty string 47 | if message_dict["content"] == "": 48 | message_dict["content"] = None 49 | if "tool_calls" in message.additional_kwargs: 50 | message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] 51 | # If tool calls only, content is None not empty string 52 | if message_dict["content"] == "": 53 | message_dict["content"] = None 54 | elif isinstance(message, SystemMessage): 55 | message_dict = {"role": "system", "content": message.content} 56 | elif isinstance(message, FunctionMessage): 57 | message_dict = { 58 | "role": "function", 59 | "content": message.content, 60 | "name": message.name, 61 | } 62 | elif isinstance(message, ToolMessage): 63 | message_dict = { 64 | "role": "tool", 65 | "content": message.content, 66 | "tool_call_id": message.tool_call_id, 67 | } 68 | else: 69 | raise TypeError(f"Got unknown type {message}") 70 | if "name" in message.additional_kwargs: 71 | message_dict["name"] = message.additional_kwargs["name"] 72 | return message_dict 73 | 74 | 75 | class History(BaseModel): 76 | """ 77 | 对话历史 78 | 可从dict生成,如 79 | h = History(**{"role":"user","content":"你好"}) 80 | 也可转换为tuple,如 81 | h.to_msy_tuple = ("human", "你好") 82 | """ 83 | 84 | role: str 85 | content: str 86 | 87 | def to_msg_tuple(self): 88 | return "ai" if self.role == "assistant" else "human", self.content 89 | 90 | def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate: 91 | role_maps = { 92 | "ai": "assistant", 93 | "human": "user", 94 | } 95 | role = role_maps.get(self.role, self.role) 96 | if is_raw: # 当前默认历史消息都是没有input_variable的文本。 97 | content = "{% raw %}" + self.content + "{% endraw %}" 98 | else: 99 | content = self.content 100 | 101 | return ChatMessagePromptTemplate.from_template( 102 | content, 103 | "jinja2", 104 | role=role, 105 | ) 106 | 107 | @classmethod 108 | def from_data(cls, h: Union[List, Tuple, Dict]) -> "History": 109 | if isinstance(h, (list, tuple)) and len(h) >= 2: 110 | h = cls(role=h[0], content=h[1]) 111 | elif isinstance(h, dict): 112 | h = cls(**h) 113 | 114 | return h 115 | 116 | @classmethod 117 | def from_message(cls, message: BaseMessage) -> "History": 118 | return cls.from_data(_convert_message_to_dict(message=message)) 119 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | in-project = true 3 | 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "langchain-glm" 3 | version = "0.0.2" 4 | description = "" 5 | authors = ["glide-the "] 6 | readme = "README.md" 7 | packages = [ 8 | {include = "langchain_glm"} 9 | ] 10 | license = "MIT" 11 | 12 | [tool.poetry.dependencies] 13 | python = ">=3.8.1,<3.12,!=3.9.7" 14 | zhipuai = { version = ">=2.1.0.20240521", python = ">=3.8.1,<3.12,!=3.9.7" } 15 | langchain = { version = ">=0.2.2,<0.3", python = ">=3.8.1,<3.12,!=3.9.7" } 16 | langchainhub = { version = ">=0.1.14,<0.3", python=">=3.8.1,<3.9.7 || >3.9.7,<4.0" } 17 | langchain-community = { version = ">=0.2.0,<0.3", python=">=3.8.1,<3.9.7 || >3.9.7,<4.0" } 18 | langchain-experimental = { version = ">=0.0.58,<0.3", python=">=3.8.1,<3.9.7 || >3.9.7,<4.0" } 19 | 20 | 21 | 22 | [tool.poetry.group.test.dependencies] 23 | # The only dependencies that should be added are 24 | # dependencies used for running tests (e.g., pytest, freezegun, response). 25 | # Any dependencies that do not meet that criteria will be removed. 26 | pytest = "^7.3.0" 27 | pytest-cov = "^4.0.0" 28 | pytest-dotenv = "^0.5.2" 29 | duckdb-engine = "^0.9.2" 30 | pytest-watcher = "^0.2.6" 31 | freezegun = "^1.2.2" 32 | responses = "^0.22.0" 33 | pytest-asyncio = "^0.23.2" 34 | lark = "^1.1.5" 35 | pandas = "^2.0.0" 36 | pytest-mock = "^3.10.0" 37 | pytest-socket = "^0.6.0" 38 | syrupy = "^4.0.2" 39 | requests-mock = "^1.11.0" 40 | 41 | langchain-openai = { version = ">=0.0.6", python = ">=3.8.1,<3.12,!=3.9.7" } 42 | 43 | 44 | [tool.poetry.group.streamlit] 45 | optional = true 46 | 47 | [tool.poetry.group.streamlit.dependencies] 48 | 49 | fastapi = "~0.109.2" 50 | sse_starlette = "~1.8.2" 51 | uvicorn = ">=0.27.0.post1" 52 | # webui 53 | streamlit = "1.34.0" 54 | streamlit-option-menu = "0.3.12" 55 | streamlit-antd-components = "0.3.1" 56 | streamlit-chatbox = "1.1.12.post4" 57 | streamlit-modal = "0.1.0" 58 | streamlit-aggrid = "1.0.5" 59 | streamlit-extras = "0.4.2" 60 | 61 | 62 | [tool.poetry.group.lint] 63 | optional = true 64 | 65 | [tool.poetry.group.lint.dependencies] 66 | ruff = "^0.1.5" 67 | 68 | 69 | 70 | [tool.poetry.group.dev] 71 | optional = true 72 | 73 | [tool.poetry.group.dev.dependencies] 74 | jupyter = "^1.0.0" 75 | setuptools = "^67.6.1" 76 | 77 | [tool.poetry.extras] 78 | 79 | 80 | 81 | [tool.ruff] 82 | exclude = [ 83 | "tests/examples/non-utf8-encoding.py", 84 | "tests/integration_tests/examples/non-utf8-encoding.py", 85 | ] 86 | 87 | [tool.ruff.lint] 88 | select = [ 89 | "E", # pycodestyle 90 | "F", # pyflakes 91 | "I", # isort 92 | "T201", # print 93 | ] 94 | 95 | [tool.mypy] 96 | ignore_missing_imports = "True" 97 | disallow_untyped_defs = "True" 98 | exclude = ["notebooks", "examples", "example_data"] 99 | 100 | [tool.coverage.run] 101 | omit = [ 102 | "tests/*", 103 | ] 104 | 105 | [build-system] 106 | requires = ["poetry-core>=1.0.0", "poetry-plugin-pypi-mirror==0.4.2"] 107 | build-backend = "poetry.core.masonry.api" 108 | 109 | [tool.pytest.ini_options] 110 | # --strict-markers will raise errors on unknown marks. 111 | # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks 112 | # 113 | # https://docs.pytest.org/en/7.1.x/reference/reference.html 114 | # --strict-config any warnings encountered while parsing the `pytest` 115 | # section of the configuration file raise errors. 116 | # 117 | # https://github.com/tophat/syrupy 118 | # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. 119 | addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv" 120 | # Registering custom markers. 121 | # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers 122 | markers = [ 123 | "requires: mark tests as requiring a specific library", 124 | "scheduled: mark tests to run in scheduled testing", 125 | "compile: mark placeholder test used to compile integration tests without running them" 126 | ] 127 | asyncio_mode = "auto" 128 | 129 | 130 | 131 | [tool.codespell] 132 | skip = '.git,*.pdf,*.svg,*.pdf,*.yaml,*.ipynb,poetry.lock,*.min.js,*.css,package-lock.json,example_data,_dist,examples,*.trig,*.json,*.md,*.html,*.txt,*.csv' 133 | # Ignore latin etc 134 | ignore-regex = '.*(Stati Uniti|Tense=Pres).*' 135 | # whats is a typo but used frequently in queries so kept as is 136 | # aapply - async apply 137 | # unsecure - typo but part of API, decided to not bother for now 138 | ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin' 139 | 140 | [tool.poetry.plugins.dotenv] 141 | ignore = "false" 142 | dotenv = "dotenv:plugin" 143 | 144 | # https://python-poetry.org/docs/repositories/ 145 | #[[tool.poetry.source]] 146 | #name = "tsinghua" 147 | #url = "https://pypi.tuna.tsinghua.edu.cn/simple/" 148 | #priority = "default" 149 | -------------------------------------------------------------------------------- /scripts/add_encoding_declaration.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | 5 | def add_encoding_declaration(directory): 6 | for root, dirs, files in os.walk(directory): 7 | for file in files: 8 | if file.endswith(".py"): 9 | file_path = os.path.join(root, file) 10 | with open(file_path, "r+", encoding="utf-8") as f: 11 | content = f.read() 12 | if not content.startswith("# -*- coding: utf-8 -*-"): 13 | f.seek(0, 0) 14 | f.write("# -*- coding: utf-8 -*-\n" + content) 15 | 16 | 17 | if __name__ == "__main__": 18 | # 使用你的项目路径 19 | project_directory = "/media/gpt4-pdf-chatbot-langchain/langchain-zhipuai" 20 | add_encoding_declaration(project_directory) 21 | -------------------------------------------------------------------------------- /scripts/check_imports.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | import string 4 | import sys 5 | import traceback 6 | from importlib.machinery import SourceFileLoader 7 | 8 | if __name__ == "__main__": 9 | files = sys.argv[1:] 10 | has_failure = False 11 | for file in files: 12 | try: 13 | module_name = "".join( 14 | random.choice(string.ascii_letters) for _ in range(20) 15 | ) 16 | SourceFileLoader(module_name, file).load_module() 17 | except Exception: 18 | has_failure = True 19 | print(file) # noqa: T201 20 | traceback.print_exc() 21 | print() # noqa: T201 22 | 23 | sys.exit(1 if has_failure else 0) 24 | -------------------------------------------------------------------------------- /scripts/check_pydantic.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script searches for lines starting with "import pydantic" or "from pydantic" 4 | # in tracked files within a Git repository. 5 | # 6 | # Usage: ./scripts/check_pydantic.sh /path/to/repository 7 | 8 | # Check if a path argument is provided 9 | if [ $# -ne 1 ]; then 10 | echo "Usage: $0 /path/to/repository" 11 | exit 1 12 | fi 13 | 14 | repository_path="$1" 15 | 16 | # Search for lines matching the pattern within the specified repository 17 | result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') 18 | 19 | # Check if any matching lines were found 20 | if [ -n "$result" ]; then 21 | echo "ERROR: The following lines need to be updated:" 22 | echo "$result" 23 | echo "Please replace the code with an import from langchain_core.pydantic_v1." 24 | echo "For example, replace 'from pydantic import BaseModel'" 25 | echo "with 'from langchain_core.pydantic_v1 import BaseModel'" 26 | exit 1 27 | fi 28 | -------------------------------------------------------------------------------- /scripts/lint_imports.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | # Initialize a variable to keep track of errors 6 | errors=0 7 | 8 | # make sure not importing from langchain_glm 9 | git --no-pager grep '^from langchain_glm\.' . && errors=$((errors+1)) 10 | 11 | # Decide on an exit status based on the errors 12 | if [ "$errors" -gt 0 ]; then 13 | exit 1 14 | else 15 | exit 0 16 | fi 17 | -------------------------------------------------------------------------------- /tests/assistant/chatchat_icon_blue_square_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MetaGLM/langchain-zhipuai/29efaeacbcb4db7572c2f09e60a4196771eefd24/tests/assistant/chatchat_icon_blue_square_v2.png -------------------------------------------------------------------------------- /tests/assistant/client.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import contextlib 3 | import json 4 | import logging 5 | import os 6 | from typing import Any, Callable, Dict, Iterator, List, Tuple, Union 7 | 8 | import httpx 9 | 10 | # httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。 11 | HTTPX_DEFAULT_TIMEOUT = 300.0 12 | 13 | logger = logging.getLogger(__name__) 14 | log_verbose = False 15 | 16 | 17 | def get_httpx_client( 18 | use_async: bool = False, 19 | proxies: Union[str, Dict] = None, 20 | timeout: float = HTTPX_DEFAULT_TIMEOUT, 21 | unused_proxies: List[str] = [], 22 | **kwargs, 23 | ) -> Union[httpx.Client, httpx.AsyncClient]: 24 | """ 25 | helper to get httpx client with default proxies that bypass local addesses. 26 | """ 27 | default_proxies = { 28 | # do not use proxy for locahost 29 | "all://127.0.0.1": None, 30 | "all://localhost": None, 31 | } 32 | # do not use proxy for user deployed fastchat servers 33 | for x in unused_proxies: 34 | host = ":".join(x.split(":")[:2]) 35 | default_proxies.update({host: None}) 36 | 37 | # get proxies from system envionrent 38 | # proxy not str empty string, None, False, 0, [] or {} 39 | default_proxies.update( 40 | { 41 | "http://": ( 42 | os.environ.get("http_proxy") 43 | if os.environ.get("http_proxy") 44 | and len(os.environ.get("http_proxy").strip()) 45 | else None 46 | ), 47 | "https://": ( 48 | os.environ.get("https_proxy") 49 | if os.environ.get("https_proxy") 50 | and len(os.environ.get("https_proxy").strip()) 51 | else None 52 | ), 53 | "all://": ( 54 | os.environ.get("all_proxy") 55 | if os.environ.get("all_proxy") 56 | and len(os.environ.get("all_proxy").strip()) 57 | else None 58 | ), 59 | } 60 | ) 61 | for host in os.environ.get("no_proxy", "").split(","): 62 | if host := host.strip(): 63 | # default_proxies.update({host: None}) # Origin code 64 | default_proxies.update( 65 | {"all://" + host: None} 66 | ) # PR 1838 fix, if not add 'all://', httpx will raise error 67 | 68 | # merge default proxies with user provided proxies 69 | if isinstance(proxies, str): 70 | proxies = {"all://": proxies} 71 | 72 | if isinstance(proxies, dict): 73 | default_proxies.update(proxies) 74 | 75 | # construct Client 76 | kwargs.update(timeout=timeout, proxies=default_proxies) 77 | 78 | if log_verbose: 79 | logger.info(f"{get_httpx_client.__class__.__name__}:kwargs: {kwargs}") 80 | 81 | if use_async: 82 | return httpx.AsyncClient(**kwargs) 83 | else: 84 | return httpx.Client(**kwargs) 85 | 86 | 87 | class ZhipuAIPluginsClient: 88 | """ """ 89 | 90 | def __init__( 91 | self, 92 | base_url: str, 93 | timeout: float = HTTPX_DEFAULT_TIMEOUT, 94 | use_async: bool = False, 95 | ): 96 | self.base_url = base_url 97 | self.timeout = timeout 98 | self._use_async = use_async 99 | self._client = None 100 | 101 | @property 102 | def client(self): 103 | if self._client is None or self._client.is_closed: 104 | self._client = get_httpx_client( 105 | base_url=self.base_url, 106 | use_async=self._use_async, 107 | timeout=self.timeout, 108 | unused_proxies=[self.base_url], 109 | ) 110 | return self._client 111 | 112 | def get( 113 | self, 114 | url: str, 115 | params: Union[Dict, List[Tuple], bytes] = None, 116 | stream: bool = False, 117 | **kwargs: Any, 118 | ) -> Union[httpx.Response, Iterator[httpx.Response], None]: 119 | try: 120 | if stream: 121 | return self.client.stream("GET", url, params=params, **kwargs) 122 | else: 123 | return self.client.get(url, params=params, **kwargs) 124 | except Exception as e: 125 | msg = f"error when get {url}: {e}" 126 | logger.error( 127 | f"{e.__class__.__name__}: {msg}", exc_info=e if log_verbose else None 128 | ) 129 | 130 | def post( 131 | self, 132 | url: str, 133 | data: Dict = None, 134 | json: Dict = None, 135 | stream: bool = False, 136 | **kwargs: Any, 137 | ) -> Union[httpx.Response, Iterator[httpx.Response], None]: 138 | try: 139 | # print(kwargs) 140 | if stream: 141 | return self.client.stream("POST", url, data=data, json=json, **kwargs) 142 | else: 143 | return self.client.post(url, data=data, json=json, **kwargs) 144 | except Exception as e: 145 | msg = f"error when post {url}: {e}" 146 | logger.error( 147 | f"{e.__class__.__name__}: {msg}", exc_info=e if log_verbose else None 148 | ) 149 | 150 | def _httpx_stream2generator( 151 | self, 152 | response: contextlib._GeneratorContextManager, 153 | as_json: bool = False, 154 | ): 155 | """ 156 | 将httpx.stream返回的GeneratorContextManager转化为普通生成器 157 | """ 158 | 159 | async def ret_async(response, as_json): 160 | try: 161 | async with response as r: 162 | chunk_cache = "" 163 | async for chunk in r.aiter_text(None): 164 | if not chunk: # fastchat api yield empty bytes on start and end 165 | continue 166 | if as_json: 167 | try: 168 | if chunk.startswith("data: "): 169 | data = json.loads(chunk_cache + chunk[6:-2]) 170 | elif chunk.startswith(":"): # skip sse comment line 171 | continue 172 | else: 173 | data = json.loads(chunk_cache + chunk) 174 | 175 | chunk_cache = "" 176 | yield data 177 | except Exception as e: 178 | msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" 179 | logger.error( 180 | f"{e.__class__.__name__}: {msg}", 181 | exc_info=e if log_verbose else None, 182 | ) 183 | 184 | if chunk.startswith("data: "): 185 | chunk_cache += chunk[6:-2] 186 | elif chunk.startswith(":"): # skip sse comment line 187 | continue 188 | else: 189 | chunk_cache += chunk 190 | continue 191 | else: 192 | # print(chunk, end="", flush=True) 193 | yield chunk 194 | except httpx.ConnectError as e: 195 | msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" 196 | logger.error(msg) 197 | yield {"code": 500, "msg": msg} 198 | except httpx.ReadTimeout as e: 199 | msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})" 200 | logger.error(msg) 201 | yield {"code": 500, "msg": msg} 202 | except Exception as e: 203 | msg = f"API通信遇到错误:{e}" 204 | logger.error( 205 | f"{e.__class__.__name__}: {msg}", 206 | exc_info=e if log_verbose else None, 207 | ) 208 | yield {"code": 500, "msg": msg} 209 | 210 | def ret_sync(response, as_json): 211 | try: 212 | with response as r: 213 | chunk_cache = "" 214 | for chunk in r.iter_text(None): 215 | if not chunk: # fastchat api yield empty bytes on start and end 216 | continue 217 | if as_json: 218 | try: 219 | if chunk.startswith("data: "): 220 | data = json.loads(chunk_cache + chunk[6:-4]) 221 | elif chunk.startswith(":"): # skip sse comment line 222 | continue 223 | else: 224 | data = json.loads(chunk_cache + chunk) 225 | 226 | chunk_cache = "" 227 | yield data 228 | except Exception as e: 229 | if chunk.startswith("data: "): 230 | chunk_cache += chunk[6:] 231 | elif chunk.startswith(":"): # skip sse comment line 232 | continue 233 | else: 234 | chunk_cache += chunk 235 | continue 236 | else: 237 | # print(chunk, end="", flush=True) 238 | yield chunk 239 | except httpx.ConnectError as e: 240 | msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" 241 | logger.error(msg) 242 | yield {"code": 500, "msg": msg} 243 | except httpx.ReadTimeout as e: 244 | msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})" 245 | logger.error(msg) 246 | yield {"code": 500, "msg": msg} 247 | except Exception as e: 248 | msg = f"API通信遇到错误:{e}" 249 | logger.error( 250 | f"{e.__class__.__name__}: {msg}", 251 | exc_info=e if log_verbose else None, 252 | ) 253 | yield {"code": 500, "msg": msg} 254 | 255 | if self._use_async: 256 | return ret_async(response, as_json) 257 | else: 258 | return ret_sync(response, as_json) 259 | 260 | def _get_response_value( 261 | self, 262 | response: httpx.Response, 263 | as_json: bool = False, 264 | value_func: Callable = None, 265 | ): 266 | """ 267 | 转换同步或异步请求返回的响应 268 | `as_json`: 返回json 269 | `value_func`: 用户可以自定义返回值,该函数接受response或json 270 | """ 271 | 272 | def to_json(r): 273 | try: 274 | return r.json() 275 | except Exception as e: 276 | msg = "API未能返回正确的JSON。" + str(e) 277 | if log_verbose: 278 | logger.error( 279 | f"{e.__class__.__name__}: {msg}", 280 | exc_info=e if log_verbose else None, 281 | ) 282 | return {"code": 500, "msg": msg, "data": None} 283 | 284 | if value_func is None: 285 | value_func = lambda r: r 286 | 287 | async def ret_async(response): 288 | if as_json: 289 | return value_func(to_json(await response)) 290 | else: 291 | return value_func(await response) 292 | 293 | if self._use_async: 294 | return ret_async(response) 295 | else: 296 | if as_json: 297 | return value_func(to_json(response)) 298 | else: 299 | return value_func(response) 300 | 301 | def chat( 302 | self, 303 | query: str, 304 | history: List[Dict] = [], 305 | ): 306 | """ """ 307 | data = { 308 | "query": query, 309 | "history": history, 310 | } 311 | 312 | response = self.post( 313 | "/chat", 314 | json=data, 315 | stream=True, 316 | ) 317 | return self._httpx_stream2generator(response, as_json=True) 318 | -------------------------------------------------------------------------------- /tests/assistant/server/server.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging.config 3 | import threading 4 | from typing import List, Tuple 5 | 6 | from fastapi import APIRouter, Body, FastAPI, status 7 | from fastapi.middleware.cors import CORSMiddleware 8 | from langchain.agents import tool 9 | from langchain_community.tools import ShellTool 10 | from langchain_core.agents import AgentAction 11 | from pydantic.v1 import Extra, Field 12 | from sse_starlette.sse import EventSourceResponse 13 | from uvicorn import Config, Server 14 | from zhipuai.core.logs import ( 15 | get_config_dict, 16 | get_log_file, 17 | get_timestamp_ms, 18 | ) 19 | 20 | from langchain_glm.agent_toolkits import BaseToolOutput 21 | from langchain_glm.agents.zhipuai_all_tools import ZhipuAIAllToolsRunnable 22 | from langchain_glm.agents.zhipuai_all_tools.base import OutputType 23 | 24 | 25 | @tool 26 | def calculate(text: str = Field(description="a math expression")) -> BaseToolOutput: 27 | """ 28 | Useful to answer questions about simple calculations. 29 | translate user question to a math expression that can be evaluated by numexpr. 30 | """ 31 | import numexpr 32 | 33 | try: 34 | ret = str(numexpr.evaluate(text)) 35 | except Exception as e: 36 | ret = f"wrong: {e}" 37 | 38 | return BaseToolOutput(ret) 39 | 40 | 41 | @tool 42 | def shell(query: str = Field(description="The command to execute")): 43 | """Use Shell to execute system shell commands""" 44 | tool = ShellTool() 45 | return BaseToolOutput(tool.run(tool_input=query)) 46 | 47 | 48 | intermediate_steps: List[Tuple[AgentAction, BaseToolOutput]] = [] 49 | 50 | 51 | async def chat( 52 | query: str = Body(..., description="用户输入", examples=["帮我计算100+1"]), 53 | message_id: str = Body(None, description="数据库消息ID"), 54 | history: List = Body( 55 | [], 56 | description="历史对话,设为一个整数可以从数据库中读取历史消息", 57 | examples=[ 58 | [ 59 | {"role": "user", "content": "你好"}, 60 | {"role": "assistant", "content": "有什么需要帮助的"}, 61 | ] 62 | ], 63 | ), 64 | ): 65 | """Agent 对话""" 66 | agent_executor = ZhipuAIAllToolsRunnable.create_agent_executor( 67 | model_name="glm-4-alltools", 68 | history=history, 69 | intermediate_steps=intermediate_steps, 70 | tools=[ 71 | {"type": "code_interpreter"}, 72 | {"type": "web_browser"}, 73 | {"type": "drawing_tool"}, 74 | calculate, 75 | ], 76 | ) 77 | chat_iterator = agent_executor.invoke(chat_input=query) 78 | 79 | async def chat_generator(): 80 | async for chat_output in chat_iterator: 81 | yield chat_output.to_json() 82 | 83 | # if agent_executor.callback.out: 84 | # intermediate_steps.extend(agent_executor.callback.intermediate_steps) 85 | 86 | return EventSourceResponse(chat_generator()) 87 | 88 | 89 | if __name__ == "__main__": 90 | logging_conf = get_config_dict( 91 | "debug", 92 | get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"), 93 | 1024 * 1024 * 1024 * 3, 94 | 1024 * 1024 * 1024 * 3, 95 | ) 96 | logging.config.dictConfig(logging_conf) # type: ignore 97 | app = FastAPI() 98 | app.add_middleware( 99 | CORSMiddleware, 100 | allow_origins=["*"], 101 | allow_credentials=True, 102 | allow_methods=["*"], 103 | allow_headers=["*"], 104 | ) 105 | 106 | chat_router = APIRouter() 107 | 108 | chat_router.add_api_route( 109 | "/chat", 110 | chat, 111 | response_model=OutputType, 112 | status_code=status.HTTP_200_OK, 113 | methods=["POST"], 114 | description="与llm模型对话(通过LLMChain)", 115 | ) 116 | app.include_router(chat_router) 117 | 118 | config = Config( 119 | app=app, 120 | host="127.0.0.1", 121 | port=10000, 122 | log_config=logging_conf, 123 | ) 124 | _server = Server(config) 125 | 126 | def run_server(): 127 | _server.shutdown_timeout = 2 # 设置为2秒 128 | 129 | _server.run() 130 | 131 | _server_thread = threading.Thread(target=run_server) 132 | _server_thread.start() 133 | _server_thread.join() 134 | -------------------------------------------------------------------------------- /tests/assistant/start_chat.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | 5 | if __name__ == "__main__": 6 | script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "webui.py") 7 | try: 8 | # for streamlit >= 1.12.1 9 | from streamlit.web import bootstrap 10 | except ImportError: 11 | from streamlit import bootstrap 12 | 13 | flag_options = {"server_address": "127.0.0.1", "server_port": 8501} 14 | args = [] 15 | bootstrap.load_config_options(flag_options=flag_options) 16 | bootstrap.run(script_dir, False, args, flag_options) 17 | -------------------------------------------------------------------------------- /tests/assistant/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import asyncio 3 | import base64 4 | from io import BytesIO 5 | 6 | 7 | def get_img_base64(file_name: str) -> str: 8 | """ 9 | get_img_base64 used in streamlit. 10 | absolute local path not working on windows. 11 | """ 12 | # 读取图片 13 | with open(file_name, "rb") as f: 14 | buffer = BytesIO(f.read()) 15 | base_str = base64.b64encode(buffer.getvalue()).decode() 16 | return f"data:image/png;base64,{base_str}" 17 | 18 | 19 | def ensure_event_loop(): 20 | try: 21 | loop = asyncio.get_event_loop() 22 | except RuntimeError: 23 | loop = asyncio.new_event_loop() 24 | asyncio.set_event_loop(loop) 25 | return loop 26 | 27 | 28 | # Create an event loop to run the async functions synchronously 29 | def run_sync(func, *args, **kwargs): 30 | loop = ensure_event_loop() 31 | asyncio.set_event_loop(loop) 32 | 33 | return asyncio.run(func(*args, **kwargs)) 34 | -------------------------------------------------------------------------------- /tests/assistant/webui.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | import streamlit as st 5 | 6 | # from chatchat.webui_pages.loom_view_client import update_store 7 | # from chatchat.webui_pages.openai_plugins import openai_plugins_page 8 | from streamlit_option_menu import option_menu 9 | 10 | from tests.assistant.client import ZhipuAIPluginsClient 11 | from tests.assistant.dialogue import dialogue_page 12 | from tests.assistant.utils import get_img_base64 13 | 14 | api = ZhipuAIPluginsClient(base_url="http://127.0.0.1:10000") 15 | 16 | 17 | if __name__ == "__main__": 18 | st.set_page_config( 19 | "assistant", 20 | get_img_base64( 21 | os.path.join( 22 | os.path.dirname(os.path.abspath(__file__)), 23 | "chatchat_icon_blue_square_v2.png", 24 | ) 25 | ), 26 | initial_sidebar_state="expanded", 27 | menu_items={}, 28 | layout="wide", 29 | ) 30 | 31 | # use the following code to set the app to wide mode and the html markdown to increase the sidebar width 32 | st.markdown( 33 | """ 34 |