├── assets ├── app_1.png ├── app_2.png ├── gradio.png ├── app_gradio.png ├── cogagent_function.jpg ├── cogagent_function_cn.jpg ├── cogagent_workflow_cn.png └── cogagent_workflow_en.png ├── app ├── requirements.txt ├── README.md ├── README_en.md ├── register.py ├── vllm_openai_server.py ├── openai_demo.py └── client.py ├── requirements.txt ├── finetune ├── configs │ ├── ds_zero_2.json │ ├── ds_zero_3.json │ ├── sft.yaml │ └── lora.yaml ├── README.md ├── README_en.md └── finetune.py ├── .github ├── ISSUE_TEMPLATE │ ├── feature-request.yaml │ └── bug_report.yaml └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── MODEL_LICENSE ├── Action_space.md ├── inference ├── cli_demo.py └── web_demo.py ├── LICENSE ├── Action_space_en.md ├── README_zh.md └── README.md /assets/app_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zai-org/CogAgent/HEAD/assets/app_1.png -------------------------------------------------------------------------------- /assets/app_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zai-org/CogAgent/HEAD/assets/app_2.png -------------------------------------------------------------------------------- /assets/gradio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zai-org/CogAgent/HEAD/assets/gradio.png -------------------------------------------------------------------------------- /assets/app_gradio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zai-org/CogAgent/HEAD/assets/app_gradio.png -------------------------------------------------------------------------------- /assets/cogagent_function.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zai-org/CogAgent/HEAD/assets/cogagent_function.jpg -------------------------------------------------------------------------------- /assets/cogagent_function_cn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zai-org/CogAgent/HEAD/assets/cogagent_function_cn.jpg -------------------------------------------------------------------------------- /assets/cogagent_workflow_cn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zai-org/CogAgent/HEAD/assets/cogagent_workflow_cn.png -------------------------------------------------------------------------------- /assets/cogagent_workflow_en.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zai-org/CogAgent/HEAD/assets/cogagent_workflow_en.png -------------------------------------------------------------------------------- /app/requirements.txt: -------------------------------------------------------------------------------- 1 | # If you want to run client.py, you need to install the following packages 2 | pyautogui>=0.9.54 3 | openai>=1.58.0 4 | pyperclip>=1.9.0 5 | gradio>=5.9.0 6 | vllm>=0.6.6 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.47.0 2 | torch>=2.5.1 3 | torchvision>=0.20.0 4 | huggingface-hub>=0.25.1 5 | sentencepiece>=0.2.0 6 | jinja2>=3.1.4 7 | pydantic>=2.9.2 8 | timm>=1.0.9 9 | tiktoken>=0.8.0 10 | numpy==1.26.4 11 | accelerate>=1.1.1 12 | sentence_transformers>=3.1.1 13 | gradio>=5.23.2 14 | openai>=1.70.0 15 | einops>=0.8.0 16 | pillow>=10.4.0 17 | sse-starlette>=2.1.3 18 | bitsandbytes>=0.43.2 19 | spaces>=0.31.1 20 | 21 | 22 | # Using When finetuning model 23 | # nltk>=3.9.1 24 | # jieba>=0.42.1 25 | # ruamel.yaml>=0.18.10 26 | # datasets 27 | # peft>0.15.1 28 | # rouge_chinese>=1.0.3 -------------------------------------------------------------------------------- /finetune/configs/ds_zero_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "allgather_partitions": true, 16 | "allgather_bucket_size": 5e8, 17 | "overlap_comm": true, 18 | "reduce_scatter": true, 19 | "reduce_bucket_size": 5e8, 20 | "contiguous_gradients": true 21 | }, 22 | 23 | "gradient_accumulation_steps": "auto", 24 | "gradient_clipping": "auto", 25 | "steps_per_print": 2000, 26 | "train_batch_size": "auto", 27 | "train_micro_batch_size_per_gpu": "auto", 28 | "wall_clock_breakdown": false 29 | } -------------------------------------------------------------------------------- /finetune/configs/ds_zero_3.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "zero_allow_untested_optimizer": true, 4 | "bf16": { 5 | "enabled": "auto" 6 | }, 7 | "optimizer": { 8 | "type": "AdamW", 9 | "params": { 10 | "lr": "auto", 11 | "betas": "auto", 12 | "eps": "auto", 13 | "weight_decay": "auto" 14 | } 15 | }, 16 | "zero_optimization": { 17 | "stage": 3, 18 | "allgather_partitions": true, 19 | "allgather_bucket_size": 5e8, 20 | "reduce_scatter": true, 21 | "contiguous_gradients": true, 22 | "overlap_comm": true, 23 | "sub_group_size": 1e9, 24 | "reduce_bucket_size": "auto", 25 | "stage3_prefetch_bucket_size": "auto", 26 | "stage3_param_persistence_threshold": "auto", 27 | "stage3_max_live_parameters": 1e9, 28 | "stage3_max_reuse_distance": 1e9, 29 | "stage3_gather_16bit_weights_on_model_save": true 30 | } 31 | } -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a request for a new CogAgent feature / 提交一个新的 CogAgent 的功能建议 3 | labels: [ "feature" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request / 功能建议 11 | description: | 12 | A brief description of the functional proposal. Links to corresponding papers and code are desirable. 13 | 对功能建议的简述。最好提供对应的论文和代码链接 14 | 15 | - type: textarea 16 | id: motivation 17 | validations: 18 | required: true 19 | attributes: 20 | label: Motivation / 动机 21 | description: | 22 | Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here. 23 | 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。 24 | 25 | - type: textarea 26 | id: contribution 27 | validations: 28 | required: true 29 | attributes: 30 | label: Your contribution / 您的贡献 31 | description: | 32 | 33 | Your PR link or any other link you can help with. 34 | 您的PR链接或者其他您能提供帮助的链接。 -------------------------------------------------------------------------------- /finetune/configs/sft.yaml: -------------------------------------------------------------------------------- 1 | data_config: 2 | train_file: train.jsonl 3 | val_file: dev.jsonl 4 | test_file: dev.jsonl 5 | num_proc: 1 6 | 7 | combine: True 8 | freezeV: True 9 | max_input_length: 512 10 | max_output_length: 512 11 | 12 | training_args: 13 | # see `transformers.Seq2SeqTrainingArguments` 14 | output_dir: ./output 15 | max_steps: 3000 16 | # needed to be fit for the dataset 17 | learning_rate: 5e-5 18 | # settings for data loading 19 | per_device_train_batch_size: 1 20 | dataloader_num_workers: 16 21 | remove_unused_columns: false 22 | # settings for saving checkpoints 23 | save_strategy: steps 24 | save_steps: 500 25 | # settings for logging 26 | log_level: info 27 | logging_strategy: steps 28 | logging_steps: 10 29 | # settings for evaluation 30 | per_device_eval_batch_size: 4 31 | eval_strategy: steps 32 | eval_steps: 500 33 | # settings for optimizer 34 | # adam_epsilon: 1e-6 35 | # uncomment the following line to detect nan or inf values 36 | # debug: underflow_overflow 37 | predict_with_generate: true 38 | generation_config: 39 | max_new_tokens: 512 40 | # set your absolute deepspeed path here 41 | deepspeed: configs/ds_zero_3.json 42 | -------------------------------------------------------------------------------- /finetune/configs/lora.yaml: -------------------------------------------------------------------------------- 1 | data_config: 2 | train_file: train.jsonl 3 | val_file: dev.jsonl 4 | test_file: dev.jsonl 5 | num_proc: 1 6 | 7 | combine: True 8 | freezeV: False 9 | max_input_length: 512 10 | max_output_length: 512 11 | 12 | training_args: 13 | # see `transformers.Seq2SeqTrainingArguments` 14 | output_dir: ./output 15 | max_steps: 3000 16 | # needed to be fit for the dataset 17 | learning_rate: 5e-4 18 | # settings for data loading 19 | per_device_train_batch_size: 1 20 | dataloader_num_workers: 16 21 | remove_unused_columns: false 22 | # settings for saving checkpoints 23 | save_strategy: steps 24 | save_steps: 500 25 | # settings for logging 26 | log_level: info 27 | logging_strategy: steps 28 | logging_steps: 10 29 | # settings for evaluation 30 | per_device_eval_batch_size: 4 31 | eval_strategy: steps 32 | eval_steps: 500 33 | # settings for optimizer 34 | # adam_epsilon: 1e-6 35 | # uncomment the following line to detect nan or inf values 36 | # debug: underflow_overflow 37 | predict_with_generate: true 38 | # see `transformers.GenerationConfig` 39 | generation_config: 40 | max_new_tokens: 512 41 | # set your absolute deepspeed path here 42 | # deepspeed: configs/ds_zero_3.json 43 | peft_config: 44 | peft_type: LORA 45 | task_type: CAUSAL_LM 46 | r: 8 47 | lora_alpha: 32 48 | lora_dropout: 0.1 49 | target_modules: ["query_key_value"] 50 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | We welcome your contributions to this repository. To ensure elegant code style and better code quality, we have prepared the following contribution guidelines. 4 | 5 | ## What We Accept 6 | 7 | + This PR fixes a typo or improves the documentation (if this is the case, you may skip the other checks). 8 | + This PR fixes a specific issue — please reference the issue number in the PR description. Make sure your code strictly follows the coding standards below. 9 | + This PR introduces a new feature — please clearly explain the necessity and implementation of the feature. Make sure your code strictly follows the coding standards below. 10 | 11 | ## Code Style Guide 12 | 13 | Good code style is an art. We have prepared a `pyproject.toml` and a `pre-commit` hook to enforce consistent code formatting across the project. You can clean up your code following the steps below: 14 | 15 | 1. Install the required dependencies: 16 | ```shell 17 | pip install ruff pre-commit 18 | ``` 19 | 2. Then, run the following command: 20 | ```shell 21 | pre-commit run --all-files 22 | ``` 23 | If your code complies with the standards, you should not see any errors. 24 | 25 | ## Naming Conventions 26 | 27 | - Please use **English** for naming; do not use Pinyin or other languages. All comments should also be in English. 28 | - Follow **PEP8** naming conventions strictly, and use underscores to separate words. Avoid meaningless names such as `a`, `b`, `c`. 29 | -------------------------------------------------------------------------------- /app/README.md: -------------------------------------------------------------------------------- 1 | ## 部署 Agent Demo 2 | 3 | ### 设备检查 4 | 5 | 本 Demo 的测试环境系统环境如下: 6 | 7 | ``` 8 | macOS Sequoia: Version 15.0.1 (24A348) 9 | Memory: 16GB 10 | Python Version: 3.13.1 / 3.10.16 11 | ``` 12 | 13 | 对于其他与上述环境不同的配置 (例如Windows操作系统,Linux桌面版操作系统),我们未进行测试,其依赖库`pyautogui`支持Windows操作系统,开发者可自行尝试和丰富Demo。 14 | 15 | 16 | ### 安装环境(用户端) 17 | 18 | 在本步骤之前,清确保已经安装了本项目的首页的`requirements.txt`的全部依赖。这些依赖能保证服务端正常运行。接着,按照下面的步骤安装用户端的依赖。 19 | 20 | 1. 请在mac系统环境中安装`tkinter`库。具体安装方式可以参考下面代码 21 | 22 | ```shell 23 | brew install python-tk 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | 为了验证系统 python 是否已经正常安装`tkinter`库,可以在终端中输入下面代码 28 | 29 | ```shell 30 | /opt/homebrew/bin/python3 -m tkinter 31 | ``` 32 | 33 | 正常应该返回 34 | 35 | ``` 36 | 2024-12-14 15:29:04.041 Python[7161:122540731] +[IMKClient subclass]: chose IMKClient_Legacy 37 | 2024-12-14 15:29:04.041 Python[7161:122540731] +[IMKInputSession subclass]: chose IMKInputSession_Legacy 38 | ``` 39 | 40 | 2. 创建虚拟环境,请不要使用`conda`,`virtualenv`等工具,因为这些工具会导致`tkinter`库无法正常使用 41 | 请将下面代码复制到终端中执行,并替换`/Users/zr/Code/CogAgent/venv`为你的实际路径 42 | 43 | ```shell 44 | /opt/homebrew/bin/python3 -m venv --copies /Users/zr/Code/CogAgent/venv 45 | ``` 46 | 47 | 你需要验证虚拟环境是否创建成功,可以在终端中输入下面代码 48 | 49 | ```shell 50 | /Users/zr/Code/CogAgent/venv/bin/python3 -m tkinter 51 | ``` 52 | 53 | 请确保在本文件夹下创建一个`caches`文件夹,用于保存模型执行中的照片。 54 | 55 | ```shell 56 | mkdir caches 57 | ``` 58 | 59 | 3. 确保你的电脑设备给予了足够权限,一般来说,你执行代码的软件需要赋予截图,录屏以及模拟键盘鼠标操作的权限。我们展现了如何在Mac设备中开启这些权限。 60 | 61 | | 开启录屏权限 | 开启键盘鼠标操作权限 | 62 | |---------------------------|---------------------------| 63 | | ![1](../assets/app_1.png) | ![2](../assets/app_2.png) | 64 | 65 | 在这里,我们使用 `Pycharm` 来运行 `client.py` 程序。因此,需要给予软件 `Pycharm` 以及 `terminal`权限,如果你仅仅在终端执行,仅需要给予 66 | `terminal`权限。`VSCode` 等其他IDE操作方式同理。 67 | 68 | 69 | ### 运行服务端 70 | 71 | 在远程服务器拉起服务 72 | 73 | ```shell 74 | python openai_demo.py --model_path THUDM/cogagent-9b-20241220 --host 0.0.0.0 --port 7870 75 | ``` 76 | 77 | 或者使用vllm启动远程服务 78 | 79 | ```shell 80 | python vllm_openai_demo.py --model_path THUDM/cogagent-9b-20241220 --host 0.0.0.0 --port 7870 81 | ``` 82 | 83 | 这将在服务器拉起一个模仿`OpenAI`接口格式的服务端,默认端口部署在 http://0.0.0.0:7870 。 84 | 85 | ### 运行客户端 86 | 87 | 运行客户端,请确定以下信息: 88 | 89 | - 请确保服务端已经正常运行,并确认服务器端和本地已经通过内网穿透等技术联通。 90 | - 请确保服务端可以从外网访问,或者通过内网穿透的方式允许自己的本地访问。在我们的代码中,服务端穿透到本地的端口为`7870` 91 | ,所以环境变量应该设置为 http://127.0.0.1:7870/v1 。 92 | - 本 Demo没有设置 API, 因此`api_key`参数设置为`EMPTY`。 93 | 94 | ```shell 95 | python client.py --api_key EMPTY --base_url http://127.0.0.1:7870/v1 --client_name 127.0.0.1 --client_port 7860 --model CogAgent 96 | ``` 97 | 98 | 通过上述命令,你将能在本地运行客户端,连接到服务端,并且使用`cogagent-9b-20241220`模型。 99 | 下图展现了正常启动APP并让模型接管电脑到截图(图中右下角小火箭是APP)。 100 | 101 | ![img.png](../assets/app_gradio.png) 102 | 103 | > 我们无法保证AI的行为的安全性,请在使用时谨慎操作。本示例仅供学术参考,我们不承担由本示例引起的任何法责任。 104 | > 105 | > 模型运行中,你可以随时按下`stop`强制停止模型当前的操作。 106 | > 107 | > 如果你认为当前模型执行正常,没有风险,请不要触碰电脑,模型需要根据实时的电脑截图来确定点击坐标,这是模型正常运行的必要条件。 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve CogAgent / 提交一个 Bug 问题报告来帮助我们改进 CogAgent 3 | body: 4 | - type: textarea 5 | id: system-info 6 | attributes: 7 | label: System Info / 系統信息 8 | description: Your operating environment / 您的运行环境信息 9 | placeholder: Includes Cuda version, Transformers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Transformers版本,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)... 10 | validations: 11 | required: true 12 | 13 | - type: textarea 14 | id: who-can-help 15 | attributes: 16 | label: Who can help? / 谁可以帮助到您? 17 | description: | 18 | Your issue will be replied to more quickly if you can figure out the right person to tag with @ 19 | All issues are read by one of the maintainers, so if you don't know who to tag, just leave this blank and our maintainer will ping the right person. 20 | 21 | Please tag fewer than 3 people. 22 | 23 | 如果您能找到合适的标签 @,您的问题会更快得到回复。 24 | 所有问题都会由我们的维护者阅读,如果您不知道该标记谁,只需留空,我们的维护人员会找到合适的开发组成员来解决问题。 25 | 26 | 标记的人数应该不超过 3 个人。 27 | 28 | If it's not a bug in these three subsections, you may not specify the helper. Our maintainer will find the right person in the development group to solve the problem. 29 | 30 | 如果不是这三个子版块的bug,您可以不指明帮助者,我们的维护人员会找到合适的开发组成员来解决问题。 31 | 32 | placeholder: "@Username ..." 33 | 34 | - type: checkboxes 35 | id: information-scripts-examples 36 | attributes: 37 | label: Information / 问题信息 38 | description: 'The problem arises when using: / 问题出现在' 39 | options: 40 | - label: "The official example scripts / 官方的示例脚本" 41 | - label: "My own modified scripts / 我自己修改的脚本和任务" 42 | 43 | - type: textarea 44 | id: reproduction 45 | validations: 46 | required: true 47 | attributes: 48 | label: Reproduction / 复现过程 49 | description: | 50 | Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit. 51 | If you have code snippets, error messages, stack traces, please provide them here as well. 52 | Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 53 | Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code. 54 | 55 | 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。 56 | 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。 57 | 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 58 | 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。 59 | placeholder: | 60 | Steps to reproduce the behavior/复现Bug的步骤: 61 | 62 | 1. 63 | 2. 64 | 3. 65 | 66 | - type: textarea 67 | id: expected-behavior 68 | validations: 69 | required: true 70 | attributes: 71 | label: Expected behavior / 期待表现 72 | description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .vscode/ 3 | draft* 4 | app/caches 5 | .DS_Store 6 | *.zip 7 | .test* 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 118 | .pdm.toml 119 | .pdm-python 120 | .pdm-build/ 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ -------------------------------------------------------------------------------- /finetune/README.md: -------------------------------------------------------------------------------- 1 | # CogAgent 模型微调 2 | 3 | Read this in [English](README_en.md) 4 | 5 | 本 demo 中,你将体验到如何微调 CogAgent 开源模型。 请严格按照文档的步骤进行操作,以避免不必要的错误。 6 | 7 | ## 多轮微调格式 8 | 9 | 多轮微调示例采用 CogAgent 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`。 10 | 11 | 对于数据文件,样例采用如下格式: 12 | 13 | 对于 cogagent-9b-20241220 模型,您应该按照以下格式整理数据, 保证每一条数据都有图片并且`user`的History steps部分包括之前已有的全部操作。 14 | 以下第一条数据是中间步骤,模型会返回`<<一般操作>>`等,第二条数据是完整过程,模型会返回`<>`。 15 | 16 | ```json 17 | { 18 | "messages": [ 19 | { 20 | "role": "user", 21 | "content": "Task: 在系统设置的桌面与程序坞部分,开启调度中心板块中“显示器具有单独空间”的选项\n(Platform: Mac)\n(Answer in Action-Operation-Sensitive format.)\nHistory steps: \n0. CLICK(box=[[588,946,616,985]], element_info='系统设置')\t点击屏幕下方的“系统设置”菜单,从而打开系统设置\n1. CLICK(box=[[227,561,297,585]], element_info='桌面与程序坞')\t点击左侧菜单中的“桌面与程序坞”选项,进入桌面与程序坞的设置界面。\n2. SCROLL_DOWN(box=[[367,39,691,929]], step_count=75, element_info='滚动')\t在当前屏幕的右侧滚动区域中,向下滚动以查看更多选项。\n", 22 | "image": "images/0000000000336.png" 23 | }, 24 | { 25 | "role": "assistant", 26 | "content": "Action: 点击“调度中心”板块中的“显示器具有单独空间”选项,从而开启该功能。\nGrounded Operation:CLICK(box=[[655,842,671,857]], element_info='“显示器具有单独空间”的开关按钮')\n<<一般操作>>" 27 | } 28 | ] 29 | }, 30 | { 31 | "messages": [ 32 | { 33 | "role": "user", 34 | "content": "Task: 通过启动台打开系统设置,导航到“网络”部分,将DNS设置手动IP地址为8.8.4.4。\n(Platform: Mac)\n(Answer in Action-Operation-Sensitive format.)\nHistory steps: \n0. CLICK(box=[[7,6,27,31]], element_info='苹果标志')\t点击屏幕左上角的“苹果”图标,以便打开系统设置。\n1. CLICK(box=[[3,77,161,99]], element_info='系统设置…')\t点击屏幕左上角菜单栏中的“系统设置…”选项,以打开系统设置界面。\n2. CLICK(box=[[7,478,125,512]], element_info='网络')\t点击左侧菜单中的“网络”选项,进入网络设置界面。\n3. CLICK(box=[[155,217,459,270]], element_info='Wi-Fi, 已连接')\t点击屏幕中间偏上的“Wi-Fi”选项,进入Wi-Fi设置界面。\n4. CLICK(box=[[405,280,447,296]], element_info='详细信息…')\t点击Wi-Fi名称旁边的“详细信息…”按钮,以查看当前Wi-Fi的详细信息。\n5. CLICK(box=[[27,344,45,362]], element_info='DNS')\t点击左侧菜单中的“DNS”选项,进入DNS设置界面。\n6. CLICK(box=[[166,308,442,399]], element_info='8.8.8.8')\t点击DNS服务器地址输入框,准备输入新的DNS服务器地址。\n7. TYPE(box=[[163,308,443,322]], text='8.8.4.4', element_info='ip地址输入框')\t在DNS服务器地址输入框中输入8.8.4.4,以设置手动IP地址。\n8. CLICK(box=[[410,636,440,651]], element_info='好')\t点击界面右下角的“好”按钮,以确认并保存DNS设置的更改。\n", 35 | "image": "images/0000000000482.png" 36 | }, 37 | { 38 | "role": "assistant", 39 | "content": "Action: 已经完成了“通过启动台打开系统设置,导航到‘网络’部分,将DNS设置手动IP地址为8.8.4.4”这一任务,执行结束。\nGrounded Operation:END()\n<>" 40 | } 41 | ] 42 | } 43 | ``` 44 | 45 | ## 配置文件 46 | 47 | 微调配置文件位于 `config` 目录下,包括以下文件: 48 | 49 | 1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed 配置文件。 50 | 51 | 2. `lora.yaml / sft.yaml`: 模型不同方式的配置文件,包括模型参数、优化器参数、训练参数等。 部分重要参数解释如下: 52 | + data_config 部分 53 | + train_file: 训练数据集的文件路径。 54 | + val_file: 验证数据集的文件路径。 55 | + test_file: 测试数据集的文件路径。 56 | + num_proc: 在加载数据时使用的进程数量。 57 | + freezeV: 是否冻结vision部分参数。 58 | + max_input_length: 输入序列的最大长度, 请注意,在模型实际的推理中,还会固定加入`1600` token 的图像编码结果。 59 | + max_output_length: 输出序列的最大长度。 60 | + training_args 部分 61 | + output_dir: 用于保存模型和其他输出的目录。 62 | + max_steps: 训练的最大步数。 63 | + per_device_train_batch_size: 每个设备(如 GPU)的训练批次大小。 64 | + dataloader_num_workers: 加载数据时使用的工作线程数量。 65 | + remove_unused_columns: 是否移除数据中未使用的列。 66 | + save_strategy: 模型保存策略(例如,每隔多少步保存一次)。 67 | + save_steps: 每隔多少步保存一次模型。 68 | + log_level: 日志级别(如 info)。 69 | + logging_strategy: 日志记录策略。 70 | + logging_steps: 每隔多少步记录一次日志。 71 | + per_device_eval_batch_size: 每个设备的评估批次大小。 72 | + evaluation_strategy: 评估策略(例如,每隔多少步进行一次评估)。 73 | + eval_steps: 每隔多少步进行一次评估。 74 | + predict_with_generate: 是否使用生成模式进行预测。 75 | + generation_config 部分 76 | + max_new_tokens: 生成的最大新 token 数量。 77 | + peft_config 部分 78 | + peft_type: 使用的参数有效调整类型 (支持 LORA 和 PREFIX_TUNING)。 79 | + task_type: 任务类型,这里是因果语言模型 (不要改动)。 80 | + Lora 参数: 81 | + r: LoRA 的秩。 82 | + lora_alpha: LoRA 的缩放因子。 83 | + lora_dropout: 在 LoRA 层使用的 dropout 概率。 84 | 85 | ## 开始微调 86 | 87 | 通过以下代码执行 **单机多卡/多机多卡** 运行,这是使用 `deepspeed` 作为加速方案的,您需要安装 `deepspeed`。 88 | CogAgent1.5数据集由您自行准备,接着,按照此命令运行: 89 | 90 | ```shell 91 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune.py data/CogAgentData/ THUDM/cogagent-2-9b configs/sft.yaml 92 | ``` 93 | 94 | 通过以下代码执行 **单机单卡** 运行。 95 | 96 | ```shell 97 | python finetune.py data/CogAgentData/ THUDM/cogagent-9b-20241220 configs/lora.yaml 98 | ``` 99 | 100 | ## 从保存点进行微调 101 | 102 | 如果按照上述方式进行训练,每次微调都会从头开始,如果你想从训练一半的模型开始微调,你可以加入第四个参数,这个参数有两种传入方式: 103 | 104 | 1. `yes`, 自动从最后一个保存的 Checkpoint开始训练 105 | 2. `XX`, 断点号数字 例 `600` 则从序号600 Checkpoint开始训练 106 | 107 | 例如,这就是一个从最后一个保存点继续微调的示例代码 108 | 109 | ```shell 110 | python finetune.py data/CogAgentData/ THUDM/cogagent-9b-20241220 configs/lora.yaml yes 111 | ``` 112 | 113 | ## 采用华为昇腾计算计算设备进行微调 114 | 115 | 如果你需要使用`Ascend NPU`设备,例如`ATLAS 300 A2`,你需要解除注释: 116 | 117 | ```python 118 | import torch_npu 119 | from torch_npu.contrib import transfer_to_npu 120 | ``` 121 | 122 | 之后就能正常运行微调程序。 123 | -------------------------------------------------------------------------------- /app/README_en.md: -------------------------------------------------------------------------------- 1 | # Deploying the Demo 2 | 3 | ## System Check 4 | 5 | The test environment for this Demo is as follows: 6 | 7 | ``` 8 | macOS Sequoia: Version 15.0.1 (24A348) 9 | Memory: 16GB 10 | Python Version: 3.13.1 / 3.10.16 11 | ``` 12 | 13 | For configurations different from the above (e.g., Windows OS, Linux desktop OS), we have not conducted any testing. The dependency library `pyautogui` supports Windows OS, and developers can try and enrich the Demo themselves. 14 | 15 | ## Installation (Client Side) 16 | 17 | Before proceeding with the following steps, please ensure that you have installed all dependencies listed in the `requirements.txt` on the project homepage. These dependencies guarantee the server-side runs properly. Next, follow the steps below to install the client-side dependencies. 18 | 19 | 1. Install the `tkinter` library on your macOS system. You can refer to the following commands: 20 | 21 | ```shell 22 | brew install python-tk 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | To verify that the system’s Python has installed the `tkinter` library correctly, type the following in the terminal: 27 | 28 | ```shell 29 | /opt/homebrew/bin/python3 -m tkinter 30 | ``` 31 | 32 | Normally, it should return: 33 | 34 | ``` 35 | 2024-12-14 15:29:04.041 Python[7161:122540731] +[IMKClient subclass]: chose IMKClient_Legacy 36 | 2024-12-14 15:29:04.041 Python[7161:122540731] +[IMKInputSession subclass]: chose IMKInputSession_Legacy 37 | ``` 38 | 39 | 2. Create a virtual environment. Please do not use `conda`, `virtualenv`, or similar tools, as they may cause the `tkinter` library to malfunction. Copy and run the following command in the terminal, replacing `/Users/zr/Code/CogAgent/venv` with your actual path: 40 | 41 | ```shell 42 | /opt/homebrew/bin/python3 -m venv --copies /Users/zr/Code/CogAgent/venv 43 | ``` 44 | 45 | You need to verify that the virtual environment has been created successfully. In the terminal, run: 46 | 47 | ```shell 48 | /Users/zr/Code/CogAgent/venv/bin/python3 -m tkinter 49 | ``` 50 | 51 | Please ensure that a `caches` folder is created in this directory to store images during model execution: 52 | 53 | ```shell 54 | mkdir caches 55 | ``` 56 | 57 | 3. Ensure that your Mac has granted sufficient permissions to your application. Typically, the software you use to run the code needs screen capture, screen recording, and keyboard/mouse operation permissions. Below is an example of how to enable these permissions on Mac: 58 | 59 | | Enable Screen Recording Permission | Enable Keyboard and Mouse Permissions | 60 | |------------------------------------------|----------------------------------------------| 61 | | ![1](../assets/app_1.png) | ![2](../assets/app_2.png) | 62 | 63 | Here, we use **Pycharm** to run `client.py`. Therefore, you need to grant permissions to **Pycharm** and **Terminal**. If you only run code in the terminal, you need to grant permissions to **Terminal** only. The same applies to **VSCode** or other IDEs. 64 | 65 | ## Running the Server 66 | 67 | Start the server on the remote machine: 68 | 69 | ```shell 70 | python openai_demo.py --model_path THUDM/cogagent-9b-20241220 --host 0.0.0.0 --port 7870 71 | ``` 72 | 73 | or use vLLM to start the server: 74 | 75 | ```shell 76 | python vllm_openai_demo.py --model_path THUDM/cogagent-9b-20241220 --host 0.0.0.0 --port 7870 77 | ``` 78 | 79 | This command launches a server on the machine that mimics the `OpenAI` interface format. By default, it will be available at http://0.0.0.0:7870. 80 | 81 | ## Running the Client 82 | 83 | To run the client, please ensure the following: 84 | 85 | - Confirm that the server is running properly, and verify that the remote server and your local machine are connected (through a direct connection or via any tunneling method). 86 | - Make sure the server can be accessed from outside networks, or that your machine can reach the server via port forwarding or other tunneling methods. In our code, the server is forwarded to your local port `7870`. Hence the environment variable should be set to `http://127.0.0.1:7870/v1`. 87 | - The Demo does not enforce an API key, so set `api_key` to `EMPTY`. 88 | 89 | Run the client with: 90 | 91 | ```shell 92 | python client.py --api_key EMPTY --base_url http://127.0.0.1:7870/v1 --client_name 127.0.0.1 --client_port 7860 --model cogagent1.5-9b 93 | ``` 94 | 95 | With the above commands, you can run the client locally, connect to the server, and use the cogagent-9b-20241220 model. 96 | The image below shows the normal process of launching the app and having the model take control of the computer to capture screenshots (the small rocket in the bottom right corner of the image represents the app). 97 | 98 | ![img.png](../assets/app_gradio.png) 99 | 100 | > We cannot guarantee the safety of AI behavior. Please use this demo with caution. This example is for academic reference only, and we assume no legal responsibility arising from it. 101 | > 102 | > You can press `stop` at any time to forcefully terminate the current model operation. 103 | > 104 | > If you believe the model is operating normally without risk, please do not interact with the computer. The model relies on real-time screenshots to determine click coordinates, which is necessary for normal operation. 105 | -------------------------------------------------------------------------------- /app/register.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file registers all meta-operations that CogAgent1.5-9B Model can perform, 3 | along with the required keywords that must be included for each meta-operation. 4 | 5 | key: value 6 | meta-operation: keyword 7 | """ 8 | 9 | import pyautogui 10 | import pyperclip 11 | import time 12 | import os 13 | import platform 14 | 15 | pyautogui.FAILSAFE = True 16 | pyautogui.PAUSE = 1 17 | 18 | # TODO: Support other META_PARAMETER and other META_OPERATION 19 | META_PARAMETER = { 20 | # Meta-operations need to contain keywords 21 | "CLICK": ["box"], 22 | "DOUBLE_CLICK": ["box"], 23 | "RIGHT_CLICK": ["box"], 24 | "TYPE": ["box", "text"], 25 | "HOVER": ["box"], 26 | "SCROLL_DOWN": ["box"], 27 | "SCROLL_UP": ["box"], 28 | # "SCROLL_RIGHT": ["box"], 29 | # "SCROLL_LEFT": ["box"], 30 | "KEY_PRESS": ["key"], 31 | "LAUNCH": ["app"], 32 | # "QUOTE_TEXT": ["box"], 33 | # "QUOTE_CLIPBOARD": ["output"], 34 | # "TEXT_FORMAT": ["input"], 35 | # "LLM": ["prompt"], 36 | "END": [""], 37 | } 38 | 39 | 40 | def identify_os(): 41 | os_detail = platform.platform() 42 | if "mac" in os_detail.lower(): 43 | return "Mac" 44 | elif "windows" in os_detail.lower(): 45 | return "Win" 46 | else: 47 | raise ValueError( 48 | f"This {os_detail} operating system is not currently supported!" 49 | ) 50 | 51 | 52 | def paste(text): 53 | pyperclip.copy(text) 54 | time.sleep(1) 55 | if identify_os() == "Mac": 56 | with pyautogui.hold("command"): 57 | pyautogui.press("v") 58 | elif identify_os() == "Win": 59 | with pyautogui.hold("ctrl"): 60 | pyautogui.press("v") 61 | 62 | 63 | def click(params): 64 | """ 65 | Meta-operation: CLICK 66 | CLICK: Simulate a left-click at the center position of the box. 67 | """ 68 | pyautogui.doubleClick(params["box"]) 69 | 70 | 71 | def double_click(params): 72 | """ 73 | Meta-operation: DOUBLE_CLICK 74 | DOUBLE_CLICK: Simulate a double-click the center position of the box. 75 | """ 76 | pyautogui.doubleClick(params["box"]) 77 | 78 | 79 | def right_click(params): 80 | """ 81 | Meta-operation: RIGHT_CLICK 82 | RIGHT_CLICK: Simulate a right-click at the center position of the box. 83 | """ 84 | pyautogui.rightClick(params["box"]) 85 | 86 | 87 | def type_input(params): 88 | """ 89 | Meta-operation: TYPE 90 | TYPE: At the center position of the box, simulate keyboard input to enter text. 91 | """ 92 | paste(params["text"]) 93 | pyautogui.press("Return") 94 | 95 | 96 | def hover(params): 97 | """ 98 | Meta-operation: HOVER 99 | HOVER: Move the mouse to the center position of the box. 100 | """ 101 | pyautogui.moveTo(params["box"]) 102 | 103 | 104 | def scroll_down(params): 105 | """ 106 | Meta-operation: SCROLL_DOWN 107 | SCROLL_DOWN: Move the mouse to the center position of the box, then scroll the screen downward. 108 | """ 109 | pyautogui.moveTo(params["box"]) 110 | pyautogui.scroll(-10) 111 | 112 | 113 | def scroll_up(params): 114 | """ 115 | Meta-operation: SCROLL_UP 116 | SCROLL_UP: Move the mouse to the center position of the box, then scroll the screen up. 117 | """ 118 | pyautogui.moveTo(params["box"]) 119 | pyautogui.scroll(10) 120 | 121 | 122 | def key_press(params): 123 | """ 124 | Meta-operation: KEY_PRESS 125 | TYPE: Press a special key on the keyboard. eg: KEY_PRESS(key='Return'). 126 | """ 127 | pyautogui.press(params["key"]) 128 | 129 | 130 | def end(params): 131 | print("Workflow Completed!") 132 | 133 | 134 | def launch(params): 135 | system_app_dir = "/System/Applications" # For Mac 136 | applications_dir = "/Applications" # For Mac 137 | applications = [app for app in os.listdir(applications_dir) if app.endswith(".app")] 138 | system_apps = [app for app in os.listdir(system_app_dir) if app.endswith(".app")] 139 | all_apps = applications + system_apps 140 | for app in all_apps: 141 | if params["app"][1:-1] in app: 142 | app_dir = applications_dir + "/" + app 143 | os.system(f"open -a '{app_dir}'") 144 | 145 | META_OPERATION = { 146 | # Defining meta-operation functions 147 | "CLICK": click, 148 | "DOUBLE_CLICK": double_click, 149 | "RIGHT_CLICK": right_click, 150 | "TYPE": type_input, 151 | "HOVER": hover, 152 | "SCROLL_DOWN": scroll_down, 153 | "SCROLL_UP": scroll_up, 154 | # "SCROLL_RIGHT": ["box"], 155 | # "SCROLL_LEFT": ["box"], 156 | "KEY_PRESS": key_press, 157 | "LAUNCH": launch, 158 | # "QUOTE_TEXT": ["box"], 159 | # "QUOTE_CLIPBOARD": ["output"], 160 | # "TEXT_FORMAT": ["input"], 161 | # "LLM": ["prompt"], 162 | "END": end, 163 | } 164 | 165 | 166 | def locateOnScreen(image, screenshotIm): 167 | print(image, screenshotIm) 168 | start = time.time() 169 | while True: 170 | try: 171 | # the locateAll() function must handle cropping to return accurate coordinates, 172 | # so don't pass a region here. 173 | retVal = pyautogui.locate(image, screenshotIm) 174 | try: 175 | screenshotIm.fp.close() 176 | except AttributeError: 177 | # Screenshots on Windows won't have an fp since they came from 178 | # ImageGrab, not a file. Screenshots on Linux will have fp set 179 | # to None since the file has been unlinked 180 | pass 181 | if retVal or time.time() - start > 0: 182 | return retVal 183 | except: 184 | if time.time() - start > 0: 185 | return None 186 | 187 | 188 | def convert_to_meta_operation(Grounded_Operation): 189 | detailed_operation = {} 190 | if Grounded_Operation["operation"] in META_PARAMETER: 191 | detailed_operation["meta"] = Grounded_Operation["operation"] 192 | for value in META_PARAMETER[Grounded_Operation["operation"]]: 193 | if value in Grounded_Operation: 194 | if value == "box": 195 | # number = (left, top, width, height) 196 | numbers = Grounded_Operation["box"] 197 | box = [num / 1000 for num in numbers] 198 | # box = (left/1000, top/1000, width/1000, height/1000) 199 | width, height = pyautogui.size() 200 | # x_min, y_min, x_max, y_max = left, top, right, down) 201 | x_min, y_min, x_max, y_max = [ 202 | int(coord * width) if i % 2 == 0 else int(coord * height) 203 | for i, coord in enumerate(box) 204 | ] 205 | x, y = (x_min + x_max) / 2, (y_min + y_max) / 2 206 | detailed_operation[value] = (x, y) 207 | else: 208 | detailed_operation[value] = Grounded_Operation[value][1:-1] 209 | print(detailed_operation) 210 | return detailed_operation 211 | else: 212 | raise "Wrong operation or operation not registered!" 213 | 214 | 215 | def agent(Grounded_Operation): 216 | detailed_operation = convert_to_meta_operation(Grounded_Operation) 217 | META_OPERATION[detailed_operation["meta"]](detailed_operation) 218 | time.sleep(2) 219 | return detailed_operation["meta"] 220 | -------------------------------------------------------------------------------- /MODEL_LICENSE: -------------------------------------------------------------------------------- 1 | The CogAgent License 2 | 3 | 1. Definitions 4 | 5 | "Licensor" refers to the CogAgent model team distributing its software. 6 | "Software" refers to the CogAgent model parameters provided under this license. 7 | 8 | 2. License Grant 9 | 10 | Subject to the terms and conditions of this license, the Licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license. 11 | This license permits you to use all open-source models in this repository for free for academic research purposes. Users wishing to use the model for commercial purposes must complete registration [here](https://open.bigmodel.cn/mla/form). Registered users may use the model for commercial activities free of charge but must comply with all terms and conditions of this license. 12 | The above copyright statement and this license statement must be included in all copies or significant portions of the software. 13 | If you distribute or provide materials from THUDM/Zhipu AI regarding the CogAgent open-source model (or any derivative works thereof) or use any materials (including all open-source models in the CogAgent series) in products or services, you must: 14 | 15 | (A) Provide a copy of this agreement with any such materials from THUDM/Zhipu AI; 16 | (B) Prominently display “Built with CogAgent” on relevant websites, user interfaces, blog posts, about pages, or product documentation. 17 | If you use materials from THUDM/Zhipu AI's CogAgent open-source model to create, train, fine-tune, or otherwise improve distributed or publicly available AI models, you must also prepend “CogAgent” to the names of any such AI models. 18 | 19 | 3. Restrictions 20 | 21 | You must comply with applicable laws, regulations, ethical standards, and other requirements in your jurisdiction when using this software. You must independently obtain permissions, licenses, or other access rights required by third-party software/applications and make prudent and independent judgments on all operational decisions. You must not use the software or implement the following actions in an improper manner: 22 | (1) Use, copy, modify, merge, publish, distribute, or create derivative works of this software, in whole or in part, for any military or illegal purposes; 23 | (2) Engage in activities that harm national security, public interest, social morals, or infringe upon others' trade secrets, intellectual property, reputation, portrait rights, property rights, or other rights and interests; 24 | (3) Use the software for fraud, phishing, spamming, misleading, bullying, harassment, discrimination, hate promotion, or dissemination of false information; 25 | (4) Use the software to make automated high-risk decisions in fields such as health, education, credit, finance, or critical infrastructure management, which significantly impact individual or societal safety, rights, or welfare; 26 | (5) Rely on the software to perform major operations, including but not limited to monetary transactions, large purchases, placing irreversible orders, or publishing content detrimental to others' rights or social ethics; 27 | (6) Use the software in services requiring subject qualifications or professional review, or as a substitute for professional services in fields such as medicine, law, journalism, education, or financial investment; 28 | (7) Use the software dishonestly, claim or imply AI-generated content is human-created, or disguise human-created works as AI-generated content; 29 | (8) Engage in illegal network intrusion, disrupt normal network functionality, steal network data, or deliberately spread malicious programs or viruses that harm network security and order; 30 | (9) Collect personal information unlawfully or use the software in a way that infringes upon any third party’s personal information protection rights or privacy. 31 | The Licensor bears no responsibility for your actions while using this software, and you shall assume all resulting liabilities. 32 | 33 | 4. Disclaimer 34 | 35 | The software is provided "as is" without any express or implied warranties, including but not limited to warranties of merchantability, fitness for a particular purpose, or non-infringement. 36 | The Licensor does not guarantee the content or operations executed by the software are entirely accurate, reliable, functional, timely, secure, error-free, uninterrupted, or continuously stable. The Licensor is not liable for risks arising from your operational errors or software defects. 37 | In no event shall the authors or copyright holders be liable for any claims, damages, or other liabilities, whether in contract, tort, or otherwise, arising from, out of, or in connection with the software or the use or other dealings in the software. 38 | 39 | 5. Limitation of Liability 40 | 41 | To the maximum extent permitted by applicable law, in no event and under no legal theory shall the Licensor be liable for any direct, indirect, special, incidental, exemplary, or consequential damages, or any other commercial losses, even if the Licensor has been advised of the possibility of such damages. 42 | 43 | 6. Dispute Resolution 44 | 45 | This license is governed by and construed in accordance with the laws of the People's Republic of China. Any disputes arising out of or in connection with this license shall be submitted to the People's Court of Haidian District, Beijing. 46 | Please note that this license may be updated to a more comprehensive version. For any questions about the license or copyright, please contact us at license@zhipuai.cn or opensource@zhipuai.cn. 47 | 48 | CogAgent系列模型开源协议 49 | 50 | 1. 定义 51 | 52 | “许可方”是指分发其软件的 CogAgent系列 模型团队。 53 | “软件”是指根据本许可提供的 CogAgent系列 模型参数。 54 | 55 | 2. 许可授予 56 | 57 | 根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。 58 | 本许可仅允许您免费使用本仓库中的所有开源模型进行学术研究,对于希望将模型用于商业目的的用户,需在[这里](https://open.bigmodel.cn/mla/form)完成登记。经过登记的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。 59 | 上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。 60 | 如果您分发或提供 THUDM / 智谱AI 关于 CogAgent系列开源模型的材料(或其任何衍生作品),或使用其中任何材料(包括 CogAgent系列的所有开源模型)的产品或服务,您应: 61 | 62 | (A) 随任何此类 THUDM / 智谱AI 材料提供本协议的副本; 63 | (B) 在相关网站、用户界面、博客文章、关于页面或产品文档上突出显示 “Built with CogAgent”。 64 | 如果您使用 THUDM / 智谱AI的 CogAgent系列开源模型的材料来创建、训练、微调或以其他方式改进已分发或可用的 AI 模型,您还应在任何此类 AI 模型名称的开头添加 “CogAgent”。 65 | 66 | 3. 限制 67 | 68 | 您在使用中应遵循使用地所适用的法律法规政策、道德规范等要求,在操作应用程序时自行取得第三方软件/应用所需的操作权限、授权或其他准入要求并对所有操作决策进行独立审慎的判断。您不得以以下不当方式使用软件或利用软件实施以下行为: 69 | (1) 出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品; 70 | (2) 利用本软件从事任何危害国家安全和国家统一,危害社会公共利益及公序良俗,侵犯他人商业秘密、知识产权、名誉权、肖像权、财产权等权益的行为; 71 | (3) 用于欺诈、诈骗、发送垃圾短信/邮件、误导、欺凌、骚扰、歧视、宣扬仇恨、传播虚假信息等途径; 72 | (4) 利用软件实施任何决策行为,如在健康、教育、信贷、金融、关键基础设施管理等对个人及社会的安全、权利或福祉有重大影响的领域做出高风险的自动化决策; 73 | (5) 依赖本软件执行任何重大的操作,包括但不限于资金交易、大额消费、下单不可撤销的订单、发布有损他人权益或社会公德的消息等; 74 | (6) 用于任何对主体资格有要求或需要专业人员审查的服务中,或作为专业服务的替代品,包括但不限于医疗、律师、新闻、教育、投资理财等专业领域; 75 | (7) 以不诚实的方式使用,主张或声称人工智能的生成物是人类的作品,或将人类的作品伪装为人工智能的生成物; 76 | (8) 非法侵入网络、干扰网络正常功能、窃取网络数据、故意传播恶意程序或病毒等危害网络安全和网络秩序的活动; 77 | (9) 违法采集他人个人信息,或以可能侵犯任何第三方个人信息保护权及隐私的方式使用本软件。 78 | 许可方不对您使用本软件的行为承担任何责任,由此产生的责任将由您自行承担。 79 | 80 | 4. 免责声明 81 | 82 | 本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 83 | 许可方不保证软件生成的内容及执行的操作百分百准确可靠、功能可用、及时、安全、无错误、不受干扰、无中断、持续稳定、不存在任何故障,AI并不能真正像人类一样理解您输入的内容及指令,如果由于您的操作失误或AI的缺陷导致的风险,许可方不承担相应的责任。 84 | 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,包括但不限于由软件或软件的使用引起、利用软件进行的交易或与软件相关引起的问题。 85 | 86 | 5. 责任限制 87 | 88 | 除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 89 | 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。 90 | 91 | 6. 争议解决 92 | 93 | 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 94 | 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 或 opensource@zhipuai.cn 与我们联系。 95 | -------------------------------------------------------------------------------- /Action_space.md: -------------------------------------------------------------------------------- 1 | # 四种鼠标操作 2 | 3 | 鼠标操作包含:左键单击(`CLICK`)、左键双击(`DOUBLE_CLICK`)、右键单击(`RIGHT_CLICK`)、鼠标悬停(`HOVER`),例如 4 | `CLICK(box=[[387,248,727,317]], element_type='可点击文本', element_info='Click to add Title')`。这四种动作所支持的参数如下表所示: 5 | 6 | | 参数名 | 可选与否 | 解释 | 7 | |--------------|------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 8 | | box | 必选 | 以[[a,b,c,d]]表示屏幕上的一个矩形,其中 a/b/c/d 均为 000 到 999 之间的三位数字。
假设屏幕的宽度为 w,高度为 h。屏幕左上角坐标为(0, 0)。矩形的左上角坐标为 (a / 1000 * w, b / 1000 * h),矩形的右下角坐标为 (c / 1000 * w, d / 1000 * h)。
实际操作的位置为矩形的中点。 | 9 | | element_type | 可选 | 对被操作元素的类型描述,例如“可点击文本” | 10 | | element_info | 可选 | 对被操作元素¬的内容描述,例如“Click to add Title” | 11 | 12 | # 文本输入(TYPE) 13 | 14 | 文本输入是指在给定位置输入文本,例如 15 | `TYPE(box=[[387,249,727,317]], text='CogAgent', element_type='文本输入框', element_info='CogAgent')` 。它所支持的参数如下表所示: 16 | 17 | | 参数名 | 可选与否 | 解释 | 18 | |--------------|------|--------------------------------------------------------------------------------------------| 19 | | box | 必选 | 请参考“四种鼠标操作”中的解释。 | 20 | | element_type | 可选 | 请参考“四种鼠标操作”中的解释。 | 21 | | element_info | 可选 | 请参考“四种鼠标操作”中的解释。 | 22 | | text | 必选 | 需要被输入的文本内容。该参数取值中可能包含形如`__CogName_xxx__`的变量。在实际执行“文本输入”动作时,变量应该被替换成实际的取值。详情请参考[这里](#jump)。 | 23 | 24 | # 四种滚动操作 25 | 26 | 滚动操作包含:向上滚动(`SCROLL_UP`)、向下滚动(`SCROLL_DOWN`)、向左滚动(`SCROLL_LEFT`)、向右滚动(`SCROLL_DOWN`),例如 27 | `SCROLL_DOWN(box=[[000,086,999,932]], element_type='滚动', element_info='滚动', step_count=5)`。这四种动作所支持的参数如下表所示: 28 | 29 | | 参数名 | 可选与否 | 解释 | 30 | |--------------|------|----------------------------------------------------------------------| 31 | | box | 必选 | 请参考“四种鼠标操作”中的解释。 | 32 | | element_type | 可选 | 请参考“四种鼠标操作”中的解释。 | 33 | | element_info | 可选 | 请参考“四种鼠标操作”中的解释。 | 34 | | step_count | 必选 | 滚动的步骤数,滚动一步对应于鼠标滚轮滚动一格。注意:系统设置和应用类型均会影响滚动一步的实际效果,因此模型难以准确预测所需的滚动步骤数。 | 35 | 36 | # 敲击键盘(KEY_PRESS) 37 | 38 | 敲击键盘是指按下并抬起依次给定的按钮,例如`KEY_PRESS(key='F11')`。这一操作类型仅有一个必选的参数`key` 39 | ,表示需要被敲击的按键名称,例如数字键(0~9)、字母(A-Z)。除此之外,`KEY_PRESS`同时支持以下常用按键,如下表所示。 40 | 41 | | | **Windows** | **macos** | 42 | |--------------------|---------------------|-------------------------| 43 | | **换行/回车** | Return | Return | 44 | | **空格** | Space | Space | 45 | | **ctrl 键(左/右)** | Lcontrol / Rcontrol | N/A | 46 | | **alt 键(左/右)** | Lmenu / Rmenu | N/A | 47 | | **control 键(左/右)** | N/A | Control / Right Control | 48 | | **command 键(左/右)** | N/A | Command / Right Command | 49 | | **Shift 键(左/右)** | Lshift / Rshift | Shift / Right Shift | 50 | | **方向键-上** | Up | Up Arrow | 51 | | **方向键-下** | Down | Down Arrow | 52 | | **方向键-左** | Left | Left Arrow | 53 | | **方向键-右** | Right | Right Arrow | 54 | 55 | # 手势(组合键) 56 | 57 | 使用组合键,例如`ctrl+f`来进行查找。对应的结构化表达为 58 | `GESTURE(actions=[KEY_DOWN(key='Lcontrol'), KEY_PRESS(key='A'), KEY_UP(key='Lcontrol')])`。`GESTURE`仅包含一个参数 59 | `actions`,该参数的取值为一个列表,列表中的每一个元素为以下三个动作之一: 60 | 61 | 1. `KEY_DOWN`:按下某一个键,但不抬起; 62 | 63 | 2. `KEY_PRESS`:敲击某一个按键,即按下并抬起; 64 | 65 | 3. `KEY_UP`:松开某一个已经按下的按键。 66 | 67 | # 启动应用或链接(LAUNCH) 68 | 69 | 直接打开某个应用,或者在浏览器中打开某个链接。`LAUNCH`操作接收`app`和`url`两个参数,前者表示需要打开的应用名,后者表示需要打开的链接。如果同时给定两个参数,则仅有 70 | `url`生效。例如: 71 | 72 | 1. `LAUNCH(app='设置', url='None')`:打开系统设置; 73 | 74 | 2. `LAUNCH(app='None', url='``baidu.com``')`:打开百度首页。 75 | 76 | # 引用文本内容(QUOTE_TEXT) 77 | 78 | 识别并处理给定区域中的文本内容,将结果存储在变量中供后续使用。例如: 79 | 80 | 1. 81 | `QUOTE_TEXT(box=[[387,249,727,317]], element_type='文本', element_info='券后价:17.00', output='__CogName_商品价格__', result='17.00')`; 82 | 83 | 2. 84 | `QUOTE_TEXT(box=[[000,086,999,932]], auto_scroll=True, element_type='窗口', element_info='CogAgent技术报告博客', output='__CogName_技术报告__')`。 85 | 86 | 它所支持的参数如下表所示: 87 | 88 | | 参数名 | 可选与否 | 解释 | 89 | |--------------|------|---------------------------------------------------------------------------------------------------------------------| 90 | | box | 必选 | 请参考“四种鼠标操作”中的解释。 | 91 | | element_type | 可选 | 请参考“四种鼠标操作”中的解释。 | 92 | | element_info | 可选 | 请参考“四种鼠标操作”中的解释。 | 93 | | output | 必选 | 变量名,表示引用结果的存储位置。格式为`__CogName_xxx__`。 | 94 | | result | 可选 | 表示文本引用的结果。如果文本引用结果过长,则`result`取值中会带有省略号,甚至没有这一参数。在这种情况下,CogAgent 端侧应用需要自行调用 OCR 服务来获取引用结果。 | 95 | | auto_scroll | 可选 | 默认值为`False`。如果`auto_scroll`为真,则 CogAgent 端侧应用需要自行向下滚动列表直到列表底部,同时获取列表内容作为医用文本的结果。当需要被引用的文本很长的时候,需要设置`auto_scroll`为真。 | 96 | 97 | # 调用大语言模型(LLM)与变量使用 98 | 99 | 组织提示词并调用大预言模型来计算结果。这一动作所接收的参数如下所示: 100 | 101 | | 参数名 | 可选与否 | 解释 | 102 | |--------|------|-----------------------------------------------------------------------------------------| 103 | | prompt | 必选 | 调用大语言模型所使用的 prompt,其中使用到的变量名会被替换成实际的取值。 | 104 | | output | 必选 | 变量名,表示调用大语言模型的结果所存储的位置。格式为`__CogName_xxx__`。 | 105 | | result | 可选 | 表示调用大语言模型的结果。如果结果过长,则`result`取值中会带有省略号,甚至没有这一参数。在这种情况下,CogAgent 端侧应用需要自行调用大语言模型服务来获取结果。 | 106 | 107 | 例如,通过以下两个操作,可以总结本页面的全部内容: 108 | 109 | 1. 引用本页面的全部内容。由于页面中内容很多,需要在引用时设置`auto_scroll=True` 110 | 111 | 1. `QUOTE_TEXT(box=[[000,086,999,932]], auto_scroll=True, element_type='窗口', element_info='CogAgent技术报告博客', output='__CogName_技术报告__')` 112 | 113 | 2. 调用大预言模型,总结技术报告的内容。技术报告的内容已经存储在变量`__CogName_技术报告__`中,因此应该在参数 114 | `prompt`中直接使用这一变量;在调用大语言模型生成总结内容时,`__CogName_技术报告__`需要被替换成实际的取值。 115 | 1. `LLM(prompt='总结以下内容:__CogName_技术报告__', output='__CogName_技术报告总结__')` 116 | 117 | # 引用剪贴板内容(QUOTE_CLIPBOARD) 118 | 119 | 将剪贴板中的内容存储在某一个变量中,供后续步骤使用。很多网页/应用会提供“点击复制到剪贴板”的功能。`QUOTE_CLIPBOARD` 120 | 可以让模型快速获得并使用剪贴板的内容。 121 | 122 | 以下是一个结构化表达的例子: 123 | `QUOTE_CLIPBORAD(output='__CogName_快速排序代码__', result='def quick_sort(arr):\n\tif len(arr) <= 1:\n\t\treturn arr\n\t...')` 124 | 125 | # 结束(END) 126 | 127 | 一个特殊的操作,表示任务已经完成。 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /inference/cli_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import torch 5 | from PIL import Image, ImageDraw 6 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 7 | from typing import List 8 | 9 | 10 | def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str): 11 | """ 12 | Draws red bounding boxes on the given image and saves it. 13 | 14 | Parameters: 15 | - image (PIL.Image.Image): The image on which to draw the bounding boxes. 16 | - boxes (List[List[float]]): A list of bounding boxes, each defined as [x_min, y_min, x_max, y_max]. 17 | Coordinates are expected to be normalized (0 to 1). 18 | - save_path (str): The path to save the updated image. 19 | 20 | Description: 21 | Each box coordinate is a fraction of the image dimension. This function converts them to actual pixel 22 | coordinates and draws a red rectangle to mark the area. The annotated image is then saved to the specified path. 23 | """ 24 | draw = ImageDraw.Draw(image) 25 | for box in boxes: 26 | x_min = int(box[0] * image.width) 27 | y_min = int(box[1] * image.height) 28 | x_max = int(box[2] * image.width) 29 | y_max = int(box[3] * image.height) 30 | draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) 31 | image.save(save_path) 32 | 33 | 34 | def main(): 35 | """ 36 | A continuous interactive demo using the CogAgent1.5 model with selectable format prompts. 37 | The output_image_path is interpreted as a directory. For each round of interaction, 38 | the annotated image will be saved in the directory with the filename: 39 | {original_image_name_without_extension}_{round_number}.png 40 | 41 | Example: 42 | python cli_demo.py --model_dir THUDM/cogagent-9b-20241220 --platform "Mac" --max_length 4096 --top_k 1 \ 43 | --output_image_path ./results --format_key status_action_op_sensitive 44 | """ 45 | 46 | parser = argparse.ArgumentParser( 47 | description="Continuous interactive demo with CogAgent model and selectable format." 48 | ) 49 | parser.add_argument( 50 | "--model_dir", required=True, help="Path or identifier of the model." 51 | ) 52 | parser.add_argument( 53 | "--platform", 54 | default="Mac", 55 | help="Platform information string (e.g., 'Mac', 'WIN').", 56 | ) 57 | parser.add_argument( 58 | "--max_length", type=int, default=4096, help="Maximum generation length." 59 | ) 60 | parser.add_argument( 61 | "--top_k", type=int, default=1, help="Top-k sampling parameter." 62 | ) 63 | parser.add_argument( 64 | "--output_image_path", 65 | default="results", 66 | help="Directory to save the annotated images.", 67 | ) 68 | parser.add_argument( 69 | "--format_key", 70 | default="action_op_sensitive", 71 | help="Key to select the prompt format.", 72 | ) 73 | args = parser.parse_args() 74 | 75 | # Dictionary mapping format keys to format strings 76 | format_dict = { 77 | "action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)", 78 | "status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)", 79 | "status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)", 80 | "status_action_op": "(Answer in Status-Action-Operation format.)", 81 | "action_op": "(Answer in Action-Operation format.)", 82 | } 83 | 84 | # Ensure the provided format_key is valid 85 | if args.format_key not in format_dict: 86 | raise ValueError( 87 | f"Invalid format_key. Available keys are: {list(format_dict.keys())}" 88 | ) 89 | 90 | # Ensure the output directory exists 91 | os.makedirs(args.output_image_path, exist_ok=True) 92 | 93 | # Load the tokenizer and model 94 | tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True) 95 | model = AutoModelForCausalLM.from_pretrained( 96 | args.model_dir, 97 | torch_dtype=torch.bfloat16, 98 | trust_remote_code=True, 99 | device_map="auto", 100 | # quantization_config=BitsAndBytesConfig(load_in_8bit=True), # For INT8 quantization 101 | # quantization_config=BitsAndBytesConfig(load_in_4bit=True), # For INT4 quantization 102 | ).eval() 103 | # Initialize platform and selected format strings 104 | platform_str = f"(Platform: {args.platform})\n" 105 | format_str = format_dict[args.format_key] 106 | 107 | # Initialize history lists 108 | history_step = [] 109 | history_action = [] 110 | 111 | round_num = 1 112 | while True: 113 | task = input("Please enter the task description ('exit' to quit): ") 114 | if task.lower() == "exit": 115 | break 116 | 117 | img_path = input("Please enter the image path: ") 118 | try: 119 | image = Image.open(img_path).convert("RGB") 120 | except Exception: 121 | print("Invalid image path. Please try again.") 122 | continue 123 | 124 | # Verify history lengths match 125 | if len(history_step) != len(history_action): 126 | raise ValueError("Mismatch in lengths of history_step and history_action.") 127 | 128 | # Format history steps for output 129 | history_str = "\nHistory steps: " 130 | for index, (step, action) in enumerate(zip(history_step, history_action)): 131 | history_str += f"\n{index}. {step}\t{action}" 132 | 133 | # Compose the query with task, platform, and selected format instructions 134 | query = f"Task: {task}{history_str}\n{platform_str}{format_str}" 135 | 136 | print(f"Round {round_num} query:\n{query}") 137 | 138 | inputs = tokenizer.apply_chat_template( 139 | [{"role": "user", "image": image, "content": query}], 140 | add_generation_prompt=True, 141 | tokenize=True, 142 | return_tensors="pt", 143 | return_dict=True, 144 | ).to(model.device) 145 | # Generation parameters 146 | gen_kwargs = { 147 | "max_length": args.max_length, 148 | "do_sample": True, 149 | "top_k": args.top_k, 150 | } 151 | 152 | # Generate response 153 | with torch.no_grad(): 154 | outputs = model.generate(**inputs, **gen_kwargs) 155 | outputs = outputs[:, inputs["input_ids"].shape[1]:] 156 | response = tokenizer.decode(outputs[0], skip_special_tokens=True) 157 | print(f"Model response:\n{response}") 158 | 159 | # Extract grounded operation and action 160 | grounded_pattern = r"Grounded Operation:\s*(.*)" 161 | action_pattern = r"Action:\s*(.*)" 162 | matches_history = re.search(grounded_pattern, response) 163 | matches_actions = re.search(action_pattern, response) 164 | 165 | if matches_history: 166 | grounded_operation = matches_history.group(1) 167 | history_step.append(grounded_operation) 168 | if matches_actions: 169 | action_operation = matches_actions.group(1) 170 | history_action.append(action_operation) 171 | 172 | # Extract bounding boxes from the response 173 | box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]" 174 | matches = re.findall(box_pattern, response) 175 | if matches: 176 | boxes = [[int(x) / 1000 for x in match] for match in matches] 177 | 178 | # Extract base name of the user's input image (without extension) 179 | base_name = os.path.splitext(os.path.basename(img_path))[0] 180 | # Construct the output file name with round number 181 | output_file_name = f"{base_name}_{round_num}.png" 182 | output_path = os.path.join(args.output_image_path, output_file_name) 183 | 184 | draw_boxes_on_image(image, boxes, output_path) 185 | print(f"Annotated image saved at: {output_path}") 186 | else: 187 | print("No bounding boxes found in the response.") 188 | 189 | round_num += 1 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /finetune/README_en.md: -------------------------------------------------------------------------------- 1 | # CogAgent model fine-tuning 2 | 3 | In this demo, you will experience how to fine-tune the CogAgent open source model. Please strictly follow the steps in 4 | the document to avoid unnecessary errors. 5 | 6 | ## Multiple rounds of fine-tuning format 7 | 8 | The multi-round fine-tuning example uses the CogAgent dialogue format convention, adding different `loss_mask` to 9 | different roles to calculate the `loss` for multiple rounds of replies in one calculation. 10 | 11 | For the data file, the sample uses the following format: 12 | 13 | For the cogagent-9b-20241220 model, you should organize the data in the following format, ensuring that each data entry has an image and the History steps section of `user` includes all previous operations. 14 | The first piece of data below is an intermediate step, and the model will return `<>`, etc. The second piece of data is the complete process, and the model will return `<>`. 15 | 16 | ```json 17 | { 18 | "messages": [ 19 | { 20 | "role": "user", 21 | "content": "Task: In the Desktop & Dock section of System Settings, turn on the option 'Displays have separate Spaces' in the Mission Control section\n(Platform: Mac)\n(Answer in Action-Operation-Sensitive format.)\nHistory steps: \n0. CLICK(box=[[588,946,616,985]], element_info='System Settings')\tClick the 'System Settings' menu at the bottom of the screen to open System Settings\n1. CLICK(box=[[227,561,297,585]], element_info='Desktop & Dock')\tClick the 'Desktop & Dock' option in the left menu to enter the Desktop & Dock settings interface. \n2. SCROLL_DOWN(box=[[367,39,691,929]], step_count=75, element_info='Scroll')\tIn the scroll area on the right side of the current screen, scroll down to view more options. \n", 22 | "image": "images/0000000000336.png" 23 | }, 24 | { 25 | "role": "assistant", 26 | "content": "Action: Click the 'Displays have separate spaces' option in the 'Schedule Center' section to enable this feature. \nGrounded Operation:CLICK(box=[[655,842,671,857]], element_info='```Displays have separate spaces``` switch button')\n<>" 27 | } 28 | ] 29 | }, 30 | { 31 | "messages": [ 32 | { 33 | "role": "user", 34 | "content": "Task: Open the system settings through the launch pad, navigate to the 'Network' section, and set the DNS manual IP address to 8.8.4.4.\n(Platform: Mac)\n(Answer in Action-Operation-Sensitive format.)\nHistory steps: \n0. CLICK(box=[[7,6,27,31]], element_info='Apple logo')\tClick the 'Apple' icon in the upper left corner of the screen to open the system settings.\n1. CLICK(box=[[3,77,161,99]], element_info='System settings…')\tClick the 'System settings…' option in the menu bar in the upper left corner of the screen to open the system settings interface.\n2. CLICK(box=[[7,478,125,512]], element_info='Network')\tClick the 'Network' option in the left menu to enter the network settings interface. \n3. CLICK(box=[[155,217,459,270]], element_info='Wi-Fi, Connected')\tClick the 'Wi-Fi' option in the upper middle of the screen to enter the Wi-Fi settings interface. \n4. CLICK(box=[[405,280,447,296]], element_info='Details...')\tClick the 'Details...' button next to the Wi-Fi name to view detailed information about the current Wi-Fi. \n5. CLICK(box=[[27,344,45,362]], element_info='DNS')\tClick the 'DNS' option in the left menu to enter the DNS settings interface. \n6. CLICK(box=[[166,308,442,399]], element_info='8.8.8.8')\tClick the DNS server address input box to enter a new DNS server address. \n7. TYPE(box=[[163,308,443,322]], text='8.8.4.4', element_info='ip address input box')\tEnter 8.8.4.4 in the DNS server>Address input box to set a manual IP address. \n8. CLICK(box=[[410,636,440,651]], element_info='OK')\tClick the 'OK' button in the lower right corner of the interface to confirm and save the changes to the DNS settings. \n", 35 | "image": "images/0000000000482.png" 36 | }, 37 | { 38 | "role": "assistant", 39 | "content": "Action: Left click\nGrounded Operation:CLICK(box=[[150,911,180,975]], element_info='lower')\n<>" 40 | }, 41 | ] 42 | } 43 | ``` 44 | 45 | ## Configuration file 46 | 47 | The fine-tuning configuration files are located in the `config` directory and include the following files: 48 | 1. `ds_zereo_2 / ds_zereo_3.json`: DeepSpeed configuration files. 49 | 2. `lora.yaml / sft.yam`l: Configuration files for different model fine-tuning methods, including model parameters, 50 | optimizer parameters, training parameters, and more. Some important parameter explanations are as follows: 51 | 52 | + Data Configuration Section 53 | + train_file: Path to the training dataset file. 54 | + val_file: Path to the validation dataset file. 55 | + test_file: Path to the test dataset file. 56 | + num_proc: Number of processes to use when loading data. 57 | 58 | + freezeV: Whether to freeze the vision model parameters. 59 | + max_input_length: Maximum input sequence length. Note that an additional 1600 tokens for image encoding results will 60 | be fixed during the model’s actual inference. 61 | + max_output_length: Maximum output sequence length. 62 | 63 | + Training Configuration Section 64 | + output_dir: Directory to save the model and other outputs. 65 | + max_steps: Maximum number of training steps. 66 | + per_device_train_batch_size: Training batch size per device (e.g., GPU). 67 | + dataloader_num_workers: Number of workers to use when loading data. 68 | + remove_unused_columns: Whether to remove unused columns from the data. 69 | + save_strategy: Model saving strategy (e.g., save every N steps). 70 | + save_steps: Number of steps after which the model will be saved. 71 | + log_level: Logging level (e.g., info). 72 | + logging_strategy: Logging strategy. 73 | + logging_steps: Number of steps after which logs will be recorded. 74 | + per_device_eval_batch_size: Evaluation batch size per device. 75 | + evaluation_strategy: Evaluation strategy (e.g., evaluate every N steps). 76 | + eval_steps: Number of steps after which evaluation will occur. 77 | + predict_with_generate: Whether to use generation mode for prediction. 78 | 79 | + Generation Configuration Section 80 | + max_new_tokens: Maximum number of new tokens to generate. 81 | 82 | + PEFT (Parameter Efficient Fine-Tuning) Configuration Section 83 | + peft_type: Type of parameter-efficient fine-tuning used (supports LORA and PREFIX_TUNING). 84 | + task_type: Task type, which is causal language modeling (do not modify). 85 | 86 | + LoRA Parameters: 87 | + r: Rank of the LoRA. 88 | + lora_alpha: Scaling factor for LoRA. 89 | + lora_dropout: Dropout probability used in the LoRA layer. 90 | 91 | ## Start fine-tuning 92 | 93 | Execute **single machine multi-card/multi-machine multi-card** run through the following code, which uses `deepspeed` as 94 | the acceleration solution, and you need to install `deepspeed`. 95 | The cogagent-9b-20241220 model dataset is prepared by yourself, then run the following command: 96 | 97 | ```shell 98 | OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune.py data/CogAgentData THUDM/cogagent-9b-20241220 configs/sft.yaml 99 | ``` 100 | 101 | Execute **single machine single card** run through the following code. 102 | 103 | ```shell 104 | python finetune.py data/CogAgentData/ THUDM/cogagent-9b-20241220 configs/lora.yaml 105 | ``` 106 | 107 | ## Fine-tune from a saved point 108 | 109 | If you train as described above, each fine-tuning will start from the beginning. If you want to fine-tune from a 110 | half-trained model, you can add a fourth parameter, which can be passed in two ways: 111 | 112 | 1. `yes`, automatically start training from the last saved Checkpoint 113 | 114 | 2. `XX`, breakpoint number, for example `600`, start training from Checkpoint 600 115 | 116 | For example, this is an example code to continue fine-tuning from the last saved point 117 | 118 | ```shell 119 | python finetune.py data/CogAgent/ THUDM/cogagent-9b-20241220 configs/lora.yaml yes 120 | ``` 121 | 122 | ## Fine-Tuning with Huawei Ascend Computing Devices 123 | 124 | If you need to use `Ascend NPU` devices, such as `ATLAS 300 A2`, you should uncomment the following lines: 125 | 126 | ```python 127 | import torch_npu 128 | from torch_npu.contrib import transfer_to_npu 129 | ``` 130 | 131 | After that, you can run the fine-tuning program as expected. 132 | -------------------------------------------------------------------------------- /inference/web_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import torch 5 | from threading import Thread, Event 6 | from PIL import Image, ImageDraw 7 | import gradio as gr 8 | from transformers import ( 9 | AutoTokenizer, 10 | AutoModelForCausalLM, 11 | TextIteratorStreamer, 12 | ) 13 | from typing import List 14 | import spaces 15 | 16 | stop_event = Event() 17 | 18 | def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str): 19 | """ 20 | Draws red bounding boxes on the given image and saves it. 21 | 22 | Parameters: 23 | - image (PIL.Image.Image): The image on which to draw the bounding boxes. 24 | - boxes (List[List[float]]): A list of bounding boxes, each defined as [x_min, y_min, x_max, y_max]. 25 | Coordinates are expected to be normalized (0 to 1). 26 | - save_path (str): The path to save the updated image. 27 | 28 | Description: 29 | Each box coordinate is a fraction of the image dimension. This function converts them to actual pixel 30 | coordinates and draws a red rectangle to mark the area. The annotated image is then saved to the specified path. 31 | """ 32 | draw = ImageDraw.Draw(image) 33 | for box in boxes: 34 | x_min = int(box[0] * image.width) 35 | y_min = int(box[1] * image.height) 36 | x_max = int(box[2] * image.width) 37 | y_max = int(box[3] * image.height) 38 | draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) 39 | image.save(save_path) 40 | 41 | 42 | def preprocess_messages(history, img_path, platform_str, format_str): 43 | history_step = [] 44 | for task, model_msg in history: 45 | grounded_pattern = r"Grounded Operation:\s*(.*)" 46 | matches_history = re.search(grounded_pattern, model_msg) 47 | if matches_history: 48 | grounded_operation = matches_history.group(1) 49 | history_step.append(grounded_operation) 50 | 51 | history_str = "\nHistory steps: " 52 | if history_step: 53 | for i, step in enumerate(history_step): 54 | history_str += f"\n{i}. {step}" 55 | 56 | if history: 57 | task = history[-1][0] 58 | else: 59 | task = "No task provided" 60 | 61 | query = f"Task: {task}{history_str}\n{platform_str}{format_str}" 62 | image = Image.open(img_path).convert("RGB") 63 | return query, image 64 | 65 | 66 | @spaces.GPU() 67 | def predict(history, max_length, img_path, platform_str, format_str, output_dir): 68 | # Reset the stop_event at the start of prediction 69 | stop_event.clear() 70 | 71 | # Remember history length before this round (for rollback if stopped) 72 | prev_len = len(history) 73 | 74 | query, image = preprocess_messages(history, img_path, platform_str, format_str) 75 | inputs = tokenizer.apply_chat_template( 76 | [{"role": "user", "image": image, "content": query}], 77 | add_generation_prompt=True, 78 | tokenize=True, 79 | return_tensors="pt", 80 | return_dict=True, 81 | ).to(model.device) 82 | streamer = TextIteratorStreamer( 83 | tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True 84 | ) 85 | generate_kwargs = { 86 | "input_ids": inputs["input_ids"], 87 | "attention_mask": inputs["attention_mask"], 88 | "position_ids": inputs["position_ids"], 89 | "images": inputs["images"], 90 | "streamer": streamer, 91 | "max_length": max_length, 92 | "do_sample": True, 93 | "top_k": 1, 94 | } 95 | t = Thread(target=model.generate, kwargs=generate_kwargs) 96 | t.start() 97 | with torch.no_grad(): 98 | for new_token in streamer: 99 | # Check if stop event is set 100 | if stop_event.is_set(): 101 | # Stop generation immediately 102 | # Rollback the last round user input 103 | while len(history) > prev_len: 104 | history.pop() 105 | yield history, None 106 | return 107 | 108 | if new_token: 109 | history[-1][1] += new_token 110 | yield history, None 111 | 112 | 113 | response = history[-1][1] 114 | box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]" 115 | matches = re.findall(box_pattern, response) 116 | if matches: 117 | boxes = [[int(x) / 1000 for x in match] for match in matches] 118 | os.makedirs(output_dir, exist_ok=True) 119 | base_name = os.path.splitext(os.path.basename(img_path))[0] 120 | round_num = sum(1 for (u, m) in history if u and m) 121 | output_path = os.path.join(output_dir, f"{base_name}_{round_num}.png") 122 | image = Image.open(img_path).convert("RGB") 123 | draw_boxes_on_image(image, boxes, output_path) 124 | yield history, output_path 125 | else: 126 | yield history, None 127 | 128 | 129 | def user(task, history): 130 | return "", history + [[task, ""]] 131 | 132 | 133 | def undo_last_round(history, output_img): 134 | if history: 135 | history.pop() 136 | return history, None 137 | 138 | 139 | def clear_all_history(): 140 | return None, None 141 | 142 | 143 | def stop_now(): 144 | # Set the stop event to interrupt generation 145 | stop_event.set() 146 | # Returning no changes here, the changes to history and output_img are handled in predict 147 | return gr.update(), gr.update() 148 | 149 | 150 | def main(): 151 | parser = argparse.ArgumentParser(description="CogAgent Gradio Demo") 152 | parser.add_argument("--host", default="127.0.0.1", help="Host IP for the server.") 153 | parser.add_argument("--port", type=int, default=7860, help="Port for the server.") 154 | parser.add_argument("--model_dir", required=True, help="Path or identifier of the model.") 155 | parser.add_argument("--format_key", default="action_op_sensitive", help="Key to select the prompt format.") 156 | parser.add_argument("--platform", default="Mac", help="Platform information string.") 157 | parser.add_argument("--output_dir", default="annotated_images", help="Directory to save annotated images.") 158 | args = parser.parse_args() 159 | 160 | format_dict = { 161 | "action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)", 162 | "status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)", 163 | "status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)", 164 | "status_action_op": "(Answer in Status-Action-Operation format.)", 165 | "action_op": "(Answer in Action-Operation format.)" 166 | } 167 | 168 | if args.format_key not in format_dict: 169 | raise ValueError(f"Invalid format_key. Available keys: {list(format_dict.keys())}") 170 | 171 | global tokenizer, model 172 | tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True) 173 | model = AutoModelForCausalLM.from_pretrained( 174 | args.model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto" 175 | ).eval() 176 | 177 | platform_str = f"(Platform: {args.platform})\n" 178 | format_str = format_dict[args.format_key] 179 | 180 | with gr.Blocks(analytics_enabled=False) as demo: 181 | gr.HTML("

CogAgent Demo

") 182 | gr.HTML( 183 | "

This Demo is for learning and communication purposes only. Users must assume responsibility for the risks associated with AI-generated planning and operations.

") 184 | 185 | with gr.Row(): 186 | img_path = gr.Image(label="Upload a Screenshot", type="filepath", height=400) 187 | output_img = gr.Image(type="filepath", label="Annotated Image", height=400, interactive=False) 188 | 189 | with gr.Row(): 190 | with gr.Column(scale=2): 191 | chatbot = gr.Chatbot(height=300) 192 | task = gr.Textbox(show_label=True, placeholder="Input...", label="Task") 193 | submitBtn = gr.Button("Submit") 194 | with gr.Column(scale=1): 195 | max_length = gr.Slider(0, 8192, value=1024, step=1.0, label="Maximum length", interactive=True) 196 | undo_last_round_btn = gr.Button("Back to Last Round") 197 | clear_history_btn = gr.Button("Clear All History") 198 | 199 | # Interrupt procedure 200 | stop_now_btn = gr.Button("Stop Now", variant="stop") 201 | 202 | submitBtn.click( 203 | user, [task, chatbot], [task, chatbot], queue=False 204 | ).then( 205 | predict, 206 | [chatbot, max_length, img_path, gr.State(platform_str), gr.State(format_str), 207 | gr.State(args.output_dir)], 208 | [chatbot, output_img], 209 | queue=True 210 | ) 211 | 212 | undo_last_round_btn.click(undo_last_round, [chatbot, output_img], [chatbot, output_img], queue=False) 213 | clear_history_btn.click(clear_all_history, None, [chatbot, output_img], queue=False) 214 | stop_now_btn.click(stop_now, None, [chatbot, output_img], queue=False) 215 | 216 | demo.queue() 217 | demo.launch(server_name=args.host, server_port=args.port) 218 | 219 | 220 | if __name__ == "__main__": 221 | main() 222 | -------------------------------------------------------------------------------- /app/vllm_openai_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script creates a vLLM OpenAI Server demo with vLLM for the CogAgent model, 3 | using the OpenAI API to interact with the model. 4 | 5 | You can specify the model path, host, and port via command-line arguments, for example: 6 | python vllm_openai_demo.py --model_path THUDM/cogagent-9b-20241220 --host 0.0.0.0 --port 8000 7 | """ 8 | 9 | import argparse 10 | import gc 11 | import time 12 | import base64 13 | from contextlib import asynccontextmanager 14 | from typing import List, Literal, Union, Tuple, Optional 15 | import torch 16 | import uvicorn 17 | import requests 18 | from fastapi import FastAPI, HTTPException 19 | from fastapi.middleware.cors import CORSMiddleware 20 | from pydantic import BaseModel, Field 21 | from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine 22 | from PIL import Image 23 | from io import BytesIO 24 | 25 | TORCH_TYPE = ( 26 | torch.bfloat16 27 | if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 28 | else torch.float16 29 | ) 30 | 31 | 32 | @asynccontextmanager 33 | async def lifespan(app: FastAPI): 34 | """ 35 | An asynchronous context manager for managing the lifecycle of the FastAPI app. 36 | It ensures that GPU memory is cleared after the app's lifecycle ends, which is essential 37 | for efficient resource management in GPU environments. 38 | """ 39 | yield 40 | if torch.cuda.is_available(): 41 | torch.cuda.empty_cache() 42 | torch.cuda.ipc_collect() 43 | 44 | 45 | app = FastAPI(lifespan=lifespan) 46 | 47 | app.add_middleware( 48 | CORSMiddleware, 49 | allow_origins=["*"], 50 | allow_credentials=True, 51 | allow_methods=["*"], 52 | allow_headers=["*"], 53 | ) 54 | 55 | 56 | class ModelCard(BaseModel): 57 | """ 58 | A Pydantic model representing a model card, which provides metadata about a machine learning model. 59 | It includes fields like model ID, owner, and creation time. 60 | """ 61 | 62 | id: str 63 | object: str = "model" 64 | created: int = Field(default_factory=lambda: int(time.time())) 65 | owned_by: str = "owner" 66 | root: Optional[str] = None 67 | parent: Optional[str] = None 68 | permission: Optional[list] = None 69 | 70 | 71 | class ModelList(BaseModel): 72 | object: str = "list" 73 | data: List[ModelCard] = [] 74 | 75 | 76 | class ImageUrl(BaseModel): 77 | url: str 78 | 79 | 80 | class TextContent(BaseModel): 81 | type: Literal["text"] 82 | text: str 83 | 84 | 85 | class ImageUrlContent(BaseModel): 86 | type: Literal["image_url"] 87 | image_url: ImageUrl 88 | 89 | 90 | ContentItem = Union[TextContent, ImageUrlContent] 91 | 92 | 93 | class ChatMessageInput(BaseModel): 94 | role: Literal["user", "assistant", "system"] 95 | content: Union[str, List[ContentItem]] 96 | name: Optional[str] = None 97 | 98 | 99 | class ChatMessageResponse(BaseModel): 100 | role: Literal["assistant"] 101 | content: str = None 102 | name: Optional[str] = None 103 | 104 | 105 | class DeltaMessage(BaseModel): 106 | role: Optional[Literal["user", "assistant", "system"]] = None 107 | content: Optional[str] = None 108 | 109 | 110 | class ChatCompletionRequest(BaseModel): 111 | model: str 112 | messages: List[ChatMessageInput] 113 | temperature: Optional[float] = 0.8 114 | top_p: Optional[float] = 0.8 115 | max_tokens: Optional[int] = None 116 | stream: Optional[bool] = False 117 | # Additional parameters 118 | repetition_penalty: Optional[float] = 1.0 119 | 120 | 121 | class ChatCompletionResponseChoice(BaseModel): 122 | index: int 123 | message: ChatMessageResponse 124 | 125 | 126 | class ChatCompletionResponseStreamChoice(BaseModel): 127 | index: int 128 | delta: DeltaMessage 129 | 130 | 131 | class UsageInfo(BaseModel): 132 | prompt_tokens: int = 0 133 | total_tokens: int = 0 134 | completion_tokens: Optional[int] = 0 135 | 136 | 137 | class ChatCompletionResponse(BaseModel): 138 | model: str 139 | object: Literal["chat.completion", "chat.completion.chunk"] 140 | choices: List[ 141 | Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice] 142 | ] 143 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 144 | usage: Optional[UsageInfo] = None 145 | 146 | 147 | @app.get("/v1/models", response_model=ModelList) 148 | async def list_models(): 149 | """ 150 | An endpoint to list available models. It returns a list of model cards. 151 | This is useful for clients to query and understand what models are available for use. 152 | """ 153 | model_card = ModelCard(id="CogAgent") 154 | return ModelList(data=[model_card]) 155 | 156 | 157 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 158 | async def create_chat_completion(request: ChatCompletionRequest): 159 | """ 160 | An endpoint to create chat completions given a set of messages and model parameters. 161 | Returns either a single completion or streams tokens as they are generated. 162 | """ 163 | global model 164 | 165 | if len(request.messages) < 1 or request.messages[-1].role == "assistant": 166 | raise HTTPException(status_code=400, detail="Invalid request") 167 | 168 | gen_params = dict( 169 | messages=request.messages, 170 | temperature=request.temperature, 171 | top_p=request.top_p, 172 | max_tokens=request.max_tokens or 1024, 173 | echo=False, 174 | stream=False, 175 | repetition_penalty=request.repetition_penalty, 176 | ) 177 | 178 | response = None 179 | messages = gen_params["messages"] 180 | query, image = process_history_and_images(messages) 181 | 182 | async for response in vllm_gen(model, query, image): 183 | pass 184 | 185 | usage = UsageInfo() 186 | print(response) 187 | message = ChatMessageResponse( 188 | role="assistant", 189 | content=response["text"], 190 | ) 191 | choice_data = ChatCompletionResponseChoice( 192 | index=0, 193 | message=message, 194 | ) 195 | task_usage = UsageInfo.model_validate(response["usage"]) 196 | for usage_key, usage_value in task_usage.model_dump().items(): 197 | setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) 198 | 199 | return ChatCompletionResponse( 200 | model=request.model, 201 | choices=[choice_data], 202 | object="chat.completion", 203 | usage=usage, 204 | ) 205 | 206 | 207 | def process_history_and_images( 208 | messages: List[ChatMessageInput], 209 | ) -> Tuple[Optional[str], Optional[List[Tuple[str, str]]], Optional[List[Image.Image]]]: 210 | """ 211 | Process history messages to extract text, identify the last user query, 212 | and convert base64 encoded image URLs to PIL images. 213 | 214 | Args: 215 | messages(List[ChatMessageInput]): List of ChatMessageInput objects. 216 | Returns: 217 | A tuple: 218 | - query (str): The user's query text. 219 | - image (PIL.Image.Image or None): The extracted image, if any. 220 | """ 221 | image = None 222 | for message in messages: 223 | content = message.content 224 | 225 | # Extract text content 226 | if isinstance(content, list): # text 227 | text_content = " ".join( 228 | item.text for item in content if isinstance(item, TextContent) 229 | ) 230 | else: 231 | # If content is a string, treat it directly as text 232 | text_content = content 233 | 234 | if isinstance(content, list): # image 235 | for item in content: 236 | if isinstance(item, ImageUrlContent): 237 | image_url = item.image_url.url 238 | if image_url.startswith("data:image/jpeg;base64,"): 239 | # Base64 encoded image 240 | base64_encoded_image = image_url.split( 241 | "data:image/jpeg;base64," 242 | )[1] 243 | image_data = base64.b64decode(base64_encoded_image) 244 | image = Image.open(BytesIO(image_data)).convert("RGB") 245 | else: 246 | # Fetch image from a remote URL 247 | response = requests.get(image_url, verify=False) 248 | image = Image.open(BytesIO(response.content)).convert("RGB") 249 | 250 | return text_content, image 251 | 252 | 253 | async def vllm_gen( 254 | model: AsyncLLMEngine, 255 | messages: Optional[str], 256 | image: Optional[List[Image.Image]] 257 | ): 258 | # Use vllm to perform inference. 259 | # For details on the meaning of the inputs and params_dict, see vLLM 260 | inputs = { 261 | "prompt": messages, 262 | "multi_modal_data": {"image": image}, 263 | } 264 | params_dict = { 265 | "n": 1, 266 | "best_of": 1, 267 | "presence_penalty": 1.0, 268 | "frequency_penalty": 0.0, 269 | "temperature": 0.6, 270 | "top_p": 0.8, 271 | "top_k": -1, 272 | "ignore_eos": False, 273 | "max_tokens": 8192, 274 | "logprobs": None, 275 | "prompt_logprobs": None, 276 | "skip_special_tokens": True, 277 | "stop_token_ids": [151329, 151336, 151338], 278 | } 279 | sampling_params = SamplingParams(**params_dict) 280 | 281 | async for output in model.generate( 282 | prompt=inputs, 283 | sampling_params=sampling_params, 284 | request_id=f"{time.time()}" 285 | ): 286 | input_echo_len = len(output.prompt_token_ids) - 1601 287 | output_echo_len = len(output.outputs[0].token_ids) 288 | yield { 289 | "text": output.outputs[0].text, 290 | "usage": { 291 | "prompt_tokens": input_echo_len, 292 | "completion_tokens": output_echo_len, 293 | "total_tokens": input_echo_len + output_echo_len, 294 | }, 295 | } 296 | 297 | 298 | def load_model(model_dir: str): 299 | engine_args = AsyncEngineArgs( 300 | model=model_dir, 301 | tensor_parallel_size=1, 302 | dtype="bfloat16", 303 | trust_remote_code=True, 304 | gpu_memory_utilization=0.9, 305 | enforce_eager=True, 306 | worker_use_ray=True, 307 | disable_log_requests=True, 308 | # 如果遇见 OOM 现象,建议开启下述参数 309 | # enable_chunked_prefill=True, 310 | # max_num_batched_tokens=8192 311 | ) 312 | engine = AsyncLLMEngine.from_engine_args(engine_args) 313 | return engine 314 | 315 | 316 | gc.collect() 317 | torch.cuda.empty_cache() 318 | 319 | if __name__ == "__main__": 320 | # Use argparse to control model_path, host, and port from command line arguments 321 | parser = argparse.ArgumentParser(description="vLLM OpenAI Server Demo for CogAgent") 322 | parser.add_argument( 323 | "--model_path", required=True, help="Path or name of the CogAgent model" 324 | ) 325 | parser.add_argument("--host", default="0.0.0.0", help="Host to run the server on") 326 | parser.add_argument( 327 | "--port", type=int, default=8000, help="Port to run the server on" 328 | ) 329 | args = parser.parse_args() 330 | 331 | # Load model 332 | model = load_model(args.model_path) 333 | 334 | uvicorn.run(app, host=args.host, port=args.port, workers=1) 335 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 CogAgent team @ Zhipu AI 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /app/openai_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script creates an OpenAI Server demo with transformers for the CogAgent model, 3 | using the OpenAI API to interact with the model. 4 | 5 | You can specify the model path, host, and port via command-line arguments, for example: 6 | python openai_demo.py --model_path THUDM/cogagent-9b-20241220 --host 0.0.0.0 --port 8000 7 | """ 8 | 9 | import argparse 10 | import gc 11 | import threading 12 | import time 13 | import base64 14 | from contextlib import asynccontextmanager 15 | from typing import List, Literal, Union, Tuple, Optional 16 | import torch 17 | import uvicorn 18 | import requests 19 | from fastapi import FastAPI, HTTPException 20 | from fastapi.middleware.cors import CORSMiddleware 21 | from pydantic import BaseModel, Field 22 | from sse_starlette.sse import EventSourceResponse 23 | from transformers import AutoTokenizer, AutoModel, TextIteratorStreamer 24 | from PIL import Image 25 | from io import BytesIO 26 | from pathlib import Path 27 | 28 | # Determine the appropriate torch dtype based on the GPU capabilities 29 | TORCH_TYPE = ( 30 | torch.bfloat16 31 | if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 32 | else torch.float16 33 | ) 34 | 35 | 36 | @asynccontextmanager 37 | async def lifespan(app: FastAPI): 38 | """ 39 | An asynchronous context manager for managing the lifecycle of the FastAPI app. 40 | It ensures that GPU memory is cleared after the app's lifecycle ends, which is essential 41 | for efficient resource management in GPU environments. 42 | """ 43 | yield 44 | if torch.cuda.is_available(): 45 | torch.cuda.empty_cache() 46 | torch.cuda.ipc_collect() 47 | 48 | 49 | app = FastAPI(lifespan=lifespan) 50 | 51 | # Enable CORS so that the API can be called from anywhere 52 | app.add_middleware( 53 | CORSMiddleware, 54 | allow_origins=["*"], 55 | allow_credentials=True, 56 | allow_methods=["*"], 57 | allow_headers=["*"], 58 | ) 59 | 60 | 61 | class ModelCard(BaseModel): 62 | """ 63 | A Pydantic model representing a model card, which provides metadata about a machine learning model. 64 | It includes fields like model ID, owner, and creation time. 65 | """ 66 | 67 | id: str 68 | object: str = "model" 69 | created: int = Field(default_factory=lambda: int(time.time())) 70 | owned_by: str = "owner" 71 | root: Optional[str] = None 72 | parent: Optional[str] = None 73 | permission: Optional[list] = None 74 | 75 | 76 | class ModelList(BaseModel): 77 | object: str = "list" 78 | data: List[ModelCard] = [] 79 | 80 | 81 | class ImageUrl(BaseModel): 82 | url: str 83 | 84 | 85 | class TextContent(BaseModel): 86 | type: Literal["text"] 87 | text: str 88 | 89 | 90 | class ImageUrlContent(BaseModel): 91 | type: Literal["image_url"] 92 | image_url: ImageUrl 93 | 94 | 95 | ContentItem = Union[TextContent, ImageUrlContent] 96 | 97 | 98 | class ChatMessageInput(BaseModel): 99 | role: Literal["user", "assistant", "system"] 100 | content: Union[str, List[ContentItem]] 101 | name: Optional[str] = None 102 | 103 | 104 | class ChatMessageResponse(BaseModel): 105 | role: Literal["assistant"] 106 | content: str = None 107 | name: Optional[str] = None 108 | 109 | 110 | class DeltaMessage(BaseModel): 111 | role: Optional[Literal["user", "assistant", "system"]] = None 112 | content: Optional[str] = None 113 | 114 | 115 | class ChatCompletionRequest(BaseModel): 116 | model: str 117 | messages: List[ChatMessageInput] 118 | temperature: Optional[float] = 0.8 119 | top_p: Optional[float] = 0.8 120 | max_tokens: Optional[int] = None 121 | stream: Optional[bool] = False 122 | repetition_penalty: Optional[float] = 1.0 123 | 124 | 125 | class ChatCompletionResponseChoice(BaseModel): 126 | index: int 127 | message: ChatMessageResponse 128 | 129 | 130 | class ChatCompletionResponseStreamChoice(BaseModel): 131 | index: int 132 | delta: DeltaMessage 133 | 134 | 135 | class UsageInfo(BaseModel): 136 | prompt_tokens: int = 0 137 | total_tokens: int = 0 138 | completion_tokens: Optional[int] = 0 139 | 140 | 141 | class ChatCompletionResponse(BaseModel): 142 | model: str 143 | object: Literal["chat.completion", "chat.completion.chunk"] 144 | choices: List[ 145 | Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice] 146 | ] 147 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 148 | usage: Optional[UsageInfo] = None 149 | 150 | 151 | @app.get("/v1/models", response_model=ModelList) 152 | async def list_models(): 153 | """ 154 | An endpoint to list available models. It returns a list of model cards. 155 | This is useful for clients to query and understand what models are available for use. 156 | """ 157 | model_card = ModelCard(id="CogAgent") 158 | return ModelList(data=[model_card]) 159 | 160 | 161 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 162 | async def create_chat_completion(request: ChatCompletionRequest): 163 | """ 164 | An endpoint to create chat completions given a set of messages and model parameters. 165 | Returns either a single completion or streams tokens as they are generated. 166 | """ 167 | global model, tokenizer 168 | 169 | if len(request.messages) < 1 or request.messages[-1].role == "assistant": 170 | raise HTTPException(status_code=400, detail="Invalid request") 171 | 172 | gen_params = dict( 173 | messages=request.messages, 174 | temperature=request.temperature, 175 | top_p=request.top_p, 176 | max_tokens=request.max_tokens or 1024, 177 | echo=False, 178 | stream=request.stream, 179 | repetition_penalty=request.repetition_penalty, 180 | ) 181 | 182 | if request.stream: 183 | # If streaming is requested, return an EventSourceResponse that yields tokens as they are generated 184 | generate = predict(request.model, gen_params) 185 | return EventSourceResponse(generate, media_type="text/event-stream") 186 | 187 | # Otherwise, return a complete response after generation 188 | response = generate_cogagent(model, tokenizer, gen_params) 189 | 190 | usage = UsageInfo() 191 | message = ChatMessageResponse(role="assistant", content=response["text"]) 192 | choice_data = ChatCompletionResponseChoice(index=0, message=message) 193 | 194 | task_usage = UsageInfo.model_validate(response["usage"]) 195 | for usage_key, usage_value in task_usage.model_dump().items(): 196 | setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) 197 | 198 | return ChatCompletionResponse( 199 | model=request.model, 200 | choices=[choice_data], 201 | object="chat.completion", 202 | usage=usage, 203 | ) 204 | 205 | 206 | def predict(model_id: str, params: dict): 207 | """ 208 | A generator function that streams the model output tokens. 209 | Used for the `stream=True` scenario, returning tokens as SSE events. 210 | """ 211 | global model, tokenizer 212 | 213 | # Initially, return the role delta message 214 | choice_data = ChatCompletionResponseStreamChoice( 215 | index=0, delta=DeltaMessage(role="assistant") 216 | ) 217 | chunk = ChatCompletionResponse( 218 | model=model_id, choices=[choice_data], object="chat.completion.chunk" 219 | ) 220 | yield chunk.model_dump_json(exclude_unset=True) 221 | 222 | previous_text = "" 223 | for new_response in generate_stream_cogagent(model, tokenizer, params): 224 | decoded_unicode = new_response["text"] 225 | delta_text = decoded_unicode[len(previous_text) :] 226 | previous_text = decoded_unicode 227 | delta = DeltaMessage(content=delta_text, role="assistant") 228 | choice_data = ChatCompletionResponseStreamChoice(index=0, delta=delta) 229 | chunk = ChatCompletionResponse( 230 | model=model_id, choices=[choice_data], object="chat.completion.chunk" 231 | ) 232 | yield chunk.model_dump_json(exclude_unset=True) 233 | 234 | # End of stream message 235 | choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage()) 236 | chunk = ChatCompletionResponse( 237 | model=model_id, choices=[choice_data], object="chat.completion.chunk" 238 | ) 239 | yield chunk.model_dump_json(exclude_unset=True) 240 | 241 | 242 | def generate_cogagent(model: AutoModel, tokenizer: AutoTokenizer, params: dict): 243 | """ 244 | Generates a response using the CogAgent model. 245 | It processes the chat history and any provided images, 246 | and then invokes the model to generate a complete response. 247 | """ 248 | response = None 249 | for response in generate_stream_cogagent(model, tokenizer, params): 250 | pass 251 | return response 252 | 253 | 254 | def process_history_and_images( 255 | messages: List[ChatMessageInput], 256 | ) -> Tuple[Optional[str], Optional[Image.Image]]: 257 | """ 258 | Process history messages to extract text, identify the last user query, 259 | and convert base64 encoded image URLs to PIL images. 260 | 261 | Args: 262 | messages(List[ChatMessageInput]): List of ChatMessageInput objects. 263 | Returns: 264 | A tuple: 265 | - query (str): The user's query text. 266 | - image (PIL.Image.Image or None): The extracted image, if any. 267 | """ 268 | image = None 269 | text_content = "" 270 | for message in messages: 271 | content = message.content 272 | 273 | # Extract text content 274 | if isinstance(content, list): 275 | extracted_texts = [ 276 | item.text for item in content if isinstance(item, TextContent) 277 | ] 278 | text_content = " ".join(extracted_texts) 279 | 280 | else: 281 | # If content is a string, treat it directly as text 282 | text_content = content 283 | 284 | # Extract image content 285 | if isinstance(content, list): 286 | for item in content: 287 | if isinstance(item, ImageUrlContent): 288 | image_url = item.image_url.url 289 | if image_url.startswith("data:image/jpeg;base64,"): 290 | # Base64 encoded image 291 | base64_encoded_image = image_url.split( 292 | "data:image/jpeg;base64," 293 | )[1] 294 | image_data = base64.b64decode(base64_encoded_image) 295 | image = Image.open(BytesIO(image_data)).convert("RGB") 296 | else: 297 | # Fetch image from a remote URL 298 | response = requests.get(image_url, verify=False) 299 | image = Image.open(BytesIO(response.content)).convert("RGB") 300 | return text_content, image 301 | 302 | 303 | @torch.inference_mode() 304 | def generate_stream_cogagent(model: AutoModel, tokenizer: AutoTokenizer, params: dict): 305 | """ 306 | Streams the generation results from the model token-by-token. 307 | Uses TextIteratorStreamer to yield partial responses as they are generated. 308 | """ 309 | messages = params["messages"] 310 | temperature = float(params.get("temperature", 1.0)) 311 | top_p = float(params.get("top_p", 1.0)) 312 | max_new_tokens = int(params.get("max_tokens", 256)) 313 | query, image = process_history_and_images(messages) 314 | 315 | # Apply a chat template (assumed to be provided by the model or custom logic) 316 | model_inputs = tokenizer.apply_chat_template( 317 | [{"role": "user", "image": image, "content": query}], 318 | add_generation_prompt=True, 319 | tokenize=True, 320 | return_tensors="pt", 321 | return_dict=True, 322 | ).to(model.device) 323 | 324 | input_echo_len = len(model_inputs["input_ids"][0]) 325 | streamer = TextIteratorStreamer( 326 | tokenizer=tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True 327 | ) 328 | gen_kwargs = { 329 | "max_length": max_new_tokens, 330 | "do_sample": True if temperature > 1e-5 else False, 331 | "top_p": top_p if temperature > 1e-5 else 0, 332 | "top_k": 1, 333 | "streamer": streamer, 334 | } 335 | if temperature > 1e-5: 336 | gen_kwargs["temperature"] = temperature 337 | 338 | generated_text = "" 339 | 340 | def generate_text(): 341 | with torch.no_grad(): 342 | model.generate(**model_inputs, **gen_kwargs) 343 | 344 | generation_thread = threading.Thread(target=generate_text) 345 | generation_thread.start() 346 | 347 | total_len = input_echo_len 348 | for next_text in streamer: 349 | generated_text += next_text 350 | total_len = len(tokenizer.encode(generated_text)) 351 | yield { 352 | "text": generated_text, 353 | "usage": { 354 | "prompt_tokens": input_echo_len, 355 | "completion_tokens": total_len - input_echo_len, 356 | "total_tokens": total_len, 357 | }, 358 | } 359 | 360 | generation_thread.join() 361 | yield { 362 | "text": generated_text, 363 | "usage": { 364 | "prompt_tokens": input_echo_len, 365 | "completion_tokens": total_len - input_echo_len, 366 | "total_tokens": total_len, 367 | }, 368 | } 369 | 370 | 371 | # Clean up GPU memory if possible 372 | gc.collect() 373 | torch.cuda.empty_cache() 374 | 375 | if __name__ == "__main__": 376 | # Use argparse to control model_path, host, and port from command line arguments 377 | parser = argparse.ArgumentParser(description="OpenAI Server Demo for CogAgent") 378 | parser.add_argument( 379 | "--model_path", required=True, help="Path or name of the CogAgent model" 380 | ) 381 | parser.add_argument("--host", default="0.0.0.0", help="Host to run the server on") 382 | parser.add_argument( 383 | "--port", type=int, default=8000, help="Port to run the server on" 384 | ) 385 | args = parser.parse_args() 386 | 387 | model_dir = Path(args.model_path).expanduser().resolve() 388 | 389 | # Load tokenizer 390 | tokenizer = AutoTokenizer.from_pretrained( 391 | args.model_path, trust_remote_code=True, encode_special_tokens=True 392 | ) 393 | # Load model 394 | model = AutoModel.from_pretrained( 395 | args.model_path, 396 | torch_dtype=TORCH_TYPE, 397 | trust_remote_code=True, 398 | device_map="auto", 399 | ).eval() 400 | 401 | # Run the Uvicorn server with the specified host and port 402 | uvicorn.run(app, host=args.host, port=args.port, workers=1) 403 | -------------------------------------------------------------------------------- /Action_space_en.md: -------------------------------------------------------------------------------- 1 | ## Four Mouse Operations 2 | 3 | Mouse operations include: Left Click (`CLICK`), Left Double Click (`DOUBLE_CLICK`), Right Click (`RIGHT_CLICK`), and 4 | Mouse Hover (`HOVER`). For example, 5 | `CLICK(box=[[387,248,727,317]], element_type='Clickable text', element_info='Click to add Title')`. The parameters 6 | supported by these four actions are as follows: 7 | 8 | | Parameter Name | Optional | Explanation | 9 | |----------------|----------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 10 | | box | Required | Represents a rectangle on the screen in the form `[[a,b,c,d]]`, where a/b/c/d are three-digit numbers ranging from 000 to 999.
Assuming the width of the screen is w and the height is h. The top-left corner of the screen is (0, 0). The top-left corner of the rectangle is (a / 1000 * w, b / 1000 * h), and the bottom-right corner of the rectangle is (c / 1000 * w, d / 1000 * h).
The actual operation position is the center of the rectangle. | 11 | | element_type | Optional | A description of the type of the element being operated on, for example, "Clickable text" | 12 | | element_info | Optional | A description of the content of the element being operated on, for example, "Click to add Title" | 13 | 14 | ## Text Input (TYPE) 15 | 16 | Text input refers to entering text at a given location, for example, 17 | `TYPE(box=[[387,249,727,317]], text='CogAgent', element_type='Text input box', element_info='CogAgent')`. The parameters 18 | it supports are as follows: 19 | 20 | | Parameter Name | Optional | Explanation | 21 | |----------------|----------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 22 | | box | Required | Please refer to the explanation in "Four Mouse Operations". | 23 | | element_type | Optional | Please refer to the explanation in "Four Mouse Operations". | 24 | | element_info | Optional | Please refer to the explanation in "Four Mouse Operations". | 25 | | text | Required | The text content that needs to be input. This parameter may contain variables in the form `__CogName_xxx__`. During actual execution of the "Text Input" action, these variables should be replaced with actual values. For more details, please refer to [here](#jump). | 26 | 27 | ## Four Scrolling Operations 28 | 29 | Scrolling operations include: Scroll Up (`SCROLL_UP`), Scroll Down (`SCROLL_DOWN`), Scroll Left (`SCROLL_LEFT`), and 30 | Scroll Right (`SCROLL_DOWN`). For example, 31 | `SCROLL_DOWN(box=[[000,086,999,932]], element_type='Scroll', element_info='Scroll', step_count=5)`. The parameters 32 | supported by these four actions are as follows: 33 | 34 | | Parameter Name | Optional | Explanation | 35 | |----------------|----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 36 | | box | Required | Please refer to the explanation in "Four Mouse Operations". | 37 | | element_type | Optional | Please refer to the explanation in "Four Mouse Operations". | 38 | | element_info | Optional | Please refer to the explanation in "Four Mouse Operations". | 39 | | step_count | Required | The number of steps to scroll, where one step corresponds to one notch of the mouse wheel. Note: Both system settings and application types can affect the actual effect of one step of scrolling, so it is difficult for the model to accurately predict the required number of scrolling steps. | 40 | 41 | ## Keyboard Press (KEY_PRESS) 42 | 43 | Keyboard press refers to pressing and releasing a given button in sequence, for example, `KEY_PRESS(key='F11')`. This 44 | operation type has only one required parameter `key`, which represents the name of the key that needs to be pressed, 45 | such as numeric keys (0–9), letters (A-Z). In addition, `KEY_PRESS` supports the following common keys, as shown in the 46 | table below. 47 | 48 | | | **Windows** | **macos** | 49 | |------------------------------|---------------------|-------------------------| 50 | | **Enter/Return** | Return | Return | 51 | | **Space** | Space | Space | 52 | | **Ctrl key (left/right)** | Lcontrol / Rcontrol | N/A | 53 | | **Alt key (left/right)** | Lmenu / Rmenu | N/A | 54 | | **Control key (left/right)** | N/A | Control / Right Control | 55 | | **Command key (left/right)** | N/A | Command / Right Command | 56 | | **Shift key (left/right)** | Lshift / Rshift | Shift / Right Shift | 57 | | **Arrow key - Up** | Up | Up Arrow | 58 | | **Arrow key - Down** | Down | Down Arrow | 59 | | **Arrow key - Left** | Left | Left Arrow | 60 | | **Arrow key - Right** | Right | Right Arrow | 61 | 62 | ## Gestures (Combination Keys) 63 | 64 | Using combination keys, for example, `ctrl+f` to perform a search. The corresponding structured expression is 65 | `GESTURE(actions=[KEY_DOWN(key='Lcontrol'), KEY_PRESS(key='A'), KEY_UP(key='Lcontrol')])`. `GESTURE` contains only one 66 | parameter `actions`, which takes a list as its value. Each element in the list is one of the following three actions: 67 | 68 | 1. `KEY_DOWN`: Press a key without releasing it. 69 | 70 | 2. `KEY_PRESS`: Tap a key, which means to press and release it. 71 | 72 | 3. `KEY_UP`: Release a key that has been pressed. 73 | 74 | ## Launching an Application or a URL (LAUNCH) 75 | 76 | Directly open an application or a link in a browser. The `LAUNCH` operation accepts two parameters, `app` and `url`, 77 | where `app` represents the name of the application to be opened and `url` represents the link to be opened. If both 78 | parameters are provided, only `url` takes effect. For example: 79 | 80 | 1. `LAUNCH(app='Settings', url='None')`: Open the system settings. 81 | 82 | 2. `LAUNCH(app='None', url='baidu.com')`: Open the Baidu homepage. 83 | 84 | ## Quoting Text Content (QUOTE_TEXT) 85 | 86 | Identify and process the text content in a given area, and store the result in a variable for subsequent use. For 87 | example: 88 | 89 | 1. 90 | `QUOTE_TEXT(box=[[387,249,727,317]], element_type='Text', element_info='Price after coupon: 17.00', output='__CogName_ProductPrice__', result='17.00')`; 91 | 92 | 2. 93 | `QUOTE_TEXT(box=[[000,086,999,932]], auto_scroll=True, element_type='Window', element_info='CogAgent Technical Report Blog', output='__CogName_TechnicalReport__')`. 94 | 95 | The parameters it supports are as follows: 96 | 97 | | Parameter Name | Optional | Explanation | 98 | |----------------|----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 99 | | box | Required | Please refer to the explanation in "Four Mouse Operations". | 100 | | element_type | Optional | Please refer to the explanation in "Four Mouse Operations". | 101 | | element_info | Optional | Please refer to the explanation in "Four Mouse Operations". | 102 | | output | Required | The variable name indicating where the quoted result is stored. The format is `__CogName_xxx__`. | 103 | | result | Optional | Represents the result of the text quote. If the text quote result is too long, the value of `result` will contain ellipses, or this parameter may not be present at all. In such cases, the CogAgent client-side application needs to call an OCR service to obtain the quoted result. | 104 | | auto_scroll | Optional | Defaults to `False`. If `auto_scroll` is true, the CogAgent client-side application needs to scroll down the list until the bottom of the list, while obtaining the list content as the result of the medical text. When the text to be quoted is very long, `auto_scroll` should be set to true. | 105 | 106 | ## Calling a Large Language Model (LLM) and Using Variables 107 | 108 | Organize prompts and call a large language model to compute results. The parameters received by this action are as 109 | follows: 110 | 111 | | Parameter Name | Optional | Explanation | 112 | |----------------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 113 | | prompt | Required | The prompt used to call the large language model, where the variable names used will be replaced with actual values. | 114 | | output | Required | The variable name indicating where the result of the large language model call is stored. The format is `__CogName_xxx__`. | 115 | | result | Optional | Represents the result of the large language model call. If the result is too long, the value of `result` will contain ellipses, or this parameter may not be present at all. In such cases, the CogAgent client-side application needs to call the large language model service to obtain the result. | 116 | 117 | For example, the following two operations can be used to summarize the entire content of this page: 118 | 119 | 1. Quote the entire content of this page. Since there is a lot of content on the page, `auto_scroll=True` needs to be 120 | set when quoting. 121 | 122 | 1. `QUOTE_TEXT(box=[[000,086,999,932]], auto_scroll=True, element_type='Window', element_info='CogAgent Technical Report Blog', output='__CogName_TechnicalReport__')` 123 | 124 | 2. Call the large language model to summarize the content of the technical report. The content of 125 | the technical report has been stored in the variable `__CogName_TechnicalReport__`, so this variable should be used 126 | directly in the `prompt` parameter; when calling the large language model to generate the summary content, 127 | `__CogName_TechnicalReport__` needs to be replaced with the actual value. 128 | 129 | 1. `LLM(prompt='Summarize the following content: __CogName_TechnicalReport__', output='__CogName_TechnicalReportSummary__')` 130 | 131 | ## Quoting Clipboard Content (QUOTE_CLIPBOARD) 132 | 133 | Store the content of the clipboard in a variable for use in subsequent steps. Many web pages and applications provide 134 | a "click to copy to clipboard" feature. `QUOTE_CLIPBOARD` allows the model to quickly obtain and use the content of the 135 | clipboard. 136 | Here is an example of a structured expression: 137 | `QUOTE_CLIPBOARD(output='__CogName_QuickSortCode__', result='def quick_sort(arr):\n\tif len(arr) <= 1:\n\t\treturn arr\n\t...')` 138 | 139 | ## End (END) 140 | 141 | A special operation indicating that the task has been completed. 142 | 143 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # CogAgent 2 | 3 | Read this in [English](README.md) 4 | 5 | - 🔥 **2024.12** 我们开源了**最新版 CogAgent-9B-20241220 模型**。相较于上一版本CogAgent,`CogAgent-9B-20241220` 6 | 在GUI感知、推理预测准确性、动作空间完善性、任务的普适和泛化性上得到了大幅提升,能够接受中英文双语的屏幕截图和语言交互。 7 | - 🏆 **2024.6** CogAgent 被 CVPR2024 接收,并被评为大会 Highlight(前3%) 。 8 | - 2023.12 我们**开源了首个GUI Agent:CogAgent**(该版本仓库位于[这里](https://github.com/THUDM/CogVLM)),并**发布了对应论文 9 | 📖 [CogAgent论文](https://arxiv.org/abs/2312.08914)**。 10 | 11 | ## 关于模型 12 | 13 | ### 模型资源 14 | 15 | | Model | 模型下载地址 | 技术文档 | 在线体验 | 16 | |:--------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 17 | | cogagent-9b-20241220 | [🤗 HuggingFace](https://huggingface.co/THUDM/cogagent-9b-20241220)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/cogagent-9b-20241220)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/cogagent-9b-20241220)
[🧩 Modelers(昇腾)](https://modelers.cn/models/zhipuai/cogagent-9b-20241220) | [📄 官方技术博客](https://cogagent.aminer.cn/blog#/articles/cogagent-9b-20241220-technical-report)
[📘 实操文档(中文)](https://zhipu-ai.feishu.cn/wiki/MhPYwtpBhinuoikNIYYcyu8dnKv?fromScene=spaceOverview) | [🤗 HuggingFace Space](https://huggingface.co/spaces/THUDM-HF-SPACE/CogAgent-Demo)
[🤖 ModelScope Space](https://modelscope.cn/studios/ZhipuAI/CogAgent-Demo)
[🧩 Modelers Space(昇腾)](https://modelers.cn/spaces/zhipuai/CogAgent) | 18 | 19 | ### 模型简介 20 | 21 | `CogAgent-9B-20241220` 模型基于 [GLM-4V-9B](https://huggingface.co/THUDM/glm-4v-9b) 22 | 双语开源VLM基座模型。通过数据的采集与优化、多阶段训练与策略改进等方法,`CogAgent-9B-20241220` 23 | 在GUI感知、推理预测准确性、动作空间完善性、任务的普适和泛化性上得到了大幅提升,能够接受中英文双语的屏幕截图和语言交互。此版CogAgent模型已被应用于智谱AI的 [GLM-PC产品](https://cogagent.aminer.cn/home) 24 | 。我们希望这版模型的发布能够帮助到学术研究者们和开发者们,一起推进基于视觉语言模型的 GUI agent 的研究和应用。 25 | 26 | ### 能力展示 27 | 28 | CogAgent-9b-20241220 模型在多平台、多类别的GUI Agent及GUI Grounding 29 | Benchmarks上取得了当前最优的结果。在 [CogAgent-9b-20241220 技术博客](https://cogagent.aminer.cn/blog#/articles/cogagent-9b-20241220-technical-report) 30 | 中,我们对比了基于API的商业模型(GPT-4o-20240806、Claude-3.5-Sonnet)、商业API + GUI Grounding模型(GPT-4o + UGround、GPT-4o + 31 | OS-ATLAS)、开源GUI Agent模型(Qwen2-VL、ShowUI、SeeClick)。结果表明,* 32 | *CogAgent在GUI定位(Screenspot)、单步操作(OmniAct)、中文step-wise内部评测榜单(CogAgentBench-basic-cn)、多步操作(OSWorld)都取得了领先的结果 33 | **,仅在OSworld上略逊于针对Computer Use特化的Claude-3.5-Sonnet和结合外接 GUI Grounding Model 的GPT-4o。 34 | 35 |
36 |
37 | 38 |

CogAgent 祝你圣诞快乐,让大模型自动为你的朋友们送上圣诞祝福吧。

39 |
40 |
41 | 42 |

想提个Issue,让 CogAgent帮你发邮件。

43 |
44 |
45 | 46 | **文档目录** 47 | 48 | - [CogAgent](#cogagent) 49 | - [关于模型](#关于模型) 50 | - [模型资源](#模型资源) 51 | - [模型简介](#模型简介) 52 | - [能力展示](#能力展示) 53 | - [推理和微调成本](#推理和微调成本) 54 | - [模型输入和输出](#模型输入和输出) 55 | - [用户输入部分](#用户输入部分) 56 | - [模型返回部分](#模型返回部分) 57 | - [一个例子](#一个例子) 58 | - [注意事项](#注意事项) 59 | - [运行模型](#运行模型) 60 | - [环境配置](#环境配置) 61 | - [运行 Agent APP 示例](#运行-agent-app-示例) 62 | - [微调模型](#微调模型) 63 | - [先前的工作](#先前的工作) 64 | - [协议](#协议) 65 | - [引用](#引用) 66 | - [研发团队 \& 致谢](#研发团队--致谢) 67 | 68 | ### 推理和微调成本 69 | 70 | + 模型在 `BF16` 精度下推理至少需要使用`29GB`显存。不建议使用 `INT4` 精度推理,性能损失较大。使用`INT4`推理的显存占用约为8GB,使用 71 | `INT8`推理的显存占用约为15GB。在`inference/cli_demo.py` 中,我们已经将这两行注释,你可以取消注释并使用`INT4`或`INT8` 72 | 推理。本方案仅支持英伟达设备。 73 | + 以上所有数据中的GPU指A100, H100 GPU,其他设备显存/内存需自行计算。 74 | + SFT过程中,本代码冻结`Vision Encoder`, Batch Size = 1, 使用`8 * A100` 进行微调,输入token(包含图像的`1600` tokens) 共计 75 | 2048 76 | Tokens。本代码无法在`Vision Encoder`不冻结的情况下进行SFT微调。LORA过程中,不冻结`Vision Encoder`, Batch Size = 1, 使用 77 | `1 * A100` 进行微调,输入token(包含图像的`1600` tokens) 共计 2048 Tokens。在上述情况下,SFT微调需要每张GPU至少需要拥有 78 | `60GB`显存,8张GPU,LORA微调需要每张GPU至少需要拥有`70GB`显存,1张GPU,不可切割。 79 | + `昇腾设备` 未测试SFT微调。仅在`Atlas800训练服务器集群`上进行测试。具体推理代码需要根据`昇腾设备`下载链接中载入模型的方式进行修改。 80 | + 在线体验链接不支持控制电脑,仅支持查看模型的推理结果。我们建议本地部署模型。 81 | 82 | ## 模型输入和输出 83 | 84 | `cogagent-9b-20241220`是一个Agent类执行模型而非对话模型,不支持连续对话,但支持连续的执行历史(也即,每次需要重开对话session,并将过往的历史给模型)。CogAgent的工作流如下图所示: 85 | 86 |
87 | 88 |
89 | 90 | **为了达到最佳的 GUI Agent 性能,我们采用了严格的输入输出格式**。 91 | 这里展示了用户应该怎么整理自己的输入格式化的传入给模型。并获得模型规则的回复。 92 | 93 | ### 用户输入部分 94 | 95 | 您可以参考 [app/client.py#L115](https://github.com/THUDM/CogAgent/blob/e3ca6f4dc94118d3dfb749f195cbb800ee4543ce/app/client.py#L115) 96 | 拼接用户输入提示词。一个最简用户输入拼接代码如下所示: 97 | 98 | ``` python 99 | 100 | current_platform = identify_os() # "Mac" or "WIN" or "Mobile",注意大小写 101 | platform_str = f"(Platform: {current_platform})\n" 102 | format_str = "(Answer in Action-Operation-Sensitive format.)\n" # You can use other format to replace "Action-Operation-Sensitive" 103 | 104 | history_str = "\nHistory steps: " 105 | for index, (grounded_op_func, action) in enumerate(zip(history_grounded_op_funcs, history_actions)): 106 | history_str += f"\n{index}. {grounded_op_func}\t{action}" # start from 0. 107 | 108 | query = f"Task: {task}{history_str}\n{platform_str}{format_str}" # Be careful about the \n 109 | 110 | ``` 111 | 112 | 拼接后的python字符串形如: 113 | 114 | ``` python 115 | "Task: Search for doors, click doors on sale and filter by brands \"Mastercraft\".\nHistory steps: \n0. CLICK(box=[[352,102,786,139]], element_info='Search')\tLeft click on the search box located in the middle top of the screen next to the Menards logo.\n1. TYPE(box=[[352,102,786,139]], text='doors', element_info='Search')\tIn the search input box at the top, type 'doors'.\n2. CLICK(box=[[787,102,809,139]], element_info='SEARCH')\tLeft click on the magnifying glass icon next to the search bar to perform the search.\n3. SCROLL_DOWN(box=[[0,209,998,952]], step_count=5, element_info='[None]')\tScroll down the page to see the available doors.\n4. CLICK(box=[[280,708,710,809]], element_info='Doors on Sale')\tClick the \"Doors On Sale\" button in the middle of the page to view the doors that are currently on sale.\n(Platform: WIN)\n(Answer in Action-Operation format.)\n" 116 | ``` 117 | 118 | 打印结果如下所示: 119 | > 120 | > Task: Search for doors, click doors on sale and filter by brands "Mastercraft". 121 | > 122 | > History steps: 123 | > 124 | > 0. CLICK(box=[[352,102,786,139]], element_info='Search') Left click on the search box located in the middle top of 125 | the screen next to the Menards logo. 126 | > 1. TYPE(box=[[352,102,786,139]], text='doors', element_info='Search') In the search input box at the top, type ' 127 | doors'. 128 | > 2. CLICK(box=[[787,102,809,139]], element_info='SEARCH') Left click on the magnifying glass icon next to the search 129 | bar to perform the search. 130 | > 3. SCROLL_DOWN(box=[[0,209,998,952]], step_count=5, element_info='[None]') Scroll down the page to see the available 131 | doors. 132 | > 4. CLICK(box=[[280,708,710,809]], element_info='Doors on Sale') Click the "Doors On Sale" button in the middle of the 133 | page to view the doors that are currently on sale. 134 | > 135 | > (Platform: WIN) 136 | > 137 | > (Answer in Action-Operation format.) 138 | > 139 | 140 | 若您想仔细了解每个字段的含义和表示,请继续阅读或是参考 [实操文档(中文)的“提示词拼接”章节](https://zhipu-ai.feishu.cn/wiki/D9FTwQ78fitS3CkZHUjcKEWTned)。 141 | 142 | 1. `task` 字段 143 | 144 | 用户输入的任务描述,类似文本格式的prompt,该输入可以指导`cogagent-9b-20241220`模型完成用户任务指令。请保证简洁明了。 145 | 146 | 2. `platform` 字段 147 | 148 | `cogagent-9b-20241220`支持在多个平台上执行可操作Agent功能, 我们支持的带有图形界面的操作系统有三个系统, 149 | - Windows 10,11,请使用 `WIN` 字段。 150 | - Mac 14,15,请使用 `Mac` 字段。 151 | - Android 13,14,15 以及其他GUI和UI操作方式几乎相同的安卓UI发行版,请使用 `Mobile` 字段。 152 | 153 | 如果您使用的是其他系统,效果可能不佳,但可以尝试使用 `Mobile` 字段用于手机设备,`WIN` 字段用于Windows设备,`Mac` 154 | 字段用于Mac设备。 155 | 156 | 3. `format` 字段 157 | 158 | 用户希望`cogagent-9b-20241220`返回何种格式的数据, 这里有以下几种选项: 159 | - `Answer in Action-Operation-Sensitive format.`: 本仓库中demo默认使用的返回方式,返回模型的行为,对应的操作,以及对应的敏感程度。 160 | - `Answer in Status-Plan-Action-Operation format.`: 返回模型的装题,行为,以及相应的操作。 161 | - `Answer in Status-Action-Operation-Sensitive format.`: 返回模型的状态,行为,对应的操作,以及对应的敏感程度。 162 | - `Answer in Status-Action-Operation format.`: 返回模型的状态,行为。 163 | - `Answer in Action-Operation format.` 返回模型的行为,对应的操作。 164 | 165 | 4. `history` 字段 166 | 167 | 拼接顺序和结果应该如下所示: 168 | ``` 169 | query = f'{task}{history}{platform}{format}' 170 | ``` 171 | 5. `继续功能` 172 | CogAgent允许用户让模型`继续回答`。这需要用户在`{task}`后加入`[Continue]\n`字段。在这种情况下,拼接顺序和结果应该如下所示: 173 | ``` 174 | query = f'{task}[Continue]\n{history}{platform}{format}' 175 | ``` 176 | 177 | ### 模型返回部分 178 | 179 | 1. 敏感操作: 包括 `<<敏感操作>> <<一般操作>>` 几种类型,只有`format`字段中含`Sensitive`的时候返回。 180 | 2. `Plan`, `Status`, `Action` 字段: 用于描述模型的行为和操作。只有要求返回对应字段的时候返回,例如带有`Action`则返回 181 | `Action`字段内容。 182 | 3. 常规回答部分,这部分回答会在格式化回答之前,表示综述。 183 | 4. `Grounded Operation` 字段: 184 | 用于描述模型的具体操作,包括操作的位置,类型,以及具体的操作内容。其中 `box` 代表执行区域的坐标,`element_type` 代表执行的元素类型, 185 | `element_info` 代表执行的元素描述。这些信息被一个 `操作指令` 操作所包裹。具体的动作空间请参考[这里](Action_space.md)。 186 | 187 | ### 一个例子 188 | 189 | 用户的任务是希望帮忙将所有邮件标记为已读,用户使用的是 Mac系统,希望返回的是Action-Operation-Sensitive格式。 190 | 正确拼接后的**提示词**应该为: 191 | 192 | ``` 193 | Task: 帮我将所有的邮件标注为已读 194 | History steps: 195 | (Platform: Mac) 196 | (Answer in Action-Operation-Sensitive format.) 197 | ``` 198 | 199 | 注意,即使没有操作历史,也需要在 prompt 中拼接上“History steps:”。接着,这里展现了不同格式要求下的**返回结果**: 200 | 201 | 202 |
203 | Answer in Action-Operation-Sensitive format 204 | 205 | ``` 206 | Action: 点击页面顶部工具栏中的“全部标为已读”按钮,将所有邮件标记为已读。 207 | Grounded Operation: CLICK(box=[[219,186,311,207]], element_type='可点击文本', element_info='全部标为已读') 208 | <<一般操作>> 209 | ``` 210 | 211 |
212 | 213 |
214 | Answer in Status-Plan-Action-Operation format 215 | 216 | ``` 217 | Status: 当前处于邮箱界面[[0, 2, 998, 905]],左侧是邮箱分类[[1, 216, 144, 570]],中间是收件箱[[144, 216, 998, 903]],已经点击“全部标为已读”按钮[[223, 178, 311, 210]]。 218 | Plan: 未来需要: 1. 点击“全部标记为已读”按钮;2. 任务结束。 219 | Action: 点击收件箱页面顶部中间的“全部标记为已读”按钮,将所有邮件标记为已读。 220 | Grounded Operation: CLICK(box=[[219,186,311,207]], element_type='可点击文本', element_info='全部标为已读') 221 | ``` 222 | 223 |
224 | 225 |
226 | Answer in Status-Action-Operation-Sensitive format 227 | 228 | ``` 229 | Status: 当前处于邮箱界面[[0, 2, 998, 905]],左侧是邮箱分类[[1, 216, 144, 570]],中间是收件箱[[144, 216, 998, 903]],已经点击“全部标为已读”按钮[[223, 178, 311, 210]]。 230 | Action: 点击页面顶部工具栏中的“全部标为已读”按钮,将所有邮件标记为已读。 231 | Grounded Operation: CLICK(box=[[219,186,311,207]], element_type='可点击文本', element_info='全部标为已读') 232 | <<一般操作>> 233 | ``` 234 | 235 |
236 | 237 |
238 | Answer in Status-Action-Operation format 239 | 240 | ``` 241 | Status: 当前处于邮箱界面[[0, 2, 998, 905]],左侧是邮箱分类[[1, 216, 144, 570]],中间是收件箱[[144, 216, 998, 903]],已经点击“全部标为已读”按钮[[223, 178, 311, 210]]。 242 | Action: 在收件箱页面顶部,点击“全部标记为已读”按钮,将所有邮件标记为已读。 243 | Grounded Operation: CLICK(box=[[219,186,311,207]], element_type='可点击文本', element_info='全部标为已读') 244 | ``` 245 | 246 |
247 | 248 |
249 | Answer in Action-Operation format 250 | 251 | ``` 252 | Action: 在左侧邮件列表中,右键单击第一封邮件,以打开操作菜单。 253 | Grounded Operation: RIGHT_CLICK(box=[[154,275,343,341]], element_info='[AXCell]') 254 | ``` 255 | 256 |
257 | 258 | ### 注意事项 259 | 260 | 1. 该模型不是对话模型,不支持连续对话,请发送具体指令,并参考我们提供的历史拼接方式进行拼接。 261 | 2. 该模型必须要有图片传入,纯文字对话无法实现GUI Agent任务。 262 | 3. 该模型输出有严格的格式要求,请严格按照我们的要求进行解析。输出格式为 STR 格式,不支持输出JSON 格式。 263 | 264 | ## 运行模型 265 | 266 | ### 环境配置 267 | 268 | 请确保你已安装 **Python 3.10.16** 或者以上版本。并安装以下依赖: 269 | 270 | ```shell 271 | pip install -r requirements.txt 272 | ``` 273 | 274 | 运行一个本地的基于`transformers`的模型推理,你可以通过运行以下命令来运行模型: 275 | 276 | ```shell 277 | python inference/cli_demo.py --model_dir THUDM/cogagent-9b-20241220 --platform "Mac" --max_length 4096 --top_k 1 --output_image_path ./results --format_key status_action_op_sensitive 278 | ``` 279 | 280 | 这是一个命令行交互代码。你需要输入对应的图像路径。 如果模型返回的结果带有bbox,则会输出一张带有bbox的图片,表示需要在这个区域内执行操作,保存的图片为路径为 281 | `output_image_path`中,图片名称为 `{你输入的图片名}_{对话轮次}.png` 。`format_key` 表示你希望通过模型通过哪种格式返回。 282 | `platform` 字段则决定了你服务于哪种平台(比如`Mac`,则你上传的截图都必须是`Mac`系统的截图)。 283 | 284 | 如果你希望运行在线 web demo,这是一个需要连续上传图片进行交互的demo,模型将会返回对应的Bbox和执行类别。该代码与 285 | `HuggingFace Space` 286 | 在线体验效果相同。 287 | 288 | ```shell 289 | python inference/web_demo.py --host 0.0.0.0 --port 7860 --model_dir THUDM/cogagent-9b-20241220 --format_key status_action_op_sensitive --platform "Mac" --output_dir ./results 290 | ``` 291 | 292 | ### 运行 Agent APP 示例 293 | 294 | 我们为开发者准备了一个基础的Demo APP,用于演示`cogagent-9b-20241220`模型的GUI能力,该Demo展示了如何在带有GPU的服务器上部署模型, 295 | 并在本地的电脑上运行`cogagent-9b-20241220`模型执行自动化GUI操作。 296 | 297 | > 我们无法保证AI的行为的安全性,请在使用时谨慎操作。 298 | > 299 | > 本示例仅供学术参考,我们不承担由本示例引起的任何法责任。 300 | 301 | 如果你对该 APP 感兴趣,欢迎查看[文档](app/README.md) 302 | 303 | ### 微调模型 304 | 305 | 如果你对微调`cogagent-9b-20241220`模型感兴趣,欢迎查看[这里](finetune/README.md)。 306 | 307 | ## 先前的工作 308 | 309 | 在2023年11月,我们发布了CogAgent的第一代模型,现在,你可以在 [CogVLM&CogAgent官方仓库](https://github.com/THUDM/CogVLM) 310 | 找到相关代码和权重地址。 311 | 312 |
313 | 314 |
315 | 316 | 317 | 318 | 324 | 330 | 331 |
319 |

CogVLM

320 |

📖 Paper: CogVLM: Visual Expert for Pretrained Language Models

321 |

CogVLM 是一个强大的开源视觉语言模型(VLM)。CogVLM-17B拥有100亿的视觉参数和70亿的语言参数,支持490*490分辨率的图像理解和多轮对话。

322 |

CogVLM-17B 17B在10个经典的跨模态基准测试中取得了最先进的性能包括NoCaps, Flicker30k captioning, RefCOCO, RefCOCO+, RefCOCOg, Visual7W, GQA, ScienceQA, VizWiz VQA 和 TDIUC 基准测试。

323 |
325 |

CogAgent

326 |

📖 Paper: CogAgent: A Visual Language Model for GUI Agents

327 |

CogAgent 是一个基于CogVLM改进的开源视觉语言模型。CogAgent-18B拥有110亿的视觉参数和70亿的语言参数, 支持1120*1120分辨率的图像理解。在CogVLM的能力之上,它进一步拥有了GUI图像Agent的能力。

328 |

CogAgent-18B 在9个经典的跨模态基准测试中实现了最先进的通用性能,包括 VQAv2, OK-VQ, TextVQA, ST-VQA, ChartQA, infoVQA, DocVQA, MM-Vet, 和 POPE 测试基准。它在包括AITW和Mind2Web在内的GUI操作数据集上显著超越了现有的模型。

329 |
332 | 333 | ## 协议 334 | 335 | - 本 github 仓库代码的使用 [Apache2.0 LICENSE](LICENSE)。 336 | 337 | - 模型权重的使用请遵循 [Model License](MODEL_LICENSE)。 338 | 339 | ## 引用 340 | 341 | 如果您认为我们的工作有用,欢迎引用我们的文章: 342 | 343 | ``` 344 | @misc{hong2023cogagent, 345 | title={CogAgent: A Visual Language Model for GUI Agents}, 346 | author={Wenyi Hong and Weihan Wang and Qingsong Lv and Jiazheng Xu and Wenmeng Yu and Junhui Ji and Yan Wang and Zihan Wang and Yuxiao Dong and Ming Ding and Jie Tang}, 347 | year={2023}, 348 | eprint={2312.08914}, 349 | archivePrefix={arXiv}, 350 | primaryClass={cs.CV} 351 | } 352 | 353 | ``` 354 | 355 | ## 研发团队 & 致谢 356 | 357 | **研发机构**:清华大学,智谱AI 358 | 359 | **团队成员**:洪文逸,纪骏辉,潘立航,岳远昌,庞常毓,薛思言,王果,王维汉,胥嘉政,杨慎,顾晓韬,东昱晓,唐杰 360 | 361 | **致谢**:我们感谢智谱 AI 数据团队的大力的支持,包括张笑涵、薛钊、陈陆、杜竟杰、王思瑜、张颖,以及所有的标注员。他们为 CogAgent 362 | 模型的训练和测试数据的收集、标注付出了艰辛的工作。我们同时感谢智谱AI开源团队张昱轩、胡晓伟、陈浩为模型开源付出的工程努力。 363 | -------------------------------------------------------------------------------- /app/client.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script creates an OpenAI Request demo for the CogAgent model. 3 | All parameters are now controlled by arguments, allowing you to run the script 4 | with a single line of code specifying all needed parameters. 5 | """ 6 | 7 | import argparse 8 | import base64 9 | import platform 10 | import pyautogui 11 | import re 12 | import os 13 | import gradio as gr 14 | import threading 15 | from PIL import Image, ImageDraw 16 | from typing import List, Dict, Any, Optional, Tuple 17 | from functools import partial 18 | # Import the `agent` function from `register` (assumed to be in the same directory) 19 | from register import agent 20 | 21 | stop_event = threading.Event() 22 | 23 | 24 | def create_chat_completion( 25 | api_key: str, 26 | base_url: str, 27 | model: str, 28 | messages: List[Dict[str, Any]], 29 | max_length: int = 512, 30 | top_p: float = 1.0, 31 | temperature: float = 1.0, 32 | presence_penalty: float = 1.0, 33 | stream: bool = False, 34 | ) -> Any: 35 | """ 36 | Creates a chat completion request. 37 | 38 | Parameters: 39 | - api_key (str): API key for authentication. 40 | - base_url (str): The base URL for the API endpoint. 41 | - model (str): Model name to use (e.g., "cogagent-9b-20241220"). 42 | - messages (List[Dict[str, Any]]): A list of messages for the conversation, where each message is a dictionary. 43 | - max_length (int, optional): The maximum length of the response. Default is 512. 44 | - top_p (float, optional): Controls nucleus sampling. Default is 1.0. 45 | - temperature (float, optional): Sampling temperature to control randomness. Default is 1.0. 46 | - presence_penalty (float, optional): Presence penalty for the model. Default is 1.0. 47 | - stream (bool, optional): Whether to stream the response. Default is False. 48 | 49 | Returns: 50 | - Any: The response from the chat completion API. 51 | """ 52 | from openai import OpenAI 53 | 54 | client = OpenAI(api_key=api_key, base_url=base_url) 55 | response = client.chat.completions.create( 56 | model=model, 57 | messages=messages, 58 | stream=stream, 59 | timeout=60, 60 | max_tokens=max_length, 61 | temperature=temperature, 62 | presence_penalty=presence_penalty, 63 | top_p=top_p, 64 | ) 65 | if response: 66 | return response.choices[0].message.content 67 | 68 | def encode_image(image_path: str) -> str: 69 | """ 70 | Encodes an image file into a base64 string. 71 | 72 | Parameters: 73 | image_path (str): The path to the image file. 74 | 75 | Returns: 76 | str: The base64-encoded string representation of the image. 77 | 78 | Raises: 79 | FileNotFoundError: If the specified image file is not found. 80 | IOError: If an error occurs during file reading. 81 | """ 82 | try: 83 | with open(image_path, "rb") as image_file: 84 | return base64.b64encode(image_file.read()).decode("utf-8") 85 | except FileNotFoundError: 86 | raise FileNotFoundError(f"File not found: {image_path}") 87 | except IOError as e: 88 | raise IOError(f"Error reading file {image_path}: {e}") 89 | 90 | 91 | def identify_os() -> str: 92 | """ 93 | Identifies the operating system based on the platform information. 94 | 95 | Returns: 96 | - str: "Mac" if the system is macOS, "WIN" if the system is Windows. 97 | 98 | Raises: 99 | - ValueError: If the operating system is not supported. 100 | """ 101 | 102 | #TODO: Need check if windows platform can run the demo. 103 | 104 | os_detail = platform.platform().lower() 105 | if "mac" in os_detail: 106 | return "Mac" 107 | elif "windows" in os_detail: 108 | return "WIN" 109 | else: 110 | raise ValueError( 111 | f"This {os_detail} operating system is not currently supported!" 112 | ) 113 | 114 | 115 | def formatting_input( 116 | task: str, history_step: List[str], history_action: List[str], round_num: int 117 | ) -> List[Dict[str, Any]]: 118 | """ 119 | Formats input data into a structured message for further processing. 120 | 121 | Parameters: 122 | - task (str): The task or query the user is asking about. 123 | - history_step (List[str]): A list of historical steps in the conversation. 124 | - history_action (List[str]): A list of actions corresponding to the history steps. 125 | - round_num (int): The current round number (used to identify the image file). 126 | 127 | Returns: 128 | - List[Dict[str, Any]]: A list of messages formatted as dictionaries. 129 | 130 | Raises: 131 | - ValueError: If the lengths of `history_step` and `history_action` do not match. 132 | """ 133 | current_platform = identify_os() 134 | platform_str = f"(Platform: {current_platform})\n" 135 | format_str = "(Answer in Status-Plan-Action-Operation-Sensitive format.)\n" 136 | 137 | if len(history_step) != len(history_action): 138 | raise ValueError("Mismatch in lengths of history_step and history_action.") 139 | 140 | history_str = "\nHistory steps: " 141 | for index, (step, action) in enumerate(zip(history_step, history_action)): 142 | history_str += f"\n{index}. {step}\t{action}" 143 | 144 | query = f"Task: {task}{history_str}\n{platform_str}{format_str}" 145 | 146 | # Create image URL with base64 encoding 147 | img_url = f"data:image/jpeg;base64,{encode_image(f'caches/img_{round_num}.png')}" 148 | 149 | messages = [ 150 | { 151 | "role": "user", 152 | "content": [ 153 | { 154 | "type": "text", 155 | "text": query, 156 | }, 157 | { 158 | "type": "image_url", 159 | "image_url": {"url": img_url}, 160 | }, 161 | ], 162 | }, 163 | ] 164 | return messages 165 | 166 | 167 | def shot_current_screen(round_num: int): 168 | """ 169 | Captures a screenshot of the current screen and saves it to the cache directory. 170 | 171 | Parameters: 172 | - round_num (int): The current round number for naming the image file. 173 | """ 174 | img = pyautogui.screenshot() 175 | img.save(f"caches/img_{round_num}.png") 176 | 177 | 178 | def extract_grounded_operation(response: str) -> Tuple[Optional[str], Optional[str]]: 179 | """ 180 | Extracts the grounded operation and action from the response text. 181 | 182 | Parameters: 183 | - response (str): The model's response text. 184 | 185 | Returns: 186 | - (step, action) (Tuple[Optional[str], Optional[str]]): Extracted step and action from the response. 187 | """ 188 | grounded_pattern = r"Grounded Operation:\s*(.*)" 189 | action_pattern = r"Action:\s*(.*)" 190 | 191 | step = None 192 | action = None 193 | 194 | matches_history = re.search(grounded_pattern, response) 195 | matches_actions = re.search(action_pattern, response) 196 | if matches_history: 197 | step = matches_history.group(1) 198 | if matches_actions: 199 | action = matches_actions.group(1) 200 | 201 | return step, action 202 | 203 | 204 | def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str): 205 | """ 206 | Draws red bounding boxes on the given image and saves it. 207 | 208 | Parameters: 209 | - image (PIL.Image.Image): The image on which to draw the bounding boxes. 210 | - boxes (List[List[float]]): A list of bounding boxes, each defined as [x_min, y_min, x_max, y_max]. 211 | Coordinates are expected to be normalized (0 to 1). 212 | - save_path (str): The path to save the updated image. 213 | """ 214 | draw = ImageDraw.Draw(image) 215 | for box in boxes: 216 | x_min = int(box[0] * image.width) 217 | y_min = int(box[1] * image.height) 218 | x_max = int(box[2] * image.width) 219 | y_max = int(box[3] * image.height) 220 | draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) 221 | image.save(save_path) 222 | 223 | 224 | def extract_bboxes(response: str, round_num: int): 225 | """ 226 | Extracts bounding boxes from the response and draws them on the corresponding screenshot. 227 | 228 | Parameters: 229 | - response (str): The response text containing bounding box information. 230 | - round_num (int): The round number to identify which image to annotate. 231 | """ 232 | box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]" 233 | matches = re.findall(box_pattern, response) 234 | if matches: 235 | boxes = [[int(x) / 1000 for x in match] for match in matches] 236 | img_save_path = f"caches/img_{round_num}_bbox.png" 237 | image = Image.open(f"caches/img_{round_num}.png").convert("RGB") 238 | draw_boxes_on_image(image, boxes, img_save_path) 239 | 240 | 241 | def is_balanced(s: str) -> bool: 242 | """ 243 | Checks if the parentheses in a string are balanced. 244 | 245 | Parameters: 246 | - s (str): The string to check. 247 | 248 | Returns: 249 | - bool: True if parentheses are balanced, False otherwise. 250 | """ 251 | stack = [] 252 | mapping = {")": "(", "]": "[", "}": "{"} 253 | if "(" not in s: 254 | return False 255 | for char in s: 256 | if char in mapping.values(): 257 | stack.append(char) 258 | elif char in mapping.keys(): 259 | if not stack or mapping[char] != stack.pop(): 260 | return False 261 | return not stack 262 | 263 | 264 | def extract_operation(step: Optional[str]) -> Dict[str, Any]: 265 | """ 266 | Extracts the operation and other details from the grounded operation step. 267 | 268 | Parameters: 269 | - step (Optional[str]): The grounded operation string. 270 | 271 | Returns: 272 | - Dict[str, Any]: A dictionary containing the operation details. 273 | """ 274 | if step is None or not is_balanced(step): 275 | return {"operation": "NO_ACTION"} 276 | 277 | op, detail = step.split("(", 1) 278 | detail = "(" + detail 279 | others_pattern = r"(\w+)\s*=\s*([^,)]+)" 280 | others = re.findall(others_pattern, detail) 281 | Grounded_Operation = dict(others) 282 | 283 | boxes_pattern = r"box=\[\[(.*?)\]\]" 284 | boxes = re.findall(boxes_pattern, detail) 285 | if boxes: 286 | Grounded_Operation["box"] = list(map(int, boxes[0].split(","))) 287 | Grounded_Operation["operation"] = op.strip() 288 | 289 | return Grounded_Operation 290 | 291 | 292 | def workflow( 293 | api_key: str, 294 | base_url: str, 295 | model: str, 296 | chatbot: List[List[str]], 297 | max_length: int, 298 | top_p: float, 299 | temperature: float, 300 | ): 301 | """ 302 | Main workflow for handling a chatbot interaction loop. 303 | 304 | Parameters: 305 | - api_key (str): API key for accessing the chatbot API. 306 | - base_url (str): Base URL for the chatbot API. 307 | - model (str): Model name to use. 308 | - chatbot (list): The initial history of the chatbot interaction. 309 | - max_length (int): Maximum response length for the chatbot. 310 | - top_p (float): Top-p sampling value for response generation. 311 | - temperature (float): Temperature value for response randomness. 312 | 313 | Yields: 314 | - history (list): Updated history of the chatbot interaction. 315 | - output_image (str): Path to the generated output image. 316 | """ 317 | history_step = [] 318 | history_action = [] 319 | history = chatbot 320 | round_num = 1 321 | task = chatbot[-1][0] if chatbot and chatbot[-1] else "No task provided" 322 | 323 | try: 324 | while True: 325 | print(f"\033[92m Round {round_num}: \033[0m") 326 | if round_num > 15: 327 | break # Exit the loop after 15 rounds 328 | 329 | # Capture the current screen for the round 330 | shot_current_screen(round_num) 331 | 332 | # Format input messages for the chatbot 333 | messages = formatting_input(task, history_step, history_action, round_num) 334 | 335 | # Call the chatbot API to get a response 336 | response = create_chat_completion( 337 | api_key=api_key, 338 | base_url=base_url, 339 | model=model, 340 | messages=messages, 341 | max_length=max_length, 342 | top_p=top_p, 343 | temperature=temperature, 344 | stream=False, 345 | ) 346 | 347 | # Extract grounded operations and actions from the response 348 | step, action = extract_grounded_operation(response) 349 | history_step.append(step if step else "") 350 | history_action.append(action if action else "") 351 | 352 | # Process bounding boxes and operations 353 | extract_bboxes(response, round_num) 354 | grounded_operation = extract_operation(step) 355 | 356 | if grounded_operation["operation"] == "NO_ACTION": 357 | break 358 | 359 | # Execute the grounded operation using the agent 360 | status = agent(grounded_operation) 361 | 362 | # Update the history with the latest response 363 | history.append([f"Round {round_num}", response]) 364 | 365 | # Prepare the output image path 366 | output_image = f"caches/img_{round_num}_bbox.png" 367 | if status == "END" or stop_event.is_set(): 368 | output_image = f"caches/img_{round_num - 1}_bbox.png" 369 | yield history, output_image 370 | break 371 | else: 372 | yield history, output_image 373 | 374 | round_num += 1 375 | finally: 376 | # Clear the stop event at the end of the workflow 377 | stop_event.clear() 378 | 379 | 380 | def switch(): 381 | """ 382 | Sets the stop event to terminate the workflow. 383 | """ 384 | stop_event.set() 385 | 386 | 387 | def gradio_web( 388 | api_key: str, 389 | base_url: str, 390 | model: str, 391 | client_name: str = "127.0.0.1", 392 | client_port: int = 8080, 393 | ): 394 | """ 395 | Launches a Gradio-based web application for interacting with CogAgent. 396 | 397 | Parameters: 398 | - api_key (str): OpenAI API key. 399 | - base_url (str): OpenAI API base URL. 400 | - model (str): Model name to use. 401 | - presence_penalty (float): Presence penalty for the model. 402 | - client_name (str): The gradio IP or hostname for hosting the app. 403 | - client_port (int): The port number for the gradio. 404 | """ 405 | with gr.Blocks() as demo: 406 | gr.HTML("

CogAgent Gradio Chat Demo

") 407 | 408 | # Top row: Chatbot and Image upload 409 | with gr.Row(): 410 | with gr.Column(scale=1, min_width=160): 411 | chatbot = gr.Chatbot(height=240) 412 | with gr.Column(scale=1, min_width=160): 413 | img_path = gr.Image( 414 | label="Operation Area Screenshot", 415 | type="filepath", 416 | show_fullscreen_button=True, 417 | ) 418 | 419 | # Bottom row: Task input, system controls 420 | with gr.Row(): 421 | with gr.Column(scale=1, min_width=160): 422 | task = gr.Textbox( 423 | show_label=True, 424 | placeholder="Please enter your task description", 425 | label="Task", 426 | max_length=320, 427 | ) 428 | submit_button = gr.Button("Submit") 429 | clear_button = gr.Button("Clear History") 430 | 431 | with gr.Column(scale=1, min_width=160): 432 | max_length = gr.Slider( 433 | minimum=0, 434 | maximum=8192, 435 | value=4096, 436 | step=1.0, 437 | label="Maximum Length", 438 | interactive=True, 439 | ) 440 | top_p = gr.Slider( 441 | minimum=0.0, 442 | maximum=1.0, 443 | value=0.8, 444 | step=0.01, 445 | label="Top P", 446 | interactive=True, 447 | ) 448 | temperature = gr.Slider( 449 | minimum=0.01, 450 | maximum=1.0, 451 | value=0.6, 452 | step=0.01, 453 | label="Temperature", 454 | interactive=True, 455 | ) 456 | 457 | # Stop button 458 | with gr.Row(): 459 | with gr.Column(scale=1, min_width=160): 460 | stop_button = gr.Button("Stop", variant="stop", size="lg") 461 | 462 | # Define functions for button actions 463 | def user_input(task, history): 464 | """Handles user task submission.""" 465 | return "", history + [[task, "Please wait for CogAgent's operation..."]] 466 | 467 | def raise_error(): 468 | """Raises an error to terminate the program.""" 469 | raise gr.Error("The program has been terminated!") 470 | 471 | def warning_start(): 472 | """Displays a warning when CogAgent starts processing.""" 473 | return gr.Warning( 474 | "CogAgent is processing. Please do not interact with the keyboard or mouse.", 475 | duration=60, 476 | ) 477 | 478 | def warning_end(): 479 | """Displays a warning when CogAgent finishes processing.""" 480 | return gr.Warning("CogAgent has finished. Please input a new task.") 481 | 482 | # Create a partial workflow function with fixed parameters 483 | workflow_partial = partial(workflow, api_key, base_url, model) 484 | 485 | # Button actions and callbacks 486 | submit_button.click( 487 | user_input, inputs=[task, chatbot], outputs=[task, chatbot], queue=False 488 | ).then(warning_start, inputs=None, outputs=None).then( 489 | workflow_partial, 490 | inputs=[chatbot, max_length, top_p, temperature], 491 | outputs=[chatbot, img_path], 492 | ).then(warning_end, inputs=None, outputs=None) 493 | 494 | clear_button.click( 495 | lambda: (None, None), inputs=None, outputs=[chatbot, img_path], queue=False 496 | ) 497 | 498 | stop_button.click(switch).then(raise_error, inputs=None, outputs=None) 499 | 500 | demo.queue() 501 | demo.launch(server_name=client_name, server_port=client_port) 502 | 503 | if __name__ == "__main__": 504 | parser = argparse.ArgumentParser( 505 | description="Run the CogAgent demo with all parameters controlled by command-line arguments." 506 | ) 507 | parser.add_argument("--api_key", required=True, help="OpenAI API Key.") 508 | parser.add_argument("--base_url", required=True, help="OpenAI API Base URL.") 509 | parser.add_argument("--model", default="CogAgent", help="Model name to use.") 510 | parser.add_argument( 511 | "--client_name", 512 | default="127.0.0.1", 513 | help="The IP or hostname for Gradio.", 514 | ) 515 | parser.add_argument( 516 | "--client_port", type=int, default=8080, help="The port number for Gradio." 517 | ) 518 | 519 | args = parser.parse_args() 520 | 521 | if not os.path.exists("caches"): 522 | os.makedirs("caches") 523 | 524 | gradio_web( 525 | api_key=args.api_key, 526 | base_url=args.base_url, 527 | model=args.model, 528 | client_name=args.client_name, 529 | client_port=args.client_port, 530 | ) 531 | -------------------------------------------------------------------------------- /finetune/finetune.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import jieba 4 | import dataclasses as dc 5 | import functools 6 | from collections.abc import Callable, Mapping, Sequence 7 | from pathlib import Path 8 | from typing import Annotated, Any, Union 9 | import numpy as np 10 | import ruamel.yaml as yaml 11 | import torch 12 | import typer 13 | from datasets import Dataset, Split 14 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 15 | from peft import PeftConfig, get_peft_config, get_peft_model 16 | from rouge_chinese import Rouge 17 | from torch import nn 18 | from transformers import ( 19 | AutoModelForCausalLM, 20 | AutoTokenizer, 21 | EvalPrediction, 22 | GenerationConfig, 23 | PreTrainedTokenizer, 24 | Seq2SeqTrainingArguments, 25 | ) 26 | from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq 27 | from transformers import Seq2SeqTrainer as _Seq2SeqTrainer 28 | from datasets import load_dataset, DatasetDict, NamedSplit 29 | from typing import Optional 30 | from PIL import Image 31 | 32 | # For Ascend NPU, please add this 33 | # import torch_npu 34 | # from torch_npu.contrib import transfer_to_npu 35 | 36 | app = typer.Typer(pretty_exceptions_show_locals=False) 37 | 38 | class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq): 39 | def __call__(self, features, return_tensors=None): 40 | output_ids = ( 41 | [feature["output_ids"] for feature in features] 42 | if "output_ids" in features[0].keys() 43 | else None 44 | ) 45 | if output_ids is not None: 46 | max_output_length = max(len(out) for out in output_ids) 47 | if self.pad_to_multiple_of is not None: 48 | max_output_length = ( 49 | (max_output_length + self.pad_to_multiple_of - 1) 50 | // self.pad_to_multiple_of 51 | * self.pad_to_multiple_of 52 | ) 53 | for feature in features: 54 | remainder = [self.tokenizer.pad_token_id] * ( 55 | max_output_length - len(feature["output_ids"]) 56 | ) 57 | if isinstance(feature["output_ids"], list): 58 | feature["output_ids"] = feature["output_ids"] + remainder 59 | else: 60 | feature["output_ids"] = np.concatenate( 61 | [feature["output_ids"], remainder] 62 | ).astype(np.int64) 63 | return super().__call__(features, return_tensors) 64 | 65 | 66 | class Seq2SeqTrainer(_Seq2SeqTrainer): 67 | 68 | def prediction_step( 69 | self, 70 | model: nn.Module, 71 | inputs: dict, 72 | prediction_loss_only: bool, 73 | ignore_keys=None, 74 | **gen_kwargs, 75 | ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 76 | with torch.no_grad(): 77 | if self.args.predict_with_generate: 78 | output_ids = inputs.pop("output_ids", None) 79 | 80 | if "labels" in inputs: 81 | del inputs["labels"] 82 | 83 | loss, generated_tokens, labels = super().prediction_step( 84 | model=model, 85 | inputs=inputs, 86 | prediction_loss_only=prediction_loss_only, 87 | ignore_keys=ignore_keys, 88 | **gen_kwargs, 89 | ) 90 | 91 | if generated_tokens is not None: 92 | generated_tokens = generated_tokens[:, inputs["input_ids"].size()[1] :] 93 | 94 | if self.args.predict_with_generate: 95 | labels = output_ids 96 | 97 | del inputs, output_ids 98 | torch.cuda.empty_cache() 99 | 100 | return loss, generated_tokens, labels 101 | 102 | 103 | @dc.dataclass 104 | class DataConfig(object): 105 | train_file: Optional[str] = None 106 | val_file: Optional[str] = None 107 | test_file: Optional[str] = None 108 | num_proc: Optional[int] = None 109 | 110 | @property 111 | def data_format(self) -> str: 112 | return Path(self.train_file).suffix 113 | 114 | @property 115 | def data_files(self) -> dict[NamedSplit, str]: 116 | return { 117 | split: data_file 118 | for split, data_file in zip( 119 | [Split.TRAIN, Split.VALIDATION, Split.TEST], 120 | [self.train_file, self.val_file, self.test_file], 121 | ) 122 | if data_file is not None 123 | } 124 | 125 | 126 | @dc.dataclass 127 | class FinetuningConfig(object): 128 | data_config: DataConfig 129 | 130 | max_input_length: int 131 | max_output_length: int 132 | combine: bool 133 | freezeV: bool 134 | 135 | training_args: Seq2SeqTrainingArguments = dc.field( 136 | default_factory=lambda: Seq2SeqTrainingArguments(output_dir="./output") 137 | ) 138 | peft_config: Optional[PeftConfig] = None 139 | 140 | def __post_init__(self): 141 | if not self.training_args.do_eval or self.data_config.val_file is None: 142 | self.training_args.do_eval = False 143 | self.training_args.evaluation_strategy = "no" 144 | self.data_config.val_file = None 145 | else: 146 | self.training_args.per_device_eval_batch_size = ( 147 | self.training_args.per_device_eval_batch_size 148 | or self.training_args.per_device_train_batch_size 149 | ) 150 | 151 | @classmethod 152 | def from_dict(cls, **kwargs) -> "FinetuningConfig": 153 | training_args = kwargs.get("training_args", None) 154 | if training_args is not None and not isinstance( 155 | training_args, Seq2SeqTrainingArguments 156 | ): 157 | gen_config = training_args.get("generation_config") 158 | if not isinstance(gen_config, GenerationConfig): 159 | training_args["generation_config"] = GenerationConfig(**gen_config) 160 | kwargs["training_args"] = Seq2SeqTrainingArguments(**training_args) 161 | 162 | data_config = kwargs.get("data_config") 163 | if not isinstance(data_config, DataConfig): 164 | kwargs["data_config"] = DataConfig(**data_config) 165 | 166 | peft_config = kwargs.get("peft_config", None) 167 | if peft_config is not None and not isinstance(peft_config, PeftConfig): 168 | kwargs["peft_config"] = get_peft_config(config_dict=peft_config) 169 | return cls(**kwargs) 170 | 171 | @classmethod 172 | def from_file(cls, path: Union[str, Path]) -> "FinetuningConfig": 173 | path = Path(path) 174 | parser = yaml.YAML(typ="safe", pure=True) 175 | parser.indent(mapping=2, offset=2, sequence=4) 176 | parser.default_flow_style = False 177 | kwargs = parser.load(path) 178 | return cls.from_dict(**kwargs) 179 | 180 | 181 | def _load_datasets( 182 | data_dir: str, 183 | data_format: str, 184 | data_files: dict[NamedSplit, str], 185 | num_proc: Optional[int], 186 | ) -> DatasetDict: 187 | if data_format == ".jsonl": 188 | dataset_dct = load_dataset( 189 | data_dir, 190 | data_files=data_files, 191 | split=None, 192 | num_proc=num_proc, 193 | ) 194 | else: 195 | raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.") 196 | return dataset_dct 197 | 198 | 199 | class DataManager(object): 200 | def __init__(self, data_dir: str, data_config: DataConfig): 201 | self._num_proc = data_config.num_proc 202 | 203 | self._dataset_dct = _load_datasets( 204 | data_dir, 205 | data_config.data_format, 206 | data_config.data_files, 207 | self._num_proc, 208 | ) 209 | 210 | def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]: 211 | return self._dataset_dct.get(split, None) 212 | 213 | def get_dataset( 214 | self, 215 | split: NamedSplit, 216 | process_fn: Callable[[dict[str, Any]], dict[str, Any]], 217 | batched: bool = True, 218 | remove_orig_columns: bool = True, 219 | ) -> Optional[Dataset]: 220 | orig_dataset = self._get_dataset(split) 221 | if orig_dataset is None: 222 | return 223 | if remove_orig_columns: 224 | remove_columns = orig_dataset.column_names 225 | else: 226 | remove_columns = None 227 | return orig_dataset.map( 228 | process_fn, 229 | batched=batched, 230 | remove_columns=remove_columns, 231 | num_proc=self._num_proc, 232 | # This is default params of orig_dataset.map, and you can change it smaller 233 | # https://github.com/THUDM/GLM-4/issues/277 234 | writer_batch_size=1000, 235 | batch_size=1000, 236 | ) 237 | 238 | 239 | def process_batch( 240 | batch: Mapping[str, Sequence], 241 | tokenizer: PreTrainedTokenizer, 242 | max_input_length: int, 243 | max_output_length: int, 244 | combine: bool, 245 | ) -> dict[str, list]: 246 | batched_conv = batch["messages"] 247 | batched_input_ids = [] 248 | batched_attention_mask = [] 249 | batched_position_ids = [] 250 | batched_labels = [] 251 | batched_images = [] 252 | 253 | max_length = max_input_length + max_output_length 254 | 255 | for conv in batched_conv: 256 | input_ids = [151331, 151333] 257 | attention_mask = [1, 1] 258 | position_ids = list(range(len(input_ids))) 259 | loss_masks = [False, False] 260 | images = [] 261 | 262 | conv[0]['image'] = Image.open(conv[len(conv)-2]['image']).convert("RGB") 263 | conv[len(conv)-2]['image'] = None 264 | 265 | for message in conv: 266 | loss_mask_val = ( 267 | False if message["role"] in ("system", "user", "observation") else True 268 | ) 269 | 270 | new_input_ids_all = tokenizer.apply_chat_template( 271 | [message], tokenize=True, return_dict=True, padding=True 272 | ) 273 | new_input_ids = new_input_ids_all["input_ids"][0][2:] 274 | new_attention_mask = new_input_ids_all["attention_mask"][0][2:] 275 | new_position_ids = list( 276 | range(position_ids[-1] + 1, position_ids[-1] + 1 + len(new_input_ids)) 277 | ) 278 | if 'images' in new_input_ids_all: 279 | images.append(new_input_ids_all["images"]) 280 | 281 | new_loss_masks = [loss_mask_val] * len(new_input_ids) 282 | input_ids += new_input_ids 283 | attention_mask += new_attention_mask 284 | position_ids += new_position_ids 285 | loss_masks += new_loss_masks 286 | 287 | input_ids.append(151336) # EOS 288 | attention_mask.append(1) 289 | position_ids.append(len(position_ids)) 290 | loss_masks.append(False) 291 | 292 | labels = [] 293 | for input_id, mask in zip(input_ids, loss_masks): 294 | if mask: 295 | labels.append(input_id) 296 | else: 297 | labels.append(-100) 298 | 299 | batched_input_ids.append(input_ids[:max_length]) 300 | batched_attention_mask.append(attention_mask[:max_length]) 301 | batched_position_ids.append(position_ids[:max_length]) 302 | batched_labels.append(labels[:max_length]) 303 | batched_images.append(images[0][0]) 304 | 305 | del ( 306 | batched_conv, 307 | conv, 308 | input_ids, 309 | attention_mask, 310 | position_ids, 311 | loss_masks, 312 | message, 313 | new_input_ids, 314 | new_loss_masks, 315 | labels, 316 | input_id, 317 | mask, 318 | ) 319 | torch.cuda.empty_cache() 320 | 321 | return { 322 | "input_ids": batched_input_ids, 323 | "attention_mask": batched_attention_mask, 324 | "position_ids": batched_position_ids, 325 | "labels": batched_labels, 326 | "images": batched_images, 327 | } 328 | 329 | 330 | def process_batch_eval( 331 | batch: Mapping[str, Sequence], 332 | tokenizer: PreTrainedTokenizer, 333 | max_input_length: int, 334 | max_output_length: int, 335 | combine: bool, 336 | ) -> dict[str, list]: 337 | batched_conv = batch["messages"] 338 | batched_input_ids = [] 339 | batched_attention_mask = [] 340 | batched_position_ids = [] 341 | batched_output_ids = [] 342 | batched_images = [] 343 | 344 | for conv in batched_conv: 345 | 346 | idx = len(conv) - 2 347 | conv[0]['image'] = Image.open(conv[idx]["image"]).convert("RGB") 348 | conv[idx]['image'] = None 349 | 350 | new_input_ids_all = tokenizer.apply_chat_template( 351 | conv, tokenize=True, return_dict=True, padding=True 352 | ) 353 | 354 | input_ids = new_input_ids_all["input_ids"][0] 355 | attention_mask = new_input_ids_all["attention_mask"][0] 356 | position_ids = list(range(len(input_ids))) 357 | 358 | dialogue_parts = [0] 359 | user_idx = [] 360 | for idx, token_id in enumerate(input_ids): 361 | if token_id == 151337: 362 | dialogue_parts.append(idx + 1) 363 | elif token_id == 151336: 364 | user_idx.append(idx) 365 | 366 | if user_idx[-1] != len(input_ids): 367 | user_idx.append(len(input_ids)) 368 | 369 | # Split the conversation into multiple dialogue segments 370 | for end_idx in range(1, len(dialogue_parts)): 371 | input_segment = input_ids[: dialogue_parts[end_idx]] 372 | attention_segment = attention_mask[: dialogue_parts[end_idx]] 373 | position_segment = position_ids[: dialogue_parts[end_idx]] 374 | output_segment = input_ids[ 375 | dialogue_parts[end_idx] : user_idx[end_idx] 376 | ] 377 | 378 | output_segment.append(151336) # Add EOS token 379 | 380 | batched_input_ids.append(input_segment[:max_input_length]) 381 | batched_attention_mask.append(attention_segment[:max_input_length]) 382 | batched_position_ids.append(position_segment[:max_input_length]) 383 | batched_output_ids.append(output_segment[:max_output_length]) 384 | batched_images.append(new_input_ids_all["images"][0]) 385 | 386 | del ( 387 | batched_conv, 388 | input_ids, 389 | attention_mask, 390 | position_ids, 391 | new_input_ids_all, 392 | output_segment, 393 | ) 394 | torch.cuda.empty_cache() 395 | 396 | return { 397 | "input_ids": batched_input_ids, 398 | "attention_mask": batched_attention_mask, 399 | "position_ids": batched_position_ids, 400 | "output_ids": batched_output_ids, 401 | "images": batched_images, 402 | } 403 | 404 | 405 | def load_tokenizer_and_model( 406 | model_dir: str, 407 | peft_config: Optional[PeftConfig] = None, 408 | ): 409 | tokenizer = AutoTokenizer.from_pretrained( 410 | model_dir, padding_side="left", trust_remote_code=True 411 | ) 412 | if peft_config is not None: 413 | model = AutoModelForCausalLM.from_pretrained( 414 | model_dir, 415 | trust_remote_code=True, 416 | use_cache=False, 417 | torch_dtype=torch.bfloat16, # Must use BFloat 16 418 | ) 419 | model = get_peft_model(model, peft_config) 420 | model.print_trainable_parameters() 421 | else: 422 | model = AutoModelForCausalLM.from_pretrained( 423 | model_dir, 424 | trust_remote_code=True, 425 | use_cache=False, 426 | torch_dtype=torch.bfloat16, 427 | ) 428 | return tokenizer, model 429 | 430 | 431 | def compute_metrics(eval_preds: EvalPrediction, tokenizer): 432 | batched_pred_ids, batched_label_ids = eval_preds 433 | batched_pred_ids[batched_pred_ids == -100] = tokenizer.pad_token_id 434 | batched_label_ids[batched_label_ids == -100] = tokenizer.pad_token_id 435 | metrics_dct = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} 436 | for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids): 437 | pred_txt = tokenizer.decode(pred_ids, skip_special_tokens=True).strip() 438 | label_txt = tokenizer.decode(label_ids, skip_special_tokens=True).strip() 439 | pred_tokens = list(jieba.cut(pred_txt)) 440 | label_tokens = list(jieba.cut(label_txt)) 441 | rouge = Rouge() 442 | scores = rouge.get_scores(" ".join(pred_tokens), " ".join(label_tokens)) 443 | for k, v in scores[0].items(): 444 | metrics_dct[k].append(round(v["f"] * 100, 4)) 445 | metrics_dct["bleu-4"].append( 446 | sentence_bleu( 447 | [label_tokens], 448 | pred_tokens, 449 | smoothing_function=SmoothingFunction().method3, 450 | ) 451 | ) 452 | return {k: np.mean(v) for k, v in metrics_dct.items()} 453 | 454 | 455 | @app.command() 456 | def main( 457 | data_dir: Annotated[str, typer.Argument(help="")], 458 | model_dir: Annotated[ 459 | str, 460 | typer.Argument( 461 | help="A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file." 462 | ), 463 | ], 464 | config_file: Annotated[str, typer.Argument(help="")], 465 | auto_resume_from_checkpoint: str = typer.Argument( 466 | default="", 467 | help="If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training", 468 | ), 469 | ): 470 | ft_config = FinetuningConfig.from_file(config_file) 471 | tokenizer, model = load_tokenizer_and_model( 472 | model_dir, peft_config=ft_config.peft_config 473 | ) 474 | 475 | if ft_config.freezeV: 476 | for param in model.transformer.vision.parameters(): 477 | param.requires_grad = False 478 | data_manager = DataManager(data_dir, ft_config.data_config) 479 | 480 | train_dataset = data_manager.get_dataset( 481 | Split.TRAIN, 482 | functools.partial( 483 | process_batch, 484 | combine=ft_config.combine, # Not use now 485 | tokenizer=tokenizer, 486 | max_input_length=ft_config.max_input_length, 487 | max_output_length=ft_config.max_output_length, 488 | ), 489 | batched=True, 490 | ) 491 | print("train_dataset:", train_dataset) 492 | 493 | val_dataset = data_manager.get_dataset( 494 | Split.VALIDATION, 495 | functools.partial( 496 | process_batch_eval, 497 | combine=ft_config.combine, 498 | tokenizer=tokenizer, 499 | max_input_length=ft_config.max_input_length, 500 | max_output_length=ft_config.max_output_length, 501 | ), 502 | batched=True, 503 | ) 504 | 505 | if val_dataset is not None: 506 | print("val_dataset:", val_dataset) 507 | test_dataset = data_manager.get_dataset( 508 | Split.TEST, 509 | functools.partial( 510 | process_batch_eval, 511 | combine=ft_config.combine, 512 | tokenizer=tokenizer, 513 | max_input_length=ft_config.max_input_length, 514 | max_output_length=ft_config.max_output_length, 515 | ), 516 | batched=True, 517 | ) 518 | if test_dataset is not None: 519 | print("test_dataset:", test_dataset) 520 | 521 | ft_config.training_args.generation_config.pad_token_id = 151329 522 | ft_config.training_args.generation_config.eos_token_id = [151329, 151336, 151338] 523 | 524 | trainer = Seq2SeqTrainer( 525 | model=model, 526 | args=ft_config.training_args, 527 | data_collator=DataCollatorForSeq2Seq( 528 | tokenizer=tokenizer, 529 | padding="longest", 530 | return_tensors="pt", 531 | ), 532 | train_dataset=train_dataset, 533 | eval_dataset=val_dataset, 534 | compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer), 535 | ) 536 | 537 | if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None: 538 | trainer.train() 539 | else: 540 | output_dir = ft_config.training_args.output_dir 541 | dirlist = os.listdir(output_dir) 542 | checkpoint_sn = 0 543 | for checkpoint_str in dirlist: 544 | if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1: 545 | checkpoint = int(checkpoint_str.replace("checkpoint-", "")) 546 | if checkpoint > checkpoint_sn: 547 | checkpoint_sn = checkpoint 548 | if auto_resume_from_checkpoint.upper() == "YES": 549 | if checkpoint_sn > 0: 550 | model.gradient_checkpointing_enable() 551 | model.enable_input_require_grads() 552 | checkpoint_directory = os.path.join( 553 | output_dir, "checkpoint-" + str(checkpoint_sn) 554 | ) 555 | print("resume checkpoint from checkpoint-" + str(checkpoint_sn)) 556 | trainer.train(resume_from_checkpoint=checkpoint_directory) 557 | else: 558 | trainer.train() 559 | else: 560 | if auto_resume_from_checkpoint.isdigit(): 561 | if int(auto_resume_from_checkpoint) > 0: 562 | checkpoint_sn = int(auto_resume_from_checkpoint) 563 | model.gradient_checkpointing_enable() 564 | model.enable_input_require_grads() 565 | checkpoint_directory = os.path.join( 566 | output_dir, "checkpoint-" + str(checkpoint_sn) 567 | ) 568 | print("resume checkpoint from checkpoint-" + str(checkpoint_sn)) 569 | trainer.train(resume_from_checkpoint=checkpoint_directory) 570 | else: 571 | print( 572 | auto_resume_from_checkpoint, 573 | "The specified checkpoint sn(" 574 | + auto_resume_from_checkpoint 575 | + ") has not been saved. Please search for the correct checkpoint in the model output directory", 576 | ) 577 | 578 | if test_dataset is not None: 579 | trainer.predict(test_dataset) 580 | 581 | 582 | if __name__ == "__main__": 583 | app() 584 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CogAgent: An open-sourced VLM-based GUI Agent 2 | 3 | [中文文档](README_zh.md) 4 | 5 | - 🔥 🆕 **December 2024:** We open-sourced **the latest version of the CogAgent-9B-20241220 model**. Compared to the 6 | previous version of CogAgent, `CogAgent-9B-20241220` features significant improvements in GUI perception, reasoning 7 | accuracy, action space completeness, task universality, and generalization. It supports bilingual (Chinese and 8 | English) interaction through both screen captures and natural language. 9 | 10 | - 🏆 **June 2024:** CogAgent was accepted by **CVPR 2024** and recognized as a conference Highlight (top 3%). 11 | 12 | - **December 2023:** We **open-sourced the first GUI Agent**: **CogAgent** (with the former repository 13 | available [here](https://github.com/THUDM/CogVLM)) and **published the corresponding paper: 14 | 📖 [CogAgent Paper](https://arxiv.org/abs/2312.08914)**. 15 | 16 | ## Model Introduction 17 | 18 | | Model | Model Download Links | Technical Documentation | Online Demo | 19 | |:--------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 20 | | cogagent-9b-20241220 | [🤗 HuggingFace](https://huggingface.co/THUDM/cogagent-9b-20241220)
[🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/cogagent-9b-20241220)
[🟣 WiseModel](https://wisemodel.cn/models/ZhipuAI/cogagent-9b-20241220)
[🧩 Modelers (Ascend)](https://modelers.cn/models/zhipuai/cogagent-9b-20241220) | [📄 Official Technical Blog](https://cogagent.aminer.cn/blog#/articles/cogagent-9b-20241220-technical-report)
[📘 Practical Guide (Chinese)](https://zhipu-ai.feishu.cn/wiki/MhPYwtpBhinuoikNIYYcyu8dnKv?fromScene=spaceOverview) | [🤗 HuggingFace Space](https://huggingface.co/spaces/THUDM-HF-SPACE/CogAgent-Demo)
[🤖 ModelScope Space](https://modelscope.cn/studios/ZhipuAI/CogAgent-Demo)
[🧩 Modelers Space (Ascend)](https://modelers.cn/spaces/zhipuai/CogAgent) | 21 | 22 | ### Model Overview 23 | 24 | `CogAgent-9B-20241220` model is based on [GLM-4V-9B](https://huggingface.co/THUDM/glm-4v-9b), a bilingual open-source 25 | VLM base model. Through data collection and optimization, multi-stage training, and strategy improvements, 26 | `CogAgent-9B-20241220` achieves significant advancements in GUI perception, inference prediction accuracy, action space 27 | completeness, and generalizability across tasks. The model supports bilingual (Chinese and English) interaction with 28 | both screenshots and language input. This version of the CogAgent model has already been applied in 29 | ZhipuAI's [GLM-PC product](https://cogagent.aminer.cn/home). We hope the release of this model can assist researchers 30 | and developers in advancing the research and applications of GUI agents based on vision-language models. 31 | 32 | ### Capability Demonstrations 33 | 34 | The CogAgent-9b-20241220 model has achieved state-of-the-art results across multiple platforms and categories in GUI 35 | Agent tasks and GUI Grounding Benchmarks. In 36 | the [CogAgent-9b-20241220 Technical Blog](https://cogagent.aminer.cn/blog#/articles/cogagent-9b-20241220-technical-report), 37 | we compared it against API-based commercial models (GPT-4o-20240806, Claude-3.5-Sonnet), commercial API + GUI Grounding 38 | models (GPT-4o + UGround, GPT-4o + OS-ATLAS), and open-source GUI Agent models (Qwen2-VL, ShowUI, SeeClick). The results 39 | demonstrate that **CogAgent leads in GUI localization (Screenspot), single-step operations (OmniAct), the Chinese 40 | step-wise in-house benchmark (CogAgentBench-basic-cn), and multi-step operations (OSWorld)**, with only a slight 41 | disadvantage in OSWorld compared to Claude-3.5-Sonnet, which specializes in Computer Use, and GPT-4o combined with 42 | external GUI Grounding models. 43 | 44 |
45 |
46 | 47 |

CogAgent wishes you a Merry Christmas! Let the large model automatically send Christmas greetings to your friends.

48 |
49 |
50 | 51 |

Want to open an issue? Let CogAgent help you send an email.

52 |
53 |
54 | 55 | 56 | **Table of Contents** 57 | 58 | - [CogAgent](#cogagent) 59 | - [Model Introduction](#model-introduction) 60 | - [Model Overview](#model-overview) 61 | - [Capability Demonstrations](#capability-demonstrations) 62 | - [Inference and Fine-tuning Costs](#inference-and-fine-tuning-costs) 63 | - [Model Inputs and Outputs](#model-inputs-and-outputs) 64 | - [User Input](#user-input) 65 | - [Model Output](#model-output) 66 | - [An Example](#an-example) 67 | - [Notes](#notes) 68 | - [Running the Model](#running-the-model) 69 | - [Environment Setup](#environment-setup) 70 | - [Running an Agent APP Example](#running-an-agent-app-example) 71 | - [Fine-tuning the Model](#fine-tuning-the-model) 72 | - [Previous Work](#previous-work) 73 | - [License](#license) 74 | - [Citation](#citation) 75 | - [Research and Development Team \& Acknowledgements](#research-and-development-team---acknowledgements) 76 | 77 | ### Inference and Fine-tuning Costs 78 | 79 | + The model requires at least 29GB of VRAM for inference at `BF16` precision. Using `INT4` precision for inference is 80 | not recommended due to significant performance loss. The VRAM usage for `INT4` inference is about 8GB, while for 81 | `INT8` inference it is about 15GB. In the `inference/cli_demo.py` file, we have commented out these two lines. You can 82 | uncomment them and use `INT4` or `INT8` inference. This solution is only supported on NVIDIA devices. 83 | + All GPU references above refer to A100 or H100 GPUs. For other devices, you need to calculate the required GPU/CPU 84 | memory accordingly. 85 | + During SFT (Supervised Fine-Tuning), this codebase freezes the `Vision Encoder`, uses a batch size of 1, and trains on 86 | `8 * A100` GPUs. The total input tokens (including images, which account for `1600` tokens) add up to 2048 tokens. 87 | This codebase cannot conduct SFT fine-tuning without freezing the `Vision Encoder`. 88 | For LoRA fine-tuning, `Vision Encoder` is **not** frozen; the batch size is 1, using `1 * A100` GPU. The total input 89 | tokens (including images, `1600` tokens) also amount to 2048 tokens. In the above setup, SFT fine-tuning requires at 90 | least `60GB` of GPU memory per GPU (with 8 GPUs), while LoRA fine-tuning requires at least `70GB` of GPU memory on a 91 | single GPU (cannot be split). 92 | + `Ascend devices` have not been tested for SFT fine-tuning. We have only tested them on the `Atlas800` training server 93 | cluster. You need to modify the inference code accordingly based on the loading mechanism described in the 94 | `Ascend device` download link. 95 | + The online demo link does **not** support controlling computers; it only allows you to view the model's inference 96 | results. We recommend deploying the model locally. 97 | 98 | ## Model Inputs and Outputs 99 | 100 | `cogagent-9b-20241220` is an agent-type execution model rather than a conversational model. It does not support 101 | continuous dialogue, but it **does** support a continuous execution history. (In other words, each time a new 102 | conversation session needs to be started, and the past history should be provided to the model.) The workflow of 103 | CogAgent is illustrated as following: 104 | 105 |
106 | 107 |
108 | 109 | **To achieve optimal GUI Agent performance, we have adopted a strict input-output format.** 110 | Below is how users should format their inputs and feed them to the model, and how to interpret the model’s responses. 111 | 112 | ### User Input 113 | 114 | You can refer 115 | to [app/client.py#L115](https://github.com/THUDM/CogAgent/blob/e3ca6f4dc94118d3dfb749f195cbb800ee4543ce/app/client.py#L115) 116 | for constructing user input prompts. A minimal example of user input concatenation code is shown below: 117 | 118 | ``` python 119 | 120 | current_platform = identify_os() # "Mac" or "WIN" or "Mobile". Pay attention to case sensitivity. 121 | platform_str = f"(Platform: {current_platform})\n" 122 | format_str = "(Answer in Action-Operation-Sensitive format.)\n" # You can use other format to replace "Action-Operation-Sensitive" 123 | 124 | history_str = "\nHistory steps: " 125 | for index, (grounded_op_func, action) in enumerate(zip(history_grounded_op_funcs, history_actions)): 126 | history_str += f"\n{index}. {grounded_op_func}\t{action}" # start from 0. 127 | 128 | query = f"Task: {task}{history_str}\n{platform_str}{format_str}" # Be careful about the \n 129 | 130 | ``` 131 | 132 | The concatenated Python string: 133 | 134 | ``` python 135 | "Task: Search for doors, click doors on sale and filter by brands \"Mastercraft\".\nHistory steps: \n0. CLICK(box=[[352,102,786,139]], element_info='Search')\tLeft click on the search box located in the middle top of the screen next to the Menards logo.\n1. TYPE(box=[[352,102,786,139]], text='doors', element_info='Search')\tIn the search input box at the top, type 'doors'.\n2. CLICK(box=[[787,102,809,139]], element_info='SEARCH')\tLeft click on the magnifying glass icon next to the search bar to perform the search.\n3. SCROLL_DOWN(box=[[0,209,998,952]], step_count=5, element_info='[None]')\tScroll down the page to see the available doors.\n4. CLICK(box=[[280,708,710,809]], element_info='Doors on Sale')\tClick the \"Doors On Sale\" button in the middle of the page to view the doors that are currently on sale.\n(Platform: WIN)\n(Answer in Action-Operation format.)\n" 136 | ``` 137 | 138 | Printed prompt: 139 | > 140 | > Task: Search for doors, click doors on sale and filter by brands "Mastercraft". 141 | > 142 | > History steps: 143 | > 144 | > 0. CLICK(box=[[352,102,786,139]], element_info='Search') Left click on the search box located in the middle top of 145 | the screen next to the Menards logo. 146 | > 1. TYPE(box=[[352,102,786,139]], text='doors', element_info='Search') In the search input box at the top, type ' 147 | doors'. 148 | > 2. CLICK(box=[[787,102,809,139]], element_info='SEARCH') Left click on the magnifying glass icon next to the search 149 | bar to perform the search. 150 | > 3. SCROLL_DOWN(box=[[0,209,998,952]], step_count=5, element_info='[None]') Scroll down the page to see the available 151 | doors. 152 | > 4. CLICK(box=[[280,708,710,809]], element_info='Doors on Sale') Click the "Doors On Sale" button in the middle of the 153 | page to view the doors that are currently on sale. 154 | > 155 | > 156 | > (Platform: WIN) 157 | > 158 | > (Answer in Action-Operation format.) 159 | > 160 | 161 | If you want to understand the meaning and representation of each field in detail, please continue reading or refer to 162 | the [Practical Documentation (in Chinese), "Prompt Concatenation" section](https://zhipu-ai.feishu.cn/wiki/D9FTwQ78fitS3CkZHUjcKEWTned). 163 | 164 | 1. **`task` field** 165 | The user’s task description, in text format similar to a prompt. This input instructs the `cogagent-9b-20241220` 166 | model on how to carry out the user’s request. Keep it concise and clear. 167 | 168 | 2. **`platform` field** 169 | `cogagent-9b-20241220` supports agent operations on multiple platforms with graphical interfaces. We currently 170 | support three systems: 171 | - Windows 10, 11: Use the `WIN` field. 172 | - macOS 14, 15: Use the `Mac` field. 173 | - Android 13, 14, 15 (and other Android UI variants with similar GUI operations): Use the `Mobile` field. 174 | 175 | If your system is not among these, the effectiveness may be suboptimal. You can try using `Mobile` for mobile 176 | devices, `WIN` for Windows, or `Mac` for Mac. 177 | 178 | 3. **`format` field** 179 | The format in which the user wants `cogagent-9b-20241220` to return data. We provide several options: 180 | - `Answer in Action-Operation-Sensitive format.`: The default demo return type in this repo. Returns the model’s 181 | actions, corresponding operations, and the sensitivity level. 182 | - `Answer in Status-Plan-Action-Operation format.`: Returns the model’s status, plan, and corresponding operations. 183 | - `Answer in Status-Action-Operation-Sensitive format.`: Returns the model’s status, actions, corresponding 184 | operations, and sensitivity. 185 | - `Answer in Status-Action-Operation format.`: Returns the model’s status and actions. 186 | - `Answer in Action-Operation format.`: Returns the model’s actions and corresponding operations. 187 | 188 | 4. **`history` field** 189 | This should be concatenated in the following order: 190 | ``` 191 | query = f'{task}{history}{platform}{format}' 192 | ``` 193 | 194 | 5. **`Continue` field** 195 | CogAgent allows users to let the model `continue answering`. This requires users to append the `[Continue]\n` field after `{task}`. In such cases, the concatenation sequence and result should be as follows: 196 | ``` 197 | query = f'{task}[Continue]\n{history}{platform}{format}' 198 | ``` 199 | 200 | ### Model Output 201 | 202 | 1. **Sensitive operations**: Includes `<<敏感操作>>` (“sensitive operation”) and `<<一般操作>>` (“general operation”). 203 | These are only returned if you request the `Sensitive` format. 204 | 2. **`Plan`, `Status`, `Action` fields**: Used to describe the model’s behavior and operations. Only returned if you 205 | request the corresponding fields. For example, if the format includes `Action`, then the model returns the `Action` 206 | field. 207 | 3. **General answer section**: A summary that appears prior to the formatted answer. 208 | 4. **`Grounded Operation` field**: 209 | Describes the model’s specific operations, including the location of the operation, the operation type, and the 210 | action details. The `box` attribute indicates the coordinate region for execution, `element_type` indicates the 211 | element type, and `element_info` describes the element. These details are wrapped within a “操作指令” (operation 212 | command). For the definition of the action space, please refer to [here](Action_space.md). 213 | 214 | ### An Example 215 | 216 | Suppose the user wants to mark all emails as read. The user is on a Mac, and the user wants the model to return in 217 | `Action-Operation-Sensitive` format. The properly **concatenated prompt** should be: 218 | 219 | ``` 220 | Task: Please mark all my emails as read 221 | History steps: 222 | (Platform: Mac) 223 | (Answer in Action-Operation-Sensitive format.) 224 | ``` 225 | 226 | Note: even if there are no historical actions, "History steps:" still needs to be appended in the prompt. Below are * 227 | *sample outputs** for different format requirements: 228 | 229 |
230 | Answer in Action-Operation-Sensitive format 231 | 232 | ``` 233 | Action: Click the 'Mark all as read' button in the top toolbar of the page to mark all emails as read. 234 | Grounded Operation: CLICK(box=[[219,186,311,207]], element_type='Clickable text', element_info='Mark all emails as read') 235 | <<一般操作>> 236 | ``` 237 | 238 |
239 | 240 |
241 | Answer in Status-Plan-Action-Operation format 242 | 243 | ``` 244 | Status: Currently in the email interface [[0, 2, 998, 905]], with the email categories on the left [[1, 216, 144, 570]], and the inbox in the center [[144, 216, 998, 903]]. The "Mark all as read" button has been clicked [[223, 178, 311, 210]]. 245 | Plan: Future tasks: 1. Click the 'Mark all as read' button; 2. Task complete. 246 | Action: Click the "Mark all as read" button at the top center of the inbox page to mark all emails as read. 247 | Grounded Operation: CLICK(box=[[219,186,311,207]], element_type='Clickable text', element_info='Mark all emails as read') 248 | ``` 249 | 250 |
251 | 252 |
253 | Answer in Status-Action-Operation-Sensitive format 254 | 255 | ``` 256 | Status: Currently in the email interface [[0, 2, 998, 905]], with the email categories on the left [[1, 216, 144, 570]], and the inbox in the center [[144, 216, 998, 903]]. The "Mark all as read" button has been clicked [[223, 178, 311, 210]]. 257 | Action: Click the "Mark all as read" button at the top center of the inbox page to mark all emails as read. 258 | Grounded Operation: CLICK(box=[[219,186,311,207]], element_type='Clickable text', element_info='Mark all emails as read') 259 | <<一般操作>> 260 | ``` 261 | 262 |
263 | 264 |
265 | Answer in Status-Action-Operation format 266 | 267 | ``` 268 | Status: Currently in the email interface [[0, 2, 998, 905]], with the email categories on the left [[1, 216, 144, 570]], and the inbox in the center [[144, 216, 998, 903]]. The "Mark all as read" button has been clicked [[223, 178, 311, 210]]. 269 | Action: Click the "Mark all as read" button at the top center of the inbox page to mark all emails as read. 270 | Grounded Operation: CLICK(box=[[219,186,311,207]], element_type='Clickable text', element_info='Mark all emails as read') 271 | ``` 272 | 273 |
274 | 275 |
276 | Answer in Action-Operation format 277 | 278 | ``` 279 | Action: Right-click the first email in the left email list to open the action menu. 280 | Grounded Operation: RIGHT_CLICK(box=[[154,275,343,341]], element_info='[AXCell]') 281 | ``` 282 | 283 |
284 | 285 | ### Notes 286 | 287 | 1. This model is **not** a conversational model and does **not** support continuous dialogue. Please send specific 288 | commands and reference our recommended method for concatenating the history. 289 | 2. The model **requires** images as input; pure text conversation cannot achieve GUI Agent tasks. 290 | 3. The model’s output adheres to a strict format. Please parse it strictly according to our requirements. The output is 291 | in **string** format; JSON output is **not** supported. 292 | 293 | ## Running the Model 294 | 295 | ### Environment Setup 296 | 297 | Make sure you have installed **Python 3.10.16** or above, and then install the following dependencies: 298 | 299 | ```shell 300 | pip install -r requirements.txt 301 | ``` 302 | 303 | To run local inference based on `transformers`, you can run the command below: 304 | 305 | ```shell 306 | python inference/cli_demo.py --model_dir THUDM/cogagent-9b-20241220 --platform "Mac" --max_length 4096 --top_k 1 --output_image_path ./results --format_key status_action_op_sensitive 307 | ``` 308 | 309 | This is a command-line interactive code. You will need to provide the path to your images. If the model returns results 310 | containing bounding boxes, it will output an image with those bounding boxes, indicating the region where the operation 311 | should be executed. The image is saved to `output_image_path`, with the file name `{your_input_image_name}_{round}.png`. 312 | The `format_key` indicates in which format you want the model to respond. The `platform` field specifies which platform 313 | you are using (e.g., `Mac`). Therefore, all uploaded screenshots must be from macOS if `platform` is set to `Mac`. 314 | 315 | If you want to run an online web demo, which supports continuous image uploads for interactive inference, you can run: 316 | 317 | ```shell 318 | python inference/web_demo.py --host 0.0.0.0 --port 7860 --model_dir THUDM/cogagent-9b-20241220 --format_key status_action_op_sensitive --platform "Mac" --output_dir ./results 319 | ``` 320 | 321 | This code provides the same experience as the `HuggingFace Space` online demo. The model will return the corresponding 322 | bounding boxes and execution categories. 323 | 324 | ### Running an Agent APP Example 325 | 326 | We have prepared a basic demo app for developers to illustrate the GUI capabilities of `cogagent-9b-20241220`. The demo 327 | shows how to deploy the model on a GPU-equipped server and run the `cogagent-9b-20241220` model locally to perform 328 | automated GUI operations. 329 | 330 | > We cannot guarantee the safety of AI behavior; please exercise caution when using it. 331 | > This example is only for academic reference. We assume no legal responsibility for any issues resulting from this 332 | > example. 333 | 334 | If you are interested in this APP, feel free to check out the [documentation](app/README.md). 335 | 336 | ### Fine-tuning the Model 337 | 338 | If you are interested in fine-tuning the `cogagent-9b-20241220` model, please refer to [here](finetune/README.md). 339 | 340 | ## Previous Work 341 | 342 | In November 2023, we released the first generation of CogAgent. You can find related code and model weights in 343 | the [CogVLM & CogAgent Official Repository](https://github.com/THUDM/CogVLM). 344 | 345 |
346 | 347 |
348 | 349 | 350 | 351 | 357 | 363 | 364 |
352 |

CogVLM

353 |

📖 Paper: CogVLM: Visual Expert for Pretrained Language Models

354 |

CogVLM is a powerful open-source Vision-Language Model (VLM). CogVLM-17B has 10B visual parameters and 7B language parameters, supporting image understanding at a resolution of 490x490, as well as multi-round dialogue.

355 |

CogVLM-17B achieves state-of-the-art performance on 10 classic multimodal benchmarks, including NoCaps, Flicker30k captioning, RefCOCO, RefCOCO+, RefCOCOg, Visual7W, GQA, ScienceQA, VizWiz VQA, and TDIUC.

356 |
358 |

CogAgent

359 |

📖 Paper: CogAgent: A Visual Language Model for GUI Agents

360 |

CogAgent is an open-source vision-language model improved upon CogVLM. CogAgent-18B has 11B visual parameters and 7B language parameters. It supports image understanding at a resolution of 1120x1120. Building on CogVLM’s capabilities, CogAgent further incorporates a GUI image agent ability.

361 |

CogAgent-18B delivers state-of-the-art general performance on 9 classic vision-language benchmarks, including VQAv2, OK-VQ, TextVQA, ST-VQA, ChartQA, infoVQA, DocVQA, MM-Vet, and POPE. It also significantly outperforms existing models on GUI operation datasets such as AITW and Mind2Web.

362 |
365 | 366 | ## License 367 | 368 | - The [Apache2.0 LICENSE](LICENSE) applies to the use of the code in this GitHub repository. 369 | - For the model weights, please follow the [Model License](MODEL_LICENSE). 370 | 371 | ## Citation 372 | 373 | If you find our work helpful, please consider citing the following papers 374 | 375 | ``` 376 | @misc{hong2023cogagent, 377 | title={CogAgent: A Visual Language Model for GUI Agents}, 378 | author={Wenyi Hong and Weihan Wang and Qingsong Lv and Jiazheng Xu and Wenmeng Yu and Junhui Ji and Yan Wang and Zihan Wang and Yuxiao Dong and Ming Ding and Jie Tang}, 379 | year={2023}, 380 | eprint={2312.08914}, 381 | archivePrefix={arXiv}, 382 | primaryClass={cs.CV} 383 | } 384 | 385 | ``` 386 | 387 | ## Research and Development Team & Acknowledgements 388 | 389 | **R&D Institutions**: Tsinghua University, Zhipu AI 390 | 391 | **Team members**: Wenyi Hong, Junhui Ji, Lihang Pan, Yuanchang Yue, Changyu Pang, Siyan Xue, Guo Wang, Weihan Wang, 392 | Jiazheng Xu, Shen Yang, Xiaotao Gu, Yuxiao Dong, Jie Tang 393 | 394 | **Acknowledgement**: We would like to thank the Zhipu AI data team for their strong support, including Xiaohan Zhang, 395 | Zhao Xue, Lu Chen, Jingjie Du, Siyu Wang, Ying Zhang, and all annotators. They worked hard to collect and annotate the 396 | training and testing data of the CogAgent model. We also thank Yuxuan Zhang, Xiaowei Hu, and Hao Chen from the Zhipu AI 397 | open source team for their engineering efforts in open sourcing the model. 398 | --------------------------------------------------------------------------------