├── .gitignore ├── LICENSE ├── README.md ├── README_cn.md ├── assets ├── AI-logo.png ├── all_tabs_demo.gif ├── asr_demo.gif ├── chat_demo.gif ├── segmentation_demo.gif ├── tts_demo.gif ├── video_convertor_demo.gif ├── video_inpainter_demo.gif ├── visualchat_demo.jpg └── visualchat_demo.png ├── configs ├── asr_demo.yml ├── base.yml ├── chatbot_demo.yml ├── segmentation_demo.yml ├── tts_demo.yml ├── video_convertor_demo.yml ├── video_inpainter_demo.yml ├── visualchat_demo.yml └── webui_configs.yml ├── demo ├── bgm │ ├── Paris.mp3 │ ├── illusionary_daytime.mp3 │ ├── time_back.mp3 │ └── windy_hill.mp3 ├── fastsam │ └── examples │ │ ├── dogs.jpg │ │ ├── sa_10039.jpg │ │ ├── sa_11025.jpg │ │ ├── sa_1309.jpg │ │ ├── sa_192.jpg │ │ ├── sa_414.jpg │ │ ├── sa_561.jpg │ │ ├── sa_862.jpg │ │ └── sa_8776.jpg └── video_inpainter │ ├── test-sample0.mp4 │ ├── test-sample1.mp4 │ ├── test-sample2.mp4 │ ├── test-sample3.mp4 │ └── test-sample4.mp4 ├── model_weights └── chatglm │ └── chatglm2-6b-int4 │ ├── config.json │ ├── configuration_chatglm.py │ ├── modeling_chatglm.py │ ├── quantization.py │ ├── tokenization_chatglm.py │ └── tokenizer_config.json ├── models ├── RAFT │ ├── __init__.py │ ├── corr.py │ ├── datasets.py │ ├── demo.py │ ├── extractor.py │ ├── raft.py │ ├── update.py │ └── utils │ │ ├── __init__.py │ │ ├── augmentor.py │ │ ├── flow_viz.py │ │ ├── flow_viz_pt.py │ │ ├── frame_utils.py │ │ └── utils.py ├── __init__.py ├── misc.py ├── modules │ ├── base_module.py │ ├── deformconv.py │ ├── flow_comp_raft.py │ ├── flow_loss_utils.py │ ├── sparse_transformer.py │ └── spectral_norm.py ├── propainter.py ├── recurrent_flow_completion.py ├── sam │ ├── README.md │ ├── __init__.py │ ├── assets │ │ ├── masks1.png │ │ ├── masks2.jpg │ │ ├── model_diagram.png │ │ ├── notebook1.png │ │ └── notebook2.png │ ├── linter.sh │ ├── notebooks │ │ ├── automatic_mask_generator_example.ipynb │ │ ├── images │ │ │ ├── dog.jpg │ │ │ ├── groceries.jpg │ │ │ └── truck.jpg │ │ ├── onnx_model_example.ipynb │ │ └── predictor_example.ipynb │ ├── scripts │ │ ├── amg.py │ │ └── export_onnx_model.py │ ├── segment_anything │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── automatic_mask_generator.py │ │ ├── build_sam.py │ │ ├── modeling │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── image_encoder.py │ │ │ ├── mask_decoder.py │ │ │ ├── prompt_encoder.py │ │ │ ├── sam.py │ │ │ └── transformer.py │ │ ├── predictor.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── amg.py │ │ │ ├── onnx.py │ │ │ └── transforms.py │ ├── setup.cfg │ └── setup.py └── tracker │ ├── base_tracker.py │ ├── config │ └── __init__.py │ ├── inference │ ├── __init__.py │ ├── image_feature_store.py │ ├── inference_core.py │ ├── kv_memory_store.py │ ├── memory_manager.py │ ├── object_info.py │ ├── object_manager.py │ └── utils │ │ ├── __init__.py │ │ ├── args_utils.py │ │ ├── burst_utils.py │ │ ├── frame_utils.py │ │ └── results_utils.py │ ├── model │ ├── __init__.py │ ├── aux_modules.py │ ├── big_modules.py │ ├── channel_attn.py │ ├── cutie.py │ ├── group_modules.py │ ├── losses.py │ ├── modules.py │ ├── transformer │ │ ├── __init__.py │ │ ├── object_summarizer.py │ │ ├── object_transformer.py │ │ ├── positional_encoding.py │ │ └── transformer_layers.py │ └── utils │ │ ├── __init__.py │ │ ├── memory_utils.py │ │ ├── parameter_groups.py │ │ └── resnet.py │ └── utils │ ├── __init__.py │ ├── image_saver.py │ ├── load_subset.py │ ├── log_integrator.py │ ├── logger.py │ ├── mask_mapper.py │ ├── palette.py │ ├── pano_utils.py │ ├── point_features.py │ ├── range_transform.py │ ├── tensor_utils.py │ └── time_estimator.py ├── requirements.txt ├── tools ├── __init__.py ├── ai_wrapper.py ├── base.py ├── chatglm_handler.py ├── chatvlm_handler.py ├── edgetts_handler.py ├── fastsam_handler.py ├── gpt_handler.py ├── inpainting_handler.py ├── sam_handler.py ├── tracker_handler.py ├── visualglm_handler.py └── whisper_handler.py ├── utils ├── __init__.py ├── chatglm_utils.py ├── fastsam │ ├── __init__.py │ ├── decoder.py │ ├── model.py │ ├── predict.py │ ├── prompt.py │ ├── tools.py │ └── utils.py ├── gradio_tabs │ ├── __init__.py │ ├── audio_tabs.py │ ├── chat_tabs.py │ ├── image_tabs.py │ └── video_tabs.py ├── gradio_utils.py ├── mask_painter.py ├── misc.py └── painter.py └── webui.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 jasonaidm 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README_cn.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 8 | 9 |

AI-WEBUI: A universal web interface for AI creation, 一款好用的图像、音频、视频处理工具

10 | 11 | 12 | ⭐ 如果对你有帮助,麻烦给个star哈,感谢! 🤗 13 |
14 | 15 | ## 🌟 1. 简介 16 | ai-webui是一个基于浏览器操作的界面,旨在提供一个通用的AI创作平台。 17 | drawing 18 | drawing 19 | 20 | 本项目提供了图像分割、目标追踪、图像修复、语音识别、语音合成等基本功能,以及集成得到的聊天问答、视频翻译、视频去水印等有利于大幅提高短视频创作效率的高级功能。 21 | 22 | ## ⚡2. 安装 23 | 24 | 要安装并使用AI-WebUI,请按照以下步骤操作: 25 | 26 | ### 2.1 克隆此项目到本地计算机 27 | 28 | ```bash 29 | git clone https://github.com/jasonaidm/ai_webui.git 30 | ``` 31 | 32 | ### 2.2 进入项目目录 33 | 34 | ```bash 35 | cd ai_webui 36 | ``` 37 | ### 2.3 创建虚拟环境 38 | ```bash 39 | conda create -n aiwebui python=3.11 40 | conda activate aiwebui 41 | ``` 42 | 43 | ### 2.4 安装所需的依赖项 44 | 45 | ```bash 46 | apt install ffmpeg -y 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | 51 | ## 🚀3. 快速开始 52 | 53 | 使用AI-WebUI非常简单。只需要按照界面上的指示进行操作即可。你可以通过上传视频、音频、图片或输入文本等方式输入创作要素,并与模型的输出进行交互。 54 | ```bash 55 | python webui.py -c ./configs/webui_configs.yml 56 | ``` 57 | 58 | 启动后,在浏览器中打开 `http://localhost:9090/?__theme=dark` 查看项目界面。 59 | 60 | ### 3.1 单一功能示例 61 | 考虑到部分用户个人电脑的GPU性能问题,我们提供单功能示例,用户可以单独运行一个AI功能,而不需要启动整个项目。 62 | 63 | 1. 图像分割 64 | - 全景分割 65 | - 基于points坐标信息的提示词分割 66 | - 基于文本提示词的分割 67 | ```bash 68 | python webui.py -c ./configs/segmentation_demo.yml 69 | ``` 70 | ![segmentation示例](./assets/segmentation_demo.gif) 71 | 72 | 2. 语音识别 73 | - 中英文等多语言识别 74 | ```bash 75 | python webui.py -c ./configs/asr_demo.yml 76 | ``` 77 | ![asr示例](./assets/asr_demo.gif) 78 | 79 | 3. 语音合成 80 | - 中英文等多语言合成 81 | ```bash 82 | python webui.py -c ./configs/tts_demo.yml 83 | ``` 84 | ![tts示例](./assets/tts_demo.gif) 85 | 86 | 87 | ### 3.2 组合功能示例 88 | 通过多个AI模型组合得到更为复杂的功能,对显卡资源要求较高。 89 | 1. 聊天问答 90 | - 文本流式对话功能 91 | - 语音对话功能 92 | ```bash 93 | python webui.py -c ./configs/chatbot_demo.yml 94 | ``` 95 | ![chatbot示例](./assets/chat_demo.gif) 96 | 97 | 2. 视频修复 98 | - 去水印 99 | - 去马赛克 100 | - 目标追踪 101 | - 消除视频里的特定物体 102 | 103 | ```bash 104 | python webui.py -c ./configs/video_inpainter_demo.yml 105 | ``` 106 | ![video_inpainter示例](./assets/video_inpainter_demo.gif) 107 | 108 | 3. 视频转换 109 | - 音视频分离 110 | - 画面裁剪 111 | - 画面加噪 112 | - 抽帧取帧 113 | - 音频识别 114 | - 字幕翻译 115 | - 语音合成 116 | - bgm添加 117 | - 视频一键生成(外网视频无脑搬运) 118 | ```bash 119 | python webui.py -c ./configs/video_convertor_demo.yml 120 | ``` 121 | ![video_convertor示例](./assets/video_convertor_demo.gif) 122 | 123 | ### 3.3 全功能上线 124 | 通过下列命令,打开所有AI功能: 125 | ```bash 126 | python webui.py -c ./configs/webui_configs.yml 127 | ``` 128 | 由于模型加载耗时较长,建议在启动后的第一次推理时加载模型。 129 | 可通过configs/base.yml配置文件的"init_model_when_start_server"来控制每一个AI模型的加载策略。 130 | 131 | ## 🔥4. 模型文件 132 | ### 4.1 模型文件下载 133 | | 模型 | 模型文件大小 | 小模型清单 | 下载链接 | 134 | | :--- | :--- | :--- | :--- | 135 | | chatglm2-6b-int4 | 3.7G | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A)| 136 | | chatglm2-6b | 12G | | [清华大学云盘](https://cloud.tsinghua.edu.cn/d/674208019e314311ab5c/?p=%2Fchatglm2-6b&mode=list) | 137 | | sam_vit_b | 358M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) | 138 | | sam_vit_h | 2.4G | | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) | 139 | | FastSAM-s | 23M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) | 140 | | FastSAM-x | 138M | | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) | 141 | | ProPainter | 150M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) | 142 | | raft-things | 20M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) | 143 | | recurrent_flow_completion | 19M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A ) | 144 | | cutie | 134M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) | 145 | | whisper-samll | 461M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) | 146 | | whisper-large-v3 | 2.9G | | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) | 147 | 148 | - 百度网盘提取码为:zogk 149 | 150 | ### 4.2 模型权重文件的目录结构 151 | ``` 152 | model_weights/ 153 | ├── chatglm 154 | │ └── chatglm2-6b-int4 155 | │ ├── config.json 156 | │ ├── configuration_chatglm.py 157 | │ ├── modeling_chatglm.py 158 | │ ├── pytorch_model.bin 159 | │ ├── quantization.py 160 | │ ├── tokenization_chatglm.py 161 | │ ├── tokenizer.model 162 | │ └── tokenizer_config.json 163 | ├── fastsam 164 | │ ├── FastSAM-s.pt 165 | │ └── FastSAM-x.pt 166 | ├── propainter 167 | │ ├── ProPainter.pth 168 | │ ├── cutie-base-mega.pth 169 | │ ├── raft-things.pth 170 | │ └── recurrent_flow_completion.pth 171 | ├── sam 172 | │ ├── sam_vit_b.pth 173 | │ └── sam_vit_h.pth 174 | └── whisper 175 | ├── large-v3.pt 176 | └── small.pt 177 | ``` 178 | 如果GPU显存小于8G,可能要小模型才能跑得起来;但小模型的效果不太理想,有条件的尽量跑大模型。 179 | 180 | ## 5. 贡献 181 | 182 | 如果你有任何建议或功能请求,欢迎提出一个 issue。 183 | 184 | ## 6. 参考 185 | - [Segment-ant-Track-Anything](https://github.com/z-x-yang/Segment-and-Track-Anything) 186 | - [ProPainter](https://github.com/sczhou/ProPainter) 187 | - [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) 188 | - [segment-anything](https://github.com/facebookresearch/segment-anything) 189 | - [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) 190 | - [whisper](https://github.com/openai/whisper) -------------------------------------------------------------------------------- /assets/AI-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/AI-logo.png -------------------------------------------------------------------------------- /assets/all_tabs_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/all_tabs_demo.gif -------------------------------------------------------------------------------- /assets/asr_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/asr_demo.gif -------------------------------------------------------------------------------- /assets/chat_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/chat_demo.gif -------------------------------------------------------------------------------- /assets/segmentation_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/segmentation_demo.gif -------------------------------------------------------------------------------- /assets/tts_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/tts_demo.gif -------------------------------------------------------------------------------- /assets/video_convertor_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/video_convertor_demo.gif -------------------------------------------------------------------------------- /assets/video_inpainter_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/video_inpainter_demo.gif -------------------------------------------------------------------------------- /assets/visualchat_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/visualchat_demo.jpg -------------------------------------------------------------------------------- /assets/visualchat_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/visualchat_demo.png -------------------------------------------------------------------------------- /configs/asr_demo.yml: -------------------------------------------------------------------------------- 1 | base: ['./configs/base.yml'] 2 | 3 | asr_task: 4 | switch: true 5 | name: 语音识别 6 | audio_asr_text: 7 | label: 'text' 8 | 9 | -------------------------------------------------------------------------------- /configs/base.yml: -------------------------------------------------------------------------------- 1 | 2 | home_desc: 站在AI「人工智能」浪潮的风口,我们可以更好地应对未来的挑战!! 3 | server_name: "0.0.0.0" 4 | server_port: 9090 5 | 6 | # AI模型相关参数 7 | ChatGLMHandler: 8 | llm_model_path: ./model_weights/chatglm/chatglm2-6b-int4 9 | num_gpus: 1 10 | trust_remote_code: true 11 | init_model_when_start_server: false 12 | 13 | ChatVLMHandler: 14 | vlm_model_path: ./model_weights/Qwen-VL-Chat # /mnt/d/ai_dev/visualglm-6b/model_weights 15 | device: 'cuda:1' 16 | quant: 8 17 | trust_remote_code: true 18 | init_model_when_start_server: false 19 | 20 | WhisperHandler: 21 | model_name: large-v3 22 | model_dir: ./model_weights/whisper 23 | device: 'cuda:1' 24 | init_model_when_start_server: false 25 | language_map: { 26 | "普通话": "zh", 27 | "英语": "en", 28 | } 29 | 30 | GPTHandler: 31 | api_url: null 32 | 33 | FastSAMHandler: 34 | fastsam_model_path: ./model_weights/fastsam/FastSAM-x.pt 35 | device: 'cuda:0' 36 | 37 | SAMHandler: 38 | model_ckpt: ./model_weights/sam/sam_vit_h.pth 39 | model_type: vit_h 40 | device: 'cuda:0' 41 | 42 | InpaintingHandler: 43 | propainter_ckpt: ./model_weights/propainter/ProPainter.pth 44 | raft_ckpt: ./model_weights/propainter/raft-things.pth 45 | flow_completion_ckpt: ./model_weights/propainter/recurrent_flow_completion.pth 46 | device: 'cuda:0' 47 | use_half: true 48 | 49 | TrackerHandler: 50 | cutie_ckpt: ./model_weights/propainter/cutie-base-mega.pth 51 | device: 'cuda:0' 52 | 53 | 54 | EdgeTTSHandler: 55 | output_dir: /tmp 56 | -------------------------------------------------------------------------------- /configs/chatbot_demo.yml: -------------------------------------------------------------------------------- 1 | base: ['./configs/base.yml'] 2 | 3 | chatbot: 4 | switch: true 5 | name: 聊天问答 6 | chatbot_win: 7 | height: 268 8 | llm_model_type: 9 | choices: ["chatglm"] 10 | value: "chatglm" 11 | 12 | tts_task: 13 | switch: true 14 | name: 语音合成 15 | output_dir: ./results/tts_outputs 16 | model_type: 17 | choices: ["edge_tts", "so_vits_svc"] 18 | value: "edge_tts" 19 | tts_voice: 20 | choices: &tts_voice_list ["zh-CN-YunxiNeural", "zh-CN-YunjianNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural", 21 | "zh-CN-YunxiaNeural", "zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-CN-shaanxi-XiaoniNeural"] 22 | value: "zh-CN-YunxiNeural" 23 | tts_rate: 24 | minimum: -100 25 | maximum: 100 26 | value: 0 27 | step: 5 28 | tts_volume: 29 | minimum: -100 30 | maximum: 100 31 | value: 0 32 | step: 5 33 | tts_pitch: 34 | minimum: -100 35 | maximum: 100 36 | value: 0 37 | step: 5 38 | -------------------------------------------------------------------------------- /configs/segmentation_demo.yml: -------------------------------------------------------------------------------- 1 | base: ['./configs/base.yml'] 2 | 3 | segmentation_task: 4 | switch: true 5 | name: 图像分割(SAM) 6 | 7 | 8 | -------------------------------------------------------------------------------- /configs/tts_demo.yml: -------------------------------------------------------------------------------- 1 | base: ['./configs/base.yml'] 2 | 3 | tts_task: 4 | switch: true 5 | name: 语音合成 6 | output_dir: /data1/zjx/ai_webui/products/tts_outputs 7 | model_type: 8 | choices: ["edge_tts", "so_vits_svc"] 9 | value: "edge_tts" 10 | tts_voice: 11 | choices: ["zh-CN-YunxiNeural", "zh-CN-YunjianNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural", 12 | "zh-CN-YunxiaNeural", "zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-CN-shaanxi-XiaoniNeural"] 13 | value: "zh-CN-YunxiNeural" 14 | tts_rate: 15 | minimum: -100 16 | maximum: 100 17 | value: 0 18 | step: 5 19 | tts_volume: 20 | minimum: -100 21 | maximum: 100 22 | value: 0 23 | step: 5 24 | tts_pitch: 25 | minimum: -100 26 | maximum: 100 27 | value: 0 28 | step: 5 -------------------------------------------------------------------------------- /configs/video_convertor_demo.yml: -------------------------------------------------------------------------------- 1 | base: ['./configs/base.yml'] 2 | 3 | video_convertor: 4 | switch: true 5 | name: 视频转换 6 | video_upload_win: 7 | height: 400 8 | width: null 9 | autoplay: true 10 | aspect_ratio_box: 11 | choices: ["16/9", "9/16", "4/3"] 12 | value: "16/9" 13 | move_up_rate: 14 | minimum: 0 15 | maximum: 0.3 16 | step: 0.05 17 | value: 0.15 18 | add_noice: 19 | value: false 20 | video_segment_length: 21 | minimum: 1 22 | maximum: 10 23 | step: 1 24 | value: 1 25 | audio_speech_rate: 26 | minimum: 0.1 27 | maximum: 1. 28 | step: 0.1 29 | value: 0.9 30 | subtitling_recognition: 31 | value: true 32 | translate_engine: 33 | choices: ["chatglm"] 34 | value: "chatglm" 35 | subtitling_language: 36 | choices: ["英语", "普通话"] 37 | value: "普通话" 38 | collage_short_video: 39 | value: true 40 | voice_role: 41 | choices: ["zh-CN-YunxiNeural", "zh-CN-YunjianNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural", 42 | "zh-CN-YunxiaNeural", "zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-CN-shaanxi-XiaoniNeural" 43 | ] 44 | value: "zh-CN-YunxiNeural" 45 | bgm_name: 46 | choices: ["Paris", "illusionary_daytime", "time_back", "windy_hill"] 47 | value: "Paris" 48 | watermark: 49 | choices: ["拉普拉丝", "川陀"] 50 | value: "拉普拉丝" 51 | 52 | -------------------------------------------------------------------------------- /configs/video_inpainter_demo.yml: -------------------------------------------------------------------------------- 1 | base: ['./configs/base.yml'] 2 | 3 | video_inpainter: 4 | switch: true 5 | name: 视频修复 6 | interactive_state: 7 | mask_save: false 8 | -------------------------------------------------------------------------------- /configs/visualchat_demo.yml: -------------------------------------------------------------------------------- 1 | base: ['./configs/base.yml'] 2 | 3 | visualchat: 4 | switch: true 5 | name: 多模态问答 6 | model_type: 7 | choices: ["qwen-vl-chat"] 8 | value: "qwen-vl-chat" 9 | -------------------------------------------------------------------------------- /configs/webui_configs.yml: -------------------------------------------------------------------------------- 1 | base: ['./configs/base.yml'] 2 | 3 | segmentation_task: 4 | switch: true 5 | name: 图像分割(SAM) 6 | 7 | asr_task: 8 | switch: true 9 | name: 语音识别 10 | audio_asr_text: 11 | label: 'text' 12 | 13 | tts_task: 14 | switch: true 15 | name: 语音合成 16 | output_dir: /data1/zjx/ai_webui/products/tts_outputs 17 | model_type: 18 | choices: ["edge_tts", "so_vits_svc"] 19 | value: "edge_tts" 20 | tts_voice: 21 | choices: ["zh-CN-YunxiNeural", "zh-CN-YunjianNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural", 22 | "zh-CN-YunxiaNeural", "zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-CN-shaanxi-XiaoniNeural"] 23 | value: "zh-CN-YunxiNeural" 24 | tts_rate: 25 | minimum: -100 26 | maximum: 100 27 | value: 0 28 | step: 5 29 | tts_volume: 30 | minimum: -100 31 | maximum: 100 32 | value: 0 33 | step: 5 34 | tts_pitch: 35 | minimum: -100 36 | maximum: 100 37 | value: 0 38 | step: 5 39 | 40 | chatbot: 41 | switch: true 42 | name: 聊天问答 43 | chatbot_win: 44 | height: 268 45 | llm_model_type: 46 | choices: ["chatglm"] 47 | value: "chatglm" 48 | 49 | visualchat: 50 | switch: true 51 | name: 多模态问答 52 | model_type: 53 | choices: ["visualglm"] 54 | value: "visualglm" 55 | 56 | video_inpainter: 57 | switch: true 58 | name: 视频修复 59 | interactive_state: 60 | mask_save: false 61 | 62 | video_convertor: 63 | switch: true 64 | name: 视频转换 65 | video_upload_win: 66 | height: 400 67 | width: null 68 | autoplay: true 69 | aspect_ratio_box: 70 | choices: ["16/9", "9/16", "4/3"] 71 | value: "16/9" 72 | move_up_rate: 73 | minimum: 0 74 | maximum: 0.3 75 | step: 0.05 76 | value: 0.15 77 | add_noice: 78 | value: false 79 | video_segment_length: 80 | minimum: 1 81 | maximum: 10 82 | step: 1 83 | value: 1 84 | audio_speech_rate: 85 | minimum: 0.1 86 | maximum: 1. 87 | step: 0.1 88 | value: 0.9 89 | subtitling_recognition: 90 | value: true 91 | translate_engine: 92 | choices: ["chatglm"] 93 | value: "chatglm" 94 | subtitling_language: 95 | choices: ["英语", "普通话"] 96 | value: "普通话" 97 | collage_short_video: 98 | value: true 99 | voice_role: 100 | choices: ["zh-CN-YunxiNeural", "zh-CN-YunjianNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural", 101 | "zh-CN-YunxiaNeural", "zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-CN-shaanxi-XiaoniNeural" 102 | ] 103 | value: "zh-CN-YunxiNeural" 104 | bgm_name: 105 | choices: ["Paris", "illusionary_daytime", "time_back", "windy_hill"] 106 | value: "Paris" 107 | watermark: 108 | choices: ["拉普拉丝", "川陀"] 109 | value: "拉普拉丝" 110 | 111 | -------------------------------------------------------------------------------- /demo/bgm/Paris.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/bgm/Paris.mp3 -------------------------------------------------------------------------------- /demo/bgm/illusionary_daytime.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/bgm/illusionary_daytime.mp3 -------------------------------------------------------------------------------- /demo/bgm/time_back.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/bgm/time_back.mp3 -------------------------------------------------------------------------------- /demo/bgm/windy_hill.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/bgm/windy_hill.mp3 -------------------------------------------------------------------------------- /demo/fastsam/examples/dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/dogs.jpg -------------------------------------------------------------------------------- /demo/fastsam/examples/sa_10039.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_10039.jpg -------------------------------------------------------------------------------- /demo/fastsam/examples/sa_11025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_11025.jpg -------------------------------------------------------------------------------- /demo/fastsam/examples/sa_1309.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_1309.jpg -------------------------------------------------------------------------------- /demo/fastsam/examples/sa_192.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_192.jpg -------------------------------------------------------------------------------- /demo/fastsam/examples/sa_414.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_414.jpg -------------------------------------------------------------------------------- /demo/fastsam/examples/sa_561.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_561.jpg -------------------------------------------------------------------------------- /demo/fastsam/examples/sa_862.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_862.jpg -------------------------------------------------------------------------------- /demo/fastsam/examples/sa_8776.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_8776.jpg -------------------------------------------------------------------------------- /demo/video_inpainter/test-sample0.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/video_inpainter/test-sample0.mp4 -------------------------------------------------------------------------------- /demo/video_inpainter/test-sample1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/video_inpainter/test-sample1.mp4 -------------------------------------------------------------------------------- /demo/video_inpainter/test-sample2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/video_inpainter/test-sample2.mp4 -------------------------------------------------------------------------------- /demo/video_inpainter/test-sample3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/video_inpainter/test-sample3.mp4 -------------------------------------------------------------------------------- /demo/video_inpainter/test-sample4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/video_inpainter/test-sample4.mp4 -------------------------------------------------------------------------------- /model_weights/chatglm/chatglm2-6b-int4/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "THUDM/chatglm2-6b", 3 | "model_type": "chatglm", 4 | "architectures": [ 5 | "ChatGLMModel" 6 | ], 7 | "auto_map": { 8 | "AutoConfig": "configuration_chatglm.ChatGLMConfig", 9 | "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration", 10 | "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration" 11 | }, 12 | "add_bias_linear": false, 13 | "add_qkv_bias": true, 14 | "apply_query_key_layer_scaling": true, 15 | "apply_residual_connection_post_layernorm": false, 16 | "attention_dropout": 0.0, 17 | "attention_softmax_in_fp32": true, 18 | "bias_dropout_fusion": true, 19 | "ffn_hidden_size": 13696, 20 | "fp32_residual_connection": false, 21 | "hidden_dropout": 0.0, 22 | "hidden_size": 4096, 23 | "kv_channels": 128, 24 | "layernorm_epsilon": 1e-05, 25 | "multi_query_attention": true, 26 | "multi_query_group_num": 2, 27 | "num_attention_heads": 32, 28 | "num_layers": 28, 29 | "original_rope": true, 30 | "padded_vocab_size": 65024, 31 | "post_layer_norm": true, 32 | "quantization_bit": 4, 33 | "rmsnorm": true, 34 | "seq_length": 32768, 35 | "use_cache": true, 36 | "torch_dtype": "float16", 37 | "transformers_version": "4.27.1", 38 | "tie_word_embeddings": false, 39 | "eos_token_id": 2, 40 | "pad_token_id": 0 41 | } -------------------------------------------------------------------------------- /model_weights/chatglm/chatglm2-6b-int4/configuration_chatglm.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | 4 | class ChatGLMConfig(PretrainedConfig): 5 | model_type = "chatglm" 6 | def __init__( 7 | self, 8 | num_layers=28, 9 | padded_vocab_size=65024, 10 | hidden_size=4096, 11 | ffn_hidden_size=13696, 12 | kv_channels=128, 13 | num_attention_heads=32, 14 | seq_length=2048, 15 | hidden_dropout=0.0, 16 | attention_dropout=0.0, 17 | layernorm_epsilon=1e-5, 18 | rmsnorm=True, 19 | apply_residual_connection_post_layernorm=False, 20 | post_layer_norm=True, 21 | add_bias_linear=False, 22 | add_qkv_bias=False, 23 | bias_dropout_fusion=True, 24 | multi_query_attention=False, 25 | multi_query_group_num=1, 26 | apply_query_key_layer_scaling=True, 27 | attention_softmax_in_fp32=True, 28 | fp32_residual_connection=False, 29 | quantization_bit=0, 30 | pre_seq_len=None, 31 | prefix_projection=False, 32 | **kwargs 33 | ): 34 | self.num_layers = num_layers 35 | self.vocab_size = padded_vocab_size 36 | self.padded_vocab_size = padded_vocab_size 37 | self.hidden_size = hidden_size 38 | self.ffn_hidden_size = ffn_hidden_size 39 | self.kv_channels = kv_channels 40 | self.num_attention_heads = num_attention_heads 41 | self.seq_length = seq_length 42 | self.hidden_dropout = hidden_dropout 43 | self.attention_dropout = attention_dropout 44 | self.layernorm_epsilon = layernorm_epsilon 45 | self.rmsnorm = rmsnorm 46 | self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm 47 | self.post_layer_norm = post_layer_norm 48 | self.add_bias_linear = add_bias_linear 49 | self.add_qkv_bias = add_qkv_bias 50 | self.bias_dropout_fusion = bias_dropout_fusion 51 | self.multi_query_attention = multi_query_attention 52 | self.multi_query_group_num = multi_query_group_num 53 | self.apply_query_key_layer_scaling = apply_query_key_layer_scaling 54 | self.attention_softmax_in_fp32 = attention_softmax_in_fp32 55 | self.fp32_residual_connection = fp32_residual_connection 56 | self.quantization_bit = quantization_bit 57 | self.pre_seq_len = pre_seq_len 58 | self.prefix_projection = prefix_projection 59 | super().__init__(**kwargs) -------------------------------------------------------------------------------- /model_weights/chatglm/chatglm2-6b-int4/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name_or_path": "THUDM/chatglm-6b", 3 | "remove_space": false, 4 | "do_lower_case": false, 5 | "tokenizer_class": "ChatGLMTokenizer", 6 | "auto_map": { 7 | "AutoTokenizer": [ 8 | "tokenization_chatglm.ChatGLMTokenizer", 9 | null 10 | ] 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /models/RAFT/__init__.py: -------------------------------------------------------------------------------- 1 | # from .demo import RAFT_infer 2 | from .raft import RAFT 3 | -------------------------------------------------------------------------------- /models/RAFT/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1) 38 | dy = torch.linspace(-r, r, 2*r+1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class CorrLayer(torch.autograd.Function): 64 | @staticmethod 65 | def forward(ctx, fmap1, fmap2, coords, r): 66 | fmap1 = fmap1.contiguous() 67 | fmap2 = fmap2.contiguous() 68 | coords = coords.contiguous() 69 | ctx.save_for_backward(fmap1, fmap2, coords) 70 | ctx.r = r 71 | corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) 72 | return corr 73 | 74 | @staticmethod 75 | def backward(ctx, grad_corr): 76 | fmap1, fmap2, coords = ctx.saved_tensors 77 | grad_corr = grad_corr.contiguous() 78 | fmap1_grad, fmap2_grad, coords_grad = \ 79 | correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r) 80 | return fmap1_grad, fmap2_grad, coords_grad, None 81 | 82 | 83 | class AlternateCorrBlock: 84 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 85 | self.num_levels = num_levels 86 | self.radius = radius 87 | 88 | self.pyramid = [(fmap1, fmap2)] 89 | for i in range(self.num_levels): 90 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 91 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 92 | self.pyramid.append((fmap1, fmap2)) 93 | 94 | def __call__(self, coords): 95 | 96 | coords = coords.permute(0, 2, 3, 1) 97 | B, H, W, _ = coords.shape 98 | 99 | corr_list = [] 100 | for i in range(self.num_levels): 101 | r = self.radius 102 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) 103 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) 104 | 105 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 106 | corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) 107 | corr_list.append(corr.squeeze(1)) 108 | 109 | corr = torch.stack(corr_list, dim=1) 110 | corr = corr.reshape(B, -1, H, W) 111 | return corr / 16.0 112 | -------------------------------------------------------------------------------- /models/RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os 4 | import cv2 5 | import glob 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | from .raft import RAFT 11 | from .utils import flow_viz 12 | from .utils.utils import InputPadder 13 | 14 | 15 | 16 | DEVICE = 'cuda' 17 | 18 | def load_image(imfile): 19 | img = np.array(Image.open(imfile)).astype(np.uint8) 20 | img = torch.from_numpy(img).permute(2, 0, 1).float() 21 | return img 22 | 23 | 24 | def load_image_list(image_files): 25 | images = [] 26 | for imfile in sorted(image_files): 27 | images.append(load_image(imfile)) 28 | 29 | images = torch.stack(images, dim=0) 30 | images = images.to(DEVICE) 31 | 32 | padder = InputPadder(images.shape) 33 | return padder.pad(images)[0] 34 | 35 | 36 | def viz(img, flo): 37 | img = img[0].permute(1,2,0).cpu().numpy() 38 | flo = flo[0].permute(1,2,0).cpu().numpy() 39 | 40 | # map flow to rgb image 41 | flo = flow_viz.flow_to_image(flo) 42 | # img_flo = np.concatenate([img, flo], axis=0) 43 | img_flo = flo 44 | 45 | cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) 46 | # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 47 | # cv2.waitKey() 48 | 49 | 50 | def demo(args): 51 | model = torch.nn.DataParallel(RAFT(args)) 52 | model.load_state_dict(torch.load(args.model)) 53 | 54 | model = model.module 55 | model.to(DEVICE) 56 | model.eval() 57 | 58 | with torch.no_grad(): 59 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 60 | glob.glob(os.path.join(args.path, '*.jpg')) 61 | 62 | images = load_image_list(images) 63 | for i in range(images.shape[0]-1): 64 | image1 = images[i,None] 65 | image2 = images[i+1,None] 66 | 67 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 68 | viz(image1, flow_up) 69 | 70 | 71 | def RAFT_infer(args): 72 | model = torch.nn.DataParallel(RAFT(args)) 73 | model.load_state_dict(torch.load(args.model)) 74 | 75 | model = model.module 76 | model.to(DEVICE) 77 | model.eval() 78 | 79 | return model 80 | -------------------------------------------------------------------------------- /models/RAFT/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .update import BasicUpdateBlock, SmallUpdateBlock 7 | from .extractor import BasicEncoder, SmallEncoder 8 | from .corr import CorrBlock, AlternateCorrBlock 9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | def __enter__(self): 19 | pass 20 | def __exit__(self, *args): 21 | pass 22 | 23 | 24 | class RAFT(nn.Module): 25 | def __init__(self, args): 26 | super(RAFT, self).__init__() 27 | self.args = args 28 | 29 | if args.small: 30 | self.hidden_dim = hdim = 96 31 | self.context_dim = cdim = 64 32 | args.corr_levels = 4 33 | args.corr_radius = 3 34 | 35 | else: 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | args.corr_levels = 4 39 | args.corr_radius = 4 40 | 41 | if 'dropout' not in args._get_kwargs(): 42 | args.dropout = 0 43 | 44 | if 'alternate_corr' not in args._get_kwargs(): 45 | args.alternate_corr = False 46 | 47 | # feature network, context network, and update block 48 | if args.small: 49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) 51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 52 | 53 | else: 54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 57 | 58 | 59 | def freeze_bn(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.BatchNorm2d): 62 | m.eval() 63 | 64 | def initialize_flow(self, img): 65 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 66 | N, C, H, W = img.shape 67 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 68 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 69 | 70 | # optical flow computed as difference: flow = coords1 - coords0 71 | return coords0, coords1 72 | 73 | def upsample_flow(self, flow, mask): 74 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 75 | N, _, H, W = flow.shape 76 | mask = mask.view(N, 1, 9, 8, 8, H, W) 77 | mask = torch.softmax(mask, dim=2) 78 | 79 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 80 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 81 | 82 | up_flow = torch.sum(mask * up_flow, dim=2) 83 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 84 | return up_flow.reshape(N, 2, 8*H, 8*W) 85 | 86 | 87 | def forward(self, image1, image2, iters=12, flow_init=None, test_mode=True): 88 | """ Estimate optical flow between pair of frames """ 89 | 90 | # image1 = 2 * (image1 / 255.0) - 1.0 91 | # image2 = 2 * (image2 / 255.0) - 1.0 92 | 93 | image1 = image1.contiguous() 94 | image2 = image2.contiguous() 95 | 96 | hdim = self.hidden_dim 97 | cdim = self.context_dim 98 | 99 | # run the feature network 100 | with autocast(enabled=self.args.mixed_precision): 101 | fmap1, fmap2 = self.fnet([image1, image2]) 102 | 103 | fmap1 = fmap1.float() 104 | fmap2 = fmap2.float() 105 | 106 | if self.args.alternate_corr: 107 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | else: 109 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 110 | 111 | # run the context network 112 | with autocast(enabled=self.args.mixed_precision): 113 | cnet = self.cnet(image1) 114 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 115 | net = torch.tanh(net) 116 | inp = torch.relu(inp) 117 | 118 | coords0, coords1 = self.initialize_flow(image1) 119 | 120 | if flow_init is not None: 121 | coords1 = coords1 + flow_init 122 | 123 | flow_predictions = [] 124 | for itr in range(iters): 125 | coords1 = coords1.detach() 126 | corr = corr_fn(coords1) # index correlation volume 127 | 128 | flow = coords1 - coords0 129 | with autocast(enabled=self.args.mixed_precision): 130 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 131 | 132 | # F(t+1) = F(t) + \Delta(t) 133 | coords1 = coords1 + delta_flow 134 | 135 | # upsample predictions 136 | if up_mask is None: 137 | flow_up = upflow8(coords1 - coords0) 138 | else: 139 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 140 | 141 | flow_predictions.append(flow_up) 142 | 143 | if test_mode: 144 | return coords1 - coords0, flow_up 145 | 146 | return flow_predictions 147 | -------------------------------------------------------------------------------- /models/RAFT/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /models/RAFT/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .flow_viz import flow_to_image 2 | from .frame_utils import writeFlow 3 | -------------------------------------------------------------------------------- /models/RAFT/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /models/RAFT/utils/flow_viz_pt.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization 2 | import torch 3 | torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 4 | 5 | @torch.no_grad() 6 | def flow_to_image(flow: torch.Tensor) -> torch.Tensor: 7 | 8 | """ 9 | Converts a flow to an RGB image. 10 | 11 | Args: 12 | flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. 13 | 14 | Returns: 15 | img (Tensor): Image Tensor of dtype uint8 where each color corresponds 16 | to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. 17 | """ 18 | 19 | if flow.dtype != torch.float: 20 | raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") 21 | 22 | orig_shape = flow.shape 23 | if flow.ndim == 3: 24 | flow = flow[None] # Add batch dim 25 | 26 | if flow.ndim != 4 or flow.shape[1] != 2: 27 | raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") 28 | 29 | max_norm = torch.sum(flow**2, dim=1).sqrt().max() 30 | epsilon = torch.finfo((flow).dtype).eps 31 | normalized_flow = flow / (max_norm + epsilon) 32 | img = _normalized_flow_to_image(normalized_flow) 33 | 34 | if len(orig_shape) == 3: 35 | img = img[0] # Remove batch dim 36 | return img 37 | 38 | @torch.no_grad() 39 | def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: 40 | 41 | """ 42 | Converts a batch of normalized flow to an RGB image. 43 | 44 | Args: 45 | normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) 46 | Returns: 47 | img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. 48 | """ 49 | 50 | N, _, H, W = normalized_flow.shape 51 | device = normalized_flow.device 52 | flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) 53 | colorwheel = _make_colorwheel().to(device) # shape [55x3] 54 | num_cols = colorwheel.shape[0] 55 | norm = torch.sum(normalized_flow**2, dim=1).sqrt() 56 | a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi 57 | fk = (a + 1) / 2 * (num_cols - 1) 58 | k0 = torch.floor(fk).to(torch.long) 59 | k1 = k0 + 1 60 | k1[k1 == num_cols] = 0 61 | f = fk - k0 62 | 63 | for c in range(colorwheel.shape[1]): 64 | tmp = colorwheel[:, c] 65 | col0 = tmp[k0] / 255.0 66 | col1 = tmp[k1] / 255.0 67 | col = (1 - f) * col0 + f * col1 68 | col = 1 - norm * (1 - col) 69 | flow_image[:, c, :, :] = torch.floor(255. * col) 70 | return flow_image 71 | 72 | 73 | @torch.no_grad() 74 | def _make_colorwheel() -> torch.Tensor: 75 | """ 76 | Generates a color wheel for optical flow visualization as presented in: 77 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 78 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. 79 | 80 | Returns: 81 | colorwheel (Tensor[55, 3]): Colorwheel Tensor. 82 | """ 83 | 84 | RY = 15 85 | YG = 6 86 | GC = 4 87 | CB = 11 88 | BM = 13 89 | MR = 6 90 | 91 | ncols = RY + YG + GC + CB + BM + MR 92 | colorwheel = torch.zeros((ncols, 3)) 93 | col = 0 94 | 95 | # RY 96 | colorwheel[0:RY, 0] = 255 97 | colorwheel[0:RY, 1] = torch.floor(255. * torch.arange(0., RY) / RY) 98 | col = col + RY 99 | # YG 100 | colorwheel[col : col + YG, 0] = 255 - torch.floor(255. * torch.arange(0., YG) / YG) 101 | colorwheel[col : col + YG, 1] = 255 102 | col = col + YG 103 | # GC 104 | colorwheel[col : col + GC, 1] = 255 105 | colorwheel[col : col + GC, 2] = torch.floor(255. * torch.arange(0., GC) / GC) 106 | col = col + GC 107 | # CB 108 | colorwheel[col : col + CB, 1] = 255 - torch.floor(255. * torch.arange(CB) / CB) 109 | colorwheel[col : col + CB, 2] = 255 110 | col = col + CB 111 | # BM 112 | colorwheel[col : col + BM, 2] = 255 113 | colorwheel[col : col + BM, 0] = torch.floor(255. * torch.arange(0., BM) / BM) 114 | col = col + BM 115 | # MR 116 | colorwheel[col : col + MR, 2] = 255 - torch.floor(255. * torch.arange(MR) / MR) 117 | colorwheel[col : col + MR, 0] = 255 118 | return colorwheel 119 | -------------------------------------------------------------------------------- /models/RAFT/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /models/RAFT/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd): 75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/__init__.py -------------------------------------------------------------------------------- /models/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import random 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import logging 8 | import numpy as np 9 | from os import path as osp 10 | 11 | def constant_init(module, val, bias=0): 12 | if hasattr(module, 'weight') and module.weight is not None: 13 | nn.init.constant_(module.weight, val) 14 | if hasattr(module, 'bias') and module.bias is not None: 15 | nn.init.constant_(module.bias, bias) 16 | 17 | initialized_logger = {} 18 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 19 | """Get the root logger. 20 | The logger will be initialized if it has not been initialized. By default a 21 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 22 | also be added. 23 | Args: 24 | logger_name (str): root logger name. Default: 'basicsr'. 25 | log_file (str | None): The log filename. If specified, a FileHandler 26 | will be added to the root logger. 27 | log_level (int): The root logger level. Note that only the process of 28 | rank 0 is affected, while other processes will set the level to 29 | "Error" and be silent most of the time. 30 | Returns: 31 | logging.Logger: The root logger. 32 | """ 33 | logger = logging.getLogger(logger_name) 34 | # if the logger has been initialized, just return it 35 | if logger_name in initialized_logger: 36 | return logger 37 | 38 | format_str = '%(asctime)s %(levelname)s: %(message)s' 39 | stream_handler = logging.StreamHandler() 40 | stream_handler.setFormatter(logging.Formatter(format_str)) 41 | logger.addHandler(stream_handler) 42 | logger.propagate = False 43 | 44 | if log_file is not None: 45 | logger.setLevel(log_level) 46 | # add file handler 47 | # file_handler = logging.FileHandler(log_file, 'w') 48 | file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log 49 | file_handler.setFormatter(logging.Formatter(format_str)) 50 | file_handler.setLevel(log_level) 51 | logger.addHandler(file_handler) 52 | initialized_logger[logger_name] = True 53 | return logger 54 | 55 | 56 | IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ 57 | torch.__version__)[0][:3])] >= [1, 12, 0] 58 | 59 | def gpu_is_available(): 60 | if IS_HIGH_VERSION: 61 | if torch.backends.mps.is_available(): 62 | return True 63 | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False 64 | 65 | def get_device(gpu_id=None): 66 | if gpu_id is None: 67 | gpu_str = '' 68 | elif isinstance(gpu_id, int): 69 | gpu_str = f':{gpu_id}' 70 | else: 71 | raise TypeError('Input should be int value.') 72 | 73 | if IS_HIGH_VERSION: 74 | if torch.backends.mps.is_available(): 75 | return torch.device('mps'+gpu_str) 76 | return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') 77 | 78 | 79 | def set_random_seed(seed): 80 | """Set random seeds.""" 81 | random.seed(seed) 82 | np.random.seed(seed) 83 | torch.manual_seed(seed) 84 | torch.cuda.manual_seed(seed) 85 | torch.cuda.manual_seed_all(seed) 86 | 87 | 88 | def get_time_str(): 89 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 90 | 91 | 92 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 93 | """Scan a directory to find the interested files. 94 | 95 | Args: 96 | dir_path (str): Path of the directory. 97 | suffix (str | tuple(str), optional): File suffix that we are 98 | interested in. Default: None. 99 | recursive (bool, optional): If set to True, recursively scan the 100 | directory. Default: False. 101 | full_path (bool, optional): If set to True, include the dir_path. 102 | Default: False. 103 | 104 | Returns: 105 | A generator for all the interested files with relative pathes. 106 | """ 107 | 108 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 109 | raise TypeError('"suffix" must be a string or tuple of strings') 110 | 111 | root = dir_path 112 | 113 | def _scandir(dir_path, suffix, recursive): 114 | for entry in os.scandir(dir_path): 115 | if not entry.name.startswith('.') and entry.is_file(): 116 | if full_path: 117 | return_path = entry.path 118 | else: 119 | return_path = osp.relpath(entry.path, root) 120 | 121 | if suffix is None: 122 | yield return_path 123 | elif return_path.endswith(suffix): 124 | yield return_path 125 | else: 126 | if recursive: 127 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 128 | else: 129 | continue 130 | 131 | return _scandir(dir_path, suffix=suffix, recursive=recursive) -------------------------------------------------------------------------------- /models/modules/base_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from functools import reduce 6 | 7 | class BaseNetwork(nn.Module): 8 | def __init__(self): 9 | super(BaseNetwork, self).__init__() 10 | 11 | def print_network(self): 12 | if isinstance(self, list): 13 | self = self[0] 14 | num_params = 0 15 | for param in self.parameters(): 16 | num_params += param.numel() 17 | print( 18 | 'Network [%s] was created. Total number of parameters: %.1f million. ' 19 | 'To see the architecture, do print(network).' % 20 | (type(self).__name__, num_params / 1000000)) 21 | 22 | def init_weights(self, init_type='normal', gain=0.02): 23 | ''' 24 | initialize network's weights 25 | init_type: normal | xavier | kaiming | orthogonal 26 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 27 | ''' 28 | def init_func(m): 29 | classname = m.__class__.__name__ 30 | if classname.find('InstanceNorm2d') != -1: 31 | if hasattr(m, 'weight') and m.weight is not None: 32 | nn.init.constant_(m.weight.data, 1.0) 33 | if hasattr(m, 'bias') and m.bias is not None: 34 | nn.init.constant_(m.bias.data, 0.0) 35 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 36 | or classname.find('Linear') != -1): 37 | if init_type == 'normal': 38 | nn.init.normal_(m.weight.data, 0.0, gain) 39 | elif init_type == 'xavier': 40 | nn.init.xavier_normal_(m.weight.data, gain=gain) 41 | elif init_type == 'xavier_uniform': 42 | nn.init.xavier_uniform_(m.weight.data, gain=1.0) 43 | elif init_type == 'kaiming': 44 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 45 | elif init_type == 'orthogonal': 46 | nn.init.orthogonal_(m.weight.data, gain=gain) 47 | elif init_type == 'none': # uses pytorch's default init method 48 | m.reset_parameters() 49 | else: 50 | raise NotImplementedError( 51 | 'initialization method [%s] is not implemented' % 52 | init_type) 53 | if hasattr(m, 'bias') and m.bias is not None: 54 | nn.init.constant_(m.bias.data, 0.0) 55 | 56 | self.apply(init_func) 57 | 58 | # propagate to children 59 | for m in self.children(): 60 | if hasattr(m, 'init_weights'): 61 | m.init_weights(init_type, gain) 62 | 63 | 64 | class Vec2Feat(nn.Module): 65 | def __init__(self, channel, hidden, kernel_size, stride, padding): 66 | super(Vec2Feat, self).__init__() 67 | self.relu = nn.LeakyReLU(0.2, inplace=True) 68 | c_out = reduce((lambda x, y: x * y), kernel_size) * channel 69 | self.embedding = nn.Linear(hidden, c_out) 70 | self.kernel_size = kernel_size 71 | self.stride = stride 72 | self.padding = padding 73 | self.bias_conv = nn.Conv2d(channel, 74 | channel, 75 | kernel_size=3, 76 | stride=1, 77 | padding=1) 78 | 79 | def forward(self, x, t, output_size): 80 | b_, _, _, _, c_ = x.shape 81 | x = x.view(b_, -1, c_) 82 | feat = self.embedding(x) 83 | b, _, c = feat.size() 84 | feat = feat.view(b * t, -1, c).permute(0, 2, 1) 85 | feat = F.fold(feat, 86 | output_size=output_size, 87 | kernel_size=self.kernel_size, 88 | stride=self.stride, 89 | padding=self.padding) 90 | feat = self.bias_conv(feat) 91 | return feat 92 | 93 | 94 | class FusionFeedForward(nn.Module): 95 | def __init__(self, dim, hidden_dim=1960, t2t_params=None): 96 | super(FusionFeedForward, self).__init__() 97 | # We set hidden_dim as a default to 1960 98 | self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim)) 99 | self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim)) 100 | assert t2t_params is not None 101 | self.t2t_params = t2t_params 102 | self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49 103 | 104 | def forward(self, x, output_size): 105 | n_vecs = 1 106 | for i, d in enumerate(self.t2t_params['kernel_size']): 107 | n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] - 108 | (d - 1) - 1) / self.t2t_params['stride'][i] + 1) 109 | 110 | x = self.fc1(x) 111 | b, n, c = x.size() 112 | normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1) 113 | normalizer = F.fold(normalizer, 114 | output_size=output_size, 115 | kernel_size=self.t2t_params['kernel_size'], 116 | padding=self.t2t_params['padding'], 117 | stride=self.t2t_params['stride']) 118 | 119 | x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1), 120 | output_size=output_size, 121 | kernel_size=self.t2t_params['kernel_size'], 122 | padding=self.t2t_params['padding'], 123 | stride=self.t2t_params['stride']) 124 | 125 | x = F.unfold(x / normalizer, 126 | kernel_size=self.t2t_params['kernel_size'], 127 | padding=self.t2t_params['padding'], 128 | stride=self.t2t_params['stride']).permute( 129 | 0, 2, 1).contiguous().view(b, n, c) 130 | x = self.fc2(x) 131 | return x 132 | -------------------------------------------------------------------------------- /models/modules/deformconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init as init 4 | from torch.nn.modules.utils import _pair, _single 5 | import math 6 | 7 | class ModulatedDeformConv2d(nn.Module): 8 | def __init__(self, 9 | in_channels, 10 | out_channels, 11 | kernel_size, 12 | stride=1, 13 | padding=0, 14 | dilation=1, 15 | groups=1, 16 | deform_groups=1, 17 | bias=True): 18 | super(ModulatedDeformConv2d, self).__init__() 19 | 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.kernel_size = _pair(kernel_size) 23 | self.stride = stride 24 | self.padding = padding 25 | self.dilation = dilation 26 | self.groups = groups 27 | self.deform_groups = deform_groups 28 | self.with_bias = bias 29 | # enable compatibility with nn.Conv2d 30 | self.transposed = False 31 | self.output_padding = _single(0) 32 | 33 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) 34 | if bias: 35 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 36 | else: 37 | self.register_parameter('bias', None) 38 | self.init_weights() 39 | 40 | def init_weights(self): 41 | n = self.in_channels 42 | for k in self.kernel_size: 43 | n *= k 44 | stdv = 1. / math.sqrt(n) 45 | self.weight.data.uniform_(-stdv, stdv) 46 | if self.bias is not None: 47 | self.bias.data.zero_() 48 | 49 | if hasattr(self, 'conv_offset'): 50 | self.conv_offset.weight.data.zero_() 51 | self.conv_offset.bias.data.zero_() 52 | 53 | def forward(self, x, offset, mask): 54 | pass -------------------------------------------------------------------------------- /models/sam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/__init__.py -------------------------------------------------------------------------------- /models/sam/assets/masks1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/assets/masks1.png -------------------------------------------------------------------------------- /models/sam/assets/masks2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/assets/masks2.jpg -------------------------------------------------------------------------------- /models/sam/assets/model_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/assets/model_diagram.png -------------------------------------------------------------------------------- /models/sam/assets/notebook1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/assets/notebook1.png -------------------------------------------------------------------------------- /models/sam/assets/notebook2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/assets/notebook2.png -------------------------------------------------------------------------------- /models/sam/linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | { 5 | black --version | grep -E "23\." > /dev/null 6 | } || { 7 | echo "Linter requires 'black==23.*' !" 8 | exit 1 9 | } 10 | 11 | ISORT_VERSION=$(isort --version-number) 12 | if [[ "$ISORT_VERSION" != 5.12* ]]; then 13 | echo "Linter requires isort==5.12.0 !" 14 | exit 1 15 | fi 16 | 17 | echo "Running isort ..." 18 | isort . --atomic 19 | 20 | echo "Running black ..." 21 | black -l 100 . 22 | 23 | echo "Running flake8 ..." 24 | if [ -x "$(command -v flake8)" ]; then 25 | flake8 . 26 | else 27 | python3 -m flake8 . 28 | fi 29 | 30 | echo "Running mypy..." 31 | 32 | mypy --exclude 'setup.py|notebooks' . 33 | -------------------------------------------------------------------------------- /models/sam/notebooks/images/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/notebooks/images/dog.jpg -------------------------------------------------------------------------------- /models/sam/notebooks/images/groceries.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/notebooks/images/groceries.jpg -------------------------------------------------------------------------------- /models/sam/notebooks/images/truck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/notebooks/images/truck.jpg -------------------------------------------------------------------------------- /models/sam/segment_anything/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/segment_anything/.DS_Store -------------------------------------------------------------------------------- /models/sam/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /models/sam/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /models/sam/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /models/sam/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /models/sam/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /models/sam/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /models/sam/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=3 4 | include_trailing_comma=True 5 | known_standard_library=numpy,setuptools 6 | skip_glob=*/__init__.py 7 | known_myself=segment_anything 8 | known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort 9 | no_lines_before=STDLIB,THIRDPARTY 10 | sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER 11 | default_section=FIRSTPARTY 12 | -------------------------------------------------------------------------------- /models/sam/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | setup( 10 | name="segment_anything", 11 | version="1.0", 12 | install_requires=[], 13 | packages=find_packages(exclude="notebooks"), 14 | extras_require={ 15 | "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"], 16 | "dev": ["flake8", "isort", "black", "mypy"], 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /models/tracker/base_tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from omegaconf import OmegaConf 5 | 6 | import sys 7 | 8 | from models.tracker.config import CONFIG 9 | from models.tracker.model.cutie import CUTIE 10 | from models.tracker.inference.inference_core import InferenceCore 11 | from models.tracker.utils.mask_mapper import MaskMapper 12 | 13 | from utils.painter import mask_painter 14 | 15 | 16 | class BaseTracker: 17 | def __init__(self, cutie_checkpoint, device) -> None: 18 | """ 19 | device: model device 20 | cutie_checkpoint: checkpoint of XMem model 21 | """ 22 | config = OmegaConf.create(CONFIG) 23 | 24 | # initialise XMem 25 | network = CUTIE(config).to(device).eval() 26 | model_weights = torch.load(cutie_checkpoint, map_location=device) 27 | network.load_weights(model_weights) 28 | 29 | # initialise IncerenceCore 30 | self.tracker = InferenceCore(network, config) 31 | self.device = device 32 | 33 | # changable properties 34 | self.mapper = MaskMapper() 35 | self.initialised = False 36 | 37 | @torch.no_grad() 38 | def resize_mask(self, mask): 39 | # mask transform is applied AFTER mapper, so we need to post-process it in eval.py 40 | h, w = mask.shape[-2:] 41 | min_hw = min(h, w) 42 | return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)), 43 | mode='nearest') 44 | 45 | @torch.no_grad() 46 | def image_to_torch(self, frame: np.ndarray, device: str = 'cuda'): 47 | # frame: H*W*3 numpy array 48 | frame = frame.transpose(2, 0, 1) 49 | frame = torch.from_numpy(frame).float().to(device, non_blocking=True) / 255 50 | return frame 51 | 52 | @torch.no_grad() 53 | def track(self, frame, first_frame_annotation=None): 54 | """ 55 | Input: 56 | frames: numpy arrays (H, W, 3) 57 | logit: numpy array (H, W), logit 58 | 59 | Output: 60 | mask: numpy arrays (H, W) 61 | logit: numpy arrays, probability map (H, W) 62 | painted_image: numpy array (H, W, 3) 63 | """ 64 | 65 | if first_frame_annotation is not None: # first frame mask 66 | # initialisation 67 | mask, labels = self.mapper.convert_mask(first_frame_annotation) 68 | mask = torch.Tensor(mask).to(self.device) 69 | else: 70 | mask = None 71 | labels = None 72 | 73 | # prepare inputs 74 | frame_tensor = self.image_to_torch(frame, self.device) 75 | 76 | # track one frame 77 | probs = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W 78 | 79 | # convert to mask 80 | out_mask = torch.argmax(probs, dim=0) 81 | out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) 82 | 83 | final_mask = np.zeros_like(out_mask) 84 | 85 | # map back 86 | for k, v in self.mapper.remappings.items(): 87 | final_mask[out_mask == v] = k 88 | 89 | num_objs = final_mask.max() 90 | painted_image = frame 91 | for obj in range(1, num_objs+1): 92 | if np.max(final_mask==obj) == 0: 93 | continue 94 | painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1) 95 | 96 | return final_mask, final_mask, painted_image 97 | 98 | @torch.no_grad() 99 | def clear_memory(self): 100 | self.tracker.clear_memory() 101 | self.mapper.clear_labels() 102 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /models/tracker/config/__init__.py: -------------------------------------------------------------------------------- 1 | CONFIG = {'exp_id': 'default', 'dataset': 'd17-val', 'amp': False, 'output_dir': None, 'flip_aug': False, 'max_internal_size': -1, 'image_directory': None, 'mask_directory': None, 'json_directory': None, 'size': None, 'save_all': None, 'use_all_masks': None, 'use_long_term': None, 'mem_every': 5, 'max_mem_frames': 5, 'long_term': {'count_usage': True, 'max_mem_frames': 10, 'min_mem_frames': 5, 'num_prototypes': 128, 'max_num_tokens': 10000, 'buffer_tokens': 2000}, 'top_k': 30, 'stagger_updates': 5, 'chunk_size': -1, 'save_scores': False, 'save_aux': False, 'visualize': False, 'model': {'pixel_mean': [0.485, 0.456, 0.406], 'pixel_std': [0.229, 0.224, 0.225], 'pixel_dim': 256, 'key_dim': 64, 'value_dim': 256, 'sensory_dim': 256, 'embed_dim': 256, 'pixel_encoder': {'type': 'resnet50', 'ms_dims': [1024, 512, 256]}, 'mask_encoder': {'type': 'resnet18', 'final_dim': 256}, 'pixel_pe_scale': 32, 'pixel_pe_temperature': 128, 'object_transformer': {'embed_dim': '${model.embed_dim}', 'ff_dim': 2048, 'num_heads': 8, 'num_blocks': 3, 'num_queries': 16, 'read_from_pixel': {'input_norm': False, 'input_add_pe': False, 'add_pe_to_qkv': [True, True, False]}, 'read_from_past': {'add_pe_to_qkv': [True, True, False]}, 'read_from_memory': {'add_pe_to_qkv': [True, True, False]}, 'read_from_query': {'add_pe_to_qkv': [True, True, False], 'output_norm': False}, 'query_self_attention': {'add_pe_to_qkv': [True, True, False]}, 'pixel_self_attention': {'add_pe_to_qkv': [True, True, False]}}, 'object_summarizer': {'embed_dim': '${model.object_transformer.embed_dim}', 'num_summaries': '${model.object_transformer.num_queries}', 'add_pe': True}, 'aux_loss': {'sensory': {'enabled': True, 'weight': 0.01}, 'query': {'enabled': True, 'weight': 0.01}}, 'mask_decoder': {'up_dims': [256, 128, 128]}}} -------------------------------------------------------------------------------- /models/tracker/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/tracker/inference/__init__.py -------------------------------------------------------------------------------- /models/tracker/inference/image_feature_store.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Iterable 3 | import torch 4 | from models.tracker.model.cutie import CUTIE 5 | 6 | 7 | class ImageFeatureStore: 8 | """ 9 | A cache for image features. 10 | These features might be reused at different parts of the inference pipeline. 11 | This class provide an interface for reusing these features. 12 | It is the user's responsibility to delete redundant features. 13 | 14 | Feature of a frame should be associated with a unique index -- typically the frame id. 15 | """ 16 | def __init__(self, network: CUTIE, no_warning: bool = False): 17 | self.network = network 18 | self._store = {} 19 | self.no_warning = no_warning 20 | 21 | def _encode_feature(self, index: int, image: torch.Tensor) -> None: 22 | ms_features, pix_feat = self.network.encode_image(image) 23 | key, shrinkage, selection = self.network.transform_key(ms_features[0]) 24 | self._store[index] = (ms_features, pix_feat, key, shrinkage, selection) 25 | 26 | def get_features(self, index: int, 27 | image: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor): 28 | if index not in self._store: 29 | self._encode_feature(index, image) 30 | 31 | return self._store[index][:2] 32 | 33 | def get_key(self, index: int, 34 | image: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor): 35 | if index not in self._store: 36 | self._encode_feature(index, image) 37 | 38 | return self._store[index][2:] 39 | 40 | def delete(self, index: int) -> None: 41 | if index in self._store: 42 | del self._store[index] 43 | 44 | def __len__(self): 45 | return len(self._store) 46 | 47 | def __del__(self): 48 | if len(self._store) > 0 and not self.no_warning: 49 | warnings.warn(f'Leaking {self._store.keys()} in the image feature store') 50 | -------------------------------------------------------------------------------- /models/tracker/inference/object_info.py: -------------------------------------------------------------------------------- 1 | class ObjectInfo: 2 | """ 3 | Store meta information for an object 4 | """ 5 | def __init__(self, id: int): 6 | self.id = id 7 | self.poke_count = 0 # count number of detections missed 8 | 9 | def poke(self) -> None: 10 | self.poke_count += 1 11 | 12 | def unpoke(self) -> None: 13 | self.poke_count = 0 14 | 15 | def __hash__(self): 16 | return hash(self.id) 17 | 18 | def __eq__(self, other): 19 | if type(other) == int: 20 | return self.id == other 21 | return self.id == other.id 22 | 23 | def __repr__(self): 24 | return f'(ID: {self.id})' 25 | -------------------------------------------------------------------------------- /models/tracker/inference/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/tracker/inference/utils/__init__.py -------------------------------------------------------------------------------- /models/tracker/inference/utils/args_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from omegaconf import DictConfig 3 | 4 | log = logging.getLogger() 5 | 6 | 7 | def get_dataset_cfg(cfg: DictConfig): 8 | dataset_name = cfg.dataset 9 | data_cfg = cfg.datasets[dataset_name] 10 | 11 | potential_overrides = [ 12 | 'image_directory', 13 | 'mask_directory', 14 | 'json_directory', 15 | 'size', 16 | 'save_all', 17 | 'use_all_masks', 18 | 'use_long_term', 19 | 'mem_every', 20 | ] 21 | 22 | for override in potential_overrides: 23 | if cfg[override] is not None: 24 | log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}') 25 | data_cfg[override] = cfg[override] 26 | # escalte all potential overrides to the top-level config 27 | if override in data_cfg: 28 | cfg[override] = data_cfg[override] 29 | 30 | return data_cfg 31 | -------------------------------------------------------------------------------- /models/tracker/inference/utils/burst_utils.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import copy 3 | import json 4 | 5 | 6 | class BURSTResultHandler: 7 | def __init__(self, dataset_json): 8 | self.dataset_json = copy.deepcopy(dataset_json) 9 | 10 | # get rid of the segmentations while keeping the metadata 11 | self.dataset_json['sequences'] = [] 12 | 13 | def add_sequence(self, sequence_json): 14 | self.dataset_json['sequences'].append(sequence_json) 15 | 16 | def dump(self, root): 17 | json_path = path.join(root, 'predictions.json') 18 | with open(json_path, 'w') as f: 19 | json.dump(self.dataset_json, f) -------------------------------------------------------------------------------- /models/tracker/inference/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | import torch 3 | 4 | from models.tracker.inference.object_info import ObjectInfo 5 | 6 | 7 | class FrameInfo: 8 | def __init__(self, image: torch.Tensor, mask: torch.Tensor, segments_info: List[ObjectInfo], 9 | ti: int, info: Dict): 10 | self.image = image 11 | self.mask = mask 12 | self.segments_info = segments_info 13 | self.ti = ti 14 | self.info = info 15 | 16 | @property 17 | def name(self) -> str: 18 | return self.info['frame'] 19 | 20 | @property 21 | def shape(self) -> Tuple(int): 22 | return self.info['shape'] 23 | 24 | @property 25 | def need_save(self) -> bool: 26 | return self.info['save'] 27 | -------------------------------------------------------------------------------- /models/tracker/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/tracker/model/__init__.py -------------------------------------------------------------------------------- /models/tracker/model/aux_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | For computing auxiliary outputs for auxiliary losses 3 | """ 4 | from typing import Dict 5 | from omegaconf import DictConfig 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .group_modules import GConv2d 10 | from models.tracker.utils.tensor_utils import aggregate 11 | 12 | 13 | class LinearPredictor(nn.Module): 14 | def __init__(self, x_dim: int, pix_dim: int): 15 | super().__init__() 16 | self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1) 17 | 18 | def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: 19 | # pixel_feat: B*pix_dim*H*W 20 | # x: B*num_objects*x_dim*H*W 21 | num_objects = x.shape[1] 22 | x = self.projection(x) 23 | 24 | pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) 25 | logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1] 26 | return logits 27 | 28 | 29 | class DirectPredictor(nn.Module): 30 | def __init__(self, x_dim: int): 31 | super().__init__() 32 | self.projection = GConv2d(x_dim, 1, kernel_size=1) 33 | 34 | def forward(self, x: torch.Tensor) -> torch.Tensor: 35 | # x: B*num_objects*x_dim*H*W 36 | logits = self.projection(x).squeeze(2) 37 | return logits 38 | 39 | 40 | class AuxComputer(nn.Module): 41 | def __init__(self, cfg: DictConfig): 42 | super().__init__() 43 | 44 | use_sensory_aux = cfg.model.aux_loss.sensory.enabled 45 | self.use_query_aux = cfg.model.aux_loss.query.enabled 46 | 47 | sensory_dim = cfg.model.sensory_dim 48 | embed_dim = cfg.model.embed_dim 49 | 50 | if use_sensory_aux: 51 | self.sensory_aux = LinearPredictor(sensory_dim, embed_dim) 52 | else: 53 | self.sensory_aux = None 54 | 55 | def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor: 56 | prob = torch.sigmoid(logits) 57 | if selector is not None: 58 | prob = prob * selector 59 | logits = aggregate(prob, dim=1) 60 | return logits 61 | 62 | def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor], 63 | selector: torch.Tensor) -> Dict[str, torch.Tensor]: 64 | sensory = aux_input['sensory'] 65 | q_logits = aux_input['q_logits'] 66 | 67 | aux_output = {} 68 | aux_output['attn_mask'] = aux_input['attn_mask'] 69 | 70 | if self.sensory_aux is not None: 71 | # B*num_objects*H*W 72 | logits = self.sensory_aux(pix_feat, sensory) 73 | aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector) 74 | if self.use_query_aux: 75 | # B*num_objects*num_levels*H*W 76 | aux_output['q_logits'] = self._aggregate_with_selector( 77 | torch.stack(q_logits, dim=2), 78 | selector.unsqueeze(2) if selector is not None else None) 79 | 80 | return aux_output -------------------------------------------------------------------------------- /models/tracker/model/channel_attn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class CAResBlock(nn.Module): 8 | def __init__(self, in_dim: int, out_dim: int, residual: bool = True): 9 | super().__init__() 10 | self.residual = residual 11 | self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1) 12 | self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1) 13 | 14 | t = int((abs(math.log2(out_dim)) + 1) // 2) 15 | k = t if t % 2 else t + 1 16 | self.pool = nn.AdaptiveAvgPool2d(1) 17 | self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False) 18 | 19 | if self.residual: 20 | if in_dim == out_dim: 21 | self.downsample = nn.Identity() 22 | else: 23 | self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1) 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | r = x 27 | x = self.conv1(F.relu(x)) 28 | x = self.conv2(F.relu(x)) 29 | 30 | b, c = x.shape[:2] 31 | w = self.pool(x).view(b, 1, c) 32 | w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1 33 | 34 | if self.residual: 35 | x = x * w + self.downsample(r) 36 | else: 37 | x = x * w 38 | 39 | return x 40 | -------------------------------------------------------------------------------- /models/tracker/model/group_modules.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .channel_attn import CAResBlock 6 | 7 | 8 | def interpolate_groups(g: torch.Tensor, ratio: float, mode: str, 9 | align_corners: bool) -> torch.Tensor: 10 | batch_size, num_objects = g.shape[:2] 11 | g = F.interpolate(g.flatten(start_dim=0, end_dim=1), 12 | scale_factor=ratio, 13 | mode=mode, 14 | align_corners=align_corners) 15 | g = g.view(batch_size, num_objects, *g.shape[1:]) 16 | return g 17 | 18 | 19 | def upsample_groups(g: torch.Tensor, 20 | ratio: float = 2, 21 | mode: str = 'bilinear', 22 | align_corners: bool = False) -> torch.Tensor: 23 | return interpolate_groups(g, ratio, mode, align_corners) 24 | 25 | 26 | def downsample_groups(g: torch.Tensor, 27 | ratio: float = 1 / 2, 28 | mode: str = 'area', 29 | align_corners: bool = None) -> torch.Tensor: 30 | return interpolate_groups(g, ratio, mode, align_corners) 31 | 32 | 33 | class GConv2d(nn.Conv2d): 34 | def forward(self, g: torch.Tensor) -> torch.Tensor: 35 | batch_size, num_objects = g.shape[:2] 36 | g = super().forward(g.flatten(start_dim=0, end_dim=1)) 37 | return g.view(batch_size, num_objects, *g.shape[1:]) 38 | 39 | 40 | class GroupResBlock(nn.Module): 41 | def __init__(self, in_dim: int, out_dim: int): 42 | super().__init__() 43 | 44 | if in_dim == out_dim: 45 | self.downsample = nn.Identity() 46 | else: 47 | self.downsample = GConv2d(in_dim, out_dim, kernel_size=1) 48 | 49 | self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1) 50 | self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1) 51 | 52 | def forward(self, g: torch.Tensor) -> torch.Tensor: 53 | out_g = self.conv1(F.relu(g)) 54 | out_g = self.conv2(F.relu(out_g)) 55 | 56 | g = self.downsample(g) 57 | 58 | return out_g + g 59 | 60 | 61 | class MainToGroupDistributor(nn.Module): 62 | def __init__(self, 63 | x_transform: Optional[nn.Module] = None, 64 | g_transform: Optional[nn.Module] = None, 65 | method: str = 'cat', 66 | reverse_order: bool = False): 67 | super().__init__() 68 | 69 | self.x_transform = x_transform 70 | self.g_transform = g_transform 71 | self.method = method 72 | self.reverse_order = reverse_order 73 | 74 | def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor: 75 | num_objects = g.shape[1] 76 | 77 | if self.x_transform is not None: 78 | x = self.x_transform(x) 79 | 80 | if self.g_transform is not None: 81 | g = self.g_transform(g) 82 | 83 | if not skip_expand: 84 | x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) 85 | if self.method == 'cat': 86 | if self.reverse_order: 87 | g = torch.cat([g, x], 2) 88 | else: 89 | g = torch.cat([x, g], 2) 90 | elif self.method == 'add': 91 | g = x + g 92 | elif self.method == 'mulcat': 93 | g = torch.cat([x * g, g], dim=2) 94 | elif self.method == 'muladd': 95 | g = x * g + g 96 | else: 97 | raise NotImplementedError 98 | 99 | return g 100 | 101 | 102 | class GroupFeatureFusionBlock(nn.Module): 103 | def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int): 104 | super().__init__() 105 | 106 | x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1) 107 | g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1) 108 | 109 | self.distributor = MainToGroupDistributor(x_transform=x_transform, 110 | g_transform=g_transform, 111 | method='add') 112 | self.block1 = CAResBlock(out_dim, out_dim) 113 | self.block2 = CAResBlock(out_dim, out_dim) 114 | 115 | def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: 116 | batch_size, num_objects = g.shape[:2] 117 | 118 | g = self.distributor(x, g) 119 | 120 | g = g.flatten(start_dim=0, end_dim=1) 121 | 122 | g = self.block1(g) 123 | g = self.block2(g) 124 | 125 | g = g.view(batch_size, num_objects, *g.shape[1:]) 126 | 127 | return g -------------------------------------------------------------------------------- /models/tracker/model/losses.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | from omegaconf import DictConfig 3 | from collections import defaultdict 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from models.tracker.utils.point_features import calculate_uncertainty, point_sample, get_uncertain_point_coords_with_randomness 8 | from models.tracker.utils.tensor_utils import cls_to_one_hot 9 | 10 | 11 | @torch.jit.script 12 | def ce_loss(logits: torch.Tensor, soft_gt: torch.Tensor) -> torch.Tensor: 13 | # logits: T*C*num_points 14 | loss = F.cross_entropy(logits, soft_gt, reduction='none') 15 | # sum over temporal dimension 16 | return loss.sum(0).mean() 17 | 18 | 19 | @torch.jit.script 20 | def dice_loss(mask: torch.Tensor, soft_gt: torch.Tensor) -> torch.Tensor: 21 | # mask: T*C*num_points 22 | # soft_gt: T*C*num_points 23 | # ignores the background 24 | mask = mask[:, 1:].flatten(start_dim=2) 25 | gt = soft_gt[:, 1:].float().flatten(start_dim=2) 26 | numerator = 2 * (mask * gt).sum(-1) 27 | denominator = mask.sum(-1) + gt.sum(-1) 28 | loss = 1 - (numerator + 1) / (denominator + 1) 29 | return loss.sum(0).mean() 30 | 31 | 32 | class LossComputer: 33 | def __init__(self, cfg: DictConfig, stage_cfg: DictConfig): 34 | super().__init__() 35 | self.point_supervision = stage_cfg.point_supervision 36 | self.num_points = stage_cfg.train_num_points 37 | self.oversample_ratio = stage_cfg.oversample_ratio 38 | self.importance_sample_ratio = stage_cfg.importance_sample_ratio 39 | 40 | self.sensory_weight = cfg.model.aux_loss.sensory.weight 41 | self.query_weight = cfg.model.aux_loss.query.weight 42 | 43 | def mask_loss(self, logits: torch.Tensor, 44 | soft_gt: torch.Tensor) -> (torch.Tensor, torch.Tensor): 45 | assert self.point_supervision 46 | 47 | with torch.no_grad(): 48 | # sample point_coords 49 | point_coords = get_uncertain_point_coords_with_randomness( 50 | logits, lambda x: calculate_uncertainty(x), self.num_points, self.oversample_ratio, 51 | self.importance_sample_ratio) 52 | # get gt labels 53 | point_labels = point_sample(soft_gt, point_coords, align_corners=False) 54 | point_logits = point_sample(logits, point_coords, align_corners=False) 55 | # point_labels and point_logits: B*C*num_points 56 | 57 | loss_ce = ce_loss(point_logits, point_labels) 58 | loss_dice = dice_loss(point_logits.softmax(dim=1), point_labels) 59 | 60 | return loss_ce, loss_dice 61 | 62 | def compute(self, data: Dict[str, torch.Tensor], 63 | num_objects: List[int]) -> Dict[str, torch.Tensor]: 64 | batch_size, num_frames = data['rgb'].shape[:2] 65 | losses = defaultdict(float) 66 | t_range = range(1, num_frames) 67 | 68 | for bi in range(batch_size): 69 | logits = torch.stack([data[f'logits_{ti}'][bi, :num_objects[bi] + 1] for ti in t_range], 70 | dim=0) 71 | cls_gt = data['cls_gt'][bi, 1:] # remove gt for the first frame 72 | soft_gt = cls_to_one_hot(cls_gt, num_objects[bi]) 73 | 74 | loss_ce, loss_dice = self.mask_loss(logits, soft_gt) 75 | losses['loss_ce'] += loss_ce / batch_size 76 | losses['loss_dice'] += loss_dice / batch_size 77 | 78 | aux = [data[f'aux_{ti}'] for ti in t_range] 79 | if 'sensory_logits' in aux[0]: 80 | sensory_log = torch.stack( 81 | [a['sensory_logits'][bi, :num_objects[bi] + 1] for a in aux], dim=0) 82 | loss_ce, loss_dice = self.mask_loss(sensory_log, soft_gt) 83 | losses['aux_sensory_ce'] += loss_ce / batch_size * self.sensory_weight 84 | losses['aux_sensory_dice'] += loss_dice / batch_size * self.sensory_weight 85 | if 'q_logits' in aux[0]: 86 | num_levels = aux[0]['q_logits'].shape[2] 87 | 88 | for l in range(num_levels): 89 | query_log = torch.stack( 90 | [a['q_logits'][bi, :num_objects[bi] + 1, l] for a in aux], dim=0) 91 | loss_ce, loss_dice = self.mask_loss(query_log, soft_gt) 92 | losses[f'aux_query_ce_l{l}'] += loss_ce / batch_size * self.query_weight 93 | losses[f'aux_query_dice_l{l}'] += loss_dice / batch_size * self.query_weight 94 | 95 | losses['total_loss'] = sum(losses.values()) 96 | 97 | return losses 98 | -------------------------------------------------------------------------------- /models/tracker/model/modules.py: -------------------------------------------------------------------------------- 1 | from typing import List, Iterable 2 | import torch 3 | import torch.nn as nn 4 | 5 | from models.tracker.model.group_modules import * 6 | 7 | 8 | class MaskUpsampleBlock(nn.Module): 9 | def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2): 10 | super().__init__() 11 | self.distributor = MainToGroupDistributor(method='add') 12 | self.out_conv = GroupResBlock(in_dim, out_dim) 13 | self.scale_factor = scale_factor 14 | 15 | def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor: 16 | g = upsample_groups(in_g, ratio=self.scale_factor) 17 | g = self.distributor(skip_f, g) 18 | g = self.out_conv(g) 19 | return g 20 | 21 | 22 | class DecoderFeatureProcessor(nn.Module): 23 | def __init__(self, decoder_dims: List[int], out_dims: List[int]): 24 | super().__init__() 25 | self.transforms = nn.ModuleList([ 26 | nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims) 27 | ]) 28 | 29 | def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]: 30 | outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)] 31 | return outputs 32 | 33 | 34 | # @torch.jit.script 35 | def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor: 36 | # h: batch_size * num_objects * hidden_dim * h * w 37 | # values: batch_size * num_objects * (hidden_dim*3) * h * w 38 | dim = values.shape[2] // 3 39 | forget_gate = torch.sigmoid(values[:, :, :dim]) 40 | update_gate = torch.sigmoid(values[:, :, dim:dim * 2]) 41 | new_value = torch.tanh(values[:, :, dim * 2:]) 42 | new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value 43 | return new_h 44 | 45 | 46 | class SensoryUpdater(nn.Module): 47 | # Used in the decoder, multi-scale feature + GRU 48 | def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int): 49 | super().__init__() 50 | self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1) 51 | self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1) 52 | self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1) 53 | 54 | self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) 55 | 56 | nn.init.xavier_normal_(self.transform.weight) 57 | 58 | def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: 59 | g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \ 60 | self.g4_conv(downsample_groups(g[2], ratio=1/4)) 61 | 62 | with torch.cuda.amp.autocast(enabled=False): 63 | g = g.float() 64 | h = h.float() 65 | values = self.transform(torch.cat([g, h], dim=2)) 66 | new_h = _recurrent_update(h, values) 67 | 68 | return new_h 69 | 70 | 71 | class SensoryDeepUpdater(nn.Module): 72 | def __init__(self, f_dim: int, sensory_dim: int): 73 | super().__init__() 74 | self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) 75 | 76 | nn.init.xavier_normal_(self.transform.weight) 77 | 78 | def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: 79 | with torch.cuda.amp.autocast(enabled=False): 80 | g = g.float() 81 | h = h.float() 82 | values = self.transform(torch.cat([g, h], dim=2)) 83 | new_h = _recurrent_update(h, values) 84 | 85 | return new_h 86 | -------------------------------------------------------------------------------- /models/tracker/model/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/tracker/model/transformer/__init__.py -------------------------------------------------------------------------------- /models/tracker/model/transformer/object_summarizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | from omegaconf import DictConfig 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from models.tracker.model.transformer.positional_encoding import PositionalEncoding 8 | 9 | 10 | # @torch.jit.script 11 | def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor, 12 | logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): 13 | # value: B*num_objects*H*W*value_dim 14 | # logits: B*num_objects*H*W*num_summaries 15 | # masks: B*num_objects*H*W*num_summaries: 1 if allowed 16 | weights = logits.sigmoid() * masks 17 | # B*num_objects*num_summaries*value_dim 18 | sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value) 19 | # B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1 20 | area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1) 21 | 22 | # B*num_objects*num_summaries*value_dim 23 | return sums, area 24 | 25 | 26 | class ObjectSummarizer(nn.Module): 27 | def __init__(self, model_cfg: DictConfig): 28 | super().__init__() 29 | 30 | this_cfg = model_cfg.object_summarizer 31 | self.value_dim = model_cfg.value_dim 32 | self.embed_dim = this_cfg.embed_dim 33 | self.num_summaries = this_cfg.num_summaries 34 | self.add_pe = this_cfg.add_pe 35 | self.pixel_pe_scale = model_cfg.pixel_pe_scale 36 | self.pixel_pe_temperature = model_cfg.pixel_pe_temperature 37 | 38 | if self.add_pe: 39 | self.pos_enc = PositionalEncoding(self.embed_dim, 40 | scale=self.pixel_pe_scale, 41 | temperature=self.pixel_pe_temperature) 42 | 43 | self.input_proj = nn.Linear(self.value_dim, self.embed_dim) 44 | self.feature_pred = nn.Sequential( 45 | nn.Linear(self.embed_dim, self.embed_dim), 46 | nn.ReLU(inplace=True), 47 | nn.Linear(self.embed_dim, self.embed_dim), 48 | ) 49 | self.weights_pred = nn.Sequential( 50 | nn.Linear(self.embed_dim, self.embed_dim), 51 | nn.ReLU(inplace=True), 52 | nn.Linear(self.embed_dim, self.num_summaries), 53 | ) 54 | 55 | def forward(self, 56 | masks: torch.Tensor, 57 | value: torch.Tensor, 58 | need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]): 59 | # masks: B*num_objects*(H0)*(W0) 60 | # value: B*num_objects*value_dim*H*W 61 | # -> B*num_objects*H*W*value_dim 62 | h, w = value.shape[-2:] 63 | masks = F.interpolate(masks, size=(h, w), mode='area') 64 | masks = masks.unsqueeze(-1) 65 | inv_masks = 1 - masks 66 | repeated_masks = torch.cat([ 67 | masks.expand(-1, -1, -1, -1, self.num_summaries // 2), 68 | inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2), 69 | ], 70 | dim=-1) 71 | 72 | value = value.permute(0, 1, 3, 4, 2) 73 | value = self.input_proj(value) 74 | if self.add_pe: 75 | pe = self.pos_enc(value) 76 | value = value + pe 77 | 78 | with torch.cuda.amp.autocast(enabled=False): 79 | value = value.float() 80 | feature = self.feature_pred(value) 81 | logits = self.weights_pred(value) 82 | sums, area = _weighted_pooling(repeated_masks, feature, logits) 83 | 84 | summaries = torch.cat([sums, area], dim=-1) 85 | 86 | if need_weights: 87 | return summaries, logits 88 | else: 89 | return summaries, None -------------------------------------------------------------------------------- /models/tracker/model/transformer/positional_encoding.py: -------------------------------------------------------------------------------- 1 | # Reference: 2 | # https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py 3 | # https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py 4 | 5 | import math 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | 12 | def get_emb(sin_inp: torch.Tensor) -> torch.Tensor: 13 | """ 14 | Gets a base embedding for one dimension with sin and cos intertwined 15 | """ 16 | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) 17 | return torch.flatten(emb, -2, -1) 18 | 19 | 20 | class PositionalEncoding(nn.Module): 21 | def __init__(self, 22 | dim: int, 23 | scale: float = math.pi * 2, 24 | temperature: float = 10000, 25 | normalize: bool = True, 26 | channel_last: bool = True, 27 | transpose_output: bool = False): 28 | super().__init__() 29 | dim = int(np.ceil(dim / 4) * 2) 30 | self.dim = dim 31 | inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim)) 32 | self.register_buffer("inv_freq", inv_freq) 33 | self.normalize = normalize 34 | self.scale = scale 35 | self.eps = 1e-6 36 | self.channel_last = channel_last 37 | self.transpose_output = transpose_output 38 | 39 | self.cached_penc = None # the cache is irrespective of the number of objects 40 | 41 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 42 | """ 43 | :param tensor: A 4/5d tensor of size 44 | channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c) 45 | channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w) 46 | :return: positional encoding tensor that has the same shape as the input if the input is 4d 47 | if the input is 5d, the output is broadcastable along the k-dimension 48 | """ 49 | if len(tensor.shape) != 4 and len(tensor.shape) != 5: 50 | raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!') 51 | 52 | if len(tensor.shape) == 5: 53 | # take a sample from the k dimension 54 | num_objects = tensor.shape[1] 55 | tensor = tensor[:, 0] 56 | else: 57 | num_objects = None 58 | 59 | if self.channel_last: 60 | batch_size, h, w, c = tensor.shape 61 | else: 62 | batch_size, c, h, w = tensor.shape 63 | 64 | if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: 65 | if num_objects is None: 66 | return self.cached_penc 67 | else: 68 | return self.cached_penc.unsqueeze(1) 69 | 70 | self.cached_penc = None 71 | 72 | pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype) 73 | pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype) 74 | if self.normalize: 75 | pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale 76 | pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale 77 | 78 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) 79 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 80 | emb_y = get_emb(sin_inp_y).unsqueeze(1) 81 | emb_x = get_emb(sin_inp_x) 82 | 83 | emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype) 84 | emb[:, :, :self.dim] = emb_x 85 | emb[:, :, self.dim:] = emb_y 86 | 87 | if not self.channel_last and self.transpose_output: 88 | # cancelled out 89 | pass 90 | elif (not self.channel_last) or (self.transpose_output): 91 | emb = emb.permute(2, 0, 1) 92 | 93 | self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1) 94 | if num_objects is None: 95 | return self.cached_penc 96 | else: 97 | return self.cached_penc.unsqueeze(1) 98 | 99 | 100 | if __name__ == '__main__': 101 | pe = PositionalEncoding(8).cuda() 102 | input = torch.ones((1, 8, 8, 8)).cuda() 103 | output = pe(input) 104 | # print(output) 105 | print(output[0, :, 0, 0]) 106 | print(output[0, :, 0, 5]) 107 | print(output[0, 0, :, 0]) 108 | print(output[0, 0, 0, :]) 109 | -------------------------------------------------------------------------------- /models/tracker/model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/tracker/model/utils/__init__.py -------------------------------------------------------------------------------- /models/tracker/model/utils/memory_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from typing import Optional, Union, Tuple 4 | 5 | 6 | # @torch.jit.script 7 | def get_similarity(mk: torch.Tensor, 8 | ms: torch.Tensor, 9 | qk: torch.Tensor, 10 | qe: torch.Tensor, 11 | add_batch_dim: bool = False) -> torch.Tensor: 12 | # used for training/inference and memory reading/memory potentiation 13 | # mk: B x CK x [N] - Memory keys 14 | # ms: B x 1 x [N] - Memory shrinkage 15 | # qk: B x CK x [HW/P] - Query keys 16 | # qe: B x CK x [HW/P] - Query selection 17 | # Dimensions in [] are flattened 18 | if add_batch_dim: 19 | mk, ms = mk.unsqueeze(0), ms.unsqueeze(0) 20 | qk, qe = qk.unsqueeze(0), qe.unsqueeze(0) 21 | 22 | CK = mk.shape[1] 23 | mk = mk.flatten(start_dim=2) 24 | ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None 25 | qk = qk.flatten(start_dim=2) 26 | qe = qe.flatten(start_dim=2) if qe is not None else None 27 | 28 | if qe is not None: 29 | # See XMem's appendix for derivation 30 | mk = mk.transpose(1, 2) 31 | a_sq = (mk.pow(2) @ qe) 32 | two_ab = 2 * (mk @ (qk * qe)) 33 | b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) 34 | similarity = (-a_sq + two_ab - b_sq) 35 | else: 36 | # similar to STCN if we don't have the selection term 37 | a_sq = mk.pow(2).sum(1).unsqueeze(2) 38 | two_ab = 2 * (mk.transpose(1, 2) @ qk) 39 | similarity = (-a_sq + two_ab) 40 | 41 | if ms is not None: 42 | similarity = similarity * ms / math.sqrt(CK) # B*N*HW 43 | else: 44 | similarity = similarity / math.sqrt(CK) # B*N*HW 45 | 46 | return similarity 47 | 48 | 49 | def do_softmax( 50 | similarity: torch.Tensor, 51 | top_k: Optional[int] = None, 52 | inplace: bool = False, 53 | return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 54 | # normalize similarity with top-k softmax 55 | # similarity: B x N x [HW/P] 56 | # use inplace with care 57 | if top_k is not None: 58 | values, indices = torch.topk(similarity, k=top_k, dim=1) 59 | 60 | x_exp = values.exp_() 61 | x_exp /= torch.sum(x_exp, dim=1, keepdim=True) 62 | if inplace: 63 | similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW 64 | affinity = similarity 65 | else: 66 | affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW 67 | else: 68 | maxes = torch.max(similarity, dim=1, keepdim=True)[0] 69 | x_exp = torch.exp(similarity - maxes) 70 | x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) 71 | affinity = x_exp / x_exp_sum 72 | indices = None 73 | 74 | if return_usage: 75 | return affinity, affinity.sum(dim=2) 76 | 77 | return affinity 78 | 79 | 80 | def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor, 81 | qe: torch.Tensor) -> torch.Tensor: 82 | # shorthand used in training with no top-k 83 | similarity = get_similarity(mk, ms, qk, qe) 84 | affinity = do_softmax(similarity) 85 | return affinity 86 | 87 | 88 | def readout(affinity: torch.Tensor, mv: torch.Tensor) -> torch.Tensor: 89 | B, CV, T, H, W = mv.shape 90 | 91 | mo = mv.view(B, CV, T * H * W) 92 | mem = torch.bmm(mo, affinity) 93 | mem = mem.view(B, CV, H, W) 94 | 95 | return mem 96 | -------------------------------------------------------------------------------- /models/tracker/model/utils/parameter_groups.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | log = logging.getLogger() 4 | 5 | 6 | def get_parameter_groups(model, stage_cfg, print_log=False): 7 | """ 8 | Assign different weight decays and learning rates to different parameters. 9 | Returns a parameter group which can be passed to the optimizer. 10 | """ 11 | weight_decay = stage_cfg.weight_decay 12 | embed_weight_decay = stage_cfg.embed_weight_decay 13 | backbone_lr_ratio = stage_cfg.backbone_lr_ratio 14 | base_lr = stage_cfg.learning_rate 15 | 16 | backbone_params = [] 17 | embed_params = [] 18 | other_params = [] 19 | 20 | embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe'] 21 | embedding_names = [e + '.weight' for e in embedding_names] 22 | 23 | # inspired by detectron2 24 | memo = set() 25 | for name, param in model.named_parameters(): 26 | if not param.requires_grad: 27 | continue 28 | # Avoid duplicating parameters 29 | if param in memo: 30 | continue 31 | memo.add(param) 32 | 33 | if name.startswith('module'): 34 | name = name[7:] 35 | 36 | inserted = False 37 | if name.startswith('pixel_encoder.'): 38 | backbone_params.append(param) 39 | inserted = True 40 | if print_log: 41 | log.info(f'{name} counted as a backbone parameter.') 42 | else: 43 | for e in embedding_names: 44 | if name.endswith(e): 45 | embed_params.append(param) 46 | inserted = True 47 | if print_log: 48 | log.info(f'{name} counted as an embedding parameter.') 49 | break 50 | 51 | if not inserted: 52 | other_params.append(param) 53 | 54 | parameter_groups = [ 55 | { 56 | 'params': backbone_params, 57 | 'lr': base_lr * backbone_lr_ratio, 58 | 'weight_decay': weight_decay 59 | }, 60 | { 61 | 'params': embed_params, 62 | 'lr': base_lr, 63 | 'weight_decay': embed_weight_decay 64 | }, 65 | { 66 | 'params': other_params, 67 | 'lr': base_lr, 68 | 'weight_decay': weight_decay 69 | }, 70 | ] 71 | 72 | return parameter_groups -------------------------------------------------------------------------------- /models/tracker/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/tracker/utils/__init__.py -------------------------------------------------------------------------------- /models/tracker/utils/load_subset.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def load_subset(path): 5 | with open(path, mode='r') as f: 6 | subset = set(f.read().splitlines()) 7 | return subset 8 | 9 | 10 | def load_empty_masks(path): 11 | with open(path, mode='r') as f: 12 | empty_masks = json.load(f) 13 | return empty_masks 14 | -------------------------------------------------------------------------------- /models/tracker/utils/log_integrator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integrate numerical values for some iterations 3 | Typically used for loss computation / logging to tensorboard 4 | Call finalize and create a new Integrator when you want to display/log 5 | """ 6 | from typing import Dict, Callable, Tuple 7 | import torch 8 | from tracker.utils.logger import TensorboardLogger 9 | 10 | 11 | class Integrator: 12 | def __init__(self, logger: TensorboardLogger, distributed: bool = True): 13 | self.values = {} 14 | self.counts = {} 15 | self.hooks = [] # List is used here to maintain insertion order 16 | 17 | self.logger = logger 18 | 19 | self.distributed = distributed 20 | self.local_rank = torch.distributed.get_rank() 21 | self.world_size = torch.distributed.get_world_size() 22 | 23 | def add_tensor(self, key: str, tensor: torch.Tensor): 24 | if key not in self.values: 25 | self.counts[key] = 1 26 | if type(tensor) == float or type(tensor) == int: 27 | self.values[key] = tensor 28 | else: 29 | self.values[key] = tensor.mean().item() 30 | else: 31 | self.counts[key] += 1 32 | if type(tensor) == float or type(tensor) == int: 33 | self.values[key] += tensor 34 | else: 35 | self.values[key] += tensor.mean().item() 36 | 37 | def add_dict(self, tensor_dict: Dict[str, torch.Tensor]): 38 | for k, v in tensor_dict.items(): 39 | self.add_tensor(k, v) 40 | 41 | def add_hook(self, hook: Callable[[torch.Tensor], Tuple[str, torch.Tensor]]): 42 | """ 43 | Adds a custom hook, i.e. compute new metrics using values in the dict 44 | The hook takes the dict as argument, and returns a (k, v) tuple 45 | e.g. for computing IoU 46 | """ 47 | if type(hook) == list: 48 | self.hooks.extend(hook) 49 | else: 50 | self.hooks.append(hook) 51 | 52 | def reset_except_hooks(self): 53 | self.values = {} 54 | self.counts = {} 55 | 56 | # Average and output the metrics 57 | def finalize(self, exp_id: str, prefix: str, it: int) -> None: 58 | 59 | for hook in self.hooks: 60 | k, v = hook(self.values) 61 | self.add_tensor(k, v) 62 | 63 | outputs = {} 64 | for k, v in self.values.items(): 65 | 66 | if k[:4] == 'hide': 67 | continue 68 | 69 | avg = v / self.counts[k] 70 | 71 | if self.distributed: 72 | # Inplace operation 73 | avg = torch.tensor(avg).cuda() 74 | torch.distributed.reduce(avg, dst=0) 75 | 76 | if self.local_rank == 0: 77 | avg = (avg / self.world_size).cpu().item() 78 | outputs[k] = avg 79 | else: 80 | # Simple does it 81 | outputs[k] = avg 82 | 83 | if (not self.distributed) or (self.local_rank == 0): 84 | self.logger.log_metrics(exp_id, prefix, outputs, it) 85 | -------------------------------------------------------------------------------- /models/tracker/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dumps things to tensorboard and console 3 | """ 4 | 5 | import os 6 | import logging 7 | import datetime 8 | from typing import Dict 9 | import numpy as np 10 | from PIL import Image 11 | 12 | from torch.utils.tensorboard import SummaryWriter 13 | from tracker.utils.time_estimator import TimeEstimator 14 | 15 | 16 | def tensor_to_numpy(image): 17 | image_np = (image.numpy() * 255).astype('uint8') 18 | return image_np 19 | 20 | 21 | def detach_to_cpu(x): 22 | return x.detach().cpu() 23 | 24 | 25 | def fix_width_trunc(x): 26 | return ('{:.9s}'.format('{:0.9f}'.format(x))) 27 | 28 | 29 | class TensorboardLogger: 30 | def __init__(self, run_dir, py_logger: logging.Logger, *, enabled_tb): 31 | self.run_dir = run_dir 32 | self.py_log = py_logger 33 | if enabled_tb: 34 | self.tb_log = SummaryWriter(run_dir) 35 | else: 36 | self.tb_log = None 37 | 38 | # Get current git info for logging 39 | try: 40 | import git 41 | repo = git.Repo(".") 42 | git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha) 43 | except (ImportError, RuntimeError): 44 | print('Failed to fetch git info. Defaulting to None') 45 | git_info = 'None' 46 | 47 | self.log_string('git', git_info) 48 | 49 | # used when logging metrics 50 | self.time_estimator: TimeEstimator = None 51 | 52 | def log_scalar(self, tag, x, it): 53 | if self.tb_log is None: 54 | return 55 | self.tb_log.add_scalar(tag, x, it) 56 | 57 | def log_metrics(self, exp_id, prefix, metrics: Dict, it): 58 | msg = f'{exp_id}-{prefix} - it {it:6d}: ' 59 | metrics_msg = '' 60 | for k, v in sorted(metrics.items()): 61 | self.log_scalar(f'{prefix}/{k}', v, it) 62 | metrics_msg += f'{k: >10}:{v:.7f},\t' 63 | 64 | if self.time_estimator is not None: 65 | self.time_estimator.update() 66 | avg_time = self.time_estimator.get_and_reset_avg_time() 67 | est = self.time_estimator.get_est_remaining(it) 68 | est = datetime.timedelta(seconds=est) 69 | if est.days > 0: 70 | remaining_str = f'{est.days}d {est.seconds // 3600}h' 71 | else: 72 | remaining_str = f'{est.seconds // 3600}h {(est.seconds%3600) // 60}m' 73 | eta = datetime.datetime.now() + est 74 | eta_str = eta.strftime('%Y-%m-%d %H:%M:%S') 75 | time_msg = f'avg_time:{avg_time:.3f},remaining:{remaining_str},eta:{eta_str},\t' 76 | msg = f'{msg} {time_msg}' 77 | 78 | msg = f'{msg} {metrics_msg}' 79 | self.py_log.info(msg) 80 | 81 | def log_image(self, stage_name, tag, image, it): 82 | image_dir = os.path.join(self.run_dir, f'{stage_name}_images') 83 | os.makedirs(image_dir, exist_ok=True) 84 | 85 | image = Image.fromarray(image) 86 | image.save(os.path.join(image_dir, f'{tag}_{it}.png')) 87 | 88 | def log_string(self, tag, x): 89 | self.py_log.info(f'{tag} - {x}') 90 | if self.tb_log is None: 91 | return 92 | self.tb_log.add_text(tag, x) 93 | 94 | def debug(self, x): 95 | self.py_log.debug(x) 96 | 97 | def info(self, x): 98 | self.py_log.info(x) 99 | 100 | def warning(self, x): 101 | self.py_log.warning(x) 102 | 103 | def error(self, x): 104 | self.py_log.error(x) 105 | 106 | def critical(self, x): 107 | self.py_log.critical(x) 108 | -------------------------------------------------------------------------------- /models/tracker/utils/mask_mapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def all_to_onehot(masks, labels): 5 | if len(masks.shape) == 3: 6 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) 7 | else: 8 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) 9 | 10 | for ni, l in enumerate(labels): 11 | Ms[ni] = (masks == l).astype(np.uint8) 12 | 13 | return Ms 14 | 15 | class MaskMapper: 16 | """ 17 | This class is used to convert a indexed-mask to a one-hot representation. 18 | It also takes care of remapping non-continuous indices 19 | It has two modes: 20 | 1. Default. Only masks with new indices are supposed to go into the remapper. 21 | This is also the case for YouTubeVOS. 22 | i.e., regions with index 0 are not "background", but "don't care". 23 | 24 | 2. Exhaustive. Regions with index 0 are considered "background". 25 | Every single pixel is considered to be "labeled". 26 | """ 27 | def __init__(self): 28 | self.labels = [] 29 | self.remappings = {} 30 | 31 | # if coherent, no mapping is required 32 | self.coherent = True 33 | 34 | def clear_labels(self): 35 | self.labels = [] 36 | self.remappings = {} 37 | # if coherent, no mapping is required 38 | self.coherent = True 39 | 40 | def convert_mask(self, mask, exhaustive=False): 41 | # mask is in index representation, H*W numpy array 42 | labels = np.unique(mask).astype(np.uint8) 43 | labels = labels[labels!=0].tolist() 44 | 45 | new_labels = list(set(labels) - set(self.labels)) 46 | if not exhaustive: 47 | assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' 48 | 49 | # add new remappings 50 | for i, l in enumerate(new_labels): 51 | self.remappings[l] = i+len(self.labels)+1 52 | if self.coherent and i+len(self.labels)+1 != l: 53 | self.coherent = False 54 | 55 | if exhaustive: 56 | new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1) 57 | else: 58 | if self.coherent: 59 | new_mapped_labels = new_labels 60 | else: 61 | new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1) 62 | 63 | self.labels.extend(new_labels) 64 | # mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float() 65 | mask = torch.from_numpy(mask).float() 66 | # mask num_objects*H*W 67 | return mask, new_mapped_labels 68 | 69 | 70 | def remap_index_mask(self, mask): 71 | # mask is in index representation, H*W numpy array 72 | if self.coherent: 73 | return mask 74 | 75 | new_mask = np.zeros_like(mask) 76 | for l, i in self.remappings.items(): 77 | new_mask[mask==i] = l 78 | return new_mask -------------------------------------------------------------------------------- /models/tracker/utils/palette.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0' 4 | 5 | youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f' 6 | 7 | davis_palette_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3) 8 | 9 | youtube_palette_np = np.frombuffer(youtube_palette, dtype=np.uint8).reshape(-1, 3) -------------------------------------------------------------------------------- /models/tracker/utils/pano_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from threading import Lock 3 | 4 | 5 | class ID2RGBConverter: 6 | def __init__(self): 7 | self.all_id = [] 8 | self.obj_to_id = {} 9 | self.lock = Lock() 10 | 11 | def _id_to_rgb(self, id: int): 12 | rgb = np.zeros((3, ), dtype=np.uint8) 13 | for i in range(3): 14 | rgb[i] = id % 256 15 | id = id // 256 16 | return rgb 17 | 18 | def convert(self, obj: int): 19 | with self.lock: 20 | if obj in self.obj_to_id: 21 | id = self.obj_to_id[obj] 22 | else: 23 | while True: 24 | id = np.random.randint(255, 256**3) 25 | if id not in self.all_id: 26 | break 27 | self.obj_to_id[obj] = id 28 | self.all_id.append(id) 29 | 30 | return id, self._id_to_rgb(id) 31 | -------------------------------------------------------------------------------- /models/tracker/utils/point_features.py: -------------------------------------------------------------------------------- 1 | # This file is copied from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py 2 | # such that users do not need to install detectron2 just for these two functions 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | from typing import List 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | def cat(tensors: List[torch.Tensor], dim: int = 0): 11 | """ 12 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list 13 | """ 14 | assert isinstance(tensors, (list, tuple)) 15 | if len(tensors) == 1: 16 | return tensors[0] 17 | return torch.cat(tensors, dim) 18 | 19 | 20 | def calculate_uncertainty(sem_seg_logits): 21 | """ 22 | For each location of the prediction `sem_seg_logits` we estimate uncerainty as the 23 | difference between top first and top second predicted logits. 24 | Args: 25 | mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and 26 | C is the number of foreground classes. The values are logits. 27 | Returns: 28 | scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with 29 | the most uncertain locations having the highest uncertainty score. 30 | """ 31 | if sem_seg_logits.shape[1] == 2: 32 | # binary segmentation 33 | return -(torch.abs(sem_seg_logits[:, 1:2])) 34 | top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0] 35 | return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) 36 | 37 | 38 | def point_sample(input, point_coords, **kwargs): 39 | """ 40 | A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. 41 | Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside 42 | [0, 1] x [0, 1] square. 43 | Args: 44 | input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. 45 | point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains 46 | [0, 1] x [0, 1] normalized point coordinates. 47 | Returns: 48 | output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains 49 | features for points in `point_coords`. The features are obtained via bilinear 50 | interpolation from `input` the same way as :function:`torch.nn.functional.grid_sample`. 51 | """ 52 | add_dim = False 53 | if point_coords.dim() == 3: 54 | add_dim = True 55 | point_coords = point_coords.unsqueeze(2) 56 | output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) 57 | if add_dim: 58 | output = output.squeeze(3) 59 | return output 60 | 61 | 62 | def get_uncertain_point_coords_with_randomness(coarse_logits, uncertainty_func, num_points, 63 | oversample_ratio, importance_sample_ratio): 64 | """ 65 | Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The uncertainties 66 | are calculated for each point using 'uncertainty_func' function that takes point's logit 67 | prediction as input. 68 | See PointRend paper for details. 69 | Args: 70 | coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for 71 | class-specific or class-agnostic prediction. 72 | uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that 73 | contains logit predictions for P points and returns their uncertainties as a Tensor of 74 | shape (N, 1, P). 75 | num_points (int): The number of points P to sample. 76 | oversample_ratio (int): Oversampling parameter. 77 | importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. 78 | Returns: 79 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P 80 | sampled points. 81 | """ 82 | assert oversample_ratio >= 1 83 | assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 84 | num_boxes = coarse_logits.shape[0] 85 | num_sampled = int(num_points * oversample_ratio) 86 | point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) 87 | point_logits = point_sample(coarse_logits, point_coords, align_corners=False) 88 | # It is crucial to calculate uncertainty based on the sampled prediction value for the points. 89 | # Calculating uncertainties of the coarse predictions first and sampling them for points leads 90 | # to incorrect results. 91 | # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between 92 | # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. 93 | # However, if we calculate uncertainties for the coarse predictions first, 94 | # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. 95 | point_uncertainties = uncertainty_func(point_logits) 96 | num_uncertain_points = int(importance_sample_ratio * num_points) 97 | num_random_points = num_points - num_uncertain_points 98 | idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] 99 | shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) 100 | idx += shift[:, None] 101 | point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 102 | 2) 103 | if num_random_points > 0: 104 | point_coords = cat( 105 | [ 106 | point_coords, 107 | torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), 108 | ], 109 | dim=1, 110 | ) 111 | return point_coords -------------------------------------------------------------------------------- /models/tracker/utils/range_transform.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | 3 | im_mean = (124, 116, 104) 4 | 5 | im_normalization = transforms.Normalize( 6 | mean=[0.485, 0.456, 0.406], 7 | std=[0.229, 0.224, 0.225] 8 | ) 9 | 10 | inv_im_trans = transforms.Normalize( 11 | mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 12 | std=[1/0.229, 1/0.224, 1/0.225]) 13 | -------------------------------------------------------------------------------- /models/tracker/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Iterable 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | # STM 7 | def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]): 8 | h, w = in_img.shape[-2:] 9 | 10 | if h % d > 0: 11 | new_h = h + d - h % d 12 | else: 13 | new_h = h 14 | if w % d > 0: 15 | new_w = w + d - w % d 16 | else: 17 | new_w = w 18 | lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2) 19 | lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2) 20 | pad_array = (int(lw), int(uw), int(lh), int(uh)) 21 | out = F.pad(in_img, pad_array) 22 | return out, pad_array 23 | 24 | 25 | def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor: 26 | if len(img.shape) == 4: 27 | if pad[2] + pad[3] > 0: 28 | img = img[:, :, pad[2]:-pad[3], :] 29 | if pad[0] + pad[1] > 0: 30 | img = img[:, :, :, pad[0]:-pad[1]] 31 | elif len(img.shape) == 3: 32 | if pad[2] + pad[3] > 0: 33 | img = img[:, pad[2]:-pad[3], :] 34 | if pad[0] + pad[1] > 0: 35 | img = img[:, :, pad[0]:-pad[1]] 36 | elif len(img.shape) == 5: 37 | if pad[2] + pad[3] > 0: 38 | img = img[:, :, :, pad[2]:-pad[3], :] 39 | if pad[0] + pad[1] > 0: 40 | img = img[:, :, :, :, pad[0]:-pad[1]] 41 | else: 42 | raise NotImplementedError 43 | return img 44 | 45 | 46 | # @torch.jit.script 47 | def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor: 48 | with torch.cuda.amp.autocast(enabled=False): 49 | prob = prob.float() 50 | new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob], 51 | dim).clamp(1e-7, 1 - 1e-7) 52 | logits = torch.log((new_prob / (1 - new_prob))) 53 | 54 | return logits 55 | 56 | 57 | # @torch.jit.script 58 | def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor: 59 | # cls_gt: B*1*H*W 60 | B, _, H, W = cls_gt.shape 61 | one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1) 62 | return one_hot -------------------------------------------------------------------------------- /models/tracker/utils/time_estimator.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class TimeEstimator: 5 | def __init__(self, total_iter, step_size): 6 | self.avg_time_window = [] # window-based average 7 | self.exp_avg_time = None # exponential moving average 8 | self.alpha = 0.7 # for exponential moving average 9 | 10 | self.last_time = time.time() # would not be accurate for the first iteration but well 11 | self.total_iter = total_iter 12 | self.step_size = step_size 13 | 14 | self.buffering_exp = True 15 | 16 | # call this at a fixed interval 17 | # does not have to be every step 18 | def update(self): 19 | curr_time = time.time() 20 | time_per_iter = curr_time - self.last_time 21 | self.last_time = curr_time 22 | 23 | self.avg_time_window.append(time_per_iter) 24 | 25 | if self.buffering_exp: 26 | if self.exp_avg_time is not None: 27 | # discard the first iteration call to not pollute the ema 28 | self.buffering_exp = False 29 | self.exp_avg_time = time_per_iter 30 | else: 31 | self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter 32 | 33 | def get_est_remaining(self, it): 34 | if self.exp_avg_time is None: 35 | return 0 36 | 37 | remaining_iter = self.total_iter - it 38 | return remaining_iter * self.exp_avg_time / self.step_size 39 | 40 | def get_and_reset_avg_time(self): 41 | avg = sum(self.avg_time_window) / len(self.avg_time_window) / self.step_size 42 | self.avg_time_window = [] 43 | return avg 44 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | edge-tts==6.1.9 2 | ultralytics == 8.0.120 3 | einops 4 | av 5 | sentencepiece 6 | cpm_kernels 7 | edge-tts==6.1.9 8 | ffmpeg-python==0.2 9 | clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33 10 | diffusers==0.21.4 11 | loguru 12 | numpy 13 | pandas 14 | opencv-python 15 | torchvision 16 | torch==2.1.1 17 | Pillow 18 | protobuf==4.23.4 19 | PyYAML 20 | segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 21 | transformers 22 | gradio==3.39.0 23 | openai-whisper==20231117 24 | anyconfig 25 | mdtex2html 26 | librosa 27 | moviepy 28 | omegaconf 29 | SwissArmyTransformer>=0.4.4 30 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .chatglm_handler import ChatGLMHandler 3 | from .whisper_handler import WhisperHandler 4 | from .edgetts_handler import EdgeTTSHandler 5 | from .gpt_handler import GPTHandler 6 | from .ai_wrapper import AIWrapper 7 | from .fastsam_handler import FastSAMHandler -------------------------------------------------------------------------------- /tools/base.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class BaseHandler: 4 | def __init__(self, args) -> None: 5 | self.model = None 6 | self.handle_args = args.get(type(self).__name__, {}) 7 | if self.handle_args.get('init_model_when_start_server'): 8 | self.init_model() 9 | 10 | def infer(self, input_data, **kwargs): 11 | raise NotImplementedError 12 | 13 | def init_model(self): 14 | raise NotImplementedError -------------------------------------------------------------------------------- /tools/chatglm_handler.py: -------------------------------------------------------------------------------- 1 | from .base import BaseHandler 2 | from utils.chatglm_utils import * 3 | 4 | 5 | class ChatGLMHandler(BaseHandler): 6 | def __init__(self, args): 7 | super().__init__(args) 8 | self.llm_model_path = self.handle_args.get('llm_model_path') 9 | self.num_gpus = self.handle_args.get('num_gpus', 2) 10 | if os.getenv('CUDA_VISIBLE_DEVICES'): 11 | self.num_gpus = min(self.num_gpus, len(os.environ['CUDA_VISIBLE_DEVICES'])) 12 | 13 | self.trust_remote_code = self.handle_args.get('trust_remote_code', True) 14 | self.device = self.handle_args.get('device', 'cuda:0') 15 | 16 | def init_model(self): 17 | if self.model is None: 18 | self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_path, trust_remote_code=True) 19 | if 'int' in os.path.basename(self.llm_model_path): 20 | self.model = AutoModel.from_pretrained(self.llm_model_path, trust_remote_code=True).to(self.device) 21 | else: 22 | self.model = load_model_on_gpus(self.llm_model_path, num_gpus=self.num_gpus) 23 | self.model = self.model.eval() 24 | 25 | def infer(self, input_text, history=[], max_length=2048, top_p=90, temperature=95, **kwargs): 26 | self.init_model() 27 | response, history = self.model.chat(self.tokenizer, input_text, history=history, max_length=max_length, top_p=top_p, temperature=temperature) 28 | return response, history 29 | 30 | def stream_chat(self, input_text, chatbot, max_length, top_p, temperature, history, past_key_values): 31 | self.init_model() 32 | chatbot.append((parse_text(input_text), "")) 33 | for response, history, past_key_values in self.model.stream_chat(self.tokenizer, input_text, history, past_key_values=past_key_values, 34 | return_past_key_values=True, 35 | max_length=max_length, top_p=top_p, 36 | temperature=temperature): 37 | chatbot[-1] = (parse_text(input_text), parse_text(response)) 38 | 39 | yield chatbot, history, past_key_values 40 | 41 | 42 | -------------------------------------------------------------------------------- /tools/chatvlm_handler.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM 2 | from .base import BaseHandler 3 | from utils.chatglm_utils import * 4 | 5 | 6 | class ChatVLMHandler(BaseHandler): 7 | def __init__(self, args): 8 | super().__init__(args) 9 | self.vlm_model_path = self.handle_args.get('vlm_model_path') 10 | self.num_gpus = self.handle_args.get('num_gpus', 2) 11 | if os.getenv('CUDA_VISIBLE_DEVICES'): 12 | self.num_gpus = min(self.num_gpus, len(os.environ['CUDA_VISIBLE_DEVICES'])) 13 | 14 | self.trust_remote_code = self.handle_args.get('trust_remote_code', True) 15 | self.device = self.handle_args.get('device', 'cuda:0') 16 | 17 | def init_model(self): 18 | if self.model is None: 19 | self.tokenizer = AutoTokenizer.from_pretrained(self.vlm_model_path, trust_remote_code=True) 20 | self.model = AutoModelForCausalLM.from_pretrained(self.vlm_model_path, device_map="cuda", trust_remote_code=self.trust_remote_code) 21 | self.model = self.model.eval() 22 | 23 | def chat(self, user_input, image, chatbot, history=None, **kwargs): 24 | self.init_model() 25 | 26 | query = self.tokenizer.from_list_format([ 27 | {'image': self.preprocess(image)}, 28 | {'text': user_input}, 29 | ]) 30 | response, history = self.model.chat(self.tokenizer, query=query, history=history) 31 | chatbot.append((parse_text(user_input), parse_text(response))) 32 | # image = tokenizer.draw_bbox_on_latest_picture(response, history) 33 | return chatbot, history 34 | 35 | def chat_stream(self, user_input, image, chatbot, history=None, **kwargs): 36 | self.init_model() 37 | chatbot.append((parse_text(user_input), "")) 38 | query = self.tokenizer.from_list_format([ 39 | {'image': self.preprocess(image)}, 40 | {'text': user_input}, 41 | ]) 42 | for response in self.model.chat_stream(self.tokenizer, query=query, history=history): 43 | chatbot[-1] = (parse_text(user_input), parse_text(response)) 44 | yield chatbot, history 45 | 46 | def preprocess(self, image): 47 | return image 48 | 49 | 50 | -------------------------------------------------------------------------------- /tools/edgetts_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import edge_tts 3 | import random 4 | import asyncio 5 | import pandas as pd 6 | from datetime import datetime 7 | from .base import BaseHandler 8 | from loguru import logger 9 | 10 | 11 | async def streaming_with_subtitles(text, audio_file, webvtt_file, voice, rate, volume, pitch) -> None: 12 | communicate = edge_tts.Communicate(text=text, voice=voice, rate=rate, volume=volume, pitch=pitch) 13 | submaker = edge_tts.SubMaker() 14 | with open(audio_file, "wb") as file: 15 | async for chunk in communicate.stream(): 16 | if chunk["type"] == "audio": 17 | file.write(chunk["data"]) 18 | elif chunk["type"] == "WordBoundary": 19 | submaker.create_sub((chunk["offset"], chunk["duration"]), chunk["text"]) 20 | 21 | with open(webvtt_file, "w", encoding="utf-8") as file: 22 | subtitles_info = submaker.generate_subs() 23 | file.write(subtitles_info) 24 | return subtitles_info 25 | 26 | 27 | class EdgeTTSHandler(BaseHandler): 28 | def __init__(self, args, **kwargs): 29 | # output_dir='/data1/zjx/ai_webui/products/audio_tmp', 30 | super().__init__(args) 31 | self.output_dir = self.handle_args.get("output_dir", "/tmp") 32 | 33 | def infer(self, tts_text_file=None, tts_text=None, voice='zh-CN-YunxiNeural', rate=0, volume=0, pitch=0, **kwargs): 34 | # 格式适配 35 | if rate >= 0: 36 | rate = f"+{rate}%" 37 | else: 38 | rate = f"{rate}%" 39 | if volume >= 0: 40 | volume = f"+{volume}%" 41 | else: 42 | volume = f"{volume}%" 43 | if pitch >= 0: 44 | pitch = f"+{pitch}Hz" 45 | else: 46 | pitch = f"{pitch}Hz" 47 | 48 | if tts_text_file: 49 | tts_text_file_path = tts_text_file.name 50 | text = "" 51 | with open(tts_text_file_path, "r") as f: 52 | for line in f: 53 | text += ' ' + line.rstrip() 54 | self.output_dir = os.path.dirname(tts_text_file_path) 55 | file_tag = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + "_" + os.path.basename(tts_text_file_path).split('.')[0] 56 | elif tts_text: 57 | file_tag = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + "_" + str(random.randint(0, 1000)) 58 | text = tts_text 59 | 60 | audio_file = os.path.join(self.output_dir, file_tag + ".mp3") 61 | webvtt_file = os.path.join(self.output_dir, file_tag + ".vtt") 62 | loop = asyncio.new_event_loop() # .get_event_loop_policy() 63 | asyncio.set_event_loop(loop) 64 | subtitles_info = loop.run_until_complete(streaming_with_subtitles(text, audio_file, webvtt_file, voice, rate, volume, pitch)) 65 | loop.close() 66 | 67 | # 后处理 68 | proc_subtitles = [] 69 | contents = subtitles_info.split("\r\n") 70 | for idx in range(1, len(contents)): 71 | if ' --> ' not in contents[idx]: 72 | continue 73 | start, end = contents[idx].split(' --> ') 74 | sentence = contents[idx+1].replace(' ', '') 75 | proc_subtitles.append({ 76 | "start": start, 77 | "end": end, 78 | "sentence": sentence 79 | }) 80 | df = pd.DataFrame(proc_subtitles) 81 | srt_file = webvtt_file.replace(".vtt", ".srt") 82 | excel_file = webvtt_file.replace(".vtt", ".xlsx") 83 | with open(srt_file, "w", encoding="utf-8") as f: 84 | for idx, row in enumerate(proc_subtitles): 85 | f.write(f"{idx+1}\n") 86 | f.write(f"{row['start']} --> {row['end']}\n") 87 | f.write(f"{row['sentence']}\n\n") 88 | 89 | df.to_excel(excel_file, index=False) 90 | file_list = [audio_file, webvtt_file, srt_file, excel_file] 91 | return audio_file, file_list 92 | 93 | def init_model(self): 94 | # Nothing to do 95 | logger.warning("## 无需初始化EdgeTTSHandler的模型!") 96 | 97 | if __name__ == "__main__": 98 | text_file = "/data1/zjx/ai_webui/raw_materials/test.txt" 99 | text_type = "file" 100 | audio_file = "test.mp3" 101 | webvtt_file = "test.vtt" 102 | voice = "zh-CN-YunxiNeural" 103 | edge_tts_handle = EdgeTTSHandle() 104 | # edge_tts_engine.streaming_with_subtitles(text_file, text_type, audio_file, webvtt_file, voice) 105 | # asyncio.run(amain()) 106 | 107 | 108 | -------------------------------------------------------------------------------- /tools/fastsam_handler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ultralytics import YOLO 3 | from .base import BaseHandler 4 | from PIL import ImageDraw 5 | import gradio as gr 6 | from utils.fastsam.tools import format_results, fast_process, point_prompt, text_prompt 7 | 8 | 9 | class FastSAMHandler(BaseHandler): 10 | """ 11 | FastSAMHandler is a handler for FastSAM. 12 | """ 13 | 14 | def __init__(self, args): 15 | super().__init__(args) 16 | self.fastsam_model_path = self.handle_args.get('fastsam_model_path') 17 | self.device = self.handle_args.get('device') 18 | self.global_points = [] 19 | self.global_point_label = [] 20 | 21 | def segment_everything( 22 | self, 23 | input, 24 | input_size=1024, 25 | iou_threshold=0.7, 26 | conf_threshold=0.25, 27 | better_quality=False, 28 | withContours=True, 29 | use_retina=True, 30 | text="", 31 | wider=False, 32 | mask_random_color=True, 33 | ): 34 | self.init_model() 35 | input_size = int(input_size) # 确保 imgsz 是整数 36 | # Thanks for the suggestion by hysts in HuggingFace. 37 | w, h = input.size 38 | scale = input_size / max(w, h) 39 | new_w = int(w * scale) 40 | new_h = int(h * scale) 41 | input = input.resize((new_w, new_h)) 42 | 43 | results = self.model(input, 44 | device=self.device, 45 | retina_masks=True, 46 | iou=iou_threshold, 47 | conf=conf_threshold, 48 | imgsz=input_size,) 49 | 50 | if len(text) > 0: 51 | results = format_results(results[0], 0) 52 | annotations, _ = text_prompt(results, text, input, device=self.device, wider=wider) 53 | annotations = np.array([annotations]) 54 | else: 55 | annotations = results[0].masks.data 56 | 57 | fig = fast_process(annotations=annotations, 58 | image=input, 59 | device=self.device, 60 | scale=(1024 // input_size), 61 | better_quality=better_quality, 62 | mask_random_color=mask_random_color, 63 | bbox=None, 64 | use_retina=use_retina, 65 | withContours=withContours,) 66 | return fig 67 | 68 | def segment_with_points( 69 | self, 70 | input, 71 | input_size=1024, 72 | iou_threshold=0.7, 73 | conf_threshold=0.25, 74 | better_quality=False, 75 | withContours=True, 76 | use_retina=True, 77 | mask_random_color=True, 78 | ): 79 | self.init_model() 80 | input_size = int(input_size) # 确保 imgsz 是整数 81 | # Thanks for the suggestion by hysts in HuggingFace. 82 | w, h = input.size 83 | scale = input_size / max(w, h) 84 | new_w = int(w * scale) 85 | new_h = int(h * scale) 86 | input = input.resize((new_w, new_h)) 87 | 88 | scaled_points = [[int(x * scale) for x in point] for point in self.global_points] 89 | 90 | results = self.model(input, 91 | device=self.device, 92 | retina_masks=True, 93 | iou=iou_threshold, 94 | conf=conf_threshold, 95 | imgsz=input_size,) 96 | 97 | results = format_results(results[0], 0) 98 | annotations, _ = point_prompt(results, scaled_points, self.global_point_label, new_h, new_w) 99 | annotations = np.array([annotations]) 100 | 101 | fig = fast_process(annotations=annotations, 102 | image=input, 103 | device=self.device, 104 | scale=(1024 // input_size), 105 | better_quality=better_quality, 106 | mask_random_color=mask_random_color, 107 | bbox=None, 108 | use_retina=use_retina, 109 | withContours=withContours,) 110 | 111 | self.global_points = [] 112 | self.global_point_label = [] 113 | return fig, None 114 | 115 | def get_points_with_draw(self, image, label, evt: gr.SelectData): 116 | x, y = evt.index[0], evt.index[1] 117 | point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255) 118 | self.global_points.append([x, y]) 119 | self.global_point_label.append(1 if label == 'Add Mask' else 0) 120 | 121 | print(x, y, label == 'Add Mask') 122 | 123 | # 创建一个可以在图像上绘图的对象 124 | draw = ImageDraw.Draw(image) 125 | draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color) 126 | return image 127 | 128 | 129 | def init_model(self): 130 | if self.model is None: 131 | self.model = YOLO(self.fastsam_model_path) 132 | -------------------------------------------------------------------------------- /tools/gpt_handler.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | try: 4 | from .base import BaseHandler 5 | except: 6 | from base import BaseHandler 7 | 8 | 9 | key="""Bearer eyJhbGciOiJSUzI1NiIsInNvdXJjZSI6IklBTSIsImtpZCI6IjIwTml0b3h6S1FxMHJPdTlvYVk2UzA3bFBnUTBjSnFQIn0.eyJzdWIiOiI2MDAwMDI5NTIiLCJpc3MiOiJodHRwczpcL1wvZGV2LWEuc3p6aGlqaW5nLmNvbVwvaWFtLW9uZWlkIiwiYmFpQnVVc2VySWQiOiIyMDkxOCIsImZpbmFuY2VVc2VySWQiOiIxMDUyNjI4NTM2MTY1MDA3MzYwIiwiZmVpU2h1SWQiOiJvbl8xOTdkMTRmZjM5MmFhYjZhM2I2ZWQ0NmQ4MWY3NjI4NSIsImF1ZCI6IjQ0MTEyODg1MTkwNTU4MTA1NiIsInNjb3BlIjoib3BlbmlkIiwibmFtZSI6IuacsemUpuelpSIsImZpbmFuY2VTaWduVXNlcklkIjoiMTAyMzI1MzM2MjI4OTI1MDMwNCIsInVzZXJUeXBlIjoiSU5ORVIiLCJleHAiOjE3MDIzNjcxODIsImlhdCI6MTcwMTc2MjM4MiwianRpIjoiY2ZhMmE1NjBhMWU4NGY5NmE0NmMyZmZhMDM4YjUwYTYiLCJlbWFpbCI6InpodWppbnhpYW5nQHpqLnRlY2giLCJhY2NvdW50Ijoiemh1amlueGlhbmcifQ.RWw-8EaD3hMoiNmcXEfR0Y-A1R35aMqCWIKWnse50YM4INBtQNivVHFbsZZRNwup8XRTQe7PWQl0dPfoGH-Kz7eR3xCIZ6--ATW_PP8CbqGTCnbWAbVxT3enXBZjpgx5qE-JN5g-ko5bwpPylah4Wg2B8n6T87wC8Iczc-Aps4L7oevWjPCrne8tha4g3AWmuWGgn00LJjy4cQlIK9ETqLeVJJAZEDm82GLpYePISFERzs4-olnGz5IKcALtVXnX_KAXIloy5q1f3TgeNfPNkfdL1_TeXsrj-TZkJFW9OUVmNtiW7yFOdBN9t9Gv6HfP-Onva9pnOankMSLNXyOLQQ""" 10 | class GPTHandler(BaseHandler): 11 | def __init__(self, args={}, **kwargs): 12 | super().__init__(args) 13 | self.api_url = self.handle_args.get("api_url") 14 | self.headers = { 15 | 'Content-Type': 'application/json', 16 | 'Authorization': key 17 | } 18 | 19 | def infer(self, input_text, **kwargs): 20 | data = { 21 | 'messages':[ 22 | { 23 | 'role': 'system', 24 | 'content': input_text, 25 | # 'history': history 26 | } 27 | ] 28 | } 29 | result = requests.post(self.api_url, headers=self.headers, data=json.dumps(data)) 30 | text = result.json()['data']['choices'][0]['message']['content'] 31 | return text, None 32 | 33 | 34 | if __name__ == "__main__": 35 | gpt_handler = GPTHandler({}) 36 | input_text = '帮我分析这个报错:ImportError: attempted relative import with no known parent package' 37 | text,_ = gpt_handler.infer(input_text) 38 | print(text) 39 | 40 | 41 | -------------------------------------------------------------------------------- /tools/visualglm_handler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base import BaseHandler 3 | from utils.chatglm_utils import * 4 | 5 | 6 | class VisualGLMHandler(BaseHandler): 7 | def __init__(self, args): 8 | super().__init__(args) 9 | self.model_path = self.handle_args.get('model_path') 10 | self.trust_remote_code = self.handle_args.get('trust_remote_code', True) 11 | self.device = self.handle_args.get('device', 'cuda:0') 12 | self.quant = self.handle_args.get('quant') 13 | 14 | def init_model(self): 15 | if self.model is None: 16 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=self.trust_remote_code) 17 | if self.quant in [4, 8]: 18 | self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=self.trust_remote_code).quantize(self.quant).half().to(self.device) 19 | else: 20 | self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=self.trust_remote_code).half().to(self.device) 21 | self.model = self.model.eval() 22 | 23 | def stream_chat(self, input, image_path, chatbot, max_length, top_p, temperature, history): 24 | self.init_model() 25 | if image_path is None: 26 | return [(input, "图片不能为空。请重新上传图片并重试。")], [] 27 | chatbot.append((parse_text(input), "")) 28 | with torch.no_grad(): 29 | for response, history in self.model.stream_chat(self.tokenizer, image_path, input, history, max_length=max_length, top_p=top_p, 30 | temperature=temperature): 31 | chatbot[-1] = (parse_text(input), parse_text(response)) 32 | 33 | yield chatbot, history 34 | 35 | def stream_chat2(self, image_path, chatbot, max_length, top_p, temperature): 36 | self.init_model() 37 | input, history = "描述这张图片。", [] 38 | chatbot.append((parse_text(input), "")) 39 | with torch.no_grad(): 40 | for response, history in self.model.stream_chat(self.tokenizer, image_path, input, history, max_length=max_length, 41 | top_p=top_p, 42 | temperature=temperature): 43 | chatbot[-1] = (parse_text(input), parse_text(response)) 44 | 45 | yield chatbot, history 46 | 47 | 48 | -------------------------------------------------------------------------------- /tools/whisper_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base import BaseHandler 3 | import whisper 4 | import pandas as pd 5 | 6 | class WhisperHandler(BaseHandler): 7 | def __init__(self, args, **kwargs): 8 | # model_name, model_dir, device='cuda:0' 9 | super().__init__(args) 10 | self.model_name = self.handle_args.get('model_name') 11 | self.device = self.handle_args.get('device', 'cuda:0') 12 | self.model_dir = self.handle_args.get('model_dir', os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models')) 13 | self.language_map = self.handle_args.get('language_map', {'英语': 'en', '普通话': 'zh'}) 14 | 15 | def infer(self, audio_file, language=None): 16 | self.init_model() 17 | if not isinstance(audio_file, str): 18 | audio_file = audio_file.name 19 | save_dir = os.path.dirname(audio_file) 20 | audio_name = os.path.basename(audio_file).split('.')[0] 21 | os.makedirs(save_dir, exist_ok=True) 22 | excel_file = os.path.join(save_dir, f"{audio_name}_speech_res.xlsx") 23 | result = self.model.transcribe(audio_file, language=self.language_map.get(language), verbose=True) 24 | # rec_language = result['language'] 25 | df = pd.DataFrame(result['segments']) 26 | df = df[['id', 'no_speech_prob', 'start', 'end', 'text']] 27 | df.to_excel(excel_file, index=False) 28 | text = ' '.join(df['text'].tolist()) 29 | result['text'] = text 30 | result['segments'] = excel_file 31 | return result 32 | 33 | def predict_v2(self, audio_file): 34 | # load excel_file 35 | audio = whisper.load_audio(audio_file) 36 | audio = whisper.pad_or_trim(audio) 37 | 38 | # make log-Mel spectrogram and move to the same device as the model 39 | mel = whisper.log_mel_spectrogram(audio).to(self.model.device) 40 | 41 | # detect the spoken language 42 | _, probs = self.model.detect_language(mel) 43 | print(f"Detected language: {max(probs, key=probs.get)}") 44 | 45 | # decode the audio 46 | options = whisper.DecodingOptions() 47 | result = whisper.decode(self.model, mel, options) 48 | 49 | # print the recognized text 50 | print(result.text) 51 | 52 | def init_model(self): 53 | if self.model is None: 54 | self.model = whisper.load_model(name=self.model_name, device=self.device, download_root=self.model_dir) 55 | 56 | 57 | LANGUAGES = { 58 | "en": "english", 59 | "zh": "chinese", 60 | "de": "german", 61 | "es": "spanish", 62 | "ru": "russian", 63 | "ko": "korean", 64 | "fr": "french", 65 | "ja": "japanese", 66 | "pt": "portuguese", 67 | "tr": "turkish", 68 | "pl": "polish", 69 | "ca": "catalan", 70 | "nl": "dutch", 71 | "ar": "arabic", 72 | "sv": "swedish", 73 | "it": "italian", 74 | "id": "indonesian", 75 | "hi": "hindi", 76 | "fi": "finnish", 77 | "vi": "vietnamese", 78 | "he": "hebrew", 79 | "uk": "ukrainian", 80 | "el": "greek", 81 | "ms": "malay", 82 | "cs": "czech", 83 | "ro": "romanian", 84 | "da": "danish", 85 | "hu": "hungarian", 86 | "ta": "tamil", 87 | "no": "norwegian", 88 | "th": "thai", 89 | "ur": "urdu", 90 | "hr": "croatian", 91 | "bg": "bulgarian", 92 | "lt": "lithuanian", 93 | "la": "latin", 94 | "mi": "maori", 95 | "ml": "malayalam", 96 | "cy": "welsh", 97 | "sk": "slovak", 98 | "te": "telugu", 99 | "fa": "persian", 100 | "lv": "latvian", 101 | "bn": "bengali", 102 | "sr": "serbian", 103 | "az": "azerbaijani", 104 | "sl": "slovenian", 105 | "kn": "kannada", 106 | "et": "estonian", 107 | "mk": "macedonian", 108 | "br": "breton", 109 | "eu": "basque", 110 | "is": "icelandic", 111 | "hy": "armenian", 112 | "ne": "nepali", 113 | "mn": "mongolian", 114 | "bs": "bosnian", 115 | "kk": "kazakh", 116 | "sq": "albanian", 117 | "sw": "swahili", 118 | "gl": "galician", 119 | "mr": "marathi", 120 | "pa": "punjabi", 121 | "si": "sinhala", 122 | "km": "khmer", 123 | "sn": "shona", 124 | "yo": "yoruba", 125 | "so": "somali", 126 | "af": "afrikaans", 127 | "oc": "occitan", 128 | "ka": "georgian", 129 | "be": "belarusian", 130 | "tg": "tajik", 131 | "sd": "sindhi", 132 | "gu": "gujarati", 133 | "am": "amharic", 134 | "yi": "yiddish", 135 | "lo": "lao", 136 | "uz": "uzbek", 137 | "fo": "faroese", 138 | "ht": "haitian creole", 139 | "ps": "pashto", 140 | "tk": "turkmen", 141 | "nn": "nynorsk", 142 | "mt": "maltese", 143 | "sa": "sanskrit", 144 | "lb": "luxembourgish", 145 | "my": "myanmar", 146 | "bo": "tibetan", 147 | "tl": "tagalog", 148 | "mg": "malagasy", 149 | "as": "assamese", 150 | "tt": "tatar", 151 | "haw": "hawaiian", 152 | "ln": "lingala", 153 | "ha": "hausa", 154 | "ba": "bashkir", 155 | "jw": "javanese", 156 | "su": "sundanese", 157 | "yue": "cantonese", 158 | } -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/utils/__init__.py -------------------------------------------------------------------------------- /utils/chatglm_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Union, Optional, Tuple 3 | from torch.nn import Module 4 | from transformers import AutoModel, AutoTokenizer 5 | 6 | 7 | def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: 8 | # transformer.word_embeddings 占用1层 9 | # transformer.final_layernorm 和 lm_head 占用1层 10 | # transformer.layers 占用 28 层 11 | # 总共30层分配到num_gpus张卡上 12 | num_trans_layers = 28 13 | per_gpu_layers = 30 / num_gpus 14 | 15 | # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError 16 | # windows下 model.device 会被设置成 transformer.word_embeddings.device 17 | # linux下 model.device 会被设置成 lm_head.device 18 | # 在调用chat或者stream_chat时,input_ids会被放到model.device上 19 | # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError 20 | # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 21 | # 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py 22 | # 仅此处做少许修改以支持ChatGLM3 23 | device_map = { 24 | 'transformer.embedding.word_embeddings': 0, 25 | 'transformer.encoder.final_layernorm': 0, 26 | 'transformer.output_layer': 0, 27 | 'transformer.rotary_pos_emb': 0, 28 | 'lm_head': 0 29 | } 30 | 31 | used = 2 32 | gpu_target = 0 33 | for i in range(num_trans_layers): 34 | if used >= per_gpu_layers: 35 | gpu_target += 1 36 | used = 0 37 | assert gpu_target < num_gpus 38 | device_map[f'transformer.encoder.layers.{i}'] = gpu_target 39 | used += 1 40 | 41 | return device_map 42 | 43 | 44 | def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2, 45 | device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module: 46 | if num_gpus < 2 and device_map is None: 47 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda() 48 | else: 49 | from accelerate import dispatch_model 50 | 51 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half() 52 | 53 | if device_map is None: 54 | device_map = auto_configure_device_map(num_gpus) 55 | 56 | model = dispatch_model(model, device_map=device_map) 57 | 58 | return model 59 | 60 | 61 | def parse_text(text): 62 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" 63 | lines = text.split("\n") 64 | lines = [line for line in lines if line != ""] 65 | count = 0 66 | for i, line in enumerate(lines): 67 | if "```" in line: 68 | count += 1 69 | items = line.split('`') 70 | if count % 2 == 1: 71 | lines[i] = f'
'
72 |             else:
73 |                 lines[i] = f'
' 74 | else: 75 | if i > 0: 76 | if count % 2 == 1: 77 | line = line.replace("`", "\`") 78 | line = line.replace("<", "<") 79 | line = line.replace(">", ">") 80 | line = line.replace(" ", " ") 81 | line = line.replace("*", "*") 82 | line = line.replace("_", "_") 83 | line = line.replace("-", "-") 84 | line = line.replace(".", ".") 85 | line = line.replace("!", "!") 86 | line = line.replace("(", "(") 87 | line = line.replace(")", ")") 88 | line = line.replace("$", "$") 89 | lines[i] = "
"+line 90 | text = "".join(lines) 91 | return text 92 | -------------------------------------------------------------------------------- /utils/fastsam/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import FastSAM 2 | from .predict import FastSAMPredictor 3 | from .prompt import FastSAMPrompt 4 | # from .val import FastSAMValidator 5 | from .decoder import FastSAMDecoder 6 | from .tools import format_results, fast_process, point_prompt, text_prompt 7 | from PIL import ImageDraw 8 | import numpy as np 9 | 10 | __all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMDecoder' 11 | -------------------------------------------------------------------------------- /utils/fastsam/decoder.py: -------------------------------------------------------------------------------- 1 | from .model import FastSAM 2 | import numpy as np 3 | from PIL import Image 4 | import clip 5 | from typing import Optional, List, Tuple, Union 6 | 7 | 8 | class FastSAMDecoder: 9 | def __init__( 10 | self, 11 | model: FastSAM, 12 | device: str='cpu', 13 | conf: float=0.4, 14 | iou: float=0.9, 15 | imgsz: int=1024, 16 | retina_masks: bool=True, 17 | ): 18 | self.model = model 19 | self.device = device 20 | self.retina_masks = retina_masks 21 | self.imgsz = imgsz 22 | self.conf = conf 23 | self.iou = iou 24 | self.image = None 25 | self.image_embedding = None 26 | 27 | def run_encoder(self, image): 28 | if isinstance(image,str): 29 | image = np.array(Image.open(image)) 30 | self.image = image 31 | image_embedding = self.model( 32 | self.image, 33 | device=self.device, 34 | retina_masks=self.retina_masks, 35 | imgsz=self.imgsz, 36 | conf=self.conf, 37 | iou=self.iou 38 | ) 39 | return image_embedding[0].numpy() 40 | 41 | def run_decoder( 42 | self, 43 | image_embedding, 44 | point_prompt: Optional[np.ndarray]=None, 45 | point_label: Optional[np.ndarray]=None, 46 | box_prompt: Optional[np.ndarray]=None, 47 | text_prompt: Optional[str]=None, 48 | )->np.ndarray: 49 | self.image_embedding = image_embedding 50 | if point_prompt is not None: 51 | ann = self.point_prompt(points=point_prompt, pointlabel=point_label) 52 | return ann 53 | elif box_prompt is not None: 54 | ann = self.box_prompt(bbox=box_prompt) 55 | return ann 56 | elif text_prompt is not None: 57 | ann = self.text_prompt(text=text_prompt) 58 | return ann 59 | else: 60 | return None 61 | 62 | def box_prompt(self, bbox): 63 | assert (bbox[2] != 0 and bbox[3] != 0) 64 | masks = self.image_embedding.masks.data 65 | target_height = self.image.shape[0] 66 | target_width = self.image.shape[1] 67 | h = masks.shape[1] 68 | w = masks.shape[2] 69 | if h != target_height or w != target_width: 70 | bbox = [ 71 | int(bbox[0] * w / target_width), 72 | int(bbox[1] * h / target_height), 73 | int(bbox[2] * w / target_width), 74 | int(bbox[3] * h / target_height), ] 75 | bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 76 | bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 77 | bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w 78 | bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h 79 | 80 | # IoUs = torch.zeros(len(masks), dtype=torch.float32) 81 | bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) 82 | 83 | masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2)) 84 | orig_masks_area = np.sum(masks, axis=(1, 2)) 85 | 86 | union = bbox_area + orig_masks_area - masks_area 87 | IoUs = masks_area / union 88 | max_iou_index = np.argmax(IoUs) 89 | 90 | return np.array([masks[max_iou_index].cpu().numpy()]) 91 | 92 | def point_prompt(self, points, pointlabel): # numpy 93 | 94 | masks = self._format_results(self.image_embedding[0], 0) 95 | target_height = self.image.shape[0] 96 | target_width = self.image.shape[1] 97 | h = masks[0]['segmentation'].shape[0] 98 | w = masks[0]['segmentation'].shape[1] 99 | if h != target_height or w != target_width: 100 | points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] 101 | onemask = np.zeros((h, w)) 102 | masks = sorted(masks, key=lambda x: x['area'], reverse=True) 103 | for i, annotation in enumerate(masks): 104 | if type(annotation) == dict: 105 | mask = annotation['segmentation'] 106 | else: 107 | mask = annotation 108 | for i, point in enumerate(points): 109 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: 110 | onemask[mask] = 1 111 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: 112 | onemask[mask] = 0 113 | onemask = onemask >= 1 114 | return np.array([onemask]) 115 | 116 | def _format_results(self, result, filter=0): 117 | annotations = [] 118 | n = len(result.masks.data) 119 | for i in range(n): 120 | annotation = {} 121 | mask = result.masks.data[i] == 1.0 122 | 123 | if np.sum(mask) < filter: 124 | continue 125 | annotation['id'] = i 126 | annotation['segmentation'] = mask 127 | annotation['bbox'] = result.boxes.data[i] 128 | annotation['score'] = result.boxes.conf[i] 129 | annotation['area'] = annotation['segmentation'].sum() 130 | annotations.append(annotation) 131 | return annotations 132 | -------------------------------------------------------------------------------- /utils/fastsam/model.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | FastSAM model interface. 4 | 5 | Usage - Predict: 6 | from ultralytics import FastSAM 7 | 8 | model = FastSAM('last.pt') 9 | results = model.predict('ultralytics/assets/bus.jpg') 10 | """ 11 | 12 | from ultralytics.yolo.cfg import get_cfg 13 | from ultralytics.yolo.engine.exporter import Exporter 14 | from ultralytics.yolo.engine.model import YOLO 15 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir 16 | from ultralytics.yolo.utils.checks import check_imgsz 17 | 18 | from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode 19 | from .predict import FastSAMPredictor 20 | 21 | 22 | class FastSAM(YOLO): 23 | 24 | @smart_inference_mode() 25 | def predict(self, source=None, stream=False, **kwargs): 26 | """ 27 | Perform prediction using the YOLO model. 28 | 29 | Args: 30 | source (str | int | PIL | np.ndarray): The source of the image to make predictions on. 31 | Accepts all source types accepted by the YOLO model. 32 | stream (bool): Whether to stream the predictions or not. Defaults to False. 33 | **kwargs : Additional keyword arguments passed to the predictor. 34 | Check the 'configuration' section in the documentation for all available options. 35 | 36 | Returns: 37 | (List[ultralytics.yolo.engine.results.Results]): The prediction results. 38 | """ 39 | if source is None: 40 | source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' 41 | LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") 42 | overrides = self.overrides.copy() 43 | overrides['conf'] = 0.25 44 | overrides.update(kwargs) # prefer kwargs 45 | overrides['mode'] = kwargs.get('mode', 'predict') 46 | assert overrides['mode'] in ['track', 'predict'] 47 | overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python 48 | self.predictor = FastSAMPredictor(overrides=overrides) 49 | self.predictor.setup_model(model=self.model, verbose=False) 50 | try: 51 | return self.predictor(source, stream=stream) 52 | except Exception as e: 53 | return None 54 | 55 | def train(self, **kwargs): 56 | """Function trains models but raises an error as FastSAM models do not support training.""" 57 | raise NotImplementedError("Currently, the training codes are on the way.") 58 | 59 | def val(self, **kwargs): 60 | """Run validation given dataset.""" 61 | overrides = dict(task='segment', mode='val') 62 | overrides.update(kwargs) # prefer kwargs 63 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) 64 | args.imgsz = check_imgsz(args.imgsz, max_dim=1) 65 | validator = FastSAM(args=args) 66 | validator(model=self.model) 67 | self.metrics = validator.metrics 68 | return validator.metrics 69 | 70 | @smart_inference_mode() 71 | def export(self, **kwargs): 72 | """ 73 | Export model. 74 | 75 | Args: 76 | **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs 77 | """ 78 | overrides = dict(task='detect') 79 | overrides.update(kwargs) 80 | overrides['mode'] = 'export' 81 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) 82 | args.task = self.task 83 | if args.imgsz == DEFAULT_CFG.imgsz: 84 | args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed 85 | if args.batch == DEFAULT_CFG.batch: 86 | args.batch = 1 # default to 1 if not modified 87 | return Exporter(overrides=args)(model=self.model) 88 | 89 | def info(self, detailed=False, verbose=True): 90 | """ 91 | Logs model info. 92 | 93 | Args: 94 | detailed (bool): Show detailed information about model. 95 | verbose (bool): Controls verbosity. 96 | """ 97 | return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) 98 | 99 | def __call__(self, source=None, stream=False, **kwargs): 100 | """Calls the 'predict' function with given arguments to perform object detection.""" 101 | return self.predict(source, stream, **kwargs) 102 | 103 | def __getattr__(self, attr): 104 | """Raises error if object has no requested attribute.""" 105 | name = self.__class__.__name__ 106 | raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") 107 | -------------------------------------------------------------------------------- /utils/fastsam/predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ultralytics.yolo.engine.results import Results 4 | from ultralytics.yolo.utils import DEFAULT_CFG, ops 5 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor 6 | from .utils import bbox_iou 7 | 8 | class FastSAMPredictor(DetectionPredictor): 9 | 10 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 11 | super().__init__(cfg, overrides, _callbacks) 12 | self.args.task = 'segment' 13 | 14 | def postprocess(self, preds, img, orig_imgs): 15 | """TODO: filter by classes.""" 16 | p = ops.non_max_suppression(preds[0], 17 | self.args.conf, 18 | self.args.iou, 19 | agnostic=self.args.agnostic_nms, 20 | max_det=self.args.max_det, 21 | nc=len(self.model.names), 22 | classes=self.args.classes) 23 | 24 | results = [] 25 | if len(p) == 0 or len(p[0]) == 0: 26 | print("No object detected.") 27 | return results 28 | 29 | full_box = torch.zeros_like(p[0][0]) 30 | full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 31 | full_box = full_box.view(1, -1) 32 | critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:]) 33 | if critical_iou_index.numel() != 0: 34 | full_box[0][4] = p[0][critical_iou_index][:,4] 35 | full_box[0][6:] = p[0][critical_iou_index][:,6:] 36 | p[0][critical_iou_index] = full_box 37 | 38 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported 39 | for i, pred in enumerate(p): 40 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 41 | path = self.batch[0] 42 | img_path = path[i] if isinstance(path, list) else path 43 | if not len(pred): # save empty boxes 44 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) 45 | continue 46 | if self.args.retina_masks: 47 | if not isinstance(orig_imgs, torch.Tensor): 48 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 49 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC 50 | else: 51 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC 52 | if not isinstance(orig_imgs, torch.Tensor): 53 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 54 | results.append( 55 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) 56 | return results 57 | -------------------------------------------------------------------------------- /utils/fastsam/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | 5 | 6 | def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): 7 | '''Adjust bounding boxes to stick to image border if they are within a certain threshold. 8 | Args: 9 | boxes: (n, 4) 10 | image_shape: (height, width) 11 | threshold: pixel threshold 12 | Returns: 13 | adjusted_boxes: adjusted bounding boxes 14 | ''' 15 | 16 | # Image dimensions 17 | h, w = image_shape 18 | 19 | # Adjust boxes 20 | boxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor( 21 | 0, dtype=torch.float, device=boxes.device), boxes[:, 0]) # x1 22 | boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor( 23 | 0, dtype=torch.float, device=boxes.device), boxes[:, 1]) # y1 24 | boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor( 25 | w, dtype=torch.float, device=boxes.device), boxes[:, 2]) # x2 26 | boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor( 27 | h, dtype=torch.float, device=boxes.device), boxes[:, 3]) # y2 28 | 29 | return boxes 30 | 31 | 32 | 33 | def convert_box_xywh_to_xyxy(box): 34 | x1 = box[0] 35 | y1 = box[1] 36 | x2 = box[0] + box[2] 37 | y2 = box[1] + box[3] 38 | return [x1, y1, x2, y2] 39 | 40 | 41 | def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False): 42 | '''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes. 43 | Args: 44 | box1: (4, ) 45 | boxes: (n, 4) 46 | Returns: 47 | high_iou_indices: Indices of boxes with IoU > thres 48 | ''' 49 | boxes = adjust_bboxes_to_image_border(boxes, image_shape) 50 | # obtain coordinates for intersections 51 | x1 = torch.max(box1[0], boxes[:, 0]) 52 | y1 = torch.max(box1[1], boxes[:, 1]) 53 | x2 = torch.min(box1[2], boxes[:, 2]) 54 | y2 = torch.min(box1[3], boxes[:, 3]) 55 | 56 | # compute the area of intersection 57 | intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) 58 | 59 | # compute the area of both individual boxes 60 | box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) 61 | box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 62 | 63 | # compute the area of union 64 | union = box1_area + box2_area - intersection 65 | 66 | # compute the IoU 67 | iou = intersection / union # Should be shape (n, ) 68 | if raw_output: 69 | if iou.numel() == 0: 70 | return 0 71 | return iou 72 | 73 | # get indices of boxes with IoU > thres 74 | high_iou_indices = torch.nonzero(iou > iou_thres).flatten() 75 | 76 | return high_iou_indices 77 | 78 | 79 | def image_to_np_ndarray(image): 80 | if type(image) is str: 81 | return np.array(Image.open(image)) 82 | elif issubclass(type(image), Image.Image): 83 | return np.array(image) 84 | elif type(image) is np.ndarray: 85 | return image 86 | return None 87 | -------------------------------------------------------------------------------- /utils/gradio_tabs/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_tabs import * 2 | from .video_tabs import * 3 | from .chat_tabs import * 4 | from .audio_tabs import * -------------------------------------------------------------------------------- /utils/gradio_tabs/audio_tabs.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | 4 | # 语音识别 5 | def asr_tab(asr_args, ai_handler): 6 | with gr.Tab(asr_args['name']): 7 | audio_file = gr.Audio(type="filepath", autoplay=True) 8 | # text_input = gr.Textbox(label="Question") 9 | audio_asr_text = gr.Textbox(label="text") 10 | audio_asr_files = gr.components.File(label="音频识别输出的文件") 11 | audio_b1 = gr.Button("Recognize speech", variant="primary") 12 | 13 | audio_b1.click(ai_handler.asr_infer, inputs=[audio_file], outputs=[audio_asr_text, audio_asr_files]) 14 | 15 | def tts_tab(tts_args, ai_handler): 16 | with gr.Tab(tts_args['name']): 17 | model_type = gr.Dropdown(label="模型类型", choices=tts_args['model_type']['choices'], value=tts_args['model_type']['value']) 18 | with gr.Row(): 19 | tts_text_file = gr.components.File(label="上传txt文本") 20 | tts_text = gr.Textbox(label="text") 21 | 22 | tts_voice = gr.Dropdown(label="选择发音人", choices=tts_args['tts_voice']['choices'], value=tts_args['tts_voice']['value']) 23 | tts_rate = gr.Slider(label="语速", minimum=tts_args['tts_rate']['minimum'], max=tts_args['tts_rate']['maximum'], value=tts_args['tts_rate']['value'], step=tts_args['tts_rate']['step']) 24 | tts_volume = gr.Slider(label="音量", minimum=tts_args['tts_volume']['minimum'], max=tts_args['tts_volume']['maximum'], value=tts_args['tts_volume']['value'], step=tts_args['tts_volume']['step']) 25 | tts_pitch = gr.Slider(label="语调", minimum=tts_args['tts_pitch']['minimum'], max=tts_args['tts_pitch']['maximum'], value=tts_args['tts_pitch']['value'], step=tts_args['tts_pitch']['step']) 26 | tts_b1 = gr.Button("合成", variant="primary") 27 | # 输出可下载文件 28 | tts_audio_file = gr.Audio(type="filepath", autoplay=True, label="合成音频") 29 | tts_out_files = gr.components.File(label="合成文件下载列表") 30 | 31 | tts_b1.click(ai_handler.tts_infer, 32 | inputs=[tts_text_file, tts_text, tts_voice, tts_rate, tts_volume, tts_pitch], 33 | outputs=[tts_audio_file, tts_out_files] 34 | ) -------------------------------------------------------------------------------- /utils/gradio_tabs/chat_tabs.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from utils import misc 3 | 4 | def chat_tab(chatbot_args, tts_args, ai_handler): 5 | # 聊天机器人 text_args[''] 6 | with gr.Tab(chatbot_args['name']): 7 | gr.Chatbot.postprocess = misc.postprocess 8 | chatbot = gr.Chatbot(height=chatbot_args['chatbot_win']['height'],) 9 | with gr.Row(): 10 | with gr.Column(scale=4): 11 | with gr.Column(scale=12): 12 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=5).style( 13 | container=False) 14 | 15 | with gr.Column(min_width=32, scale=1): 16 | text_submitBtn = gr.Button("发送文字聊天信息", variant="primary") 17 | # speech_file = gr.Audio(type="filepath", min_width=25, autoplay=False, label="语音文件or录音") 18 | speech_file = gr.Audio(label="Audio", sources="microphone", type="filepath", elem_id='audio') 19 | speech_submitBtn = gr.Button("发送语音信息") 20 | chat_replay_audio = gr.Audio(type="filepath", autoplay=True, label="AI回话内容") 21 | 22 | with gr.Column(scale=1): 23 | llm_model_type = gr.Dropdown(choices=chatbot_args['llm_model_type']['choices'], 24 | value=chatbot_args['llm_model_type']['value'], 25 | label="大语言模型") 26 | chat_tts_voice = gr.Dropdown(label="选择发音人", choices=tts_args['tts_voice']['choices'], value=tts_args['tts_voice']['value']) 27 | text_emptyBtn = gr.Button("Clear History") 28 | max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True) 29 | top_p = gr.Slider(0, 1, value=0.95, step=0.01, label="Top P", interactive=True) 30 | temperature = gr.Slider(0, 1, value=0.85, step=0.01, label="Temperature", interactive=True) 31 | 32 | history = gr.State([]) 33 | past_key_values = gr.State(None) 34 | 35 | text_submitBtn.click(ai_handler.chatglm_handler.stream_chat, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values], 36 | [chatbot, history, past_key_values], show_progress=True) 37 | text_submitBtn.click(misc.reset_user_input, [], [user_input]) 38 | 39 | # speech_submitBtn.click(ai_handler.llm_infer, [llm_model_type, user_input, speech_file, history, max_length, top_p, temperature], 40 | # [chat_replay_audio, history], show_progress=True) 41 | speech_submitBtn.click(ai_handler.audio_chat, [llm_model_type, chat_tts_voice, speech_file, history, max_length, top_p, temperature], 42 | [chat_replay_audio, history], show_progress=True) 43 | 44 | # speech_submitBtn.click(fn=action, inputs=speech_submitBtn, outputs=speech_submitBtn).\ 45 | # then(fn=lambda: None, _js=click_js()).\ 46 | # then(fn=check_btn, inputs=speech_submitBtn).\ 47 | # success(fn=ai_handler.llm_infer, inputs=[llm_model_type, user_input, speech_file, history, max_length, top_p, temperature], outputs=[chat_replay_audio, history]) 48 | 49 | 50 | text_emptyBtn.click(misc.reset_state, outputs=[chatbot, history, past_key_values], show_progress=True) 51 | 52 | 53 | def visualchat_tab(visualchat_args, ai_handler): 54 | MAINTENANCE_NOTICE = 'Hint 1: If the app report "Something went wrong, connection error out", \ 55 | please turn off your proxy and retry.\nHint 2: If you upload a large size of image like 10MB, \ 56 | it may take some time to upload and process. Please be patient and wait.' 57 | with gr.Tab(visualchat_args['name']): 58 | with gr.Row(): 59 | with gr.Column(scale=2): 60 | image_path = gr.Image(type="filepath", label="Image Prompt", value=None).style(height=480) 61 | with gr.Column(scale=4): 62 | chatbot = gr.Chatbot().style(height=480) 63 | with gr.Row(): 64 | with gr.Column(scale=2, min_width=100): 65 | max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) 66 | top_p = gr.Slider(0, 1, value=0.4, step=0.01, label="Top P", interactive=True) 67 | temperature = gr.Slider(0, 1, value=0.8, step=0.01, label="Temperature", interactive=True) 68 | with gr.Column(scale=4): 69 | with gr.Box(): 70 | with gr.Row(): 71 | with gr.Column(scale=2): 72 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=6).style( 73 | container=False) 74 | with gr.Column(scale=1, min_width=64): 75 | submitBtn = gr.Button("Submit", variant="primary") 76 | emptyBtn = gr.Button("Clear History") 77 | gr.Markdown('\n' + MAINTENANCE_NOTICE + '\n') 78 | history = gr.State([]) 79 | 80 | # submitBtn.click(ai_handler.visualchat_handler.stream_chat, [user_input, image_path, chatbot, max_length, top_p, temperature, history], [chatbot, history], 81 | # show_progress=True) 82 | # image_path.upload(ai_handler.visualchat_handler.stream_chat2, [image_path, chatbot, max_length, top_p, temperature], [chatbot, history], 83 | # show_progress=True) 84 | 85 | # submitBtn.click(ai_handler.chatvlm_handler.chat, [user_input, image_path, chatbot, history], [chatbot, history], show_progress=True) 86 | submitBtn.click(ai_handler.chatvlm_handler.chat_stream, [user_input, image_path, chatbot, history], [chatbot, history], show_progress=True) 87 | # image_path.clear(misc.reset_state, outputs=[image_path, chatbot, history], show_progress=True) 88 | # submitBtn.click(misc.reset_user_input, [], [user_input]) 89 | emptyBtn.click(misc.reset_state2, outputs=[chatbot, history], show_progress=True) 90 | -------------------------------------------------------------------------------- /utils/gradio_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Iterable 3 | import gradio as gr 4 | from gradio.themes.base import Base 5 | from gradio.themes.utils import colors, fonts, sizes 6 | import time 7 | 8 | 9 | class Seafoam(Base): 10 | def __init__( 11 | self, 12 | *, 13 | primary_hue: colors.Color | str = colors.gray, 14 | secondary_hue: colors.Color | str = colors.blue, 15 | neutral_hue: colors.Color | str = colors.neutral, 16 | spacing_size: sizes.Size | str = sizes.spacing_md, 17 | radius_size: sizes.Size | str = sizes.radius_md, 18 | text_size: sizes.Size | str = sizes.text_lg, 19 | font: fonts.Font 20 | | str 21 | | Iterable[fonts.Font | str] = ( 22 | fonts.GoogleFont("Quicksand"), 23 | "ui-sans-serif", 24 | "sans-serif", 25 | ), 26 | font_mono: fonts.Font 27 | | str 28 | | Iterable[fonts.Font | str] = ( 29 | fonts.GoogleFont("IBM Plex Mono"), 30 | "ui-monospace", 31 | "monospace", 32 | ), 33 | ): 34 | super().__init__( 35 | primary_hue=primary_hue, 36 | secondary_hue=secondary_hue, 37 | neutral_hue=neutral_hue, 38 | spacing_size=spacing_size, 39 | radius_size=radius_size, 40 | text_size=text_size, 41 | font=font, 42 | font_mono=font_mono, 43 | ) 44 | 45 | 46 | def click_js(): 47 | return """function audioRecord() { 48 | var xPathRes = document.evaluate ('//*[contains(@class, "record")]', document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null); 49 | xPathRes.singleNodeValue.click();}""" 50 | 51 | 52 | def action(btn): 53 | """Changes button text on click""" 54 | if btn == 'Speak': return 'Stop' 55 | else: return 'Speak' 56 | 57 | 58 | def check_btn(btn): 59 | """Checks for correct button text before invoking transcribe()""" 60 | if btn != 'Speak': raise Exception('Recording...') 61 | 62 | -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import anyconfig 4 | from tools import AIWrapper 5 | 6 | from utils.misc import parse_config 7 | from utils.gradio_utils import * 8 | import gradio as gr 9 | from gradio.themes.utils import colors 10 | from utils.gradio_tabs import * 11 | torch.manual_seed(1234) 12 | 13 | 14 | os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' 15 | 16 | def launch_webui(yaml_file, device_ids=None, share=None, **kwargs): 17 | if device_ids is not None: 18 | os.environ['CUDA_VISIBLE_DEVICES'] = device_ids 19 | with open(yaml_file, 'rb') as f: 20 | args = anyconfig.load(f) 21 | if 'base' in args: 22 | args = parse_config(args) 23 | 24 | video_convertor_args = args.get('video_convertor', {}) 25 | video_inpainter_args = args.get('video_inpainter', {}) 26 | segmentation_args = args.get('segmentation_task', {}) 27 | chat_args = args.get('chatbot', {}) 28 | visualchat_args = args.get('visualchat', {}) 29 | asr_args = args.get('asr_task', {}) 30 | tts_args = args.get('tts_task', {}) 31 | 32 | # 初始化AI引擎 33 | ai_handler = AIWrapper(args) 34 | 35 | # seafoam = gradio_utils.Seafoam() 36 | theme=gr.themes.Soft(primary_hue=colors.gray, neutral_hue=colors.neutral) 37 | with gr.Blocks(theme=theme) as web: 38 | # gr.Markdown(args["home_name"]) 39 | gr.HTML(f"""

{args["home_desc"]}

""") 40 | # Process text, audio or video file using this web 41 | 42 | # 视频剪辑 43 | if video_convertor_args.get('switch'): 44 | video_convertor_tab(video_convertor_args, ai_handler) 45 | 46 | # 视频修复 47 | if video_inpainter_args.get('switch'): 48 | video_inpainter_tab(video_inpainter_args, ai_handler) 49 | 50 | # 图像分割 51 | if segmentation_args.get('switch'): 52 | sam_tab(segmentation_args, ai_handler) 53 | 54 | # 聊天问答 55 | if chat_args.get('switch'): 56 | chat_tab(chat_args, tts_args, ai_handler) 57 | # 多模态问答 58 | if visualchat_args.get('switch'): 59 | visualchat_tab(visualchat_args, ai_handler) 60 | 61 | # 语音识别 62 | if asr_args.get('switch'): 63 | asr_tab(asr_args, ai_handler) 64 | 65 | # 语音合成 66 | if tts_args.get('switch'): 67 | tts_tab(tts_args, ai_handler) 68 | 69 | web.queue().launch(share=share, server_name=args["server_name"], server_port=args["server_port"]) 70 | 71 | 72 | def parse_opt(): 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('-f', '-c', '-cfg', '--yaml_file', type=str, default='configs/webui_configs.yml', 75 | help='yaml config file' 76 | ) 77 | parser.add_argument('-d', '--device_ids', type=str, default=None, help='device ids') 78 | parser.add_argument('-s', "--share", action="store_true", help='whether public url') 79 | opt = parser.parse_args() 80 | return opt 81 | 82 | 83 | if __name__ == "__main__": 84 | opt = parse_opt() 85 | launch_webui(**opt.__dict__) 86 | --------------------------------------------------------------------------------