├── .github
├── CONTRIBUTING.md
├── ISSUE_TEMPLATE
│ ├── config.yml
│ ├── feature-request.yml
│ ├── report-bug.yml
│ ├── report-docker.yml
│ ├── report-localhost.yml
│ ├── report-others.yml
│ └── report-server.yml
├── pull_request_template.md
└── workflows
│ ├── Build_Docker.yml
│ └── Release_docker.yml
├── .gitignore
├── CITATION.cff
├── ChuanhuChatbot.py
├── Dockerfile
├── LICENSE
├── README.md
├── config_example.json
├── configs
└── ds_config_chatbot.json
├── locale
├── en_US.json
├── extract_locale.py
├── ja_JP.json
├── ko_KR.json
├── ru_RU.json
├── sv_SE.json
├── vi_VN.json
└── zh_CN.json
├── modules
├── __init__.py
├── config.py
├── index_func.py
├── models
│ ├── Azure.py
│ ├── ChatGLM.py
│ ├── ChuanhuAgent.py
│ ├── Claude.py
│ ├── DALLE3.py
│ ├── ERNIE.py
│ ├── GoogleGemini.py
│ ├── GoogleGemma.py
│ ├── GooglePaLM.py
│ ├── Groq.py
│ ├── LLaMA.py
│ ├── MOSS.py
│ ├── Ollama.py
│ ├── OpenAIInstruct.py
│ ├── OpenAIVision.py
│ ├── Qwen.py
│ ├── StableLM.py
│ ├── XMChat.py
│ ├── __init__.py
│ ├── base_model.py
│ ├── configuration_moss.py
│ ├── inspurai.py
│ ├── midjourney.py
│ ├── minimax.py
│ ├── modeling_moss.py
│ ├── models.py
│ ├── spark.py
│ └── tokenization_moss.py
├── overwrites.py
├── pdf_func.py
├── presets.py
├── repo.py
├── shared.py
├── train_func.py
├── utils.py
├── webui.py
└── webui_locale.py
├── readme
├── README_en.md
├── README_ja.md
├── README_ko.md
└── README_ru.md
├── requirements.txt
├── requirements_advanced.txt
├── run_Linux.sh
├── run_Windows.bat
├── run_macOS.command
├── templates
├── 1 中文提示词.json
├── 2 English Prompts.csv
├── 3 繁體提示詞.json
├── 4 川虎的Prompts.json
├── 5 日本語Prompts.json
└── 6 Russian Prompts.json
└── web_assets
├── chatbot.png
├── favicon.ico
├── html
├── appearance_switcher.html
├── billing_info.html
├── chatbot_header_btn.html
├── chatbot_more.html
├── chatbot_placeholder.html
├── close_btn.html
├── footer.html
├── func_nav.html
├── header_title.html
├── update.html
└── web_config.html
├── icon
├── any-icon-512.png
└── mask-icon-512.png
├── javascript
├── ChuanhuChat.js
├── chat-history.js
├── chat-list.js
├── external-scripts.js
├── fake-gradio.js
├── file-input.js
├── localization.js
├── message-button.js
├── sliders.js
├── updater.js
├── user-info.js
├── utils.js
└── webui.js
├── manifest.json
├── model_logos
├── claude-3.jpg
├── gemini.svg
├── meta.webp
├── openai-black.webp
└── openai-green.webp
├── stylesheet
├── ChuanhuChat.css
├── chatbot.css
├── custom-components.css
├── markdown.css
└── override-gradio.css
└── user.png
/.github/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # 如何做出贡献
2 |
3 | 感谢您对 **川虎Chat** 的关注!感谢您投入时间为我们的项目做出贡献!
4 |
5 | 在开始之前,您可以阅读我们的以下简短提示。更多信息您可以点击链接查阅。
6 |
7 | ## GitHub 新手?
8 |
9 | 以下是 GitHub 的一些资源,如果您是GitHub新手,它们可帮助您开始为开源项目做贡献:
10 |
11 | - [GitHub上为开源做出贡献的方法](https://docs.github.com/en/get-started/exploring-projects-on-github/finding-ways-to-contribute-to-open-source-on-github)
12 | - [设置Git](https://docs.github.com/en/get-started/quickstart/set-up-git)
13 | - [GitHub工作流](https://docs.github.com/en/get-started/quickstart/github-flow)
14 | - [使用拉取请求](https://docs.github.com/en/github/collaborating-with-pull-requests)
15 |
16 | ## 提交 Issues
17 |
18 | 是的!提交ISSUE其实是您为项目做出贡献的一种方式!但需要您提出合理的ISSUE才是对项目有帮助的。
19 |
20 | 我们的[常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题)中描述了您应当怎样提出一个不重复的ISSUE,以及什么情况应当提ISSUE,什么情况应当在讨论区发问。
21 |
22 | **请注意,ISSUE不是项目的评论区。**
23 |
24 | > **Note**
25 | >
26 | > 另外,请注意“问题”一词表示“question”和“problem”的区别。
27 | > 如果您需要报告项目本身实际的技术问题、故障或错误(problem),那么欢迎提交一个新的 issue。但是,如果您只是碰到了一些自己无法解决的问题需要向其他用户或我们提问(question),那么最好的选择是在讨论区中发布一个新的帖子。 如果您不确定,请首先考虑在讨论区提问。
28 | >
29 | > 目前,我们默认了您发在 issue 中的问题是一个 question,但我们希望避免再在 issue 中见到类似“我该怎么操作?”的提问QAQ。
30 |
31 | ## 提交 Pull Request
32 |
33 | 如果您具备一定能力,您可以修改本项目的源代码,并提交一个 pull request!合并之后,您的名字将会出现在 CONTRIBUTORS 中~
34 |
35 | 我们的[贡献指南](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/贡献指南)详细地写出了您每一步应当做什么~ 如果您希望提交源代码的更改,快去看看吧~
36 |
37 | > **Note**
38 | >
39 | > 我们不会强制要求您符合我们的规范,但希望您可以减轻我们的工作。
40 |
41 | ## 参与讨论
42 |
43 | 讨论区是我们进行对话的地方。
44 |
45 | 如果您想帮助有一个很棒的新想法,或者想分享您的使用技巧,请加入我们的讨论(Discussion)!同时,许多用户会在讨论区提出他们的疑问,如果您能为他们提供解答,我们也将无比感激!
46 |
47 | -----
48 |
49 | 再次感谢您看到这里!感谢您为我们项目做出的贡献!
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled:
2 | contact_links:
3 | - name: 讨论区
4 | url: https://github.com/GaiZhenbiao/ChuanhuChatGPT/discussions
5 | about: 如果遇到疑问,请优先前往讨论区提问~
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature-request.yml:
--------------------------------------------------------------------------------
1 | name: 功能请求
2 | description: "请求更多功能!"
3 | title: "[功能请求]: "
4 | labels: ["feature request"]
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: 您可以请求更多功能!麻烦您花些时间填写以下信息~
9 | - type: textarea
10 | attributes:
11 | label: 相关问题
12 | description: 该功能请求是否与某个问题相关?
13 | placeholder: 发送信息后有概率ChatGPT返回error,刷新后又要重新打一遍文字,较为麻烦
14 | validations:
15 | required: false
16 | - type: textarea
17 | attributes:
18 | label: 可能的解决办法
19 | description: 如果可以,给出一个解决思路~ 或者,你希望实现什么功能?
20 | placeholder: 发送失败后在输入框或聊天气泡保留发送的文本
21 | validations:
22 | required: true
23 | - type: checkboxes
24 | attributes:
25 | label: 帮助开发
26 | description: 如果您能帮助开发并提交一个pull request,那再好不过了!
27 | 参考:[贡献指南](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/贡献指南)
28 | options:
29 | - label: 我愿意协助开发!
30 | required: false
31 | - type: textarea
32 | attributes:
33 | label: 补充说明
34 | description: |
35 | 链接?参考资料?任何更多背景信息!
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/report-bug.yml:
--------------------------------------------------------------------------------
1 | name: 报告BUG
2 | description: "报告一个bug,且您确信这是bug而不是您的问题"
3 | title: "[Bug]: "
4 | labels: ["bug"]
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | 感谢提交 issue! 请尽可能完整填写以下信息,帮助我们更好地定位问题~
10 | **在一切开始之前,请确保您已经阅读过 [常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题) 页面**。
11 | 如果您确信这是一个我们的 bug,而不是因为您的原因部署失败,欢迎提交该issue!
12 | 如果您不能确定这是bug还是您的问题,请选择 [其他类型的issue模板](https://github.com/GaiZhenbiao/ChuanhuChatGPT/issues/new/choose)。
13 |
14 | ------
15 | - type: checkboxes
16 | attributes:
17 | label: 这个bug是否已存在现有issue了?
18 | description: 请搜索全部issue和[常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题)以查看您想报告的issue是否已存在。
19 | options:
20 | - label: 我确认没有已有issue,且已阅读**常见问题**。
21 | required: true
22 | - type: textarea
23 | id: what-happened
24 | attributes:
25 | label: 错误表现
26 | description: 请描述您遇到的bug。
27 | 提示:如果可以,也请提供错误的截图,如本地部署的网页截图与终端错误报告的截图。
28 | 如果可以,也请提供`.json`格式的对话记录。
29 | placeholder: 发生什么事了?
30 | validations:
31 | required: true
32 | - type: textarea
33 | attributes:
34 | label: 复现操作
35 | description: 你之前干了什么,然后出现了bug呢?
36 | placeholder: |
37 | 1. 正常完成本地部署
38 | 2. 选取GPT3.5-turbo模型,正确填写API
39 | 3. 在对话框中要求 ChatGPT “以LaTeX格式输出三角函数”
40 | 4. ChatGPT 输出部分内容后程序被自动终止
41 | validations:
42 | required: true
43 | - type: textarea
44 | id: logs
45 | attributes:
46 | label: 错误日志
47 | description: 请将终端中的主要错误报告粘贴至此处。
48 | render: shell
49 | - type: textarea
50 | attributes:
51 | label: 运行环境
52 | description: |
53 | 网页底部会列出您运行环境的版本信息,请务必填写。以下是一个例子:
54 | - **OS**: Windows11 22H2
55 | - **Browser**: Chrome
56 | - **Gradio version**: 3.22.1
57 | - **Python version**: 3.11.1
58 | value: |
59 | - OS:
60 | - Browser:
61 | - Gradio version:
62 | - Python version:
63 | validations:
64 | required: false
65 | - type: checkboxes
66 | attributes:
67 | label: 帮助解决
68 | description: 如果您能够并愿意协助解决该问题,向我们提交一个pull request,那再好不过了!
69 | 参考:[贡献指南](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/贡献指南)
70 | options:
71 | - label: 我愿意协助解决!
72 | required: false
73 | - type: textarea
74 | attributes:
75 | label: 补充说明
76 | description: 链接?参考资料?任何更多背景信息!
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/report-docker.yml:
--------------------------------------------------------------------------------
1 | name: Docker部署错误
2 | description: "报告使用 Docker 部署时的问题或错误"
3 | title: "[Docker]: "
4 | labels: ["question","docker deployment"]
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | 感谢提交 issue! 请尽可能完整填写以下信息,帮助我们更好地定位问题~
10 | **在一切开始之前,请确保您已经阅读过 [常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题) 页面**,查看它是否已经对您的问题做出了解答。
11 | 如果没有,请检索 [issue](https://github.com/GaiZhenbiao/ChuanhuChatGPT/issues) 与 [discussion](https://github.com/GaiZhenbiao/ChuanhuChatGPT/discussions) ,查看有没有相同或类似的问题。
12 |
13 | ------
14 | - type: checkboxes
15 | attributes:
16 | label: 是否已存在现有反馈与解答?
17 | description: 请搜索issue、discussion和[常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题)以查看您想报告的issue是否已存在。
18 | options:
19 | - label: 我确认没有已有issue或discussion,且已阅读**常见问题**。
20 | required: true
21 | - type: checkboxes
22 | attributes:
23 | label: 是否是一个代理配置相关的疑问?
24 | description: 请不要提交代理配置相关的issue。如有疑问请前往 [讨论区](https://github.com/GaiZhenbiao/ChuanhuChatGPT/discussions)。
25 | options:
26 | - label: 我确认这不是一个代理配置相关的疑问。
27 | required: true
28 | - type: textarea
29 | id: what-happened
30 | attributes:
31 | label: 错误描述
32 | description: 请描述您遇到的错误或问题。
33 | 提示:如果可以,也请提供错误的截图,如本地部署的网页截图与终端错误报告的截图。
34 | 如果可以,也请提供`.json`格式的对话记录。
35 | placeholder: 发生什么事了?
36 | validations:
37 | required: true
38 | - type: textarea
39 | attributes:
40 | label: 复现操作
41 | description: 你之前干了什么,然后出现了错误呢?
42 | placeholder: |
43 | 1. 正常完成本地部署
44 | 2. 选取GPT3.5-turbo模型,正确填写API
45 | 3. 在对话框中要求 ChatGPT “以LaTeX格式输出三角函数”
46 | 4. ChatGPT 输出部分内容后程序被自动终止
47 | validations:
48 | required: true
49 | - type: textarea
50 | id: logs
51 | attributes:
52 | label: 错误日志
53 | description: 请将终端中的主要错误报告粘贴至此处。
54 | render: shell
55 | - type: textarea
56 | attributes:
57 | label: 运行环境
58 | description: |
59 | 网页底部会列出您运行环境的版本信息,请务必填写。以下是一个例子:
60 | - **OS**: Linux/amd64
61 | - **Docker version**: 1.8.2
62 | - **Gradio version**: 3.22.1
63 | - **Python version**: 3.11.1
64 | value: |
65 | - OS:
66 | - Docker version:
67 | - Gradio version:
68 | - Python version:
69 | validations:
70 | required: false
71 | - type: textarea
72 | attributes:
73 | label: 补充说明
74 | description: 链接?参考资料?任何更多背景信息!
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/report-localhost.yml:
--------------------------------------------------------------------------------
1 | name: 本地部署错误
2 | description: "报告本地部署时的问题或错误(小白首选)"
3 | title: "[本地部署]: "
4 | labels: ["question","localhost deployment"]
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | 感谢提交 issue! 请尽可能完整填写以下信息,帮助我们更好地定位问题~
10 | **在一切开始之前,请确保您已经阅读过 [常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题) 页面**,查看它是否已经对您的问题做出了解答。
11 | 如果没有,请检索 [issue](https://github.com/GaiZhenbiao/ChuanhuChatGPT/issues) 与 [discussion](https://github.com/GaiZhenbiao/ChuanhuChatGPT/discussions) ,查看有没有相同或类似的问题。
12 |
13 | **另外,请不要再提交 `Something went wrong Expecting value: line 1 column 1 (char 0)` 和 代理配置 相关的问题,请再看一遍 [常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题) 页,实在不行请前往 discussion。**
14 |
15 | ------
16 | - type: checkboxes
17 | attributes:
18 | label: 是否已存在现有反馈与解答?
19 | description: 请搜索issue、discussion和[常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题)以查看您想报告的issue是否已存在。
20 | options:
21 | - label: 我确认没有已有issue或discussion,且已阅读**常见问题**。
22 | required: true
23 | - type: checkboxes
24 | attributes:
25 | label: 是否是一个代理配置相关的疑问?
26 | description: 请不要提交代理配置相关的issue。如有疑问请前往 [讨论区](https://github.com/GaiZhenbiao/ChuanhuChatGPT/discussions)。
27 | options:
28 | - label: 我确认这不是一个代理配置相关的疑问。
29 | required: true
30 | - type: textarea
31 | id: what-happened
32 | attributes:
33 | label: 错误描述
34 | description: 请描述您遇到的错误或问题。
35 | 提示:如果可以,也请提供错误的截图,如本地部署的网页截图与终端错误报告的截图。
36 | 如果可以,也请提供`.json`格式的对话记录。
37 | placeholder: 发生什么事了?
38 | validations:
39 | required: true
40 | - type: textarea
41 | attributes:
42 | label: 复现操作
43 | description: 你之前干了什么,然后出现了错误呢?
44 | placeholder: |
45 | 1. 正常完成本地部署
46 | 2. 选取GPT3.5-turbo模型,正确填写API
47 | 3. 在对话框中要求 ChatGPT “以LaTeX格式输出三角函数”
48 | 4. ChatGPT 输出部分内容后程序被自动终止
49 | validations:
50 | required: true
51 | - type: textarea
52 | id: logs
53 | attributes:
54 | label: 错误日志
55 | description: 请将终端中的主要错误报告粘贴至此处。
56 | render: shell
57 | - type: textarea
58 | attributes:
59 | label: 运行环境
60 | description: |
61 | 网页底部会列出您运行环境的版本信息,请务必填写。以下是一个例子:
62 | - **OS**: Windows11 22H2
63 | - **Browser**: Chrome
64 | - **Gradio version**: 3.22.1
65 | - **Python version**: 3.11.1
66 | value: |
67 | - OS:
68 | - Browser:
69 | - Gradio version:
70 | - Python version:
71 | render: markdown
72 | validations:
73 | required: false
74 | - type: textarea
75 | attributes:
76 | label: 补充说明
77 | description: 链接?参考资料?任何更多背景信息!
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/report-others.yml:
--------------------------------------------------------------------------------
1 | name: 其他错误
2 | description: "报告其他问题(如 Hugging Face 中的 Space 等)"
3 | title: "[其他]: "
4 | labels: ["question"]
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | 感谢提交 issue! 请尽可能完整填写以下信息,帮助我们更好地定位问题~
10 | **在一切开始之前,请确保您已经阅读过 [常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题) 页面**,查看它是否已经对您的问题做出了解答。
11 | 如果没有,请检索 [issue](https://github.com/GaiZhenbiao/ChuanhuChatGPT/issues) 与 [discussion](https://github.com/GaiZhenbiao/ChuanhuChatGPT/discussions) ,查看有没有相同或类似的问题。
12 |
13 | ------
14 | - type: checkboxes
15 | attributes:
16 | label: 是否已存在现有反馈与解答?
17 | description: 请搜索issue、discussion和[常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题)以查看您想报告的issue是否已存在。
18 | options:
19 | - label: 我确认没有已有issue或discussion,且已阅读**常见问题**。
20 | required: true
21 | - type: textarea
22 | id: what-happened
23 | attributes:
24 | label: 错误描述
25 | description: 请描述您遇到的错误或问题。
26 | 提示:如果可以,也请提供错误的截图,如本地部署的网页截图与终端错误报告的截图。
27 | 如果可以,也请提供`.json`格式的对话记录。
28 | placeholder: 发生什么事了?
29 | validations:
30 | required: true
31 | - type: textarea
32 | attributes:
33 | label: 复现操作
34 | description: 你之前干了什么,然后出现了错误呢?
35 | placeholder: |
36 | 1. 正常完成本地部署
37 | 2. 选取GPT3.5-turbo模型,正确填写API
38 | 3. 在对话框中要求 ChatGPT “以LaTeX格式输出三角函数”
39 | 4. ChatGPT 输出部分内容后程序被自动终止
40 | validations:
41 | required: true
42 | - type: textarea
43 | id: logs
44 | attributes:
45 | label: 错误日志
46 | description: 请将终端中的主要错误报告粘贴至此处。
47 | render: shell
48 | - type: textarea
49 | attributes:
50 | label: 运行环境
51 | description: |
52 | 网页底部会列出您运行环境的版本信息,请务必填写。以下是一个例子:
53 | - **OS**: Windows11 22H2
54 | - **Browser**: Chrome
55 | - **Gradio version**: 3.22.1
56 | - **Python version**: 3.11.1
57 | value: |
58 | - OS:
59 | - Browser:
60 | - Gradio version:
61 | - Python version:
62 | (或您的其他运行环境信息)
63 | validations:
64 | required: false
65 | - type: textarea
66 | attributes:
67 | label: 补充说明
68 | description: 链接?参考资料?任何更多背景信息!
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/report-server.yml:
--------------------------------------------------------------------------------
1 | name: 服务器部署错误
2 | description: "报告在远程服务器上部署时的问题或错误"
3 | title: "[远程部署]: "
4 | labels: ["question","server deployment"]
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | 感谢提交 issue! 请尽可能完整填写以下信息,帮助我们更好地定位问题~
10 | **在一切开始之前,请确保您已经阅读过 [常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题) 页面**,查看它是否已经对您的问题做出了解答。
11 | 如果没有,请检索 [issue](https://github.com/GaiZhenbiao/ChuanhuChatGPT/issues) 与 [discussion](https://github.com/GaiZhenbiao/ChuanhuChatGPT/discussions) ,查看有没有相同或类似的问题。
12 |
13 | ------
14 | - type: checkboxes
15 | attributes:
16 | label: 是否已存在现有反馈与解答?
17 | description: 请搜索issue、discussion和[常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题)以查看您想报告的issue是否已存在。
18 | options:
19 | - label: 我确认没有已有issue或discussion,且已阅读**常见问题**。
20 | required: true
21 | - type: checkboxes
22 | attributes:
23 | label: 是否是一个代理配置相关的疑问?
24 | description: 请不要提交代理配置相关的issue。如有疑问请前往 [讨论区](https://github.com/GaiZhenbiao/ChuanhuChatGPT/discussions)。
25 | options:
26 | - label: 我确认这不是一个代理配置相关的疑问。
27 | required: true
28 | - type: textarea
29 | id: what-happened
30 | attributes:
31 | label: 错误描述
32 | description: 请描述您遇到的错误或问题。
33 | 提示:如果可以,也请提供错误的截图,如本地部署的网页截图与终端错误报告的截图。
34 | 如果可以,也请提供`.json`格式的对话记录。
35 | placeholder: 发生什么事了?
36 | validations:
37 | required: true
38 | - type: textarea
39 | attributes:
40 | label: 复现操作
41 | description: 你之前干了什么,然后出现了错误呢?
42 | placeholder: |
43 | 1. 正常完成本地部署
44 | 2. 选取GPT3.5-turbo模型,正确填写API
45 | 3. 在对话框中要求 ChatGPT “以LaTeX格式输出三角函数”
46 | 4. ChatGPT 输出部分内容后程序被自动终止
47 | validations:
48 | required: true
49 | - type: textarea
50 | id: logs
51 | attributes:
52 | label: 错误日志
53 | description: 请将终端中的主要错误报告粘贴至此处。
54 | render: shell
55 | - type: textarea
56 | attributes:
57 | label: 运行环境
58 | description: |
59 | 网页底部会列出您运行环境的版本信息,请务必填写。以下是一个例子:
60 | - **OS**: Windows11 22H2
61 | - **Docker version**: 1.8.2
62 | - **Gradio version**: 3.22.1
63 | - **Python version**: 3.11.1
64 | value: |
65 | - OS:
66 | - Server:
67 | - Gradio version:
68 | - Python version:
69 | validations:
70 | required: false
71 | - type: textarea
72 | attributes:
73 | label: 补充说明
74 | description: 链接?参考资料?任何更多背景信息!
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 |
13 |
14 | ## 作者自述
15 | ### 描述
16 | 描述您的 pull request 所做的更改。
17 | 另外请附上相关程序运行时的截图(before & after),以直观地展现您的更改达成的效果。
18 |
19 | ### 相关问题
20 | (如有)请列出与此拉取请求相关的issue。
21 |
22 | ### 补充信息
23 | (如有)请提供任何其他信息或说明,有助于其他贡献者理解您的更改。
24 | 如果您提交的是 draft pull request,也请在这里写明开发进度。
25 |
26 |
27 |
33 |
--------------------------------------------------------------------------------
/.github/workflows/Build_Docker.yml:
--------------------------------------------------------------------------------
1 | name: Build Docker when Push
2 |
3 | on:
4 | push:
5 | branches:
6 | - "main"
7 |
8 | jobs:
9 | docker:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Checkout
13 | uses: actions/checkout@v4
14 |
15 | - name: Set commit SHA
16 | run: echo "COMMIT_SHA=$(echo ${{ github.sha }} | cut -c 1-7)" >> ${GITHUB_ENV}
17 |
18 | - name: Set up QEMU
19 | uses: docker/setup-qemu-action@v2
20 |
21 | - name: Set up Docker Buildx
22 | uses: docker/setup-buildx-action@v3
23 |
24 | - name: Login to GitHub Container Registry
25 | uses: docker/login-action@v2
26 | with:
27 | registry: ghcr.io
28 | username: ${{ github.repository_owner }}
29 | password: ${{ secrets.MY_TOKEN }}
30 |
31 | - name: Owner names
32 | run: |
33 | GITOWNER=$(echo ${{ github.repository_owner }} | tr '[:upper:]' '[:lower:]')
34 | echo "GITOWNER=$GITOWNER" >> ${GITHUB_ENV}
35 |
36 | - name: Build and export
37 | uses: docker/build-push-action@v5
38 | with:
39 | context: .
40 | platforms: linux/amd64,linux/arm64
41 | push: false
42 | tags: |
43 | ghcr.io/${{ env.GITOWNER }}/chuanhuchatgpt:latest
44 | ghcr.io/${{ env.GITOWNER }}/chuanhuchatgpt:${{ github.sha }}
45 | outputs: type=oci,dest=/tmp/myimage-${{ env.COMMIT_SHA }}.tar
46 |
47 | - name: Upload artifact
48 | uses: actions/upload-artifact@v3
49 | with:
50 | name: chuanhuchatgpt-${{ env.COMMIT_SHA }}
51 | path: /tmp/myimage-${{ env.COMMIT_SHA }}.tar
52 |
--------------------------------------------------------------------------------
/.github/workflows/Release_docker.yml:
--------------------------------------------------------------------------------
1 | name: Build and Push Docker when Release
2 |
3 | on:
4 | release:
5 | types: [published]
6 | workflow_dispatch:
7 |
8 | jobs:
9 | docker:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Checkout
13 | uses: actions/checkout@v3
14 | with:
15 | ref: ${{ github.event.release.target_commitish }}
16 |
17 | - name: Set release tag
18 | run: |
19 | echo "RELEASE_TAG=${{ github.event.release.tag_name }}" >> ${GITHUB_ENV}
20 |
21 | - name: Set up QEMU
22 | uses: docker/setup-qemu-action@v2
23 |
24 | - name: Set up Docker Buildx
25 | uses: docker/setup-buildx-action@v2
26 |
27 | - name: Login to Docker Hub
28 | uses: docker/login-action@v2
29 | with:
30 | username: ${{ secrets.DOCKERHUB_USERNAME }}
31 | password: ${{ secrets.DOCKERHUB_TOKEN }}
32 |
33 | - name: Login to GitHub Container Registry
34 | uses: docker/login-action@v2
35 | with:
36 | registry: ghcr.io
37 | username: ${{ github.repository_owner }}
38 | password: ${{ secrets.MY_TOKEN }}
39 |
40 | - name: Owner names
41 | run: |
42 | GITOWNER=$(echo ${{ github.repository_owner }} | tr '[:upper:]' '[:lower:]')
43 | echo "GITOWNER=$GITOWNER" >> ${GITHUB_ENV}
44 |
45 | - name: Build and push
46 | uses: docker/build-push-action@v4
47 | with:
48 | context: .
49 | platforms: linux/amd64,linux/arm64
50 | push: true
51 | tags: |
52 | ghcr.io/${{ env.GITOWNER }}/chuanhuchatgpt:latest
53 | ghcr.io/${{ env.GITOWNER }}/chuanhuchatgpt:${{ env.RELEASE_TAG }}
54 | ${{ secrets.DOCKERHUB_USERNAME }}/chuanhuchatgpt:latest
55 | ${{ secrets.DOCKERHUB_USERNAME }}/chuanhuchatgpt:${{ env.RELEASE_TAG }}
56 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 | history/
30 | index/
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 | # Mac system file
134 | **/.DS_Store
135 |
136 | #vscode
137 | .vscode
138 |
139 | # 配置文件/模型文件
140 | api_key.txt
141 | config.json
142 | auth.json
143 | .models/
144 | models/*
145 | lora/
146 | .idea
147 | templates/*
148 | files/
149 | tmp/
150 |
151 | scripts/
152 | include/
153 | pyvenv.cfg
154 |
155 | create_release.sh
156 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | title: Chuanhu Chat
3 | message: >-
4 | If you use this software, please cite it using these
5 | metadata.
6 | type: software
7 | authors:
8 | - given-names: Chuanhu
9 | orcid: https://orcid.org/0000-0001-8954-8598
10 | - given-names: MZhao
11 | orcid: https://orcid.org/0000-0003-2298-6213
12 | - given-names: Keldos
13 | orcid: https://orcid.org/0009-0005-0357-272X
14 | repository-code: 'https://github.com/GaiZhenbiao/ChuanhuChatGPT'
15 | url: 'https://github.com/GaiZhenbiao/ChuanhuChatGPT'
16 | abstract: This software provides a light and easy-to-use interface for ChatGPT API and many LLMs.
17 | license: GPL-3.0
18 | commit: c6c08bc62ef80e37c8be52f65f9b6051a7eea1fa
19 | version: '20230709'
20 | date-released: '2023-07-09'
21 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10-slim-buster as builder
2 |
3 | # Install build essentials, Rust, and additional dependencies
4 | RUN apt-get update \
5 | && apt-get install -y build-essential curl cmake pkg-config libssl-dev \
6 | && apt-get clean \
7 | && rm -rf /var/lib/apt/lists/* \
8 | && curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
9 |
10 | # Add Cargo to PATH
11 | ENV PATH="/root/.cargo/bin:${PATH}"
12 |
13 | # Upgrade pip
14 | RUN pip install --upgrade pip
15 |
16 | COPY requirements.txt .
17 | COPY requirements_advanced.txt .
18 |
19 | # Install Python packages
20 | RUN pip install --user --no-cache-dir -r requirements.txt
21 |
22 | # Uncomment the following line if you want to install advanced requirements
23 | # RUN pip install --user --no-cache-dir -r requirements_advanced.txt
24 |
25 | FROM python:3.10-slim-buster
26 | LABEL maintainer="iskoldt"
27 |
28 | # Copy Rust and Cargo from builder
29 | COPY --from=builder /root/.cargo /root/.cargo
30 | COPY --from=builder /root/.rustup /root/.rustup
31 |
32 | # Copy Python packages from builder
33 | COPY --from=builder /root/.local /root/.local
34 |
35 | # Set up environment
36 | ENV PATH=/root/.local/bin:/root/.cargo/bin:$PATH
37 | ENV RUSTUP_HOME=/root/.rustup
38 | ENV CARGO_HOME=/root/.cargo
39 |
40 | COPY . /app
41 | WORKDIR /app
42 | ENV dockerrun=yes
43 | CMD ["python3", "-u", "ChuanhuChatbot.py","2>&1", "|", "tee", "/var/log/application.log"]
44 | EXPOSE 7860
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
5 |
6 | 川虎 Chat 🐯 Chuanhu Chat
7 |
8 |
9 |
10 |
11 |
12 |
13 |
为ChatGPT等多种LLM提供了一个轻快好用的Web图形界面和众多附加功能
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | 支持 DeepSeek R1 & GPT 4 · 基于文件问答 · LLM本地部署 · 联网搜索 · Agent 助理 · 支持 Fine-tune
26 |
27 |
视频教程
28 | ·
29 |
2.0介绍视频
30 | ||
31 |
在线体验
32 | ·
33 |
一键部署
34 |
35 |
36 |
37 |
38 | [](https://github.com/GaiZhenbiao/ChuanhuChatGPT/assets/51039745/0eee1598-c2fd-41c6-bda9-7b059a3ce6e7?autoplay=1)
39 |
40 | ## 目录
41 |
42 | | [支持模型](#支持模型) | [使用技巧](#使用技巧) | [安装方式](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用教程) | [常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题) | [给作者买可乐🥤](#捐款) | [加入Telegram群组](https://t.me/tkdifferent) |
43 | | --- | --- | --- | --- | --- | --- |
44 |
45 | ## ✨ 5.0 重磅更新!
46 |
47 | 
48 |
49 |
50 | New! 全新的用户界面!精致得不像 Gradio,甚至有毛玻璃效果!
51 |
52 | New! 适配了移动端(包括全面屏手机的挖孔/刘海),层级更加清晰。
53 |
54 | New! 历史记录移到左侧,使用更加方便。并且支持搜索(支持正则)、删除、重命名。
55 |
56 | New! 现在可以让大模型自动命名历史记录(需在设置或配置文件中开启)。
57 |
58 | New! 现在可以将 川虎Chat 作为 PWA 应用程序安装,体验更加原生!支持 Chrome/Edge/Safari 等浏览器。
59 |
60 | New! 图标适配各个平台,看起来更舒服。
61 |
62 | New! 支持 Finetune(微调) GPT 3.5!
63 |
64 |
65 | ## 支持模型
66 |
67 | | API 调用模型 | 备注 | 本地部署模型 | 备注 |
68 | | :---: | --- | :---: | --- |
69 | | [ChatGPT(GPT-4、GPT-4o、o1)](https://chat.openai.com) | 支持微调 gpt-3.5 | [ChatGLM](https://github.com/THUDM/ChatGLM-6B) ([ChatGLM2](https://github.com/THUDM/ChatGLM2-6B)) ([ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b)) ||
70 | | [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service) | | [LLaMA](https://github.com/facebookresearch/llama) | 支持 Lora 模型 |
71 | | [Google Gemini Pro](https://ai.google.dev/gemini-api/docs/api-key?hl=zh-cn) | | [StableLM](https://github.com/Stability-AI/StableLM) ||
72 | | [讯飞星火认知大模型](https://xinghuo.xfyun.cn) | | [MOSS](https://github.com/OpenLMLab/MOSS) ||
73 | | [Inspur Yuan 1.0](https://air.inspur.com/home) | | [通义千问](https://github.com/QwenLM/Qwen/tree/main) ||
74 | | [MiniMax](https://api.minimax.chat/) ||[DeepSeek](https://platform.deepseek.com)||
75 | | [XMChat](https://github.com/MILVLG/xmchat) | 不支持流式传输|||
76 | | [Midjourney](https://www.midjourney.com/) | 不支持流式传输|||
77 | | [Claude](https://www.anthropic.com/) | ✨ 现已支持Claude 3 Opus、Sonnet,Haiku将会在推出后的第一时间支持|||
78 | | DALL·E 3 ||||
79 |
80 | ## 使用技巧
81 |
82 | ### 💪 强力功能
83 | - **川虎助理**:类似 AutoGPT,全自动解决你的问题;
84 | - **在线搜索**:ChatGPT 的数据太旧?给 LLM 插上网络的翅膀;
85 | - **知识库**:让 ChatGPT 帮你量子速读!根据文件回答问题。
86 | - **本地部署LLM**:一键部署,获取属于你自己的大语言模型。
87 | - **GPT 3.5微调**:支持微调 GPT 3.5,让 ChatGPT 更加个性化。
88 | - **[自定义模型](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%A8%A1%E5%9E%8B-Custom-Models)**:灵活地自定义模型,例如对接本地推理服务。
89 |
90 | ### 🤖 System Prompt
91 | - 通过 System Prompt 设定前提条件,可以很有效地进行角色扮演;
92 | - 川虎Chat 预设了Prompt模板,点击`加载Prompt模板`,先选择 Prompt 模板集合,然后在下方选择想要的 Prompt。
93 |
94 | ### 💬 基础对话
95 | - 如果回答不满意,可以使用 `重新生成` 按钮再试一次,或者直接 `删除这轮对话`;
96 | - 输入框支持换行,按 Shift + Enter即可;
97 | - 在输入框按 ↑ ↓ 方向键,可以在发送记录中快速切换;
98 | - 每次新建一个对话太麻烦,试试 `单论对话` 功能;
99 | - 回答气泡旁边的小按钮,不仅能 `一键复制`,还能 `查看Markdown原文`;
100 | - 指定回答语言,让 ChatGPT 固定以某种语言回答。
101 |
102 | ### 📜 对话历史
103 | - 对话历史记录会被自动保存,不用担心问完之后找不到了;
104 | - 多用户历史记录隔离,除了你都看不到;
105 | - 重命名历史记录,方便日后查找;
106 | - New! 魔法般自动命名历史记录,让 LLM 理解对话内容,帮你自动为历史记录命名!
107 | - New! 搜索历史记录,支持正则表达式!
108 |
109 | ### 🖼️ 小而美的体验
110 | - 自研 Small-and-Beautiful 主题,带给你小而美的体验;
111 | - 自动亮暗色切换,给你从早到晚的舒适体验;
112 | - 完美渲染 LaTeX / 表格 / 代码块,支持代码高亮;
113 | - New! 非线性动画、毛玻璃效果,精致得不像 Gradio!
114 | - New! 适配 Windows / macOS / Linux / iOS / Android,从图标到全面屏适配,给你最合适的体验!
115 | - New! 支持以 PWA应用程序 安装,体验更加原生!
116 |
117 | ### 👨💻 极客功能
118 | - New! 支持 Fine-tune(微调)gpt-3.5!
119 | - 大量 LLM 参数可调;
120 | - 支持更换 api-host;
121 | - 支持自定义代理;
122 | - 支持多 api-key 负载均衡。
123 |
124 | ### ⚒️ 部署相关
125 | - 部署到服务器:在 `config.json` 中设置 `"server_name": "0.0.0.0", "server_port": <你的端口号>,`。
126 | - 获取公共链接:在 `config.json` 中设置 `"share": true,`。注意程序必须在运行,才能通过公共链接访问。
127 | - 在Hugging Face上使用:建议在右上角 **复制Space** 再使用,这样App反应可能会快一点。
128 |
129 | ## 快速上手
130 |
131 | 在终端执行以下命令:
132 |
133 | ```shell
134 | git clone https://github.com/GaiZhenbiao/ChuanhuChatGPT.git
135 | cd ChuanhuChatGPT
136 | pip install -r requirements.txt
137 | ```
138 |
139 | 然后,在项目文件夹中复制一份 `config_example.json`,并将其重命名为 `config.json`,在其中填入 `API-Key` 等设置。
140 |
141 | ```shell
142 | python ChuanhuChatbot.py
143 | ```
144 |
145 | 一个浏览器窗口将会自动打开,此时您将可以使用 **川虎Chat** 与ChatGPT或其他模型进行对话。
146 |
147 | > **Note**
148 | >
149 | > 具体详尽的安装教程和使用教程请查看[本项目的wiki页面](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用教程)。
150 |
151 | ## 疑难杂症解决
152 |
153 | 在遇到各种问题查阅相关信息前,您可以先尝试 **手动拉取本项目的最新更改1** 并 **更新依赖库2**,然后重试。步骤为:
154 |
155 | 1. 点击网页上的 `Download ZIP` 按钮,下载最新代码并解压覆盖,或
156 | ```shell
157 | git pull https://github.com/GaiZhenbiao/ChuanhuChatGPT.git main -f
158 | ```
159 | 2. 尝试再次安装依赖(可能本项目引入了新的依赖)
160 | ```
161 | pip install -r requirements.txt
162 | ```
163 |
164 | 很多时候,这样就可以解决问题。
165 |
166 | 如果问题仍然存在,请查阅该页面:[常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题)
167 |
168 | 该页面列出了**几乎所有**您可能遇到的各种问题,包括如何配置代理,以及遇到问题后您该采取的措施,**请务必认真阅读**。
169 |
170 | ## 了解更多
171 |
172 | 若需了解更多信息,请查看我们的 [wiki](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki):
173 |
174 | - [想要做出贡献?](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/贡献指南)
175 | - [项目更新情况?](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/更新日志)
176 | - [二次开发许可?](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用许可)
177 | - [如何引用项目?](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用许可#如何引用该项目)
178 |
179 | ## Starchart
180 |
181 | [](https://star-history.com/#GaiZhenbiao/ChuanhuChatGPT&Date)
182 |
183 | ## Contributors
184 |
185 |
186 |
187 |
188 |
189 | ## 捐款
190 |
191 | 🐯如果觉得这个软件对你有所帮助,欢迎请作者喝可乐、喝咖啡~
192 |
193 | 联系作者:请去[我的bilibili账号](https://space.bilibili.com/29125536)私信我。
194 |
195 |
196 |
197 |
198 |
--------------------------------------------------------------------------------
/config_example.json:
--------------------------------------------------------------------------------
1 | {
2 | // 各配置具体说明,见 [https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用教程#配置-configjson]
3 |
4 | //== API 配置 ==
5 | "openai_api_key": "", // 你的 OpenAI API Key,一般必填,若空缺则需在图形界面中填入API Key
6 | "deepseek_api_key": "", // 你的 DeepSeek API Key,用于 DeepSeek Chat 和 Reasoner(R1) 对话模型
7 | "google_genai_api_key": "", // 你的 Google Gemini API Key ,用于 Google Gemini 对话模型
8 | "google_genai_api_host": "generativelanguage.googleapis.com", // 你的 Google Gemini API Host 地址,一般无需更改
9 | "xmchat_api_key": "", // 你的 xmchat API Key,用于 XMChat 对话模型
10 | "minimax_api_key": "", // 你的 MiniMax API Key,用于 MiniMax 对话模型
11 | "minimax_group_id": "", // 你的 MiniMax Group ID,用于 MiniMax 对话模型
12 | "midjourney_proxy_api_base": "https://xxx/mj", // 你的 https://github.com/novicezk/midjourney-proxy 代理地址
13 | "midjourney_proxy_api_secret": "", // 你的 MidJourney Proxy API Secret,用于鉴权访问 api,可选
14 | "midjourney_discord_proxy_url": "", // 你的 MidJourney Discord Proxy URL,用于对生成对图进行反代,可选
15 | "midjourney_temp_folder": "./tmp", // 你的 MidJourney 临时文件夹,用于存放生成的图片,填空则关闭自动下载切图(直接显示MJ的四宫格图)
16 | "spark_appid": "", // 你的 讯飞星火大模型 API AppID,用于讯飞星火大模型对话模型
17 | "spark_api_key": "", // 你的 讯飞星火大模型 API Key,用于讯飞星火大模型对话模型
18 | "spark_api_secret": "", // 你的 讯飞星火大模型 API Secret,用于讯飞星火大模型对话模型
19 | "claude_api_secret":"",// 你的 Claude API Secret,用于 Claude 对话模型
20 | "ernie_api_key": "",// 你的文心一言在百度云中的API Key,用于文心一言对话模型
21 | "ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型
22 | "ollama_host": "", // 你的 Ollama Host,用于 Ollama 对话模型
23 | "huggingface_auth_token": "", // 你的 Hugging Face API Token,用于访问有限制的模型
24 | "groq_api_key": "", // 你的 Groq API Key,用于 Groq 对话模型(https://console.groq.com/)
25 |
26 | //== Azure ==
27 | "openai_api_type": "openai", // 可选项:azure, openai
28 | "azure_openai_api_key": "", // 你的 Azure OpenAI API Key,用于 Azure OpenAI 对话模型
29 | "azure_openai_api_base_url": "", // 你的 Azure Base URL
30 | "azure_openai_api_version": "2023-05-15", // 你的 Azure OpenAI API 版本
31 | "azure_deployment_name": "", // 你的 Azure OpenAI Chat 模型 Deployment 名称
32 | "azure_embedding_deployment_name": "", // 你的 Azure OpenAI Embedding 模型 Deployment 名称
33 | "azure_embedding_model_name": "text-embedding-ada-002", // 你的 Azure OpenAI Embedding 模型名称
34 |
35 | //== 基础配置 ==
36 | "language": "auto", // 界面语言,可选"auto", "zh_CN", "en_US", "ja_JP", "ko_KR", "sv_SE", "ru_RU", "vi_VN"
37 | "users": [], // 用户列表,[["用户名1", "密码1"], ["用户名2", "密码2"], ...]
38 | "admin_list": [], // 管理员列表,["用户名1", "用户名2", ...] 只有管理员可以重启服务
39 | "local_embedding": false, //是否在本地编制索引
40 | "hide_history_when_not_logged_in": false, //未登录情况下是否不展示对话历史
41 | "check_update": true, //是否启用检查更新
42 | "default_model": "GPT3.5 Turbo", // 默认模型
43 | "chat_name_method_index": 2, // 选择对话名称的方法。0: 使用日期时间命名;1: 使用第一条提问命名,2: 使用模型自动总结
44 | "bot_avatar": "default", // 机器人头像,可填写本地或网络图片链接,或者"none"(不显示头像)
45 | "user_avatar": "default", // 用户头像,可填写本地或网络图片链接,或者"none"(不显示头像)
46 |
47 | //== API 用量 ==
48 | "show_api_billing": false, //是否显示OpenAI API用量(启用需要填写sensitive_id)
49 | "sensitive_id": "", // 你 OpenAI 账户的 Sensitive ID,用于查询 API 用量
50 | "usage_limit": 120, // 该 OpenAI API Key 的当月限额,单位:美元,用于计算百分比和显示上限
51 | "legacy_api_usage": false, // 是否使用旧版 API 用量查询接口(OpenAI现已关闭该接口,但是如果你在使用第三方 API,第三方可能仍然支持此接口)
52 |
53 | //== 川虎助理设置 ==
54 | "GOOGLE_CSE_ID": "", //谷歌搜索引擎ID,用于川虎助理Pro模式,获取方式请看 https://stackoverflow.com/questions/37083058/programmatically-searching-google-in-python-using-custom-search
55 | "GOOGLE_API_KEY": "", //谷歌API Key,用于川虎助理Pro模式
56 | "WOLFRAM_ALPHA_APPID": "", //Wolfram Alpha API Key,用于川虎助理Pro模式,获取方式请看 https://products.wolframalpha.com/api/
57 | "SERPAPI_API_KEY": "", //SerpAPI API Key,用于川虎助理Pro模式,获取方式请看 https://serpapi.com/
58 |
59 | //== 文档处理与显示 ==
60 | "latex_option": "default", // LaTeX 公式渲染策略,可选"default", "strict", "all"或者"disabled"
61 | "advance_docs": {
62 | "pdf": {
63 | "two_column": false, // 是否认为PDF是双栏的
64 | "formula_ocr": true // 是否使用OCR识别PDF中的公式
65 | }
66 | },
67 |
68 | //== 高级配置 ==
69 | // 是否多个API Key轮换使用
70 | "multi_api_key": false,
71 | "hide_my_key": false, // 如果你想在UI中隐藏 API 密钥输入框,将此值设置为 true
72 | // "available_models": ["GPT3.5 Turbo", "GPT4 Turbo", "GPT4 Vision"], // 可用的模型列表,将覆盖默认的可用模型列表
73 | // "extra_models": ["模型名称3", "模型名称4", ...], // 额外的模型,将添加到可用的模型列表之后
74 | // "extra_model_metadata": {
75 | // "GPT-3.5 Turbo Keldos": {
76 | // "model_name": "gpt-3.5-turbo",
77 | // "description": "GPT-3.5 Turbo is a large language model trained by OpenAI. It is the latest version of the GPT series of models, and is known for its ability to generate human-like text.",
78 | // "model_type": "OpenAI",
79 | // "multimodal": false,
80 | // "api_host": "https://www.example.com",
81 | // "token_limit": 4096,
82 | // "max_generation": 4096,
83 | // },
84 | // }
85 | // "api_key_list": [
86 | // "sk-xxxxxxxxxxxxxxxxxxxxxxxx1",
87 | // "sk-xxxxxxxxxxxxxxxxxxxxxxxx2",
88 | // "sk-xxxxxxxxxxxxxxxxxxxxxxxx3"
89 | // ],
90 | // "rename_model": "GPT-4o-mini", //指定默认命名模型
91 | // 自定义OpenAI API Base
92 | // "openai_api_base": "https://api.openai.com",
93 | // 自定义使用代理(请替换代理URL)
94 | // "https_proxy": "http://127.0.0.1:1079",
95 | // "http_proxy": "http://127.0.0.1:1079",
96 | // 自定义端口、自定义ip(请替换对应内容)
97 | // "server_name": "0.0.0.0",
98 | // "server_port": 7860,
99 | // 如果要share到gradio,设置为true
100 | // "share": false,
101 | //如果不想自动打开浏览器,设置为false
102 | //"autobrowser": false
103 | }
104 |
--------------------------------------------------------------------------------
/configs/ds_config_chatbot.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": false
4 | },
5 | "bf16": {
6 | "enabled": true
7 | },
8 | "comms_logger": {
9 | "enabled": false,
10 | "verbose": false,
11 | "prof_all": false,
12 | "debug": false
13 | },
14 | "steps_per_print": 20000000000000000,
15 | "train_micro_batch_size_per_gpu": 1,
16 | "wall_clock_breakdown": false
17 | }
18 |
--------------------------------------------------------------------------------
/locale/extract_locale.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import os
4 | import re
5 | import sys
6 |
7 | import aiohttp
8 | import commentjson
9 | import commentjson as json
10 |
11 | asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
12 |
13 | with open("config.json", "r", encoding="utf-8") as f:
14 | config = commentjson.load(f)
15 | api_key = config["openai_api_key"]
16 | url = config["openai_api_base"] + "/v1/chat/completions" if "openai_api_base" in config else "https://api.openai.com/v1/chat/completions"
17 |
18 |
19 | def get_current_strings():
20 | pattern = r'i18n\s*\(\s*["\']([^"\']*(?:\)[^"\']*)?)["\']\s*\)'
21 |
22 | # Load the .py files
23 | contents = ""
24 | for dirpath, dirnames, filenames in os.walk("."):
25 | for filename in filenames:
26 | if filename.endswith(".py"):
27 | filepath = os.path.join(dirpath, filename)
28 | with open(filepath, 'r', encoding='utf-8') as f:
29 | contents += f.read()
30 | # Matching with regular expressions
31 | matches = re.findall(pattern, contents, re.DOTALL)
32 | data = {match.strip('()"'): '' for match in matches}
33 | fixed_data = {} # fix some keys
34 | for key, value in data.items():
35 | if "](" in key and key.count("(") != key.count(")"):
36 | fixed_data[key+")"] = value
37 | else:
38 | fixed_data[key] = value
39 |
40 | return fixed_data
41 |
42 |
43 | def get_locale_strings(filename):
44 | try:
45 | with open(filename, "r", encoding="utf-8") as f:
46 | locale_strs = json.load(f)
47 | except FileNotFoundError:
48 | locale_strs = {}
49 | return locale_strs
50 |
51 |
52 | def sort_strings(existing_translations):
53 | # Sort the merged data
54 | sorted_translations = {}
55 | # Add entries with (NOT USED) in their values
56 | for key, value in sorted(existing_translations.items(), key=lambda x: x[0]):
57 | if "(🔴NOT USED)" in value:
58 | sorted_translations[key] = value
59 | # Add entries with empty values
60 | for key, value in sorted(existing_translations.items(), key=lambda x: x[0]):
61 | if value == "":
62 | sorted_translations[key] = value
63 | # Add the rest of the entries
64 | for key, value in sorted(existing_translations.items(), key=lambda x: x[0]):
65 | if value != "" and "(NOT USED)" not in value:
66 | sorted_translations[key] = value
67 |
68 | return sorted_translations
69 |
70 |
71 | async def auto_translate(str, language):
72 | headers = {
73 | "Content-Type": "application/json",
74 | "Authorization": f"Bearer {api_key}",
75 | "temperature": f"{0}",
76 | }
77 | payload = {
78 | "model": "gpt-3.5-turbo",
79 | "messages": [
80 | {
81 | "role": "system",
82 | "content": f"You are a translation program;\nYour job is to translate user input into {language};\nThe content you are translating is a string in the App;\nDo not explain emoji;\nIf input is only a emoji, please simply return origin emoji;\nPlease ensure that the translation results are concise and easy to understand."
83 | },
84 | {"role": "user", "content": f"{str}"}
85 | ],
86 | }
87 |
88 | async with aiohttp.ClientSession() as session:
89 | async with session.post(url, headers=headers, json=payload) as response:
90 | data = await response.json()
91 | return data["choices"][0]["message"]["content"]
92 |
93 |
94 | async def main(auto=False):
95 | current_strs = get_current_strings()
96 | locale_files = []
97 | # 遍历locale目录下的所有json文件
98 | for dirpath, dirnames, filenames in os.walk("locale"):
99 | for filename in filenames:
100 | if filename.endswith(".json"):
101 | locale_files.append(os.path.join(dirpath, filename))
102 |
103 |
104 | for locale_filename in locale_files:
105 | if "zh_CN" in locale_filename:
106 | continue
107 | try:
108 | locale_strs = get_locale_strings(locale_filename)
109 | except json.decoder.JSONDecodeError:
110 | import traceback
111 | traceback.print_exc()
112 | logging.error(f"Error decoding {locale_filename}")
113 | continue
114 |
115 | # Add new keys
116 | new_keys = []
117 | for key in current_strs:
118 | if key not in locale_strs:
119 | new_keys.append(key)
120 | locale_strs[key] = ""
121 | print(f"{locale_filename[7:-5]}'s new str: {len(new_keys)}")
122 | # Add (NOT USED) to invalid keys
123 | for key in locale_strs:
124 | if key not in current_strs:
125 | locale_strs[key] = "(🔴NOT USED)" + locale_strs[key]
126 | print(f"{locale_filename[7:-5]}'s invalid str: {len(locale_strs) - len(current_strs)}")
127 |
128 | locale_strs = sort_strings(locale_strs)
129 |
130 | if auto:
131 | tasks = []
132 | non_translated_keys = []
133 | for key in locale_strs:
134 | if locale_strs[key] == "":
135 | non_translated_keys.append(key)
136 | tasks.append(auto_translate(key, locale_filename[7:-5]))
137 | results = await asyncio.gather(*tasks)
138 | for key, result in zip(non_translated_keys, results):
139 | locale_strs[key] = "(🟡REVIEW NEEDED)" + result
140 | print(f"{locale_filename[7:-5]}'s auto translated str: {len(non_translated_keys)}")
141 |
142 | with open(locale_filename, 'w', encoding='utf-8') as f:
143 | json.dump(locale_strs, f, ensure_ascii=False, indent=4)
144 |
145 |
146 | if __name__ == "__main__":
147 | auto = False
148 | if len(sys.argv) > 1 and sys.argv[1] == "--auto":
149 | auto = True
150 | asyncio.run(main(auto))
151 |
--------------------------------------------------------------------------------
/locale/zh_CN.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpt3.5turbo_description": "GPT-3.5 Turbo 是由 OpenAI 开发的一款仅限文本的大型语言模型。它基于 GPT-3 模型,并已经在大量数据上进行了微调。最新版本的 GPT-3.5 Turbo 进行了性能和精度优化,支持最大 16k tokens 的上下文窗口和最大 4096 tokens 的响应长度。此模型始终使用可用的最新版本的 GPT-3.5 Turbo。",
3 | "gpt3.5turbo_instruct_description": "GPT3.5 Turbo Instruct 是 OpenAI 开发的文本补全模型,具有与 GPT-3 时代模型相似的功能。它兼容旧版的 Completions 端点,但不兼容 Chat Completions。该模型的上下文窗口为 4096 个 tokens。",
4 | "gpt3.5turbo_16k_description": "旧版的 GPT-3.5 Turbo 模型,具有 16k tokens 的上下文窗口。",
5 | "gpt4_description": "GPT-4 是 OpenAI 开发的一款仅限文本的大型语言模型。它具有 8192 个 tokens 的上下文窗口和 4096 个 tokens 的最大响应长度。该模型始终使用可用的最新版本的 GPT-4。建议使用 GPT-4 Turbo 以获得更好的性能、更快的速度和更低的成本。",
6 | "gpt4_32k_description": "GPT-4 32K 是 OpenAI 开发的一个仅限文本的大型语言模型。它具有 32,000 tokens 的上下文窗口和 4,096 tokens 的最大响应长度。这个模型从未广泛推出,建议使用 GPT-4 Turbo。",
7 | "gpt4turbo_description": "GPT-4 Turbo 是由 OpenAI 开发的一款多模态大型语言模型。它在广泛的自然语言处理任务上提供最先进的性能,包括文本生成、翻译、摘要、视觉问题回答等。GPT-4 Turbo 拥有最大 128k tokens 的上下文窗口和最大 4096 tokens 的响应长度。此模型始终使用可用的最新版本的 GPT-4 Turbo。",
8 | "claude3_haiku_description": "Claude3 Haiku 是由 Anthropic 开发的一款多模态大型语言模型。它是 Claude 3 模型家族中最快、最紧凑的模型,旨在实现近乎即时的响应速度,但是性能不如 Sonnet 和 Opus。Claude3 Haiku 有最大 200k tokens 的上下文窗口和最大 4096 tokens 的响应长度。此模型始终使用可用的最新版本的 Claude3 Haiku。",
9 | "claude3_sonnet_description": "Claude3 Sonnet 是由 Anthropic 开发的一款多模态大型语言模型。它在智能与速度之间保持最佳平衡,适用于企业工作负载和大规模 AI 部署。Claude3 Sonnet 拥有最大 200k tokens 的上下文窗口和最大 4096 tokens 的响应长度。此模型始终使用可用的最新版本的 Claude3 Sonnet。",
10 | "claude3_opus_description": "Claude3 Opus 是由 Anthropic 开发的一款多模态大型语言模型。它是 Claude 3 模型家族中最智能、最大的模型,能够在高度复杂的任务上呈现最顶尖的性能,呈现出类似人类的理解能力。Claude3 Opus 拥有最大 200k tokens 的上下文窗口和最大 4096 tokens 的响应长度。此模型始终使用可用的最新版本的 Claude3 Opus。",
11 | "groq_llama3_8b_description": "采用 [Groq](https://console.groq.com/) 的 LLaMA 3 8B。Groq 是一个非常快速的语言模型推理服务。",
12 | "groq_llama3_70b_description": "采用 [Groq](https://console.groq.com/) 的 LLaMA 3 70B。Groq 是一个非常快速的语言模型推理服务。",
13 | "groq_mixtral_8x7b_description": "采用 [Groq](https://console.groq.com/) 的 Mixtral 8x7B。Groq 是一个非常快速的语言模型推理服务。",
14 | "groq_gemma_7b_description": "采用 [Groq](https://console.groq.com/) 的 Gemma 7B。Groq 是一个非常快速的语言模型推理服务。",
15 | "chuanhu_description": "一个能使用多种工具解决复杂问题的智能体。",
16 | "gpt_default_slogan": "今天能帮您些什么?",
17 | "claude_default_slogan": "我能帮您什么忙?",
18 | "chuanhu_slogan": "川虎今天能帮你做些什么?",
19 | "chuanhu_question_1": "今天杭州天气如何?",
20 | "chuanhu_question_2": "最近 Apple 发布了什么新品?",
21 | "chuanhu_question_3": "现在显卡的价格如何?",
22 | "chuanhu_question_4": "TikTok 上有什么新梗?",
23 | "gpt4o_description": "OpenAI 的最先进的多模态旗舰模型,比 GPT-4 Turbo 更便宜、更快。",
24 | "gpt4omini_description": "OpenAI 的经济实惠且智能的小型模型,适用于快速、轻量级任务。",
25 | "o1_description": "o1 系列的大型语言模型通过强化学习训练,能够执行复杂的推理任务。o1 模型在回答之前会进行思考,产生一长串内部思维链,然后再回应用户。",
26 | "no_permission_to_update_description": "你没有权限更新。请联系管理员。管理员的配置方式为在配置文件 config.json 中的 admin_list 中添加用户名。"
27 | }
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GaiZhenbiao/ChuanhuChatGPT/550fd86b9411bf8afe73d783a0d90d074e118be4/modules/__init__.py
--------------------------------------------------------------------------------
/modules/index_func.py:
--------------------------------------------------------------------------------
1 | import PyPDF2
2 | from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
3 | from langchain_community.vectorstores import FAISS
4 | from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
5 | from tqdm import tqdm
6 |
7 | from modules.config import local_embedding
8 | from modules.utils import *
9 |
10 |
11 | def get_documents(file_src):
12 | from langchain.schema import Document
13 | from langchain.text_splitter import TokenTextSplitter
14 |
15 | text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=30)
16 |
17 | documents = []
18 | logging.debug("Loading documents...")
19 | logging.debug(f"file_src: {file_src}")
20 | for file in file_src:
21 | filepath = file.name
22 | filename = os.path.basename(filepath)
23 | file_type = os.path.splitext(filename)[1]
24 | logging.info(f"loading file: {filename}")
25 | texts = None
26 | try:
27 | if file_type == ".pdf":
28 | logging.debug("Loading PDF...")
29 | try:
30 | from modules.config import advance_docs
31 | from modules.pdf_func import parse_pdf
32 |
33 | two_column = advance_docs["pdf"].get("two_column", False)
34 | pdftext = parse_pdf(filepath, two_column).text
35 | except:
36 | pdftext = ""
37 | with open(filepath, "rb") as pdfFileObj:
38 | pdfReader = PyPDF2.PdfReader(pdfFileObj)
39 | for page in tqdm(pdfReader.pages):
40 | pdftext += page.extract_text()
41 | texts = [Document(page_content=pdftext, metadata={"source": filepath})]
42 | elif file_type == ".docx":
43 | logging.debug("Loading Word...")
44 | from langchain.document_loaders import \
45 | UnstructuredWordDocumentLoader
46 |
47 | loader = UnstructuredWordDocumentLoader(filepath)
48 | texts = loader.load()
49 | elif file_type == ".pptx":
50 | logging.debug("Loading PowerPoint...")
51 | from langchain.document_loaders import \
52 | UnstructuredPowerPointLoader
53 |
54 | loader = UnstructuredPowerPointLoader(filepath)
55 | texts = loader.load()
56 | elif file_type == ".epub":
57 | logging.debug("Loading EPUB...")
58 | from langchain.document_loaders import UnstructuredEPubLoader
59 |
60 | loader = UnstructuredEPubLoader(filepath)
61 | texts = loader.load()
62 | elif file_type == ".xlsx":
63 | logging.debug("Loading Excel...")
64 | text_list = excel_to_string(filepath)
65 | texts = []
66 | for elem in text_list:
67 | texts.append(
68 | Document(page_content=elem, metadata={"source": filepath})
69 | )
70 | elif file_type in [
71 | ".jpg",
72 | ".jpeg",
73 | ".png",
74 | ".heif",
75 | ".heic",
76 | ".webp",
77 | ".bmp",
78 | ".gif",
79 | ".tiff",
80 | ".tif",
81 | ]:
82 | raise gr.Warning(
83 | i18n("不支持的文件: ")
84 | + filename
85 | + i18n(",请使用 .pdf, .docx, .pptx, .epub, .xlsx 等文档。")
86 | )
87 | else:
88 | logging.debug("Loading text file...")
89 | from langchain.document_loaders import TextLoader
90 |
91 | loader = TextLoader(filepath, "utf8")
92 | texts = loader.load()
93 | except Exception as e:
94 | import traceback
95 |
96 | logging.error(f"Error loading file: {filename}")
97 | traceback.print_exc()
98 |
99 | if texts is not None:
100 | texts = text_splitter.split_documents(texts)
101 | documents.extend(texts)
102 | logging.debug("Documents loaded.")
103 | return documents
104 |
105 |
106 | def construct_index(
107 | api_key,
108 | file_src,
109 | max_input_size=4096,
110 | num_outputs=5,
111 | max_chunk_overlap=20,
112 | chunk_size_limit=600,
113 | embedding_limit=None,
114 | separator=" ",
115 | load_from_cache_if_possible=True,
116 | ):
117 | if api_key:
118 | os.environ["OPENAI_API_KEY"] = api_key
119 | else:
120 | # 由于一个依赖的愚蠢的设计,这里必须要有一个API KEY
121 | os.environ["OPENAI_API_KEY"] = "sk-xxxxxxx"
122 | logging.debug(f"api base: {os.environ.get('OPENAI_API_BASE', None)}")
123 | chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
124 | embedding_limit = None if embedding_limit == 0 else embedding_limit
125 | separator = " " if separator == "" else separator
126 |
127 | index_name = get_file_hash(file_src)
128 | index_path = f"./index/{index_name}"
129 | if local_embedding:
130 | embeddings = HuggingFaceEmbeddings(
131 | model_name="sentence-transformers/distiluse-base-multilingual-cased-v2"
132 | )
133 | else:
134 | if os.environ.get("OPENAI_API_TYPE", "openai") == "openai":
135 | embeddings = OpenAIEmbeddings(
136 | openai_api_base=os.environ.get("OPENAI_API_BASE", None),
137 | openai_api_key=os.environ.get("OPENAI_EMBEDDING_API_KEY", api_key),
138 | model="text-embedding-3-large",
139 | )
140 | else:
141 | embeddings = AzureOpenAIEmbeddings(
142 | deployment=os.environ["AZURE_EMBEDDING_DEPLOYMENT_NAME"],
143 | openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
144 | model=os.environ["AZURE_EMBEDDING_MODEL_NAME"],
145 | azure_endpoint=os.environ["AZURE_OPENAI_API_BASE_URL"],
146 | openai_api_type="azure",
147 | )
148 | if os.path.exists(index_path) and load_from_cache_if_possible:
149 | logging.info(i18n("找到了缓存的索引文件,加载中……"))
150 | return FAISS.load_local(
151 | index_path, embeddings, allow_dangerous_deserialization=True
152 | )
153 | else:
154 | documents = get_documents(file_src)
155 | logging.debug(i18n("构建索引中……"))
156 | if documents:
157 | with retrieve_proxy():
158 | index = FAISS.from_documents(documents, embeddings)
159 | else:
160 | raise Exception(i18n("没有找到任何支持的文档。"))
161 | logging.debug(i18n("索引构建完成!"))
162 | os.makedirs("./index", exist_ok=True)
163 | index.save_local(index_path)
164 | logging.debug(i18n("索引已保存至本地!"))
165 | return index
166 |
--------------------------------------------------------------------------------
/modules/models/Azure.py:
--------------------------------------------------------------------------------
1 | from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
2 | import os
3 |
4 | from .base_model import Base_Chat_Langchain_Client
5 |
6 | # load_config_to_environ(["azure_openai_api_key", "azure_api_base_url", "azure_openai_api_version", "azure_deployment_name"])
7 |
8 | class Azure_OpenAI_Client(Base_Chat_Langchain_Client):
9 | def setup_model(self):
10 | # inplement this to setup the model then return it
11 | return AzureChatOpenAI(
12 | openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"],
13 | openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
14 | deployment_name=os.environ["AZURE_DEPLOYMENT_NAME"],
15 | openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
16 | openai_api_type="azure",
17 | streaming=True
18 | )
19 |
--------------------------------------------------------------------------------
/modules/models/ChatGLM.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | import os
5 | import platform
6 |
7 | import gc
8 | import torch
9 | import colorama
10 |
11 | from ..index_func import *
12 | from ..presets import *
13 | from ..utils import *
14 | from .base_model import BaseLLMModel
15 |
16 |
17 | class ChatGLM_Client(BaseLLMModel):
18 | def __init__(self, model_name, user_name="") -> None:
19 | super().__init__(model_name=model_name, user=user_name)
20 | import torch
21 | from transformers import AutoModel, AutoTokenizer
22 | global CHATGLM_TOKENIZER, CHATGLM_MODEL
23 | self.deinitialize()
24 | if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
25 | system_name = platform.system()
26 | model_path = None
27 | if os.path.exists("models"):
28 | model_dirs = os.listdir("models")
29 | if model_name in model_dirs:
30 | model_path = f"models/{model_name}"
31 | if model_path is not None:
32 | model_source = model_path
33 | else:
34 | model_source = f"THUDM/{model_name}"
35 | CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(
36 | model_source, trust_remote_code=True
37 | )
38 | quantified = False
39 | if "int4" in model_name:
40 | quantified = True
41 | model = AutoModel.from_pretrained(
42 | model_source, trust_remote_code=True
43 | )
44 | if torch.cuda.is_available():
45 | # run on CUDA
46 | logging.info("CUDA is available, using CUDA")
47 | model = model.half().cuda()
48 | # mps加速还存在一些问题,暂时不使用
49 | elif system_name == "Darwin" and model_path is not None and not quantified:
50 | logging.info("Running on macOS, using MPS")
51 | # running on macOS and model already downloaded
52 | model = model.half().to("mps")
53 | else:
54 | logging.info("GPU is not available, using CPU")
55 | model = model.float()
56 | model = model.eval()
57 | CHATGLM_MODEL = model
58 |
59 | def _get_glm3_style_input(self):
60 | history = self.history
61 | query = history.pop()["content"]
62 | return history, query
63 |
64 | def _get_glm2_style_input(self):
65 | history = [x["content"] for x in self.history]
66 | query = history.pop()
67 | logging.debug(colorama.Fore.YELLOW +
68 | f"{history}" + colorama.Fore.RESET)
69 | assert (
70 | len(history) % 2 == 0
71 | ), f"History should be even length. current history is: {history}"
72 | history = [[history[i], history[i + 1]]
73 | for i in range(0, len(history), 2)]
74 | return history, query
75 |
76 | def _get_glm_style_input(self):
77 | if "glm2" in self.model_name:
78 | return self._get_glm2_style_input()
79 | else:
80 | return self._get_glm3_style_input()
81 |
82 | def get_answer_at_once(self):
83 | history, query = self._get_glm_style_input()
84 | response, _ = CHATGLM_MODEL.chat(
85 | CHATGLM_TOKENIZER, query, history=history)
86 | return response, len(response)
87 |
88 | def get_answer_stream_iter(self):
89 | history, query = self._get_glm_style_input()
90 | for response, history in CHATGLM_MODEL.stream_chat(
91 | CHATGLM_TOKENIZER,
92 | query,
93 | history,
94 | max_length=self.token_upper_limit,
95 | top_p=self.top_p,
96 | temperature=self.temperature,
97 | ):
98 | yield response
99 |
100 | def deinitialize(self):
101 | # 释放显存
102 | global CHATGLM_MODEL, CHATGLM_TOKENIZER
103 | CHATGLM_MODEL = None
104 | CHATGLM_TOKENIZER = None
105 | gc.collect()
106 | torch.cuda.empty_cache()
107 | logging.info("ChatGLM model deinitialized")
108 |
--------------------------------------------------------------------------------
/modules/models/Claude.py:
--------------------------------------------------------------------------------
1 | from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
2 | from ..presets import *
3 | from ..utils import *
4 |
5 | from .base_model import BaseLLMModel
6 |
7 |
8 | class Claude_Client(BaseLLMModel):
9 | def __init__(self, model_name, api_secret) -> None:
10 | super().__init__(model_name=model_name)
11 | self.api_secret = api_secret
12 | if None in [self.api_secret]:
13 | raise Exception("请在配置文件或者环境变量中设置Claude的API Secret")
14 | self.claude_client = Anthropic(api_key=self.api_secret, base_url=self.api_host)
15 |
16 | def _get_claude_style_history(self):
17 | history = []
18 | image_buffer = []
19 | image_count = 0
20 | for message in self.history:
21 | if message["role"] == "user":
22 | content = []
23 | if image_buffer:
24 | if image_count == 1:
25 | content.append(
26 | {
27 | "type": "image",
28 | "source": {
29 | "type": "base64",
30 | "media_type": f"image/{self.get_image_type(image_buffer[0])}",
31 | "data": self.get_base64_image(image_buffer[0]),
32 | },
33 | },
34 | )
35 | else:
36 | image_buffer_length = len(image_buffer)
37 | for idx, image in enumerate(image_buffer):
38 | content.append(
39 | {"type": "text", "text": f"Image {image_count - image_buffer_length + idx + 1}:"},
40 | )
41 | content.append(
42 | {
43 | "type": "image",
44 | "source": {
45 | "type": "base64",
46 | "media_type": f"image/{self.get_image_type(image)}",
47 | "data": self.get_base64_image(image),
48 | },
49 | },
50 | )
51 | if content:
52 | content.append({"type": "text", "text": message["content"]})
53 | history.append(construct_user(content))
54 | image_buffer = []
55 | else:
56 | history.append(message)
57 | elif message["role"] == "assistant":
58 | history.append(message)
59 | elif message["role"] == "image":
60 | image_buffer.append(message["content"])
61 | image_count += 1
62 | # history with base64 data replaced with "#base64#"
63 | # history_for_display = history.copy()
64 | # for message in history_for_display:
65 | # if message["role"] == "user":
66 | # if type(message["content"]) == list:
67 | # for content in message["content"]:
68 | # if content["type"] == "image":
69 | # content["source"]["data"] = "#base64#"
70 | # logging.info(f"History for Claude: {history_for_display}")
71 | return history
72 |
73 | def get_answer_stream_iter(self):
74 | system_prompt = self.system_prompt
75 | history = self._get_claude_style_history()
76 |
77 | try:
78 | with self.claude_client.messages.stream(
79 | model=self.model_name,
80 | max_tokens=self.max_generation_token,
81 | messages=history,
82 | system=system_prompt,
83 | ) as stream:
84 | partial_text = ""
85 | for text in stream.text_stream:
86 | partial_text += text
87 | yield partial_text
88 | except Exception as e:
89 | yield i18n(GENERAL_ERROR_MSG) + ": " + str(e)
90 |
91 | def get_answer_at_once(self):
92 | system_prompt = self.system_prompt
93 | history = self._get_claude_style_history()
94 |
95 | response = self.claude_client.messages.create(
96 | model=self.model_name,
97 | max_tokens=self.max_generation_token,
98 | messages=history,
99 | system=system_prompt,
100 | )
101 | if response is not None:
102 | return response.content[0].text, response.usage.output_tokens
103 | else:
104 | return i18n("获取资源错误"), 0
105 |
--------------------------------------------------------------------------------
/modules/models/DALLE3.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from .base_model import BaseLLMModel
3 | from .. import shared
4 | import requests
5 | from ..presets import *
6 | from ..config import retrieve_proxy, sensitive_id
7 |
8 | class OpenAI_DALLE3_Client(BaseLLMModel):
9 | def __init__(self, model_name, api_key, user_name="") -> None:
10 | super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})
11 | if self.api_host is not None:
12 | self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = shared.format_openai_host(self.api_host)
13 | else:
14 | self.api_host, self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = shared.state.api_host, shared.state.chat_completion_url, shared.state.images_completion_url, shared.state.openai_api_base, shared.state.balance_api_url, shared.state.usage_api_url
15 | self._refresh_header()
16 |
17 | def _get_dalle3_prompt(self):
18 | prompt = self.history[-1]["content"]
19 | if prompt.endswith("--raw"):
20 | prompt = "I NEED to test how the tool works with extremely simple prompts. DO NOT add any detail, just use it AS-IS:" + prompt
21 | return prompt
22 |
23 | def get_answer_at_once(self, stream=False):
24 | prompt = self._get_dalle3_prompt()
25 | headers = {
26 | "Content-Type": "application/json",
27 | "Authorization": f"Bearer {self.api_key}"
28 | }
29 | payload = {
30 | "model": self.model_name,
31 | "prompt": prompt,
32 | "n": 1,
33 | "size": "1024x1024",
34 | "quality": "standard",
35 | }
36 | if stream:
37 | timeout = TIMEOUT_STREAMING
38 | else:
39 | timeout = TIMEOUT_ALL
40 |
41 | if self.images_completion_url != IMAGES_COMPLETION_URL:
42 | logging.debug(f"使用自定义API URL: {self.images_completion_url}")
43 |
44 | with retrieve_proxy():
45 | try:
46 | response = requests.post(
47 | self.images_completion_url,
48 | headers=headers,
49 | json=payload,
50 | stream=stream,
51 | timeout=timeout,
52 | )
53 | response.raise_for_status() # 根据HTTP状态码引发异常
54 | response_data = response.json()
55 | image_url = response_data['data'][0]['url']
56 | img_tag = f'
'
57 | revised_prompt = response_data['data'][0].get('revised_prompt', '')
58 | return img_tag + revised_prompt, 0
59 | except requests.exceptions.RequestException as e:
60 | return str(e), 0
61 |
62 | def _refresh_header(self):
63 | self.headers = {
64 | "Content-Type": "application/json",
65 | "Authorization": f"Bearer {sensitive_id}",
66 | }
--------------------------------------------------------------------------------
/modules/models/ERNIE.py:
--------------------------------------------------------------------------------
1 | from ..presets import *
2 | from ..utils import *
3 |
4 | from .base_model import BaseLLMModel
5 |
6 |
7 | class ERNIE_Client(BaseLLMModel):
8 | def __init__(self, model_name, api_key, secret_key) -> None:
9 | super().__init__(model_name=model_name)
10 | self.api_key = api_key
11 | self.api_secret = secret_key
12 | if None in [self.api_secret, self.api_key]:
13 | raise Exception("请在配置文件或者环境变量中设置文心一言的API Key 和 Secret Key")
14 |
15 | if self.model_name == "ERNIE-Bot-turbo":
16 | self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token="
17 | elif self.model_name == "ERNIE-Bot":
18 | self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token="
19 | elif self.model_name == "ERNIE-Bot-4":
20 | self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token="
21 |
22 | def get_access_token(self):
23 | """
24 | 使用 AK,SK 生成鉴权签名(Access Token)
25 | :return: access_token,或是None(如果错误)
26 | """
27 | url = "https://aip.baidubce.com/oauth/2.0/token?client_id=" + self.api_key + "&client_secret=" + self.api_secret + "&grant_type=client_credentials"
28 |
29 | payload = json.dumps("")
30 | headers = {
31 | 'Content-Type': 'application/json',
32 | 'Accept': 'application/json'
33 | }
34 |
35 | response = requests.request("POST", url, headers=headers, data=payload)
36 |
37 | return response.json()["access_token"]
38 | def get_answer_stream_iter(self):
39 | url = self.ERNIE_url + self.get_access_token()
40 | system_prompt = self.system_prompt
41 | history = self.history
42 | if system_prompt is not None:
43 | history = [construct_system(system_prompt), *history]
44 |
45 | # 去除history中 history的role为system的
46 | history = [i for i in history if i["role"] != "system"]
47 |
48 | payload = json.dumps({
49 | "messages":history,
50 | "stream": True
51 | })
52 | headers = {
53 | 'Content-Type': 'application/json'
54 | }
55 |
56 | response = requests.request("POST", url, headers=headers, data=payload, stream=True)
57 |
58 | if response.status_code == 200:
59 | partial_text = ""
60 | for line in response.iter_lines():
61 | if len(line) == 0:
62 | continue
63 | line = json.loads(line[5:])
64 | partial_text += line['result']
65 | yield partial_text
66 | else:
67 | yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
68 |
69 |
70 | def get_answer_at_once(self):
71 | url = self.ERNIE_url + self.get_access_token()
72 | system_prompt = self.system_prompt
73 | history = self.history
74 | if system_prompt is not None:
75 | history = [construct_system(system_prompt), *history]
76 |
77 | # 去除history中 history的role为system的
78 | history = [i for i in history if i["role"] != "system"]
79 |
80 | payload = json.dumps({
81 | "messages": history,
82 | "stream": True
83 | })
84 | headers = {
85 | 'Content-Type': 'application/json'
86 | }
87 |
88 | response = requests.request("POST", url, headers=headers, data=payload, stream=True)
89 |
90 | if response.status_code == 200:
91 |
92 | return str(response.json()["result"]),len(response.json()["result"])
93 | else:
94 | return "获取资源错误", 0
95 |
96 |
97 |
--------------------------------------------------------------------------------
/modules/models/GoogleGemma.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from threading import Thread
3 |
4 | import torch
5 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6 |
7 | from ..presets import *
8 | from .base_model import BaseLLMModel
9 |
10 |
11 | class GoogleGemmaClient(BaseLLMModel):
12 | def __init__(self, model_name, api_key, user_name="") -> None:
13 | super().__init__(model_name=model_name, user=user_name)
14 |
15 | global GEMMA_TOKENIZER, GEMMA_MODEL
16 | # self.deinitialize()
17 | self.default_max_generation_token = self.token_upper_limit
18 | self.max_generation_token = self.token_upper_limit
19 | if GEMMA_TOKENIZER is None or GEMMA_MODEL is None:
20 | model_path = None
21 | if os.path.exists("models"):
22 | model_dirs = os.listdir("models")
23 | if model_name in model_dirs:
24 | model_path = f"models/{model_name}"
25 | if model_path is not None:
26 | model_source = model_path
27 | else:
28 | if os.path.exists(
29 | os.path.join("models", MODEL_METADATA[model_name]["model_name"])
30 | ):
31 | model_source = os.path.join(
32 | "models", MODEL_METADATA[model_name]["model_name"]
33 | )
34 | else:
35 | try:
36 | model_source = MODEL_METADATA[model_name]["repo_id"]
37 | except:
38 | model_source = model_name
39 | dtype = torch.bfloat16
40 | GEMMA_TOKENIZER = AutoTokenizer.from_pretrained(
41 | model_source, use_auth_token=os.environ["HF_AUTH_TOKEN"]
42 | )
43 | GEMMA_MODEL = AutoModelForCausalLM.from_pretrained(
44 | model_source,
45 | device_map="auto",
46 | torch_dtype=dtype,
47 | trust_remote_code=True,
48 | resume_download=True,
49 | use_auth_token=os.environ["HF_AUTH_TOKEN"],
50 | )
51 |
52 | def deinitialize(self):
53 | global GEMMA_TOKENIZER, GEMMA_MODEL
54 | GEMMA_TOKENIZER = None
55 | GEMMA_MODEL = None
56 | self.clear_cuda_cache()
57 | logging.info("GEMMA deinitialized")
58 |
59 | def _get_gemma_style_input(self):
60 | global GEMMA_TOKENIZER
61 | # messages = [{"role": "system", "content": self.system_prompt}, *self.history] # system prompt is not supported
62 | messages = self.history
63 | prompt = GEMMA_TOKENIZER.apply_chat_template(
64 | messages, tokenize=False, add_generation_prompt=True
65 | )
66 | inputs = GEMMA_TOKENIZER.encode(
67 | prompt, add_special_tokens=True, return_tensors="pt"
68 | )
69 | return inputs
70 |
71 | def get_answer_at_once(self):
72 | global GEMMA_TOKENIZER, GEMMA_MODEL
73 | inputs = self._get_gemma_style_input()
74 | outputs = GEMMA_MODEL.generate(
75 | input_ids=inputs.to(GEMMA_MODEL.device),
76 | max_new_tokens=self.max_generation_token,
77 | )
78 | generated_token_count = outputs.shape[1] - inputs.shape[1]
79 | outputs = GEMMA_TOKENIZER.decode(outputs[0], skip_special_tokens=True)
80 | outputs = outputs.split("model\n")[-1][:-5]
81 | self.clear_cuda_cache()
82 | return outputs, generated_token_count
83 |
84 | def get_answer_stream_iter(self):
85 | global GEMMA_TOKENIZER, GEMMA_MODEL
86 | inputs = self._get_gemma_style_input()
87 | streamer = TextIteratorStreamer(
88 | GEMMA_TOKENIZER, timeout=10.0, skip_prompt=True, skip_special_tokens=True
89 | )
90 | input_kwargs = dict(
91 | input_ids=inputs.to(GEMMA_MODEL.device),
92 | max_new_tokens=self.max_generation_token,
93 | streamer=streamer,
94 | )
95 | t = Thread(target=GEMMA_MODEL.generate, kwargs=input_kwargs)
96 | t.start()
97 |
98 | partial_text = ""
99 | for new_text in streamer:
100 | partial_text += new_text
101 | yield partial_text
102 | self.clear_cuda_cache()
103 |
--------------------------------------------------------------------------------
/modules/models/GooglePaLM.py:
--------------------------------------------------------------------------------
1 | from .base_model import BaseLLMModel
2 | import google.generativeai as palm
3 |
4 |
5 | class Google_PaLM_Client(BaseLLMModel):
6 | def __init__(self, model_name, api_key, user_name="") -> None:
7 | super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})
8 |
9 | def _get_palm_style_input(self):
10 | new_history = []
11 | for item in self.history:
12 | if item["role"] == "user":
13 | new_history.append({'author': '1', 'content': item["content"]})
14 | else:
15 | new_history.append({'author': '0', 'content': item["content"]})
16 | return new_history
17 |
18 | def get_answer_at_once(self):
19 | palm.configure(api_key=self.api_key)
20 | messages = self._get_palm_style_input()
21 | response = palm.chat(context=self.system_prompt, messages=messages,
22 | temperature=self.temperature, top_p=self.top_p, model=self.model_name)
23 | if response.last is not None:
24 | return response.last, len(response.last)
25 | else:
26 | reasons = '\n\n'.join(
27 | reason['reason'].name for reason in response.filters)
28 | return "由于下面的原因,Google 拒绝返回 PaLM 的回答:\n\n" + reasons, 0
29 |
--------------------------------------------------------------------------------
/modules/models/Groq.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import textwrap
4 | import uuid
5 |
6 | import os
7 | from groq import Groq
8 | import gradio as gr
9 | import PIL
10 | import requests
11 |
12 | from modules.presets import i18n
13 |
14 | from ..index_func import construct_index
15 | from ..utils import count_token, construct_system
16 | from .base_model import BaseLLMModel
17 |
18 |
19 | class Groq_Client(BaseLLMModel):
20 | def __init__(self, model_name, api_key, user_name="") -> None:
21 | super().__init__(
22 | model_name=model_name,
23 | user=user_name,
24 | config={
25 | "api_key": api_key
26 | }
27 | )
28 | self.client = Groq(
29 | api_key=os.environ.get("GROQ_API_KEY"),
30 | base_url=self.api_host,
31 | )
32 |
33 | def _get_groq_style_input(self):
34 | messages = [construct_system(self.system_prompt), *self.history]
35 | return messages
36 |
37 | def get_answer_at_once(self):
38 | messages = self._get_groq_style_input()
39 | chat_completion = self.client.chat.completions.create(
40 | messages=messages,
41 | model=self.model_name,
42 | )
43 | return chat_completion.choices[0].message.content, chat_completion.usage.total_tokens
44 |
45 |
46 | def get_answer_stream_iter(self):
47 | messages = self._get_groq_style_input()
48 | completion = self.client.chat.completions.create(
49 | model=self.model_name,
50 | messages=messages,
51 | temperature=self.temperature,
52 | max_tokens=self.max_generation_token,
53 | top_p=self.top_p,
54 | stream=True,
55 | stop=self.stop_sequence,
56 | )
57 |
58 | partial_text = ""
59 | for chunk in completion:
60 | partial_text += chunk.choices[0].delta.content or ""
61 | yield partial_text
62 |
--------------------------------------------------------------------------------
/modules/models/LLaMA.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | import os
5 | from llama_cpp import Llama
6 |
7 | from ..index_func import *
8 | from ..presets import *
9 | from ..utils import *
10 | from .base_model import BaseLLMModel, download
11 |
12 | SYS_PREFIX = "<>\n"
13 | SYS_POSTFIX = "\n<>\n\n"
14 | INST_PREFIX = "[INST] "
15 | INST_POSTFIX = " "
16 | OUTPUT_PREFIX = "[/INST] "
17 | OUTPUT_POSTFIX = ""
18 |
19 |
20 | class LLaMA_Client(BaseLLMModel):
21 | def __init__(self, model_name, lora_path=None, user_name="") -> None:
22 | super().__init__(model_name=model_name, user=user_name)
23 |
24 | self.max_generation_token = 1000
25 | if model_name in MODEL_METADATA:
26 | path_to_model = download(
27 | MODEL_METADATA[model_name]["repo_id"],
28 | MODEL_METADATA[model_name]["filelist"][0],
29 | )
30 | else:
31 | dir_to_model = os.path.join("models", model_name)
32 | # look for nay .gguf file in the dir_to_model directory and its subdirectories
33 | path_to_model = None
34 | for root, dirs, files in os.walk(dir_to_model):
35 | for file in files:
36 | if file.endswith(".gguf"):
37 | path_to_model = os.path.join(root, file)
38 | break
39 | if path_to_model is not None:
40 | break
41 | self.system_prompt = ""
42 |
43 | if lora_path is not None:
44 | lora_path = os.path.join("lora", lora_path)
45 | self.model = Llama(model_path=path_to_model, lora_path=lora_path)
46 | else:
47 | self.model = Llama(model_path=path_to_model)
48 |
49 | def _get_llama_style_input(self):
50 | context = []
51 | for conv in self.history:
52 | if conv["role"] == "system":
53 | context.append(SYS_PREFIX + conv["content"] + SYS_POSTFIX)
54 | elif conv["role"] == "user":
55 | context.append(
56 | INST_PREFIX + conv["content"] + INST_POSTFIX + OUTPUT_PREFIX
57 | )
58 | else:
59 | context.append(conv["content"] + OUTPUT_POSTFIX)
60 | return "".join(context)
61 | # for conv in self.history:
62 | # if conv["role"] == "system":
63 | # context.append(conv["content"])
64 | # elif conv["role"] == "user":
65 | # context.append(
66 | # conv["content"]
67 | # )
68 | # else:
69 | # context.append(conv["content"])
70 | # return "\n\n".join(context)+"\n\n"
71 |
72 | def get_answer_at_once(self):
73 | context = self._get_llama_style_input()
74 | response = self.model(
75 | context,
76 | max_tokens=self.max_generation_token,
77 | stop=[],
78 | echo=False,
79 | stream=False,
80 | )
81 | return response, len(response)
82 |
83 | def get_answer_stream_iter(self):
84 | context = self._get_llama_style_input()
85 | iter = self.model(
86 | context,
87 | max_tokens=self.max_generation_token,
88 | stop=[SYS_PREFIX, SYS_POSTFIX, INST_PREFIX, OUTPUT_PREFIX, OUTPUT_POSTFIX],
89 | echo=False,
90 | stream=True,
91 | )
92 | partial_text = ""
93 | for i in iter:
94 | response = i["choices"][0]["text"]
95 | partial_text += response
96 | yield partial_text
97 |
--------------------------------------------------------------------------------
/modules/models/Ollama.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import textwrap
4 | import uuid
5 |
6 | from ollama import Client
7 |
8 | from modules.presets import i18n
9 |
10 | from ..index_func import construct_index
11 | from ..utils import count_token
12 | from .base_model import BaseLLMModel
13 |
14 |
15 | class OllamaClient(BaseLLMModel):
16 | def __init__(self, model_name, user_name="", ollama_host="", backend_model="") -> None:
17 | super().__init__(model_name=model_name, user=user_name)
18 | self.backend_model = backend_model
19 | self.ollama_host = ollama_host
20 | self.update_token_limit()
21 |
22 | def get_model_list(self):
23 | client = Client(host=self.ollama_host)
24 | return client.list()
25 |
26 | def update_token_limit(self):
27 | lower_model_name = self.backend_model.lower()
28 | if "mistral" in lower_model_name:
29 | self.token_upper_limit = 8*1024
30 | elif "gemma" in lower_model_name:
31 | self.token_upper_limit = 8*1024
32 | elif "codellama" in lower_model_name:
33 | self.token_upper_limit = 4*1024
34 | elif "llama2-chinese" in lower_model_name:
35 | self.token_upper_limit = 4*1024
36 | elif "llama2" in lower_model_name:
37 | self.token_upper_limit = 4*1024
38 | elif "mixtral" in lower_model_name:
39 | self.token_upper_limit = 32*1024
40 | elif "llava" in lower_model_name:
41 | self.token_upper_limit = 4*1024
42 |
43 | def get_answer_stream_iter(self):
44 | if self.backend_model == "":
45 | return i18n("请先选择Ollama后端模型\n\n")
46 | client = Client(host=self.ollama_host)
47 | response = client.chat(model=self.backend_model, messages=self.history,stream=True)
48 | partial_text = ""
49 | for i in response:
50 | response = i['message']['content']
51 | partial_text += response
52 | yield partial_text
53 | self.all_token_counts[-1] = count_token(partial_text)
54 | yield partial_text
55 |
--------------------------------------------------------------------------------
/modules/models/OpenAIInstruct.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 |
3 | client = OpenAI()
4 | from .base_model import BaseLLMModel
5 | from .. import shared
6 | from ..config import retrieve_proxy
7 |
8 |
9 | class OpenAI_Instruct_Client(BaseLLMModel):
10 | def __init__(self, model_name, api_key, user_name="") -> None:
11 | super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})
12 |
13 | def _get_instruct_style_input(self):
14 | return "".join([item["content"] for item in self.history])
15 |
16 | @shared.state.switching_api_key
17 | def get_answer_at_once(self):
18 | prompt = self._get_instruct_style_input()
19 | with retrieve_proxy():
20 | response = client.completions.create(
21 | model=self.model_name,
22 | prompt=prompt,
23 | temperature=self.temperature,
24 | top_p=self.top_p,
25 | )
26 | return response.choices[0].text.strip(), response.usage.total_tokens
27 |
--------------------------------------------------------------------------------
/modules/models/Qwen.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer
2 | import os
3 | from transformers.generation import GenerationConfig
4 | import logging
5 | import colorama
6 | from .base_model import BaseLLMModel
7 | from ..presets import MODEL_METADATA
8 |
9 |
10 | class Qwen_Client(BaseLLMModel):
11 | def __init__(self, model_name, user_name="") -> None:
12 | super().__init__(model_name=model_name, user=user_name)
13 | model_source = None
14 | if os.path.exists("models"):
15 | model_dirs = os.listdir("models")
16 | if model_name in model_dirs:
17 | model_source = f"models/{model_name}"
18 | if model_source is None:
19 | try:
20 | model_source = MODEL_METADATA[model_name]["repo_id"]
21 | except KeyError:
22 | model_source = model_name
23 | self.tokenizer = AutoTokenizer.from_pretrained(model_source, trust_remote_code=True, resume_download=True)
24 | self.model = AutoModelForCausalLM.from_pretrained(model_source, device_map="cuda", trust_remote_code=True, resume_download=True).eval()
25 |
26 | def generation_config(self):
27 | return GenerationConfig.from_dict({
28 | "chat_format": "chatml",
29 | "do_sample": True,
30 | "eos_token_id": 151643,
31 | "max_length": self.token_upper_limit,
32 | "max_new_tokens": 512,
33 | "max_window_size": 6144,
34 | "pad_token_id": 151643,
35 | "top_k": 0,
36 | "top_p": self.top_p,
37 | "transformers_version": "4.33.2",
38 | "trust_remote_code": True,
39 | "temperature": self.temperature,
40 | })
41 |
42 | def _get_glm_style_input(self):
43 | history = [x["content"] for x in self.history]
44 | query = history.pop()
45 | logging.debug(colorama.Fore.YELLOW +
46 | f"{history}" + colorama.Fore.RESET)
47 | assert (
48 | len(history) % 2 == 0
49 | ), f"History should be even length. current history is: {history}"
50 | history = [[history[i], history[i + 1]]
51 | for i in range(0, len(history), 2)]
52 | return history, query
53 |
54 | def get_answer_at_once(self):
55 | history, query = self._get_glm_style_input()
56 | self.model.generation_config = self.generation_config()
57 | response, history = self.model.chat(self.tokenizer, query, history=history)
58 | return response, len(response)
59 |
60 | def get_answer_stream_iter(self):
61 | history, query = self._get_glm_style_input()
62 | self.model.generation_config = self.generation_config()
63 | for response in self.model.chat_stream(
64 | self.tokenizer,
65 | query,
66 | history,
67 | ):
68 | yield response
69 |
--------------------------------------------------------------------------------
/modules/models/StableLM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
3 | import time
4 | import numpy as np
5 | from torch.nn import functional as F
6 | import os
7 | from .base_model import BaseLLMModel
8 | from threading import Thread
9 |
10 | STABLELM_MODEL = None
11 | STABLELM_TOKENIZER = None
12 |
13 |
14 | class StopOnTokens(StoppingCriteria):
15 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
16 | stop_ids = [50278, 50279, 50277, 1, 0]
17 | for stop_id in stop_ids:
18 | if input_ids[0][-1] == stop_id:
19 | return True
20 | return False
21 |
22 |
23 | class StableLM_Client(BaseLLMModel):
24 | def __init__(self, model_name, user_name="") -> None:
25 | super().__init__(model_name=model_name, user=user_name)
26 | global STABLELM_MODEL, STABLELM_TOKENIZER
27 | print(f"Starting to load StableLM to memory")
28 | if model_name == "StableLM":
29 | model_name = "stabilityai/stablelm-tuned-alpha-7b"
30 | else:
31 | model_name = f"models/{model_name}"
32 | if STABLELM_MODEL is None:
33 | STABLELM_MODEL = AutoModelForCausalLM.from_pretrained(
34 | model_name, torch_dtype=torch.float16).cuda()
35 | if STABLELM_TOKENIZER is None:
36 | STABLELM_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
37 | self.generator = pipeline(
38 | 'text-generation', model=STABLELM_MODEL, tokenizer=STABLELM_TOKENIZER, device=0)
39 | print(f"Sucessfully loaded StableLM to the memory")
40 | self.system_prompt = """StableAssistant
41 | - StableAssistant is A helpful and harmless Open Source AI Language Model developed by Stability and CarperAI.
42 | - StableAssistant is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
43 | - StableAssistant is more than just an information source, StableAssistant is also able to write poetry, short stories, and make jokes.
44 | - StableAssistant will refuse to participate in anything that could harm a human."""
45 | self.max_generation_token = 1024
46 | self.top_p = 0.95
47 | self.temperature = 1.0
48 |
49 | def _get_stablelm_style_input(self):
50 | history = self.history + [{"role": "assistant", "content": ""}]
51 | print(history)
52 | messages = self.system_prompt + \
53 | "".join(["".join(["<|USER|>"+history[i]["content"], "<|ASSISTANT|>"+history[i + 1]["content"]])
54 | for i in range(0, len(history), 2)])
55 | return messages
56 |
57 | def _generate(self, text, bad_text=None):
58 | stop = StopOnTokens()
59 | result = self.generator(text, max_new_tokens=self.max_generation_token, num_return_sequences=1, num_beams=1, do_sample=True,
60 | temperature=self.temperature, top_p=self.top_p, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
61 | return result[0]["generated_text"].replace(text, "")
62 |
63 | def get_answer_at_once(self):
64 | messages = self._get_stablelm_style_input()
65 | return self._generate(messages), len(messages)
66 |
67 | def get_answer_stream_iter(self):
68 | stop = StopOnTokens()
69 | messages = self._get_stablelm_style_input()
70 |
71 | # model_inputs = tok([messages], return_tensors="pt")['input_ids'].cuda()[:, :4096-1024]
72 | model_inputs = STABLELM_TOKENIZER(
73 | [messages], return_tensors="pt").to("cuda")
74 | streamer = TextIteratorStreamer(
75 | STABLELM_TOKENIZER, timeout=10., skip_prompt=True, skip_special_tokens=True)
76 | generate_kwargs = dict(
77 | model_inputs,
78 | streamer=streamer,
79 | max_new_tokens=self.max_generation_token,
80 | do_sample=True,
81 | top_p=self.top_p,
82 | top_k=1000,
83 | temperature=self.temperature,
84 | num_beams=1,
85 | stopping_criteria=StoppingCriteriaList([stop])
86 | )
87 | t = Thread(target=STABLELM_MODEL.generate, kwargs=generate_kwargs)
88 | t.start()
89 |
90 | partial_text = ""
91 | for new_text in streamer:
92 | partial_text += new_text
93 | yield partial_text
94 |
--------------------------------------------------------------------------------
/modules/models/XMChat.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import base64
4 | import json
5 | import logging
6 | import os
7 | import uuid
8 | from io import BytesIO
9 |
10 | import requests
11 | from PIL import Image
12 |
13 | from ..index_func import *
14 | from ..presets import *
15 | from ..utils import *
16 | from .base_model import BaseLLMModel
17 |
18 |
19 | class XMChat(BaseLLMModel):
20 | def __init__(self, api_key, user_name=""):
21 | super().__init__(model_name="xmchat", user=user_name)
22 | self.api_key = api_key
23 | self.session_id = None
24 | self.reset()
25 | self.image_bytes = None
26 | self.image_path = None
27 | self.xm_history = []
28 | self.url = "https://xmbot.net/web"
29 | if self.api_host is not None:
30 | self.url = self.api_host
31 | self.last_conv_id = None
32 |
33 | def reset(self, remain_system_prompt=False):
34 | self.session_id = str(uuid.uuid4())
35 | self.last_conv_id = None
36 | return super().reset()
37 |
38 | def image_to_base64(self, image_path):
39 | # 打开并加载图片
40 | img = Image.open(image_path)
41 |
42 | # 获取图片的宽度和高度
43 | width, height = img.size
44 |
45 | # 计算压缩比例,以确保最长边小于4096像素
46 | max_dimension = 2048
47 | scale_ratio = min(max_dimension / width, max_dimension / height)
48 |
49 | if scale_ratio < 1:
50 | # 按压缩比例调整图片大小
51 | new_width = int(width * scale_ratio)
52 | new_height = int(height * scale_ratio)
53 | img = img.resize((new_width, new_height), Image.LANCZOS)
54 |
55 | # 将图片转换为jpg格式的二进制数据
56 | buffer = BytesIO()
57 | if img.mode == "RGBA":
58 | img = img.convert("RGB")
59 | img.save(buffer, format='JPEG')
60 | binary_image = buffer.getvalue()
61 |
62 | # 对二进制数据进行Base64编码
63 | base64_image = base64.b64encode(binary_image).decode('utf-8')
64 |
65 | return base64_image
66 |
67 | def try_read_image(self, filepath):
68 | def is_image_file(filepath):
69 | # 判断文件是否为图片
70 | valid_image_extensions = [
71 | ".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
72 | file_extension = os.path.splitext(filepath)[1].lower()
73 | return file_extension in valid_image_extensions
74 |
75 | if is_image_file(filepath):
76 | logging.info(f"读取图片文件: {filepath}")
77 | self.image_bytes = self.image_to_base64(filepath)
78 | self.image_path = filepath
79 | else:
80 | self.image_bytes = None
81 | self.image_path = None
82 |
83 | def like(self):
84 | if self.last_conv_id is None:
85 | return "点赞失败,你还没发送过消息"
86 | data = {
87 | "uuid": self.last_conv_id,
88 | "appraise": "good"
89 | }
90 | requests.post(self.url, json=data)
91 | return "👍点赞成功,感谢反馈~"
92 |
93 | def dislike(self):
94 | if self.last_conv_id is None:
95 | return "点踩失败,你还没发送过消息"
96 | data = {
97 | "uuid": self.last_conv_id,
98 | "appraise": "bad"
99 | }
100 | requests.post(self.url, json=data)
101 | return "👎点踩成功,感谢反馈~"
102 |
103 | def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
104 | fake_inputs = real_inputs
105 | display_append = ""
106 | limited_context = False
107 | return limited_context, fake_inputs, display_append, real_inputs, chatbot
108 |
109 | def handle_file_upload(self, files, chatbot, language):
110 | """if the model accepts multi modal input, implement this function"""
111 | if files:
112 | for file in files:
113 | if file.name:
114 | logging.info(f"尝试读取图像: {file.name}")
115 | self.try_read_image(file.name)
116 | if self.image_path is not None:
117 | chatbot = chatbot + [((self.image_path,), None)]
118 | if self.image_bytes is not None:
119 | logging.info("使用图片作为输入")
120 | # XMChat的一轮对话中实际上只能处理一张图片
121 | self.reset()
122 | conv_id = str(uuid.uuid4())
123 | data = {
124 | "user_id": self.api_key,
125 | "session_id": self.session_id,
126 | "uuid": conv_id,
127 | "data_type": "imgbase64",
128 | "data": self.image_bytes
129 | }
130 | response = requests.post(self.url, json=data)
131 | response = json.loads(response.text)
132 | logging.info(f"图片回复: {response['data']}")
133 | return None, chatbot, None
134 |
135 | def get_answer_at_once(self):
136 | question = self.history[-1]["content"]
137 | conv_id = str(uuid.uuid4())
138 | self.last_conv_id = conv_id
139 | data = {
140 | "user_id": self.api_key,
141 | "session_id": self.session_id,
142 | "uuid": conv_id,
143 | "data_type": "text",
144 | "data": question
145 | }
146 | response = requests.post(self.url, json=data)
147 | try:
148 | response = json.loads(response.text)
149 | return response["data"], len(response["data"])
150 | except Exception as e:
151 | return response.text, len(response.text)
152 |
--------------------------------------------------------------------------------
/modules/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GaiZhenbiao/ChuanhuChatGPT/550fd86b9411bf8afe73d783a0d90d074e118be4/modules/models/__init__.py
--------------------------------------------------------------------------------
/modules/models/configuration_moss.py:
--------------------------------------------------------------------------------
1 | """ Moss model configuration"""
2 |
3 | from transformers.utils import logging
4 | from transformers.configuration_utils import PretrainedConfig
5 |
6 |
7 | logger = logging.get_logger(__name__)
8 |
9 |
10 | class MossConfig(PretrainedConfig):
11 | r"""
12 | This is the configuration class to store the configuration of a [`MossModel`]. It is used to instantiate a
13 | Moss model according to the specified arguments, defining the model architecture. Instantiating a configuration
14 | with the defaults will yield a similar configuration to that of the Moss
15 | [fnlp/moss-moon-003-base](https://huggingface.co/fnlp/moss-moon-003-base) architecture. Configuration objects
16 | inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from
17 | [`PretrainedConfig`] for more information.
18 |
19 | Args:
20 | vocab_size (`int`, *optional*, defaults to 107008):
21 | Vocabulary size of the Moss model. Defines the number of different tokens that can be represented by the
22 | `inputs_ids` passed when calling [`MossModel`].
23 | n_positions (`int`, *optional*, defaults to 2048):
24 | The maximum sequence length that this model might ever be used with. Typically set this to something large
25 | just in case (e.g., 512 or 1024 or 2048).
26 | n_embd (`int`, *optional*, defaults to 4096):
27 | Dimensionality of the embeddings and hidden states.
28 | n_layer (`int`, *optional*, defaults to 28):
29 | Number of hidden layers in the Transformer encoder.
30 | n_head (`int`, *optional*, defaults to 16):
31 | Number of attention heads for each attention layer in the Transformer encoder.
32 | rotary_dim (`int`, *optional*, defaults to 64):
33 | Number of dimensions in the embedding that Rotary Position Embedding is applied to.
34 | n_inner (`int`, *optional*, defaults to None):
35 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
36 | activation_function (`str`, *optional*, defaults to `"gelu_new"`):
37 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
38 | resid_pdrop (`float`, *optional*, defaults to 0.1):
39 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
40 | embd_pdrop (`int`, *optional*, defaults to 0.1):
41 | The dropout ratio for the embeddings.
42 | attn_pdrop (`float`, *optional*, defaults to 0.1):
43 | The dropout ratio for the attention.
44 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
45 | The epsilon to use in the layer normalization layers.
46 | initializer_range (`float`, *optional*, defaults to 0.02):
47 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
48 | use_cache (`bool`, *optional*, defaults to `True`):
49 | Whether or not the model should return the last key/values attentions (not used by all models).
50 |
51 | Example:
52 |
53 | ```python
54 | >>> from modeling_moss import MossModel
55 | >>> from configuration_moss import MossConfig
56 |
57 | >>> # Initializing a moss-moon-003-base configuration
58 | >>> configuration = MossConfig()
59 |
60 | >>> # Initializing a model (with random weights) from the configuration
61 | >>> model = MossModel(configuration)
62 |
63 | >>> # Accessing the model configuration
64 | >>> configuration = model.config
65 | ```"""
66 |
67 | model_type = "moss"
68 | attribute_map = {
69 | "max_position_embeddings": "n_positions",
70 | "hidden_size": "n_embd",
71 | "num_attention_heads": "n_head",
72 | "num_hidden_layers": "n_layer",
73 | }
74 |
75 | def __init__(
76 | self,
77 | vocab_size=107008,
78 | n_positions=2048,
79 | n_ctx=2048,
80 | n_embd=4096,
81 | n_layer=28,
82 | n_head=16,
83 | rotary_dim=64,
84 | n_inner=None,
85 | activation_function="gelu_new",
86 | resid_pdrop=0.0,
87 | embd_pdrop=0.0,
88 | attn_pdrop=0.0,
89 | layer_norm_epsilon=1e-5,
90 | initializer_range=0.02,
91 | use_cache=True,
92 | bos_token_id=106028,
93 | eos_token_id=106068,
94 | tie_word_embeddings=False,
95 | **kwargs,
96 | ):
97 | self.vocab_size = vocab_size
98 | self.n_ctx = n_ctx
99 | self.n_positions = n_positions
100 | self.n_embd = n_embd
101 | self.n_layer = n_layer
102 | self.n_head = n_head
103 | self.n_inner = n_inner
104 | self.rotary_dim = rotary_dim
105 | self.activation_function = activation_function
106 | self.resid_pdrop = resid_pdrop
107 | self.embd_pdrop = embd_pdrop
108 | self.attn_pdrop = attn_pdrop
109 | self.layer_norm_epsilon = layer_norm_epsilon
110 | self.initializer_range = initializer_range
111 | self.use_cache = use_cache
112 |
113 | self.bos_token_id = bos_token_id
114 | self.eos_token_id = eos_token_id
115 |
116 | super().__init__(
117 | bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
118 | )
119 |
--------------------------------------------------------------------------------
/modules/models/minimax.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import colorama
5 | import requests
6 | import logging
7 |
8 | from modules.models.base_model import BaseLLMModel
9 | from modules.presets import STANDARD_ERROR_MSG, GENERAL_ERROR_MSG, TIMEOUT_STREAMING, TIMEOUT_ALL, i18n
10 |
11 | group_id = os.environ.get("MINIMAX_GROUP_ID", "")
12 |
13 |
14 | class MiniMax_Client(BaseLLMModel):
15 | """
16 | MiniMax Client
17 | 接口文档见 https://api.minimax.chat/document/guides/chat
18 | """
19 |
20 | def __init__(self, model_name, api_key, user_name="", system_prompt=None):
21 | super().__init__(model_name=model_name, user=user_name)
22 | self.url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}'
23 | self.history = []
24 | self.api_key = api_key
25 | self.system_prompt = system_prompt
26 | self.headers = {
27 | "Authorization": f"Bearer {api_key}",
28 | "Content-Type": "application/json"
29 | }
30 |
31 | def get_answer_at_once(self):
32 | # minimax temperature is (0,1] and base model temperature is [0,2], and yuan 0.9 == base 1 so need to convert
33 | temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10
34 |
35 | request_body = {
36 | "model": self.model_name.replace('minimax-', ''),
37 | "temperature": temperature,
38 | "skip_info_mask": True,
39 | 'messages': [{"sender_type": "USER", "text": self.history[-1]['content']}]
40 | }
41 | if self.n_choices:
42 | request_body['beam_width'] = self.n_choices
43 | if self.system_prompt:
44 | request_body['prompt'] = self.system_prompt
45 | if self.max_generation_token:
46 | request_body['tokens_to_generate'] = self.max_generation_token
47 | if self.top_p:
48 | request_body['top_p'] = self.top_p
49 |
50 | response = requests.post(self.url, headers=self.headers, json=request_body)
51 |
52 | res = response.json()
53 | answer = res['reply']
54 | total_token_count = res["usage"]["total_tokens"]
55 | return answer, total_token_count
56 |
57 | def get_answer_stream_iter(self):
58 | response = self._get_response(stream=True)
59 | if response is not None:
60 | iter = self._decode_chat_response(response)
61 | partial_text = ""
62 | for i in iter:
63 | partial_text += i
64 | yield partial_text
65 | else:
66 | yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
67 |
68 | def _get_response(self, stream=False):
69 | minimax_api_key = self.api_key
70 | history = self.history
71 | logging.debug(colorama.Fore.YELLOW +
72 | f"{history}" + colorama.Fore.RESET)
73 | headers = {
74 | "Content-Type": "application/json",
75 | "Authorization": f"Bearer {minimax_api_key}",
76 | }
77 |
78 | temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10
79 |
80 | messages = []
81 | for msg in self.history:
82 | if msg['role'] == 'user':
83 | messages.append({"sender_type": "USER", "text": msg['content']})
84 | else:
85 | messages.append({"sender_type": "BOT", "text": msg['content']})
86 |
87 | request_body = {
88 | "model": self.model_name.replace('minimax-', ''),
89 | "temperature": temperature,
90 | "skip_info_mask": True,
91 | 'messages': messages
92 | }
93 | if self.n_choices:
94 | request_body['beam_width'] = self.n_choices
95 | if self.system_prompt:
96 | lines = self.system_prompt.splitlines()
97 | if lines[0].find(":") != -1 and len(lines[0]) < 20:
98 | request_body["role_meta"] = {
99 | "user_name": lines[0].split(":")[0],
100 | "bot_name": lines[0].split(":")[1]
101 | }
102 | lines.pop()
103 | request_body["prompt"] = "\n".join(lines)
104 | if self.max_generation_token:
105 | request_body['tokens_to_generate'] = self.max_generation_token
106 | else:
107 | request_body['tokens_to_generate'] = 512
108 | if self.top_p:
109 | request_body['top_p'] = self.top_p
110 |
111 | if stream:
112 | timeout = TIMEOUT_STREAMING
113 | request_body['stream'] = True
114 | request_body['use_standard_sse'] = True
115 | else:
116 | timeout = TIMEOUT_ALL
117 | try:
118 | response = requests.post(
119 | self.url,
120 | headers=headers,
121 | json=request_body,
122 | stream=stream,
123 | timeout=timeout,
124 | )
125 | except:
126 | return None
127 |
128 | return response
129 |
130 | def _decode_chat_response(self, response):
131 | error_msg = ""
132 | for chunk in response.iter_lines():
133 | if chunk:
134 | chunk = chunk.decode()
135 | chunk_length = len(chunk)
136 | print(chunk)
137 | try:
138 | chunk = json.loads(chunk[6:])
139 | except json.JSONDecodeError:
140 | print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
141 | error_msg += chunk
142 | continue
143 | if chunk_length > 6 and "delta" in chunk["choices"][0]:
144 | if "finish_reason" in chunk["choices"][0] and chunk["choices"][0]["finish_reason"] == "stop":
145 | self.all_token_counts.append(chunk["usage"]["total_tokens"] - sum(self.all_token_counts))
146 | break
147 | try:
148 | yield chunk["choices"][0]["delta"]
149 | except Exception as e:
150 | logging.error(f"Error: {e}")
151 | continue
152 | if error_msg:
153 | try:
154 | error_msg = json.loads(error_msg)
155 | if 'base_resp' in error_msg:
156 | status_code = error_msg['base_resp']['status_code']
157 | status_msg = error_msg['base_resp']['status_msg']
158 | raise Exception(f"{status_code} - {status_msg}")
159 | except json.JSONDecodeError:
160 | pass
161 | raise Exception(error_msg)
162 |
--------------------------------------------------------------------------------
/modules/models/spark.py:
--------------------------------------------------------------------------------
1 | import _thread as thread
2 | import base64
3 | import datetime
4 | import hashlib
5 | import hmac
6 | import json
7 | from collections import deque
8 | from urllib.parse import urlparse
9 | import ssl
10 | from datetime import datetime
11 | from time import mktime
12 | from urllib.parse import urlencode
13 | from wsgiref.handlers import format_date_time
14 | from threading import Condition
15 | import websocket
16 | import logging
17 |
18 | from .base_model import BaseLLMModel, CallbackToIterator
19 |
20 |
21 | class Ws_Param(object):
22 | # 来自官方 Demo
23 | # 初始化
24 | def __init__(self, APPID, APIKey, APISecret, Spark_url):
25 | self.APPID = APPID
26 | self.APIKey = APIKey
27 | self.APISecret = APISecret
28 | self.host = urlparse(Spark_url).netloc
29 | self.path = urlparse(Spark_url).path
30 | self.Spark_url = Spark_url
31 |
32 | # 生成url
33 | def create_url(self):
34 | # 生成RFC1123格式的时间戳
35 | now = datetime.now()
36 | date = format_date_time(mktime(now.timetuple()))
37 |
38 | # 拼接字符串
39 | signature_origin = "host: " + self.host + "\n"
40 | signature_origin += "date: " + date + "\n"
41 | signature_origin += "GET " + self.path + " HTTP/1.1"
42 |
43 | # 进行hmac-sha256进行加密
44 | signature_sha = hmac.new(
45 | self.APISecret.encode("utf-8"),
46 | signature_origin.encode("utf-8"),
47 | digestmod=hashlib.sha256,
48 | ).digest()
49 |
50 | signature_sha_base64 = base64.b64encode(
51 | signature_sha).decode(encoding="utf-8")
52 |
53 | authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
54 |
55 | authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
56 | encoding="utf-8"
57 | )
58 |
59 | # 将请求的鉴权参数组合为字典
60 | v = {"authorization": authorization, "date": date, "host": self.host}
61 | # 拼接鉴权参数,生成url
62 | url = self.Spark_url + "?" + urlencode(v)
63 | # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
64 | return url
65 |
66 |
67 | class Spark_Client(BaseLLMModel):
68 | def __init__(self, model_name, appid, api_key, api_secret, user_name="") -> None:
69 | super().__init__(model_name=model_name, user=user_name)
70 | self.api_key = api_key
71 | self.appid = appid
72 | self.api_secret = api_secret
73 | if None in [self.api_key, self.appid, self.api_secret]:
74 | raise Exception("请在配置文件或者环境变量中设置讯飞的API Key、APP ID和API Secret")
75 | self.spark_url = f"wss://spark-api.xf-yun.com{self.metadata['path']}"
76 | self.domain = self.metadata['domain']
77 |
78 | # 收到websocket错误的处理
79 | def on_error(self, ws, error):
80 | ws.iterator.callback("出现了错误:" + error)
81 |
82 | # 收到websocket关闭的处理
83 | def on_close(self, ws, one, two):
84 | pass
85 |
86 | # 收到websocket连接建立的处理
87 | def on_open(self, ws):
88 | thread.start_new_thread(self.run, (ws,))
89 |
90 | def run(self, ws, *args):
91 | data = json.dumps(
92 | self.gen_params()
93 | )
94 | ws.send(data)
95 |
96 | # 收到websocket消息的处理
97 | def on_message(self, ws, message):
98 | ws.iterator.callback(message)
99 |
100 | def gen_params(self):
101 | """
102 | 通过appid和用户的提问来生成请参数
103 | """
104 | data = {
105 | "header": {"app_id": self.appid, "uid": "1234"},
106 | "parameter": {
107 | "chat": {
108 | "domain": self.domain,
109 | "random_threshold": self.temperature,
110 | "max_tokens": 4096,
111 | "auditing": "default",
112 | }
113 | },
114 | "payload": {"message": {"text": self.history}},
115 | }
116 | return data
117 |
118 | def get_answer_stream_iter(self):
119 | wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, self.spark_url)
120 | websocket.enableTrace(False)
121 | wsUrl = wsParam.create_url()
122 | ws = websocket.WebSocketApp(
123 | wsUrl,
124 | on_message=self.on_message,
125 | on_error=self.on_error,
126 | on_close=self.on_close,
127 | on_open=self.on_open,
128 | )
129 | ws.appid = self.appid
130 | ws.domain = self.domain
131 |
132 | # Initialize the CallbackToIterator
133 | ws.iterator = CallbackToIterator()
134 |
135 | # Start the WebSocket connection in a separate thread
136 | thread.start_new_thread(
137 | ws.run_forever, (), {"sslopt": {"cert_reqs": ssl.CERT_NONE}}
138 | )
139 |
140 | # Iterate over the CallbackToIterator instance
141 | answer = ""
142 | total_tokens = 0
143 | for message in ws.iterator:
144 | data = json.loads(message)
145 | code = data["header"]["code"]
146 | if code != 0:
147 | ws.close()
148 | raise Exception(f"请求错误: {code}, {data}")
149 | else:
150 | choices = data["payload"]["choices"]
151 | status = choices["status"]
152 | content = choices["text"][0]["content"]
153 | if "usage" in data["payload"]:
154 | total_tokens = data["payload"]["usage"]["text"]["total_tokens"]
155 | answer += content
156 | if status == 2:
157 | ws.iterator.finish() # Finish the iterator when the status is 2
158 | ws.close()
159 | yield answer, total_tokens
160 |
--------------------------------------------------------------------------------
/modules/pdf_func.py:
--------------------------------------------------------------------------------
1 | from types import SimpleNamespace
2 | import pdfplumber
3 | import logging
4 | from langchain.docstore.document import Document
5 |
6 | def prepare_table_config(crop_page):
7 | """Prepare table查找边界, 要求page为原始page
8 |
9 | From https://github.com/jsvine/pdfplumber/issues/242
10 | """
11 | page = crop_page.root_page # root/parent
12 | cs = page.curves + page.edges
13 | def curves_to_edges():
14 | """See https://github.com/jsvine/pdfplumber/issues/127"""
15 | edges = []
16 | for c in cs:
17 | edges += pdfplumber.utils.rect_to_edges(c)
18 | return edges
19 | edges = curves_to_edges()
20 | return {
21 | "vertical_strategy": "explicit",
22 | "horizontal_strategy": "explicit",
23 | "explicit_vertical_lines": edges,
24 | "explicit_horizontal_lines": edges,
25 | "intersection_y_tolerance": 10,
26 | }
27 |
28 | def get_text_outside_table(crop_page):
29 | ts = prepare_table_config(crop_page)
30 | if len(ts["explicit_vertical_lines"]) == 0 or len(ts["explicit_horizontal_lines"]) == 0:
31 | return crop_page
32 |
33 | ### Get the bounding boxes of the tables on the page.
34 | bboxes = [table.bbox for table in crop_page.root_page.find_tables(table_settings=ts)]
35 | def not_within_bboxes(obj):
36 | """Check if the object is in any of the table's bbox."""
37 | def obj_in_bbox(_bbox):
38 | """See https://github.com/jsvine/pdfplumber/blob/stable/pdfplumber/table.py#L404"""
39 | v_mid = (obj["top"] + obj["bottom"]) / 2
40 | h_mid = (obj["x0"] + obj["x1"]) / 2
41 | x0, top, x1, bottom = _bbox
42 | return (h_mid >= x0) and (h_mid < x1) and (v_mid >= top) and (v_mid < bottom)
43 | return not any(obj_in_bbox(__bbox) for __bbox in bboxes)
44 |
45 | return crop_page.filter(not_within_bboxes)
46 | # 请使用 LaTeX 表达公式,行内公式以 $ 包裹,行间公式以 $$ 包裹
47 |
48 | extract_words = lambda page: page.extract_words(keep_blank_chars=True, y_tolerance=0, x_tolerance=1, extra_attrs=["fontname", "size", "object_type"])
49 | # dict_keys(['text', 'x0', 'x1', 'top', 'doctop', 'bottom', 'upright', 'direction', 'fontname', 'size'])
50 |
51 | def get_title_with_cropped_page(first_page):
52 | title = [] # 处理标题
53 | x0,top,x1,bottom = first_page.bbox # 获取页面边框
54 |
55 | for word in extract_words(first_page):
56 | word = SimpleNamespace(**word)
57 |
58 | if word.size >= 14:
59 | title.append(word.text)
60 | title_bottom = word.bottom
61 | elif word.text == "Abstract": # 获取页面abstract
62 | top = word.top
63 |
64 | user_info = [i["text"] for i in extract_words(first_page.within_bbox((x0,title_bottom,x1,bottom)))]
65 | # 裁剪掉上半部分, within_bbox: full_included; crop: partial_included
66 | return title, user_info, first_page.within_bbox((x0,top,x1,bottom))
67 |
68 | def get_column_cropped_pages(pages, two_column=True):
69 | new_pages = []
70 | for page in pages:
71 | if two_column:
72 | left = page.within_bbox((0, 0, page.width/2, page.height),relative=True)
73 | right = page.within_bbox((page.width/2, 0, page.width, page.height), relative=True)
74 | new_pages.append(left)
75 | new_pages.append(right)
76 | else:
77 | new_pages.append(page)
78 |
79 | return new_pages
80 |
81 | def parse_pdf(filename, two_column = True):
82 | level = logging.getLogger().level
83 | if level == logging.getLevelName("DEBUG"):
84 | logging.getLogger().setLevel("INFO")
85 |
86 | with pdfplumber.open(filename) as pdf:
87 | title, user_info, first_page = get_title_with_cropped_page(pdf.pages[0])
88 | new_pages = get_column_cropped_pages([first_page] + pdf.pages[1:], two_column)
89 |
90 | chapters = []
91 | # tuple (chapter_name, [pageid] (start,stop), chapter_text)
92 | create_chapter = lambda page_start,name_top,name_bottom: SimpleNamespace(
93 | name=[],
94 | name_top=name_top,
95 | name_bottom=name_bottom,
96 | record_chapter_name = True,
97 |
98 | page_start=page_start,
99 | page_stop=None,
100 |
101 | text=[],
102 | )
103 | cur_chapter = None
104 |
105 | # 按页遍历PDF文档
106 | for idx, page in enumerate(new_pages):
107 | page = get_text_outside_table(page)
108 |
109 | # 按行遍历页面文本
110 | for word in extract_words(page):
111 | word = SimpleNamespace(**word)
112 |
113 | # 检查行文本是否以12号字体打印,如果是,则将其作为新章节开始
114 | if word.size >= 11: # 出现chapter name
115 | if cur_chapter is None:
116 | cur_chapter = create_chapter(page.page_number, word.top, word.bottom)
117 | elif not cur_chapter.record_chapter_name or (cur_chapter.name_bottom != cur_chapter.name_bottom and cur_chapter.name_top != cur_chapter.name_top):
118 | # 不再继续写chapter name
119 | cur_chapter.page_stop = page.page_number # stop id
120 | chapters.append(cur_chapter)
121 | # 重置当前chapter信息
122 | cur_chapter = create_chapter(page.page_number, word.top, word.bottom)
123 |
124 | # print(word.size, word.top, word.bottom, word.text)
125 | cur_chapter.name.append(word.text)
126 | else:
127 | cur_chapter.record_chapter_name = False # chapter name 结束
128 | cur_chapter.text.append(word.text)
129 | else:
130 | # 处理最后一个章节
131 | cur_chapter.page_stop = page.page_number # stop id
132 | chapters.append(cur_chapter)
133 |
134 | for i in chapters:
135 | logging.info(f"section: {i.name} pages:{i.page_start, i.page_stop} word-count:{len(i.text)}")
136 | logging.debug(" ".join(i.text))
137 |
138 | title = " ".join(title)
139 | user_info = " ".join(user_info)
140 | text = f"Article Title: {title}, Information:{user_info}\n"
141 | for idx, chapter in enumerate(chapters):
142 | chapter.name = " ".join(chapter.name)
143 | text += f"The {idx}th Chapter {chapter.name}: " + " ".join(chapter.text) + "\n"
144 |
145 | logging.getLogger().setLevel(level)
146 | return Document(page_content=text, metadata={"title": title})
147 |
148 |
149 | if __name__ == '__main__':
150 | # Test code
151 | z = parse_pdf("./build/test.pdf")
152 | print(z["user_info"])
153 | print(z["title"])
154 |
155 |
--------------------------------------------------------------------------------
/modules/shared.py:
--------------------------------------------------------------------------------
1 | from modules.presets import CHAT_COMPLETION_URL, BALANCE_API_URL, USAGE_API_URL, API_HOST, OPENAI_API_BASE, IMAGES_COMPLETION_URL
2 | import os
3 | import queue
4 | import openai
5 |
6 | def format_openai_host(api_host: str):
7 | api_host = api_host.rstrip("/")
8 | if not api_host.startswith("http"):
9 | api_host = f"https://{api_host}"
10 | if api_host.endswith("/v1"):
11 | api_host = api_host[:-3]
12 | chat_completion_url = f"{api_host}/v1/chat/completions"
13 | images_completion_url = f"{api_host}/v1/images/generations"
14 | openai_api_base = f"{api_host}/v1"
15 | balance_api_url = f"{api_host}/dashboard/billing/credit_grants"
16 | usage_api_url = f"{api_host}/dashboard/billing/usage"
17 | return chat_completion_url, images_completion_url, openai_api_base, balance_api_url, usage_api_url
18 |
19 | class State:
20 | interrupted = False
21 | multi_api_key = False
22 | chat_completion_url = CHAT_COMPLETION_URL
23 | balance_api_url = BALANCE_API_URL
24 | usage_api_url = USAGE_API_URL
25 | openai_api_base = OPENAI_API_BASE
26 | images_completion_url = IMAGES_COMPLETION_URL
27 | api_host = API_HOST
28 |
29 | def interrupt(self):
30 | self.interrupted = True
31 |
32 | def recover(self):
33 | self.interrupted = False
34 |
35 | def set_api_host(self, api_host: str):
36 | self.api_host = api_host
37 | self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = format_openai_host(api_host)
38 | os.environ["OPENAI_API_BASE"] = self.openai_api_base
39 |
40 | def reset_api_host(self):
41 | self.chat_completion_url = CHAT_COMPLETION_URL
42 | self.images_completion_url = IMAGES_COMPLETION_URL
43 | self.balance_api_url = BALANCE_API_URL
44 | self.usage_api_url = USAGE_API_URL
45 | self.api_host = API_HOST
46 | os.environ["OPENAI_API_BASE"] = f"https://{API_HOST}"
47 | return API_HOST
48 |
49 | def reset_all(self):
50 | self.interrupted = False
51 | self.chat_completion_url = CHAT_COMPLETION_URL
52 |
53 | def set_api_key_queue(self, api_key_list):
54 | self.multi_api_key = True
55 | self.api_key_queue = queue.Queue()
56 | for api_key in api_key_list:
57 | self.api_key_queue.put(api_key)
58 |
59 | def switching_api_key(self, func):
60 | if not hasattr(self, "api_key_queue"):
61 | return func
62 |
63 | def wrapped(*args, **kwargs):
64 | api_key = self.api_key_queue.get()
65 | args[0].api_key = api_key
66 | ret = func(*args, **kwargs)
67 | self.api_key_queue.put(api_key)
68 | return ret
69 |
70 | return wrapped
71 |
72 |
73 | state = State()
74 |
75 | modules_path = os.path.dirname(os.path.realpath(__file__))
76 | chuanhu_path = os.path.dirname(modules_path)
77 | assets_path = os.path.join(chuanhu_path, "web_assets")
--------------------------------------------------------------------------------
/modules/train_func.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import traceback
4 |
5 | from openai import OpenAI
6 |
7 | client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
8 | import gradio as gr
9 | import ujson as json
10 | import commentjson
11 | import openpyxl
12 |
13 | import modules.presets as presets
14 | from modules.utils import get_file_hash, count_token
15 | from modules.presets import i18n
16 |
17 | def excel_to_jsonl(filepath, preview=False):
18 | # 打开Excel文件
19 | workbook = openpyxl.load_workbook(filepath)
20 |
21 | # 获取第一个工作表
22 | sheet = workbook.active
23 |
24 | # 获取所有行数据
25 | data = []
26 | for row in sheet.iter_rows(values_only=True):
27 | data.append(row)
28 |
29 | # 构建字典列表
30 | headers = data[0]
31 | jsonl = []
32 | for row in data[1:]:
33 | row_data = dict(zip(headers, row))
34 | if any(row_data.values()):
35 | jsonl.append(row_data)
36 | formatted_jsonl = []
37 | for i in jsonl:
38 | if "提问" in i and "答案" in i:
39 | if "系统" in i :
40 | formatted_jsonl.append({
41 | "messages":[
42 | {"role": "system", "content": i["系统"]},
43 | {"role": "user", "content": i["提问"]},
44 | {"role": "assistant", "content": i["答案"]}
45 | ]
46 | })
47 | else:
48 | formatted_jsonl.append({
49 | "messages":[
50 | {"role": "user", "content": i["提问"]},
51 | {"role": "assistant", "content": i["答案"]}
52 | ]
53 | })
54 | else:
55 | logging.warning(f"跳过一行数据,因为没有找到提问和答案: {i}")
56 | return formatted_jsonl
57 |
58 | def jsonl_save_to_disk(jsonl, filepath):
59 | file_hash = get_file_hash(file_paths = [filepath])
60 | os.makedirs("files", exist_ok=True)
61 | save_path = f"files/{file_hash}.jsonl"
62 | with open(save_path, "w") as f:
63 | f.write("\n".join([json.dumps(i, ensure_ascii=False) for i in jsonl]))
64 | return save_path
65 |
66 | def estimate_cost(ds):
67 | dialogues = []
68 | for l in ds:
69 | for m in l["messages"]:
70 | dialogues.append(m["content"])
71 | dialogues = "\n".join(dialogues)
72 | tokens = count_token(dialogues)
73 | return f"Token 数约为 {tokens},预估每轮(epoch)费用约为 {tokens / 1000 * 0.008} 美元。"
74 |
75 |
76 | def handle_dataset_selection(file_src):
77 | logging.info(f"Loading dataset {file_src.name}...")
78 | preview = ""
79 | if file_src.name.endswith(".jsonl"):
80 | with open(file_src.name, "r") as f:
81 | ds = [json.loads(l) for l in f.readlines()]
82 | else:
83 | ds = excel_to_jsonl(file_src.name)
84 | preview = ds[0]
85 |
86 | return preview, gr.update(interactive=True), estimate_cost(ds)
87 |
88 | def upload_to_openai(file_src):
89 | dspath = file_src.name
90 | msg = ""
91 | logging.info(f"Uploading dataset {dspath}...")
92 | if dspath.endswith(".xlsx"):
93 | jsonl = excel_to_jsonl(dspath)
94 | dspath = jsonl_save_to_disk(jsonl, dspath)
95 | try:
96 | uploaded = client.files.create(file=open(dspath, "rb"),
97 | purpose='fine-tune')
98 | return uploaded.id, f"上传成功"
99 | except Exception as e:
100 | traceback.print_exc()
101 | return "", f"上传失败,原因:{ e }"
102 |
103 | def build_event_description(id, status, trained_tokens, name=i18n("暂时未知")):
104 | # convert to markdown
105 | return f"""
106 | #### 训练任务 {id}
107 |
108 | 模型名称:{name}
109 |
110 | 状态:{status}
111 |
112 | 已经训练了 {trained_tokens} 个token
113 | """
114 |
115 | def start_training(file_id, suffix, epochs):
116 | try:
117 | job = client.fine_tuning.jobs.create(training_file=file_id, model="gpt-3.5-turbo", suffix=suffix, hyperparameters={"n_epochs": epochs})
118 | return build_event_description(job.id, job.status, job.trained_tokens)
119 | except Exception as e:
120 | traceback.print_exc()
121 | if "is not ready" in str(e):
122 | return "训练出错,因为文件还没准备好。OpenAI 需要一点时间准备文件,过几分钟再来试试。"
123 | return f"训练失败,原因:{ e }"
124 |
125 | def get_training_status():
126 | active_jobs = [build_event_description(job.id, job.status, job.trained_tokens, job.fine_tuned_model) for job in client.fine_tuning.jobs.list().data if job.status != "cancelled"]
127 | return "\n\n".join(active_jobs), gr.update(interactive=True) if len(active_jobs) > 0 else gr.update(interactive=False)
128 |
129 | def handle_dataset_clear():
130 | return gr.update(value=None), gr.update(interactive=False)
131 |
132 | def add_to_models():
133 | succeeded_jobs = [job for job in client.fine_tuning.jobs.list().data if job.status == "succeeded"]
134 | extra_models = [job.fine_tuned_model for job in succeeded_jobs]
135 | for i in extra_models:
136 | if i not in presets.MODELS:
137 | presets.MODELS.append(i)
138 |
139 | with open('config.json', 'r') as f:
140 | data = commentjson.load(f)
141 | if 'extra_models' in data:
142 | for i in extra_models:
143 | if i not in data['extra_models']:
144 | data['extra_models'].append(i)
145 | else:
146 | data['extra_models'] = extra_models
147 | if 'extra_model_metadata' in data:
148 | for i in extra_models:
149 | if i not in data['extra_model_metadata']:
150 | data['extra_model_metadata'][i] = {"model_name": i, "model_type": "OpenAIVision"}
151 | else:
152 | data['extra_model_metadata'] = {i: {"model_name": i, "model_type": "OpenAIVision"} for i in extra_models}
153 | with open('config.json', 'w') as f:
154 | commentjson.dump(data, f, indent=4)
155 |
156 | return gr.update(choices=presets.MODELS), f"成功添加了 {len(succeeded_jobs)} 个模型。"
157 |
158 | def cancel_all_jobs():
159 | jobs = [job for job in client.fine_tuning.jobs.list().data if job.status not in ["cancelled", "succeeded"]]
160 | for job in jobs:
161 | client.fine_tuning.jobs.cancel(job.id)
162 | return f"成功取消了 {len(jobs)} 个训练任务。"
163 |
--------------------------------------------------------------------------------
/modules/webui.py:
--------------------------------------------------------------------------------
1 |
2 | from collections import namedtuple
3 | import os
4 | import gradio as gr
5 |
6 | from . import shared
7 |
8 | # with open("./assets/ChuanhuChat.js", "r", encoding="utf-8") as f, \
9 | # open("./assets/external-scripts.js", "r", encoding="utf-8") as f1:
10 | # customJS = f.read()
11 | # externalScripts = f1.read()
12 |
13 |
14 | def get_html(filename):
15 | path = os.path.join(shared.chuanhu_path, "web_assets", "html", filename)
16 | if os.path.exists(path):
17 | with open(path, encoding="utf8") as file:
18 | return file.read()
19 | return ""
20 |
21 | def webpath(fn):
22 | if fn.startswith(shared.assets_path):
23 | web_path = os.path.relpath(fn, shared.chuanhu_path).replace('\\', '/')
24 | else:
25 | web_path = os.path.abspath(fn)
26 | return f'file={web_path}?{os.path.getmtime(fn)}'
27 |
28 | ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
29 |
30 | def javascript_html():
31 | head = ""
32 | for script in list_scripts("javascript", ".js"):
33 | head += f'\n'
34 | for script in list_scripts("javascript", ".mjs"):
35 | head += f'\n'
36 | return head
37 |
38 | def css_html():
39 | head = ""
40 | for cssfile in list_scripts("stylesheet", ".css"):
41 | head += f''
42 | return head
43 |
44 | def list_scripts(scriptdirname, extension):
45 | scripts_list = []
46 | scripts_dir = os.path.join(shared.chuanhu_path, "web_assets", scriptdirname)
47 | if os.path.exists(scripts_dir):
48 | for filename in sorted(os.listdir(scripts_dir)):
49 | scripts_list.append(ScriptFile(shared.assets_path, filename, os.path.join(scripts_dir, filename)))
50 | scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
51 | return scripts_list
52 |
53 |
54 | def reload_javascript():
55 | js = javascript_html()
56 | js += ''
57 | js += ''
58 | js += ''
59 |
60 | meta = """
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 | """
73 | css = css_html()
74 |
75 | def template_response(*args, **kwargs):
76 | res = GradioTemplateResponseOriginal(*args, **kwargs)
77 | res.body = res.body.replace(b'', f'{meta}{js}'.encode("utf8"))
78 | # res.body = res.body.replace(b'', f'{js}'.encode("utf8"))
79 | res.body = res.body.replace(b'