├── generate_colab_link.sh ├── .gitignore ├── README.md ├── Simon-LLM-Application-OpenAI-Python-SDK-Google-Gemini-API-v2.ipynb ├── LICENSE ├── Simon_LLM_Application_Google_Gemini_model_Openai_agent_sdk_example.ipynb ├── Simon-LLM-Application-Use_Gemini_Model_to_build_AI_Tools_and_AI_Agents.ipynb ├── Simon_LLM_Application_Google_Vertex_AI_GenAISDK_Full_Method.ipynb ├── Simon_LLM_Application_VLLM_Tool_Google_Gemma3_Model_Service.ipynb ├── Simon-LLM-Application-premier-12-chinese-taipei-performance-data-AI-Agent.ipynb └── Simon-LLM-Application-Gemma-2b-LORA-Fine-Tuning.ipynb /generate_colab_link.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 設置 GitHub 倉庫基本資訊 4 | GITHUB_USER="LiuYuWei" 5 | GITHUB_REPO="llm-colab-application" 6 | 7 | # 檢查是否有 .ipynb 文件,排除掉 -checkpoint.ipynb 的檔案 8 | NOTEBOOK_FILES=$(find . -name "*.ipynb" -type f | grep -v "\-checkpoint.ipynb") 9 | 10 | if [ -z "$NOTEBOOK_FILES" ]; then 11 | echo "沒有找到任何 .ipynb 文件。" 12 | exit 1 13 | fi 14 | 15 | # 初始化 README 文件 16 | if [ ! -f "README.md" ]; then 17 | echo "# $GITHUB_REPO" > README.md 18 | echo -e "\n## Google Colab 連結\n" >> README.md 19 | else 20 | # 清除 README 中舊的 Colab 連結部分 21 | sed -i '' '/## Google Colab 連結/,$d' README.md 22 | echo -e "\n## Google Colab 連結\n\n" >> README.md 23 | fi 24 | 25 | # 對每個 .ipynb 文件生成 Colab 連結並添加到 README 文件中 26 | for FILE in $NOTEBOOK_FILES; do 27 | # 提取 Notebook 文件名 28 | NOTEBOOK_NAME=$(basename "$FILE") 29 | 30 | # 生成 Colab 連結 31 | COLAB_LINK="https://colab.research.google.com/github/$GITHUB_USER/$GITHUB_REPO/blob/main/$NOTEBOOK_NAME" 32 | 33 | # 添加到 README 34 | echo -e "- $NOTEBOOK_NAME\n\n [$NOTEBOOK_NAME]($COLAB_LINK)\n\n" >> README.md 35 | done 36 | 37 | echo "所有 Colab 連結已生成並添加到 README.md。" 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llm-colab-application 2 | LLM AI 模型 Colab 應用 - 由 [Simon Liu](https://tinyurl.com/simonliuyuwei) 整理與撰寫的 Google Colab Notebook ,目的是希望能夠提供 LLM 模型實作應用方案,便於快速驗證、教學。 3 | 4 | ## Google Colab 連結 5 | 6 | ### Gemini 7 | 8 | [ Model - Python SDK ] 9 | 10 | - Google Gemini - Google AI Studio - google-genai python SDK: 11 | 12 | [Simon_LLM_Application_Google_AI_Studio_GenAISDK_Full_Method.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon_LLM_Application_Google_AI_Studio_GenAISDK_Full_Method.ipynb) 13 | 14 | - Google Gemini - Google Cloud Vertex AI - google-genai python SDK: 15 | 16 | [Simon_LLM_Application_Google_Vertex_AI_GenAISDK_Full_Method.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon_LLM_Application_Google_Vertex_AI_GenAISDK_Full_Method.ipynb) 17 | 18 | - Google Gemini google-genai python SDK Example: (準備下架) 19 | 20 | [Simon_LLM_Application_Google_Gemini_GenAISDK_Full_Method.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon_LLM_Application_Google_Gemini_GenAISDK_Full_Method.ipynb) 21 | 22 | - OpenAI Python SDK - Google Gemini API Example: 23 | 24 | [Simon-LLM-Application-OpenAI-Python-SDK-Google-Gemini-API-v2.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon-LLM-Application-OpenAI-Python-SDK-Google-Gemini-API-v2.ipynb) 25 | 26 | [ AI Agent ] 27 | 28 | - CWA API AI Agent - Gemini Model API: 29 | 30 | [Simon-LLM-Application-Gemini-2-Example-Gemini-AI-Agent-Application.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon-LLM-Application-Gemini-2-Example-Gemini-AI-Agent-Application.ipynb) 31 | 32 | - Four arithmetic operations(四則運算) AI Agent - Gemini Model API: 33 | 34 | [Simon-LLM-Application-Use_Gemini_Model_to_build_AI_Tools_and_AI_Agents.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon-LLM-Application-Use_Gemini_Model_to_build_AI_Tools_and_AI_Agents.ipynb) 35 | 36 | - Premier 12 CT Performance AI Agent - Gemini Model API: 37 | 38 | [Simon-LLM-Application-premier-12-chinese-taipei-performance-data-AI-Agent.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon-LLM-Application-premier-12-chinese-taipei-performance-data-AI-Agent.ipynb) 39 | 40 | - OpenAI SDK AI Agent - Gemini Model API: 41 | 42 | [Simon_LLM_Application_Google_Gemini_model_Openai_agent_sdk_example.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon_LLM_Application_Google_Gemini_model_Openai_agent_sdk_example.ipynb) 43 | 44 | [ MCP ] 45 | 46 | - MCP - google-genai python SDK - Gemini Model API: 47 | 48 | [Simon_LLM_Google_Gemini_Model_API_SDK_MCP_Example.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon_LLM_Google_Gemini_Model_API_SDK_MCP_Example.ipynb) 49 | 50 | [ Model Evaluation ] 51 | 52 | - Twinkle Eval Model Evaluation - Gemini Model API: 53 | 54 | [Simon_LLM_Application_Twinkle_Eval_Tool_Google_Gemini_Model_Evaluation.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon_LLM_Application_Twinkle_Eval_Tool_Google_Gemini_Model_Evaluation.ipynb) 55 | 56 | ### Gemma 57 | 58 | [ LLM as a Service ] 59 | 60 | - Ollama tool Model Service Example - Google Gemma Model: 61 | 62 | [Simon-LLM-Application-Ollama-Model-Service.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon-LLM-Application-Ollama-Model-Service.ipynb) 63 | 64 | - vllm tool Model Service Example - Google Gemma Model: 65 | 66 | [Simon_LLM_Application_VLLM_Tool_Google_Gemma3_Model_Service.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon_LLM_Application_VLLM_Tool_Google_Gemma3_Model_Service.ipynb) 67 | 68 | - Expose Ngrok Link for Ollama Model Service: 69 | 70 | [Simon_LLM_Application_Ollama_Ngrok_Llm_Service.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon_LLM_Application_Ollama_Ngrok_Llm_Service.ipynb) 71 | 72 | - gpt-oss ollama service 73 | 74 | [Simon_LLM_Application_gpt_oss_Ollama_Llm_Service.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon_LLM_Application_gpt_oss_Ollama_Llm_Service.ipynb) 75 | 76 | [ RAG ] 77 | 78 | - Simon-LLM-Application-Ollama-LLM-Model-and-RAG.ipynb 79 | 80 | [Simon-LLM-Application-Ollama-LLM-Model-and-RAG.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon-LLM-Application-Ollama-LLM-Model-and-RAG.ipynb) 81 | 82 | [ Fine-tuning ] 83 | 84 | - Gemma 2b LORA Fine-tuning: 85 | 86 | [Simon-LLM-Application-Gemma-2b-LORA-Fine-Tuning.ipynb](https://colab.research.google.com/github/LiuYuWei/llm-colab-application/blob/main/Simon-LLM-Application-Gemma-2b-LORA-Fine-Tuning.ipynb) 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /Simon-LLM-Application-OpenAI-Python-SDK-Google-Gemini-API-v2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "-QhPWE1lwZHH" 7 | }, 8 | "source": [ 9 | "# 使用 OpenAI Python SDK 來去使用 Google Gemini API\n", 10 | "\n", 11 | "Made by SimonLiu" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": { 17 | "id": "db29b8d4247e" 18 | }, 19 | "source": [ 20 | "本 Colab 將介紹如何透過 OpenAI 的 Python SDK 來訪問並使用 Google Gemini API。我們將一步步說明如何設定環境、安裝必要的套件,以及撰寫 Python 程式碼來調用 Gemini API,讓開發者可以簡單地與 Google 的服務進行互動。" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": { 26 | "id": "NNNg43Ymw54e" 27 | }, 28 | "source": [ 29 | "## 先決條件\n", 30 | "\n", 31 | "\n", 32 | "您可以在 Google Colab 中運行此教程,無需額外的環境配置。\n", 33 | "或者,若要在本地完成此快速入門,請參閱 Gemini API 開始指南 中的 Python 指南。" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": { 39 | "id": "kHkHARdb1ZID" 40 | }, 41 | "source": [ 42 | "## Install the SDK\n", 43 | "\n", 44 | "安裝 OpenAI Python SDK - 使用 pip 安裝:" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 1, 50 | "metadata": { 51 | "colab": { 52 | "base_uri": "https://localhost:8080/" 53 | }, 54 | "id": "J6Pd9SFJ1yVi", 55 | "outputId": "2b7c7d05-f9eb-48e2-8986-d76e029b4de3" 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "!pip install -q -U openai" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": { 65 | "id": "EeMCtmx9ykyx" 66 | }, 67 | "source": [ 68 | "## 設置您的 API 密鑰\n", 69 | "\n", 70 | "若要使用 Gemini API,您需要一個 API 密鑰。如果您還沒有密鑰,可以在 Google AI Studio 中創建一個密鑰。\n", 71 | "\n", 72 | "Get an API key\n", 73 | "\n", 74 | "\n", 75 | "在 Colab 中,將密鑰添加到左側面板中的 \"🔑\" 秘密管理器。將其命名為 GOOGLE_API_KEY。然後將密鑰傳遞給 SDK:" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 2, 81 | "metadata": { 82 | "id": "HTiaTu6O1LRC" 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY')" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": { 92 | "id": "nXxypzJH4MUl" 93 | }, 94 | "source": [ 95 | "## 生成文本\n" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 8, 101 | "metadata": { 102 | "colab": { 103 | "base_uri": "https://localhost:8080/", 104 | "height": 301 105 | }, 106 | "id": "j51mcrLD4Y2W", 107 | "outputId": "0500f866-e58d-49de-c2c1-76b1a12c3a1d" 108 | }, 109 | "outputs": [ 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "您好!我是一個大型語言模型,由 Google 訓練而成。我的目標是提供資訊、協助完成任務,以及進行有意義的對話。 \n", 115 | "\n", 116 | "我可以理解和回應各種問題,生成各種文字內容,包括故事、詩歌、程式碼等等。我還可以翻譯語言、總結資訊,以及回答您的問題。\n", 117 | "\n", 118 | "我仍在不斷學習和進步,希望能成為您最可靠的助手。如果您有任何需要幫忙的地方,請隨時告訴我! \n", 119 | "\n" 120 | ] 121 | } 122 | ], 123 | "source": [ 124 | "from openai import OpenAI\n", 125 | "\n", 126 | "client = OpenAI(\n", 127 | " api_key=GOOGLE_API_KEY,\n", 128 | " base_url=\"https://generativelanguage.googleapis.com/v1beta/\"\n", 129 | ")\n", 130 | "\n", 131 | "response = client.chat.completions.create(\n", 132 | " model=\"gemini-1.5-flash\",\n", 133 | " n=1,\n", 134 | " messages=[\n", 135 | " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", 136 | " {\n", 137 | " \"role\": \"user\",\n", 138 | " \"content\": \"請使用繁體中文介紹你自己。\"\n", 139 | " }\n", 140 | " ]\n", 141 | ")\n", 142 | "\n", 143 | "print(response.choices[0].message.content)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": { 149 | "id": "zUUAQS9u4biH", 150 | "jp-MarkdownHeadingCollapsed": true, 151 | "tags": [] 152 | }, 153 | "source": [ 154 | "## 相關資訊\n", 155 | "\n", 156 | "- [Simon Liu 文章](https://medium.com/@simon3458/lab-openai-python-sdk-google-gemini-api-30710fa54b48)\n", 157 | "- [Google 官方部落格資訊](https://developers.googleblog.com/en/gemini-is-now-accessible-from-the-openai-library/)" 158 | ] 159 | } 160 | ], 161 | "metadata": { 162 | "colab": { 163 | "provenance": [] 164 | }, 165 | "kernelspec": { 166 | "display_name": "conda_python3", 167 | "language": "python", 168 | "name": "conda_python3" 169 | }, 170 | "language_info": { 171 | "codemirror_mode": { 172 | "name": "ipython", 173 | "version": 3 174 | }, 175 | "file_extension": ".py", 176 | "mimetype": "text/x-python", 177 | "name": "python", 178 | "nbconvert_exporter": "python", 179 | "pygments_lexer": "ipython3", 180 | "version": "3.10.15" 181 | } 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 4 185 | } 186 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Simon_LLM_Application_Google_Gemini_model_Openai_agent_sdk_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "source": [ 20 | "# 使用 openai-agent 的 python-sdk 來 透過 Google Gemini Model API 建立各種 Agents\n", 21 | "\n", 22 | "> 作者: Simon Liu" 23 | ], 24 | "metadata": { 25 | "id": "BLY2gaU843dx" 26 | } 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "source": [ 31 | "## I. 安裝 OpenAI Agents python package" 32 | ], 33 | "metadata": { 34 | "id": "It_ocGyB6JWm" 35 | } 36 | }, 37 | { 38 | "cell_type": "code", 39 | "source": [ 40 | "!pip install -q openai-agents" 41 | ], 42 | "metadata": { 43 | "id": "I6jUbYjKzjo_" 44 | }, 45 | "execution_count": null, 46 | "outputs": [] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "source": [ 51 | "## II. 設定 Agent SDK ,並且將 Google Gemini API Key 等資訊設定進去。" 52 | ], 53 | "metadata": { 54 | "id": "3rLaDL9d6Sen" 55 | } 56 | }, 57 | { 58 | "cell_type": "code", 59 | "source": [ 60 | "import os\n", 61 | "import asyncio\n", 62 | "\n", 63 | "from openai import AsyncOpenAI\n", 64 | "from agents import Agent, OpenAIChatCompletionsModel, Runner, function_tool, set_tracing_disabled" 65 | ], 66 | "metadata": { 67 | "id": "w4wMirh9zirp" 68 | }, 69 | "execution_count": 3, 70 | "outputs": [] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "source": [ 75 | "from google.colab import userdata\n", 76 | "\n", 77 | "BASE_URL = os.getenv(\"EXAMPLE_BASE_URL\") or \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n", 78 | "API_KEY = os.getenv(\"EXAMPLE_API_KEY\") or userdata.get('GOOGLE_API_KEY')\n", 79 | "MODEL_NAME = os.getenv(\"EXAMPLE_MODEL_NAME\") or \"gemini-2.0-flash\"\n", 80 | "\n", 81 | "if not BASE_URL or not API_KEY or not MODEL_NAME:\n", 82 | " raise ValueError(\n", 83 | " \"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code.\"\n", 84 | " )" 85 | ], 86 | "metadata": { 87 | "id": "5dx8e1Agzy1h" 88 | }, 89 | "execution_count": 5, 90 | "outputs": [] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "source": [ 95 | "client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY)\n", 96 | "set_tracing_disabled(disabled=True)" 97 | ], 98 | "metadata": { 99 | "id": "skRFPl7-0NBT" 100 | }, 101 | "execution_count": 6, 102 | "outputs": [] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "source": [ 107 | "## III. 建立第一個 Agent: 取得地區天氣(非真實串接,只是吐回文字)" 108 | ], 109 | "metadata": { 110 | "id": "am2NpN196e28" 111 | } 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 11, 116 | "metadata": { 117 | "id": "OWTbyI1czhPw" 118 | }, 119 | "outputs": [], 120 | "source": [ 121 | "@function_tool\n", 122 | "def get_weather(city: str):\n", 123 | " print(f\"[debug] getting weather for {city}\")\n", 124 | " return f\"The weather in {city} is sunny.\"\n", 125 | "\n", 126 | "\n", 127 | "async def main(prompt = \"What's the weather in Tokyo?\"):\n", 128 | " # This agent will use the custom LLM provider\n", 129 | " agent = Agent(\n", 130 | " name=\"Assistant\",\n", 131 | " instructions=\"請使用繁體中文回覆結果\",\n", 132 | " model=OpenAIChatCompletionsModel(model=MODEL_NAME, openai_client=client),\n", 133 | " tools=[get_weather],\n", 134 | " )\n", 135 | "\n", 136 | " result = await Runner.run(agent, prompt)\n", 137 | " print(result.final_output)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "source": [ 143 | "await main(\"請問台北的天氣如何?\")" 144 | ], 145 | "metadata": { 146 | "colab": { 147 | "base_uri": "https://localhost:8080/" 148 | }, 149 | "id": "ZVdl22_F0RQl", 150 | "outputId": "a0cb52ee-87e7-4373-fbbd-6eb37d0d609d" 151 | }, 152 | "execution_count": 18, 153 | "outputs": [ 154 | { 155 | "output_type": "stream", 156 | "name": "stdout", 157 | "text": [ 158 | "[debug] getting weather for 台北\n", 159 | "台北的天氣是晴朗的。\n", 160 | "\n" 161 | ] 162 | } 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "source": [ 168 | "## IV. 中階 Agnet: 四則運算 Agent" 169 | ], 170 | "metadata": { 171 | "id": "3U5JQPy57J_a" 172 | } 173 | }, 174 | { 175 | "cell_type": "code", 176 | "source": [ 177 | "@function_tool\n", 178 | "def calculate(a: float, b: float, operator: str):\n", 179 | " print(f\"[debug] calculating: {a} {operator} {b}\")\n", 180 | " if operator == '+':\n", 181 | " return f\"結果是 {a + b}\"\n", 182 | " elif operator == '-':\n", 183 | " return f\"結果是 {a - b}\"\n", 184 | " elif operator == '*':\n", 185 | " return f\"結果是 {a * b}\"\n", 186 | " elif operator == '/':\n", 187 | " if b == 0:\n", 188 | " return \"錯誤:除數不能為零。\"\n", 189 | " return f\"結果是 {a / b}\"\n", 190 | " else:\n", 191 | " return \"錯誤:不支援的運算符,請使用 '+', '-', '*', '/' 其中之一。\"\n", 192 | "\n", 193 | "async def main(prompt=\"幫我算 12 除以 4\"):\n", 194 | " agent = Agent(\n", 195 | " name=\"Calculator\",\n", 196 | " instructions=\"請使用繁體中文回覆結果\",\n", 197 | " model=OpenAIChatCompletionsModel(model=MODEL_NAME, openai_client=client),\n", 198 | " tools=[calculate],\n", 199 | " )\n", 200 | "\n", 201 | " result = await Runner.run(agent, prompt)\n", 202 | " print(result.final_output)\n" 203 | ], 204 | "metadata": { 205 | "id": "Sp5DP0zO01aG" 206 | }, 207 | "execution_count": 19, 208 | "outputs": [] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "source": [ 213 | "await main(\"請問312除以四的平方,然後再加555是多少?\")" 214 | ], 215 | "metadata": { 216 | "colab": { 217 | "base_uri": "https://localhost:8080/" 218 | }, 219 | "id": "F37OjRzO1EiW", 220 | "outputId": "8c422502-0441-49a6-b2c5-bb48e24a80ef" 221 | }, 222 | "execution_count": 21, 223 | "outputs": [ 224 | { 225 | "output_type": "stream", 226 | "name": "stdout", 227 | "text": [ 228 | "[debug] calculating: 4.0 * 4.0\n", 229 | "[debug] calculating: 312.0 / 16.0\n", 230 | "[debug] calculating: 19.5 + 555.0\n", 231 | "312 除以四的平方,然後再加 555 是 574.5。\n", 232 | "\n" 233 | ] 234 | } 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "source": [ 240 | "## V. 高階 Agents: 透過 xgboost 進行模型訓練" 241 | ], 242 | "metadata": { 243 | "id": "obOdTu6H1eHk" 244 | } 245 | }, 246 | { 247 | "cell_type": "code", 248 | "source": [ 249 | "import pandas as pd\n", 250 | "import xgboost as xgb\n", 251 | "from sklearn.model_selection import train_test_split\n", 252 | "from sklearn.metrics import accuracy_score\n", 253 | "from sklearn.preprocessing import LabelEncoder\n", 254 | "import joblib\n", 255 | "\n", 256 | "@function_tool\n", 257 | "def train_xgboost_model(\n", 258 | " data_path: str,\n", 259 | " target_column: str,\n", 260 | " model_output_path: str = \"xgboost_model.pkl\"\n", 261 | "):\n", 262 | " \"\"\"\n", 263 | " 訓練一個 XGBoost 分類模型。\n", 264 | "\n", 265 | " Parameters:\n", 266 | " - data_path: 資料的 CSV 檔案路徑\n", 267 | " - target_column: 目標變數欄位名稱\n", 268 | " - model_output_path: 訓練好的模型儲存檔案路徑\n", 269 | "\n", 270 | " Returns:\n", 271 | " - 模型訓練的準確率與儲存位置\n", 272 | " \"\"\"\n", 273 | " print(f\"[debug] loading data from {data_path}\")\n", 274 | " df = pd.read_csv(data_path)\n", 275 | "\n", 276 | " if target_column not in df.columns:\n", 277 | " return f\"錯誤:指定的目標欄位 '{target_column}' 不存在於資料集中。\"\n", 278 | "\n", 279 | " X = df.drop(columns=[target_column])\n", 280 | " y = df[target_column]\n", 281 | "\n", 282 | " # 將非數值類別轉為整數編碼\n", 283 | " if y.dtype == 'object' or y.dtype.name == 'category':\n", 284 | " print(\"[debug] encoding target labels\")\n", 285 | " label_encoder = LabelEncoder()\n", 286 | " y = label_encoder.fit_transform(y)\n", 287 | " # 可選:儲存 label encoder,供預測階段使用\n", 288 | " joblib.dump(label_encoder, model_output_path.replace(\".pkl\", \"_label_encoder.pkl\"))\n", 289 | "\n", 290 | " print(f\"[debug] splitting data\")\n", 291 | " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", 292 | "\n", 293 | " print(f\"[debug] training XGBoost model\")\n", 294 | " model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss')\n", 295 | " model.fit(X_train, y_train)\n", 296 | "\n", 297 | " print(f\"[debug] evaluating model\")\n", 298 | " y_pred = model.predict(X_test)\n", 299 | " acc = accuracy_score(y_test, y_pred)\n", 300 | "\n", 301 | " print(f\"[debug] saving model to {model_output_path}\")\n", 302 | " joblib.dump(model, model_output_path)\n", 303 | "\n", 304 | " return f\"模型訓練完成。準確率為 {acc:.2%},模型與 LabelEncoder 分別儲存於 {model_output_path} 與 {model_output_path.replace('.pkl', '_label_encoder.pkl')}\"\n" 305 | ], 306 | "metadata": { 307 | "id": "buA3_frb1eZV" 308 | }, 309 | "execution_count": 30, 310 | "outputs": [] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "source": [ 315 | "async def main(prompt):\n", 316 | " agent = Agent(\n", 317 | " name=\"MLTrainer\",\n", 318 | " instructions=\"請使用繁體中文回覆結果\",\n", 319 | " model=OpenAIChatCompletionsModel(model=MODEL_NAME, openai_client=client),\n", 320 | " tools=[train_xgboost_model],\n", 321 | " )\n", 322 | "\n", 323 | " result = await Runner.run(agent, prompt)\n", 324 | " print(result.final_output)\n" 325 | ], 326 | "metadata": { 327 | "id": "DizxuK9J1iK1" 328 | }, 329 | "execution_count": 41, 330 | "outputs": [] 331 | }, 332 | { 333 | "cell_type": "markdown", 334 | "source": [ 335 | "### Case 1: Iris dataset 訓練" 336 | ], 337 | "metadata": { 338 | "id": "B-pgTwi6-uSi" 339 | } 340 | }, 341 | { 342 | "cell_type": "code", 343 | "source": [ 344 | "# 創建相關資料夾\n", 345 | "\n", 346 | "!mkdir data/ model/" 347 | ], 348 | "metadata": { 349 | "id": "uQh2GYll3Jqj" 350 | }, 351 | "execution_count": 38, 352 | "outputs": [] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "source": [ 357 | "# 下載 iris dataset csv 檔案\n", 358 | "\n", 359 | "!curl --output ./data/iris.csv \\\n", 360 | " --url https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv" 361 | ], 362 | "metadata": { 363 | "colab": { 364 | "base_uri": "https://localhost:8080/" 365 | }, 366 | "id": "ToBBk7ux1lDi", 367 | "outputId": "d9be7cef-eed4-41c6-e095-e1dd5cbbdaa7" 368 | }, 369 | "execution_count": 40, 370 | "outputs": [ 371 | { 372 | "output_type": "stream", 373 | "name": "stdout", 374 | "text": [ 375 | " % Total % Received % Xferd Average Speed Time Time Time Current\n", 376 | " Dload Upload Total Spent Left Speed\n", 377 | "\r 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\r100 3858 100 3858 0 0 27656 0 --:--:-- --:--:-- --:--:-- 27755\n" 378 | ] 379 | } 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "source": [ 385 | "# 下 Prompt 來進行模型訓練\n", 386 | "\n", 387 | "await main(\"請用檔案路徑 ./data/iris.csv 訓練一個目標欄位為 species 的 XGBoost 模型,模型訓練好請存到 ./model/iris_model.pkl\")" 388 | ], 389 | "metadata": { 390 | "colab": { 391 | "base_uri": "https://localhost:8080/" 392 | }, 393 | "id": "NuaXpwCW1uQT", 394 | "outputId": "300f32ef-5001-408f-b147-34a653b17cad" 395 | }, 396 | "execution_count": 42, 397 | "outputs": [ 398 | { 399 | "output_type": "stream", 400 | "name": "stdout", 401 | "text": [ 402 | "[debug] loading data from ./data/iris.csv\n", 403 | "[debug] encoding target labels\n", 404 | "[debug] splitting data\n", 405 | "[debug] training XGBoost model\n" 406 | ] 407 | }, 408 | { 409 | "output_type": "stream", 410 | "name": "stderr", 411 | "text": [ 412 | "/usr/local/lib/python3.11/dist-packages/xgboost/core.py:158: UserWarning: [07:46:12] WARNING: /workspace/src/learner.cc:740: \n", 413 | "Parameters: { \"use_label_encoder\" } are not used.\n", 414 | "\n", 415 | " warnings.warn(smsg, UserWarning)\n" 416 | ] 417 | }, 418 | { 419 | "output_type": "stream", 420 | "name": "stdout", 421 | "text": [ 422 | "[debug] evaluating model\n", 423 | "[debug] saving model to ./model/iris_model.pkl\n", 424 | "模型訓練完成。準確率為 100.00%,模型與 LabelEncoder 分別儲存於 ./model/iris_model.pkl 與 ./model/iris_model_label_encoder.pkl。\n", 425 | "\n" 426 | ] 427 | } 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "source": [ 433 | "### Case 2: Wine dataset 分類演算法模型訓練" 434 | ], 435 | "metadata": { 436 | "id": "5HF8D762-10u" 437 | } 438 | }, 439 | { 440 | "cell_type": "code", 441 | "source": [ 442 | "!curl --output ./data/wine.csv \\\n", 443 | " --url https://gist.githubusercontent.com/tijptjik/9408623/raw/b237fa5848349a14a14e5d4107dc7897c21951f5/wine.csv" 444 | ], 445 | "metadata": { 446 | "colab": { 447 | "base_uri": "https://localhost:8080/" 448 | }, 449 | "id": "ejVoBP8_2hnQ", 450 | "outputId": "10074ff8-181a-4f8e-fea5-dd84e4f466df" 451 | }, 452 | "execution_count": 43, 453 | "outputs": [ 454 | { 455 | "output_type": "stream", 456 | "name": "stdout", 457 | "text": [ 458 | " % Total % Received % Xferd Average Speed Time Time Time Current\n", 459 | " Dload Upload Total Spent Left Speed\n", 460 | "\r 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\r 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\r100 10889 100 10889 0 0 40949 0 --:--:-- --:--:-- --:--:-- 40936\n" 461 | ] 462 | } 463 | ] 464 | }, 465 | { 466 | "cell_type": "code", 467 | "source": [ 468 | "await main(\"請用檔案路徑 ./data/wine.csv 訓練一個目標欄位為 Wine 的 XGBoost 模型,模型訓練好請存到 ./model/wine_model.pkl\")" 469 | ], 470 | "metadata": { 471 | "colab": { 472 | "base_uri": "https://localhost:8080/" 473 | }, 474 | "id": "EYAzSGyJ3nw7", 475 | "outputId": "507fd6c2-2360-49f1-e88e-c7ed0e1fefb6" 476 | }, 477 | "execution_count": 45, 478 | "outputs": [ 479 | { 480 | "output_type": "stream", 481 | "name": "stdout", 482 | "text": [ 483 | "[debug] loading data from ./data/wine.csv\n", 484 | "[debug] splitting data\n", 485 | "[debug] training XGBoost model\n", 486 | "模型訓練發生錯誤。錯誤訊息指出從目標變數 `Wine` 推斷出的類別無效。預期的類別是 `[0 1 2]`,但實際得到的類別是 `[1 2 3]`。這表示目標變數的類別編碼從 1 開始,而不是從 0 開始。XGBoost 模型預期類別從 0 開始編碼。\n", 487 | "\n", 488 | "為了修正這個問題,您需要預先處理您的資料,將 `Wine` 欄位中的類別編碼從 `[1 2 3]` 轉換為 `[0 1 2]`。您可以使用 Python 和 Pandas 讀取 CSV 檔案,然後將 `Wine` 欄位的值減 1。\n", 489 | "\n", 490 | "如果修正資料後,您仍然遇到問題,請檢查資料路徑、目標欄位名稱和模型輸出路徑是否正確。\n", 491 | "\n" 492 | ] 493 | } 494 | ] 495 | }, 496 | { 497 | "cell_type": "markdown", 498 | "source": [ 499 | "-> 你應該會得到類別無效的錯誤訊息,這是正常的,因為 target label 跟 iris dataset 是不一樣的,你有兩個解法:\n", 500 | "\n", 501 | "1. 自己整理成一樣的資料樣子\n", 502 | "2. 寫一個 function tool 來進行資料整理\n", 503 | "\n", 504 | "因為要示範 Agent ,所以這邊示範第二種。" 505 | ], 506 | "metadata": { 507 | "id": "KvezgQr-_DL2" 508 | } 509 | }, 510 | { 511 | "cell_type": "code", 512 | "source": [ 513 | "@function_tool\n", 514 | "def normalize_class_labels(\n", 515 | " data_path: str,\n", 516 | " target_column: str,\n", 517 | " output_path: str = None\n", 518 | "):\n", 519 | " \"\"\"\n", 520 | " 將分類欄位中的類別值調整為從 0 開始的整數(例如從 [1,2,3] → [0,1,2])。\n", 521 | "\n", 522 | " Parameters:\n", 523 | " - data_path: 原始 CSV 檔案路徑\n", 524 | " - target_column: 需要標準化的目標欄位名稱\n", 525 | " - output_path: 轉換後資料儲存路徑(如未提供,將覆寫原始檔案)\n", 526 | "\n", 527 | " Returns:\n", 528 | " - 新類別對應表與儲存檔案路徑\n", 529 | " \"\"\"\n", 530 | " print(f\"[debug] loading data from {data_path}\")\n", 531 | " df = pd.read_csv(data_path)\n", 532 | "\n", 533 | " if target_column not in df.columns:\n", 534 | " return f\"錯誤:目標欄位 '{target_column}' 不存在於資料中。\"\n", 535 | "\n", 536 | " print(f\"[debug] normalizing target column '{target_column}'\")\n", 537 | " unique_classes = sorted(df[target_column].unique())\n", 538 | " class_mapping = {label: idx for idx, label in enumerate(unique_classes)}\n", 539 | " df[target_column] = df[target_column].map(class_mapping)\n", 540 | "\n", 541 | " output_file = output_path if output_path else data_path\n", 542 | " df.to_csv(output_file, index=False)\n", 543 | "\n", 544 | " return f\"類別欄位已標準化為從 0 開始的整數。對應關係為:{class_mapping}。資料已儲存至 {output_file}\"\n" 545 | ], 546 | "metadata": { 547 | "id": "D3pBJeoK4fXh" 548 | }, 549 | "execution_count": 46, 550 | "outputs": [] 551 | }, 552 | { 553 | "cell_type": "code", 554 | "source": [ 555 | "async def main(prompt):\n", 556 | " agent = Agent(\n", 557 | " name=\"MLTrainer\",\n", 558 | " instructions=\"請使用繁體中文回覆結果\",\n", 559 | " model=OpenAIChatCompletionsModel(model=MODEL_NAME, openai_client=client),\n", 560 | " tools=[train_xgboost_model, normalize_class_labels],\n", 561 | " )\n", 562 | "\n", 563 | " result = await Runner.run(agent, prompt)\n", 564 | " print(result.final_output)\n" 565 | ], 566 | "metadata": { 567 | "id": "APOwMBaB4gYB" 568 | }, 569 | "execution_count": 47, 570 | "outputs": [] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "source": [ 575 | "await main(\"請用檔案路徑 ./data/wine.csv 訓練一個目標欄位為 Wine 的 XGBoost 模型,模型訓練好請存到 ./model/wine_model.pkl\")" 576 | ], 577 | "metadata": { 578 | "colab": { 579 | "base_uri": "https://localhost:8080/" 580 | }, 581 | "id": "NDfchmPg4j6M", 582 | "outputId": "3cb956f3-9619-4b4c-c1d7-7e019257bbef" 583 | }, 584 | "execution_count": 48, 585 | "outputs": [ 586 | { 587 | "output_type": "stream", 588 | "name": "stdout", 589 | "text": [ 590 | "[debug] loading data from ./data/wine.csv\n", 591 | "[debug] splitting data\n", 592 | "[debug] training XGBoost model\n", 593 | "[debug] loading data from ./data/wine.csv\n", 594 | "[debug] normalizing target column 'Wine'\n", 595 | "[debug] loading data from ./data/wine_normalized.csv\n", 596 | "[debug] splitting data\n", 597 | "[debug] training XGBoost model\n", 598 | "[debug] evaluating model\n", 599 | "[debug] saving model to ./model/wine_model.pkl\n" 600 | ] 601 | }, 602 | { 603 | "output_type": "stream", 604 | "name": "stderr", 605 | "text": [ 606 | "/usr/local/lib/python3.11/dist-packages/xgboost/core.py:158: UserWarning: [07:51:27] WARNING: /workspace/src/learner.cc:740: \n", 607 | "Parameters: { \"use_label_encoder\" } are not used.\n", 608 | "\n", 609 | " warnings.warn(smsg, UserWarning)\n" 610 | ] 611 | }, 612 | { 613 | "output_type": "stream", 614 | "name": "stdout", 615 | "text": [ 616 | "XGBoost 模型已成功訓練!準確率為 94.44%,模型儲存於 `./model/wine_model.pkl`。\n", 617 | "\n" 618 | ] 619 | } 620 | ] 621 | }, 622 | { 623 | "cell_type": "markdown", 624 | "source": [ 625 | "你應該就有得到正確訓練模型的訊息了。" 626 | ], 627 | "metadata": { 628 | "id": "bRD9Abo4_xXp" 629 | } 630 | } 631 | ] 632 | } -------------------------------------------------------------------------------- /Simon-LLM-Application-Use_Gemini_Model_to_build_AI_Tools_and_AI_Agents.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "9e5ae75a", 6 | "metadata": { 7 | "id": "9e5ae75a" 8 | }, 9 | "source": [ 10 | "# 使用 LangChain 和 Google Generative AI 建構工具與代理" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "107129ee", 16 | "metadata": { 17 | "id": "107129ee" 18 | }, 19 | "source": [ 20 | "## 安裝必要套件" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "id": "2d2814c9", 27 | "metadata": { 28 | "colab": { 29 | "base_uri": "https://localhost:8080/" 30 | }, 31 | "id": "2d2814c9", 32 | "outputId": "084b5e75-abbf-4b71-fbcc-a7cad199e678" 33 | }, 34 | "outputs": [ 35 | { 36 | "output_type": "stream", 37 | "name": "stdout", 38 | "text": [ 39 | "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/2.4 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.6/2.4 MB\u001b[0m \u001b[31m18.0 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━\u001b[0m \u001b[32m1.8/2.4 MB\u001b[0m \u001b[31m22.6 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m2.4/2.4 MB\u001b[0m \u001b[31m24.3 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.4/2.4 MB\u001b[0m \u001b[31m15.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 40 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.8/41.8 kB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 41 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m26.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 42 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.5/49.5 kB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 43 | "\u001b[?25h" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "# 安裝 LangChain 和 Google Generative AI 的相關套件。\n", 49 | "# 若尚未安裝這些套件,可以執行此指令。\n", 50 | "!pip install -q langchain_community langchain-google-genai" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "id": "9f20b82a", 56 | "metadata": { 57 | "id": "9f20b82a" 58 | }, 59 | "source": [ 60 | "## 設定 Google API 金鑰" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 2, 66 | "id": "5ddc42d4", 67 | "metadata": { 68 | "id": "5ddc42d4" 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "# 設定 Google API 金鑰,方便與 Google 相關服務整合。\n", 73 | "# 確保用戶資料中已存有 `GOOGLE_API_KEY`。\n", 74 | "import os\n", 75 | "from google.colab import userdata\n", 76 | "\n", 77 | "os.environ['GOOGLE_API_KEY'] = userdata.get('GOOGLE_API_KEY')" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "id": "f9f85e76", 83 | "metadata": { 84 | "id": "f9f85e76" 85 | }, 86 | "source": [ 87 | "## 定義簡單的工具" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 3, 93 | "id": "86d69da5", 94 | "metadata": { 95 | "id": "86d69da5" 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "# 使用 LangChain 定義一個工具:乘法運算。\n", 100 | "# 此工具接收兩個整數並返回它們的乘積。\n", 101 | "from langchain_core.tools import tool\n", 102 | "\n", 103 | "@tool\n", 104 | "def multiply(first_int: int, second_int: int) -> int:\n", 105 | " \"\"\"Multiply two integers together.\"\"\"\n", 106 | " return first_int * second_int" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "id": "b280b798", 112 | "metadata": { 113 | "id": "b280b798" 114 | }, 115 | "source": [ 116 | "## 測試工具功能" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 4, 122 | "id": "cd9df2e1", 123 | "metadata": { 124 | "colab": { 125 | "base_uri": "https://localhost:8080/" 126 | }, 127 | "id": "cd9df2e1", 128 | "outputId": "c9eb52d3-99da-4fe7-9918-f7e5a41ae1d2" 129 | }, 130 | "outputs": [ 131 | { 132 | "output_type": "execute_result", 133 | "data": { 134 | "text/plain": [ 135 | "20" 136 | ] 137 | }, 138 | "metadata": {}, 139 | "execution_count": 4 140 | } 141 | ], 142 | "source": [ 143 | "# 測試剛剛定義的 `multiply` 工具。\n", 144 | "# 使用字典格式提供參數,然後調用工具。\n", 145 | "multiply.invoke({'first_int': 4, 'second_int': 5})" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "id": "61fd1db4", 151 | "metadata": { 152 | "id": "61fd1db4" 153 | }, 154 | "source": [ 155 | "## 初始化 Gemini 語言模型" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 5, 161 | "id": "01251db7", 162 | "metadata": { 163 | "id": "01251db7" 164 | }, 165 | "outputs": [], 166 | "source": [ 167 | "# 使用 Google Generative AI 的 Chat 模型進行初始化。\n", 168 | "from langchain_google_genai import ChatGoogleGenerativeAI\n", 169 | "\n", 170 | "llm = ChatGoogleGenerativeAI(\n", 171 | " model=\"gemini-1.5-flash\", # 指定模型\n", 172 | " temperature=0, # 控制生成內容的隨機性,0 代表固定輸出\n", 173 | ")" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "id": "392d4b45", 179 | "metadata": { 180 | "id": "392d4b45" 181 | }, 182 | "source": [ 183 | "## 將工具與模型結合" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 6, 189 | "id": "08b4d5f1", 190 | "metadata": { 191 | "id": "08b4d5f1" 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "# 將工具綁定到語言模型,讓模型可以在需要時調用這些工具。\n", 196 | "llm_with_tools = llm.bind_tools([multiply])" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "id": "495c52de", 202 | "metadata": { 203 | "id": "495c52de" 204 | }, 205 | "source": [ 206 | "## 測試語言模型與工具的整合" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 7, 212 | "id": "b66d2aad", 213 | "metadata": { 214 | "colab": { 215 | "base_uri": "https://localhost:8080/" 216 | }, 217 | "id": "b66d2aad", 218 | "outputId": "1ec7e376-6fe6-475a-f373-35dcebda3c05" 219 | }, 220 | "outputs": [ 221 | { 222 | "output_type": "execute_result", 223 | "data": { 224 | "text/plain": [ 225 | "[{'name': 'multiply',\n", 226 | " 'args': {'second_int': 42.0, 'first_int': 5.0},\n", 227 | " 'id': '6f730aad-38d3-4166-bde3-74897bdb03e4',\n", 228 | " 'type': 'tool_call'}]" 229 | ] 230 | }, 231 | "metadata": {}, 232 | "execution_count": 7 233 | } 234 | ], 235 | "source": [ 236 | "# 使用自然語言輸入,測試模型是否可以正確調用工具來完成計算。\n", 237 | "msg = llm_with_tools.invoke(\"whats 5 times forty two\")\n", 238 | "msg.tool_calls # 檢視工具被調用時的參數" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "id": "11b17ab1", 244 | "metadata": { 245 | "id": "11b17ab1" 246 | }, 247 | "source": [ 248 | "## 建立多步操作鏈" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 8, 254 | "id": "65c7d722", 255 | "metadata": { 256 | "colab": { 257 | "base_uri": "https://localhost:8080/" 258 | }, 259 | "id": "65c7d722", 260 | "outputId": "3aba1626-7dd2-459c-fba6-7f824c65069a" 261 | }, 262 | "outputs": [ 263 | { 264 | "output_type": "execute_result", 265 | "data": { 266 | "text/plain": [ 267 | "92" 268 | ] 269 | }, 270 | "metadata": {}, 271 | "execution_count": 8 272 | } 273 | ], 274 | "source": [ 275 | "# 建立一個包含多步處理的操作鏈,使用模型輸出來觸發工具的執行。\n", 276 | "from operator import itemgetter\n", 277 | "\n", 278 | "chain = llm_with_tools | (lambda x: x.tool_calls[0]['args']) | multiply\n", 279 | "chain.invoke(\"What's four times 23\") # 測試新的處理鏈" 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "id": "99c589a0", 285 | "metadata": { 286 | "id": "99c589a0" 287 | }, 288 | "source": [ 289 | "## 定義更多工具" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 9, 295 | "id": "65fe48b5", 296 | "metadata": { 297 | "id": "65fe48b5" 298 | }, 299 | "outputs": [], 300 | "source": [ 301 | "# 定義加法和次方運算工具,擴展功能。\n", 302 | "@tool\n", 303 | "def add(first_int: int, second_int: int) -> int:\n", 304 | " \"Add two integers.\"\n", 305 | " return first_int + second_int\n", 306 | "\n", 307 | "@tool\n", 308 | "def exponentiate(base: int, exponent: int) -> int:\n", 309 | " \"Exponentiate the base to the exponent power.\"\n", 310 | " return base**exponent\n", 311 | "\n", 312 | "# 將工具集合起來以供後續使用。\n", 313 | "tools = [multiply, add, exponentiate]" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "id": "206db4ea", 319 | "metadata": { 320 | "id": "206db4ea" 321 | }, 322 | "source": [ 323 | "## 構建代理" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 10, 329 | "id": "df3e184d", 330 | "metadata": { 331 | "colab": { 332 | "base_uri": "https://localhost:8080/" 333 | }, 334 | "id": "df3e184d", 335 | "outputId": "159acd9e-2180-4128-9580-acc2d230ce1e" 336 | }, 337 | "outputs": [ 338 | { 339 | "output_type": "stream", 340 | "name": "stderr", 341 | "text": [ 342 | "/usr/local/lib/python3.10/dist-packages/langsmith/client.py:241: LangSmithMissingAPIKeyWarning: API key must be provided when using hosted LangSmith API\n", 343 | " warnings.warn(\n" 344 | ] 345 | }, 346 | { 347 | "output_type": "stream", 348 | "name": "stdout", 349 | "text": [ 350 | "================================\u001b[1m System Message \u001b[0m================================\n", 351 | "\n", 352 | "You are a helpful assistant\n", 353 | "\n", 354 | "=============================\u001b[1m Messages Placeholder \u001b[0m=============================\n", 355 | "\n", 356 | "\u001b[33;1m\u001b[1;3m{chat_history}\u001b[0m\n", 357 | "\n", 358 | "================================\u001b[1m Human Message \u001b[0m=================================\n", 359 | "\n", 360 | "\u001b[33;1m\u001b[1;3m{input}\u001b[0m\n", 361 | "\n", 362 | "=============================\u001b[1m Messages Placeholder \u001b[0m=============================\n", 363 | "\n", 364 | "\u001b[33;1m\u001b[1;3m{agent_scratchpad}\u001b[0m\n" 365 | ] 366 | } 367 | ], 368 | "source": [ 369 | "# 使用 LangChain 提供的工具,創建可以使用這些工具的代理。\n", 370 | "from langchain import hub\n", 371 | "from langchain.agents import AgentExecutor, create_tool_calling_agent\n", 372 | "\n", 373 | "# 獲取預定義的提示模板,可以替換為自定義提示。\n", 374 | "prompt = hub.pull('hwchase17/openai-tools-agent')\n", 375 | "prompt.pretty_print()\n", 376 | "\n", 377 | "# 創建工具調用代理,並傳入定義的工具和提示。\n", 378 | "agent = create_tool_calling_agent(llm, tools, prompt)" 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "id": "28ff9772", 384 | "metadata": { 385 | "id": "28ff9772" 386 | }, 387 | "source": [ 388 | "## 創建代理執行器" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 11, 394 | "id": "a1b7e55d", 395 | "metadata": { 396 | "id": "a1b7e55d" 397 | }, 398 | "outputs": [], 399 | "source": [ 400 | "# 創建代理執行器,將代理與工具結合,並啟用詳細模式。\n", 401 | "agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)" 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "id": "402fb06d", 407 | "metadata": { 408 | "id": "402fb06d" 409 | }, 410 | "source": [ 411 | "## 測試代理執行器" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 25, 417 | "id": "6d309fd3", 418 | "metadata": { 419 | "colab": { 420 | "base_uri": "https://localhost:8080/" 421 | }, 422 | "id": "6d309fd3", 423 | "outputId": "e81012a4-c7c3-4256-b64a-8e75e38f4585" 424 | }, 425 | "outputs": [ 426 | { 427 | "output_type": "stream", 428 | "name": "stdout", 429 | "text": [ 430 | "\n", 431 | "\n", 432 | "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", 433 | "\u001b[32;1m\u001b[1;3m\n", 434 | "Invoking: `exponentiate` with `{'exponent': 2.0, 'base': 243.0}`\n", 435 | "\n", 436 | "\n", 437 | "\u001b[0m\u001b[38;5;200m\u001b[1;3m59049\u001b[0m\u001b[32;1m\u001b[1;3m\n", 438 | "Invoking: `add` with `{'second_int': -15330.0, 'first_int': 59049.0}`\n", 439 | "\n", 440 | "\n", 441 | "\u001b[0m\u001b[33;1m\u001b[1;3m43719\u001b[0m\u001b[32;1m\u001b[1;3m243 的平方是 59049。59049 減去 15330 等於 43719。\n", 442 | "\u001b[0m\n", 443 | "\n", 444 | "\u001b[1m> Finished chain.\u001b[0m\n", 445 | "243 的平方是 59049。59049 減去 15330 等於 43719。\n", 446 | "\n" 447 | ] 448 | } 449 | ], 450 | "source": [ 451 | "# 測試執行器,讓它完成複雜的計算任務:\n", 452 | "# 3 的 5 次方 × (12 + 3) 的平方\n", 453 | "result = agent_executor.invoke(\n", 454 | " {\"input\": \"兩百四十三的平方減去15330等於多少?\"}\n", 455 | ")\n", 456 | "\n", 457 | "# 輸出最終結果\n", 458 | "print(result['output'])" 459 | ] 460 | }, 461 | { 462 | "cell_type": "markdown", 463 | "source": [ 464 | "# 新增工具:了解時間和時間差" 465 | ], 466 | "metadata": { 467 | "id": "0LhfVjmfor90" 468 | }, 469 | "id": "0LhfVjmfor90" 470 | }, 471 | { 472 | "cell_type": "code", 473 | "source": [ 474 | "from datetime import datetime\n", 475 | "from langchain_core.tools import tool\n", 476 | "\n", 477 | "# 工具 1:獲取當前時間\n", 478 | "@tool\n", 479 | "def current_time() -> str:\n", 480 | " \"\"\"Get the current date and time as a string.\"\"\"\n", 481 | " now = datetime.now()\n", 482 | " return now.strftime(\"%Y-%m-%d %H:%M:%S\")\n", 483 | "\n", 484 | "# 工具 2:計算兩個日期的差異天數\n", 485 | "@tool\n", 486 | "def date_difference(date1: str, date2: str) -> int:\n", 487 | " \"\"\"\n", 488 | " Calculate the difference in days between two dates.\n", 489 | " The dates must be in the format 'YYYY-MM-DD'.\n", 490 | " \"\"\"\n", 491 | " try:\n", 492 | " d1 = datetime.strptime(date1, \"%Y-%m-%d\")\n", 493 | " d2 = datetime.strptime(date2, \"%Y-%m-%d\")\n", 494 | " return abs((d2 - d1).days)\n", 495 | " except ValueError as e:\n", 496 | " raise ValueError(\"Invalid date format. Please use 'YYYY-MM-DD'.\") from e\n" 497 | ], 498 | "metadata": { 499 | "id": "HpJ0Bnj1l5Ss" 500 | }, 501 | "id": "HpJ0Bnj1l5Ss", 502 | "execution_count": 26, 503 | "outputs": [] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "source": [ 508 | "# 將新工具添加到工具列表\n", 509 | "tools.extend([current_time, date_difference])\n", 510 | "\n", 511 | "# 更新代理執行器\n", 512 | "agent = create_tool_calling_agent(llm, tools, prompt)\n", 513 | "agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)\n" 514 | ], 515 | "metadata": { 516 | "id": "cCWbTbOKl83N" 517 | }, 518 | "id": "cCWbTbOKl83N", 519 | "execution_count": 27, 520 | "outputs": [] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "source": [ 525 | "# 使用自然語言輸入,測試模型是否可以正確調用工具來完成計算。\n", 526 | "result = llm.invoke(\"現在時間是?\")\n", 527 | "print(result.content)" 528 | ], 529 | "metadata": { 530 | "colab": { 531 | "base_uri": "https://localhost:8080/", 532 | "height": 36 533 | }, 534 | "id": "7LTmrK4Jmw6E", 535 | "outputId": "2d998302-bf95-416e-9e89-111cae1e7e7c" 536 | }, 537 | "id": "7LTmrK4Jmw6E", 538 | "execution_count": 34, 539 | "outputs": [ 540 | { 541 | "output_type": "execute_result", 542 | "data": { 543 | "text/plain": [ 544 | "'我不知道現在時間。我是一個大型語言模型,沒有存取即時資訊,例如時間。 請查看你的電腦或手機的時鐘。\\n'" 545 | ], 546 | "application/vnd.google.colaboratory.intrinsic+json": { 547 | "type": "string" 548 | } 549 | }, 550 | "metadata": {}, 551 | "execution_count": 34 552 | } 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "source": [ 558 | "result = agent_executor.invoke({\"input\": \"現在時間是?\"})\n", 559 | "print(result['output'])\n" 560 | ], 561 | "metadata": { 562 | "colab": { 563 | "base_uri": "https://localhost:8080/" 564 | }, 565 | "id": "3ZPT1wP3l-0K", 566 | "outputId": "f1cb8901-9f12-4539-b6eb-4e9652472d48" 567 | }, 568 | "id": "3ZPT1wP3l-0K", 569 | "execution_count": 28, 570 | "outputs": [ 571 | { 572 | "output_type": "stream", 573 | "name": "stdout", 574 | "text": [ 575 | "\n", 576 | "\n", 577 | "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", 578 | "\u001b[32;1m\u001b[1;3m\n", 579 | "Invoking: `current_time` with `{}`\n", 580 | "\n", 581 | "\n", 582 | "\u001b[0m\u001b[36;1m\u001b[1;3m2024-11-19 08:45:30\u001b[0m\u001b[32;1m\u001b[1;3m現在時間是 2024年11月19日 08:45:30。\n", 583 | "\u001b[0m\n", 584 | "\n", 585 | "\u001b[1m> Finished chain.\u001b[0m\n", 586 | "現在時間是 2024年11月19日 08:45:30。\n", 587 | "\n" 588 | ] 589 | } 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "source": [ 595 | "%%time\n", 596 | "\n", 597 | "# 使用自然語言輸入,測試模型是否可以正確調用工具來完成計算。\n", 598 | "result = llm.invoke(\"請問2024年2月18日和12月19日相差幾天?\")\n", 599 | "print(result.content)" 600 | ], 601 | "metadata": { 602 | "colab": { 603 | "base_uri": "https://localhost:8080/" 604 | }, 605 | "id": "GureTaeqm3L1", 606 | "outputId": "47415d70-7651-4d68-a7b1-e8d2e909f04a" 607 | }, 608 | "id": "GureTaeqm3L1", 609 | "execution_count": 45, 610 | "outputs": [ 611 | { 612 | "output_type": "stream", 613 | "name": "stdout", 614 | "text": [ 615 | "從2024年2月18日到2024年12月19日,相差 **305 天**。\n", 616 | "\n", 617 | "CPU times: user 228 ms, sys: 43.4 ms, total: 271 ms\n", 618 | "Wall time: 38.5 s\n" 619 | ] 620 | } 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "source": [ 626 | "%%time\n", 627 | "\n", 628 | "result = agent_executor.invoke({\n", 629 | " \"input\": \"請問2024年2月18日和12月19日相差幾天?\"\n", 630 | "})\n", 631 | "print(result['output'])" 632 | ], 633 | "metadata": { 634 | "colab": { 635 | "base_uri": "https://localhost:8080/" 636 | }, 637 | "id": "giERbGRKmHF-", 638 | "outputId": "9ecff309-874e-4293-b051-50bf4e317a0f" 639 | }, 640 | "id": "giERbGRKmHF-", 641 | "execution_count": 47, 642 | "outputs": [ 643 | { 644 | "output_type": "stream", 645 | "name": "stdout", 646 | "text": [ 647 | "\n", 648 | "\n", 649 | "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", 650 | "\u001b[32;1m\u001b[1;3m\n", 651 | "Invoking: `date_difference` with `{'date2': '2024-12-19', 'date1': '2024-02-18'}`\n", 652 | "\n", 653 | "\n", 654 | "\u001b[0m\u001b[33;1m\u001b[1;3m305\u001b[0m\u001b[32;1m\u001b[1;3m2024年2月18日和12月19日相差305天。\n", 655 | "\u001b[0m\n", 656 | "\n", 657 | "\u001b[1m> Finished chain.\u001b[0m\n", 658 | "2024年2月18日和12月19日相差305天。\n", 659 | "\n", 660 | "CPU times: user 128 ms, sys: 31.5 ms, total: 160 ms\n", 661 | "Wall time: 14.2 s\n" 662 | ] 663 | } 664 | ] 665 | } 666 | ], 667 | "metadata": { 668 | "colab": { 669 | "provenance": [] 670 | }, 671 | "language_info": { 672 | "name": "python" 673 | }, 674 | "kernelspec": { 675 | "name": "python3", 676 | "display_name": "Python 3" 677 | } 678 | }, 679 | "nbformat": 4, 680 | "nbformat_minor": 5 681 | } -------------------------------------------------------------------------------- /Simon_LLM_Application_Google_Vertex_AI_GenAISDK_Full_Method.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "include_colab_link": true 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "source": [ 31 | "# Google Gen AI SDK\n", 32 | "\n", 33 | "google-genai 是用於與 Google 的 Generative AI API 互動的初始 Python 用戶端程式庫。\n", 34 | "\n", 35 | "Google Gen AI Python SDK 為開發人員提供了一個接口,可以將 Google 的生成模型整合到他們的 Python 應用程式中。它支援Gemini 開發者 API和Vertex AI API。" 36 | ], 37 | "metadata": { 38 | "id": "w3-A0eIq7s73" 39 | } 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "source": [ 44 | "## I. Installation and Configure\n" 45 | ], 46 | "metadata": { 47 | "id": "X__fXs-M7z9t" 48 | } 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "source": [ 53 | "#### 透過 pip 安裝 google-genai python package" 54 | ], 55 | "metadata": { 56 | "id": "k9CiucOKH6Se" 57 | } 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "collapsed": true, 64 | "id": "WqYS-AoA7kdL" 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "!pip install --upgrade -q google-genai" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "source": [ 74 | "#### 設定Google Cloud 項目\n", 75 | "\n", 76 | "設定您的Google Cloud 專案並啟用Vertex AI API。\n", 77 | "\n", 78 | "- Step 1: In the Google Cloud console, on the project selector page, select or create a Google Cloud project.\n", 79 | "
[Go to project selector](https://console.cloud.google.com/projectselector2/home/dashboard?hl=zh-tw)\n", 80 | "\n", 81 | "- Step 2: Make sure that [billing is enabled](https://cloud.google.com/billing/docs/how-to/verify-billing-enabled?hl=zh-tw#confirm_billing_is_enabled_on_a_project) for your Google Cloud project .\n", 82 | "\n", 83 | "- Step 3: Enable the Vertex AI API.\n", 84 | "
[Enable the API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com&hl=zh-cn)" 85 | ], 86 | "metadata": { 87 | "id": "fjne11MxxWcB" 88 | } 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "source": [ 93 | "#### Vertex AI 設定" 94 | ], 95 | "metadata": { 96 | "id": "IVBJkXDYHwMY" 97 | } 98 | }, 99 | { 100 | "cell_type": "code", 101 | "source": [ 102 | "# gcloud CLI 進行身份驗證\n", 103 | "!gcloud auth application-default login" 104 | ], 105 | "metadata": { 106 | "id": "pqlArhn3GqQd", 107 | "outputId": "70481cba-3244-4df9-e351-155d58249ecf", 108 | "colab": { 109 | "base_uri": "https://localhost:8080/" 110 | } 111 | }, 112 | "execution_count": null, 113 | "outputs": [ 114 | { 115 | "output_type": "stream", 116 | "name": "stdout", 117 | "text": [ 118 | "Go to the following link in your browser, and complete the sign-in prompts:\n", 119 | "\n", 120 | " https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com&redirect_uri=https%3A%2F%2Fsdk.cloud.google.com%2Fapplicationdefaultauthcode.html&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fsqlservice.login&state=XfjW1uCCL5zAUr4ERMcflrJx0f2D9q&prompt=consent&token_usage=remote&access_type=offline&code_challenge=pEhxzM3exbdVVHIeJc4HEv1SfiDgrmMiLIiJBi61M3Y&code_challenge_method=S256\n", 121 | "\n", 122 | "Once finished, enter the verification code provided in your browser: 4/0Ab_5qlnh1RmX3_iBcBcEObD1mTnz7s3xc_iGdt9iGg4kOq0dOqpG4yIqfAn85HbYE-ymVQ\n", 123 | "\n", 124 | "Credentials saved to file: [/content/.config/application_default_credentials.json]\n", 125 | "\n", 126 | "These credentials will be used by any library that requests Application Default Credentials (ADC).\n", 127 | "\n", 128 | "Quota project \"leafy-bond-456001-r7\" was added to ADC which can be used by Google client libraries for billing and quota. Note that some services may still bill the project owning the resource.\n" 129 | ] 130 | } 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "source": [ 136 | "## II. Import and create a client" 137 | ], 138 | "metadata": { 139 | "id": "O4V-eXd-76IA" 140 | } 141 | }, 142 | { 143 | "cell_type": "code", 144 | "source": [ 145 | "import os\n", 146 | "vertex_ai_project = 'fill-in-the-project-id-here' #@param {type:\"string\"}\n", 147 | "\n", 148 | "os.environ[\"GOOGLE_CLOUD_PROJECT\"] = vertex_ai_project\n", 149 | "os.environ[\"GOOGLE_CLOUD_LOCATION\"] = \"global\"\n", 150 | "os.environ[\"GOOGLE_GENAI_USE_VERTEXAI\"] = \"True\"" 151 | ], 152 | "metadata": { 153 | "id": "8BhtTw9N7509" 154 | }, 155 | "execution_count": null, 156 | "outputs": [] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "source": [ 161 | "## III. API 測試" 162 | ], 163 | "metadata": { 164 | "id": "kNxa7sC__6Q8" 165 | } 166 | }, 167 | { 168 | "cell_type": "code", 169 | "source": [ 170 | "from google import genai\n", 171 | "from google.genai.types import HttpOptions\n", 172 | "\n", 173 | "client = genai.Client(http_options=HttpOptions(api_version=\"v1\"))\n", 174 | "\n", 175 | "response = client.models.generate_content(\n", 176 | " model='gemini-2.5-flash-preview-04-17', contents='請問為何天空是藍色的?'\n", 177 | ")\n", 178 | "print(response.text)" 179 | ], 180 | "metadata": { 181 | "colab": { 182 | "base_uri": "https://localhost:8080/" 183 | }, 184 | "id": "tDiXAroK_9Vi", 185 | "outputId": "89d3bc4c-b335-4f61-988b-7c192a04657b" 186 | }, 187 | "execution_count": null, 188 | "outputs": [ 189 | { 190 | "output_type": "stream", 191 | "name": "stdout", 192 | "text": [ 193 | "天空之所以呈現藍色,主要是因為太陽光進入地球大氣層時,與空氣中的氣體分子(主要是氮分子和氧分子)以及微小塵埃顆粒發生了「**瑞利散射**」(Rayleigh scattering)的現象。\n", 194 | "\n", 195 | "以下是詳細的解釋步驟:\n", 196 | "\n", 197 | "1. **太陽光是白色光:** 太陽光看起來是白色的,但它實際上是由不同波長(也就是不同顏色)的光混合而成,就像彩虹一樣,包含了紅、橙、黃、綠、藍、靛、紫等各種顏色。不同顏色的光有不同的波長,其中紅光的波長最長,紫光的波長最短,藍光的波長屬於較短的一端。\n", 198 | "\n", 199 | "2. **光線通過大氣層:** 當太陽光穿過地球的大氣層時,會與空氣中的氣體分子和懸浮微粒發生作用。\n", 200 | "\n", 201 | "3. **瑞利散射是關鍵:** 當光線遇到的大小遠小於其波長的粒子時(例如可見光遇到空氣中的氮分子和氧分子),就會發生瑞利散射。瑞利散射有一個重要的特性:**散射的強度與光線波長的四次方成反比**。\n", 202 | "\n", 203 | "4. **藍光被強烈散射:** 由於藍光和紫光的波長比較短,根據瑞利散射的原理,它們被空氣分子散射的強度要遠遠大於波長較長的紅光、橙光等。簡單來說,藍光(和紫光)更容易被“彈開”,向四面八方散開。\n", 204 | "\n", 205 | "5. **我們看到的是散射光:** 當我們抬頭看向天空的任何方向(而不是直接看向太陽時),我們看到的主要是這些被空氣分子散射開來的光線。因為藍光被最有效地散射到四面八方,所以大部分進入我們眼睛的光是藍色的,我們就覺得天空是藍色的。\n", 206 | "\n", 207 | "6. **為什麼不是紫色?** 雖然紫光的波長比藍光更短,散射更強烈,但我們看到的顏色主要是藍色,原因有幾個:一是太陽光譜中藍光的能量相對比紫光高一些;二是人眼的視網膜對藍光比紫光更敏感。因此,雖然紫光和藍光都被強烈散射,我們最終感覺到的顏色以藍色為主。\n", 208 | "\n", 209 | "**總結來說:** 天空的藍色並不是天空本身發出的光,而是太陽光中的藍光,在穿透地球大氣層時,被空氣中的氣體分子以瑞利散射的方式強力散射開來,充滿了整個天空,並進入我們的眼睛,使我們看到了藍色的天空。\n" 210 | ] 211 | } 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "source": [ 217 | "import requests\n", 218 | "import subprocess\n", 219 | "import json\n", 220 | "\n", 221 | "# 取得 gcloud access token\n", 222 | "access_token = subprocess.check_output(\n", 223 | " ['gcloud', 'auth', 'print-access-token']\n", 224 | ").decode('utf-8').strip()\n", 225 | "\n", 226 | "# 設定專案與模型資訊\n", 227 | "model_id = \"gemini-2.0-flash-001\"\n", 228 | "url = f'https://us-central1-aiplatform.googleapis.com/v1/projects/{vertex_ai_project}/locations/us-central1/publishers/google/models/{model_id}:generateContent'\n", 229 | "\n", 230 | "# 設定請求 headers\n", 231 | "headers = {\n", 232 | " 'Authorization': f'Bearer {access_token}',\n", 233 | " 'Content-Type': 'application/json'\n", 234 | "}\n", 235 | "\n", 236 | "# 設定請求 payload\n", 237 | "payload = {\n", 238 | " \"contents\": {\n", 239 | " \"role\": \"user\",\n", 240 | " \"parts\": [\n", 241 | " {\n", 242 | " \"text\": \"請問:「庭院深深深幾許」總共有幾種排列方式?\"\n", 243 | " }\n", 244 | " ]\n", 245 | " }\n", 246 | "}\n", 247 | "\n", 248 | "# 發送 POST 請求\n", 249 | "response = requests.post(url, headers=headers, data=json.dumps(payload))\n", 250 | "\n", 251 | "# 輸出回應結果\n", 252 | "if response.status_code == 200:\n", 253 | " print(\"Response:\")\n", 254 | " print(response.json())\n", 255 | " print(\"answer:\")\n", 256 | " print(response.json()['candidates'][0]['content']['parts'][0]['text'])\n", 257 | "else:\n", 258 | " print(f\"Error: {response.status_code}\")\n", 259 | " print(response.text)\n" 260 | ], 261 | "metadata": { 262 | "colab": { 263 | "base_uri": "https://localhost:8080/" 264 | }, 265 | "id": "ACp0w7QAMUaT", 266 | "outputId": "3540963c-a073-4b11-d100-dd0f609472b8" 267 | }, 268 | "execution_count": null, 269 | "outputs": [ 270 | { 271 | "output_type": "stream", 272 | "name": "stdout", 273 | "text": [ 274 | "Response:\n", 275 | "{'candidates': [{'content': {'role': 'model', 'parts': [{'text': '「庭院深深深幾許」這六個字,其中「深」出現了三次,其他字各出現一次。因此,我們要計算這六個字有多少種不同的排列方式,可以使用以下公式:\\n\\n總排列數 = 6! / (3! * 1! * 1! * 1!) = 6! / 3!\\n\\n其中:\\n\\n* 6! (6 階乘) 是假設所有字都不同的情況下的排列數,即 6 * 5 * 4 * 3 * 2 * 1 = 720\\n* 3! (3 階乘) 是因為「深」字重複了三次,需要除以重複排列的次數,即 3 * 2 * 1 = 6\\n\\n因此,總排列數 = 720 / 6 = 120\\n\\n所以,「庭院深深深幾許」總共有 **120** 種不同的排列方式。\\n'}]}, 'finishReason': 'STOP', 'avgLogprobs': -0.1817999296781549}], 'usageMetadata': {'promptTokenCount': 16, 'candidatesTokenCount': 209, 'totalTokenCount': 225, 'trafficType': 'ON_DEMAND', 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 16}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 209}]}, 'modelVersion': 'gemini-2.0-flash-001', 'createTime': '2025-05-07T17:15:15.762268Z', 'responseId': 'I5UbaJzDLoG-nvgP_Zr8wAU'}\n", 276 | "answer:\n", 277 | "「庭院深深深幾許」這六個字,其中「深」出現了三次,其他字各出現一次。因此,我們要計算這六個字有多少種不同的排列方式,可以使用以下公式:\n", 278 | "\n", 279 | "總排列數 = 6! / (3! * 1! * 1! * 1!) = 6! / 3!\n", 280 | "\n", 281 | "其中:\n", 282 | "\n", 283 | "* 6! (6 階乘) 是假設所有字都不同的情況下的排列數,即 6 * 5 * 4 * 3 * 2 * 1 = 720\n", 284 | "* 3! (3 階乘) 是因為「深」字重複了三次,需要除以重複排列的次數,即 3 * 2 * 1 = 6\n", 285 | "\n", 286 | "因此,總排列數 = 720 / 6 = 120\n", 287 | "\n", 288 | "所以,「庭院深深深幾許」總共有 **120** 種不同的排列方式。\n", 289 | "\n" 290 | ] 291 | } 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "source": [ 297 | "## V. 類型配置" 298 | ], 299 | "metadata": { 300 | "id": "9TXqUaaNBEIr" 301 | } 302 | }, 303 | { 304 | "cell_type": "code", 305 | "source": [ 306 | "from google.genai.types import GenerateContentConfig, Part\n", 307 | "\n", 308 | "response = client.models.generate_content(\n", 309 | " model='gemini-2.0-flash-001',\n", 310 | " contents=Part.from_text(text='請問為何天空是藍色的?'),\n", 311 | " config=GenerateContentConfig(\n", 312 | " temperature=0,\n", 313 | " top_p=0.95,\n", 314 | " top_k=20,\n", 315 | " candidate_count=1,\n", 316 | " seed=5,\n", 317 | " max_output_tokens=100,\n", 318 | " stop_sequences=['STOP!'],\n", 319 | " presence_penalty=0.0,\n", 320 | " frequency_penalty=0.0,\n", 321 | " ),\n", 322 | ")\n", 323 | "\n", 324 | "print(response.text)" 325 | ], 326 | "metadata": { 327 | "colab": { 328 | "base_uri": "https://localhost:8080/" 329 | }, 330 | "id": "nmbOiACcBQ4T", 331 | "outputId": "41de856b-a310-4a59-e487-8e3be9038ddb" 332 | }, 333 | "execution_count": null, 334 | "outputs": [ 335 | { 336 | "output_type": "stream", 337 | "name": "stdout", 338 | "text": [ 339 | "天空之所以是藍色的,主要是因為一種叫做**瑞利散射 (Rayleigh scattering)** 的物理現象。以下是詳細的解釋:\n", 340 | "\n", 341 | "* **太陽光是混合光:** 太陽光並非單一顏色,而是由紅、橙、黃、綠、藍、靛、紫等各種顏色的光混合而成。\n", 342 | "\n", 343 | "* **大氣層中的微粒:** 地球的大氣層中充滿了各種微小的粒子,例如氮氣、氧\n" 344 | ] 345 | } 346 | ] 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "source": [ 351 | "## V. Function Calling\n", 352 | "\n", 353 | "此處示範,統整所有行政院最新新聞消息的 function calling。\n", 354 | "\n", 355 | "API 位置:https://opendata.ey.gov.tw/api/index.html" 356 | ], 357 | "metadata": { 358 | "id": "pDMXxw9gCAQq" 359 | } 360 | }, 361 | { 362 | "cell_type": "code", 363 | "source": [ 364 | "from typing import List, Union\n", 365 | "from datetime import datetime, date\n", 366 | "import requests\n", 367 | "\n", 368 | "def normalize_date(date_str: Union[str, date]) -> str:\n", 369 | " \"\"\"將輸入轉為 'YYYY/MM/DD' 字串格式\"\"\"\n", 370 | " if isinstance(date_str, date):\n", 371 | " return date_str.strftime('%Y/%m/%d')\n", 372 | " try:\n", 373 | " parsed = datetime.strptime(date_str, '%Y-%m-%d')\n", 374 | " return parsed.strftime('%Y/%m/%d')\n", 375 | " except ValueError:\n", 376 | " return date_str # 假設已是 YYYY/MM/DD\n", 377 | "\n", 378 | "def get_ey_news(start_date: str, end_date: str) -> List[dict]:\n", 379 | " \"\"\"取得行政院公開新聞資料,日期會強制轉為 'YYYY/MM/DD'\"\"\"\n", 380 | " start_date_fmt = normalize_date(start_date)\n", 381 | " end_date_fmt = normalize_date(end_date)\n", 382 | "\n", 383 | " url = \"https://opendata.ey.gov.tw/api/ExecutiveYuan/NewsEy\"\n", 384 | " params = {\n", 385 | " \"StartDate\": start_date_fmt,\n", 386 | " \"EndDate\": end_date_fmt,\n", 387 | " \"MaxSize\": 10,\n", 388 | " \"IsRemoveHtmlTag\": True\n", 389 | " }\n", 390 | " response = requests.get(url, params=params)\n", 391 | " response.raise_for_status()\n", 392 | " return response.json()\n" 393 | ], 394 | "metadata": { 395 | "id": "Rp2vkqe0CCqU" 396 | }, 397 | "execution_count": null, 398 | "outputs": [] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "source": [ 403 | "output = get_ey_news(\"2025/04/01\", \"2025/04/30\")\n", 404 | "output[0]" 405 | ], 406 | "metadata": { 407 | "colab": { 408 | "base_uri": "https://localhost:8080/" 409 | }, 410 | "id": "qnzCdFFfFbRm", 411 | "outputId": "81eb961e-aa42-43ab-def6-8e12bda6d88d" 412 | }, 413 | "execution_count": null, 414 | "outputs": [ 415 | { 416 | "output_type": "execute_result", 417 | "data": { 418 | "text/plain": [ 419 | "{'標題': '主持政院性平會議 卓揆:持續改善性別薪資差距、均衡大專校院師資性別比例、增設性別友善廁所',\n", 420 | " '內容': '

行政院長卓榮泰今(30)日主持「行政院性別平等會第32次委員會議」時表示,政府致力推動性別平等政策及完善相關軟硬體設備,包括定期揭露各行各業的性別薪資資料,積極推動改善策略;促使大專校院系所逐步達成更均衡的師資性別比例,以及透過增設性別友善廁所等各項措施,以營造尊重多元包容與公義的性別平等社會。
\\xa0

卓院長致詞時指出,全面提升性別平等及多元友善的社會環境,一直是政府在各項施政的重要課題,並致力推動性別平等相關政策及完善軟硬體設備。卓院長感謝委員今日撥冗出席提供建言,期盼政府與民間一起合作,讓公務機關擁有更正確方向及具體目標,落實推動性別平等各項措施,共同營造尊重、多元、包容、公義的性別平等社會。
\\xa0

隨後,卓院長在聽取行政院主計總處「運用公務大數據觀察性別薪資概況與分析」報告後表示,為協助各部會研析改善性別薪資差距的具體策進作為,請主計總處提供公務大數據中各類性別統計分析資料,協助勞動部、金融監督管理委員會及各主管機關後續資料運用事宜,並評估增列不利處境者的性別薪資統計,如身心障礙人士、原住民等,以利研擬更全面的策進作為。
\\xa0

卓院長請勞動部強化公司同工同酬的自我檢核及改善,並協同主計總處及各主管機關,定期揭露各行各業的性別薪資資料,積極推動改善策略。卓院長指出,政府自明(115)年起,要求實收資本額100億元以上的上市櫃公司,揭露男性及女性非主管職務,全時員工的薪資平均數及薪資中位數資訊,請金管會加強宣導,展現政府改善性別薪資差距的決心;另請金管會針對上市櫃公司所揭露的統計數據,適時對外詳細說明,讓民眾更了解真實性別薪資概況。
\\xa0

針對教育部、國防部、內政部「大專校院系所專任教師任一性別比例過低之具體改進策略」報告,卓院長表示,性別平等是政府施政重要核心價值,請教育部、國防部及內政部,針對僅有單一性別專任教師的系所,瞭解原因,視情形依《聯合國消除對婦女一切形式歧視公約》(CEDAW)規定,評估訂定「暫行特別措施」,並請勞動部協助確認後續相關精進招聘規定的合宜性。
\\xa0

卓院長亦指示教育部、國防部及內政部蒐集國外推動相關經驗,確保主管大專校院所招聘教師在專業能力符合要求,並兼顧學生受教權的前提下,積極打造性別友善、尊重多元且更具包容性的職場環境,並依所訂目標、期程及策略,逐步達成更為均衡的師資性別比例。
\\xa0

有關環境部「性別友善廁所倍增行動方案辦理情形」報告,卓院長表示,為打造性別友善與包容性的生活環境,請環境部持續評估各地設置性別友善廁所的需求情形,並兼顧衡平性,針對較為不足的縣市進行輔導,並加強督導查核新增設廁所是否符合性平友善規範。
\\xa0

此外,卓院長請環境部蒐集社會各界意見,精進作法及改善策略,以提高公廁的多元友善度,推動過程中亦請加強社會大眾對性別友善廁所的認知,消除社會大眾使用疑慮;同時強化鼓勵私部門設置性別友善廁所參與的機制,期望公私部門共同營造尊重多元、包容性的友善公共空間。
\\xa0

針對委員在會中提及新住民女性遭不法金融剝削,其個資及照片遭上網公開一事,卓院長請內政部、法務部嚴查不法行為;請勞動部加強保障新住民及外籍移工等弱勢族群工作權益;請金管會向新住民加強宣導普惠金融措施,並請林明昕政務委員協調統籌相關部會妥予應處。此外,行政院會已於上(3)月通過「個人資料保護委員會組織法」草案,俟該法三讀通過及「個人資料保護委員會」成立後,將有助強化及落實個人資料保護工作。

',\n", 421 | " '上版日期': '2025/04/30',\n", 422 | " '來源網址': 'https://www.ey.gov.tw/Page/9277F759E41CCD91/65c458a2-a533-4216-8165-2fc6eb037c91'}" 423 | ] 424 | }, 425 | "metadata": {}, 426 | "execution_count": 8 427 | } 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "source": [ 433 | "response = client.models.generate_content(\n", 434 | " model='gemini-2.0-flash-001',\n", 435 | " contents=\"幫我找 2025/04/01 至 2025/04/30 行政院發布的新聞標題和發佈時間\",\n", 436 | " config=GenerateContentConfig(tools=[get_ey_news]),\n", 437 | ")\n", 438 | "\n", 439 | "print(response.text)" 440 | ], 441 | "metadata": { 442 | "colab": { 443 | "base_uri": "https://localhost:8080/" 444 | }, 445 | "id": "vo9avx7CCpwk", 446 | "outputId": "885af06b-2ad1-4624-e52e-a5d528ba6d5f" 447 | }, 448 | "execution_count": null, 449 | "outputs": [ 450 | { 451 | "output_type": "stream", 452 | "name": "stdout", 453 | "text": [ 454 | "2025/04/01 至 2025/04/30 行政院發布的新聞標題和發佈時間如下:\n", 455 | "\n", 456 | "* **2025/04/30**: 主持政院性平會議 卓揆:持續改善性別薪資差距、均衡大專校院師資性別比例、增設性別友善廁所\n", 457 | "* **2025/04/30**: 表揚國土保育有功人員 卓揆期勉持續守護國人健康、維護環境正義、確保國家資源永續發展\n", 458 | "* **2025/04/28**: 政府已啟動「精準打擊」機制 全力壓制詐騙犯罪及降低民眾財產損失金額\n", 459 | "* **2025/04/28**: 卓揆拜會立法院各黨團 盼朝野共同支持並儘速通過「因應國際情勢強化經濟社會及國土安全韌性特別條例」\n", 460 | "* **2025/04/24**: 卓揆赴教宗方濟各靈堂追思 緬懷教宗一生勤儉、關愛世人\n", 461 | "* **2025/04/24**: 卓揆拍板通過「因應國際情勢強化經濟社會及國土安全韌性特別條例」草案 編列4,100億元支持產業、安定就業、照顧民生、強化韌性\n", 462 | "* **2025/04/24**: 卓揆指示最短時間內全數提出解凍案 盼立法院儘速審議 讓國家快速向前行\n", 463 | "* **2025/04/24**: 卓揆:持續精進打詐工作 守護民眾財產安全、維繫社會安定\n", 464 | "* **2025/04/24**: 卓揆:放寬支持方案申請要件以貼近實際需求 及時支持受美國關稅政策影響產業、廠商及勞工\n", 465 | "* **2025/04/24**: 政院通過「113年度中央政府總決算暨附屬單位決算及綜計表」及「中央政府前瞻基礎建設計畫第4期特別決算」 連8年總決算歲入歲出有賸餘 嚴守財政、開源節流\n" 466 | ] 467 | } 468 | ] 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "source": [ 473 | "## VII. Pydantic-JSON 回應架構" 474 | ], 475 | "metadata": { 476 | "id": "gh0bJ1lLFiCL" 477 | } 478 | }, 479 | { 480 | "cell_type": "code", 481 | "source": [ 482 | "from pydantic import BaseModel\n", 483 | "\n", 484 | "\n", 485 | "class CountryInfo(BaseModel):\n", 486 | " name: str\n", 487 | " population: int\n", 488 | " capital: str\n", 489 | " continent: str\n", 490 | " gdp: int\n", 491 | " official_language: str\n", 492 | " total_area_sq_mi: int\n", 493 | "\n", 494 | "\n", 495 | "response = client.models.generate_content(\n", 496 | " model='gemini-2.0-flash-001',\n", 497 | " contents='請給我台灣的相關資訊',\n", 498 | " config=GenerateContentConfig(\n", 499 | " response_mime_type='application/json',\n", 500 | " response_schema=CountryInfo,\n", 501 | " ),\n", 502 | ")\n", 503 | "print(response.text)" 504 | ], 505 | "metadata": { 506 | "colab": { 507 | "base_uri": "https://localhost:8080/" 508 | }, 509 | "id": "Xj1OpsPmFhuX", 510 | "outputId": "7ae4d516-065f-4244-cfb7-d789fe790098" 511 | }, 512 | "execution_count": null, 513 | "outputs": [ 514 | { 515 | "output_type": "stream", 516 | "name": "stdout", 517 | "text": [ 518 | "{\n", 519 | " \"name\": \"Taiwan\",\n", 520 | " \"population\": 23819742,\n", 521 | " \"capital\": \"Taipei\",\n", 522 | " \"continent\": \"Asia\",\n", 523 | " \"gdp\": 759100000000,\n", 524 | " \"official_language\": \"Mandarin Chinese\",\n", 525 | " \"total_area_sq_mi\": 13974\n", 526 | "}\n" 527 | ] 528 | } 529 | ] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "source": [ 534 | "## VIII. Streaming Output" 535 | ], 536 | "metadata": { 537 | "id": "kP4_R0lbGAew" 538 | } 539 | }, 540 | { 541 | "cell_type": "code", 542 | "source": [ 543 | "for chunk in client.models.generate_content_stream(\n", 544 | " model='gemini-2.0-flash-001', contents='請用100字講一個大野狼的故事'\n", 545 | "):\n", 546 | " print(chunk.text, end='')" 547 | ], 548 | "metadata": { 549 | "colab": { 550 | "base_uri": "https://localhost:8080/" 551 | }, 552 | "id": "ZonBfGbJGADQ", 553 | "outputId": "1383e8d2-c5af-4ab5-f709-5ff8ef240aaa" 554 | }, 555 | "execution_count": null, 556 | "outputs": [ 557 | { 558 | "output_type": "stream", 559 | "name": "stdout", 560 | "text": [ 561 | "從前,森林裡住著一隻飢餓的大野狼。牠偽裝成老奶奶,想騙小紅帽進屋吃掉。小紅帽不疑有他,進了屋子,卻發現「奶奶」的眼睛、耳朵和嘴巴都大的可怕!大野狼露出了真面目,正要撲向小紅帽時,一位勇敢的獵人衝進來,救了小紅帽,並把大野狼趕出了森林。從此以後,小紅帽再也不敢輕易相信陌生人了。\n" 562 | ] 563 | } 564 | ] 565 | }, 566 | { 567 | "cell_type": "markdown", 568 | "source": [ 569 | "[ 只給 Vertex AI 用戶使用 ] 計算 token 數量" 570 | ], 571 | "metadata": { 572 | "id": "nEkKcNBTIzni" 573 | } 574 | }, 575 | { 576 | "cell_type": "code", 577 | "source": [ 578 | "response = client.models.compute_tokens(\n", 579 | " model='gemini-2.0-flash-001',\n", 580 | " contents='請問天空為何是藍色的?',\n", 581 | ")\n", 582 | "print(response)" 583 | ], 584 | "metadata": { 585 | "id": "lxqVjXNRI0AL", 586 | "outputId": "9c1119b4-dab3-475d-ed30-316d05c976d2", 587 | "colab": { 588 | "base_uri": "https://localhost:8080/" 589 | } 590 | }, 591 | "execution_count": null, 592 | "outputs": [ 593 | { 594 | "output_type": "stream", 595 | "name": "stdout", 596 | "text": [ 597 | "tokens_info=[TokensInfo(role='user', token_ids=[160287, 90420, 171755, 235427, 238752, 35219, 235544], tokens=[b'\\xe8\\xab\\x8b\\xe5\\x95\\x8f', b'\\xe5\\xa4\\xa9\\xe7\\xa9\\xba', b'\\xe7\\x82\\xba\\xe4\\xbd\\x95', b'\\xe6\\x98\\xaf', b'\\xe8\\x97\\x8d', b'\\xe8\\x89\\xb2\\xe7\\x9a\\x84', b'\\xef\\xbc\\x9f'])]\n" 598 | ] 599 | } 600 | ] 601 | } 602 | ] 603 | } -------------------------------------------------------------------------------- /Simon_LLM_Application_VLLM_Tool_Google_Gemma3_Model_Service.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "dfsDR_omdNea" 7 | }, 8 | "source": [ 9 | "# Use VLLM to deploy on Google Gemma 3 Model\n", 10 | "\n", 11 | "本筆記本示範如何使用 vLLM 部署 Gemma 模型並進行查詢。vLLM 是一個快速且易於使用的大型語言模型推論與服務框架,並內建支援 Gemma 3 模型的部署。" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": { 17 | "id": "o6Pwu9IkZk7_" 18 | }, 19 | "source": [ 20 | "# 設定\n", 21 | "\n", 22 | "1. [一定要有] 至少 T4 GPU\n", 23 | "2. HuggingFace Access Token: 你可以到 [Hugging Face access token](https://huggingface.co/docs/hub/en/security-tokens) 去產生一組 Token 後,放入 Colab secret 'HF_TOKEN' 中。" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": { 29 | "id": "gJaQ-OVoPKCo" 30 | }, 31 | "source": [ 32 | "## Step 1: 安裝 Vllm 套件\n" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "Tue Nov 18 01:08:33 2025 \n", 45 | "+-----------------------------------------------------------------------------------------+\n", 46 | "| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |\n", 47 | "|-----------------------------------------+------------------------+----------------------+\n", 48 | "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 49 | "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n", 50 | "| | | MIG M. |\n", 51 | "|=========================================+========================+======================|\n", 52 | "| 0 NVIDIA L4 Off | 00000000:00:03.0 Off | 0 |\n", 53 | "| N/A 45C P8 11W / 72W | 0MiB / 23034MiB | 0% Default |\n", 54 | "| | | N/A |\n", 55 | "+-----------------------------------------+------------------------+----------------------+\n", 56 | " \n", 57 | "+-----------------------------------------------------------------------------------------+\n", 58 | "| Processes: |\n", 59 | "| GPU GI CI PID Type Process name GPU Memory |\n", 60 | "| ID ID Usage |\n", 61 | "|=========================================================================================|\n", 62 | "| No running processes found |\n", 63 | "+-----------------------------------------------------------------------------------------+\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "!nvidia-smi" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 1, 74 | "metadata": { 75 | "colab": { 76 | "base_uri": "https://localhost:8080/" 77 | }, 78 | "collapsed": true, 79 | "executionInfo": { 80 | "elapsed": 167592, 81 | "status": "ok", 82 | "timestamp": 1756968040142, 83 | "user": { 84 | "displayName": "Yu-Wei Simon Liu (Simon Liu)", 85 | "userId": "07932650701621055368" 86 | }, 87 | "user_tz": -480 88 | }, 89 | "id": "DHrOMaOAPSAM", 90 | "outputId": "14428e99-48a6-4439-f880-183eb9173f70" 91 | }, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m438.2/438.2 MB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", 98 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m180.0/180.0 kB\u001b[0m \u001b[31m16.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 99 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.5/45.5 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 100 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m111.0/111.0 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 101 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.4/45.4 kB\u001b[0m \u001b[31m4.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 102 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.9/3.9 MB\u001b[0m \u001b[31m79.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m\n", 103 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.3/2.3 MB\u001b[0m \u001b[31m91.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 104 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m117.2/117.2 MB\u001b[0m \u001b[31m11.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", 105 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.7/8.7 MB\u001b[0m \u001b[31m124.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", 106 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m96.2/96.2 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 107 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m15.0/15.0 MB\u001b[0m \u001b[31m122.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m0:01\u001b[0m\n", 108 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.5/6.5 MB\u001b[0m \u001b[31m131.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n", 109 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.0/3.0 MB\u001b[0m \u001b[31m111.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 110 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.4/71.4 MB\u001b[0m \u001b[31m31.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", 111 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m72.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 112 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m388.0/388.0 kB\u001b[0m \u001b[31m30.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 113 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m285.7/285.7 kB\u001b[0m \u001b[31m25.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 114 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m213.6/213.6 kB\u001b[0m \u001b[31m19.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 115 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m180.7/180.7 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 116 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.6/71.6 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 117 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m456.8/456.8 kB\u001b[0m \u001b[31m29.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 118 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m108.3/108.3 kB\u001b[0m \u001b[31m11.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 119 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.4/42.4 MB\u001b[0m \u001b[31m56.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", 120 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.9/40.9 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 121 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m331.1/331.1 kB\u001b[0m \u001b[31m29.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 122 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m517.7/517.7 kB\u001b[0m \u001b[31m38.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 123 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.3/6.3 MB\u001b[0m \u001b[31m107.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", 124 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.4/4.4 MB\u001b[0m \u001b[31m62.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", 125 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m959.8/959.8 kB\u001b[0m \u001b[31m61.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 126 | "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", 127 | "ipython 7.34.0 requires jedi>=0.16, which is not installed.\u001b[0m\u001b[31m\n", 128 | "\u001b[0m" 129 | ] 130 | } 131 | ], 132 | "source": [ 133 | "!pip install -q vllm" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": { 139 | "id": "wn2lC5hVPUxy" 140 | }, 141 | "source": [ 142 | "## Method 1: 使用 VLLM Python SDK 來使用 Google Gemma 3\n", 143 | "\n", 144 | "注意:使用完 Method 1,請務必重啟工作階段,才能夠釋放 GPU 資源" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 3, 150 | "metadata": { 151 | "colab": { 152 | "base_uri": "https://localhost:8080/" 153 | }, 154 | "id": "q8o4A6QKPp3d", 155 | "outputId": "5dcc599a-f827-49db-df8b-47cf1935c393" 156 | }, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "INFO 11-18 01:09:14 [__init__.py:216] Automatically detected platform cuda.\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "from vllm import LLM" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "import os\n", 177 | "\n", 178 | "os.environ[\"HF_TOKEN\"] = \"\" # Replace with your actual Hugging Face token" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 9, 184 | "metadata": { 185 | "colab": { 186 | "base_uri": "https://localhost:8080/", 187 | "height": 453, 188 | "referenced_widgets": [ 189 | "7a6cf131b5e74c6ca8e32570d4c9e17d", 190 | "04007b832a4a466c9ead23ae37b06eba", 191 | "5037af64a0f94d0c9864af9a59cbe59a", 192 | "ec56826bf045403cb3b84efb0415d34c", 193 | "059b8ce56b0e43ca9c409d52e57db662", 194 | "f1a791d37151488d88bd5b76b1f07195", 195 | "9192a80c14bb49f3b5e45f684fd26ccf", 196 | "3ee345574f9148ed8c6a66e8bf497529", 197 | "c63da91ed2f34bad9d873f05713ddb92", 198 | "eaef7f3693ed433581f38fb8576b334f", 199 | "0283be179808435894898d53d2b0045f" 200 | ] 201 | }, 202 | "id": "k8y8SD1XPzAr", 203 | "outputId": "be49ee72-3aaa-4a0f-c460-3afde8def8ba" 204 | }, 205 | "outputs": [ 206 | { 207 | "name": "stdout", 208 | "output_type": "stream", 209 | "text": [ 210 | "INFO 11-18 01:23:51 [utils.py:233] non-default args: {'dtype': 'bfloat16', 'disable_log_stats': True, 'model': 'google/gemma-3-4b-it'}\n" 211 | ] 212 | }, 213 | { 214 | "data": { 215 | "application/vnd.jupyter.widget-view+json": { 216 | "model_id": "73b5c4c3bf934360a8648e3d614db018", 217 | "version_major": 2, 218 | "version_minor": 0 219 | }, 220 | "text/plain": [ 221 | "config.json: 0%| | 0.00/855 [00:00] 1.29K --.-KB/s in 0s \n", 90 | "\n", 91 | "2024-11-20 06:24:49 (530 MB/s) - ‘中華隊_bat_data_with_chinese_names.csv’ saved [1325/1325]\n", 92 | "\n", 93 | "--2024-11-20 06:24:49-- https://huggingface.co/datasets/Simon-Liu/premier-12-chinese-taipei-performance-data/resolve/main/%E4%B8%AD%E8%8F%AF%E9%9A%8A_field_data_with_chinese_names.csv\n", 94 | "Resolving huggingface.co (huggingface.co)... 3.171.171.6, 3.171.171.128, 3.171.171.65, ...\n", 95 | "Connecting to huggingface.co (huggingface.co)|3.171.171.6|:443... connected.\n", 96 | "HTTP request sent, awaiting response... 200 OK\n", 97 | "Length: 948 [text/plain]\n", 98 | "Saving to: ‘中華隊_field_data_with_chinese_names.csv’\n", 99 | "\n", 100 | "中華隊_field_data_w 100%[===================>] 948 --.-KB/s in 0s \n", 101 | "\n", 102 | "2024-11-20 06:24:49 (321 MB/s) - ‘中華隊_field_data_with_chinese_names.csv’ saved [948/948]\n", 103 | "\n", 104 | "--2024-11-20 06:24:49-- https://huggingface.co/datasets/Simon-Liu/premier-12-chinese-taipei-performance-data/resolve/main/%E4%B8%AD%E8%8F%AF%E9%9A%8A_pitch_data_with_chinese_names.csv\n", 105 | "Resolving huggingface.co (huggingface.co)... 3.171.171.6, 3.171.171.128, 3.171.171.65, ...\n", 106 | "Connecting to huggingface.co (huggingface.co)|3.171.171.6|:443... connected.\n", 107 | "HTTP request sent, awaiting response... 200 OK\n", 108 | "Length: 1150 (1.1K) [text/plain]\n", 109 | "Saving to: ‘中華隊_pitch_data_with_chinese_names.csv’\n", 110 | "\n", 111 | "中華隊_pitch_data_w 100%[===================>] 1.12K --.-KB/s in 0s \n", 112 | "\n", 113 | "2024-11-20 06:24:49 (404 MB/s) - ‘中華隊_pitch_data_with_chinese_names.csv’ saved [1150/1150]\n", 114 | "\n" 115 | ] 116 | } 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "source": [ 122 | "# Model" 123 | ], 124 | "metadata": { 125 | "id": "zEPJ8OsgOqrH" 126 | }, 127 | "id": "zEPJ8OsgOqrH" 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 3, 132 | "id": "8ef2e0b3", 133 | "metadata": { 134 | "id": "8ef2e0b3" 135 | }, 136 | "outputs": [], 137 | "source": [ 138 | "## 設定 Google API\n", 139 | "# 設定 Google API 金鑰,準備初始化 Google 生成式 AI 模型。\n", 140 | "import os\n", 141 | "from google.colab import userdata\n", 142 | "\n", 143 | "os.environ['GOOGLE_API_KEY'] = userdata.get('GOOGLE_API_KEY')" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 4, 149 | "id": "df0118b5", 150 | "metadata": { 151 | "id": "df0118b5" 152 | }, 153 | "outputs": [], 154 | "source": [ 155 | "## 初始化生成式 AI 模型\n", 156 | "# 使用 Google Generative AI 的模型來處理查詢。\n", 157 | "from langchain_google_genai import ChatGoogleGenerativeAI\n", 158 | "\n", 159 | "# 初始化語言模型\n", 160 | "llm = ChatGoogleGenerativeAI(\n", 161 | " model=\"gemini-1.5-flash-8b\",\n", 162 | " temperature=0,\n", 163 | ")" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "source": [ 169 | "# LLM with tool - Function Calling" 170 | ], 171 | "metadata": { 172 | "id": "C2Quv8SqNsiL" 173 | }, 174 | "id": "C2Quv8SqNsiL" 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 5, 179 | "id": "23b6b25a", 180 | "metadata": { 181 | "id": "23b6b25a" 182 | }, 183 | "outputs": [], 184 | "source": [ 185 | "## 定義打擊成績查詢工具\n", 186 | "# 使用 LangChain 的工具功能定義搜尋特定選手打擊數據的工具。\n", 187 | "import pandas as pd\n", 188 | "import json\n", 189 | "from langchain_core.tools import tool\n", 190 | "\n", 191 | "@tool\n", 192 | "def search_bat_player_stats(player_name: str) -> str:\n", 193 | " \"\"\"\n", 194 | " Search for player statistics by name in the CSV file.\n", 195 | " Returns raw data in JSON format.\n", 196 | " \"\"\"\n", 197 | " try:\n", 198 | " # 讀取 CSV 文件\n", 199 | " file_path = \"/content/中華隊_bat_data_with_chinese_names.csv\"\n", 200 | " df = pd.read_csv(file_path)\n", 201 | "\n", 202 | " # 搜尋球員\n", 203 | " player_data = df[df['player_chinese'] == player_name]\n", 204 | "\n", 205 | " if player_data.empty:\n", 206 | " return json.dumps({\"error\": \"Player not found.\"})\n", 207 | "\n", 208 | " # 將結果轉換為 JSON 格式\n", 209 | " result = player_data.to_dict(orient=\"records\")\n", 210 | " return json.dumps(result, ensure_ascii=False)\n", 211 | "\n", 212 | " except Exception as e:\n", 213 | " return json.dumps({\"error\": str(e)})" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 6, 219 | "id": "e3065f83", 220 | "metadata": { 221 | "colab": { 222 | "base_uri": "https://localhost:8080/", 223 | "height": 87 224 | }, 225 | "id": "e3065f83", 226 | "outputId": "92034f51-e166-4436-a9ff-0a48c6e800dc" 227 | }, 228 | "outputs": [ 229 | { 230 | "output_type": "stream", 231 | "name": "stdout", 232 | "text": [ 233 | "CPU times: user 15.3 ms, sys: 2.82 ms, total: 18.1 ms\n", 234 | "Wall time: 49.5 ms\n" 235 | ] 236 | }, 237 | { 238 | "output_type": "execute_result", 239 | "data": { 240 | "text/plain": [ 241 | "'[{\"球員\": \"CHANG\\\\nCheng-Yu\", \"AB\": 7, \"R\": 1, \"H\": 1, \"2B\": 0, \"3B\": 0, \"HR\": 0, \"RBI\": 1, \"TB\": 1, \"AVG\": 0.143, \"SLG\": 0.143, \"OBP\": 0.143, \"OPS\": 0.286, \"BB\": 0, \"HBP\": 0, \"SO\": 1, \"GDP\": 0, \"SF\": 0, \"SH\": 0, \"SB\": 1, \"CS\": 0, \"player_chinese\": \"張政禹\"}]'" 242 | ], 243 | "application/vnd.google.colaboratory.intrinsic+json": { 244 | "type": "string" 245 | } 246 | }, 247 | "metadata": {}, 248 | "execution_count": 6 249 | } 250 | ], 251 | "source": [ 252 | "## 測試打擊成績查詢工具\n", 253 | "# 測試 `search_bat_player_stats` 工具,查詢張政禹的打擊數據。\n", 254 | "%%time\n", 255 | "\n", 256 | "search_bat_player_stats.invoke({\"player_name\": \"張政禹\"})" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 7, 262 | "id": "89610ada", 263 | "metadata": { 264 | "id": "89610ada" 265 | }, 266 | "outputs": [], 267 | "source": [ 268 | "## 綁定工具與語言模型\n", 269 | "# 將 `search_bat_player_stats` 工具綁定到語言模型,方便整合操作。\n", 270 | "llm_with_tools = llm.bind_tools([search_bat_player_stats])" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 8, 276 | "id": "11301ebd", 277 | "metadata": { 278 | "colab": { 279 | "base_uri": "https://localhost:8080/" 280 | }, 281 | "id": "11301ebd", 282 | "outputId": "e44a28bd-97d7-4a73-f40e-40d0f0678da3" 283 | }, 284 | "outputs": [ 285 | { 286 | "output_type": "stream", 287 | "name": "stdout", 288 | "text": [ 289 | "CPU times: user 15.1 ms, sys: 1.9 ms, total: 17.1 ms\n", 290 | "Wall time: 515 ms\n" 291 | ] 292 | }, 293 | { 294 | "output_type": "execute_result", 295 | "data": { 296 | "text/plain": [ 297 | "[{'name': 'search_bat_player_stats',\n", 298 | " 'args': {'player_name': '張政禹'},\n", 299 | " 'id': 'fd8d881d-2906-4b40-91a9-c9c044626e29',\n", 300 | " 'type': 'tool_call'}]" 301 | ] 302 | }, 303 | "metadata": {}, 304 | "execution_count": 8 305 | } 306 | ], 307 | "source": [ 308 | "## 測試語言模型與工具的整合\n", 309 | "# 測試整合後的工具,詢問張政禹的比賽成績並檢視結果。\n", 310 | "%%time\n", 311 | "\n", 312 | "msg = llm_with_tools.invoke(\"張政禹選手的成績?\")\n", 313 | "msg.tool_calls" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "source": [ 319 | "# AI Agent with three tools" 320 | ], 321 | "metadata": { 322 | "id": "QXIX-MHfOzlu" 323 | }, 324 | "id": "QXIX-MHfOzlu" 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 9, 329 | "id": "21b692fa", 330 | "metadata": { 331 | "colab": { 332 | "base_uri": "https://localhost:8080/" 333 | }, 334 | "id": "21b692fa", 335 | "outputId": "a6ceac91-86eb-4239-b727-b71ee885e5e8" 336 | }, 337 | "outputs": [ 338 | { 339 | "output_type": "stream", 340 | "name": "stdout", 341 | "text": [ 342 | "================================\u001b[1m Human Message \u001b[0m=================================\n", 343 | "\n", 344 | "\n", 345 | "================================ System Message ================================\n", 346 | "\n", 347 | "這是一個可以查詢 2024 12強棒球賽,中華隊投球、打擊、守備數據庫,\n", 348 | "你是一個專業的數據查詢和分析助手。你可以使用工具來查詢數據並幫助用戶完成額外計算。\n", 349 | "\n", 350 | "工具分成:\n", 351 | "1. 打擊數據成績\n", 352 | "2. 投球數據成績\n", 353 | "3. 守備數據成績\n", 354 | "\n", 355 | "當用戶詢問問題時:\n", 356 | "- 首先使用工具查詢棒球員的相關數據。\n", 357 | "- 然後完成所需的計算。\n", 358 | "- 最後以自然語言回答用戶的問題。\n", 359 | "\n", 360 | "現在準備好處理用戶的請求。\n", 361 | "\n", 362 | "================================ Human Message =================================\n", 363 | "\n", 364 | "\u001b[33;1m\u001b[1;3m{input}\u001b[0m\n", 365 | "\n", 366 | "============================= Messages Placeholder =============================\n", 367 | "\n", 368 | "\u001b[33;1m\u001b[1;3m{agent_scratchpad}\u001b[0m\n", 369 | "\n" 370 | ] 371 | } 372 | ], 373 | "source": [ 374 | "## 定義自訂 Prompt 模板\n", 375 | "# 定義用於生成式 AI 和工具的 Prompt,包含佔位符以便替換。\n", 376 | "from langchain_core.prompts import ChatPromptTemplate\n", 377 | "\n", 378 | "# 定義包含佔位符的字串模板\n", 379 | "template = \"\"\"\n", 380 | "================================ System Message ================================\n", 381 | "\n", 382 | "這是一個可以查詢 2024 12強棒球賽,中華隊投球、打擊、守備數據庫,\n", 383 | "你是一個專業的數據查詢和分析助手。你可以使用工具來查詢數據並幫助用戶完成額外計算。\n", 384 | "\n", 385 | "工具分成:\n", 386 | "1. 打擊數據成績\n", 387 | "2. 投球數據成績\n", 388 | "3. 守備數據成績\n", 389 | "\n", 390 | "當用戶詢問問題時:\n", 391 | "- 首先使用工具查詢棒球員的相關數據。\n", 392 | "- 然後完成所需的計算。\n", 393 | "- 最後以自然語言回答用戶的問題。\n", 394 | "\n", 395 | "現在準備好處理用戶的請求。\n", 396 | "\n", 397 | "================================ Human Message =================================\n", 398 | "\n", 399 | "{input}\n", 400 | "\n", 401 | "============================= Messages Placeholder =============================\n", 402 | "\n", 403 | "{agent_scratchpad}\n", 404 | "\"\"\"\n", 405 | "\n", 406 | "# 使用 from_template 方法將字串轉換為 ChatPromptTemplate\n", 407 | "prompt = ChatPromptTemplate.from_template(template)\n", 408 | "prompt.pretty_print()" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 10, 414 | "id": "75987beb", 415 | "metadata": { 416 | "id": "75987beb" 417 | }, 418 | "outputs": [], 419 | "source": [ 420 | "## 定義守備成績查詢工具\n", 421 | "# 定義搜尋守備成績數據的工具,擴展數據查詢能力。\n", 422 | "import pandas as pd\n", 423 | "import json\n", 424 | "from langchain_core.tools import tool\n", 425 | "\n", 426 | "@tool\n", 427 | "def search_field_player_stats(player_name: str) -> str:\n", 428 | " \"\"\"\n", 429 | " Search for player statistics by name in the CSV file.\n", 430 | " Returns raw data in JSON format.\n", 431 | " \"\"\"\n", 432 | " try:\n", 433 | " # 讀取 CSV 文件\n", 434 | " file_path = \"/content/中華隊_field_data_with_chinese_names.csv\"\n", 435 | " df = pd.read_csv(file_path)\n", 436 | "\n", 437 | " # 搜尋球員\n", 438 | " player_data = df[df['player_chinese'] == player_name]\n", 439 | "\n", 440 | " if player_data.empty:\n", 441 | " return json.dumps({\"error\": \"Player not found.\"})\n", 442 | "\n", 443 | " # 將結果轉換為 JSON 格式\n", 444 | " result = player_data.to_dict(orient=\"records\")\n", 445 | " return json.dumps(result, ensure_ascii=False)\n", 446 | "\n", 447 | " except Exception as e:\n", 448 | " return json.dumps({\"error\": str(e)})" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 11, 454 | "id": "4fb96021", 455 | "metadata": { 456 | "id": "4fb96021" 457 | }, 458 | "outputs": [], 459 | "source": [ 460 | "## 定義投球成績查詢工具\n", 461 | "# 定義搜尋投球成績數據的工具,提供完整的投手數據支持。\n", 462 | "import pandas as pd\n", 463 | "import json\n", 464 | "from langchain_core.tools import tool\n", 465 | "\n", 466 | "@tool\n", 467 | "def search_pitch_player_stats(player_name: str) -> str:\n", 468 | " \"\"\"\n", 469 | " Search for player statistics by name in the CSV file.\n", 470 | " Returns raw data in JSON format.\n", 471 | " \"\"\"\n", 472 | " try:\n", 473 | " # 讀取 CSV 文件\n", 474 | " file_path = \"/content/中華隊_pitch_data_with_chinese_names.csv\"\n", 475 | " df = pd.read_csv(file_path)\n", 476 | "\n", 477 | " # 搜尋球員\n", 478 | " player_data = df[df['player_chinese'] == player_name]\n", 479 | "\n", 480 | " if player_data.empty:\n", 481 | " return json.dumps({\"error\": \"Player not found.\"})\n", 482 | "\n", 483 | " # 將結果轉換為 JSON 格式\n", 484 | " result = player_data.to_dict(orient=\"records\")\n", 485 | " return json.dumps(result, ensure_ascii=False)\n", 486 | "\n", 487 | " except Exception as e:\n", 488 | " return json.dumps({\"error\": str(e)})" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": 12, 494 | "id": "cb4ce413", 495 | "metadata": { 496 | "id": "cb4ce413" 497 | }, 498 | "outputs": [], 499 | "source": [ 500 | "## 載入代理工具執行模組\n", 501 | "# 載入 LangChain 的代理執行功能模組。\n", 502 | "from langchain.agents import AgentExecutor, create_tool_calling_agent\n", 503 | "\n", 504 | "## 整合所有工具\n", 505 | "# 將打擊、守備和投球數據查詢工具整合到工具清單中。\n", 506 | "tools = [search_bat_player_stats, search_field_player_stats, search_pitch_player_stats]\n", 507 | "\n", 508 | "## 創建工具代理\n", 509 | "# 建立代理工具系統,將語言模型和工具清單整合以提供查詢能力。\n", 510 | "# Construct the tool calling agent\n", 511 | "agent = create_tool_calling_agent(llm, tools, prompt)\n", 512 | "\n", 513 | "## 建立代理執行器\n", 514 | "# 設定代理執行器,允許使用工具代理執行複雜查詢。\n", 515 | "# Create an agent executor by passing in the agent and tools\n", 516 | "agent_executor = AgentExecutor(\n", 517 | " agent=agent,\n", 518 | " tools=tools,\n", 519 | " verbose=True,\n", 520 | " return_intermediate_steps=True,\n", 521 | " max_iterations=5 # Example limit\n", 522 | ")" 523 | ] 524 | }, 525 | { 526 | "cell_type": "markdown", 527 | "source": [ 528 | "# User can ask the question here" 529 | ], 530 | "metadata": { 531 | "id": "w3B5pFk5O-uX" 532 | }, 533 | "id": "w3B5pFk5O-uX" 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": 13, 538 | "id": "ed8bcf40", 539 | "metadata": { 540 | "colab": { 541 | "base_uri": "https://localhost:8080/" 542 | }, 543 | "id": "ed8bcf40", 544 | "outputId": "4506a447-4788-48e7-96bb-f71956c97ed3" 545 | }, 546 | "outputs": [ 547 | { 548 | "output_type": "stream", 549 | "name": "stdout", 550 | "text": [ 551 | "\n", 552 | "\n", 553 | "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", 554 | "\u001b[32;1m\u001b[1;3m\n", 555 | "Invoking: `search_pitch_player_stats` with `{'player_name': '王志煊'}`\n", 556 | "\n", 557 | "\n", 558 | "\u001b[0m\u001b[38;5;200m\u001b[1;3m[{\"球員\": \"WANG\\nChih-Hsuan\", \"W\": 1, \"L\": 0, \"ERA\": 0.0, \"APP\": 2, \"GS\": 0, \"SV\": 0, \"IP\": 2.0, \"H\": 1, \"R\": 0, \"ER\": 0, \"BB\": 0, \"SO\": 2, \"2B\": 1, \"HR\": 0, \"AB\": 7, \"BAVG\": 0.143, \"HB\": 0, \"SHA\": 0, \"SHA.1\": 0, \"GO\": 1, \"FO\": 3, \"WHIP\": 0.5, \"player_chinese\": \"王志煊\"}]\u001b[0m\u001b[32;1m\u001b[1;3m\n", 559 | "Invoking: `search_pitch_player_stats` with `{'player_name': '王志煊'}`\n", 560 | "\n", 561 | "\n", 562 | "\u001b[0m\u001b[38;5;200m\u001b[1;3m[{\"球員\": \"WANG\\nChih-Hsuan\", \"W\": 1, \"L\": 0, \"ERA\": 0.0, \"APP\": 2, \"GS\": 0, \"SV\": 0, \"IP\": 2.0, \"H\": 1, \"R\": 0, \"ER\": 0, \"BB\": 0, \"SO\": 2, \"2B\": 1, \"HR\": 0, \"AB\": 7, \"BAVG\": 0.143, \"HB\": 0, \"SHA\": 0, \"SHA.1\": 0, \"GO\": 1, \"FO\": 3, \"WHIP\": 0.5, \"player_chinese\": \"王志煊\"}]\u001b[0m\u001b[32;1m\u001b[1;3m王志煊的投球成績:\n", 563 | "\n", 564 | "投球次數:2\n", 565 | "勝場:1\n", 566 | "敗場:0\n", 567 | "自責分率:0.0\n", 568 | "投球局數:2.0\n", 569 | "被安打:1\n", 570 | "失分:0\n", 571 | "自責分:0\n", 572 | "保送:0\n", 573 | "三振:2\n", 574 | "滾地球:1\n", 575 | "飛球:3\n", 576 | "\n", 577 | "滾飛比:1/3 = 0.33\n", 578 | "\u001b[0m\n", 579 | "\n", 580 | "\u001b[1m> Finished chain.\u001b[0m\n", 581 | "CPU times: user 116 ms, sys: 12.1 ms, total: 128 ms\n", 582 | "Wall time: 1.8 s\n" 583 | ] 584 | } 585 | ], 586 | "source": [ 587 | "## 測試代理執行器\n", 588 | "# 使用代理執行器查詢王志煊的投球成績,並計算滾飛比。\n", 589 | "%%time\n", 590 | "\n", 591 | "result = agent_executor.invoke(\n", 592 | " {\n", 593 | " \"input\": \"請問王志煊的投球成績,並幫我計算滾飛比?\"\n", 594 | " }\n", 595 | ")" 596 | ] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "execution_count": 14, 601 | "id": "2794a750", 602 | "metadata": { 603 | "colab": { 604 | "base_uri": "https://localhost:8080/" 605 | }, 606 | "id": "2794a750", 607 | "outputId": "f5cecacd-bd62-4ac3-c70c-7279abf04819" 608 | }, 609 | "outputs": [ 610 | { 611 | "output_type": "stream", 612 | "name": "stdout", 613 | "text": [ 614 | "王志煊的投球成績:\n", 615 | "\n", 616 | "投球次數:2\n", 617 | "勝場:1\n", 618 | "敗場:0\n", 619 | "自責分率:0.0\n", 620 | "投球局數:2.0\n", 621 | "被安打:1\n", 622 | "失分:0\n", 623 | "自責分:0\n", 624 | "保送:0\n", 625 | "三振:2\n", 626 | "滾地球:1\n", 627 | "飛球:3\n", 628 | "\n", 629 | "滾飛比:1/3 = 0.33\n", 630 | "\n" 631 | ] 632 | } 633 | ], 634 | "source": [ 635 | "## 列印結果\n", 636 | "# 輸出查詢結果到控制台。\n", 637 | "print(result['output'])" 638 | ] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "execution_count": 15, 643 | "id": "af732759", 644 | "metadata": { 645 | "colab": { 646 | "base_uri": "https://localhost:8080/" 647 | }, 648 | "id": "af732759", 649 | "outputId": "de2815d7-8ebc-422e-e318-bb0bd63661d0" 650 | }, 651 | "outputs": [ 652 | { 653 | "output_type": "stream", 654 | "name": "stdout", 655 | "text": [ 656 | "\n", 657 | "\n", 658 | "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", 659 | "\u001b[32;1m\u001b[1;3m\n", 660 | "Invoking: `search_bat_player_stats` with `{'player_name': '林安可'}`\n", 661 | "\n", 662 | "\n", 663 | "\u001b[0m\u001b[36;1m\u001b[1;3m[{\"球員\": \"LIN\\nAn-Ko\", \"AB\": 8, \"R\": 0, \"H\": 2, \"2B\": 1, \"3B\": 0, \"HR\": 0, \"RBI\": 0, \"TB\": 3, \"AVG\": 0.25, \"SLG\": 0.375, \"OBP\": 0.333, \"OPS\": 0.708, \"BB\": 0, \"HBP\": 1, \"SO\": 2, \"GDP\": 0, \"SF\": 0, \"SH\": 0, \"SB\": 0, \"CS\": 0, \"player_chinese\": \"林安可\"}]\u001b[0m\u001b[32;1m\u001b[1;3m林安可這次比賽打擊成績為:打數 8 球,安打 2 支,包含 1 支二壘打,打擊率 0.25,上壘率 0.333,長打率 0.375,OPS 值 0.708。此外,他被擊中 1 次,三振 2 次。\n", 664 | "\u001b[0m\n", 665 | "\n", 666 | "\u001b[1m> Finished chain.\u001b[0m\n", 667 | "CPU times: user 53.9 ms, sys: 5.2 ms, total: 59.1 ms\n", 668 | "Wall time: 1.06 s\n" 669 | ] 670 | } 671 | ], 672 | "source": [ 673 | "## 測試其他查詢\n", 674 | "# 測試代理執行器查詢林安可的打擊狀況,檢視結果。\n", 675 | "%%time\n", 676 | "\n", 677 | "result = agent_executor.invoke(\n", 678 | " {\n", 679 | " \"input\": \"請問林安可這次賽會的打擊狀況如何?\"\n", 680 | " }\n", 681 | ")" 682 | ] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "execution_count": 16, 687 | "id": "b72718b2", 688 | "metadata": { 689 | "colab": { 690 | "base_uri": "https://localhost:8080/" 691 | }, 692 | "id": "b72718b2", 693 | "outputId": "4cfba872-89fe-40f1-ae8e-063690197a48" 694 | }, 695 | "outputs": [ 696 | { 697 | "output_type": "stream", 698 | "name": "stdout", 699 | "text": [ 700 | "林安可這次比賽打擊成績為:打數 8 球,安打 2 支,包含 1 支二壘打,打擊率 0.25,上壘率 0.333,長打率 0.375,OPS 值 0.708。\n", 701 | "此外,他被擊中 1 次,三振 2 次。\n", 702 | "\n", 703 | "\n" 704 | ] 705 | } 706 | ], 707 | "source": [ 708 | "## 格式化並列印輸出\n", 709 | "# 將輸出格式化後列印,提高結果的可讀性。\n", 710 | "# 印出結果\n", 711 | "print(result['output'].replace('。', '。\\n'))" 712 | ] 713 | }, 714 | { 715 | "cell_type": "markdown", 716 | "source": [ 717 | "# Use Google Mesop python package application UI." 718 | ], 719 | "metadata": { 720 | "id": "K84vYo6VPHqK" 721 | }, 722 | "id": "K84vYo6VPHqK" 723 | }, 724 | { 725 | "cell_type": "code", 726 | "source": [ 727 | "!pip install mesop" 728 | ], 729 | "metadata": { 730 | "colab": { 731 | "base_uri": "https://localhost:8080/" 732 | }, 733 | "id": "4QGGuf6-PISx", 734 | "outputId": "d30f6f8a-486a-4fdc-f448-f8605dea8292" 735 | }, 736 | "id": "4QGGuf6-PISx", 737 | "execution_count": 17, 738 | "outputs": [ 739 | { 740 | "output_type": "stream", 741 | "name": "stdout", 742 | "text": [ 743 | "Collecting mesop\n", 744 | " Downloading mesop-0.12.9-py3-none-any.whl.metadata (1.0 kB)\n", 745 | "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from mesop) (1.4.0)\n", 746 | "Collecting deepdiff==6.* (from mesop)\n", 747 | " Downloading deepdiff-6.7.1-py3-none-any.whl.metadata (6.1 kB)\n", 748 | "Requirement already satisfied: flask in /usr/local/lib/python3.10/dist-packages (from mesop) (3.0.3)\n", 749 | "Requirement already satisfied: msgpack in /usr/local/lib/python3.10/dist-packages (from mesop) (1.1.0)\n", 750 | "Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from mesop) (4.25.5)\n", 751 | "Requirement already satisfied: pydantic in /usr/local/lib/python3.10/dist-packages (from mesop) (2.9.2)\n", 752 | "Requirement already satisfied: python-dotenv in /usr/local/lib/python3.10/dist-packages (from mesop) (1.0.1)\n", 753 | "Collecting watchdog (from mesop)\n", 754 | " Downloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl.metadata (44 kB)\n", 755 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.3/44.3 kB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 756 | "\u001b[?25hRequirement already satisfied: werkzeug>=3.0.6 in /usr/local/lib/python3.10/dist-packages (from mesop) (3.1.3)\n", 757 | "Collecting ordered-set<4.2.0,>=4.0.2 (from deepdiff==6.*->mesop)\n", 758 | " Downloading ordered_set-4.1.0-py3-none-any.whl.metadata (5.3 kB)\n", 759 | "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=3.0.6->mesop) (3.0.2)\n", 760 | "Requirement already satisfied: Jinja2>=3.1.2 in /usr/local/lib/python3.10/dist-packages (from flask->mesop) (3.1.4)\n", 761 | "Requirement already satisfied: itsdangerous>=2.1.2 in /usr/local/lib/python3.10/dist-packages (from flask->mesop) (2.2.0)\n", 762 | "Requirement already satisfied: click>=8.1.3 in /usr/local/lib/python3.10/dist-packages (from flask->mesop) (8.1.7)\n", 763 | "Requirement already satisfied: blinker>=1.6.2 in /usr/local/lib/python3.10/dist-packages (from flask->mesop) (1.9.0)\n", 764 | "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic->mesop) (0.7.0)\n", 765 | "Requirement already satisfied: pydantic-core==2.23.4 in /usr/local/lib/python3.10/dist-packages (from pydantic->mesop) (2.23.4)\n", 766 | "Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic->mesop) (4.12.2)\n", 767 | "Downloading mesop-0.12.9-py3-none-any.whl (8.1 MB)\n", 768 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.1/8.1 MB\u001b[0m \u001b[31m41.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 769 | "\u001b[?25hDownloading deepdiff-6.7.1-py3-none-any.whl (76 kB)\n", 770 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.6/76.6 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 771 | "\u001b[?25hDownloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl (79 kB)\n", 772 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.1/79.1 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 773 | "\u001b[?25hDownloading ordered_set-4.1.0-py3-none-any.whl (7.6 kB)\n", 774 | "Installing collected packages: watchdog, ordered-set, deepdiff, mesop\n", 775 | "Successfully installed deepdiff-6.7.1 mesop-0.12.9 ordered-set-4.1.0 watchdog-6.0.0\n" 776 | ] 777 | } 778 | ] 779 | }, 780 | { 781 | "cell_type": "code", 782 | "source": [ 783 | "import mesop as me\n", 784 | "import mesop.labs as mel\n", 785 | "\n", 786 | "me.colab_run()" 787 | ], 788 | "metadata": { 789 | "colab": { 790 | "base_uri": "https://localhost:8080/" 791 | }, 792 | "id": "8KJZ90FGPL2Y", 793 | "outputId": "f9c9bfbc-85e6-43c6-efae-1beec7046dc9" 794 | }, 795 | "id": "8KJZ90FGPL2Y", 796 | "execution_count": 18, 797 | "outputs": [ 798 | { 799 | "output_type": "stream", 800 | "name": "stdout", 801 | "text": [ 802 | "\n", 803 | "\u001b[32mRunning server on: http://localhost:32123\u001b[0m\n", 804 | " * Serving Flask app 'mesop.server.server'\n", 805 | " * Debug mode: off\n" 806 | ] 807 | }, 808 | { 809 | "output_type": "stream", 810 | "name": "stderr", 811 | "text": [ 812 | "INFO:werkzeug:\u001b[31m\u001b[1mWARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead.\u001b[0m\n", 813 | " * Running on all addresses (::)\n", 814 | " * Running on http://[::1]:32123\n", 815 | " * Running on http://[::1]:32123\n" 816 | ] 817 | } 818 | ] 819 | }, 820 | { 821 | "cell_type": "code", 822 | "source": [ 823 | "@me.page(path=\"/chat\")\n", 824 | "def chat():\n", 825 | " mel.chat(transform)\n", 826 | "\n", 827 | "def transform(prompt: str, history: list[mel.ChatMessage]) -> str:\n", 828 | " result = agent_executor.invoke(\n", 829 | " {\n", 830 | " \"input\": prompt\n", 831 | " }\n", 832 | " )\n", 833 | "\n", 834 | " return result['output']" 835 | ], 836 | "metadata": { 837 | "id": "4H4PJNlXPLz_" 838 | }, 839 | "id": "4H4PJNlXPLz_", 840 | "execution_count": 20, 841 | "outputs": [] 842 | }, 843 | { 844 | "cell_type": "code", 845 | "source": [ 846 | "me.colab_show(path=\"/chat\", height = '400')" 847 | ], 848 | "metadata": { 849 | "colab": { 850 | "base_uri": "https://localhost:8080/", 851 | "height": 421 852 | }, 853 | "id": "UNtgDQJzPLuh", 854 | "outputId": "7af6bfef-80b7-49a9-c795-09992378ee1d" 855 | }, 856 | "id": "UNtgDQJzPLuh", 857 | "execution_count": 21, 858 | "outputs": [ 859 | { 860 | "output_type": "display_data", 861 | "data": { 862 | "text/plain": [ 863 | "" 864 | ], 865 | "application/javascript": [ 866 | "(async (port, path, width, height, cache, element) => {\n", 867 | " if (!google.colab.kernel.accessAllowed && !cache) {\n", 868 | " return;\n", 869 | " }\n", 870 | " element.appendChild(document.createTextNode(''));\n", 871 | " const url = await google.colab.kernel.proxyPort(port, {cache});\n", 872 | " const iframe = document.createElement('iframe');\n", 873 | " iframe.src = new URL(path, url).toString();\n", 874 | " iframe.height = height;\n", 875 | " iframe.width = width;\n", 876 | " iframe.style.border = 0;\n", 877 | " iframe.allow = [\n", 878 | " 'accelerometer',\n", 879 | " 'autoplay',\n", 880 | " 'camera',\n", 881 | " 'clipboard-read',\n", 882 | " 'clipboard-write',\n", 883 | " 'gyroscope',\n", 884 | " 'magnetometer',\n", 885 | " 'microphone',\n", 886 | " 'serial',\n", 887 | " 'usb',\n", 888 | " 'xr-spatial-tracking',\n", 889 | " ].join('; ');\n", 890 | " element.appendChild(iframe);\n", 891 | " })(32123, \"/chat\", \"100%\", \"400\", false, window.element)" 892 | ] 893 | }, 894 | "metadata": {} 895 | } 896 | ] 897 | } 898 | ], 899 | "metadata": { 900 | "colab": { 901 | "provenance": [] 902 | }, 903 | "kernelspec": { 904 | "display_name": "Python 3", 905 | "name": "python3" 906 | }, 907 | "language_info": { 908 | "name": "python" 909 | } 910 | }, 911 | "nbformat": 4, 912 | "nbformat_minor": 5 913 | } -------------------------------------------------------------------------------- /Simon-LLM-Application-Gemma-2b-LORA-Fine-Tuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "G3MMAcssHTML" 7 | }, 8 | "source": [ 9 | "\n", 10 | "" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "SDEExiAk4fLb" 17 | }, 18 | "source": [ 19 | "# Fine-tune Gemma models in Keras using LoRA\n", 20 | "\n", 21 | "改編自 Google Cloud 官方範例:https://www.apache.org/licenses/LICENSE-2.0" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": { 27 | "id": "lSGRSsRPgkzK" 28 | }, 29 | "source": [ 30 | "## Overview\n", 31 | "\n", 32 | "Gemma 是⼀系列輕量級的開放式 LLM 模型,建構於與打造 Gemini 模型相同的研發技術之上。\n", 33 | "\n", 34 | "大型語言模型 (LLMs) 像 Gemma 已被證明能有效執行各種自然語言處理 (NLP) 任務。LLM 首先會透過自我監督的方式,在大量文本資料集上進行預訓練。預訓練可幫助 LLM 學習通用知識,例如詞語之間的統計關係。然後,LLM 可以使用特定領域的資料進行微調,以執行下游任務 (例如情緒分析)。\n", 35 | "\n", 36 | "LLMs 的模型非常龐大。對於大多數應用程式來說,完整微調 (更新模型中的所有參數) 並非必要,因為典型的微調資料集通常遠遠小於預訓練資料集。\n", 37 | "\n", 38 | "低秩適應 (LoRA) 是一種微調技術,可以大幅減少下游任務的可訓練參數量。該技術透過凍結模型權重並引入少量新權重的方式來實現這一目標。LoRA 微調速度更快、記憶體消耗更少,生成的模型權重更小 (幾百 MB),同時還能維持模型輸出质量。\n", 39 | "\n", 40 | "這個 Colab ,我將引導您使用 KerasNLP 執行 LoRA 微調,使用 Gemma 2B 模型以及 Databricks Dolly 15k 資料集 。此資料集包含 15,000 個由人類生成的高品質提示/回覆配對,專為微調 LLM 而設計。" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": { 46 | "id": "w1q6-W_mKIT-" 47 | }, 48 | "source": [ 49 | "## Setup" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "id": "lyhHCMfoRZ_v" 56 | }, 57 | "source": [ 58 | "### 取得 Gemma 訪問權限\n", 59 | "\n", 60 | "要完成本教學,您首先需要完成 [Gemma 設定](https://ai.google.dev/gemma/docs/setup) 中的設定指示。Gemma 設定指示將指導您完成以下步驟:\n", 61 | "\n", 62 | "* 在 [kaggle.com](https://kaggle.com) 上獲得 Gemma 的訪問權限。\n", 63 | "* 選擇一個具有足夠資源的 Colab 運行時環境,以運行 Gemma 2B 模型。\n", 64 | "* 生成並配置您的 Kaggle 使用者名稱和 API 金鑰。\n", 65 | "\n", 66 | "完成 Gemma 設定後,請進入下一個部分,為您的 Colab 環境設置環境變數。" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "AZ5Qo0fxRZ1V" 73 | }, 74 | "source": [ 75 | "### 選擇運行時環境\n", 76 | "\n", 77 | "要完成本教學,您需要選擇一個具有足夠資源的 Colab 運行時環境來執行 Gemma 模型。在這裡,您可以使用 T4 GPU:\n", 78 | "\n", 79 | "1. 在 Colab 視窗的右上角,選擇 ▾(**額外的連線選項**)。\n", 80 | "2. 選擇 **變更運行時類型**。\n", 81 | "3. 在 **硬體加速器** 下,選擇 **T4 GPU**。\n" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": { 87 | "id": "hsPC0HRkJl0K" 88 | }, 89 | "source": [ 90 | "### 配置您的 API 金鑰\n", 91 | "\n", 92 | "要使用 Gemma,您必須提供您的 Kaggle 使用者名稱和 Kaggle API 金鑰。\n", 93 | "\n", 94 | "要生成 Kaggle API 金鑰,請前往 Kaggle 使用者檔案的 **帳戶** 標籤頁並選擇 **Create New Token**。這將會下載一個包含您的 API 認證的 `kaggle.json` 檔案。\n", 95 | "\n", 96 | "在 Colab 中,選擇左側面板中的 **Secrets**(🔑),然後新增您的 Kaggle 使用者名稱和 Kaggle API 金鑰。將您的使用者名稱儲存為 `KAGGLE_USERNAME`,並將您的 API 金鑰儲存為 `KAGGLE_KEY`。\n" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": { 102 | "id": "7iOF6Yo-wUEC" 103 | }, 104 | "source": [ 105 | "### Set environment variables\n", 106 | "\n", 107 | "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`." 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 1, 113 | "metadata": { 114 | "executionInfo": { 115 | "elapsed": 2836, 116 | "status": "ok", 117 | "timestamp": 1725684340358, 118 | "user": { 119 | "displayName": "劉育維", 120 | "userId": "07932650701621055368" 121 | }, 122 | "user_tz": -480 123 | }, 124 | "id": "0_EdOg9DPK6Q" 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "import os\n", 129 | "from google.colab import userdata\n", 130 | "\n", 131 | "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n", 132 | "# vars as appropriate for your system.\n", 133 | "\n", 134 | "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n", 135 | "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": { 141 | "id": "CuEUAKJW1QkQ" 142 | }, 143 | "source": [ 144 | "### Install dependencies\n", 145 | "\n", 146 | "Install Keras, KerasNLP, and other dependencies." 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 2, 152 | "metadata": { 153 | "colab": { 154 | "base_uri": "https://localhost:8080/" 155 | }, 156 | "executionInfo": { 157 | "elapsed": 22473, 158 | "status": "ok", 159 | "timestamp": 1725684362827, 160 | "user": { 161 | "displayName": "劉育維", 162 | "userId": "07932650701621055368" 163 | }, 164 | "user_tz": -480 165 | }, 166 | "id": "1eeBtYqJsZPG", 167 | "outputId": "a3f22767-ffdc-4c85-f690-f252db7cd8f2" 168 | }, 169 | "outputs": [ 170 | { 171 | "name": "stdout", 172 | "output_type": "stream", 173 | "text": [ 174 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m572.2/572.2 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 175 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m31.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 176 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 177 | "\u001b[?25h" 178 | ] 179 | } 180 | ], 181 | "source": [ 182 | "# Install Keras 3 last. See https://keras.io/getting_started/ for more details.\n", 183 | "!pip install -q -U keras-nlp\n", 184 | "!pip install -q -U \"keras>=3\"" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": { 190 | "id": "rGLS-l5TxIR4" 191 | }, 192 | "source": [ 193 | "### Select a backend\n", 194 | "\n", 195 | "Keras 是一個高階、多框架的深度學習 API,旨在簡化使用並提升使用便利性。使用 Keras 3,您可以在以下三種後端之一上運行工作流程:TensorFlow、JAX 或 PyTorch。\n", 196 | "\n", 197 | "在本教學中,請將後端配置為 JAX。\n" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 3, 203 | "metadata": { 204 | "executionInfo": { 205 | "elapsed": 5, 206 | "status": "ok", 207 | "timestamp": 1725684362828, 208 | "user": { 209 | "displayName": "劉育維", 210 | "userId": "07932650701621055368" 211 | }, 212 | "user_tz": -480 213 | }, 214 | "id": "yn5uy8X8sdD0" 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"torch\" or \"tensorflow\".\n", 219 | "# Avoid memory fragmentation on JAX backend.\n", 220 | "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\"" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": { 226 | "id": "hZs8XXqUKRmi" 227 | }, 228 | "source": [ 229 | "### Import packages\n", 230 | "\n", 231 | "Import Keras and KerasNLP." 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 4, 237 | "metadata": { 238 | "executionInfo": { 239 | "elapsed": 8664, 240 | "status": "ok", 241 | "timestamp": 1725684371488, 242 | "user": { 243 | "displayName": "劉育維", 244 | "userId": "07932650701621055368" 245 | }, 246 | "user_tz": -480 247 | }, 248 | "id": "FYHyPUA9hKTf" 249 | }, 250 | "outputs": [], 251 | "source": [ 252 | "import keras\n", 253 | "import keras_nlp" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": { 259 | "id": "9T7xe_jzslv4" 260 | }, 261 | "source": [ 262 | "## Load Dataset\n", 263 | "\n", 264 | "本次,我們使用的文件是:[erhwenkuo/dolly-15k-chinese-zhtw](https://huggingface.co/datasets/erhwenkuo/dolly-15k-chinese-zhtw)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 5, 270 | "metadata": { 271 | "colab": { 272 | "base_uri": "https://localhost:8080/" 273 | }, 274 | "executionInfo": { 275 | "elapsed": 4371, 276 | "status": "ok", 277 | "timestamp": 1725684375850, 278 | "user": { 279 | "displayName": "劉育維", 280 | "userId": "07932650701621055368" 281 | }, 282 | "user_tz": -480 283 | }, 284 | "id": "xRaNCPUXKoa7", 285 | "outputId": "693d6010-8c9d-436e-bd74-a7f13c98d384" 286 | }, 287 | "outputs": [ 288 | { 289 | "name": "stdout", 290 | "output_type": "stream", 291 | "text": [ 292 | "--2024-09-07 04:46:10-- https://huggingface.co/datasets/erhwenkuo/dolly-15k-chinese-zhtw/resolve/main/data/train-00000-of-00001-839cf763a52639ec.parquet\n", 293 | "Resolving huggingface.co (huggingface.co)... 3.165.160.61, 3.165.160.12, 3.165.160.59, ...\n", 294 | "Connecting to huggingface.co (huggingface.co)|3.165.160.61|:443... connected.\n", 295 | "HTTP request sent, awaiting response... 302 Found\n", 296 | "Location: https://cdn-lfs.huggingface.co/repos/7f/a4/7fa4bedfecc28e6c287d997e1cf54c95a43f10cd85204e4590a9e456dfe93acf/54fcd90b0dfa518bac827f1016031be9a53684e9594d74afbd4b2e51b28c44ab?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27train-00000-of-00001-839cf763a52639ec.parquet%3B+filename%3D%22train-00000-of-00001-839cf763a52639ec.parquet%22%3B&Expires=1725943571&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyNTk0MzU3MX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy83Zi9hNC83ZmE0YmVkZmVjYzI4ZTZjMjg3ZDk5N2UxY2Y1NGM5NWE0M2YxMGNkODUyMDRlNDU5MGE5ZTQ1NmRmZTkzYWNmLzU0ZmNkOTBiMGRmYTUxOGJhYzgyN2YxMDE2MDMxYmU5YTUzNjg0ZTk1OTRkNzRhZmJkNGIyZTUxYjI4YzQ0YWI%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=jRRnPpkcpFhAhfBlqSOGFpj1PHvbWAd3IPAhhvHohyQRHsVU0pQ7xNraDSgb%7E%7E8L41ejAQ4oxOlsgJIfDUTnotlVS0%7EDskFXUFSy1VgrFsykOOwie4WxSDpEu1xzNXpAb4f44R6j%7Es2ZD3BXWVg%7EAHqgV9jUjra%7ErhNX8mKtvSsC2cUR9zZaPKsNH3rPPffOlpGxvG%7EI%7ENAVzFOG05e%7Evr57oeroZ5UgRmQ4ajX--AcTOTWXYXTSJnOfH49gJbz7iBFsy%7ES40e9kz7xGP0Py5EO3GEn4-Vbji4PP-EPTU38-V1rVCq%7EIUTxm-Vhr5rk5N9bhWujqZMSckrBNWOBGSw__&Key-Pair-Id=K3ESJI6DHPFC7 [following]\n", 297 | "--2024-09-07 04:46:11-- https://cdn-lfs.huggingface.co/repos/7f/a4/7fa4bedfecc28e6c287d997e1cf54c95a43f10cd85204e4590a9e456dfe93acf/54fcd90b0dfa518bac827f1016031be9a53684e9594d74afbd4b2e51b28c44ab?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27train-00000-of-00001-839cf763a52639ec.parquet%3B+filename%3D%22train-00000-of-00001-839cf763a52639ec.parquet%22%3B&Expires=1725943571&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyNTk0MzU3MX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy83Zi9hNC83ZmE0YmVkZmVjYzI4ZTZjMjg3ZDk5N2UxY2Y1NGM5NWE0M2YxMGNkODUyMDRlNDU5MGE5ZTQ1NmRmZTkzYWNmLzU0ZmNkOTBiMGRmYTUxOGJhYzgyN2YxMDE2MDMxYmU5YTUzNjg0ZTk1OTRkNzRhZmJkNGIyZTUxYjI4YzQ0YWI%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=jRRnPpkcpFhAhfBlqSOGFpj1PHvbWAd3IPAhhvHohyQRHsVU0pQ7xNraDSgb%7E%7E8L41ejAQ4oxOlsgJIfDUTnotlVS0%7EDskFXUFSy1VgrFsykOOwie4WxSDpEu1xzNXpAb4f44R6j%7Es2ZD3BXWVg%7EAHqgV9jUjra%7ErhNX8mKtvSsC2cUR9zZaPKsNH3rPPffOlpGxvG%7EI%7ENAVzFOG05e%7Evr57oeroZ5UgRmQ4ajX--AcTOTWXYXTSJnOfH49gJbz7iBFsy%7ES40e9kz7xGP0Py5EO3GEn4-Vbji4PP-EPTU38-V1rVCq%7EIUTxm-Vhr5rk5N9bhWujqZMSckrBNWOBGSw__&Key-Pair-Id=K3ESJI6DHPFC7\n", 298 | "Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 108.138.94.14, 108.138.94.122, 108.138.94.25, ...\n", 299 | "Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|108.138.94.14|:443... connected.\n", 300 | "HTTP request sent, awaiting response... 200 OK\n", 301 | "Length: 7492947 (7.1M) [binary/octet-stream]\n", 302 | "Saving to: ‘databricks-dolly-15k-zhtw.parquet’\n", 303 | "\n", 304 | "databricks-dolly-15 100%[===================>] 7.15M 2.07MB/s in 3.5s \n", 305 | "\n", 306 | "2024-09-07 04:46:15 (2.07 MB/s) - ‘databricks-dolly-15k-zhtw.parquet’ saved [7492947/7492947]\n", 307 | "\n" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "!wget -O databricks-dolly-15k-zhtw.parquet https://huggingface.co/datasets/erhwenkuo/dolly-15k-chinese-zhtw/resolve/main/data/train-00000-of-00001-839cf763a52639ec.parquet" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": { 318 | "id": "45UpBDfBgf0I" 319 | }, 320 | "source": [ 321 | "預處理資料。本教學使用 1000 個訓練範例的子集來加快 notebook 的執行速度。若需更高品質的微調,請考慮使用更多的訓練資料。\n" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 6, 327 | "metadata": { 328 | "executionInfo": { 329 | "elapsed": 4289, 330 | "status": "ok", 331 | "timestamp": 1725684380127, 332 | "user": { 333 | "displayName": "劉育維", 334 | "userId": "07932650701621055368" 335 | }, 336 | "user_tz": -480 337 | }, 338 | "id": "ZiS-KU9osh_N" 339 | }, 340 | "outputs": [], 341 | "source": [ 342 | "import pandas as pd\n", 343 | "\n", 344 | "# Read the Parquet file into a DataFrame\n", 345 | "df = pd.read_parquet(\"databricks-dolly-15k-zhtw.parquet\")\n", 346 | "\n", 347 | "data = []\n", 348 | "\n", 349 | "# Iterate over each row in the DataFrame\n", 350 | "for _, row in df.iterrows():\n", 351 | " # Check if there is context, and skip if true\n", 352 | " if row[\"context\"]:\n", 353 | " continue\n", 354 | " # Format the entire example as a single string\n", 355 | " template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n", 356 | " data.append(template.format(instruction=row[\"instruction\"], response=row[\"response\"]))\n", 357 | "\n", 358 | "# Only use 1000 training examples, to keep it fast\n", 359 | "data = data[:1000]" 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "metadata": { 365 | "id": "7RCE3fdGhDE5" 366 | }, 367 | "source": [ 368 | "## 載入模型\n", 369 | "\n", 370 | "KerasNLP 提供了許多熱門[模型架構](https://keras.io/api/keras_nlp/models/)的實作。在本教學中,您將使用 `GemmaCausalLM` 建立一個模型,這是一個用於因果語言建模的端對端 Gemma 模型。因果語言模型會根據先前的標記預測下一個標記。\n", 371 | "\n", 372 | "使用 `from_preset` 方法來建立模型:\n" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 7, 378 | "metadata": { 379 | "colab": { 380 | "base_uri": "https://localhost:8080/", 381 | "height": 385 382 | }, 383 | "executionInfo": { 384 | "elapsed": 55171, 385 | "status": "ok", 386 | "timestamp": 1725684435282, 387 | "user": { 388 | "displayName": "劉育維", 389 | "userId": "07932650701621055368" 390 | }, 391 | "user_tz": -480 392 | }, 393 | "id": "vz5zLEyLstfn", 394 | "outputId": "1729630d-270d-4e75-bda1-f22897ec2a49" 395 | }, 396 | "outputs": [ 397 | { 398 | "data": { 399 | "text/html": [ 400 | "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
401 |        "
\n" 402 | ], 403 | "text/plain": [ 404 | "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" 405 | ] 406 | }, 407 | "metadata": {}, 408 | "output_type": "display_data" 409 | }, 410 | { 411 | "data": { 412 | "text/html": [ 413 | "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
414 |        "┃ Tokenizer (type)                                                                                Vocab # ┃\n",
415 |        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
416 |        "│ gemma_tokenizer (GemmaTokenizer)                   │                                             256,000 │\n",
417 |        "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
418 |        "
\n" 419 | ], 420 | "text/plain": [ 421 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 422 | "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", 423 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 424 | "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", 425 | "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" 426 | ] 427 | }, 428 | "metadata": {}, 429 | "output_type": "display_data" 430 | }, 431 | { 432 | "data": { 433 | "text/html": [ 434 | "
Model: \"gemma_causal_lm\"\n",
435 |        "
\n" 436 | ], 437 | "text/plain": [ 438 | "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" 439 | ] 440 | }, 441 | "metadata": {}, 442 | "output_type": "display_data" 443 | }, 444 | { 445 | "data": { 446 | "text/html": [ 447 | "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
448 |        "┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃\n",
449 |        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
450 |        "│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │\n",
451 |        "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
452 |        "│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │\n",
453 |        "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
454 |        "│ gemma_backbone                │ (None, None, 2304)        │   2,614,341,888 │ padding_mask[0][0],        │\n",
455 |        "│ (GemmaBackbone)               │                           │                 │ token_ids[0][0]            │\n",
456 |        "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
457 |        "│ token_embedding               │ (None, None, 256000)      │     589,824,000 │ gemma_backbone[0][0]       │\n",
458 |        "│ (ReversibleEmbedding)         │                           │                 │                            │\n",
459 |        "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
460 |        "
\n" 461 | ], 462 | "text/plain": [ 463 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 464 | "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", 465 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 466 | "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", 467 | "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", 468 | "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", 469 | "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", 470 | "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,614,341,888\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", 471 | "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", 472 | "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", 473 | "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", 474 | "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", 475 | "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" 476 | ] 477 | }, 478 | "metadata": {}, 479 | "output_type": "display_data" 480 | }, 481 | { 482 | "data": { 483 | "text/html": [ 484 | "
 Total params: 2,614,341,888 (9.74 GB)\n",
485 |        "
\n" 486 | ], 487 | "text/plain": [ 488 | "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" 489 | ] 490 | }, 491 | "metadata": {}, 492 | "output_type": "display_data" 493 | }, 494 | { 495 | "data": { 496 | "text/html": [ 497 | "
 Trainable params: 2,614,341,888 (9.74 GB)\n",
498 |        "
\n" 499 | ], 500 | "text/plain": [ 501 | "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" 502 | ] 503 | }, 504 | "metadata": {}, 505 | "output_type": "display_data" 506 | }, 507 | { 508 | "data": { 509 | "text/html": [ 510 | "
 Non-trainable params: 0 (0.00 B)\n",
511 |        "
\n" 512 | ], 513 | "text/plain": [ 514 | "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" 515 | ] 516 | }, 517 | "metadata": {}, 518 | "output_type": "display_data" 519 | } 520 | ], 521 | "source": [ 522 | "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_2b_en\")\n", 523 | "gemma_lm.summary()" 524 | ] 525 | }, 526 | { 527 | "cell_type": "markdown", 528 | "metadata": { 529 | "id": "Nl4lvPy5zA26" 530 | }, 531 | "source": [ 532 | "`from_preset` 方法會從預設的架構和權重來實例化模型。在上述程式碼中,字串 \"gemma2_2b_en\" 指定了預設的架構,即擁有 20 億個參數的 Gemma 模型。\n", 533 | "\n", 534 | "注意:還有一個具有 70 億參數的 Gemma 模型可供使用。若要在 Colab 中運行更大的模型,您需要付費方案中提供的高級 GPU 訪問權限。或者,您可以在 Kaggle 或 Google Cloud 上進行 [Gemma 7B 模型的分佈式調優](https://ai.google.dev/gemma/docs/distributed_tuning)。\n" 535 | ] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "metadata": { 540 | "id": "G_L6A5J-1QgC" 541 | }, 542 | "source": [ 543 | "## 微調前的推理\n", 544 | "\n", 545 | "在本節中,您將使用各種提示來查詢模型,觀察它的回應。\n" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 8, 551 | "metadata": { 552 | "colab": { 553 | "base_uri": "https://localhost:8080/" 554 | }, 555 | "executionInfo": { 556 | "elapsed": 25333, 557 | "status": "ok", 558 | "timestamp": 1725684460605, 559 | "user": { 560 | "displayName": "劉育維", 561 | "userId": "07932650701621055368" 562 | }, 563 | "user_tz": -480 564 | }, 565 | "id": "ZwQz3xxxKciD", 566 | "outputId": "b7a3d40a-c39a-4995-95e0-09f82c5db9f5" 567 | }, 568 | "outputs": [ 569 | { 570 | "name": "stdout", 571 | "output_type": "stream", 572 | "text": [ 573 | "Instruction:\n", 574 | "愛麗絲的父母有三個女兒:艾米、傑西,第三個女兒叫什麼名字?\t\n", 575 | "\n", 576 | "Response:\n", 577 | "艾莉絲的母親有一個女兒叫傑西。\n", 578 | "\n", 579 | "Instruction:\n", 580 | "艾米和傑西是愛麗絲的兩個姐姐。\t\n", 581 | "\n", 582 | "Question:\n", 583 | "艾米和傑西分別是愛麗絲的大姐姐和小姐姐。\n", 584 | "\n", 585 | "Instruction:\n", 586 | "傑西是愛麗絲的妹妹。\t\n", 587 | "\n", 588 | "Response:\n", 589 | "艾米和傑西都是愛麗絲的姐姐,但是艾米是大姐姐,傑西是小姐姐。\n" 590 | ] 591 | } 592 | ], 593 | "source": [ 594 | "prompt = template.format(\n", 595 | " instruction=\"愛麗絲的父母有三個女兒:艾米、傑西,第三個女兒叫什麼名字?\t\",\n", 596 | " response=\"\",\n", 597 | ")\n", 598 | "sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)\n", 599 | "gemma_lm.compile(sampler=sampler)\n", 600 | "print(gemma_lm.generate(prompt, max_length=256))" 601 | ] 602 | }, 603 | { 604 | "cell_type": "markdown", 605 | "metadata": { 606 | "id": "Pt7Nr6a7tItO" 607 | }, 608 | "source": [ 609 | "## LoRA 微調\n", 610 | "\n", 611 | "為了讓模型提供更好的回應,使用 Databricks Dolly 15k 資料集進行低秩適應(LoRA)微調。\n", 612 | "\n", 613 | "LoRA 的秩(rank)決定了可訓練矩陣的維度,這些矩陣會被添加到大型語言模型(LLM)的原始權重中。它控制了微調調整的表現力和精確性。\n", 614 | "\n", 615 | "較高的秩意味著可以進行更詳細的更改,但也會增加可訓練參數的數量。較低的秩則意味著較少的計算負擔,但可能導致適應精度較低。\n", 616 | "\n", 617 | "本教學使用的 LoRA 秩為 4。在實務中,建議從相對較小的秩開始(例如 4、8、16),這樣在實驗中計算更為高效。訓練模型並評估其在任務上的性能改進,然後在隨後的試驗中逐漸增加秩,看是否進一步提升性能。\n" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": 9, 623 | "metadata": { 624 | "colab": { 625 | "base_uri": "https://localhost:8080/", 626 | "height": 385 627 | }, 628 | "executionInfo": { 629 | "elapsed": 27, 630 | "status": "ok", 631 | "timestamp": 1725684460606, 632 | "user": { 633 | "displayName": "劉育維", 634 | "userId": "07932650701621055368" 635 | }, 636 | "user_tz": -480 637 | }, 638 | "id": "RCucu6oHz53G", 639 | "outputId": "62028177-1ab7-4c4d-d0c5-785d4a396873" 640 | }, 641 | "outputs": [ 642 | { 643 | "data": { 644 | "text/html": [ 645 | "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
646 |        "
\n" 647 | ], 648 | "text/plain": [ 649 | "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" 650 | ] 651 | }, 652 | "metadata": {}, 653 | "output_type": "display_data" 654 | }, 655 | { 656 | "data": { 657 | "text/html": [ 658 | "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
659 |        "┃ Tokenizer (type)                                                                                Vocab # ┃\n",
660 |        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
661 |        "│ gemma_tokenizer (GemmaTokenizer)                   │                                             256,000 │\n",
662 |        "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
663 |        "
\n" 664 | ], 665 | "text/plain": [ 666 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 667 | "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", 668 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 669 | "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", 670 | "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" 671 | ] 672 | }, 673 | "metadata": {}, 674 | "output_type": "display_data" 675 | }, 676 | { 677 | "data": { 678 | "text/html": [ 679 | "
Model: \"gemma_causal_lm\"\n",
680 |        "
\n" 681 | ], 682 | "text/plain": [ 683 | "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" 684 | ] 685 | }, 686 | "metadata": {}, 687 | "output_type": "display_data" 688 | }, 689 | { 690 | "data": { 691 | "text/html": [ 692 | "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
693 |        "┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃\n",
694 |        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
695 |        "│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │\n",
696 |        "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
697 |        "│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │\n",
698 |        "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
699 |        "│ gemma_backbone                │ (None, None, 2304)        │   2,617,270,528 │ padding_mask[0][0],        │\n",
700 |        "│ (GemmaBackbone)               │                           │                 │ token_ids[0][0]            │\n",
701 |        "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
702 |        "│ token_embedding               │ (None, None, 256000)      │     589,824,000 │ gemma_backbone[0][0]       │\n",
703 |        "│ (ReversibleEmbedding)         │                           │                 │                            │\n",
704 |        "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
705 |        "
\n" 706 | ], 707 | "text/plain": [ 708 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 709 | "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", 710 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 711 | "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", 712 | "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", 713 | "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", 714 | "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", 715 | "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,617,270,528\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", 716 | "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", 717 | "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", 718 | "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", 719 | "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", 720 | "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" 721 | ] 722 | }, 723 | "metadata": {}, 724 | "output_type": "display_data" 725 | }, 726 | { 727 | "data": { 728 | "text/html": [ 729 | "
 Total params: 2,617,270,528 (9.75 GB)\n",
730 |        "
\n" 731 | ], 732 | "text/plain": [ 733 | "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,617,270,528\u001b[0m (9.75 GB)\n" 734 | ] 735 | }, 736 | "metadata": {}, 737 | "output_type": "display_data" 738 | }, 739 | { 740 | "data": { 741 | "text/html": [ 742 | "
 Trainable params: 2,928,640 (11.17 MB)\n",
743 |        "
\n" 744 | ], 745 | "text/plain": [ 746 | "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,928,640\u001b[0m (11.17 MB)\n" 747 | ] 748 | }, 749 | "metadata": {}, 750 | "output_type": "display_data" 751 | }, 752 | { 753 | "data": { 754 | "text/html": [ 755 | "
 Non-trainable params: 2,614,341,888 (9.74 GB)\n",
756 |        "
\n" 757 | ], 758 | "text/plain": [ 759 | "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" 760 | ] 761 | }, 762 | "metadata": {}, 763 | "output_type": "display_data" 764 | } 765 | ], 766 | "source": [ 767 | "# Enable LoRA for the model and set the LoRA rank to 4.\n", 768 | "gemma_lm.backbone.enable_lora(rank=4)\n", 769 | "gemma_lm.summary()" 770 | ] 771 | }, 772 | { 773 | "cell_type": "markdown", 774 | "metadata": { 775 | "id": "hQQ47kcdpbZ9" 776 | }, 777 | "source": [ 778 | "請注意,啟用 LoRA 會大幅減少可訓練參數的數量(從 26 億減少到 290 萬)。\n" 779 | ] 780 | }, 781 | { 782 | "cell_type": "code", 783 | "execution_count": 10, 784 | "metadata": { 785 | "colab": { 786 | "base_uri": "https://localhost:8080/" 787 | }, 788 | "executionInfo": { 789 | "elapsed": 1909922, 790 | "status": "ok", 791 | "timestamp": 1725686370505, 792 | "user": { 793 | "displayName": "劉育維", 794 | "userId": "07932650701621055368" 795 | }, 796 | "user_tz": -480 797 | }, 798 | "id": "_Peq7TnLtHse", 799 | "outputId": "bf8af61d-3c6c-4485-fc56-d36dc9fe2950" 800 | }, 801 | "outputs": [ 802 | { 803 | "name": "stdout", 804 | "output_type": "stream", 805 | "text": [ 806 | "Epoch 1/2\n", 807 | "\u001b[1m1000/1000\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m983s\u001b[0m 941ms/step - loss: 1.1238 - sparse_categorical_accuracy: 0.4821\n", 808 | "Epoch 2/2\n", 809 | "\u001b[1m1000/1000\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m923s\u001b[0m 896ms/step - loss: 0.9916 - sparse_categorical_accuracy: 0.5133\n" 810 | ] 811 | }, 812 | { 813 | "data": { 814 | "text/plain": [ 815 | "" 816 | ] 817 | }, 818 | "execution_count": 10, 819 | "metadata": {}, 820 | "output_type": "execute_result" 821 | } 822 | ], 823 | "source": [ 824 | "# Limit the input sequence length to 256 (to control memory usage).\n", 825 | "gemma_lm.preprocessor.sequence_length = 256\n", 826 | "# Use AdamW (a common optimizer for transformer models).\n", 827 | "optimizer = keras.optimizers.AdamW(\n", 828 | " learning_rate=5e-5,\n", 829 | " weight_decay=0.01,\n", 830 | ")\n", 831 | "# Exclude layernorm and bias terms from decay.\n", 832 | "optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n", 833 | "\n", 834 | "gemma_lm.compile(\n", 835 | " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", 836 | " optimizer=optimizer,\n", 837 | " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", 838 | ")\n", 839 | "gemma_lm.fit(data, epochs=1, batch_size=1)" 840 | ] 841 | }, 842 | { 843 | "cell_type": "markdown", 844 | "metadata": { 845 | "id": "bx3m8f1dB7nk" 846 | }, 847 | "source": [ 848 | "### 關於在 NVIDIA GPU 上進行混合精度微調的注意事項\n", 849 | "\n", 850 | "建議在微調時使用全精度。在 NVIDIA GPU 上進行微調時,請注意,您可以使用混合精度(`keras.mixed_precision.set_global_policy('mixed_bfloat16')`)來加快訓練速度,同時對訓練品質的影響最小。混合精度微調會消耗更多的記憶體,因此僅適用於較大的 GPU。\n", 851 | "\n", 852 | "在推理時,使用半精度(`keras.config.set_floatx(\"bfloat16\")`)即可節省記憶體,而混合精度則不適用。\n" 853 | ] 854 | }, 855 | { 856 | "cell_type": "code", 857 | "execution_count": 11, 858 | "metadata": { 859 | "executionInfo": { 860 | "elapsed": 19, 861 | "status": "ok", 862 | "timestamp": 1725686370506, 863 | "user": { 864 | "displayName": "劉育維", 865 | "userId": "07932650701621055368" 866 | }, 867 | "user_tz": -480 868 | }, 869 | "id": "T0lHxEDX03gp" 870 | }, 871 | "outputs": [], 872 | "source": [ 873 | "# Uncomment the line below if you want to enable mixed precision training on GPUs\n", 874 | "# keras.mixed_precision.set_global_policy('mixed_bfloat16')" 875 | ] 876 | }, 877 | { 878 | "cell_type": "markdown", 879 | "metadata": { 880 | "id": "4yd-1cNw1dTn" 881 | }, 882 | "source": [ 883 | "## 微調後的推理\n", 884 | "\n", 885 | "微調後的回應將遵循提示中提供的指令。\n" 886 | ] 887 | }, 888 | { 889 | "cell_type": "markdown", 890 | "metadata": { 891 | "id": "H55JYJ1a1Kos" 892 | }, 893 | "source": [ 894 | "### Europe Trip Prompt" 895 | ] 896 | }, 897 | { 898 | "cell_type": "code", 899 | "execution_count": 18, 900 | "metadata": { 901 | "colab": { 902 | "base_uri": "https://localhost:8080/" 903 | }, 904 | "executionInfo": { 905 | "elapsed": 19369, 906 | "status": "ok", 907 | "timestamp": 1725687364198, 908 | "user": { 909 | "displayName": "劉育維", 910 | "userId": "07932650701621055368" 911 | }, 912 | "user_tz": -480 913 | }, 914 | "id": "Y7cDJHy8WfCB", 915 | "outputId": "64ffba1b-386b-4ca2-def4-9e7954961750" 916 | }, 917 | "outputs": [ 918 | { 919 | "name": "stdout", 920 | "output_type": "stream", 921 | "text": [ 922 | "Instruction:\n", 923 | "哪一種是魚類?Tope還是Rope?\t\n", 924 | "\n", 925 | "Response:\n", 926 | "rope\n" 927 | ] 928 | } 929 | ], 930 | "source": [ 931 | "prompt = template.format(\n", 932 | " instruction=\"哪一種是魚類?Tope還是Rope?\t\",\n", 933 | " response=\"\",\n", 934 | ")\n", 935 | "sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)\n", 936 | "gemma_lm.compile(sampler=sampler)\n", 937 | "print(gemma_lm.generate(prompt, max_length=256))" 938 | ] 939 | }, 940 | { 941 | "cell_type": "markdown", 942 | "metadata": { 943 | "id": "I8kFG12l0mVe" 944 | }, 945 | "source": [ 946 | "請注意,為了示範用途,本教學僅在資料集的一小部分上進行一次 epoch 的微調,且使用較低的 LoRA 秩值。若要從微調後的模型中獲得更好的回應,您可以嘗試:\n", 947 | "\n", 948 | "1. 增加微調資料集的大小\n", 949 | "2. 增加訓練步數(epochs)\n", 950 | "3. 設定更高的 LoRA 秩值\n", 951 | "4. 修改超參數值,例如 `learning_rate`(學習率)和 `weight_decay`(權重衰減)。\n" 952 | ] 953 | }, 954 | { 955 | "cell_type": "markdown", 956 | "metadata": { 957 | "id": "gSsRdeiof_rJ" 958 | }, 959 | "source": [ 960 | "## 總結與下一步\n", 961 | "\n", 962 | "本教學介紹了如何使用 KerasNLP 對 Gemma 模型進行 LoRA 微調。接下來可以查看以下文件:\n", 963 | "\n", 964 | "* 學習如何 [使用 Gemma 模型生成文本](https://ai.google.dev/gemma/docs/get_started)。\n", 965 | "* 學習如何進行 [Gemma 模型的分佈式微調和推理](https://ai.google.dev/gemma/docs/distributed_tuning)。\n", 966 | "* 學習如何 [使用 Vertex AI 配合 Gemma 開放模型](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma)。\n", 967 | "* 學習如何 [使用 KerasNLP 微調 Gemma 並部署到 Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)。\n" 968 | ] 969 | } 970 | ], 971 | "metadata": { 972 | "accelerator": "GPU", 973 | "colab": { 974 | "provenance": [] 975 | }, 976 | "kernelspec": { 977 | "display_name": "Python 3", 978 | "name": "python3" 979 | } 980 | }, 981 | "nbformat": 4, 982 | "nbformat_minor": 0 983 | } 984 | --------------------------------------------------------------------------------