├── .gitignore
├── LICENSE
├── README.md
├── README_cn.md
├── assets
├── AI-logo.png
├── all_tabs_demo.gif
├── asr_demo.gif
├── chat_demo.gif
├── segmentation_demo.gif
├── tts_demo.gif
├── video_convertor_demo.gif
├── video_inpainter_demo.gif
├── visualchat_demo.jpg
└── visualchat_demo.png
├── configs
├── asr_demo.yml
├── base.yml
├── chatbot_demo.yml
├── segmentation_demo.yml
├── tts_demo.yml
├── video_convertor_demo.yml
├── video_inpainter_demo.yml
├── visualchat_demo.yml
└── webui_configs.yml
├── demo
├── bgm
│ ├── Paris.mp3
│ ├── illusionary_daytime.mp3
│ ├── time_back.mp3
│ └── windy_hill.mp3
├── fastsam
│ └── examples
│ │ ├── dogs.jpg
│ │ ├── sa_10039.jpg
│ │ ├── sa_11025.jpg
│ │ ├── sa_1309.jpg
│ │ ├── sa_192.jpg
│ │ ├── sa_414.jpg
│ │ ├── sa_561.jpg
│ │ ├── sa_862.jpg
│ │ └── sa_8776.jpg
└── video_inpainter
│ ├── test-sample0.mp4
│ ├── test-sample1.mp4
│ ├── test-sample2.mp4
│ ├── test-sample3.mp4
│ └── test-sample4.mp4
├── model_weights
└── chatglm
│ └── chatglm2-6b-int4
│ ├── config.json
│ ├── configuration_chatglm.py
│ ├── modeling_chatglm.py
│ ├── quantization.py
│ ├── tokenization_chatglm.py
│ └── tokenizer_config.json
├── models
├── RAFT
│ ├── __init__.py
│ ├── corr.py
│ ├── datasets.py
│ ├── demo.py
│ ├── extractor.py
│ ├── raft.py
│ ├── update.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── augmentor.py
│ │ ├── flow_viz.py
│ │ ├── flow_viz_pt.py
│ │ ├── frame_utils.py
│ │ └── utils.py
├── __init__.py
├── misc.py
├── modules
│ ├── base_module.py
│ ├── deformconv.py
│ ├── flow_comp_raft.py
│ ├── flow_loss_utils.py
│ ├── sparse_transformer.py
│ └── spectral_norm.py
├── propainter.py
├── recurrent_flow_completion.py
├── sam
│ ├── README.md
│ ├── __init__.py
│ ├── assets
│ │ ├── masks1.png
│ │ ├── masks2.jpg
│ │ ├── model_diagram.png
│ │ ├── notebook1.png
│ │ └── notebook2.png
│ ├── linter.sh
│ ├── notebooks
│ │ ├── automatic_mask_generator_example.ipynb
│ │ ├── images
│ │ │ ├── dog.jpg
│ │ │ ├── groceries.jpg
│ │ │ └── truck.jpg
│ │ ├── onnx_model_example.ipynb
│ │ └── predictor_example.ipynb
│ ├── scripts
│ │ ├── amg.py
│ │ └── export_onnx_model.py
│ ├── segment_anything
│ │ ├── .DS_Store
│ │ ├── __init__.py
│ │ ├── automatic_mask_generator.py
│ │ ├── build_sam.py
│ │ ├── modeling
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── image_encoder.py
│ │ │ ├── mask_decoder.py
│ │ │ ├── prompt_encoder.py
│ │ │ ├── sam.py
│ │ │ └── transformer.py
│ │ ├── predictor.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── amg.py
│ │ │ ├── onnx.py
│ │ │ └── transforms.py
│ ├── setup.cfg
│ └── setup.py
└── tracker
│ ├── base_tracker.py
│ ├── config
│ └── __init__.py
│ ├── inference
│ ├── __init__.py
│ ├── image_feature_store.py
│ ├── inference_core.py
│ ├── kv_memory_store.py
│ ├── memory_manager.py
│ ├── object_info.py
│ ├── object_manager.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── args_utils.py
│ │ ├── burst_utils.py
│ │ ├── frame_utils.py
│ │ └── results_utils.py
│ ├── model
│ ├── __init__.py
│ ├── aux_modules.py
│ ├── big_modules.py
│ ├── channel_attn.py
│ ├── cutie.py
│ ├── group_modules.py
│ ├── losses.py
│ ├── modules.py
│ ├── transformer
│ │ ├── __init__.py
│ │ ├── object_summarizer.py
│ │ ├── object_transformer.py
│ │ ├── positional_encoding.py
│ │ └── transformer_layers.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── memory_utils.py
│ │ ├── parameter_groups.py
│ │ └── resnet.py
│ └── utils
│ ├── __init__.py
│ ├── image_saver.py
│ ├── load_subset.py
│ ├── log_integrator.py
│ ├── logger.py
│ ├── mask_mapper.py
│ ├── palette.py
│ ├── pano_utils.py
│ ├── point_features.py
│ ├── range_transform.py
│ ├── tensor_utils.py
│ └── time_estimator.py
├── requirements.txt
├── tools
├── __init__.py
├── ai_wrapper.py
├── base.py
├── chatglm_handler.py
├── chatvlm_handler.py
├── edgetts_handler.py
├── fastsam_handler.py
├── gpt_handler.py
├── inpainting_handler.py
├── sam_handler.py
├── tracker_handler.py
├── visualglm_handler.py
└── whisper_handler.py
├── utils
├── __init__.py
├── chatglm_utils.py
├── fastsam
│ ├── __init__.py
│ ├── decoder.py
│ ├── model.py
│ ├── predict.py
│ ├── prompt.py
│ ├── tools.py
│ └── utils.py
├── gradio_tabs
│ ├── __init__.py
│ ├── audio_tabs.py
│ ├── chat_tabs.py
│ ├── image_tabs.py
│ └── video_tabs.py
├── gradio_utils.py
├── mask_painter.py
├── misc.py
└── painter.py
└── webui.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 jasonaidm
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README_cn.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
8 |
9 |
AI-WEBUI: A universal web interface for AI creation, 一款好用的图像、音频、视频处理工具
10 |
11 |
12 | ⭐ 如果对你有帮助,麻烦给个star哈,感谢! 🤗
13 |
14 |
15 | ## 🌟 1. 简介
16 | ai-webui是一个基于浏览器操作的界面,旨在提供一个通用的AI创作平台。
17 |
18 |
19 |
20 | 本项目提供了图像分割、目标追踪、图像修复、语音识别、语音合成等基本功能,以及集成得到的聊天问答、视频翻译、视频去水印等有利于大幅提高短视频创作效率的高级功能。
21 |
22 | ## ⚡2. 安装
23 |
24 | 要安装并使用AI-WebUI,请按照以下步骤操作:
25 |
26 | ### 2.1 克隆此项目到本地计算机
27 |
28 | ```bash
29 | git clone https://github.com/jasonaidm/ai_webui.git
30 | ```
31 |
32 | ### 2.2 进入项目目录
33 |
34 | ```bash
35 | cd ai_webui
36 | ```
37 | ### 2.3 创建虚拟环境
38 | ```bash
39 | conda create -n aiwebui python=3.11
40 | conda activate aiwebui
41 | ```
42 |
43 | ### 2.4 安装所需的依赖项
44 |
45 | ```bash
46 | apt install ffmpeg -y
47 | pip install -r requirements.txt
48 | ```
49 |
50 |
51 | ## 🚀3. 快速开始
52 |
53 | 使用AI-WebUI非常简单。只需要按照界面上的指示进行操作即可。你可以通过上传视频、音频、图片或输入文本等方式输入创作要素,并与模型的输出进行交互。
54 | ```bash
55 | python webui.py -c ./configs/webui_configs.yml
56 | ```
57 |
58 | 启动后,在浏览器中打开 `http://localhost:9090/?__theme=dark` 查看项目界面。
59 |
60 | ### 3.1 单一功能示例
61 | 考虑到部分用户个人电脑的GPU性能问题,我们提供单功能示例,用户可以单独运行一个AI功能,而不需要启动整个项目。
62 |
63 | 1. 图像分割
64 | - 全景分割
65 | - 基于points坐标信息的提示词分割
66 | - 基于文本提示词的分割
67 | ```bash
68 | python webui.py -c ./configs/segmentation_demo.yml
69 | ```
70 | 
71 |
72 | 2. 语音识别
73 | - 中英文等多语言识别
74 | ```bash
75 | python webui.py -c ./configs/asr_demo.yml
76 | ```
77 | 
78 |
79 | 3. 语音合成
80 | - 中英文等多语言合成
81 | ```bash
82 | python webui.py -c ./configs/tts_demo.yml
83 | ```
84 | 
85 |
86 |
87 | ### 3.2 组合功能示例
88 | 通过多个AI模型组合得到更为复杂的功能,对显卡资源要求较高。
89 | 1. 聊天问答
90 | - 文本流式对话功能
91 | - 语音对话功能
92 | ```bash
93 | python webui.py -c ./configs/chatbot_demo.yml
94 | ```
95 | 
96 |
97 | 2. 视频修复
98 | - 去水印
99 | - 去马赛克
100 | - 目标追踪
101 | - 消除视频里的特定物体
102 |
103 | ```bash
104 | python webui.py -c ./configs/video_inpainter_demo.yml
105 | ```
106 | 
107 |
108 | 3. 视频转换
109 | - 音视频分离
110 | - 画面裁剪
111 | - 画面加噪
112 | - 抽帧取帧
113 | - 音频识别
114 | - 字幕翻译
115 | - 语音合成
116 | - bgm添加
117 | - 视频一键生成(外网视频无脑搬运)
118 | ```bash
119 | python webui.py -c ./configs/video_convertor_demo.yml
120 | ```
121 | 
122 |
123 | ### 3.3 全功能上线
124 | 通过下列命令,打开所有AI功能:
125 | ```bash
126 | python webui.py -c ./configs/webui_configs.yml
127 | ```
128 | 由于模型加载耗时较长,建议在启动后的第一次推理时加载模型。
129 | 可通过configs/base.yml配置文件的"init_model_when_start_server"来控制每一个AI模型的加载策略。
130 |
131 | ## 🔥4. 模型文件
132 | ### 4.1 模型文件下载
133 | | 模型 | 模型文件大小 | 小模型清单 | 下载链接 |
134 | | :--- | :--- | :--- | :--- |
135 | | chatglm2-6b-int4 | 3.7G | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A)|
136 | | chatglm2-6b | 12G | | [清华大学云盘](https://cloud.tsinghua.edu.cn/d/674208019e314311ab5c/?p=%2Fchatglm2-6b&mode=list) |
137 | | sam_vit_b | 358M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) |
138 | | sam_vit_h | 2.4G | | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) |
139 | | FastSAM-s | 23M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) |
140 | | FastSAM-x | 138M | | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) |
141 | | ProPainter | 150M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) |
142 | | raft-things | 20M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) |
143 | | recurrent_flow_completion | 19M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A ) |
144 | | cutie | 134M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) |
145 | | whisper-samll | 461M | ✅ | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) |
146 | | whisper-large-v3 | 2.9G | | [百度网盘](https://pan.baidu.com/s/1d-eRdvX-wRgm4XUJ24G30A) |
147 |
148 | - 百度网盘提取码为:zogk
149 |
150 | ### 4.2 模型权重文件的目录结构
151 | ```
152 | model_weights/
153 | ├── chatglm
154 | │ └── chatglm2-6b-int4
155 | │ ├── config.json
156 | │ ├── configuration_chatglm.py
157 | │ ├── modeling_chatglm.py
158 | │ ├── pytorch_model.bin
159 | │ ├── quantization.py
160 | │ ├── tokenization_chatglm.py
161 | │ ├── tokenizer.model
162 | │ └── tokenizer_config.json
163 | ├── fastsam
164 | │ ├── FastSAM-s.pt
165 | │ └── FastSAM-x.pt
166 | ├── propainter
167 | │ ├── ProPainter.pth
168 | │ ├── cutie-base-mega.pth
169 | │ ├── raft-things.pth
170 | │ └── recurrent_flow_completion.pth
171 | ├── sam
172 | │ ├── sam_vit_b.pth
173 | │ └── sam_vit_h.pth
174 | └── whisper
175 | ├── large-v3.pt
176 | └── small.pt
177 | ```
178 | 如果GPU显存小于8G,可能要小模型才能跑得起来;但小模型的效果不太理想,有条件的尽量跑大模型。
179 |
180 | ## 5. 贡献
181 |
182 | 如果你有任何建议或功能请求,欢迎提出一个 issue。
183 |
184 | ## 6. 参考
185 | - [Segment-ant-Track-Anything](https://github.com/z-x-yang/Segment-and-Track-Anything)
186 | - [ProPainter](https://github.com/sczhou/ProPainter)
187 | - [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B)
188 | - [segment-anything](https://github.com/facebookresearch/segment-anything)
189 | - [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
190 | - [whisper](https://github.com/openai/whisper)
--------------------------------------------------------------------------------
/assets/AI-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/AI-logo.png
--------------------------------------------------------------------------------
/assets/all_tabs_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/all_tabs_demo.gif
--------------------------------------------------------------------------------
/assets/asr_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/asr_demo.gif
--------------------------------------------------------------------------------
/assets/chat_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/chat_demo.gif
--------------------------------------------------------------------------------
/assets/segmentation_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/segmentation_demo.gif
--------------------------------------------------------------------------------
/assets/tts_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/tts_demo.gif
--------------------------------------------------------------------------------
/assets/video_convertor_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/video_convertor_demo.gif
--------------------------------------------------------------------------------
/assets/video_inpainter_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/video_inpainter_demo.gif
--------------------------------------------------------------------------------
/assets/visualchat_demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/visualchat_demo.jpg
--------------------------------------------------------------------------------
/assets/visualchat_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/assets/visualchat_demo.png
--------------------------------------------------------------------------------
/configs/asr_demo.yml:
--------------------------------------------------------------------------------
1 | base: ['./configs/base.yml']
2 |
3 | asr_task:
4 | switch: true
5 | name: 语音识别
6 | audio_asr_text:
7 | label: 'text'
8 |
9 |
--------------------------------------------------------------------------------
/configs/base.yml:
--------------------------------------------------------------------------------
1 |
2 | home_desc: 站在AI「人工智能」浪潮的风口,我们可以更好地应对未来的挑战!!
3 | server_name: "0.0.0.0"
4 | server_port: 9090
5 |
6 | # AI模型相关参数
7 | ChatGLMHandler:
8 | llm_model_path: ./model_weights/chatglm/chatglm2-6b-int4
9 | num_gpus: 1
10 | trust_remote_code: true
11 | init_model_when_start_server: false
12 |
13 | ChatVLMHandler:
14 | vlm_model_path: ./model_weights/Qwen-VL-Chat # /mnt/d/ai_dev/visualglm-6b/model_weights
15 | device: 'cuda:1'
16 | quant: 8
17 | trust_remote_code: true
18 | init_model_when_start_server: false
19 |
20 | WhisperHandler:
21 | model_name: large-v3
22 | model_dir: ./model_weights/whisper
23 | device: 'cuda:1'
24 | init_model_when_start_server: false
25 | language_map: {
26 | "普通话": "zh",
27 | "英语": "en",
28 | }
29 |
30 | GPTHandler:
31 | api_url: null
32 |
33 | FastSAMHandler:
34 | fastsam_model_path: ./model_weights/fastsam/FastSAM-x.pt
35 | device: 'cuda:0'
36 |
37 | SAMHandler:
38 | model_ckpt: ./model_weights/sam/sam_vit_h.pth
39 | model_type: vit_h
40 | device: 'cuda:0'
41 |
42 | InpaintingHandler:
43 | propainter_ckpt: ./model_weights/propainter/ProPainter.pth
44 | raft_ckpt: ./model_weights/propainter/raft-things.pth
45 | flow_completion_ckpt: ./model_weights/propainter/recurrent_flow_completion.pth
46 | device: 'cuda:0'
47 | use_half: true
48 |
49 | TrackerHandler:
50 | cutie_ckpt: ./model_weights/propainter/cutie-base-mega.pth
51 | device: 'cuda:0'
52 |
53 |
54 | EdgeTTSHandler:
55 | output_dir: /tmp
56 |
--------------------------------------------------------------------------------
/configs/chatbot_demo.yml:
--------------------------------------------------------------------------------
1 | base: ['./configs/base.yml']
2 |
3 | chatbot:
4 | switch: true
5 | name: 聊天问答
6 | chatbot_win:
7 | height: 268
8 | llm_model_type:
9 | choices: ["chatglm"]
10 | value: "chatglm"
11 |
12 | tts_task:
13 | switch: true
14 | name: 语音合成
15 | output_dir: ./results/tts_outputs
16 | model_type:
17 | choices: ["edge_tts", "so_vits_svc"]
18 | value: "edge_tts"
19 | tts_voice:
20 | choices: &tts_voice_list ["zh-CN-YunxiNeural", "zh-CN-YunjianNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural",
21 | "zh-CN-YunxiaNeural", "zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-CN-shaanxi-XiaoniNeural"]
22 | value: "zh-CN-YunxiNeural"
23 | tts_rate:
24 | minimum: -100
25 | maximum: 100
26 | value: 0
27 | step: 5
28 | tts_volume:
29 | minimum: -100
30 | maximum: 100
31 | value: 0
32 | step: 5
33 | tts_pitch:
34 | minimum: -100
35 | maximum: 100
36 | value: 0
37 | step: 5
38 |
--------------------------------------------------------------------------------
/configs/segmentation_demo.yml:
--------------------------------------------------------------------------------
1 | base: ['./configs/base.yml']
2 |
3 | segmentation_task:
4 | switch: true
5 | name: 图像分割(SAM)
6 |
7 |
8 |
--------------------------------------------------------------------------------
/configs/tts_demo.yml:
--------------------------------------------------------------------------------
1 | base: ['./configs/base.yml']
2 |
3 | tts_task:
4 | switch: true
5 | name: 语音合成
6 | output_dir: /data1/zjx/ai_webui/products/tts_outputs
7 | model_type:
8 | choices: ["edge_tts", "so_vits_svc"]
9 | value: "edge_tts"
10 | tts_voice:
11 | choices: ["zh-CN-YunxiNeural", "zh-CN-YunjianNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural",
12 | "zh-CN-YunxiaNeural", "zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-CN-shaanxi-XiaoniNeural"]
13 | value: "zh-CN-YunxiNeural"
14 | tts_rate:
15 | minimum: -100
16 | maximum: 100
17 | value: 0
18 | step: 5
19 | tts_volume:
20 | minimum: -100
21 | maximum: 100
22 | value: 0
23 | step: 5
24 | tts_pitch:
25 | minimum: -100
26 | maximum: 100
27 | value: 0
28 | step: 5
--------------------------------------------------------------------------------
/configs/video_convertor_demo.yml:
--------------------------------------------------------------------------------
1 | base: ['./configs/base.yml']
2 |
3 | video_convertor:
4 | switch: true
5 | name: 视频转换
6 | video_upload_win:
7 | height: 400
8 | width: null
9 | autoplay: true
10 | aspect_ratio_box:
11 | choices: ["16/9", "9/16", "4/3"]
12 | value: "16/9"
13 | move_up_rate:
14 | minimum: 0
15 | maximum: 0.3
16 | step: 0.05
17 | value: 0.15
18 | add_noice:
19 | value: false
20 | video_segment_length:
21 | minimum: 1
22 | maximum: 10
23 | step: 1
24 | value: 1
25 | audio_speech_rate:
26 | minimum: 0.1
27 | maximum: 1.
28 | step: 0.1
29 | value: 0.9
30 | subtitling_recognition:
31 | value: true
32 | translate_engine:
33 | choices: ["chatglm"]
34 | value: "chatglm"
35 | subtitling_language:
36 | choices: ["英语", "普通话"]
37 | value: "普通话"
38 | collage_short_video:
39 | value: true
40 | voice_role:
41 | choices: ["zh-CN-YunxiNeural", "zh-CN-YunjianNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural",
42 | "zh-CN-YunxiaNeural", "zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-CN-shaanxi-XiaoniNeural"
43 | ]
44 | value: "zh-CN-YunxiNeural"
45 | bgm_name:
46 | choices: ["Paris", "illusionary_daytime", "time_back", "windy_hill"]
47 | value: "Paris"
48 | watermark:
49 | choices: ["拉普拉丝", "川陀"]
50 | value: "拉普拉丝"
51 |
52 |
--------------------------------------------------------------------------------
/configs/video_inpainter_demo.yml:
--------------------------------------------------------------------------------
1 | base: ['./configs/base.yml']
2 |
3 | video_inpainter:
4 | switch: true
5 | name: 视频修复
6 | interactive_state:
7 | mask_save: false
8 |
--------------------------------------------------------------------------------
/configs/visualchat_demo.yml:
--------------------------------------------------------------------------------
1 | base: ['./configs/base.yml']
2 |
3 | visualchat:
4 | switch: true
5 | name: 多模态问答
6 | model_type:
7 | choices: ["qwen-vl-chat"]
8 | value: "qwen-vl-chat"
9 |
--------------------------------------------------------------------------------
/configs/webui_configs.yml:
--------------------------------------------------------------------------------
1 | base: ['./configs/base.yml']
2 |
3 | segmentation_task:
4 | switch: true
5 | name: 图像分割(SAM)
6 |
7 | asr_task:
8 | switch: true
9 | name: 语音识别
10 | audio_asr_text:
11 | label: 'text'
12 |
13 | tts_task:
14 | switch: true
15 | name: 语音合成
16 | output_dir: /data1/zjx/ai_webui/products/tts_outputs
17 | model_type:
18 | choices: ["edge_tts", "so_vits_svc"]
19 | value: "edge_tts"
20 | tts_voice:
21 | choices: ["zh-CN-YunxiNeural", "zh-CN-YunjianNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural",
22 | "zh-CN-YunxiaNeural", "zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-CN-shaanxi-XiaoniNeural"]
23 | value: "zh-CN-YunxiNeural"
24 | tts_rate:
25 | minimum: -100
26 | maximum: 100
27 | value: 0
28 | step: 5
29 | tts_volume:
30 | minimum: -100
31 | maximum: 100
32 | value: 0
33 | step: 5
34 | tts_pitch:
35 | minimum: -100
36 | maximum: 100
37 | value: 0
38 | step: 5
39 |
40 | chatbot:
41 | switch: true
42 | name: 聊天问答
43 | chatbot_win:
44 | height: 268
45 | llm_model_type:
46 | choices: ["chatglm"]
47 | value: "chatglm"
48 |
49 | visualchat:
50 | switch: true
51 | name: 多模态问答
52 | model_type:
53 | choices: ["visualglm"]
54 | value: "visualglm"
55 |
56 | video_inpainter:
57 | switch: true
58 | name: 视频修复
59 | interactive_state:
60 | mask_save: false
61 |
62 | video_convertor:
63 | switch: true
64 | name: 视频转换
65 | video_upload_win:
66 | height: 400
67 | width: null
68 | autoplay: true
69 | aspect_ratio_box:
70 | choices: ["16/9", "9/16", "4/3"]
71 | value: "16/9"
72 | move_up_rate:
73 | minimum: 0
74 | maximum: 0.3
75 | step: 0.05
76 | value: 0.15
77 | add_noice:
78 | value: false
79 | video_segment_length:
80 | minimum: 1
81 | maximum: 10
82 | step: 1
83 | value: 1
84 | audio_speech_rate:
85 | minimum: 0.1
86 | maximum: 1.
87 | step: 0.1
88 | value: 0.9
89 | subtitling_recognition:
90 | value: true
91 | translate_engine:
92 | choices: ["chatglm"]
93 | value: "chatglm"
94 | subtitling_language:
95 | choices: ["英语", "普通话"]
96 | value: "普通话"
97 | collage_short_video:
98 | value: true
99 | voice_role:
100 | choices: ["zh-CN-YunxiNeural", "zh-CN-YunjianNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural",
101 | "zh-CN-YunxiaNeural", "zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-CN-shaanxi-XiaoniNeural"
102 | ]
103 | value: "zh-CN-YunxiNeural"
104 | bgm_name:
105 | choices: ["Paris", "illusionary_daytime", "time_back", "windy_hill"]
106 | value: "Paris"
107 | watermark:
108 | choices: ["拉普拉丝", "川陀"]
109 | value: "拉普拉丝"
110 |
111 |
--------------------------------------------------------------------------------
/demo/bgm/Paris.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/bgm/Paris.mp3
--------------------------------------------------------------------------------
/demo/bgm/illusionary_daytime.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/bgm/illusionary_daytime.mp3
--------------------------------------------------------------------------------
/demo/bgm/time_back.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/bgm/time_back.mp3
--------------------------------------------------------------------------------
/demo/bgm/windy_hill.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/bgm/windy_hill.mp3
--------------------------------------------------------------------------------
/demo/fastsam/examples/dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/dogs.jpg
--------------------------------------------------------------------------------
/demo/fastsam/examples/sa_10039.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_10039.jpg
--------------------------------------------------------------------------------
/demo/fastsam/examples/sa_11025.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_11025.jpg
--------------------------------------------------------------------------------
/demo/fastsam/examples/sa_1309.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_1309.jpg
--------------------------------------------------------------------------------
/demo/fastsam/examples/sa_192.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_192.jpg
--------------------------------------------------------------------------------
/demo/fastsam/examples/sa_414.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_414.jpg
--------------------------------------------------------------------------------
/demo/fastsam/examples/sa_561.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_561.jpg
--------------------------------------------------------------------------------
/demo/fastsam/examples/sa_862.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_862.jpg
--------------------------------------------------------------------------------
/demo/fastsam/examples/sa_8776.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/fastsam/examples/sa_8776.jpg
--------------------------------------------------------------------------------
/demo/video_inpainter/test-sample0.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/video_inpainter/test-sample0.mp4
--------------------------------------------------------------------------------
/demo/video_inpainter/test-sample1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/video_inpainter/test-sample1.mp4
--------------------------------------------------------------------------------
/demo/video_inpainter/test-sample2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/video_inpainter/test-sample2.mp4
--------------------------------------------------------------------------------
/demo/video_inpainter/test-sample3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/video_inpainter/test-sample3.mp4
--------------------------------------------------------------------------------
/demo/video_inpainter/test-sample4.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/demo/video_inpainter/test-sample4.mp4
--------------------------------------------------------------------------------
/model_weights/chatglm/chatglm2-6b-int4/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "THUDM/chatglm2-6b",
3 | "model_type": "chatglm",
4 | "architectures": [
5 | "ChatGLMModel"
6 | ],
7 | "auto_map": {
8 | "AutoConfig": "configuration_chatglm.ChatGLMConfig",
9 | "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
10 | "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration"
11 | },
12 | "add_bias_linear": false,
13 | "add_qkv_bias": true,
14 | "apply_query_key_layer_scaling": true,
15 | "apply_residual_connection_post_layernorm": false,
16 | "attention_dropout": 0.0,
17 | "attention_softmax_in_fp32": true,
18 | "bias_dropout_fusion": true,
19 | "ffn_hidden_size": 13696,
20 | "fp32_residual_connection": false,
21 | "hidden_dropout": 0.0,
22 | "hidden_size": 4096,
23 | "kv_channels": 128,
24 | "layernorm_epsilon": 1e-05,
25 | "multi_query_attention": true,
26 | "multi_query_group_num": 2,
27 | "num_attention_heads": 32,
28 | "num_layers": 28,
29 | "original_rope": true,
30 | "padded_vocab_size": 65024,
31 | "post_layer_norm": true,
32 | "quantization_bit": 4,
33 | "rmsnorm": true,
34 | "seq_length": 32768,
35 | "use_cache": true,
36 | "torch_dtype": "float16",
37 | "transformers_version": "4.27.1",
38 | "tie_word_embeddings": false,
39 | "eos_token_id": 2,
40 | "pad_token_id": 0
41 | }
--------------------------------------------------------------------------------
/model_weights/chatglm/chatglm2-6b-int4/configuration_chatglm.py:
--------------------------------------------------------------------------------
1 | from transformers import PretrainedConfig
2 |
3 |
4 | class ChatGLMConfig(PretrainedConfig):
5 | model_type = "chatglm"
6 | def __init__(
7 | self,
8 | num_layers=28,
9 | padded_vocab_size=65024,
10 | hidden_size=4096,
11 | ffn_hidden_size=13696,
12 | kv_channels=128,
13 | num_attention_heads=32,
14 | seq_length=2048,
15 | hidden_dropout=0.0,
16 | attention_dropout=0.0,
17 | layernorm_epsilon=1e-5,
18 | rmsnorm=True,
19 | apply_residual_connection_post_layernorm=False,
20 | post_layer_norm=True,
21 | add_bias_linear=False,
22 | add_qkv_bias=False,
23 | bias_dropout_fusion=True,
24 | multi_query_attention=False,
25 | multi_query_group_num=1,
26 | apply_query_key_layer_scaling=True,
27 | attention_softmax_in_fp32=True,
28 | fp32_residual_connection=False,
29 | quantization_bit=0,
30 | pre_seq_len=None,
31 | prefix_projection=False,
32 | **kwargs
33 | ):
34 | self.num_layers = num_layers
35 | self.vocab_size = padded_vocab_size
36 | self.padded_vocab_size = padded_vocab_size
37 | self.hidden_size = hidden_size
38 | self.ffn_hidden_size = ffn_hidden_size
39 | self.kv_channels = kv_channels
40 | self.num_attention_heads = num_attention_heads
41 | self.seq_length = seq_length
42 | self.hidden_dropout = hidden_dropout
43 | self.attention_dropout = attention_dropout
44 | self.layernorm_epsilon = layernorm_epsilon
45 | self.rmsnorm = rmsnorm
46 | self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
47 | self.post_layer_norm = post_layer_norm
48 | self.add_bias_linear = add_bias_linear
49 | self.add_qkv_bias = add_qkv_bias
50 | self.bias_dropout_fusion = bias_dropout_fusion
51 | self.multi_query_attention = multi_query_attention
52 | self.multi_query_group_num = multi_query_group_num
53 | self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
54 | self.attention_softmax_in_fp32 = attention_softmax_in_fp32
55 | self.fp32_residual_connection = fp32_residual_connection
56 | self.quantization_bit = quantization_bit
57 | self.pre_seq_len = pre_seq_len
58 | self.prefix_projection = prefix_projection
59 | super().__init__(**kwargs)
--------------------------------------------------------------------------------
/model_weights/chatglm/chatglm2-6b-int4/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name_or_path": "THUDM/chatglm-6b",
3 | "remove_space": false,
4 | "do_lower_case": false,
5 | "tokenizer_class": "ChatGLMTokenizer",
6 | "auto_map": {
7 | "AutoTokenizer": [
8 | "tokenization_chatglm.ChatGLMTokenizer",
9 | null
10 | ]
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/models/RAFT/__init__.py:
--------------------------------------------------------------------------------
1 | # from .demo import RAFT_infer
2 | from .raft import RAFT
3 |
--------------------------------------------------------------------------------
/models/RAFT/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1)
38 | dy = torch.linspace(-r, r, 2*r+1)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class CorrLayer(torch.autograd.Function):
64 | @staticmethod
65 | def forward(ctx, fmap1, fmap2, coords, r):
66 | fmap1 = fmap1.contiguous()
67 | fmap2 = fmap2.contiguous()
68 | coords = coords.contiguous()
69 | ctx.save_for_backward(fmap1, fmap2, coords)
70 | ctx.r = r
71 | corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
72 | return corr
73 |
74 | @staticmethod
75 | def backward(ctx, grad_corr):
76 | fmap1, fmap2, coords = ctx.saved_tensors
77 | grad_corr = grad_corr.contiguous()
78 | fmap1_grad, fmap2_grad, coords_grad = \
79 | correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
80 | return fmap1_grad, fmap2_grad, coords_grad, None
81 |
82 |
83 | class AlternateCorrBlock:
84 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
85 | self.num_levels = num_levels
86 | self.radius = radius
87 |
88 | self.pyramid = [(fmap1, fmap2)]
89 | for i in range(self.num_levels):
90 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
91 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
92 | self.pyramid.append((fmap1, fmap2))
93 |
94 | def __call__(self, coords):
95 |
96 | coords = coords.permute(0, 2, 3, 1)
97 | B, H, W, _ = coords.shape
98 |
99 | corr_list = []
100 | for i in range(self.num_levels):
101 | r = self.radius
102 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
103 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
104 |
105 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
106 | corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
107 | corr_list.append(corr.squeeze(1))
108 |
109 | corr = torch.stack(corr_list, dim=1)
110 | corr = corr.reshape(B, -1, H, W)
111 | return corr / 16.0
112 |
--------------------------------------------------------------------------------
/models/RAFT/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import os
4 | import cv2
5 | import glob
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 |
10 | from .raft import RAFT
11 | from .utils import flow_viz
12 | from .utils.utils import InputPadder
13 |
14 |
15 |
16 | DEVICE = 'cuda'
17 |
18 | def load_image(imfile):
19 | img = np.array(Image.open(imfile)).astype(np.uint8)
20 | img = torch.from_numpy(img).permute(2, 0, 1).float()
21 | return img
22 |
23 |
24 | def load_image_list(image_files):
25 | images = []
26 | for imfile in sorted(image_files):
27 | images.append(load_image(imfile))
28 |
29 | images = torch.stack(images, dim=0)
30 | images = images.to(DEVICE)
31 |
32 | padder = InputPadder(images.shape)
33 | return padder.pad(images)[0]
34 |
35 |
36 | def viz(img, flo):
37 | img = img[0].permute(1,2,0).cpu().numpy()
38 | flo = flo[0].permute(1,2,0).cpu().numpy()
39 |
40 | # map flow to rgb image
41 | flo = flow_viz.flow_to_image(flo)
42 | # img_flo = np.concatenate([img, flo], axis=0)
43 | img_flo = flo
44 |
45 | cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])
46 | # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
47 | # cv2.waitKey()
48 |
49 |
50 | def demo(args):
51 | model = torch.nn.DataParallel(RAFT(args))
52 | model.load_state_dict(torch.load(args.model))
53 |
54 | model = model.module
55 | model.to(DEVICE)
56 | model.eval()
57 |
58 | with torch.no_grad():
59 | images = glob.glob(os.path.join(args.path, '*.png')) + \
60 | glob.glob(os.path.join(args.path, '*.jpg'))
61 |
62 | images = load_image_list(images)
63 | for i in range(images.shape[0]-1):
64 | image1 = images[i,None]
65 | image2 = images[i+1,None]
66 |
67 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
68 | viz(image1, flow_up)
69 |
70 |
71 | def RAFT_infer(args):
72 | model = torch.nn.DataParallel(RAFT(args))
73 | model.load_state_dict(torch.load(args.model))
74 |
75 | model = model.module
76 | model.to(DEVICE)
77 | model.eval()
78 |
79 | return model
80 |
--------------------------------------------------------------------------------
/models/RAFT/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .update import BasicUpdateBlock, SmallUpdateBlock
7 | from .extractor import BasicEncoder, SmallEncoder
8 | from .corr import CorrBlock, AlternateCorrBlock
9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except:
14 | # dummy autocast for PyTorch < 1.6
15 | class autocast:
16 | def __init__(self, enabled):
17 | pass
18 | def __enter__(self):
19 | pass
20 | def __exit__(self, *args):
21 | pass
22 |
23 |
24 | class RAFT(nn.Module):
25 | def __init__(self, args):
26 | super(RAFT, self).__init__()
27 | self.args = args
28 |
29 | if args.small:
30 | self.hidden_dim = hdim = 96
31 | self.context_dim = cdim = 64
32 | args.corr_levels = 4
33 | args.corr_radius = 3
34 |
35 | else:
36 | self.hidden_dim = hdim = 128
37 | self.context_dim = cdim = 128
38 | args.corr_levels = 4
39 | args.corr_radius = 4
40 |
41 | if 'dropout' not in args._get_kwargs():
42 | args.dropout = 0
43 |
44 | if 'alternate_corr' not in args._get_kwargs():
45 | args.alternate_corr = False
46 |
47 | # feature network, context network, and update block
48 | if args.small:
49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52 |
53 | else:
54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57 |
58 |
59 | def freeze_bn(self):
60 | for m in self.modules():
61 | if isinstance(m, nn.BatchNorm2d):
62 | m.eval()
63 |
64 | def initialize_flow(self, img):
65 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
66 | N, C, H, W = img.shape
67 | coords0 = coords_grid(N, H//8, W//8).to(img.device)
68 | coords1 = coords_grid(N, H//8, W//8).to(img.device)
69 |
70 | # optical flow computed as difference: flow = coords1 - coords0
71 | return coords0, coords1
72 |
73 | def upsample_flow(self, flow, mask):
74 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
75 | N, _, H, W = flow.shape
76 | mask = mask.view(N, 1, 9, 8, 8, H, W)
77 | mask = torch.softmax(mask, dim=2)
78 |
79 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
80 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
81 |
82 | up_flow = torch.sum(mask * up_flow, dim=2)
83 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
84 | return up_flow.reshape(N, 2, 8*H, 8*W)
85 |
86 |
87 | def forward(self, image1, image2, iters=12, flow_init=None, test_mode=True):
88 | """ Estimate optical flow between pair of frames """
89 |
90 | # image1 = 2 * (image1 / 255.0) - 1.0
91 | # image2 = 2 * (image2 / 255.0) - 1.0
92 |
93 | image1 = image1.contiguous()
94 | image2 = image2.contiguous()
95 |
96 | hdim = self.hidden_dim
97 | cdim = self.context_dim
98 |
99 | # run the feature network
100 | with autocast(enabled=self.args.mixed_precision):
101 | fmap1, fmap2 = self.fnet([image1, image2])
102 |
103 | fmap1 = fmap1.float()
104 | fmap2 = fmap2.float()
105 |
106 | if self.args.alternate_corr:
107 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
108 | else:
109 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
110 |
111 | # run the context network
112 | with autocast(enabled=self.args.mixed_precision):
113 | cnet = self.cnet(image1)
114 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
115 | net = torch.tanh(net)
116 | inp = torch.relu(inp)
117 |
118 | coords0, coords1 = self.initialize_flow(image1)
119 |
120 | if flow_init is not None:
121 | coords1 = coords1 + flow_init
122 |
123 | flow_predictions = []
124 | for itr in range(iters):
125 | coords1 = coords1.detach()
126 | corr = corr_fn(coords1) # index correlation volume
127 |
128 | flow = coords1 - coords0
129 | with autocast(enabled=self.args.mixed_precision):
130 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
131 |
132 | # F(t+1) = F(t) + \Delta(t)
133 | coords1 = coords1 + delta_flow
134 |
135 | # upsample predictions
136 | if up_mask is None:
137 | flow_up = upflow8(coords1 - coords0)
138 | else:
139 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
140 |
141 | flow_predictions.append(flow_up)
142 |
143 | if test_mode:
144 | return coords1 - coords0, flow_up
145 |
146 | return flow_predictions
147 |
--------------------------------------------------------------------------------
/models/RAFT/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self, args):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self, args):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, args, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(args)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, args, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.args = args
118 | self.encoder = BasicMotionEncoder(args)
119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121 |
122 | self.mask = nn.Sequential(
123 | nn.Conv2d(128, 256, 3, padding=1),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(256, 64*9, 1, padding=0))
126 |
127 | def forward(self, net, inp, corr, flow, upsample=True):
128 | motion_features = self.encoder(flow, corr)
129 | inp = torch.cat([inp, motion_features], dim=1)
130 |
131 | net = self.gru(net, inp)
132 | delta_flow = self.flow_head(net)
133 |
134 | # scale mask to balence gradients
135 | mask = .25 * self.mask(net)
136 | return net, mask, delta_flow
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/models/RAFT/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .flow_viz import flow_to_image
2 | from .frame_utils import writeFlow
3 |
--------------------------------------------------------------------------------
/models/RAFT/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/models/RAFT/utils/flow_viz_pt.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
2 | import torch
3 | torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
4 |
5 | @torch.no_grad()
6 | def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
7 |
8 | """
9 | Converts a flow to an RGB image.
10 |
11 | Args:
12 | flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
13 |
14 | Returns:
15 | img (Tensor): Image Tensor of dtype uint8 where each color corresponds
16 | to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
17 | """
18 |
19 | if flow.dtype != torch.float:
20 | raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
21 |
22 | orig_shape = flow.shape
23 | if flow.ndim == 3:
24 | flow = flow[None] # Add batch dim
25 |
26 | if flow.ndim != 4 or flow.shape[1] != 2:
27 | raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
28 |
29 | max_norm = torch.sum(flow**2, dim=1).sqrt().max()
30 | epsilon = torch.finfo((flow).dtype).eps
31 | normalized_flow = flow / (max_norm + epsilon)
32 | img = _normalized_flow_to_image(normalized_flow)
33 |
34 | if len(orig_shape) == 3:
35 | img = img[0] # Remove batch dim
36 | return img
37 |
38 | @torch.no_grad()
39 | def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
40 |
41 | """
42 | Converts a batch of normalized flow to an RGB image.
43 |
44 | Args:
45 | normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
46 | Returns:
47 | img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
48 | """
49 |
50 | N, _, H, W = normalized_flow.shape
51 | device = normalized_flow.device
52 | flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
53 | colorwheel = _make_colorwheel().to(device) # shape [55x3]
54 | num_cols = colorwheel.shape[0]
55 | norm = torch.sum(normalized_flow**2, dim=1).sqrt()
56 | a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
57 | fk = (a + 1) / 2 * (num_cols - 1)
58 | k0 = torch.floor(fk).to(torch.long)
59 | k1 = k0 + 1
60 | k1[k1 == num_cols] = 0
61 | f = fk - k0
62 |
63 | for c in range(colorwheel.shape[1]):
64 | tmp = colorwheel[:, c]
65 | col0 = tmp[k0] / 255.0
66 | col1 = tmp[k1] / 255.0
67 | col = (1 - f) * col0 + f * col1
68 | col = 1 - norm * (1 - col)
69 | flow_image[:, c, :, :] = torch.floor(255. * col)
70 | return flow_image
71 |
72 |
73 | @torch.no_grad()
74 | def _make_colorwheel() -> torch.Tensor:
75 | """
76 | Generates a color wheel for optical flow visualization as presented in:
77 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
78 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
79 |
80 | Returns:
81 | colorwheel (Tensor[55, 3]): Colorwheel Tensor.
82 | """
83 |
84 | RY = 15
85 | YG = 6
86 | GC = 4
87 | CB = 11
88 | BM = 13
89 | MR = 6
90 |
91 | ncols = RY + YG + GC + CB + BM + MR
92 | colorwheel = torch.zeros((ncols, 3))
93 | col = 0
94 |
95 | # RY
96 | colorwheel[0:RY, 0] = 255
97 | colorwheel[0:RY, 1] = torch.floor(255. * torch.arange(0., RY) / RY)
98 | col = col + RY
99 | # YG
100 | colorwheel[col : col + YG, 0] = 255 - torch.floor(255. * torch.arange(0., YG) / YG)
101 | colorwheel[col : col + YG, 1] = 255
102 | col = col + YG
103 | # GC
104 | colorwheel[col : col + GC, 1] = 255
105 | colorwheel[col : col + GC, 2] = torch.floor(255. * torch.arange(0., GC) / GC)
106 | col = col + GC
107 | # CB
108 | colorwheel[col : col + CB, 1] = 255 - torch.floor(255. * torch.arange(CB) / CB)
109 | colorwheel[col : col + CB, 2] = 255
110 | col = col + CB
111 | # BM
112 | colorwheel[col : col + BM, 2] = 255
113 | colorwheel[col : col + BM, 0] = torch.floor(255. * torch.arange(0., BM) / BM)
114 | col = col + BM
115 | # MR
116 | colorwheel[col : col + MR, 2] = 255 - torch.floor(255. * torch.arange(MR) / MR)
117 | colorwheel[col : col + MR, 0] = 255
118 | return colorwheel
119 |
--------------------------------------------------------------------------------
/models/RAFT/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/models/RAFT/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd):
75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/__init__.py
--------------------------------------------------------------------------------
/models/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import random
4 | import time
5 | import torch
6 | import torch.nn as nn
7 | import logging
8 | import numpy as np
9 | from os import path as osp
10 |
11 | def constant_init(module, val, bias=0):
12 | if hasattr(module, 'weight') and module.weight is not None:
13 | nn.init.constant_(module.weight, val)
14 | if hasattr(module, 'bias') and module.bias is not None:
15 | nn.init.constant_(module.bias, bias)
16 |
17 | initialized_logger = {}
18 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
19 | """Get the root logger.
20 | The logger will be initialized if it has not been initialized. By default a
21 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
22 | also be added.
23 | Args:
24 | logger_name (str): root logger name. Default: 'basicsr'.
25 | log_file (str | None): The log filename. If specified, a FileHandler
26 | will be added to the root logger.
27 | log_level (int): The root logger level. Note that only the process of
28 | rank 0 is affected, while other processes will set the level to
29 | "Error" and be silent most of the time.
30 | Returns:
31 | logging.Logger: The root logger.
32 | """
33 | logger = logging.getLogger(logger_name)
34 | # if the logger has been initialized, just return it
35 | if logger_name in initialized_logger:
36 | return logger
37 |
38 | format_str = '%(asctime)s %(levelname)s: %(message)s'
39 | stream_handler = logging.StreamHandler()
40 | stream_handler.setFormatter(logging.Formatter(format_str))
41 | logger.addHandler(stream_handler)
42 | logger.propagate = False
43 |
44 | if log_file is not None:
45 | logger.setLevel(log_level)
46 | # add file handler
47 | # file_handler = logging.FileHandler(log_file, 'w')
48 | file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
49 | file_handler.setFormatter(logging.Formatter(format_str))
50 | file_handler.setLevel(log_level)
51 | logger.addHandler(file_handler)
52 | initialized_logger[logger_name] = True
53 | return logger
54 |
55 |
56 | IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
57 | torch.__version__)[0][:3])] >= [1, 12, 0]
58 |
59 | def gpu_is_available():
60 | if IS_HIGH_VERSION:
61 | if torch.backends.mps.is_available():
62 | return True
63 | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
64 |
65 | def get_device(gpu_id=None):
66 | if gpu_id is None:
67 | gpu_str = ''
68 | elif isinstance(gpu_id, int):
69 | gpu_str = f':{gpu_id}'
70 | else:
71 | raise TypeError('Input should be int value.')
72 |
73 | if IS_HIGH_VERSION:
74 | if torch.backends.mps.is_available():
75 | return torch.device('mps'+gpu_str)
76 | return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
77 |
78 |
79 | def set_random_seed(seed):
80 | """Set random seeds."""
81 | random.seed(seed)
82 | np.random.seed(seed)
83 | torch.manual_seed(seed)
84 | torch.cuda.manual_seed(seed)
85 | torch.cuda.manual_seed_all(seed)
86 |
87 |
88 | def get_time_str():
89 | return time.strftime('%Y%m%d_%H%M%S', time.localtime())
90 |
91 |
92 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
93 | """Scan a directory to find the interested files.
94 |
95 | Args:
96 | dir_path (str): Path of the directory.
97 | suffix (str | tuple(str), optional): File suffix that we are
98 | interested in. Default: None.
99 | recursive (bool, optional): If set to True, recursively scan the
100 | directory. Default: False.
101 | full_path (bool, optional): If set to True, include the dir_path.
102 | Default: False.
103 |
104 | Returns:
105 | A generator for all the interested files with relative pathes.
106 | """
107 |
108 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
109 | raise TypeError('"suffix" must be a string or tuple of strings')
110 |
111 | root = dir_path
112 |
113 | def _scandir(dir_path, suffix, recursive):
114 | for entry in os.scandir(dir_path):
115 | if not entry.name.startswith('.') and entry.is_file():
116 | if full_path:
117 | return_path = entry.path
118 | else:
119 | return_path = osp.relpath(entry.path, root)
120 |
121 | if suffix is None:
122 | yield return_path
123 | elif return_path.endswith(suffix):
124 | yield return_path
125 | else:
126 | if recursive:
127 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
128 | else:
129 | continue
130 |
131 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
--------------------------------------------------------------------------------
/models/modules/base_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from functools import reduce
6 |
7 | class BaseNetwork(nn.Module):
8 | def __init__(self):
9 | super(BaseNetwork, self).__init__()
10 |
11 | def print_network(self):
12 | if isinstance(self, list):
13 | self = self[0]
14 | num_params = 0
15 | for param in self.parameters():
16 | num_params += param.numel()
17 | print(
18 | 'Network [%s] was created. Total number of parameters: %.1f million. '
19 | 'To see the architecture, do print(network).' %
20 | (type(self).__name__, num_params / 1000000))
21 |
22 | def init_weights(self, init_type='normal', gain=0.02):
23 | '''
24 | initialize network's weights
25 | init_type: normal | xavier | kaiming | orthogonal
26 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
27 | '''
28 | def init_func(m):
29 | classname = m.__class__.__name__
30 | if classname.find('InstanceNorm2d') != -1:
31 | if hasattr(m, 'weight') and m.weight is not None:
32 | nn.init.constant_(m.weight.data, 1.0)
33 | if hasattr(m, 'bias') and m.bias is not None:
34 | nn.init.constant_(m.bias.data, 0.0)
35 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1
36 | or classname.find('Linear') != -1):
37 | if init_type == 'normal':
38 | nn.init.normal_(m.weight.data, 0.0, gain)
39 | elif init_type == 'xavier':
40 | nn.init.xavier_normal_(m.weight.data, gain=gain)
41 | elif init_type == 'xavier_uniform':
42 | nn.init.xavier_uniform_(m.weight.data, gain=1.0)
43 | elif init_type == 'kaiming':
44 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
45 | elif init_type == 'orthogonal':
46 | nn.init.orthogonal_(m.weight.data, gain=gain)
47 | elif init_type == 'none': # uses pytorch's default init method
48 | m.reset_parameters()
49 | else:
50 | raise NotImplementedError(
51 | 'initialization method [%s] is not implemented' %
52 | init_type)
53 | if hasattr(m, 'bias') and m.bias is not None:
54 | nn.init.constant_(m.bias.data, 0.0)
55 |
56 | self.apply(init_func)
57 |
58 | # propagate to children
59 | for m in self.children():
60 | if hasattr(m, 'init_weights'):
61 | m.init_weights(init_type, gain)
62 |
63 |
64 | class Vec2Feat(nn.Module):
65 | def __init__(self, channel, hidden, kernel_size, stride, padding):
66 | super(Vec2Feat, self).__init__()
67 | self.relu = nn.LeakyReLU(0.2, inplace=True)
68 | c_out = reduce((lambda x, y: x * y), kernel_size) * channel
69 | self.embedding = nn.Linear(hidden, c_out)
70 | self.kernel_size = kernel_size
71 | self.stride = stride
72 | self.padding = padding
73 | self.bias_conv = nn.Conv2d(channel,
74 | channel,
75 | kernel_size=3,
76 | stride=1,
77 | padding=1)
78 |
79 | def forward(self, x, t, output_size):
80 | b_, _, _, _, c_ = x.shape
81 | x = x.view(b_, -1, c_)
82 | feat = self.embedding(x)
83 | b, _, c = feat.size()
84 | feat = feat.view(b * t, -1, c).permute(0, 2, 1)
85 | feat = F.fold(feat,
86 | output_size=output_size,
87 | kernel_size=self.kernel_size,
88 | stride=self.stride,
89 | padding=self.padding)
90 | feat = self.bias_conv(feat)
91 | return feat
92 |
93 |
94 | class FusionFeedForward(nn.Module):
95 | def __init__(self, dim, hidden_dim=1960, t2t_params=None):
96 | super(FusionFeedForward, self).__init__()
97 | # We set hidden_dim as a default to 1960
98 | self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim))
99 | self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim))
100 | assert t2t_params is not None
101 | self.t2t_params = t2t_params
102 | self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49
103 |
104 | def forward(self, x, output_size):
105 | n_vecs = 1
106 | for i, d in enumerate(self.t2t_params['kernel_size']):
107 | n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
108 | (d - 1) - 1) / self.t2t_params['stride'][i] + 1)
109 |
110 | x = self.fc1(x)
111 | b, n, c = x.size()
112 | normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1)
113 | normalizer = F.fold(normalizer,
114 | output_size=output_size,
115 | kernel_size=self.t2t_params['kernel_size'],
116 | padding=self.t2t_params['padding'],
117 | stride=self.t2t_params['stride'])
118 |
119 | x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
120 | output_size=output_size,
121 | kernel_size=self.t2t_params['kernel_size'],
122 | padding=self.t2t_params['padding'],
123 | stride=self.t2t_params['stride'])
124 |
125 | x = F.unfold(x / normalizer,
126 | kernel_size=self.t2t_params['kernel_size'],
127 | padding=self.t2t_params['padding'],
128 | stride=self.t2t_params['stride']).permute(
129 | 0, 2, 1).contiguous().view(b, n, c)
130 | x = self.fc2(x)
131 | return x
132 |
--------------------------------------------------------------------------------
/models/modules/deformconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init as init
4 | from torch.nn.modules.utils import _pair, _single
5 | import math
6 |
7 | class ModulatedDeformConv2d(nn.Module):
8 | def __init__(self,
9 | in_channels,
10 | out_channels,
11 | kernel_size,
12 | stride=1,
13 | padding=0,
14 | dilation=1,
15 | groups=1,
16 | deform_groups=1,
17 | bias=True):
18 | super(ModulatedDeformConv2d, self).__init__()
19 |
20 | self.in_channels = in_channels
21 | self.out_channels = out_channels
22 | self.kernel_size = _pair(kernel_size)
23 | self.stride = stride
24 | self.padding = padding
25 | self.dilation = dilation
26 | self.groups = groups
27 | self.deform_groups = deform_groups
28 | self.with_bias = bias
29 | # enable compatibility with nn.Conv2d
30 | self.transposed = False
31 | self.output_padding = _single(0)
32 |
33 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
34 | if bias:
35 | self.bias = nn.Parameter(torch.Tensor(out_channels))
36 | else:
37 | self.register_parameter('bias', None)
38 | self.init_weights()
39 |
40 | def init_weights(self):
41 | n = self.in_channels
42 | for k in self.kernel_size:
43 | n *= k
44 | stdv = 1. / math.sqrt(n)
45 | self.weight.data.uniform_(-stdv, stdv)
46 | if self.bias is not None:
47 | self.bias.data.zero_()
48 |
49 | if hasattr(self, 'conv_offset'):
50 | self.conv_offset.weight.data.zero_()
51 | self.conv_offset.bias.data.zero_()
52 |
53 | def forward(self, x, offset, mask):
54 | pass
--------------------------------------------------------------------------------
/models/sam/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/__init__.py
--------------------------------------------------------------------------------
/models/sam/assets/masks1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/assets/masks1.png
--------------------------------------------------------------------------------
/models/sam/assets/masks2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/assets/masks2.jpg
--------------------------------------------------------------------------------
/models/sam/assets/model_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/assets/model_diagram.png
--------------------------------------------------------------------------------
/models/sam/assets/notebook1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/assets/notebook1.png
--------------------------------------------------------------------------------
/models/sam/assets/notebook2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/assets/notebook2.png
--------------------------------------------------------------------------------
/models/sam/linter.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | {
5 | black --version | grep -E "23\." > /dev/null
6 | } || {
7 | echo "Linter requires 'black==23.*' !"
8 | exit 1
9 | }
10 |
11 | ISORT_VERSION=$(isort --version-number)
12 | if [[ "$ISORT_VERSION" != 5.12* ]]; then
13 | echo "Linter requires isort==5.12.0 !"
14 | exit 1
15 | fi
16 |
17 | echo "Running isort ..."
18 | isort . --atomic
19 |
20 | echo "Running black ..."
21 | black -l 100 .
22 |
23 | echo "Running flake8 ..."
24 | if [ -x "$(command -v flake8)" ]; then
25 | flake8 .
26 | else
27 | python3 -m flake8 .
28 | fi
29 |
30 | echo "Running mypy..."
31 |
32 | mypy --exclude 'setup.py|notebooks' .
33 |
--------------------------------------------------------------------------------
/models/sam/notebooks/images/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/notebooks/images/dog.jpg
--------------------------------------------------------------------------------
/models/sam/notebooks/images/groceries.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/notebooks/images/groceries.jpg
--------------------------------------------------------------------------------
/models/sam/notebooks/images/truck.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/notebooks/images/truck.jpg
--------------------------------------------------------------------------------
/models/sam/segment_anything/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/sam/segment_anything/.DS_Store
--------------------------------------------------------------------------------
/models/sam/segment_anything/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .build_sam import (
8 | build_sam,
9 | build_sam_vit_h,
10 | build_sam_vit_l,
11 | build_sam_vit_b,
12 | sam_model_registry,
13 | )
14 | from .predictor import SamPredictor
15 | from .automatic_mask_generator import SamAutomaticMaskGenerator
16 |
--------------------------------------------------------------------------------
/models/sam/segment_anything/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from functools import partial
10 |
11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12 |
13 |
14 | def build_sam_vit_h(checkpoint=None):
15 | return _build_sam(
16 | encoder_embed_dim=1280,
17 | encoder_depth=32,
18 | encoder_num_heads=16,
19 | encoder_global_attn_indexes=[7, 15, 23, 31],
20 | checkpoint=checkpoint,
21 | )
22 |
23 |
24 | build_sam = build_sam_vit_h
25 |
26 |
27 | def build_sam_vit_l(checkpoint=None):
28 | return _build_sam(
29 | encoder_embed_dim=1024,
30 | encoder_depth=24,
31 | encoder_num_heads=16,
32 | encoder_global_attn_indexes=[5, 11, 17, 23],
33 | checkpoint=checkpoint,
34 | )
35 |
36 |
37 | def build_sam_vit_b(checkpoint=None):
38 | return _build_sam(
39 | encoder_embed_dim=768,
40 | encoder_depth=12,
41 | encoder_num_heads=12,
42 | encoder_global_attn_indexes=[2, 5, 8, 11],
43 | checkpoint=checkpoint,
44 | )
45 |
46 |
47 | sam_model_registry = {
48 | "default": build_sam_vit_h,
49 | "vit_h": build_sam_vit_h,
50 | "vit_l": build_sam_vit_l,
51 | "vit_b": build_sam_vit_b,
52 | }
53 |
54 |
55 | def _build_sam(
56 | encoder_embed_dim,
57 | encoder_depth,
58 | encoder_num_heads,
59 | encoder_global_attn_indexes,
60 | checkpoint=None,
61 | ):
62 | prompt_embed_dim = 256
63 | image_size = 1024
64 | vit_patch_size = 16
65 | image_embedding_size = image_size // vit_patch_size
66 | sam = Sam(
67 | image_encoder=ImageEncoderViT(
68 | depth=encoder_depth,
69 | embed_dim=encoder_embed_dim,
70 | img_size=image_size,
71 | mlp_ratio=4,
72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73 | num_heads=encoder_num_heads,
74 | patch_size=vit_patch_size,
75 | qkv_bias=True,
76 | use_rel_pos=True,
77 | global_attn_indexes=encoder_global_attn_indexes,
78 | window_size=14,
79 | out_chans=prompt_embed_dim,
80 | ),
81 | prompt_encoder=PromptEncoder(
82 | embed_dim=prompt_embed_dim,
83 | image_embedding_size=(image_embedding_size, image_embedding_size),
84 | input_image_size=(image_size, image_size),
85 | mask_in_chans=16,
86 | ),
87 | mask_decoder=MaskDecoder(
88 | num_multimask_outputs=3,
89 | transformer=TwoWayTransformer(
90 | depth=2,
91 | embedding_dim=prompt_embed_dim,
92 | mlp_dim=2048,
93 | num_heads=8,
94 | ),
95 | transformer_dim=prompt_embed_dim,
96 | iou_head_depth=3,
97 | iou_head_hidden_dim=256,
98 | ),
99 | pixel_mean=[123.675, 116.28, 103.53],
100 | pixel_std=[58.395, 57.12, 57.375],
101 | )
102 | sam.eval()
103 | if checkpoint is not None:
104 | with open(checkpoint, "rb") as f:
105 | state_dict = torch.load(f)
106 | sam.load_state_dict(state_dict)
107 | return sam
108 |
--------------------------------------------------------------------------------
/models/sam/segment_anything/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sam import Sam
8 | from .image_encoder import ImageEncoderViT
9 | from .mask_decoder import MaskDecoder
10 | from .prompt_encoder import PromptEncoder
11 | from .transformer import TwoWayTransformer
12 |
--------------------------------------------------------------------------------
/models/sam/segment_anything/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/models/sam/segment_anything/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/models/sam/segment_anything/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11 |
12 | from copy import deepcopy
13 | from typing import Tuple
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31 | return np.array(resize(to_pil_image(image), target_size))
32 |
33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34 | """
35 | Expects a numpy array of length 2 in the final dimension. Requires the
36 | original image size in (H, W) format.
37 | """
38 | old_h, old_w = original_size
39 | new_h, new_w = self.get_preprocess_shape(
40 | original_size[0], original_size[1], self.target_length
41 | )
42 | coords = deepcopy(coords).astype(float)
43 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
44 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
45 | return coords
46 |
47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48 | """
49 | Expects a numpy array shape Bx4. Requires the original image size
50 | in (H, W) format.
51 | """
52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53 | return boxes.reshape(-1, 4)
54 |
55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56 | """
57 | Expects batched images with shape BxCxHxW and float format. This
58 | transformation may not exactly match apply_image. apply_image is
59 | the transformation expected by the model.
60 | """
61 | # Expects an image in BCHW format. May not exactly match apply_image.
62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
63 | return F.interpolate(
64 | image, target_size, mode="bilinear", align_corners=False, antialias=True
65 | )
66 |
67 | def apply_coords_torch(
68 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
69 | ) -> torch.Tensor:
70 | """
71 | Expects a torch tensor with length 2 in the last dimension. Requires the
72 | original image size in (H, W) format.
73 | """
74 | old_h, old_w = original_size
75 | new_h, new_w = self.get_preprocess_shape(
76 | original_size[0], original_size[1], self.target_length
77 | )
78 | coords = deepcopy(coords).to(torch.float)
79 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
80 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
81 | return coords
82 |
83 | def apply_boxes_torch(
84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85 | ) -> torch.Tensor:
86 | """
87 | Expects a torch tensor with shape Bx4. Requires the original image
88 | size in (H, W) format.
89 | """
90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91 | return boxes.reshape(-1, 4)
92 |
93 | @staticmethod
94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95 | """
96 | Compute the output size given input size and target long side length.
97 | """
98 | scale = long_side_length * 1.0 / max(oldh, oldw)
99 | newh, neww = oldh * scale, oldw * scale
100 | neww = int(neww + 0.5)
101 | newh = int(newh + 0.5)
102 | return (newh, neww)
103 |
--------------------------------------------------------------------------------
/models/sam/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | line_length=100
3 | multi_line_output=3
4 | include_trailing_comma=True
5 | known_standard_library=numpy,setuptools
6 | skip_glob=*/__init__.py
7 | known_myself=segment_anything
8 | known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort
9 | no_lines_before=STDLIB,THIRDPARTY
10 | sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER
11 | default_section=FIRSTPARTY
12 |
--------------------------------------------------------------------------------
/models/sam/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from setuptools import find_packages, setup
8 |
9 | setup(
10 | name="segment_anything",
11 | version="1.0",
12 | install_requires=[],
13 | packages=find_packages(exclude="notebooks"),
14 | extras_require={
15 | "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"],
16 | "dev": ["flake8", "isort", "black", "mypy"],
17 | },
18 | )
19 |
--------------------------------------------------------------------------------
/models/tracker/base_tracker.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from omegaconf import OmegaConf
5 |
6 | import sys
7 |
8 | from models.tracker.config import CONFIG
9 | from models.tracker.model.cutie import CUTIE
10 | from models.tracker.inference.inference_core import InferenceCore
11 | from models.tracker.utils.mask_mapper import MaskMapper
12 |
13 | from utils.painter import mask_painter
14 |
15 |
16 | class BaseTracker:
17 | def __init__(self, cutie_checkpoint, device) -> None:
18 | """
19 | device: model device
20 | cutie_checkpoint: checkpoint of XMem model
21 | """
22 | config = OmegaConf.create(CONFIG)
23 |
24 | # initialise XMem
25 | network = CUTIE(config).to(device).eval()
26 | model_weights = torch.load(cutie_checkpoint, map_location=device)
27 | network.load_weights(model_weights)
28 |
29 | # initialise IncerenceCore
30 | self.tracker = InferenceCore(network, config)
31 | self.device = device
32 |
33 | # changable properties
34 | self.mapper = MaskMapper()
35 | self.initialised = False
36 |
37 | @torch.no_grad()
38 | def resize_mask(self, mask):
39 | # mask transform is applied AFTER mapper, so we need to post-process it in eval.py
40 | h, w = mask.shape[-2:]
41 | min_hw = min(h, w)
42 | return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
43 | mode='nearest')
44 |
45 | @torch.no_grad()
46 | def image_to_torch(self, frame: np.ndarray, device: str = 'cuda'):
47 | # frame: H*W*3 numpy array
48 | frame = frame.transpose(2, 0, 1)
49 | frame = torch.from_numpy(frame).float().to(device, non_blocking=True) / 255
50 | return frame
51 |
52 | @torch.no_grad()
53 | def track(self, frame, first_frame_annotation=None):
54 | """
55 | Input:
56 | frames: numpy arrays (H, W, 3)
57 | logit: numpy array (H, W), logit
58 |
59 | Output:
60 | mask: numpy arrays (H, W)
61 | logit: numpy arrays, probability map (H, W)
62 | painted_image: numpy array (H, W, 3)
63 | """
64 |
65 | if first_frame_annotation is not None: # first frame mask
66 | # initialisation
67 | mask, labels = self.mapper.convert_mask(first_frame_annotation)
68 | mask = torch.Tensor(mask).to(self.device)
69 | else:
70 | mask = None
71 | labels = None
72 |
73 | # prepare inputs
74 | frame_tensor = self.image_to_torch(frame, self.device)
75 |
76 | # track one frame
77 | probs = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
78 |
79 | # convert to mask
80 | out_mask = torch.argmax(probs, dim=0)
81 | out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
82 |
83 | final_mask = np.zeros_like(out_mask)
84 |
85 | # map back
86 | for k, v in self.mapper.remappings.items():
87 | final_mask[out_mask == v] = k
88 |
89 | num_objs = final_mask.max()
90 | painted_image = frame
91 | for obj in range(1, num_objs+1):
92 | if np.max(final_mask==obj) == 0:
93 | continue
94 | painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
95 |
96 | return final_mask, final_mask, painted_image
97 |
98 | @torch.no_grad()
99 | def clear_memory(self):
100 | self.tracker.clear_memory()
101 | self.mapper.clear_labels()
102 | torch.cuda.empty_cache()
--------------------------------------------------------------------------------
/models/tracker/config/__init__.py:
--------------------------------------------------------------------------------
1 | CONFIG = {'exp_id': 'default', 'dataset': 'd17-val', 'amp': False, 'output_dir': None, 'flip_aug': False, 'max_internal_size': -1, 'image_directory': None, 'mask_directory': None, 'json_directory': None, 'size': None, 'save_all': None, 'use_all_masks': None, 'use_long_term': None, 'mem_every': 5, 'max_mem_frames': 5, 'long_term': {'count_usage': True, 'max_mem_frames': 10, 'min_mem_frames': 5, 'num_prototypes': 128, 'max_num_tokens': 10000, 'buffer_tokens': 2000}, 'top_k': 30, 'stagger_updates': 5, 'chunk_size': -1, 'save_scores': False, 'save_aux': False, 'visualize': False, 'model': {'pixel_mean': [0.485, 0.456, 0.406], 'pixel_std': [0.229, 0.224, 0.225], 'pixel_dim': 256, 'key_dim': 64, 'value_dim': 256, 'sensory_dim': 256, 'embed_dim': 256, 'pixel_encoder': {'type': 'resnet50', 'ms_dims': [1024, 512, 256]}, 'mask_encoder': {'type': 'resnet18', 'final_dim': 256}, 'pixel_pe_scale': 32, 'pixel_pe_temperature': 128, 'object_transformer': {'embed_dim': '${model.embed_dim}', 'ff_dim': 2048, 'num_heads': 8, 'num_blocks': 3, 'num_queries': 16, 'read_from_pixel': {'input_norm': False, 'input_add_pe': False, 'add_pe_to_qkv': [True, True, False]}, 'read_from_past': {'add_pe_to_qkv': [True, True, False]}, 'read_from_memory': {'add_pe_to_qkv': [True, True, False]}, 'read_from_query': {'add_pe_to_qkv': [True, True, False], 'output_norm': False}, 'query_self_attention': {'add_pe_to_qkv': [True, True, False]}, 'pixel_self_attention': {'add_pe_to_qkv': [True, True, False]}}, 'object_summarizer': {'embed_dim': '${model.object_transformer.embed_dim}', 'num_summaries': '${model.object_transformer.num_queries}', 'add_pe': True}, 'aux_loss': {'sensory': {'enabled': True, 'weight': 0.01}, 'query': {'enabled': True, 'weight': 0.01}}, 'mask_decoder': {'up_dims': [256, 128, 128]}}}
--------------------------------------------------------------------------------
/models/tracker/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/tracker/inference/__init__.py
--------------------------------------------------------------------------------
/models/tracker/inference/image_feature_store.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Iterable
3 | import torch
4 | from models.tracker.model.cutie import CUTIE
5 |
6 |
7 | class ImageFeatureStore:
8 | """
9 | A cache for image features.
10 | These features might be reused at different parts of the inference pipeline.
11 | This class provide an interface for reusing these features.
12 | It is the user's responsibility to delete redundant features.
13 |
14 | Feature of a frame should be associated with a unique index -- typically the frame id.
15 | """
16 | def __init__(self, network: CUTIE, no_warning: bool = False):
17 | self.network = network
18 | self._store = {}
19 | self.no_warning = no_warning
20 |
21 | def _encode_feature(self, index: int, image: torch.Tensor) -> None:
22 | ms_features, pix_feat = self.network.encode_image(image)
23 | key, shrinkage, selection = self.network.transform_key(ms_features[0])
24 | self._store[index] = (ms_features, pix_feat, key, shrinkage, selection)
25 |
26 | def get_features(self, index: int,
27 | image: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor):
28 | if index not in self._store:
29 | self._encode_feature(index, image)
30 |
31 | return self._store[index][:2]
32 |
33 | def get_key(self, index: int,
34 | image: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor):
35 | if index not in self._store:
36 | self._encode_feature(index, image)
37 |
38 | return self._store[index][2:]
39 |
40 | def delete(self, index: int) -> None:
41 | if index in self._store:
42 | del self._store[index]
43 |
44 | def __len__(self):
45 | return len(self._store)
46 |
47 | def __del__(self):
48 | if len(self._store) > 0 and not self.no_warning:
49 | warnings.warn(f'Leaking {self._store.keys()} in the image feature store')
50 |
--------------------------------------------------------------------------------
/models/tracker/inference/object_info.py:
--------------------------------------------------------------------------------
1 | class ObjectInfo:
2 | """
3 | Store meta information for an object
4 | """
5 | def __init__(self, id: int):
6 | self.id = id
7 | self.poke_count = 0 # count number of detections missed
8 |
9 | def poke(self) -> None:
10 | self.poke_count += 1
11 |
12 | def unpoke(self) -> None:
13 | self.poke_count = 0
14 |
15 | def __hash__(self):
16 | return hash(self.id)
17 |
18 | def __eq__(self, other):
19 | if type(other) == int:
20 | return self.id == other
21 | return self.id == other.id
22 |
23 | def __repr__(self):
24 | return f'(ID: {self.id})'
25 |
--------------------------------------------------------------------------------
/models/tracker/inference/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/tracker/inference/utils/__init__.py
--------------------------------------------------------------------------------
/models/tracker/inference/utils/args_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from omegaconf import DictConfig
3 |
4 | log = logging.getLogger()
5 |
6 |
7 | def get_dataset_cfg(cfg: DictConfig):
8 | dataset_name = cfg.dataset
9 | data_cfg = cfg.datasets[dataset_name]
10 |
11 | potential_overrides = [
12 | 'image_directory',
13 | 'mask_directory',
14 | 'json_directory',
15 | 'size',
16 | 'save_all',
17 | 'use_all_masks',
18 | 'use_long_term',
19 | 'mem_every',
20 | ]
21 |
22 | for override in potential_overrides:
23 | if cfg[override] is not None:
24 | log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}')
25 | data_cfg[override] = cfg[override]
26 | # escalte all potential overrides to the top-level config
27 | if override in data_cfg:
28 | cfg[override] = data_cfg[override]
29 |
30 | return data_cfg
31 |
--------------------------------------------------------------------------------
/models/tracker/inference/utils/burst_utils.py:
--------------------------------------------------------------------------------
1 | from os import path
2 | import copy
3 | import json
4 |
5 |
6 | class BURSTResultHandler:
7 | def __init__(self, dataset_json):
8 | self.dataset_json = copy.deepcopy(dataset_json)
9 |
10 | # get rid of the segmentations while keeping the metadata
11 | self.dataset_json['sequences'] = []
12 |
13 | def add_sequence(self, sequence_json):
14 | self.dataset_json['sequences'].append(sequence_json)
15 |
16 | def dump(self, root):
17 | json_path = path.join(root, 'predictions.json')
18 | with open(json_path, 'w') as f:
19 | json.dump(self.dataset_json, f)
--------------------------------------------------------------------------------
/models/tracker/inference/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple
2 | import torch
3 |
4 | from models.tracker.inference.object_info import ObjectInfo
5 |
6 |
7 | class FrameInfo:
8 | def __init__(self, image: torch.Tensor, mask: torch.Tensor, segments_info: List[ObjectInfo],
9 | ti: int, info: Dict):
10 | self.image = image
11 | self.mask = mask
12 | self.segments_info = segments_info
13 | self.ti = ti
14 | self.info = info
15 |
16 | @property
17 | def name(self) -> str:
18 | return self.info['frame']
19 |
20 | @property
21 | def shape(self) -> Tuple(int):
22 | return self.info['shape']
23 |
24 | @property
25 | def need_save(self) -> bool:
26 | return self.info['save']
27 |
--------------------------------------------------------------------------------
/models/tracker/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/tracker/model/__init__.py
--------------------------------------------------------------------------------
/models/tracker/model/aux_modules.py:
--------------------------------------------------------------------------------
1 | """
2 | For computing auxiliary outputs for auxiliary losses
3 | """
4 | from typing import Dict
5 | from omegaconf import DictConfig
6 | import torch
7 | import torch.nn as nn
8 |
9 | from .group_modules import GConv2d
10 | from models.tracker.utils.tensor_utils import aggregate
11 |
12 |
13 | class LinearPredictor(nn.Module):
14 | def __init__(self, x_dim: int, pix_dim: int):
15 | super().__init__()
16 | self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1)
17 |
18 | def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
19 | # pixel_feat: B*pix_dim*H*W
20 | # x: B*num_objects*x_dim*H*W
21 | num_objects = x.shape[1]
22 | x = self.projection(x)
23 |
24 | pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
25 | logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1]
26 | return logits
27 |
28 |
29 | class DirectPredictor(nn.Module):
30 | def __init__(self, x_dim: int):
31 | super().__init__()
32 | self.projection = GConv2d(x_dim, 1, kernel_size=1)
33 |
34 | def forward(self, x: torch.Tensor) -> torch.Tensor:
35 | # x: B*num_objects*x_dim*H*W
36 | logits = self.projection(x).squeeze(2)
37 | return logits
38 |
39 |
40 | class AuxComputer(nn.Module):
41 | def __init__(self, cfg: DictConfig):
42 | super().__init__()
43 |
44 | use_sensory_aux = cfg.model.aux_loss.sensory.enabled
45 | self.use_query_aux = cfg.model.aux_loss.query.enabled
46 |
47 | sensory_dim = cfg.model.sensory_dim
48 | embed_dim = cfg.model.embed_dim
49 |
50 | if use_sensory_aux:
51 | self.sensory_aux = LinearPredictor(sensory_dim, embed_dim)
52 | else:
53 | self.sensory_aux = None
54 |
55 | def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor:
56 | prob = torch.sigmoid(logits)
57 | if selector is not None:
58 | prob = prob * selector
59 | logits = aggregate(prob, dim=1)
60 | return logits
61 |
62 | def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor],
63 | selector: torch.Tensor) -> Dict[str, torch.Tensor]:
64 | sensory = aux_input['sensory']
65 | q_logits = aux_input['q_logits']
66 |
67 | aux_output = {}
68 | aux_output['attn_mask'] = aux_input['attn_mask']
69 |
70 | if self.sensory_aux is not None:
71 | # B*num_objects*H*W
72 | logits = self.sensory_aux(pix_feat, sensory)
73 | aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector)
74 | if self.use_query_aux:
75 | # B*num_objects*num_levels*H*W
76 | aux_output['q_logits'] = self._aggregate_with_selector(
77 | torch.stack(q_logits, dim=2),
78 | selector.unsqueeze(2) if selector is not None else None)
79 |
80 | return aux_output
--------------------------------------------------------------------------------
/models/tracker/model/channel_attn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class CAResBlock(nn.Module):
8 | def __init__(self, in_dim: int, out_dim: int, residual: bool = True):
9 | super().__init__()
10 | self.residual = residual
11 | self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
12 | self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
13 |
14 | t = int((abs(math.log2(out_dim)) + 1) // 2)
15 | k = t if t % 2 else t + 1
16 | self.pool = nn.AdaptiveAvgPool2d(1)
17 | self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
18 |
19 | if self.residual:
20 | if in_dim == out_dim:
21 | self.downsample = nn.Identity()
22 | else:
23 | self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | r = x
27 | x = self.conv1(F.relu(x))
28 | x = self.conv2(F.relu(x))
29 |
30 | b, c = x.shape[:2]
31 | w = self.pool(x).view(b, 1, c)
32 | w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1
33 |
34 | if self.residual:
35 | x = x * w + self.downsample(r)
36 | else:
37 | x = x * w
38 |
39 | return x
40 |
--------------------------------------------------------------------------------
/models/tracker/model/group_modules.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .channel_attn import CAResBlock
6 |
7 |
8 | def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
9 | align_corners: bool) -> torch.Tensor:
10 | batch_size, num_objects = g.shape[:2]
11 | g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
12 | scale_factor=ratio,
13 | mode=mode,
14 | align_corners=align_corners)
15 | g = g.view(batch_size, num_objects, *g.shape[1:])
16 | return g
17 |
18 |
19 | def upsample_groups(g: torch.Tensor,
20 | ratio: float = 2,
21 | mode: str = 'bilinear',
22 | align_corners: bool = False) -> torch.Tensor:
23 | return interpolate_groups(g, ratio, mode, align_corners)
24 |
25 |
26 | def downsample_groups(g: torch.Tensor,
27 | ratio: float = 1 / 2,
28 | mode: str = 'area',
29 | align_corners: bool = None) -> torch.Tensor:
30 | return interpolate_groups(g, ratio, mode, align_corners)
31 |
32 |
33 | class GConv2d(nn.Conv2d):
34 | def forward(self, g: torch.Tensor) -> torch.Tensor:
35 | batch_size, num_objects = g.shape[:2]
36 | g = super().forward(g.flatten(start_dim=0, end_dim=1))
37 | return g.view(batch_size, num_objects, *g.shape[1:])
38 |
39 |
40 | class GroupResBlock(nn.Module):
41 | def __init__(self, in_dim: int, out_dim: int):
42 | super().__init__()
43 |
44 | if in_dim == out_dim:
45 | self.downsample = nn.Identity()
46 | else:
47 | self.downsample = GConv2d(in_dim, out_dim, kernel_size=1)
48 |
49 | self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1)
50 | self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1)
51 |
52 | def forward(self, g: torch.Tensor) -> torch.Tensor:
53 | out_g = self.conv1(F.relu(g))
54 | out_g = self.conv2(F.relu(out_g))
55 |
56 | g = self.downsample(g)
57 |
58 | return out_g + g
59 |
60 |
61 | class MainToGroupDistributor(nn.Module):
62 | def __init__(self,
63 | x_transform: Optional[nn.Module] = None,
64 | g_transform: Optional[nn.Module] = None,
65 | method: str = 'cat',
66 | reverse_order: bool = False):
67 | super().__init__()
68 |
69 | self.x_transform = x_transform
70 | self.g_transform = g_transform
71 | self.method = method
72 | self.reverse_order = reverse_order
73 |
74 | def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor:
75 | num_objects = g.shape[1]
76 |
77 | if self.x_transform is not None:
78 | x = self.x_transform(x)
79 |
80 | if self.g_transform is not None:
81 | g = self.g_transform(g)
82 |
83 | if not skip_expand:
84 | x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
85 | if self.method == 'cat':
86 | if self.reverse_order:
87 | g = torch.cat([g, x], 2)
88 | else:
89 | g = torch.cat([x, g], 2)
90 | elif self.method == 'add':
91 | g = x + g
92 | elif self.method == 'mulcat':
93 | g = torch.cat([x * g, g], dim=2)
94 | elif self.method == 'muladd':
95 | g = x * g + g
96 | else:
97 | raise NotImplementedError
98 |
99 | return g
100 |
101 |
102 | class GroupFeatureFusionBlock(nn.Module):
103 | def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int):
104 | super().__init__()
105 |
106 | x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1)
107 | g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1)
108 |
109 | self.distributor = MainToGroupDistributor(x_transform=x_transform,
110 | g_transform=g_transform,
111 | method='add')
112 | self.block1 = CAResBlock(out_dim, out_dim)
113 | self.block2 = CAResBlock(out_dim, out_dim)
114 |
115 | def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
116 | batch_size, num_objects = g.shape[:2]
117 |
118 | g = self.distributor(x, g)
119 |
120 | g = g.flatten(start_dim=0, end_dim=1)
121 |
122 | g = self.block1(g)
123 | g = self.block2(g)
124 |
125 | g = g.view(batch_size, num_objects, *g.shape[1:])
126 |
127 | return g
--------------------------------------------------------------------------------
/models/tracker/model/losses.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict
2 | from omegaconf import DictConfig
3 | from collections import defaultdict
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from models.tracker.utils.point_features import calculate_uncertainty, point_sample, get_uncertain_point_coords_with_randomness
8 | from models.tracker.utils.tensor_utils import cls_to_one_hot
9 |
10 |
11 | @torch.jit.script
12 | def ce_loss(logits: torch.Tensor, soft_gt: torch.Tensor) -> torch.Tensor:
13 | # logits: T*C*num_points
14 | loss = F.cross_entropy(logits, soft_gt, reduction='none')
15 | # sum over temporal dimension
16 | return loss.sum(0).mean()
17 |
18 |
19 | @torch.jit.script
20 | def dice_loss(mask: torch.Tensor, soft_gt: torch.Tensor) -> torch.Tensor:
21 | # mask: T*C*num_points
22 | # soft_gt: T*C*num_points
23 | # ignores the background
24 | mask = mask[:, 1:].flatten(start_dim=2)
25 | gt = soft_gt[:, 1:].float().flatten(start_dim=2)
26 | numerator = 2 * (mask * gt).sum(-1)
27 | denominator = mask.sum(-1) + gt.sum(-1)
28 | loss = 1 - (numerator + 1) / (denominator + 1)
29 | return loss.sum(0).mean()
30 |
31 |
32 | class LossComputer:
33 | def __init__(self, cfg: DictConfig, stage_cfg: DictConfig):
34 | super().__init__()
35 | self.point_supervision = stage_cfg.point_supervision
36 | self.num_points = stage_cfg.train_num_points
37 | self.oversample_ratio = stage_cfg.oversample_ratio
38 | self.importance_sample_ratio = stage_cfg.importance_sample_ratio
39 |
40 | self.sensory_weight = cfg.model.aux_loss.sensory.weight
41 | self.query_weight = cfg.model.aux_loss.query.weight
42 |
43 | def mask_loss(self, logits: torch.Tensor,
44 | soft_gt: torch.Tensor) -> (torch.Tensor, torch.Tensor):
45 | assert self.point_supervision
46 |
47 | with torch.no_grad():
48 | # sample point_coords
49 | point_coords = get_uncertain_point_coords_with_randomness(
50 | logits, lambda x: calculate_uncertainty(x), self.num_points, self.oversample_ratio,
51 | self.importance_sample_ratio)
52 | # get gt labels
53 | point_labels = point_sample(soft_gt, point_coords, align_corners=False)
54 | point_logits = point_sample(logits, point_coords, align_corners=False)
55 | # point_labels and point_logits: B*C*num_points
56 |
57 | loss_ce = ce_loss(point_logits, point_labels)
58 | loss_dice = dice_loss(point_logits.softmax(dim=1), point_labels)
59 |
60 | return loss_ce, loss_dice
61 |
62 | def compute(self, data: Dict[str, torch.Tensor],
63 | num_objects: List[int]) -> Dict[str, torch.Tensor]:
64 | batch_size, num_frames = data['rgb'].shape[:2]
65 | losses = defaultdict(float)
66 | t_range = range(1, num_frames)
67 |
68 | for bi in range(batch_size):
69 | logits = torch.stack([data[f'logits_{ti}'][bi, :num_objects[bi] + 1] for ti in t_range],
70 | dim=0)
71 | cls_gt = data['cls_gt'][bi, 1:] # remove gt for the first frame
72 | soft_gt = cls_to_one_hot(cls_gt, num_objects[bi])
73 |
74 | loss_ce, loss_dice = self.mask_loss(logits, soft_gt)
75 | losses['loss_ce'] += loss_ce / batch_size
76 | losses['loss_dice'] += loss_dice / batch_size
77 |
78 | aux = [data[f'aux_{ti}'] for ti in t_range]
79 | if 'sensory_logits' in aux[0]:
80 | sensory_log = torch.stack(
81 | [a['sensory_logits'][bi, :num_objects[bi] + 1] for a in aux], dim=0)
82 | loss_ce, loss_dice = self.mask_loss(sensory_log, soft_gt)
83 | losses['aux_sensory_ce'] += loss_ce / batch_size * self.sensory_weight
84 | losses['aux_sensory_dice'] += loss_dice / batch_size * self.sensory_weight
85 | if 'q_logits' in aux[0]:
86 | num_levels = aux[0]['q_logits'].shape[2]
87 |
88 | for l in range(num_levels):
89 | query_log = torch.stack(
90 | [a['q_logits'][bi, :num_objects[bi] + 1, l] for a in aux], dim=0)
91 | loss_ce, loss_dice = self.mask_loss(query_log, soft_gt)
92 | losses[f'aux_query_ce_l{l}'] += loss_ce / batch_size * self.query_weight
93 | losses[f'aux_query_dice_l{l}'] += loss_dice / batch_size * self.query_weight
94 |
95 | losses['total_loss'] = sum(losses.values())
96 |
97 | return losses
98 |
--------------------------------------------------------------------------------
/models/tracker/model/modules.py:
--------------------------------------------------------------------------------
1 | from typing import List, Iterable
2 | import torch
3 | import torch.nn as nn
4 |
5 | from models.tracker.model.group_modules import *
6 |
7 |
8 | class MaskUpsampleBlock(nn.Module):
9 | def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
10 | super().__init__()
11 | self.distributor = MainToGroupDistributor(method='add')
12 | self.out_conv = GroupResBlock(in_dim, out_dim)
13 | self.scale_factor = scale_factor
14 |
15 | def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
16 | g = upsample_groups(in_g, ratio=self.scale_factor)
17 | g = self.distributor(skip_f, g)
18 | g = self.out_conv(g)
19 | return g
20 |
21 |
22 | class DecoderFeatureProcessor(nn.Module):
23 | def __init__(self, decoder_dims: List[int], out_dims: List[int]):
24 | super().__init__()
25 | self.transforms = nn.ModuleList([
26 | nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims)
27 | ])
28 |
29 | def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]:
30 | outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)]
31 | return outputs
32 |
33 |
34 | # @torch.jit.script
35 | def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
36 | # h: batch_size * num_objects * hidden_dim * h * w
37 | # values: batch_size * num_objects * (hidden_dim*3) * h * w
38 | dim = values.shape[2] // 3
39 | forget_gate = torch.sigmoid(values[:, :, :dim])
40 | update_gate = torch.sigmoid(values[:, :, dim:dim * 2])
41 | new_value = torch.tanh(values[:, :, dim * 2:])
42 | new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
43 | return new_h
44 |
45 |
46 | class SensoryUpdater(nn.Module):
47 | # Used in the decoder, multi-scale feature + GRU
48 | def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
49 | super().__init__()
50 | self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
51 | self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
52 | self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
53 |
54 | self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
55 |
56 | nn.init.xavier_normal_(self.transform.weight)
57 |
58 | def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
59 | g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
60 | self.g4_conv(downsample_groups(g[2], ratio=1/4))
61 |
62 | with torch.cuda.amp.autocast(enabled=False):
63 | g = g.float()
64 | h = h.float()
65 | values = self.transform(torch.cat([g, h], dim=2))
66 | new_h = _recurrent_update(h, values)
67 |
68 | return new_h
69 |
70 |
71 | class SensoryDeepUpdater(nn.Module):
72 | def __init__(self, f_dim: int, sensory_dim: int):
73 | super().__init__()
74 | self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
75 |
76 | nn.init.xavier_normal_(self.transform.weight)
77 |
78 | def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
79 | with torch.cuda.amp.autocast(enabled=False):
80 | g = g.float()
81 | h = h.float()
82 | values = self.transform(torch.cat([g, h], dim=2))
83 | new_h = _recurrent_update(h, values)
84 |
85 | return new_h
86 |
--------------------------------------------------------------------------------
/models/tracker/model/transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/tracker/model/transformer/__init__.py
--------------------------------------------------------------------------------
/models/tracker/model/transformer/object_summarizer.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Optional
2 | from omegaconf import DictConfig
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from models.tracker.model.transformer.positional_encoding import PositionalEncoding
8 |
9 |
10 | # @torch.jit.script
11 | def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor,
12 | logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
13 | # value: B*num_objects*H*W*value_dim
14 | # logits: B*num_objects*H*W*num_summaries
15 | # masks: B*num_objects*H*W*num_summaries: 1 if allowed
16 | weights = logits.sigmoid() * masks
17 | # B*num_objects*num_summaries*value_dim
18 | sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value)
19 | # B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1
20 | area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1)
21 |
22 | # B*num_objects*num_summaries*value_dim
23 | return sums, area
24 |
25 |
26 | class ObjectSummarizer(nn.Module):
27 | def __init__(self, model_cfg: DictConfig):
28 | super().__init__()
29 |
30 | this_cfg = model_cfg.object_summarizer
31 | self.value_dim = model_cfg.value_dim
32 | self.embed_dim = this_cfg.embed_dim
33 | self.num_summaries = this_cfg.num_summaries
34 | self.add_pe = this_cfg.add_pe
35 | self.pixel_pe_scale = model_cfg.pixel_pe_scale
36 | self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
37 |
38 | if self.add_pe:
39 | self.pos_enc = PositionalEncoding(self.embed_dim,
40 | scale=self.pixel_pe_scale,
41 | temperature=self.pixel_pe_temperature)
42 |
43 | self.input_proj = nn.Linear(self.value_dim, self.embed_dim)
44 | self.feature_pred = nn.Sequential(
45 | nn.Linear(self.embed_dim, self.embed_dim),
46 | nn.ReLU(inplace=True),
47 | nn.Linear(self.embed_dim, self.embed_dim),
48 | )
49 | self.weights_pred = nn.Sequential(
50 | nn.Linear(self.embed_dim, self.embed_dim),
51 | nn.ReLU(inplace=True),
52 | nn.Linear(self.embed_dim, self.num_summaries),
53 | )
54 |
55 | def forward(self,
56 | masks: torch.Tensor,
57 | value: torch.Tensor,
58 | need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]):
59 | # masks: B*num_objects*(H0)*(W0)
60 | # value: B*num_objects*value_dim*H*W
61 | # -> B*num_objects*H*W*value_dim
62 | h, w = value.shape[-2:]
63 | masks = F.interpolate(masks, size=(h, w), mode='area')
64 | masks = masks.unsqueeze(-1)
65 | inv_masks = 1 - masks
66 | repeated_masks = torch.cat([
67 | masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
68 | inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
69 | ],
70 | dim=-1)
71 |
72 | value = value.permute(0, 1, 3, 4, 2)
73 | value = self.input_proj(value)
74 | if self.add_pe:
75 | pe = self.pos_enc(value)
76 | value = value + pe
77 |
78 | with torch.cuda.amp.autocast(enabled=False):
79 | value = value.float()
80 | feature = self.feature_pred(value)
81 | logits = self.weights_pred(value)
82 | sums, area = _weighted_pooling(repeated_masks, feature, logits)
83 |
84 | summaries = torch.cat([sums, area], dim=-1)
85 |
86 | if need_weights:
87 | return summaries, logits
88 | else:
89 | return summaries, None
--------------------------------------------------------------------------------
/models/tracker/model/transformer/positional_encoding.py:
--------------------------------------------------------------------------------
1 | # Reference:
2 | # https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py
3 | # https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py
4 |
5 | import math
6 |
7 | import numpy as np
8 | import torch
9 | from torch import nn
10 |
11 |
12 | def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
13 | """
14 | Gets a base embedding for one dimension with sin and cos intertwined
15 | """
16 | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
17 | return torch.flatten(emb, -2, -1)
18 |
19 |
20 | class PositionalEncoding(nn.Module):
21 | def __init__(self,
22 | dim: int,
23 | scale: float = math.pi * 2,
24 | temperature: float = 10000,
25 | normalize: bool = True,
26 | channel_last: bool = True,
27 | transpose_output: bool = False):
28 | super().__init__()
29 | dim = int(np.ceil(dim / 4) * 2)
30 | self.dim = dim
31 | inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim))
32 | self.register_buffer("inv_freq", inv_freq)
33 | self.normalize = normalize
34 | self.scale = scale
35 | self.eps = 1e-6
36 | self.channel_last = channel_last
37 | self.transpose_output = transpose_output
38 |
39 | self.cached_penc = None # the cache is irrespective of the number of objects
40 |
41 | def forward(self, tensor: torch.Tensor) -> torch.Tensor:
42 | """
43 | :param tensor: A 4/5d tensor of size
44 | channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c)
45 | channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w)
46 | :return: positional encoding tensor that has the same shape as the input if the input is 4d
47 | if the input is 5d, the output is broadcastable along the k-dimension
48 | """
49 | if len(tensor.shape) != 4 and len(tensor.shape) != 5:
50 | raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!')
51 |
52 | if len(tensor.shape) == 5:
53 | # take a sample from the k dimension
54 | num_objects = tensor.shape[1]
55 | tensor = tensor[:, 0]
56 | else:
57 | num_objects = None
58 |
59 | if self.channel_last:
60 | batch_size, h, w, c = tensor.shape
61 | else:
62 | batch_size, c, h, w = tensor.shape
63 |
64 | if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
65 | if num_objects is None:
66 | return self.cached_penc
67 | else:
68 | return self.cached_penc.unsqueeze(1)
69 |
70 | self.cached_penc = None
71 |
72 | pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype)
73 | pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype)
74 | if self.normalize:
75 | pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale
76 | pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale
77 |
78 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
79 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
80 | emb_y = get_emb(sin_inp_y).unsqueeze(1)
81 | emb_x = get_emb(sin_inp_x)
82 |
83 | emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype)
84 | emb[:, :, :self.dim] = emb_x
85 | emb[:, :, self.dim:] = emb_y
86 |
87 | if not self.channel_last and self.transpose_output:
88 | # cancelled out
89 | pass
90 | elif (not self.channel_last) or (self.transpose_output):
91 | emb = emb.permute(2, 0, 1)
92 |
93 | self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1)
94 | if num_objects is None:
95 | return self.cached_penc
96 | else:
97 | return self.cached_penc.unsqueeze(1)
98 |
99 |
100 | if __name__ == '__main__':
101 | pe = PositionalEncoding(8).cuda()
102 | input = torch.ones((1, 8, 8, 8)).cuda()
103 | output = pe(input)
104 | # print(output)
105 | print(output[0, :, 0, 0])
106 | print(output[0, :, 0, 5])
107 | print(output[0, 0, :, 0])
108 | print(output[0, 0, 0, :])
109 |
--------------------------------------------------------------------------------
/models/tracker/model/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/tracker/model/utils/__init__.py
--------------------------------------------------------------------------------
/models/tracker/model/utils/memory_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from typing import Optional, Union, Tuple
4 |
5 |
6 | # @torch.jit.script
7 | def get_similarity(mk: torch.Tensor,
8 | ms: torch.Tensor,
9 | qk: torch.Tensor,
10 | qe: torch.Tensor,
11 | add_batch_dim: bool = False) -> torch.Tensor:
12 | # used for training/inference and memory reading/memory potentiation
13 | # mk: B x CK x [N] - Memory keys
14 | # ms: B x 1 x [N] - Memory shrinkage
15 | # qk: B x CK x [HW/P] - Query keys
16 | # qe: B x CK x [HW/P] - Query selection
17 | # Dimensions in [] are flattened
18 | if add_batch_dim:
19 | mk, ms = mk.unsqueeze(0), ms.unsqueeze(0)
20 | qk, qe = qk.unsqueeze(0), qe.unsqueeze(0)
21 |
22 | CK = mk.shape[1]
23 | mk = mk.flatten(start_dim=2)
24 | ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
25 | qk = qk.flatten(start_dim=2)
26 | qe = qe.flatten(start_dim=2) if qe is not None else None
27 |
28 | if qe is not None:
29 | # See XMem's appendix for derivation
30 | mk = mk.transpose(1, 2)
31 | a_sq = (mk.pow(2) @ qe)
32 | two_ab = 2 * (mk @ (qk * qe))
33 | b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
34 | similarity = (-a_sq + two_ab - b_sq)
35 | else:
36 | # similar to STCN if we don't have the selection term
37 | a_sq = mk.pow(2).sum(1).unsqueeze(2)
38 | two_ab = 2 * (mk.transpose(1, 2) @ qk)
39 | similarity = (-a_sq + two_ab)
40 |
41 | if ms is not None:
42 | similarity = similarity * ms / math.sqrt(CK) # B*N*HW
43 | else:
44 | similarity = similarity / math.sqrt(CK) # B*N*HW
45 |
46 | return similarity
47 |
48 |
49 | def do_softmax(
50 | similarity: torch.Tensor,
51 | top_k: Optional[int] = None,
52 | inplace: bool = False,
53 | return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
54 | # normalize similarity with top-k softmax
55 | # similarity: B x N x [HW/P]
56 | # use inplace with care
57 | if top_k is not None:
58 | values, indices = torch.topk(similarity, k=top_k, dim=1)
59 |
60 | x_exp = values.exp_()
61 | x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
62 | if inplace:
63 | similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
64 | affinity = similarity
65 | else:
66 | affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
67 | else:
68 | maxes = torch.max(similarity, dim=1, keepdim=True)[0]
69 | x_exp = torch.exp(similarity - maxes)
70 | x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
71 | affinity = x_exp / x_exp_sum
72 | indices = None
73 |
74 | if return_usage:
75 | return affinity, affinity.sum(dim=2)
76 |
77 | return affinity
78 |
79 |
80 | def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor,
81 | qe: torch.Tensor) -> torch.Tensor:
82 | # shorthand used in training with no top-k
83 | similarity = get_similarity(mk, ms, qk, qe)
84 | affinity = do_softmax(similarity)
85 | return affinity
86 |
87 |
88 | def readout(affinity: torch.Tensor, mv: torch.Tensor) -> torch.Tensor:
89 | B, CV, T, H, W = mv.shape
90 |
91 | mo = mv.view(B, CV, T * H * W)
92 | mem = torch.bmm(mo, affinity)
93 | mem = mem.view(B, CV, H, W)
94 |
95 | return mem
96 |
--------------------------------------------------------------------------------
/models/tracker/model/utils/parameter_groups.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | log = logging.getLogger()
4 |
5 |
6 | def get_parameter_groups(model, stage_cfg, print_log=False):
7 | """
8 | Assign different weight decays and learning rates to different parameters.
9 | Returns a parameter group which can be passed to the optimizer.
10 | """
11 | weight_decay = stage_cfg.weight_decay
12 | embed_weight_decay = stage_cfg.embed_weight_decay
13 | backbone_lr_ratio = stage_cfg.backbone_lr_ratio
14 | base_lr = stage_cfg.learning_rate
15 |
16 | backbone_params = []
17 | embed_params = []
18 | other_params = []
19 |
20 | embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe']
21 | embedding_names = [e + '.weight' for e in embedding_names]
22 |
23 | # inspired by detectron2
24 | memo = set()
25 | for name, param in model.named_parameters():
26 | if not param.requires_grad:
27 | continue
28 | # Avoid duplicating parameters
29 | if param in memo:
30 | continue
31 | memo.add(param)
32 |
33 | if name.startswith('module'):
34 | name = name[7:]
35 |
36 | inserted = False
37 | if name.startswith('pixel_encoder.'):
38 | backbone_params.append(param)
39 | inserted = True
40 | if print_log:
41 | log.info(f'{name} counted as a backbone parameter.')
42 | else:
43 | for e in embedding_names:
44 | if name.endswith(e):
45 | embed_params.append(param)
46 | inserted = True
47 | if print_log:
48 | log.info(f'{name} counted as an embedding parameter.')
49 | break
50 |
51 | if not inserted:
52 | other_params.append(param)
53 |
54 | parameter_groups = [
55 | {
56 | 'params': backbone_params,
57 | 'lr': base_lr * backbone_lr_ratio,
58 | 'weight_decay': weight_decay
59 | },
60 | {
61 | 'params': embed_params,
62 | 'lr': base_lr,
63 | 'weight_decay': embed_weight_decay
64 | },
65 | {
66 | 'params': other_params,
67 | 'lr': base_lr,
68 | 'weight_decay': weight_decay
69 | },
70 | ]
71 |
72 | return parameter_groups
--------------------------------------------------------------------------------
/models/tracker/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/models/tracker/utils/__init__.py
--------------------------------------------------------------------------------
/models/tracker/utils/load_subset.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | def load_subset(path):
5 | with open(path, mode='r') as f:
6 | subset = set(f.read().splitlines())
7 | return subset
8 |
9 |
10 | def load_empty_masks(path):
11 | with open(path, mode='r') as f:
12 | empty_masks = json.load(f)
13 | return empty_masks
14 |
--------------------------------------------------------------------------------
/models/tracker/utils/log_integrator.py:
--------------------------------------------------------------------------------
1 | """
2 | Integrate numerical values for some iterations
3 | Typically used for loss computation / logging to tensorboard
4 | Call finalize and create a new Integrator when you want to display/log
5 | """
6 | from typing import Dict, Callable, Tuple
7 | import torch
8 | from tracker.utils.logger import TensorboardLogger
9 |
10 |
11 | class Integrator:
12 | def __init__(self, logger: TensorboardLogger, distributed: bool = True):
13 | self.values = {}
14 | self.counts = {}
15 | self.hooks = [] # List is used here to maintain insertion order
16 |
17 | self.logger = logger
18 |
19 | self.distributed = distributed
20 | self.local_rank = torch.distributed.get_rank()
21 | self.world_size = torch.distributed.get_world_size()
22 |
23 | def add_tensor(self, key: str, tensor: torch.Tensor):
24 | if key not in self.values:
25 | self.counts[key] = 1
26 | if type(tensor) == float or type(tensor) == int:
27 | self.values[key] = tensor
28 | else:
29 | self.values[key] = tensor.mean().item()
30 | else:
31 | self.counts[key] += 1
32 | if type(tensor) == float or type(tensor) == int:
33 | self.values[key] += tensor
34 | else:
35 | self.values[key] += tensor.mean().item()
36 |
37 | def add_dict(self, tensor_dict: Dict[str, torch.Tensor]):
38 | for k, v in tensor_dict.items():
39 | self.add_tensor(k, v)
40 |
41 | def add_hook(self, hook: Callable[[torch.Tensor], Tuple[str, torch.Tensor]]):
42 | """
43 | Adds a custom hook, i.e. compute new metrics using values in the dict
44 | The hook takes the dict as argument, and returns a (k, v) tuple
45 | e.g. for computing IoU
46 | """
47 | if type(hook) == list:
48 | self.hooks.extend(hook)
49 | else:
50 | self.hooks.append(hook)
51 |
52 | def reset_except_hooks(self):
53 | self.values = {}
54 | self.counts = {}
55 |
56 | # Average and output the metrics
57 | def finalize(self, exp_id: str, prefix: str, it: int) -> None:
58 |
59 | for hook in self.hooks:
60 | k, v = hook(self.values)
61 | self.add_tensor(k, v)
62 |
63 | outputs = {}
64 | for k, v in self.values.items():
65 |
66 | if k[:4] == 'hide':
67 | continue
68 |
69 | avg = v / self.counts[k]
70 |
71 | if self.distributed:
72 | # Inplace operation
73 | avg = torch.tensor(avg).cuda()
74 | torch.distributed.reduce(avg, dst=0)
75 |
76 | if self.local_rank == 0:
77 | avg = (avg / self.world_size).cpu().item()
78 | outputs[k] = avg
79 | else:
80 | # Simple does it
81 | outputs[k] = avg
82 |
83 | if (not self.distributed) or (self.local_rank == 0):
84 | self.logger.log_metrics(exp_id, prefix, outputs, it)
85 |
--------------------------------------------------------------------------------
/models/tracker/utils/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Dumps things to tensorboard and console
3 | """
4 |
5 | import os
6 | import logging
7 | import datetime
8 | from typing import Dict
9 | import numpy as np
10 | from PIL import Image
11 |
12 | from torch.utils.tensorboard import SummaryWriter
13 | from tracker.utils.time_estimator import TimeEstimator
14 |
15 |
16 | def tensor_to_numpy(image):
17 | image_np = (image.numpy() * 255).astype('uint8')
18 | return image_np
19 |
20 |
21 | def detach_to_cpu(x):
22 | return x.detach().cpu()
23 |
24 |
25 | def fix_width_trunc(x):
26 | return ('{:.9s}'.format('{:0.9f}'.format(x)))
27 |
28 |
29 | class TensorboardLogger:
30 | def __init__(self, run_dir, py_logger: logging.Logger, *, enabled_tb):
31 | self.run_dir = run_dir
32 | self.py_log = py_logger
33 | if enabled_tb:
34 | self.tb_log = SummaryWriter(run_dir)
35 | else:
36 | self.tb_log = None
37 |
38 | # Get current git info for logging
39 | try:
40 | import git
41 | repo = git.Repo(".")
42 | git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha)
43 | except (ImportError, RuntimeError):
44 | print('Failed to fetch git info. Defaulting to None')
45 | git_info = 'None'
46 |
47 | self.log_string('git', git_info)
48 |
49 | # used when logging metrics
50 | self.time_estimator: TimeEstimator = None
51 |
52 | def log_scalar(self, tag, x, it):
53 | if self.tb_log is None:
54 | return
55 | self.tb_log.add_scalar(tag, x, it)
56 |
57 | def log_metrics(self, exp_id, prefix, metrics: Dict, it):
58 | msg = f'{exp_id}-{prefix} - it {it:6d}: '
59 | metrics_msg = ''
60 | for k, v in sorted(metrics.items()):
61 | self.log_scalar(f'{prefix}/{k}', v, it)
62 | metrics_msg += f'{k: >10}:{v:.7f},\t'
63 |
64 | if self.time_estimator is not None:
65 | self.time_estimator.update()
66 | avg_time = self.time_estimator.get_and_reset_avg_time()
67 | est = self.time_estimator.get_est_remaining(it)
68 | est = datetime.timedelta(seconds=est)
69 | if est.days > 0:
70 | remaining_str = f'{est.days}d {est.seconds // 3600}h'
71 | else:
72 | remaining_str = f'{est.seconds // 3600}h {(est.seconds%3600) // 60}m'
73 | eta = datetime.datetime.now() + est
74 | eta_str = eta.strftime('%Y-%m-%d %H:%M:%S')
75 | time_msg = f'avg_time:{avg_time:.3f},remaining:{remaining_str},eta:{eta_str},\t'
76 | msg = f'{msg} {time_msg}'
77 |
78 | msg = f'{msg} {metrics_msg}'
79 | self.py_log.info(msg)
80 |
81 | def log_image(self, stage_name, tag, image, it):
82 | image_dir = os.path.join(self.run_dir, f'{stage_name}_images')
83 | os.makedirs(image_dir, exist_ok=True)
84 |
85 | image = Image.fromarray(image)
86 | image.save(os.path.join(image_dir, f'{tag}_{it}.png'))
87 |
88 | def log_string(self, tag, x):
89 | self.py_log.info(f'{tag} - {x}')
90 | if self.tb_log is None:
91 | return
92 | self.tb_log.add_text(tag, x)
93 |
94 | def debug(self, x):
95 | self.py_log.debug(x)
96 |
97 | def info(self, x):
98 | self.py_log.info(x)
99 |
100 | def warning(self, x):
101 | self.py_log.warning(x)
102 |
103 | def error(self, x):
104 | self.py_log.error(x)
105 |
106 | def critical(self, x):
107 | self.py_log.critical(x)
108 |
--------------------------------------------------------------------------------
/models/tracker/utils/mask_mapper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def all_to_onehot(masks, labels):
5 | if len(masks.shape) == 3:
6 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
7 | else:
8 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
9 |
10 | for ni, l in enumerate(labels):
11 | Ms[ni] = (masks == l).astype(np.uint8)
12 |
13 | return Ms
14 |
15 | class MaskMapper:
16 | """
17 | This class is used to convert a indexed-mask to a one-hot representation.
18 | It also takes care of remapping non-continuous indices
19 | It has two modes:
20 | 1. Default. Only masks with new indices are supposed to go into the remapper.
21 | This is also the case for YouTubeVOS.
22 | i.e., regions with index 0 are not "background", but "don't care".
23 |
24 | 2. Exhaustive. Regions with index 0 are considered "background".
25 | Every single pixel is considered to be "labeled".
26 | """
27 | def __init__(self):
28 | self.labels = []
29 | self.remappings = {}
30 |
31 | # if coherent, no mapping is required
32 | self.coherent = True
33 |
34 | def clear_labels(self):
35 | self.labels = []
36 | self.remappings = {}
37 | # if coherent, no mapping is required
38 | self.coherent = True
39 |
40 | def convert_mask(self, mask, exhaustive=False):
41 | # mask is in index representation, H*W numpy array
42 | labels = np.unique(mask).astype(np.uint8)
43 | labels = labels[labels!=0].tolist()
44 |
45 | new_labels = list(set(labels) - set(self.labels))
46 | if not exhaustive:
47 | assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode'
48 |
49 | # add new remappings
50 | for i, l in enumerate(new_labels):
51 | self.remappings[l] = i+len(self.labels)+1
52 | if self.coherent and i+len(self.labels)+1 != l:
53 | self.coherent = False
54 |
55 | if exhaustive:
56 | new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1)
57 | else:
58 | if self.coherent:
59 | new_mapped_labels = new_labels
60 | else:
61 | new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1)
62 |
63 | self.labels.extend(new_labels)
64 | # mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()
65 | mask = torch.from_numpy(mask).float()
66 | # mask num_objects*H*W
67 | return mask, new_mapped_labels
68 |
69 |
70 | def remap_index_mask(self, mask):
71 | # mask is in index representation, H*W numpy array
72 | if self.coherent:
73 | return mask
74 |
75 | new_mask = np.zeros_like(mask)
76 | for l, i in self.remappings.items():
77 | new_mask[mask==i] = l
78 | return new_mask
--------------------------------------------------------------------------------
/models/tracker/utils/palette.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0'
4 |
5 | youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f'
6 |
7 | davis_palette_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3)
8 |
9 | youtube_palette_np = np.frombuffer(youtube_palette, dtype=np.uint8).reshape(-1, 3)
--------------------------------------------------------------------------------
/models/tracker/utils/pano_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from threading import Lock
3 |
4 |
5 | class ID2RGBConverter:
6 | def __init__(self):
7 | self.all_id = []
8 | self.obj_to_id = {}
9 | self.lock = Lock()
10 |
11 | def _id_to_rgb(self, id: int):
12 | rgb = np.zeros((3, ), dtype=np.uint8)
13 | for i in range(3):
14 | rgb[i] = id % 256
15 | id = id // 256
16 | return rgb
17 |
18 | def convert(self, obj: int):
19 | with self.lock:
20 | if obj in self.obj_to_id:
21 | id = self.obj_to_id[obj]
22 | else:
23 | while True:
24 | id = np.random.randint(255, 256**3)
25 | if id not in self.all_id:
26 | break
27 | self.obj_to_id[obj] = id
28 | self.all_id.append(id)
29 |
30 | return id, self._id_to_rgb(id)
31 |
--------------------------------------------------------------------------------
/models/tracker/utils/point_features.py:
--------------------------------------------------------------------------------
1 | # This file is copied from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
2 | # such that users do not need to install detectron2 just for these two functions
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 |
5 | from typing import List
6 | import torch
7 | from torch.nn import functional as F
8 |
9 |
10 | def cat(tensors: List[torch.Tensor], dim: int = 0):
11 | """
12 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list
13 | """
14 | assert isinstance(tensors, (list, tuple))
15 | if len(tensors) == 1:
16 | return tensors[0]
17 | return torch.cat(tensors, dim)
18 |
19 |
20 | def calculate_uncertainty(sem_seg_logits):
21 | """
22 | For each location of the prediction `sem_seg_logits` we estimate uncerainty as the
23 | difference between top first and top second predicted logits.
24 | Args:
25 | mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and
26 | C is the number of foreground classes. The values are logits.
27 | Returns:
28 | scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with
29 | the most uncertain locations having the highest uncertainty score.
30 | """
31 | if sem_seg_logits.shape[1] == 2:
32 | # binary segmentation
33 | return -(torch.abs(sem_seg_logits[:, 1:2]))
34 | top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0]
35 | return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
36 |
37 |
38 | def point_sample(input, point_coords, **kwargs):
39 | """
40 | A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
41 | Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
42 | [0, 1] x [0, 1] square.
43 | Args:
44 | input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
45 | point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
46 | [0, 1] x [0, 1] normalized point coordinates.
47 | Returns:
48 | output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
49 | features for points in `point_coords`. The features are obtained via bilinear
50 | interpolation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
51 | """
52 | add_dim = False
53 | if point_coords.dim() == 3:
54 | add_dim = True
55 | point_coords = point_coords.unsqueeze(2)
56 | output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
57 | if add_dim:
58 | output = output.squeeze(3)
59 | return output
60 |
61 |
62 | def get_uncertain_point_coords_with_randomness(coarse_logits, uncertainty_func, num_points,
63 | oversample_ratio, importance_sample_ratio):
64 | """
65 | Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The uncertainties
66 | are calculated for each point using 'uncertainty_func' function that takes point's logit
67 | prediction as input.
68 | See PointRend paper for details.
69 | Args:
70 | coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
71 | class-specific or class-agnostic prediction.
72 | uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
73 | contains logit predictions for P points and returns their uncertainties as a Tensor of
74 | shape (N, 1, P).
75 | num_points (int): The number of points P to sample.
76 | oversample_ratio (int): Oversampling parameter.
77 | importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
78 | Returns:
79 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
80 | sampled points.
81 | """
82 | assert oversample_ratio >= 1
83 | assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
84 | num_boxes = coarse_logits.shape[0]
85 | num_sampled = int(num_points * oversample_ratio)
86 | point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
87 | point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
88 | # It is crucial to calculate uncertainty based on the sampled prediction value for the points.
89 | # Calculating uncertainties of the coarse predictions first and sampling them for points leads
90 | # to incorrect results.
91 | # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
92 | # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
93 | # However, if we calculate uncertainties for the coarse predictions first,
94 | # both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
95 | point_uncertainties = uncertainty_func(point_logits)
96 | num_uncertain_points = int(importance_sample_ratio * num_points)
97 | num_random_points = num_points - num_uncertain_points
98 | idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
99 | shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)
100 | idx += shift[:, None]
101 | point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points,
102 | 2)
103 | if num_random_points > 0:
104 | point_coords = cat(
105 | [
106 | point_coords,
107 | torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),
108 | ],
109 | dim=1,
110 | )
111 | return point_coords
--------------------------------------------------------------------------------
/models/tracker/utils/range_transform.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as transforms
2 |
3 | im_mean = (124, 116, 104)
4 |
5 | im_normalization = transforms.Normalize(
6 | mean=[0.485, 0.456, 0.406],
7 | std=[0.229, 0.224, 0.225]
8 | )
9 |
10 | inv_im_trans = transforms.Normalize(
11 | mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
12 | std=[1/0.229, 1/0.224, 1/0.225])
13 |
--------------------------------------------------------------------------------
/models/tracker/utils/tensor_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Iterable
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | # STM
7 | def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]):
8 | h, w = in_img.shape[-2:]
9 |
10 | if h % d > 0:
11 | new_h = h + d - h % d
12 | else:
13 | new_h = h
14 | if w % d > 0:
15 | new_w = w + d - w % d
16 | else:
17 | new_w = w
18 | lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
19 | lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
20 | pad_array = (int(lw), int(uw), int(lh), int(uh))
21 | out = F.pad(in_img, pad_array)
22 | return out, pad_array
23 |
24 |
25 | def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor:
26 | if len(img.shape) == 4:
27 | if pad[2] + pad[3] > 0:
28 | img = img[:, :, pad[2]:-pad[3], :]
29 | if pad[0] + pad[1] > 0:
30 | img = img[:, :, :, pad[0]:-pad[1]]
31 | elif len(img.shape) == 3:
32 | if pad[2] + pad[3] > 0:
33 | img = img[:, pad[2]:-pad[3], :]
34 | if pad[0] + pad[1] > 0:
35 | img = img[:, :, pad[0]:-pad[1]]
36 | elif len(img.shape) == 5:
37 | if pad[2] + pad[3] > 0:
38 | img = img[:, :, :, pad[2]:-pad[3], :]
39 | if pad[0] + pad[1] > 0:
40 | img = img[:, :, :, :, pad[0]:-pad[1]]
41 | else:
42 | raise NotImplementedError
43 | return img
44 |
45 |
46 | # @torch.jit.script
47 | def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor:
48 | with torch.cuda.amp.autocast(enabled=False):
49 | prob = prob.float()
50 | new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob],
51 | dim).clamp(1e-7, 1 - 1e-7)
52 | logits = torch.log((new_prob / (1 - new_prob)))
53 |
54 | return logits
55 |
56 |
57 | # @torch.jit.script
58 | def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor:
59 | # cls_gt: B*1*H*W
60 | B, _, H, W = cls_gt.shape
61 | one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1)
62 | return one_hot
--------------------------------------------------------------------------------
/models/tracker/utils/time_estimator.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | class TimeEstimator:
5 | def __init__(self, total_iter, step_size):
6 | self.avg_time_window = [] # window-based average
7 | self.exp_avg_time = None # exponential moving average
8 | self.alpha = 0.7 # for exponential moving average
9 |
10 | self.last_time = time.time() # would not be accurate for the first iteration but well
11 | self.total_iter = total_iter
12 | self.step_size = step_size
13 |
14 | self.buffering_exp = True
15 |
16 | # call this at a fixed interval
17 | # does not have to be every step
18 | def update(self):
19 | curr_time = time.time()
20 | time_per_iter = curr_time - self.last_time
21 | self.last_time = curr_time
22 |
23 | self.avg_time_window.append(time_per_iter)
24 |
25 | if self.buffering_exp:
26 | if self.exp_avg_time is not None:
27 | # discard the first iteration call to not pollute the ema
28 | self.buffering_exp = False
29 | self.exp_avg_time = time_per_iter
30 | else:
31 | self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter
32 |
33 | def get_est_remaining(self, it):
34 | if self.exp_avg_time is None:
35 | return 0
36 |
37 | remaining_iter = self.total_iter - it
38 | return remaining_iter * self.exp_avg_time / self.step_size
39 |
40 | def get_and_reset_avg_time(self):
41 | avg = sum(self.avg_time_window) / len(self.avg_time_window) / self.step_size
42 | self.avg_time_window = []
43 | return avg
44 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | edge-tts==6.1.9
2 | ultralytics == 8.0.120
3 | einops
4 | av
5 | sentencepiece
6 | cpm_kernels
7 | edge-tts==6.1.9
8 | ffmpeg-python==0.2
9 | clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33
10 | diffusers==0.21.4
11 | loguru
12 | numpy
13 | pandas
14 | opencv-python
15 | torchvision
16 | torch==2.1.1
17 | Pillow
18 | protobuf==4.23.4
19 | PyYAML
20 | segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588
21 | transformers
22 | gradio==3.39.0
23 | openai-whisper==20231117
24 | anyconfig
25 | mdtex2html
26 | librosa
27 | moviepy
28 | omegaconf
29 | SwissArmyTransformer>=0.4.4
30 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .chatglm_handler import ChatGLMHandler
3 | from .whisper_handler import WhisperHandler
4 | from .edgetts_handler import EdgeTTSHandler
5 | from .gpt_handler import GPTHandler
6 | from .ai_wrapper import AIWrapper
7 | from .fastsam_handler import FastSAMHandler
--------------------------------------------------------------------------------
/tools/base.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class BaseHandler:
4 | def __init__(self, args) -> None:
5 | self.model = None
6 | self.handle_args = args.get(type(self).__name__, {})
7 | if self.handle_args.get('init_model_when_start_server'):
8 | self.init_model()
9 |
10 | def infer(self, input_data, **kwargs):
11 | raise NotImplementedError
12 |
13 | def init_model(self):
14 | raise NotImplementedError
--------------------------------------------------------------------------------
/tools/chatglm_handler.py:
--------------------------------------------------------------------------------
1 | from .base import BaseHandler
2 | from utils.chatglm_utils import *
3 |
4 |
5 | class ChatGLMHandler(BaseHandler):
6 | def __init__(self, args):
7 | super().__init__(args)
8 | self.llm_model_path = self.handle_args.get('llm_model_path')
9 | self.num_gpus = self.handle_args.get('num_gpus', 2)
10 | if os.getenv('CUDA_VISIBLE_DEVICES'):
11 | self.num_gpus = min(self.num_gpus, len(os.environ['CUDA_VISIBLE_DEVICES']))
12 |
13 | self.trust_remote_code = self.handle_args.get('trust_remote_code', True)
14 | self.device = self.handle_args.get('device', 'cuda:0')
15 |
16 | def init_model(self):
17 | if self.model is None:
18 | self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_path, trust_remote_code=True)
19 | if 'int' in os.path.basename(self.llm_model_path):
20 | self.model = AutoModel.from_pretrained(self.llm_model_path, trust_remote_code=True).to(self.device)
21 | else:
22 | self.model = load_model_on_gpus(self.llm_model_path, num_gpus=self.num_gpus)
23 | self.model = self.model.eval()
24 |
25 | def infer(self, input_text, history=[], max_length=2048, top_p=90, temperature=95, **kwargs):
26 | self.init_model()
27 | response, history = self.model.chat(self.tokenizer, input_text, history=history, max_length=max_length, top_p=top_p, temperature=temperature)
28 | return response, history
29 |
30 | def stream_chat(self, input_text, chatbot, max_length, top_p, temperature, history, past_key_values):
31 | self.init_model()
32 | chatbot.append((parse_text(input_text), ""))
33 | for response, history, past_key_values in self.model.stream_chat(self.tokenizer, input_text, history, past_key_values=past_key_values,
34 | return_past_key_values=True,
35 | max_length=max_length, top_p=top_p,
36 | temperature=temperature):
37 | chatbot[-1] = (parse_text(input_text), parse_text(response))
38 |
39 | yield chatbot, history, past_key_values
40 |
41 |
42 |
--------------------------------------------------------------------------------
/tools/chatvlm_handler.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM
2 | from .base import BaseHandler
3 | from utils.chatglm_utils import *
4 |
5 |
6 | class ChatVLMHandler(BaseHandler):
7 | def __init__(self, args):
8 | super().__init__(args)
9 | self.vlm_model_path = self.handle_args.get('vlm_model_path')
10 | self.num_gpus = self.handle_args.get('num_gpus', 2)
11 | if os.getenv('CUDA_VISIBLE_DEVICES'):
12 | self.num_gpus = min(self.num_gpus, len(os.environ['CUDA_VISIBLE_DEVICES']))
13 |
14 | self.trust_remote_code = self.handle_args.get('trust_remote_code', True)
15 | self.device = self.handle_args.get('device', 'cuda:0')
16 |
17 | def init_model(self):
18 | if self.model is None:
19 | self.tokenizer = AutoTokenizer.from_pretrained(self.vlm_model_path, trust_remote_code=True)
20 | self.model = AutoModelForCausalLM.from_pretrained(self.vlm_model_path, device_map="cuda", trust_remote_code=self.trust_remote_code)
21 | self.model = self.model.eval()
22 |
23 | def chat(self, user_input, image, chatbot, history=None, **kwargs):
24 | self.init_model()
25 |
26 | query = self.tokenizer.from_list_format([
27 | {'image': self.preprocess(image)},
28 | {'text': user_input},
29 | ])
30 | response, history = self.model.chat(self.tokenizer, query=query, history=history)
31 | chatbot.append((parse_text(user_input), parse_text(response)))
32 | # image = tokenizer.draw_bbox_on_latest_picture(response, history)
33 | return chatbot, history
34 |
35 | def chat_stream(self, user_input, image, chatbot, history=None, **kwargs):
36 | self.init_model()
37 | chatbot.append((parse_text(user_input), ""))
38 | query = self.tokenizer.from_list_format([
39 | {'image': self.preprocess(image)},
40 | {'text': user_input},
41 | ])
42 | for response in self.model.chat_stream(self.tokenizer, query=query, history=history):
43 | chatbot[-1] = (parse_text(user_input), parse_text(response))
44 | yield chatbot, history
45 |
46 | def preprocess(self, image):
47 | return image
48 |
49 |
50 |
--------------------------------------------------------------------------------
/tools/edgetts_handler.py:
--------------------------------------------------------------------------------
1 | import os
2 | import edge_tts
3 | import random
4 | import asyncio
5 | import pandas as pd
6 | from datetime import datetime
7 | from .base import BaseHandler
8 | from loguru import logger
9 |
10 |
11 | async def streaming_with_subtitles(text, audio_file, webvtt_file, voice, rate, volume, pitch) -> None:
12 | communicate = edge_tts.Communicate(text=text, voice=voice, rate=rate, volume=volume, pitch=pitch)
13 | submaker = edge_tts.SubMaker()
14 | with open(audio_file, "wb") as file:
15 | async for chunk in communicate.stream():
16 | if chunk["type"] == "audio":
17 | file.write(chunk["data"])
18 | elif chunk["type"] == "WordBoundary":
19 | submaker.create_sub((chunk["offset"], chunk["duration"]), chunk["text"])
20 |
21 | with open(webvtt_file, "w", encoding="utf-8") as file:
22 | subtitles_info = submaker.generate_subs()
23 | file.write(subtitles_info)
24 | return subtitles_info
25 |
26 |
27 | class EdgeTTSHandler(BaseHandler):
28 | def __init__(self, args, **kwargs):
29 | # output_dir='/data1/zjx/ai_webui/products/audio_tmp',
30 | super().__init__(args)
31 | self.output_dir = self.handle_args.get("output_dir", "/tmp")
32 |
33 | def infer(self, tts_text_file=None, tts_text=None, voice='zh-CN-YunxiNeural', rate=0, volume=0, pitch=0, **kwargs):
34 | # 格式适配
35 | if rate >= 0:
36 | rate = f"+{rate}%"
37 | else:
38 | rate = f"{rate}%"
39 | if volume >= 0:
40 | volume = f"+{volume}%"
41 | else:
42 | volume = f"{volume}%"
43 | if pitch >= 0:
44 | pitch = f"+{pitch}Hz"
45 | else:
46 | pitch = f"{pitch}Hz"
47 |
48 | if tts_text_file:
49 | tts_text_file_path = tts_text_file.name
50 | text = ""
51 | with open(tts_text_file_path, "r") as f:
52 | for line in f:
53 | text += ' ' + line.rstrip()
54 | self.output_dir = os.path.dirname(tts_text_file_path)
55 | file_tag = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + "_" + os.path.basename(tts_text_file_path).split('.')[0]
56 | elif tts_text:
57 | file_tag = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + "_" + str(random.randint(0, 1000))
58 | text = tts_text
59 |
60 | audio_file = os.path.join(self.output_dir, file_tag + ".mp3")
61 | webvtt_file = os.path.join(self.output_dir, file_tag + ".vtt")
62 | loop = asyncio.new_event_loop() # .get_event_loop_policy()
63 | asyncio.set_event_loop(loop)
64 | subtitles_info = loop.run_until_complete(streaming_with_subtitles(text, audio_file, webvtt_file, voice, rate, volume, pitch))
65 | loop.close()
66 |
67 | # 后处理
68 | proc_subtitles = []
69 | contents = subtitles_info.split("\r\n")
70 | for idx in range(1, len(contents)):
71 | if ' --> ' not in contents[idx]:
72 | continue
73 | start, end = contents[idx].split(' --> ')
74 | sentence = contents[idx+1].replace(' ', '')
75 | proc_subtitles.append({
76 | "start": start,
77 | "end": end,
78 | "sentence": sentence
79 | })
80 | df = pd.DataFrame(proc_subtitles)
81 | srt_file = webvtt_file.replace(".vtt", ".srt")
82 | excel_file = webvtt_file.replace(".vtt", ".xlsx")
83 | with open(srt_file, "w", encoding="utf-8") as f:
84 | for idx, row in enumerate(proc_subtitles):
85 | f.write(f"{idx+1}\n")
86 | f.write(f"{row['start']} --> {row['end']}\n")
87 | f.write(f"{row['sentence']}\n\n")
88 |
89 | df.to_excel(excel_file, index=False)
90 | file_list = [audio_file, webvtt_file, srt_file, excel_file]
91 | return audio_file, file_list
92 |
93 | def init_model(self):
94 | # Nothing to do
95 | logger.warning("## 无需初始化EdgeTTSHandler的模型!")
96 |
97 | if __name__ == "__main__":
98 | text_file = "/data1/zjx/ai_webui/raw_materials/test.txt"
99 | text_type = "file"
100 | audio_file = "test.mp3"
101 | webvtt_file = "test.vtt"
102 | voice = "zh-CN-YunxiNeural"
103 | edge_tts_handle = EdgeTTSHandle()
104 | # edge_tts_engine.streaming_with_subtitles(text_file, text_type, audio_file, webvtt_file, voice)
105 | # asyncio.run(amain())
106 |
107 |
108 |
--------------------------------------------------------------------------------
/tools/fastsam_handler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from ultralytics import YOLO
3 | from .base import BaseHandler
4 | from PIL import ImageDraw
5 | import gradio as gr
6 | from utils.fastsam.tools import format_results, fast_process, point_prompt, text_prompt
7 |
8 |
9 | class FastSAMHandler(BaseHandler):
10 | """
11 | FastSAMHandler is a handler for FastSAM.
12 | """
13 |
14 | def __init__(self, args):
15 | super().__init__(args)
16 | self.fastsam_model_path = self.handle_args.get('fastsam_model_path')
17 | self.device = self.handle_args.get('device')
18 | self.global_points = []
19 | self.global_point_label = []
20 |
21 | def segment_everything(
22 | self,
23 | input,
24 | input_size=1024,
25 | iou_threshold=0.7,
26 | conf_threshold=0.25,
27 | better_quality=False,
28 | withContours=True,
29 | use_retina=True,
30 | text="",
31 | wider=False,
32 | mask_random_color=True,
33 | ):
34 | self.init_model()
35 | input_size = int(input_size) # 确保 imgsz 是整数
36 | # Thanks for the suggestion by hysts in HuggingFace.
37 | w, h = input.size
38 | scale = input_size / max(w, h)
39 | new_w = int(w * scale)
40 | new_h = int(h * scale)
41 | input = input.resize((new_w, new_h))
42 |
43 | results = self.model(input,
44 | device=self.device,
45 | retina_masks=True,
46 | iou=iou_threshold,
47 | conf=conf_threshold,
48 | imgsz=input_size,)
49 |
50 | if len(text) > 0:
51 | results = format_results(results[0], 0)
52 | annotations, _ = text_prompt(results, text, input, device=self.device, wider=wider)
53 | annotations = np.array([annotations])
54 | else:
55 | annotations = results[0].masks.data
56 |
57 | fig = fast_process(annotations=annotations,
58 | image=input,
59 | device=self.device,
60 | scale=(1024 // input_size),
61 | better_quality=better_quality,
62 | mask_random_color=mask_random_color,
63 | bbox=None,
64 | use_retina=use_retina,
65 | withContours=withContours,)
66 | return fig
67 |
68 | def segment_with_points(
69 | self,
70 | input,
71 | input_size=1024,
72 | iou_threshold=0.7,
73 | conf_threshold=0.25,
74 | better_quality=False,
75 | withContours=True,
76 | use_retina=True,
77 | mask_random_color=True,
78 | ):
79 | self.init_model()
80 | input_size = int(input_size) # 确保 imgsz 是整数
81 | # Thanks for the suggestion by hysts in HuggingFace.
82 | w, h = input.size
83 | scale = input_size / max(w, h)
84 | new_w = int(w * scale)
85 | new_h = int(h * scale)
86 | input = input.resize((new_w, new_h))
87 |
88 | scaled_points = [[int(x * scale) for x in point] for point in self.global_points]
89 |
90 | results = self.model(input,
91 | device=self.device,
92 | retina_masks=True,
93 | iou=iou_threshold,
94 | conf=conf_threshold,
95 | imgsz=input_size,)
96 |
97 | results = format_results(results[0], 0)
98 | annotations, _ = point_prompt(results, scaled_points, self.global_point_label, new_h, new_w)
99 | annotations = np.array([annotations])
100 |
101 | fig = fast_process(annotations=annotations,
102 | image=input,
103 | device=self.device,
104 | scale=(1024 // input_size),
105 | better_quality=better_quality,
106 | mask_random_color=mask_random_color,
107 | bbox=None,
108 | use_retina=use_retina,
109 | withContours=withContours,)
110 |
111 | self.global_points = []
112 | self.global_point_label = []
113 | return fig, None
114 |
115 | def get_points_with_draw(self, image, label, evt: gr.SelectData):
116 | x, y = evt.index[0], evt.index[1]
117 | point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
118 | self.global_points.append([x, y])
119 | self.global_point_label.append(1 if label == 'Add Mask' else 0)
120 |
121 | print(x, y, label == 'Add Mask')
122 |
123 | # 创建一个可以在图像上绘图的对象
124 | draw = ImageDraw.Draw(image)
125 | draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
126 | return image
127 |
128 |
129 | def init_model(self):
130 | if self.model is None:
131 | self.model = YOLO(self.fastsam_model_path)
132 |
--------------------------------------------------------------------------------
/tools/gpt_handler.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import json
3 | try:
4 | from .base import BaseHandler
5 | except:
6 | from base import BaseHandler
7 |
8 |
9 | key="""Bearer eyJhbGciOiJSUzI1NiIsInNvdXJjZSI6IklBTSIsImtpZCI6IjIwTml0b3h6S1FxMHJPdTlvYVk2UzA3bFBnUTBjSnFQIn0.eyJzdWIiOiI2MDAwMDI5NTIiLCJpc3MiOiJodHRwczpcL1wvZGV2LWEuc3p6aGlqaW5nLmNvbVwvaWFtLW9uZWlkIiwiYmFpQnVVc2VySWQiOiIyMDkxOCIsImZpbmFuY2VVc2VySWQiOiIxMDUyNjI4NTM2MTY1MDA3MzYwIiwiZmVpU2h1SWQiOiJvbl8xOTdkMTRmZjM5MmFhYjZhM2I2ZWQ0NmQ4MWY3NjI4NSIsImF1ZCI6IjQ0MTEyODg1MTkwNTU4MTA1NiIsInNjb3BlIjoib3BlbmlkIiwibmFtZSI6IuacsemUpuelpSIsImZpbmFuY2VTaWduVXNlcklkIjoiMTAyMzI1MzM2MjI4OTI1MDMwNCIsInVzZXJUeXBlIjoiSU5ORVIiLCJleHAiOjE3MDIzNjcxODIsImlhdCI6MTcwMTc2MjM4MiwianRpIjoiY2ZhMmE1NjBhMWU4NGY5NmE0NmMyZmZhMDM4YjUwYTYiLCJlbWFpbCI6InpodWppbnhpYW5nQHpqLnRlY2giLCJhY2NvdW50Ijoiemh1amlueGlhbmcifQ.RWw-8EaD3hMoiNmcXEfR0Y-A1R35aMqCWIKWnse50YM4INBtQNivVHFbsZZRNwup8XRTQe7PWQl0dPfoGH-Kz7eR3xCIZ6--ATW_PP8CbqGTCnbWAbVxT3enXBZjpgx5qE-JN5g-ko5bwpPylah4Wg2B8n6T87wC8Iczc-Aps4L7oevWjPCrne8tha4g3AWmuWGgn00LJjy4cQlIK9ETqLeVJJAZEDm82GLpYePISFERzs4-olnGz5IKcALtVXnX_KAXIloy5q1f3TgeNfPNkfdL1_TeXsrj-TZkJFW9OUVmNtiW7yFOdBN9t9Gv6HfP-Onva9pnOankMSLNXyOLQQ"""
10 | class GPTHandler(BaseHandler):
11 | def __init__(self, args={}, **kwargs):
12 | super().__init__(args)
13 | self.api_url = self.handle_args.get("api_url")
14 | self.headers = {
15 | 'Content-Type': 'application/json',
16 | 'Authorization': key
17 | }
18 |
19 | def infer(self, input_text, **kwargs):
20 | data = {
21 | 'messages':[
22 | {
23 | 'role': 'system',
24 | 'content': input_text,
25 | # 'history': history
26 | }
27 | ]
28 | }
29 | result = requests.post(self.api_url, headers=self.headers, data=json.dumps(data))
30 | text = result.json()['data']['choices'][0]['message']['content']
31 | return text, None
32 |
33 |
34 | if __name__ == "__main__":
35 | gpt_handler = GPTHandler({})
36 | input_text = '帮我分析这个报错:ImportError: attempted relative import with no known parent package'
37 | text,_ = gpt_handler.infer(input_text)
38 | print(text)
39 |
40 |
41 |
--------------------------------------------------------------------------------
/tools/visualglm_handler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base import BaseHandler
3 | from utils.chatglm_utils import *
4 |
5 |
6 | class VisualGLMHandler(BaseHandler):
7 | def __init__(self, args):
8 | super().__init__(args)
9 | self.model_path = self.handle_args.get('model_path')
10 | self.trust_remote_code = self.handle_args.get('trust_remote_code', True)
11 | self.device = self.handle_args.get('device', 'cuda:0')
12 | self.quant = self.handle_args.get('quant')
13 |
14 | def init_model(self):
15 | if self.model is None:
16 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=self.trust_remote_code)
17 | if self.quant in [4, 8]:
18 | self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=self.trust_remote_code).quantize(self.quant).half().to(self.device)
19 | else:
20 | self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=self.trust_remote_code).half().to(self.device)
21 | self.model = self.model.eval()
22 |
23 | def stream_chat(self, input, image_path, chatbot, max_length, top_p, temperature, history):
24 | self.init_model()
25 | if image_path is None:
26 | return [(input, "图片不能为空。请重新上传图片并重试。")], []
27 | chatbot.append((parse_text(input), ""))
28 | with torch.no_grad():
29 | for response, history in self.model.stream_chat(self.tokenizer, image_path, input, history, max_length=max_length, top_p=top_p,
30 | temperature=temperature):
31 | chatbot[-1] = (parse_text(input), parse_text(response))
32 |
33 | yield chatbot, history
34 |
35 | def stream_chat2(self, image_path, chatbot, max_length, top_p, temperature):
36 | self.init_model()
37 | input, history = "描述这张图片。", []
38 | chatbot.append((parse_text(input), ""))
39 | with torch.no_grad():
40 | for response, history in self.model.stream_chat(self.tokenizer, image_path, input, history, max_length=max_length,
41 | top_p=top_p,
42 | temperature=temperature):
43 | chatbot[-1] = (parse_text(input), parse_text(response))
44 |
45 | yield chatbot, history
46 |
47 |
48 |
--------------------------------------------------------------------------------
/tools/whisper_handler.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .base import BaseHandler
3 | import whisper
4 | import pandas as pd
5 |
6 | class WhisperHandler(BaseHandler):
7 | def __init__(self, args, **kwargs):
8 | # model_name, model_dir, device='cuda:0'
9 | super().__init__(args)
10 | self.model_name = self.handle_args.get('model_name')
11 | self.device = self.handle_args.get('device', 'cuda:0')
12 | self.model_dir = self.handle_args.get('model_dir', os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models'))
13 | self.language_map = self.handle_args.get('language_map', {'英语': 'en', '普通话': 'zh'})
14 |
15 | def infer(self, audio_file, language=None):
16 | self.init_model()
17 | if not isinstance(audio_file, str):
18 | audio_file = audio_file.name
19 | save_dir = os.path.dirname(audio_file)
20 | audio_name = os.path.basename(audio_file).split('.')[0]
21 | os.makedirs(save_dir, exist_ok=True)
22 | excel_file = os.path.join(save_dir, f"{audio_name}_speech_res.xlsx")
23 | result = self.model.transcribe(audio_file, language=self.language_map.get(language), verbose=True)
24 | # rec_language = result['language']
25 | df = pd.DataFrame(result['segments'])
26 | df = df[['id', 'no_speech_prob', 'start', 'end', 'text']]
27 | df.to_excel(excel_file, index=False)
28 | text = ' '.join(df['text'].tolist())
29 | result['text'] = text
30 | result['segments'] = excel_file
31 | return result
32 |
33 | def predict_v2(self, audio_file):
34 | # load excel_file
35 | audio = whisper.load_audio(audio_file)
36 | audio = whisper.pad_or_trim(audio)
37 |
38 | # make log-Mel spectrogram and move to the same device as the model
39 | mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
40 |
41 | # detect the spoken language
42 | _, probs = self.model.detect_language(mel)
43 | print(f"Detected language: {max(probs, key=probs.get)}")
44 |
45 | # decode the audio
46 | options = whisper.DecodingOptions()
47 | result = whisper.decode(self.model, mel, options)
48 |
49 | # print the recognized text
50 | print(result.text)
51 |
52 | def init_model(self):
53 | if self.model is None:
54 | self.model = whisper.load_model(name=self.model_name, device=self.device, download_root=self.model_dir)
55 |
56 |
57 | LANGUAGES = {
58 | "en": "english",
59 | "zh": "chinese",
60 | "de": "german",
61 | "es": "spanish",
62 | "ru": "russian",
63 | "ko": "korean",
64 | "fr": "french",
65 | "ja": "japanese",
66 | "pt": "portuguese",
67 | "tr": "turkish",
68 | "pl": "polish",
69 | "ca": "catalan",
70 | "nl": "dutch",
71 | "ar": "arabic",
72 | "sv": "swedish",
73 | "it": "italian",
74 | "id": "indonesian",
75 | "hi": "hindi",
76 | "fi": "finnish",
77 | "vi": "vietnamese",
78 | "he": "hebrew",
79 | "uk": "ukrainian",
80 | "el": "greek",
81 | "ms": "malay",
82 | "cs": "czech",
83 | "ro": "romanian",
84 | "da": "danish",
85 | "hu": "hungarian",
86 | "ta": "tamil",
87 | "no": "norwegian",
88 | "th": "thai",
89 | "ur": "urdu",
90 | "hr": "croatian",
91 | "bg": "bulgarian",
92 | "lt": "lithuanian",
93 | "la": "latin",
94 | "mi": "maori",
95 | "ml": "malayalam",
96 | "cy": "welsh",
97 | "sk": "slovak",
98 | "te": "telugu",
99 | "fa": "persian",
100 | "lv": "latvian",
101 | "bn": "bengali",
102 | "sr": "serbian",
103 | "az": "azerbaijani",
104 | "sl": "slovenian",
105 | "kn": "kannada",
106 | "et": "estonian",
107 | "mk": "macedonian",
108 | "br": "breton",
109 | "eu": "basque",
110 | "is": "icelandic",
111 | "hy": "armenian",
112 | "ne": "nepali",
113 | "mn": "mongolian",
114 | "bs": "bosnian",
115 | "kk": "kazakh",
116 | "sq": "albanian",
117 | "sw": "swahili",
118 | "gl": "galician",
119 | "mr": "marathi",
120 | "pa": "punjabi",
121 | "si": "sinhala",
122 | "km": "khmer",
123 | "sn": "shona",
124 | "yo": "yoruba",
125 | "so": "somali",
126 | "af": "afrikaans",
127 | "oc": "occitan",
128 | "ka": "georgian",
129 | "be": "belarusian",
130 | "tg": "tajik",
131 | "sd": "sindhi",
132 | "gu": "gujarati",
133 | "am": "amharic",
134 | "yi": "yiddish",
135 | "lo": "lao",
136 | "uz": "uzbek",
137 | "fo": "faroese",
138 | "ht": "haitian creole",
139 | "ps": "pashto",
140 | "tk": "turkmen",
141 | "nn": "nynorsk",
142 | "mt": "maltese",
143 | "sa": "sanskrit",
144 | "lb": "luxembourgish",
145 | "my": "myanmar",
146 | "bo": "tibetan",
147 | "tl": "tagalog",
148 | "mg": "malagasy",
149 | "as": "assamese",
150 | "tt": "tatar",
151 | "haw": "hawaiian",
152 | "ln": "lingala",
153 | "ha": "hausa",
154 | "ba": "bashkir",
155 | "jw": "javanese",
156 | "su": "sundanese",
157 | "yue": "cantonese",
158 | }
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonaidm/ai_webui/45474a80d073004df7f451aadb00a470ada1fc6f/utils/__init__.py
--------------------------------------------------------------------------------
/utils/chatglm_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Union, Optional, Tuple
3 | from torch.nn import Module
4 | from transformers import AutoModel, AutoTokenizer
5 |
6 |
7 | def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
8 | # transformer.word_embeddings 占用1层
9 | # transformer.final_layernorm 和 lm_head 占用1层
10 | # transformer.layers 占用 28 层
11 | # 总共30层分配到num_gpus张卡上
12 | num_trans_layers = 28
13 | per_gpu_layers = 30 / num_gpus
14 |
15 | # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
16 | # windows下 model.device 会被设置成 transformer.word_embeddings.device
17 | # linux下 model.device 会被设置成 lm_head.device
18 | # 在调用chat或者stream_chat时,input_ids会被放到model.device上
19 | # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
20 | # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
21 | # 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py
22 | # 仅此处做少许修改以支持ChatGLM3
23 | device_map = {
24 | 'transformer.embedding.word_embeddings': 0,
25 | 'transformer.encoder.final_layernorm': 0,
26 | 'transformer.output_layer': 0,
27 | 'transformer.rotary_pos_emb': 0,
28 | 'lm_head': 0
29 | }
30 |
31 | used = 2
32 | gpu_target = 0
33 | for i in range(num_trans_layers):
34 | if used >= per_gpu_layers:
35 | gpu_target += 1
36 | used = 0
37 | assert gpu_target < num_gpus
38 | device_map[f'transformer.encoder.layers.{i}'] = gpu_target
39 | used += 1
40 |
41 | return device_map
42 |
43 |
44 | def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
45 | device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module:
46 | if num_gpus < 2 and device_map is None:
47 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
48 | else:
49 | from accelerate import dispatch_model
50 |
51 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half()
52 |
53 | if device_map is None:
54 | device_map = auto_configure_device_map(num_gpus)
55 |
56 | model = dispatch_model(model, device_map=device_map)
57 |
58 | return model
59 |
60 |
61 | def parse_text(text):
62 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
63 | lines = text.split("\n")
64 | lines = [line for line in lines if line != ""]
65 | count = 0
66 | for i, line in enumerate(lines):
67 | if "```" in line:
68 | count += 1
69 | items = line.split('`')
70 | if count % 2 == 1:
71 | lines[i] = f''
72 | else:
73 | lines[i] = f'
'
74 | else:
75 | if i > 0:
76 | if count % 2 == 1:
77 | line = line.replace("`", "\`")
78 | line = line.replace("<", "<")
79 | line = line.replace(">", ">")
80 | line = line.replace(" ", " ")
81 | line = line.replace("*", "*")
82 | line = line.replace("_", "_")
83 | line = line.replace("-", "-")
84 | line = line.replace(".", ".")
85 | line = line.replace("!", "!")
86 | line = line.replace("(", "(")
87 | line = line.replace(")", ")")
88 | line = line.replace("$", "$")
89 | lines[i] = "
"+line
90 | text = "".join(lines)
91 | return text
92 |
--------------------------------------------------------------------------------
/utils/fastsam/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import FastSAM
2 | from .predict import FastSAMPredictor
3 | from .prompt import FastSAMPrompt
4 | # from .val import FastSAMValidator
5 | from .decoder import FastSAMDecoder
6 | from .tools import format_results, fast_process, point_prompt, text_prompt
7 | from PIL import ImageDraw
8 | import numpy as np
9 |
10 | __all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMDecoder'
11 |
--------------------------------------------------------------------------------
/utils/fastsam/decoder.py:
--------------------------------------------------------------------------------
1 | from .model import FastSAM
2 | import numpy as np
3 | from PIL import Image
4 | import clip
5 | from typing import Optional, List, Tuple, Union
6 |
7 |
8 | class FastSAMDecoder:
9 | def __init__(
10 | self,
11 | model: FastSAM,
12 | device: str='cpu',
13 | conf: float=0.4,
14 | iou: float=0.9,
15 | imgsz: int=1024,
16 | retina_masks: bool=True,
17 | ):
18 | self.model = model
19 | self.device = device
20 | self.retina_masks = retina_masks
21 | self.imgsz = imgsz
22 | self.conf = conf
23 | self.iou = iou
24 | self.image = None
25 | self.image_embedding = None
26 |
27 | def run_encoder(self, image):
28 | if isinstance(image,str):
29 | image = np.array(Image.open(image))
30 | self.image = image
31 | image_embedding = self.model(
32 | self.image,
33 | device=self.device,
34 | retina_masks=self.retina_masks,
35 | imgsz=self.imgsz,
36 | conf=self.conf,
37 | iou=self.iou
38 | )
39 | return image_embedding[0].numpy()
40 |
41 | def run_decoder(
42 | self,
43 | image_embedding,
44 | point_prompt: Optional[np.ndarray]=None,
45 | point_label: Optional[np.ndarray]=None,
46 | box_prompt: Optional[np.ndarray]=None,
47 | text_prompt: Optional[str]=None,
48 | )->np.ndarray:
49 | self.image_embedding = image_embedding
50 | if point_prompt is not None:
51 | ann = self.point_prompt(points=point_prompt, pointlabel=point_label)
52 | return ann
53 | elif box_prompt is not None:
54 | ann = self.box_prompt(bbox=box_prompt)
55 | return ann
56 | elif text_prompt is not None:
57 | ann = self.text_prompt(text=text_prompt)
58 | return ann
59 | else:
60 | return None
61 |
62 | def box_prompt(self, bbox):
63 | assert (bbox[2] != 0 and bbox[3] != 0)
64 | masks = self.image_embedding.masks.data
65 | target_height = self.image.shape[0]
66 | target_width = self.image.shape[1]
67 | h = masks.shape[1]
68 | w = masks.shape[2]
69 | if h != target_height or w != target_width:
70 | bbox = [
71 | int(bbox[0] * w / target_width),
72 | int(bbox[1] * h / target_height),
73 | int(bbox[2] * w / target_width),
74 | int(bbox[3] * h / target_height), ]
75 | bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
76 | bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
77 | bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
78 | bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
79 |
80 | # IoUs = torch.zeros(len(masks), dtype=torch.float32)
81 | bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
82 |
83 | masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2))
84 | orig_masks_area = np.sum(masks, axis=(1, 2))
85 |
86 | union = bbox_area + orig_masks_area - masks_area
87 | IoUs = masks_area / union
88 | max_iou_index = np.argmax(IoUs)
89 |
90 | return np.array([masks[max_iou_index].cpu().numpy()])
91 |
92 | def point_prompt(self, points, pointlabel): # numpy
93 |
94 | masks = self._format_results(self.image_embedding[0], 0)
95 | target_height = self.image.shape[0]
96 | target_width = self.image.shape[1]
97 | h = masks[0]['segmentation'].shape[0]
98 | w = masks[0]['segmentation'].shape[1]
99 | if h != target_height or w != target_width:
100 | points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
101 | onemask = np.zeros((h, w))
102 | masks = sorted(masks, key=lambda x: x['area'], reverse=True)
103 | for i, annotation in enumerate(masks):
104 | if type(annotation) == dict:
105 | mask = annotation['segmentation']
106 | else:
107 | mask = annotation
108 | for i, point in enumerate(points):
109 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
110 | onemask[mask] = 1
111 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
112 | onemask[mask] = 0
113 | onemask = onemask >= 1
114 | return np.array([onemask])
115 |
116 | def _format_results(self, result, filter=0):
117 | annotations = []
118 | n = len(result.masks.data)
119 | for i in range(n):
120 | annotation = {}
121 | mask = result.masks.data[i] == 1.0
122 |
123 | if np.sum(mask) < filter:
124 | continue
125 | annotation['id'] = i
126 | annotation['segmentation'] = mask
127 | annotation['bbox'] = result.boxes.data[i]
128 | annotation['score'] = result.boxes.conf[i]
129 | annotation['area'] = annotation['segmentation'].sum()
130 | annotations.append(annotation)
131 | return annotations
132 |
--------------------------------------------------------------------------------
/utils/fastsam/model.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | FastSAM model interface.
4 |
5 | Usage - Predict:
6 | from ultralytics import FastSAM
7 |
8 | model = FastSAM('last.pt')
9 | results = model.predict('ultralytics/assets/bus.jpg')
10 | """
11 |
12 | from ultralytics.yolo.cfg import get_cfg
13 | from ultralytics.yolo.engine.exporter import Exporter
14 | from ultralytics.yolo.engine.model import YOLO
15 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
16 | from ultralytics.yolo.utils.checks import check_imgsz
17 |
18 | from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
19 | from .predict import FastSAMPredictor
20 |
21 |
22 | class FastSAM(YOLO):
23 |
24 | @smart_inference_mode()
25 | def predict(self, source=None, stream=False, **kwargs):
26 | """
27 | Perform prediction using the YOLO model.
28 |
29 | Args:
30 | source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
31 | Accepts all source types accepted by the YOLO model.
32 | stream (bool): Whether to stream the predictions or not. Defaults to False.
33 | **kwargs : Additional keyword arguments passed to the predictor.
34 | Check the 'configuration' section in the documentation for all available options.
35 |
36 | Returns:
37 | (List[ultralytics.yolo.engine.results.Results]): The prediction results.
38 | """
39 | if source is None:
40 | source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
41 | LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
42 | overrides = self.overrides.copy()
43 | overrides['conf'] = 0.25
44 | overrides.update(kwargs) # prefer kwargs
45 | overrides['mode'] = kwargs.get('mode', 'predict')
46 | assert overrides['mode'] in ['track', 'predict']
47 | overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
48 | self.predictor = FastSAMPredictor(overrides=overrides)
49 | self.predictor.setup_model(model=self.model, verbose=False)
50 | try:
51 | return self.predictor(source, stream=stream)
52 | except Exception as e:
53 | return None
54 |
55 | def train(self, **kwargs):
56 | """Function trains models but raises an error as FastSAM models do not support training."""
57 | raise NotImplementedError("Currently, the training codes are on the way.")
58 |
59 | def val(self, **kwargs):
60 | """Run validation given dataset."""
61 | overrides = dict(task='segment', mode='val')
62 | overrides.update(kwargs) # prefer kwargs
63 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
64 | args.imgsz = check_imgsz(args.imgsz, max_dim=1)
65 | validator = FastSAM(args=args)
66 | validator(model=self.model)
67 | self.metrics = validator.metrics
68 | return validator.metrics
69 |
70 | @smart_inference_mode()
71 | def export(self, **kwargs):
72 | """
73 | Export model.
74 |
75 | Args:
76 | **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
77 | """
78 | overrides = dict(task='detect')
79 | overrides.update(kwargs)
80 | overrides['mode'] = 'export'
81 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
82 | args.task = self.task
83 | if args.imgsz == DEFAULT_CFG.imgsz:
84 | args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
85 | if args.batch == DEFAULT_CFG.batch:
86 | args.batch = 1 # default to 1 if not modified
87 | return Exporter(overrides=args)(model=self.model)
88 |
89 | def info(self, detailed=False, verbose=True):
90 | """
91 | Logs model info.
92 |
93 | Args:
94 | detailed (bool): Show detailed information about model.
95 | verbose (bool): Controls verbosity.
96 | """
97 | return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
98 |
99 | def __call__(self, source=None, stream=False, **kwargs):
100 | """Calls the 'predict' function with given arguments to perform object detection."""
101 | return self.predict(source, stream, **kwargs)
102 |
103 | def __getattr__(self, attr):
104 | """Raises error if object has no requested attribute."""
105 | name = self.__class__.__name__
106 | raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
107 |
--------------------------------------------------------------------------------
/utils/fastsam/predict.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ultralytics.yolo.engine.results import Results
4 | from ultralytics.yolo.utils import DEFAULT_CFG, ops
5 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor
6 | from .utils import bbox_iou
7 |
8 | class FastSAMPredictor(DetectionPredictor):
9 |
10 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
11 | super().__init__(cfg, overrides, _callbacks)
12 | self.args.task = 'segment'
13 |
14 | def postprocess(self, preds, img, orig_imgs):
15 | """TODO: filter by classes."""
16 | p = ops.non_max_suppression(preds[0],
17 | self.args.conf,
18 | self.args.iou,
19 | agnostic=self.args.agnostic_nms,
20 | max_det=self.args.max_det,
21 | nc=len(self.model.names),
22 | classes=self.args.classes)
23 |
24 | results = []
25 | if len(p) == 0 or len(p[0]) == 0:
26 | print("No object detected.")
27 | return results
28 |
29 | full_box = torch.zeros_like(p[0][0])
30 | full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
31 | full_box = full_box.view(1, -1)
32 | critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
33 | if critical_iou_index.numel() != 0:
34 | full_box[0][4] = p[0][critical_iou_index][:,4]
35 | full_box[0][6:] = p[0][critical_iou_index][:,6:]
36 | p[0][critical_iou_index] = full_box
37 |
38 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
39 | for i, pred in enumerate(p):
40 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
41 | path = self.batch[0]
42 | img_path = path[i] if isinstance(path, list) else path
43 | if not len(pred): # save empty boxes
44 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
45 | continue
46 | if self.args.retina_masks:
47 | if not isinstance(orig_imgs, torch.Tensor):
48 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
49 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
50 | else:
51 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
52 | if not isinstance(orig_imgs, torch.Tensor):
53 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
54 | results.append(
55 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
56 | return results
57 |
--------------------------------------------------------------------------------
/utils/fastsam/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from PIL import Image
4 |
5 |
6 | def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
7 | '''Adjust bounding boxes to stick to image border if they are within a certain threshold.
8 | Args:
9 | boxes: (n, 4)
10 | image_shape: (height, width)
11 | threshold: pixel threshold
12 | Returns:
13 | adjusted_boxes: adjusted bounding boxes
14 | '''
15 |
16 | # Image dimensions
17 | h, w = image_shape
18 |
19 | # Adjust boxes
20 | boxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor(
21 | 0, dtype=torch.float, device=boxes.device), boxes[:, 0]) # x1
22 | boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor(
23 | 0, dtype=torch.float, device=boxes.device), boxes[:, 1]) # y1
24 | boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor(
25 | w, dtype=torch.float, device=boxes.device), boxes[:, 2]) # x2
26 | boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor(
27 | h, dtype=torch.float, device=boxes.device), boxes[:, 3]) # y2
28 |
29 | return boxes
30 |
31 |
32 |
33 | def convert_box_xywh_to_xyxy(box):
34 | x1 = box[0]
35 | y1 = box[1]
36 | x2 = box[0] + box[2]
37 | y2 = box[1] + box[3]
38 | return [x1, y1, x2, y2]
39 |
40 |
41 | def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
42 | '''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
43 | Args:
44 | box1: (4, )
45 | boxes: (n, 4)
46 | Returns:
47 | high_iou_indices: Indices of boxes with IoU > thres
48 | '''
49 | boxes = adjust_bboxes_to_image_border(boxes, image_shape)
50 | # obtain coordinates for intersections
51 | x1 = torch.max(box1[0], boxes[:, 0])
52 | y1 = torch.max(box1[1], boxes[:, 1])
53 | x2 = torch.min(box1[2], boxes[:, 2])
54 | y2 = torch.min(box1[3], boxes[:, 3])
55 |
56 | # compute the area of intersection
57 | intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
58 |
59 | # compute the area of both individual boxes
60 | box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
61 | box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
62 |
63 | # compute the area of union
64 | union = box1_area + box2_area - intersection
65 |
66 | # compute the IoU
67 | iou = intersection / union # Should be shape (n, )
68 | if raw_output:
69 | if iou.numel() == 0:
70 | return 0
71 | return iou
72 |
73 | # get indices of boxes with IoU > thres
74 | high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
75 |
76 | return high_iou_indices
77 |
78 |
79 | def image_to_np_ndarray(image):
80 | if type(image) is str:
81 | return np.array(Image.open(image))
82 | elif issubclass(type(image), Image.Image):
83 | return np.array(image)
84 | elif type(image) is np.ndarray:
85 | return image
86 | return None
87 |
--------------------------------------------------------------------------------
/utils/gradio_tabs/__init__.py:
--------------------------------------------------------------------------------
1 | from .image_tabs import *
2 | from .video_tabs import *
3 | from .chat_tabs import *
4 | from .audio_tabs import *
--------------------------------------------------------------------------------
/utils/gradio_tabs/audio_tabs.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 |
4 | # 语音识别
5 | def asr_tab(asr_args, ai_handler):
6 | with gr.Tab(asr_args['name']):
7 | audio_file = gr.Audio(type="filepath", autoplay=True)
8 | # text_input = gr.Textbox(label="Question")
9 | audio_asr_text = gr.Textbox(label="text")
10 | audio_asr_files = gr.components.File(label="音频识别输出的文件")
11 | audio_b1 = gr.Button("Recognize speech", variant="primary")
12 |
13 | audio_b1.click(ai_handler.asr_infer, inputs=[audio_file], outputs=[audio_asr_text, audio_asr_files])
14 |
15 | def tts_tab(tts_args, ai_handler):
16 | with gr.Tab(tts_args['name']):
17 | model_type = gr.Dropdown(label="模型类型", choices=tts_args['model_type']['choices'], value=tts_args['model_type']['value'])
18 | with gr.Row():
19 | tts_text_file = gr.components.File(label="上传txt文本")
20 | tts_text = gr.Textbox(label="text")
21 |
22 | tts_voice = gr.Dropdown(label="选择发音人", choices=tts_args['tts_voice']['choices'], value=tts_args['tts_voice']['value'])
23 | tts_rate = gr.Slider(label="语速", minimum=tts_args['tts_rate']['minimum'], max=tts_args['tts_rate']['maximum'], value=tts_args['tts_rate']['value'], step=tts_args['tts_rate']['step'])
24 | tts_volume = gr.Slider(label="音量", minimum=tts_args['tts_volume']['minimum'], max=tts_args['tts_volume']['maximum'], value=tts_args['tts_volume']['value'], step=tts_args['tts_volume']['step'])
25 | tts_pitch = gr.Slider(label="语调", minimum=tts_args['tts_pitch']['minimum'], max=tts_args['tts_pitch']['maximum'], value=tts_args['tts_pitch']['value'], step=tts_args['tts_pitch']['step'])
26 | tts_b1 = gr.Button("合成", variant="primary")
27 | # 输出可下载文件
28 | tts_audio_file = gr.Audio(type="filepath", autoplay=True, label="合成音频")
29 | tts_out_files = gr.components.File(label="合成文件下载列表")
30 |
31 | tts_b1.click(ai_handler.tts_infer,
32 | inputs=[tts_text_file, tts_text, tts_voice, tts_rate, tts_volume, tts_pitch],
33 | outputs=[tts_audio_file, tts_out_files]
34 | )
--------------------------------------------------------------------------------
/utils/gradio_tabs/chat_tabs.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from utils import misc
3 |
4 | def chat_tab(chatbot_args, tts_args, ai_handler):
5 | # 聊天机器人 text_args['']
6 | with gr.Tab(chatbot_args['name']):
7 | gr.Chatbot.postprocess = misc.postprocess
8 | chatbot = gr.Chatbot(height=chatbot_args['chatbot_win']['height'],)
9 | with gr.Row():
10 | with gr.Column(scale=4):
11 | with gr.Column(scale=12):
12 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=5).style(
13 | container=False)
14 |
15 | with gr.Column(min_width=32, scale=1):
16 | text_submitBtn = gr.Button("发送文字聊天信息", variant="primary")
17 | # speech_file = gr.Audio(type="filepath", min_width=25, autoplay=False, label="语音文件or录音")
18 | speech_file = gr.Audio(label="Audio", sources="microphone", type="filepath", elem_id='audio')
19 | speech_submitBtn = gr.Button("发送语音信息")
20 | chat_replay_audio = gr.Audio(type="filepath", autoplay=True, label="AI回话内容")
21 |
22 | with gr.Column(scale=1):
23 | llm_model_type = gr.Dropdown(choices=chatbot_args['llm_model_type']['choices'],
24 | value=chatbot_args['llm_model_type']['value'],
25 | label="大语言模型")
26 | chat_tts_voice = gr.Dropdown(label="选择发音人", choices=tts_args['tts_voice']['choices'], value=tts_args['tts_voice']['value'])
27 | text_emptyBtn = gr.Button("Clear History")
28 | max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
29 | top_p = gr.Slider(0, 1, value=0.95, step=0.01, label="Top P", interactive=True)
30 | temperature = gr.Slider(0, 1, value=0.85, step=0.01, label="Temperature", interactive=True)
31 |
32 | history = gr.State([])
33 | past_key_values = gr.State(None)
34 |
35 | text_submitBtn.click(ai_handler.chatglm_handler.stream_chat, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
36 | [chatbot, history, past_key_values], show_progress=True)
37 | text_submitBtn.click(misc.reset_user_input, [], [user_input])
38 |
39 | # speech_submitBtn.click(ai_handler.llm_infer, [llm_model_type, user_input, speech_file, history, max_length, top_p, temperature],
40 | # [chat_replay_audio, history], show_progress=True)
41 | speech_submitBtn.click(ai_handler.audio_chat, [llm_model_type, chat_tts_voice, speech_file, history, max_length, top_p, temperature],
42 | [chat_replay_audio, history], show_progress=True)
43 |
44 | # speech_submitBtn.click(fn=action, inputs=speech_submitBtn, outputs=speech_submitBtn).\
45 | # then(fn=lambda: None, _js=click_js()).\
46 | # then(fn=check_btn, inputs=speech_submitBtn).\
47 | # success(fn=ai_handler.llm_infer, inputs=[llm_model_type, user_input, speech_file, history, max_length, top_p, temperature], outputs=[chat_replay_audio, history])
48 |
49 |
50 | text_emptyBtn.click(misc.reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)
51 |
52 |
53 | def visualchat_tab(visualchat_args, ai_handler):
54 | MAINTENANCE_NOTICE = 'Hint 1: If the app report "Something went wrong, connection error out", \
55 | please turn off your proxy and retry.\nHint 2: If you upload a large size of image like 10MB, \
56 | it may take some time to upload and process. Please be patient and wait.'
57 | with gr.Tab(visualchat_args['name']):
58 | with gr.Row():
59 | with gr.Column(scale=2):
60 | image_path = gr.Image(type="filepath", label="Image Prompt", value=None).style(height=480)
61 | with gr.Column(scale=4):
62 | chatbot = gr.Chatbot().style(height=480)
63 | with gr.Row():
64 | with gr.Column(scale=2, min_width=100):
65 | max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
66 | top_p = gr.Slider(0, 1, value=0.4, step=0.01, label="Top P", interactive=True)
67 | temperature = gr.Slider(0, 1, value=0.8, step=0.01, label="Temperature", interactive=True)
68 | with gr.Column(scale=4):
69 | with gr.Box():
70 | with gr.Row():
71 | with gr.Column(scale=2):
72 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=6).style(
73 | container=False)
74 | with gr.Column(scale=1, min_width=64):
75 | submitBtn = gr.Button("Submit", variant="primary")
76 | emptyBtn = gr.Button("Clear History")
77 | gr.Markdown('\n' + MAINTENANCE_NOTICE + '\n')
78 | history = gr.State([])
79 |
80 | # submitBtn.click(ai_handler.visualchat_handler.stream_chat, [user_input, image_path, chatbot, max_length, top_p, temperature, history], [chatbot, history],
81 | # show_progress=True)
82 | # image_path.upload(ai_handler.visualchat_handler.stream_chat2, [image_path, chatbot, max_length, top_p, temperature], [chatbot, history],
83 | # show_progress=True)
84 |
85 | # submitBtn.click(ai_handler.chatvlm_handler.chat, [user_input, image_path, chatbot, history], [chatbot, history], show_progress=True)
86 | submitBtn.click(ai_handler.chatvlm_handler.chat_stream, [user_input, image_path, chatbot, history], [chatbot, history], show_progress=True)
87 | # image_path.clear(misc.reset_state, outputs=[image_path, chatbot, history], show_progress=True)
88 | # submitBtn.click(misc.reset_user_input, [], [user_input])
89 | emptyBtn.click(misc.reset_state2, outputs=[chatbot, history], show_progress=True)
90 |
--------------------------------------------------------------------------------
/utils/gradio_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Iterable
3 | import gradio as gr
4 | from gradio.themes.base import Base
5 | from gradio.themes.utils import colors, fonts, sizes
6 | import time
7 |
8 |
9 | class Seafoam(Base):
10 | def __init__(
11 | self,
12 | *,
13 | primary_hue: colors.Color | str = colors.gray,
14 | secondary_hue: colors.Color | str = colors.blue,
15 | neutral_hue: colors.Color | str = colors.neutral,
16 | spacing_size: sizes.Size | str = sizes.spacing_md,
17 | radius_size: sizes.Size | str = sizes.radius_md,
18 | text_size: sizes.Size | str = sizes.text_lg,
19 | font: fonts.Font
20 | | str
21 | | Iterable[fonts.Font | str] = (
22 | fonts.GoogleFont("Quicksand"),
23 | "ui-sans-serif",
24 | "sans-serif",
25 | ),
26 | font_mono: fonts.Font
27 | | str
28 | | Iterable[fonts.Font | str] = (
29 | fonts.GoogleFont("IBM Plex Mono"),
30 | "ui-monospace",
31 | "monospace",
32 | ),
33 | ):
34 | super().__init__(
35 | primary_hue=primary_hue,
36 | secondary_hue=secondary_hue,
37 | neutral_hue=neutral_hue,
38 | spacing_size=spacing_size,
39 | radius_size=radius_size,
40 | text_size=text_size,
41 | font=font,
42 | font_mono=font_mono,
43 | )
44 |
45 |
46 | def click_js():
47 | return """function audioRecord() {
48 | var xPathRes = document.evaluate ('//*[contains(@class, "record")]', document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null);
49 | xPathRes.singleNodeValue.click();}"""
50 |
51 |
52 | def action(btn):
53 | """Changes button text on click"""
54 | if btn == 'Speak': return 'Stop'
55 | else: return 'Speak'
56 |
57 |
58 | def check_btn(btn):
59 | """Checks for correct button text before invoking transcribe()"""
60 | if btn != 'Speak': raise Exception('Recording...')
61 |
62 |
--------------------------------------------------------------------------------
/webui.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import anyconfig
4 | from tools import AIWrapper
5 |
6 | from utils.misc import parse_config
7 | from utils.gradio_utils import *
8 | import gradio as gr
9 | from gradio.themes.utils import colors
10 | from utils.gradio_tabs import *
11 | torch.manual_seed(1234)
12 |
13 |
14 | os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
15 |
16 | def launch_webui(yaml_file, device_ids=None, share=None, **kwargs):
17 | if device_ids is not None:
18 | os.environ['CUDA_VISIBLE_DEVICES'] = device_ids
19 | with open(yaml_file, 'rb') as f:
20 | args = anyconfig.load(f)
21 | if 'base' in args:
22 | args = parse_config(args)
23 |
24 | video_convertor_args = args.get('video_convertor', {})
25 | video_inpainter_args = args.get('video_inpainter', {})
26 | segmentation_args = args.get('segmentation_task', {})
27 | chat_args = args.get('chatbot', {})
28 | visualchat_args = args.get('visualchat', {})
29 | asr_args = args.get('asr_task', {})
30 | tts_args = args.get('tts_task', {})
31 |
32 | # 初始化AI引擎
33 | ai_handler = AIWrapper(args)
34 |
35 | # seafoam = gradio_utils.Seafoam()
36 | theme=gr.themes.Soft(primary_hue=colors.gray, neutral_hue=colors.neutral)
37 | with gr.Blocks(theme=theme) as web:
38 | # gr.Markdown(args["home_name"])
39 | gr.HTML(f"""{args["home_desc"]}
""")
40 | # Process text, audio or video file using this web
41 |
42 | # 视频剪辑
43 | if video_convertor_args.get('switch'):
44 | video_convertor_tab(video_convertor_args, ai_handler)
45 |
46 | # 视频修复
47 | if video_inpainter_args.get('switch'):
48 | video_inpainter_tab(video_inpainter_args, ai_handler)
49 |
50 | # 图像分割
51 | if segmentation_args.get('switch'):
52 | sam_tab(segmentation_args, ai_handler)
53 |
54 | # 聊天问答
55 | if chat_args.get('switch'):
56 | chat_tab(chat_args, tts_args, ai_handler)
57 | # 多模态问答
58 | if visualchat_args.get('switch'):
59 | visualchat_tab(visualchat_args, ai_handler)
60 |
61 | # 语音识别
62 | if asr_args.get('switch'):
63 | asr_tab(asr_args, ai_handler)
64 |
65 | # 语音合成
66 | if tts_args.get('switch'):
67 | tts_tab(tts_args, ai_handler)
68 |
69 | web.queue().launch(share=share, server_name=args["server_name"], server_port=args["server_port"])
70 |
71 |
72 | def parse_opt():
73 | parser = argparse.ArgumentParser()
74 | parser.add_argument('-f', '-c', '-cfg', '--yaml_file', type=str, default='configs/webui_configs.yml',
75 | help='yaml config file'
76 | )
77 | parser.add_argument('-d', '--device_ids', type=str, default=None, help='device ids')
78 | parser.add_argument('-s', "--share", action="store_true", help='whether public url')
79 | opt = parser.parse_args()
80 | return opt
81 |
82 |
83 | if __name__ == "__main__":
84 | opt = parse_opt()
85 | launch_webui(**opt.__dict__)
86 |
--------------------------------------------------------------------------------