├── .gitignore
├── 0、install pwsh.sh
├── 1、install-uv-qinglong.ps1
├── 2.1、image_watermark_detect.ps1
├── 2、video_spliter.ps1
├── 3、tagger.ps1
├── 4、run.ps1
├── LICENSE
├── README.md
├── config
├── __init__.py
├── config.py
└── config.toml
├── datasets
└── put datasets here
├── lanceExport.ps1
├── lanceImport.ps1
├── module
├── __init__.py
├── api_handler.py
├── captioner.py
├── lanceImport.py
├── lanceexport.py
├── scenedetect.py
└── waterdetect.py
├── requirements-uv-linux.txt
├── requirements-uv.txt
├── requirements.txt
└── utils
├── __init__.py
├── console_util.py
├── stream_util.py
└── wdtagger.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 | datasets/
173 | wd14_tagger_model/
174 | data/
175 | huggingface/
176 | watermark_detection/
177 |
178 | # windsurf rules
179 | .windsurfrules
180 |
--------------------------------------------------------------------------------
/0、install pwsh.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/bash
2 |
3 | echo "检查是否已安装 PowerShell..."
4 | if ! command -v pwsh &> /dev/null
5 | then
6 | echo "PowerShell 未安装,正在安装..."
7 |
8 | # 下载 PowerShell '.tar.gz' 压缩包
9 | curl -L -o /tmp/powershell.tar.gz https://github.com/PowerShell/PowerShell/releases/download/v7.5.1/powershell-7.5.1-linux-x64.tar.gz
10 |
11 | # 创建目标文件夹
12 | mkdir -p /opt/microsoft/powershell/7
13 |
14 | # 解压 PowerShell 到目标文件夹
15 | tar zxf /tmp/powershell.tar.gz -C /opt/microsoft/powershell/7
16 |
17 | # 设置执行权限
18 | chmod +x /opt/microsoft/powershell/7/pwsh
19 |
20 | # 创建指向 pwsh 的符号链接
21 | ln -s /opt/microsoft/powershell/7/pwsh /usr/bin/pwsh
22 |
23 | echo "PowerShell 安装完成"
24 | else
25 | echo "PowerShell 已安装"
26 | fi
27 |
28 | echo "Install completed"
29 |
--------------------------------------------------------------------------------
/1、install-uv-qinglong.ps1:
--------------------------------------------------------------------------------
1 | Set-Location $PSScriptRoot
2 |
3 | $Env:HF_HOME = "huggingface"
4 | #$Env:HF_ENDPOINT="https://hf-mirror.com"
5 | $Env:PIP_DISABLE_PIP_VERSION_CHECK = 1
6 | $Env:PIP_NO_CACHE_DIR = 1
7 | #$Env:PIP_INDEX_URL="https://pypi.mirrors.ustc.edu.cn/simple"
8 | #$Env:UV_INDEX_URL="https://pypi.tuna.tsinghua.edu.cn/simple/"
9 | $Env:UV_EXTRA_INDEX_URL = "https://download.pytorch.org/whl/cu128"
10 | $Env:UV_CACHE_DIR = "${env:LOCALAPPDATA}/uv/cache"
11 | $Env:UV_NO_BUILD_ISOLATION = 1
12 | $Env:UV_NO_CACHE = 0
13 | $Env:UV_LINK_MODE = "symlink"
14 | $Env:GIT_LFS_SKIP_SMUDGE = 1
15 | $Env:CUDA_HOME = "${env:CUDA_PATH}"
16 |
17 | function InstallFail {
18 | Write-Output "Install failed|安装失败。"
19 | Read-Host | Out-Null ;
20 | Exit
21 | }
22 |
23 | function Check {
24 | param (
25 | $ErrorInfo
26 | )
27 | if (!($?)) {
28 | Write-Output $ErrorInfo
29 | InstallFail
30 | }
31 | }
32 |
33 | try {
34 | ~/.local/bin/uv --version
35 | Write-Output "uv installed|UV模块已安装."
36 | }
37 | catch {
38 | Write-Output "Installing uv|安装uv模块中..."
39 | if ($Env:OS -ilike "*windows*") {
40 | powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
41 | Check "uv install failed|安装uv模块失败。"
42 | }
43 | else {
44 | curl -LsSf https://astral.sh/uv/install.sh | sh
45 | Check "uv install failed|安装uv模块失败。"
46 | }
47 | }
48 |
49 | if ($env:OS -ilike "*windows*") {
50 | chcp 65001
51 | # First check if UV cache directory already exists
52 | if (Test-Path -Path "${env:LOCALAPPDATA}/uv/cache") {
53 | Write-Host "UV cache directory already exists, skipping disk space check"
54 | }
55 | else {
56 | # Check C drive free space with error handling
57 | try {
58 | $CDrive = Get-WmiObject Win32_LogicalDisk -Filter "DeviceID='C:'" -ErrorAction Stop
59 | if ($CDrive) {
60 | $FreeSpaceGB = [math]::Round($CDrive.FreeSpace / 1GB, 2)
61 | Write-Host "C: drive free space: ${FreeSpaceGB}GB"
62 |
63 | # $Env:UV cache directory based on available space
64 | if ($FreeSpaceGB -lt 10) {
65 | Write-Host "Low disk space detected. Using local .cache directory"
66 | $Env:UV_CACHE_DIR = ".cache"
67 | }
68 | }
69 | else {
70 | Write-Warning "C: drive not found. Using local .cache directory"
71 | $Env:UV_CACHE_DIR = ".cache"
72 | }
73 | }
74 | catch {
75 | Write-Warning "Failed to check disk space: $_. Using local .cache directory"
76 | $Env:UV_CACHE_DIR = ".cache"
77 | }
78 | }
79 | if (Test-Path "./venv/Scripts/activate") {
80 | Write-Output "Windows venv"
81 | . ./venv/Scripts/activate
82 | }
83 | elseif (Test-Path "./.venv/Scripts/activate") {
84 | Write-Output "Windows .venv"
85 | . ./.venv/Scripts/activate
86 | }
87 | else {
88 | Write-Output "Create .venv"
89 | ~/.local/bin/uv venv -p 3.11 --seed
90 | . ./.venv/Scripts/activate
91 | }
92 | }
93 | elseif (Test-Path "./venv/bin/activate") {
94 | Write-Output "Linux venv"
95 | . ./venv/bin/Activate.ps1
96 | }
97 | elseif (Test-Path "./.venv/bin/activate") {
98 | Write-Output "Linux .venv"
99 | . ./.venv/bin/activate.ps1
100 | }
101 | else {
102 | Write-Output "Create .venv"
103 | ~/.local/bin/uv venv -p 3.11 --seed
104 | . ./.venv/bin/activate.ps1
105 | }
106 |
107 | Write-Output "Installing main requirements"
108 |
109 | ~/.local/bin/uv pip install --upgrade setuptools wheel pip wheel_stub
110 |
111 | if ($env:OS -ilike "*windows*") {
112 | ~/.local/bin/uv pip sync requirements-uv.txt --index-strategy unsafe-best-match
113 | Check "Install main requirements failed"
114 | }
115 | else {
116 | ~/.local/bin/uv pip sync requirements-uv-linux.txt --index-strategy unsafe-best-match
117 | Check "Install main requirements failed"
118 | }
119 |
120 | Write-Output "Install finished"
121 | Read-Host | Out-Null ;
122 |
--------------------------------------------------------------------------------
/2.1、image_watermark_detect.ps1:
--------------------------------------------------------------------------------
1 | #region Configuration
2 | # Model settings
3 | $Config = @{
4 | train_data_dir = "./datasets" # Input images path | 图片输入路径
5 | # bdsqlsz/Watermark-Detection-SigLIP2-onnx
6 | # bdsqlsz/joycaption-watermark-detection-onnx
7 | repo_id = "bdsqlsz/joycaption-watermark-detection-onnx" # Model repo ID from Hugging Face
8 | model_dir = "watermark_detection" # Local model folder path | 本地模型文件夹路径
9 | batch_size = 12 # Batch size for inference
10 | thresh = 1.0 # Concept threshold
11 | }
12 | #endregion
13 |
14 | #region Environment Setup
15 | # Activate python venv
16 | Set-Location $PSScriptRoot
17 | $env:PYTHONPATH = "$PSScriptRoot;$env:PYTHONPATH"
18 | $VenvPaths = @(
19 | "./venv/Scripts/activate",
20 | "./.venv/Scripts/activate",
21 | "./venv/bin/Activate.ps1",
22 | "./.venv/bin/activate.ps1"
23 | )
24 |
25 | foreach ($Path in $VenvPaths) {
26 | if (Test-Path $Path) {
27 | Write-Output "Activating venv: $Path"
28 | & $Path
29 | break
30 | }
31 | }
32 |
33 | # Set environment variables
34 | $Env:HF_HOME = "huggingface"
35 | #$Env:HF_ENDPOINT = "https://hf-mirror.com"
36 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1"
37 | $Env:CUDA_HOME = "${env:CUDA_PATH}"
38 | $Env:TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION = "1"
39 | $Env:TF_CUDNN_USE_AUTOTUNE = "1"
40 | $Env:TF_TRT_ALLOW_TF32 = "1"
41 |
42 | #endregion
43 |
44 | #region Build Arguments
45 | $ExtArgs = [System.Collections.ArrayList]::new()
46 |
47 | # Add configuration arguments
48 | if ($Config.repo_id) { [void]$ExtArgs.Add("--repo_id=$($Config.repo_id)") }
49 | if ($Config.model_dir) { [void]$ExtArgs.Add("--model_dir=$($Config.model_dir)") }
50 | if ($Config.batch_size) { [void]$ExtArgs.Add("--batch_size=$($Config.batch_size)") }
51 | if ($Config.thresh -ne 1.0) { [void]$ExtArgs.Add("--thresh=$($Config.thresh)") }
52 |
53 | #endregion
54 |
55 | #region Execute Watermark Detection
56 | Write-Output "Starting Watermark Detection..."
57 |
58 | # Run tagger
59 | accelerate launch --num_cpu_threads_per_process=8 "./module/waterdetect.py" `
60 | $Config.train_data_dir `
61 | $ExtArgs
62 |
63 | Write-Output "Watermark Detection finished"
64 | Read-Host | Out-Null
65 |
66 | #endregion
67 |
--------------------------------------------------------------------------------
/2、video_spliter.ps1:
--------------------------------------------------------------------------------
1 | #region Configuration
2 | # 场景检测设置
3 | $Config = @{
4 | input_video_dir = "./datasets" # 输入视频目录路径
5 | output_dir = "" # 输出目录路径,如果不指定则默认为输入目录
6 | detector = "AdaptiveDetector" # 场景检测器,可选"ContentDetector","AdaptiveDetector","HashDetector","HistogramDetector","ThresholdDetector"
7 | threshold = 0.0 # 场景检测阈值,数值越低越敏感。ContentDetector: 27.0, AdaptiveDetector: 3.0, HashDetector: 0.395, HistogramDetector: 0.05, ThresholdDetector: 12
8 | min_scene_len = 16 # 最小场景长度,数值越小越敏感
9 | luma_only = $false # 是否只使用亮度变化检测
10 | save_html = $true # 是否保存HTML报告
11 | video2images_min_number = 1 # 每个场景保存的图像数量,为0则不保存
12 | recursive = $false # 是否递归搜索子目录
13 | }
14 | #endregion
15 |
16 | #region Environment Setup
17 | # 激活Python虚拟环境
18 | Set-Location $PSScriptRoot
19 | $env:PYTHONPATH = "$PSScriptRoot;$env:PYTHONPATH"
20 | $VenvPaths = @(
21 | "./venv/Scripts/activate",
22 | "./.venv/Scripts/activate",
23 | "./venv/bin/Activate.ps1",
24 | "./.venv/bin/activate.ps1"
25 | )
26 |
27 | foreach ($Path in $VenvPaths) {
28 | if (Test-Path $Path) {
29 | Write-Output "Activating venv: $Path"
30 | & $Path
31 | break
32 | }
33 | }
34 | #endregion
35 |
36 | #region Build Arguments
37 | $Env:HF_HOME = "huggingface"
38 | #$Env:HF_ENDPOINT = "https://hf-mirror.com"
39 | $ExtArgs = [System.Collections.ArrayList]::new()
40 |
41 | # 添加配置参数
42 | if ($Config.output_dir) { [void]$ExtArgs.Add("--output_dir=$($Config.output_dir)") }
43 | if ($Config.detector -ne "AdaptiveDetector") { [void]$ExtArgs.Add("--detector=$($Config.detector)") }
44 | if ($Config.threshold -ne 0.0) { [void]$ExtArgs.Add("--threshold=$($Config.threshold)") }
45 | if ($Config.min_scene_len) { [void]$ExtArgs.Add("--min_scene_len=$($Config.min_scene_len)") }
46 | if ($Config.luma_only) { [void]$ExtArgs.Add("--luma_only") }
47 | if ($Config.save_html) { [void]$ExtArgs.Add("--save_html") }
48 | if ($Config.video2images_min_number -gt 0) { [void]$ExtArgs.Add("--video2images_min_number=$($Config.video2images_min_number)") }
49 | if ($Config.recursive) { [void]$ExtArgs.Add("--recursive") }
50 | #endregion
51 |
52 | #region Execute Scene Detection
53 | Write-Output "Starting scene detection..."
54 |
55 | # 运行场景检测程序
56 | python -m module.scenedetect `
57 | $Config.input_video_dir `
58 | $ExtArgs
59 |
60 | Write-Output "Scene detection finished"
61 | Read-Host | Out-Null
62 | #endregion
63 |
--------------------------------------------------------------------------------
/3、tagger.ps1:
--------------------------------------------------------------------------------
1 | #region Configuration
2 | # Model settings
3 | $Config = @{
4 | train_data_dir = "./datasets" # Input images path | 图片输入路径
5 | repo_id = "SmilingWolf/wd-eva02-large-tagger-v3" # Model repo ID from Hugging Face
6 | model_dir = "wd14_tagger_model" # Local model folder path | 本地模型文件夹路径
7 | batch_size = 12 # Batch size for inference
8 | thresh = 0.3 # Concept threshold
9 | general_threshold = 0.3 # General threshold
10 | character_threshold = 1.0 # Character threshold
11 | }
12 |
13 | # Feature flags
14 | $Features = @{
15 | frequency_tags = $false # Order by frequency tags
16 | remove_underscore = $true # Convert underscore to space
17 | use_rating_tags = $false # Use rating tags
18 | use_rating_tags_as_last_tag = $false # Put rating tags at the end
19 | character_tags_first = $false # Put character tags first
20 | character_tag_expand = $false # Split character_(series) into character, series
21 | remove_parents_tag = $true # Remove parent tags
22 | overwrite = $true # Overwrite existing tag files
23 | }
24 |
25 | # Tag settings
26 | $TagConfig = @{
27 | undesired_tags = "" # Tags to exclude
28 | always_first_tags = "1girl,1boy,2girls,3girls,4girls,5girls,6girls,2boys,3boys,4boys,5boys,6boys"
29 | tag_replacement = "1girl,1woman;2girls,2women;3girls,3women;4girls,4women;5girls,5women;1boy,1man"
30 | }
31 |
32 | #endregion
33 |
34 | #region Environment Setup
35 | # Activate python venv
36 | Set-Location $PSScriptRoot
37 | $env:PYTHONPATH = "$PSScriptRoot;$env:PYTHONPATH"
38 | $VenvPaths = @(
39 | "./venv/Scripts/activate",
40 | "./.venv/Scripts/activate",
41 | "./venv/bin/Activate.ps1",
42 | "./.venv/bin/activate.ps1"
43 | )
44 |
45 | foreach ($Path in $VenvPaths) {
46 | if (Test-Path $Path) {
47 | Write-Output "Activating venv: $Path"
48 | & $Path
49 | break
50 | }
51 | }
52 |
53 | # Set environment variables
54 | $Env:HF_HOME = "huggingface"
55 | #$Env:HF_ENDPOINT = "https://hf-mirror.com"
56 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1"
57 | $Env:CUDA_HOME = "${env:CUDA_PATH}"
58 | $Env:TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION = "1"
59 | $Env:TF_CUDNN_USE_AUTOTUNE = "1"
60 | $Env:TF_TRT_ALLOW_TF32 = "1"
61 |
62 | #endregion
63 |
64 | #region Build Arguments
65 | $ExtArgs = [System.Collections.ArrayList]::new()
66 |
67 | # Add configuration arguments
68 | if ($Config.repo_id) { [void]$ExtArgs.Add("--repo_id=$($Config.repo_id)") }
69 | if ($Config.model_dir) { [void]$ExtArgs.Add("--model_dir=$($Config.model_dir)") }
70 | if ($Config.batch_size) { [void]$ExtArgs.Add("--batch_size=$($Config.batch_size)") }
71 | if ($Config.general_threshold) { [void]$ExtArgs.Add("--general_threshold=$($Config.general_threshold)") }
72 | if ($Config.character_threshold) { [void]$ExtArgs.Add("--character_threshold=$($Config.character_threshold)") }
73 |
74 | # Add feature flags
75 | if ($Features.remove_underscore) { [void]$ExtArgs.Add("--remove_underscore") }
76 | if ($Features.recursive) { [void]$ExtArgs.Add("--recursive") }
77 | if ($Features.frequency_tags) { [void]$ExtArgs.Add("--frequency_tags") }
78 | if ($Features.character_tags_first) { [void]$ExtArgs.Add("--character_tags_first") }
79 | if ($Features.character_tag_expand) { [void]$ExtArgs.Add("--character_tag_expand") }
80 | if ($Features.use_rating_tags_as_last_tag) { [void]$ExtArgs.Add("--use_rating_tags_as_last_tag") }
81 | elseif ($Features.use_rating_tags) { [void]$ExtArgs.Add("--use_rating_tags") }
82 | if ($Features.remove_parents_tag) { [void]$ExtArgs.Add("--remove_parents_tag") }
83 | if ($Features.overwrite) { [void]$ExtArgs.Add("--overwrite") }
84 |
85 | # Add tag configuration
86 | if ($TagConfig.undesired_tags) { [void]$ExtArgs.Add("--undesired_tags=$($TagConfig.undesired_tags)") }
87 | if ($TagConfig.always_first_tags) { [void]$ExtArgs.Add("--always_first_tags=$($TagConfig.always_first_tags)") }
88 | if ($TagConfig.tag_replacement) { [void]$ExtArgs.Add("--tag_replacement=$($TagConfig.tag_replacement)") }
89 |
90 | #endregion
91 |
92 | #region Execute Tagger
93 | Write-Output "Starting tagger..."
94 |
95 | # Run tagger
96 | accelerate launch --num_cpu_threads_per_process=8 "./utils/wdtagger.py" `
97 | $Config.train_data_dir `
98 | --thresh=$($Config.thresh) `
99 | --caption_extension .txt `
100 | $ExtArgs
101 |
102 | Write-Output "Tagger finished"
103 | Read-Host | Out-Null
104 |
105 | #endregion
106 |
--------------------------------------------------------------------------------
/4、run.ps1:
--------------------------------------------------------------------------------
1 | $dataset_path = "./datasets"
2 | $gemini_api_key = ""
3 | $gemini_model_path = "gemini-2.5-pro-exp-03-25"
4 | $gemini_task = ""
5 | $pixtral_api_key = ""
6 | $pixtral_model_path = "pixtral-large-2411"
7 | $step_api_key = ""
8 | $step_model_path = "step-1.5v-mini"
9 | $qwenVL_api_key = ""
10 | $qwenVL_model_path = "qwen-vl-max-latest" # qwen2.5-vl-72b-instruct<10mins qwen-vl-max-latest <1min
11 | $glm_api_key = ""
12 | $glm_model_path = "GLM-4V-Plus-0111"
13 | $dir_name = $false
14 | $mode = "long"
15 | $not_clip_with_caption = $false # Not clip with caption | 不根据caption裁剪
16 | $wait_time = 1
17 | $max_retries = 100
18 | $segment_time = 600
19 | $ocr = $false
20 | $document_image = $true
21 | $scene_detector = "AdaptiveDetector" # from ["ContentDetector","AdaptiveDetector","HashDetector","HistogramDetector","ThresholdDetector"]
22 | $scene_threshold = 0.0 # default value ["ContentDetector": 27.0, "AdaptiveDetector": 3.0, "HashDetector": 0.395, "HistogramDetector": 0.05, "ThresholdDetector": 12]
23 | $scene_min_len = 15
24 | $scene_luma_only = $false
25 | $tags_highlightrate = 0.38
26 |
27 | # ============= DO NOT MODIFY CONTENTS BELOW | 请勿修改下方内容 =====================
28 | # Activate python venv
29 | Set-Location $PSScriptRoot
30 | if ($env:OS -ilike "*windows*") {
31 | if (Test-Path "./venv/Scripts/activate") {
32 | Write-Output "Windows venv"
33 | ./venv/Scripts/activate
34 | }
35 | elseif (Test-Path "./.venv/Scripts/activate") {
36 | Write-Output "Windows .venv"
37 | ./.venv/Scripts/activate
38 | }
39 | }
40 | elseif (Test-Path "./venv/bin/activate") {
41 | Write-Output "Linux venv"
42 | ./venv/bin/Activate.ps1
43 | }
44 | elseif (Test-Path "./.venv/bin/activate") {
45 | Write-Output "Linux .venv"
46 | ./.venv/bin/activate.ps1
47 | }
48 |
49 | $Env:HF_HOME = "huggingface"
50 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1"
51 | #$Env:HF_ENDPOINT = "https://hf-mirror.com"
52 | $Env:PILLOW_IGNORE_XMP_DATA_IS_TOO_LONG = "1"
53 | $ext_args = [System.Collections.ArrayList]::new()
54 | #$Env:HTTP_PROXY = "http://127.0.0.1:7890"
55 | #$Env:HTTPS_PROXY = "http://127.0.0.1:7890"
56 |
57 |
58 |
59 | if ($gemini_api_key) {
60 | [void]$ext_args.Add("--gemini_api_key=$gemini_api_key")
61 | if ($gemini_task) {
62 | [void]$ext_args.Add("--gemini_task=$gemini_task")
63 | }
64 | }
65 |
66 | if ($gemini_model_path) {
67 | [void]$ext_args.Add("--gemini_model_path=$gemini_model_path")
68 | }
69 |
70 | if ($pixtral_api_key) {
71 | [void]$ext_args.Add("--pixtral_api_key=$pixtral_api_key")
72 | }
73 |
74 | if ($pixtral_model_path) {
75 | [void]$ext_args.Add("--pixtral_model_path=$pixtral_model_path")
76 | }
77 |
78 | if ($step_api_key) {
79 | [void]$ext_args.Add("--step_api_key=$step_api_key")
80 | }
81 |
82 | if ($step_model_path) {
83 | [void]$ext_args.Add("--step_model_path=$step_model_path")
84 | }
85 |
86 | if ($qwenVL_api_key) {
87 | [void]$ext_args.Add("--qwenVL_api_key=$qwenVL_api_key")
88 | }
89 |
90 | if ($qwenVL_model_path) {
91 | [void]$ext_args.Add("--qwenVL_model_path=$qwenVL_model_path")
92 | }
93 |
94 | if ($glm_api_key) {
95 | [void]$ext_args.Add("--glm_api_key=$glm_api_key")
96 | }
97 |
98 | if ($glm_model_path) {
99 | [void]$ext_args.Add("--glm_model_path=$glm_model_path")
100 | }
101 |
102 | if ($dir_name) {
103 | [void]$ext_args.Add("--dir_name")
104 | }
105 |
106 | if ($mode -ine "all") {
107 | [void]$ext_args.Add("--mode=$mode")
108 | }
109 |
110 | if ($not_clip_with_caption) {
111 | [void]$ext_args.Add("--not_clip_with_caption")
112 | }
113 |
114 | if ($wait_time -ine 1) {
115 | [void]$ext_args.Add("--wait_time=$wait_time")
116 | }
117 |
118 | if ($max_retries -ine 20) {
119 | [void]$ext_args.Add("--max_retries=$max_retries")
120 | }
121 |
122 | if ($segment_time -ine 600) {
123 | [void]$ext_args.Add("--segment_time=$segment_time")
124 | }
125 |
126 | if ($ocr) {
127 | [void]$ext_args.Add("--ocr")
128 | }
129 |
130 | if ($document_image) {
131 | [void]$ext_args.Add("--document_image")
132 | }
133 |
134 | if ($scene_detector -ne "AdaptiveDetector") {
135 | [void]$ext_args.Add("--scene_detector=$($scene_detector)")
136 | }
137 |
138 | if ($scene_threshold -ne 0.0) {
139 | [void]$ext_args.Add("--scene_threshold=$scene_threshold")
140 | }
141 |
142 | if ($scene_min_len -ne 15) {
143 | [void]$ext_args.Add("--scene_min_len=$scene_min_len")
144 | }
145 |
146 | if ($scene_luma_only) {
147 | [void]$ext_args.Add("--scene_luma_only")
148 | }
149 |
150 | if ($tags_highlightrate -ne 0.4) {
151 | [void]$ext_args.Add("--tags_highlightrate=$tags_highlightrate")
152 | }
153 |
154 | # run train
155 | python -m module.captioner $dataset_path $ext_args
156 |
157 | Write-Output "Captioner finished"
158 | Read-Host | Out-Null ;
159 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://ko-fi.com/N4N1NOO2K)
2 |
3 | # qinglong-captioner (2.6)
4 |
5 | A Python toolkit for generating video captions using the Lance database format and Gemini API for automatic captioning.
6 |
7 | ## Changlog
8 |
9 | ### 2.6
10 | 
11 |
12 | A new watermark detection script has been added, initially supporting two watermark detection models, which can quickly classify images in the dataset into watermarked/unwatermarked categories.
13 | It will generate two folders, and data separation is done through symbolic links. If needed, you can copy the corresponding folder to transfer data without deleting it, and it does not occupy additional space.
14 | (As symbolic links require permissions, you must run PowerShell as admin.)
15 |
16 | Finally, it will generate a JSON file report listing the watermark detection results for all images in the original path, including detection values and results.
17 | The watermark threshold can be modified in the script to correspondingly change the detection results.
18 |
19 |
20 | ### 2.5
21 | 
22 |
23 | We officially support the tags highlight captions feature! Currently unlocked in the pixtral model, and we are considering adding it to other models such as gemini in the future.
24 |
25 | What are tags highlight?
26 |
27 | As is well known, non-state-of-the-art VLMs have some inaccuracies, so first use wdtagger for tags annotation, and then input the tags annotation to the VLM for assistance, which can improve accuracy.
28 |
29 | Currently, the tags have been categorized, and it is also possible to quickly check the annotation quality (e.g., purple is for character names and copyright, red is for clothing, brown is for body features, light yellow is for actions, etc.)
30 |
31 | The annotation quality obtained in the end is comparable to some closed-source models!
32 |
33 | Additionally, we have added check parameters, which can specify the parent folder as the character name to designate the character's name, as well as specify the check for the tags highlight rate. Generally, good captions should have a highlight rate of over 35%.
34 |
35 | You can also specify different highlight rates to change the default standard.
36 |
37 | How to use? just use 3、tagger.ps1 first for generate tags for your image datasets,
38 |
39 | then use 4、run.ps1 with pixtral apikey
40 |
41 | ### 2.4
42 |
43 | We support Gemini image caption and rating.
44 | It also supports gemini2.5-flash-preview-04-17.
45 |
46 | However, after testing, the flash version has poor effects and image review, it is recommended to use the pro version
47 |
48 | 
49 | flash↑
50 |
51 | 
52 | pro ↑
53 |
54 | ### 2.3
55 |
56 | Well, we forgot to release version 2.2, so we directly released version 2.3!
57 |
58 | Version 2.3 updated the GLM4V model for video captions
59 |
60 | ### 2.2
61 |
62 | Version 2.2 has updated TensorRT for accelerating local ONNX model WDtagger.
63 |
64 | After testing, it takes 30 minutes to mark 10,000 samples with the standard CUDA tag,
65 |
66 | while using TensorRT, it can be completed in just 15 to 20 minutes.
67 |
68 | However, the first time using it will take a longer time to compile.
69 |
70 | If TensorRT fails, it will automatically revert to CUDA without worry.
71 |
72 | If it prompts that TensorRT librarys are missing, it may be missing some parts
73 |
74 | Please install version 10.7.x manually from [here](https://developer.nvidia.com/tensorrt/download/10x)
75 |
76 | ### 2.1
77 |
78 | Added support for Gemini 2.5 Pro Exp. Now uses 600 seconds cut video by default.
79 |
80 | ### 2.0 Big Update!
81 |
82 | Now we support video segmentation! A new video segmentation module has been added, which detects key timestamps based on scene changes and then outputs the corresponding images and video clips!
83 | Export an HTML for reference, the effect is very significant!
84 | 
85 |
86 | We have also added subtitle alignment algorithms, which automatically align Gemini's timestamp subtitles to the millisecond level after detecting scene change frames (there are still some errors, but the effect is much better).
87 |
88 | Finally, we added the image output feature of the latest gemini-2.0-flash-exp model!
89 |
90 | You can customize the task, add the task name in the [`config.toml`](https://github.com/sdbds/qinglong-captions/blob/main/config/config.toml), which will automatically handle the corresponding images (and then label them)
91 |
92 | Currently, some simple task descriptions are as follows: Welcome the community to continuously optimize these task prompts and provide contributions!
93 | https://github.com/sdbds/qinglong-captions/blob/12b7750ee0bca7e41168e98775cd95c7b9c57173/config/config.toml#L239-L249
94 |
95 | 
96 |
97 | 
98 |
99 |
100 | ### 1.9
101 |
102 | Now with Mistral OCR functionality!
103 | Utilizing Mistral's advanced OCR capabilities to extract text information from videos and images.
104 |
105 | This feature is particularly useful when processing media files containing subtitles, signs, or other text elements, enhancing the accuracy and completeness of captions.
106 |
107 | The OCR functionality is integrated into the existing workflow and can be used without additional configuration.
108 |
109 | ### 1.8
110 |
111 | Now added WDtagger!
112 | Even if you cannot use the GPU, you can also use the CPU for labeling.
113 |
114 | It has multi-threading and various optimizations, processing large-scale data quickly.
115 |
116 | Using ONNX processing, model acceleration.
117 |
118 | Code reference@kohya-ss
119 | https://github.com/sdbds/sd-scripts/blob/main/finetune/tag_images_by_wd14_tagger.py
120 |
121 | Version 2.0 will add dual caption functionality, input wdtagger's taggers, then output natural language
122 | 
123 |
124 |
125 | ### 1.7
126 |
127 | Now we support the qwen-VL series video caption model!
128 |
129 | - qwen-vl-max-latest
130 | - qwen2.5-vl-72b-instruct
131 | - qwen2.5-vl-7b-instruct
132 | - qwen2.5-vl-3b-instruct
133 |
134 | qwen2.5-vl has 2 seconds ~ 10 mins, qwen-vl-max-latest has 1 min limit.
135 | These models are not good at capturing timestamps; it is recommended to use segmented video clips for captions and to modify the prompts.
136 |
137 | Video upload feature requires an application to be submitted to the official, please submit the application [here](https://smartservice.console.aliyun.com/service/create-ticket?spm=a2c4g.11186623.0.0.3489b0a8Ql486b).
138 |
139 | We consider adding local model inference in the future, such as qwen2.5-vl-7b-instruct, etc.
140 |
141 | Additionally, now using streaming inference to output logs, you can see the model's real-time output before the complete output is displayed.
142 |
143 | ### 1.6
144 |
145 | Now the Google gemini SDK has been updated, and the new version of the SDK is suitable for the new model of gemini 2.0!
146 |
147 | The new SDK is more powerful and mainly supports the function of verifying uploaded videos.
148 |
149 | If you want to repeatedly tag the same video and no longer need to upload it repeatedly, the video name and file size/hash will be automatically verified.
150 |
151 | At the same time, the millisecond-level alignment function has been updated. After the subtitles of long video segmentation are merged, the timeline is automatically aligned to milliseconds, which is very neat!
152 |
153 | ## Features
154 |
155 | - Automatic video/audio/image description using Google's Gemini API or only image with pixtral-large 124B
156 | - Export captions in SRT format
157 | - Support for multiple video formats
158 | - Batch processing with progress tracking
159 | - Maintains original directory structure
160 | - Configurable through TOML files
161 | - Lance database integration for efficient data management
162 |
163 | ## Modules
164 |
165 | ### Dataset Import (`lanceImport.py`)
166 | - Import videos into Lance database format
167 | - Preserve original directory structure
168 | - Support for both single directory and paired directory structures
169 |
170 | ### Dataset Export (`lanceexport.py`)
171 | - Extract videos and captions from Lance datasets
172 | - Maintains original file structure
173 | - Exports captions as SRT files in the same directory as source videos
174 | - Auto Clip with SRT timestamps
175 |
176 | ### Auto Captioning (`captioner.py` & `api_handler.py`)
177 | - Automatic video scene description using Gemini API or Pixtral API
178 | - Batch processing support
179 | - SRT format output with timestamps
180 | - Robust error handling and retry mechanisms
181 | - Progress tracking for batch operations
182 |
183 | ### Configuration (`config.py` & `config.toml`)
184 | - API prompt configuration management
185 | - Customizable batch processing parameters
186 | - Default schema includes file paths and metadata
187 |
188 | ## Installation
189 |
190 | Give unrestricted script access to powershell so venv can work:
191 |
192 | - Open an administrator powershell window
193 | - Type Set-ExecutionPolicy Unrestricted and answer A
194 | - Close admin powershell window
195 |
196 | 
197 |
198 | ### Windows
199 | Run the following PowerShell script:
200 | ```powershell
201 | ./1、install-uv-qinglong.ps1
202 | ```
203 |
204 | ### Linux
205 | 1. First install PowerShell:
206 | ```bash
207 | sudo sh ./0、install pwsh.sh
208 | ```
209 | 2. Then run the installation script using PowerShell:
210 | ```powershell
211 | sudo pwsh ./1、install-uv-qinglong.ps1
212 | ```
213 | use sudo pwsh if you in Linux.
214 |
215 | ### TensorRT (Optional)
216 | windows need to install TensorRT-libs manually from [here](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.9.0/zip/TensorRT-10.9.0.34.Windows.win10.cuda-12.8.zip).
217 | TensorRT can faster use WD14Tagger (not effect API part)
218 | Now we use 10.9 version
219 |
220 | ## Usage
221 |
222 | video example:
223 | https://files.catbox.moe/8fudnf.mp4
224 |
225 | ### Just put Video or audio files into datasets folders
226 |
227 | ### Importing Media
228 | Use the PowerShell script to import your videos:
229 | ```powershell
230 | ./lanceImport.ps1
231 | ```
232 |
233 | ### Exporting Media
234 | Use the PowerShell script to export data from Lance format:
235 | ```powershell
236 | ./lanceExport.ps1
237 | ```
238 |
239 | ### Auto Captioning
240 | Use the PowerShell script to generate captions for your videos:
241 |
242 | ```powershell
243 | ./run.ps1
244 | ```
245 |
246 | Note: You'll need to configure your [Gemini API key](https://aistudio.google.com/apikey) in `run.ps1` before using the auto-captioning feature.
247 | [Pixtral API key](https://console.mistral.ai/api-keys/) optional for image caption.
248 |
249 | Now we support [step-1.5v-mini](https://platform.stepfun.com/) optional for video captioner.
250 |
251 | Now we support [qwen-VL](https://bailian.console.aliyun.com/#/model-market) series optional for video captioner.
252 |
253 | Now we support [Mistral OCR](https://console.mistral.ai/api-keys/) optional for PDF and image OCR.
254 |
255 | Now we support [GLM](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) series optional for video captioner.
256 |
257 | ```
258 | $dataset_path = "./datasets"
259 | $gemini_api_key = ""
260 | $gemini_model_path = "gemini-2.0-pro-exp-02-05"
261 | $pixtral_api_key = ""
262 | $pixtral_model_path = "pixtral-large-2411"
263 | $step_api_key = ""
264 | $step_model_path = "step-1.5v-mini"
265 | $qwenVL_api_key = ""
266 | $qwenVL_model_path = "qwen-vl-max-latest" # qwen2.5-vl-72b-instruct<10mins qwen-vl-max-latest <1min
267 | $glm_api_key = ""
268 | $glm_model_path = "GLM-4V-Plus-0111"
269 | $dir_name = $true
270 | $mode = "long"
271 | $not_clip_with_caption = $false # Not clip with caption | 不根据caption裁剪
272 | $wait_time = 1
273 | $max_retries = 100
274 | $segment_time = 600
275 | $ocr = $false
276 | $document_image = $true
277 | $scene_detector = "AdaptiveDetector" # from ["ContentDetector","AdaptiveDetector","HashDetector","HistogramDetector","ThresholdDetector"]
278 | $scene_threshold = 0.0 # default value ["ContentDetector": 27.0, "AdaptiveDetector": 3.0, "HashDetector": 0.395, "HistogramDetector": 0.05, "ThresholdDetector": 12]
279 | $scene_min_len = 15
280 | $scene_luma_only = $false
281 | ```
282 | ---
283 |
284 | # 青龙数据集工具 (2.6)
285 |
286 | 基于 Lance 数据库格式的视频自动字幕生成工具,使用 Gemini API 进行场景描述生成。
287 |
288 | ## 功能特点
289 |
290 | - 使用 Google Gemini API 进行视频场景自动描述
291 | - 导出 SRT 格式字幕文件
292 | - 支持多种视频格式
293 | - 批量处理并显示进度
294 | - 保持原始目录结构
295 | - 通过 TOML 文件配置
296 | - 集成 Lance 数据库实现高效数据管理
297 |
298 | ## 模块说明
299 |
300 | ### 数据集导入 (`lanceImport.py`)
301 | - 将视频导入 Lance 数据库格式
302 | - 保持原始目录结构
303 | - 支持单目录和配对目录结构
304 |
305 | ### 数据集导出 (`lanceexport.py`)
306 | - 从 Lance 数据集中提取视频和字幕
307 | - 保持原有文件结构
308 | - 在源视频所在目录导出 SRT 格式字幕
309 |
310 | ### 自动字幕生成 (`captioner.py` & `api_handler.py`)
311 | - 使用 Gemini API 进行视频场景描述
312 | - 支持批量处理
313 | - 生成带时间戳的 SRT 格式字幕
314 | - 健壮的错误处理和重试机制
315 | - 批处理进度跟踪
316 |
317 | ### 配置模块 (`config.py` & `config.toml`)
318 | - API 配置管理
319 | - 可自定义批处理参数
320 | - 默认结构包含文件路径和元数据
321 |
322 | ## 安装方法
323 |
324 | ### Windows 系统
325 | 运行以下 PowerShell 脚本:
326 | ```powershell
327 | ./1、install-uv-qinglong.ps1
328 | ```
329 |
330 | ### Linux 系统
331 | 1. 首先安装 PowerShell:
332 | ```bash
333 | sudo sh ./0、install pwsh.sh
334 | ```
335 | 2. 然后使用 PowerShell 运行安装脚本:
336 | ```powershell
337 | pwsh ./1、install-uv-qinglong.ps1
338 | ```
339 |
340 | ## 使用方法
341 |
342 | ### 把媒体文件放到datasets文件夹下
343 |
344 | ### 导入视频
345 | 使用 PowerShell 脚本导入视频:
346 | ```powershell
347 | ./lanceImport.ps1
348 | ```
349 |
350 | ### 导出数据
351 | 使用 PowerShell 脚本从 Lance 格式导出数据:
352 | ```powershell
353 | ./lanceExport.ps1
354 | ```
355 |
356 | ### 自动字幕生成
357 | 使用 PowerShell 脚本为视频生成字幕:
358 | ```powershell
359 | ./run.ps1
360 | ```
361 |
362 | 注意:使用自动字幕生成功能前,需要在 `run.ps1` 中配置 [Gemini API 密钥](https://aistudio.google.com/apikey)。
363 | [Pixtral API 秘钥](https://console.mistral.ai/api-keys/) 可选为图片打标。
364 |
365 | 现在我们支持使用[阶跃星辰](https://platform.stepfun.com/)的视频模型进行视频标注。
366 |
367 | 现在我们支持使用[通义千问VL](https://bailian.console.aliyun.com/#/model-market)的视频模型进行视频标注。
368 |
369 | 现在我们支持使用[Mistral OCR](https://console.mistral.ai/api-keys/)的OCR功能进行图片字幕生成。
370 |
371 | 现在我们支持使用[智谱GLM](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys)的视频模型进行视频标注。
372 |
373 | ```
374 | $dataset_path = "./datasets"
375 | $gemini_api_key = ""
376 | $gemini_model_path = "gemini-2.0-pro-exp-02-05"
377 | $pixtral_api_key = ""
378 | $pixtral_model_path = "pixtral-large-2411"
379 | $step_api_key = ""
380 | $step_model_path = "step-1.5v-mini"
381 | $qwenVL_api_key = ""
382 | $qwenVL_model_path = "qwen-vl-max-latest" # qwen2.5-vl-72b-instruct<10mins qwen-vl-max-latest <1min
383 | $glm_api_key = ""
384 | $glm_model_path = "GLM-4V-Plus-0111"
385 | $dir_name = $true
386 | $mode = "long"
387 | $not_clip_with_caption = $false # Not clip with caption | 不根据caption裁剪
388 | $wait_time = 1
389 | $max_retries = 100
390 | $segment_time = 600
391 | $ocr = $false
392 | $document_image = $true
393 | $scene_detector = "AdaptiveDetector" # from ["ContentDetector","AdaptiveDetector","HashDetector","HistogramDetector","ThresholdDetector"]
394 | $scene_threshold = 0.0 # default value ["ContentDetector": 27.0, "AdaptiveDetector": 3.0, "HashDetector": 0.395, "HistogramDetector": 0.05, "ThresholdDetector": 12]
395 | $scene_min_len = 15
396 | $scene_luma_only = $false
397 | ```
398 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/qinglong-captions/c9376f3bf325675c7dc10b5da702e1e0eb629947/config/__init__.py
--------------------------------------------------------------------------------
/config/config.py:
--------------------------------------------------------------------------------
1 | """Configuration constants for the dataset processing."""
2 |
3 | from typing import Tuple, List, Dict, Any
4 | import os
5 | import toml
6 | import pyarrow as pa
7 |
8 | # Base image extensions
9 | BASE_IMAGE_EXTENSIONS: List[str] = [
10 | ".png",
11 | ".jpg",
12 | ".jpeg",
13 | ".gif",
14 | ".webp",
15 | ".bmp",
16 | ".ico",
17 | ".tif",
18 | ".tiff",
19 | ".PNG",
20 | ".JPG",
21 | ".JPEG",
22 | ".GIF",
23 | ".WEBP",
24 | ".BMP",
25 | ".ICO",
26 | ".TIF",
27 | ".TIFF",
28 | ]
29 |
30 | BASE_ANIMATION_EXTENSIONS: List[str] = [
31 | ".gif",
32 | ".webp",
33 | ".GIF",
34 | ".WEBP",
35 | ]
36 |
37 | BASE_VIDEO_EXTENSIONS: List[str] = [
38 | ".mp4",
39 | ".webm",
40 | ".avi",
41 | ".mkv",
42 | ".mov",
43 | ".flv",
44 | ".wmv",
45 | ".m4v",
46 | ".mpg",
47 | ".mpeg",
48 | ".MP4",
49 | ".WEBM",
50 | ".AVI",
51 | ".MKV",
52 | ".MOV",
53 | ".FLV",
54 | ".WMV",
55 | ".M4V",
56 | ".MPG",
57 | ".MPEG",
58 | ]
59 |
60 | BASE_AUDIO_EXTENSIONS: List[str] = [
61 | ".mp3",
62 | ".wav",
63 | ".ogg",
64 | ".flac",
65 | ".m4a",
66 | ".wma",
67 | ".aac",
68 | ".aiff",
69 | ".aifc",
70 | ".aif",
71 | ".au",
72 | ".snd",
73 | ".mid",
74 | ".midi",
75 | ".mka",
76 | ".MP3",
77 | ".WAV",
78 | ".OGG",
79 | ".FLAC",
80 | ".M4A",
81 | ".WMA",
82 | ".AAC",
83 | ".AIFF",
84 | ".AIFC",
85 | ".AIF",
86 | ".AU",
87 | ".SND",
88 | ".MID",
89 | ".MIDI",
90 | ".MKA",
91 | ]
92 |
93 | BASE_APPLICATION_EXTENSIONS: List[str] = [
94 | ".pdf",
95 | ".PDF",
96 | ]
97 |
98 |
99 | def get_supported_extensions(media_type: str = "image") -> Tuple[str, ...]:
100 | """Get all supported media extensions including optional formats."""
101 | if media_type == "image" or media_type == "animation":
102 | extensions = (
103 | BASE_IMAGE_EXTENSIONS.copy()
104 | if media_type == "image"
105 | else BASE_ANIMATION_EXTENSIONS.copy()
106 | )
107 |
108 | # Try to add AVIF support
109 | try:
110 | import pillow_avif
111 |
112 | extensions.extend([".avif", ".AVIF"])
113 | except ImportError:
114 | pass
115 |
116 | # Try to add JPEG-XL support
117 | try:
118 | import pillow_jxl
119 |
120 | extensions.extend([".jxl", ".JXL"])
121 | except ImportError:
122 | pass
123 |
124 | try:
125 | from pillow_heif import register_heif_opener
126 |
127 | register_heif_opener()
128 | extensions.extend([".heic", ".heif", ".HEIC", ".HEIF"])
129 | except ImportError:
130 | pass
131 |
132 | if media_type == "animation":
133 | try:
134 | from apng import APNG
135 |
136 | extensions.extend([".apng", ".APNG"])
137 | except ImportError:
138 | pass
139 |
140 | elif media_type == "video":
141 | extensions = BASE_VIDEO_EXTENSIONS.copy()
142 | elif media_type == "audio":
143 | extensions = BASE_AUDIO_EXTENSIONS.copy()
144 | elif media_type == "application":
145 | extensions = BASE_APPLICATION_EXTENSIONS.copy()
146 |
147 | return tuple(extensions)
148 |
149 |
150 | def load_toml_config(config_path: str, section: str) -> Dict[str, Any]:
151 | """Load a configuration section from a TOML file.
152 |
153 | Args:
154 | config_path: Path to the TOML file
155 | section: Name of the section to load
156 |
157 | Returns:
158 | Dictionary containing the configuration data
159 | """
160 | if not os.path.exists(config_path):
161 | raise FileNotFoundError(f"Config file not found: {config_path}")
162 |
163 | try:
164 | config = toml.load(config_path)
165 | section_data = config.get(section, {})
166 |
167 | if not section_data:
168 | raise ValueError(f"No {section} configuration found in TOML file")
169 |
170 | return section_data
171 | except Exception as e:
172 | raise ValueError(f"Failed to parse config file: {str(e)}")
173 |
174 |
175 | def load_schema_from_toml(config_path: str) -> List[Tuple[str, str]]:
176 | """Load dataset schema from a TOML file.
177 |
178 | Args:
179 | config_path: Path to the TOML file containing schema definition
180 |
181 | Returns:
182 | List of tuples containing (field_name, field_type)
183 | """
184 | schema_data = load_toml_config(config_path, "schema")
185 | fields = schema_data.get("fields", [])
186 | return [(field["name"], field["type"]) for field in fields]
187 |
188 |
189 | def load_colors_from_toml(config_path: str) -> Dict[str, str]:
190 | """Load console colors from a TOML file.
191 |
192 | Args:
193 | config_path: Path to the TOML file containing colors definition
194 |
195 | Returns:
196 | Dictionary mapping media types to color names
197 | """
198 | return load_toml_config(config_path, "colors")
199 |
200 |
201 | def load_prompts_from_toml(config_path: str) -> Dict[str, str]:
202 | """Load prompts from a TOML file.
203 |
204 | Args:
205 | config_path: Path to the TOML file containing prompts definition
206 |
207 | Returns:
208 | Dictionary containing prompt configurations
209 | """
210 | return load_toml_config(config_path, "prompts")
211 |
212 |
213 | # Default schema definition
214 | DEFAULT_DATASET_SCHEMA = [
215 | ("uris", pa.string()),
216 | ("mime", pa.string()),
217 | ("width", pa.int32()),
218 | ("height", pa.int32()),
219 | ("channels", pa.int32()),
220 | ("depth", pa.int32()),
221 | ("hash", pa.string()),
222 | ("size", pa.int64()),
223 | ("has_audio", pa.bool_()),
224 | ("duration", pa.int32()),
225 | ("num_frames", pa.int32()),
226 | ("frame_rate", pa.float32()),
227 | ("blob", pa.large_binary()),
228 | ("captions", pa.list_(pa.string())),
229 | ]
230 |
231 | # Default console colors
232 | DEFAULT_CONSOLE_COLORS = {
233 | "image": "green",
234 | "animation": "bold green",
235 | "video": "magenta",
236 | "audio": "orange1",
237 | "application": "bright_red",
238 | "text": "yellow",
239 | "caption": "yellow",
240 | "unknown": "cyan",
241 | }
242 |
243 | # Current active configurations - defaults to built-in values
244 | DATASET_SCHEMA = DEFAULT_DATASET_SCHEMA.copy()
245 | CONSOLE_COLORS = DEFAULT_CONSOLE_COLORS.copy()
246 | SYSTEM_PROMPT = "" # Will be loaded from config
247 |
248 |
249 | def load_config(config_path: str) -> None:
250 | """Load all configurations from a TOML file.
251 |
252 | Args:
253 | config_path: Path to the TOML file containing configurations
254 | """
255 | global DATASET_SCHEMA, CONSOLE_COLORS, SYSTEM_PROMPT
256 |
257 | try:
258 | DATASET_SCHEMA = load_schema_from_toml(config_path)
259 | except Exception as e:
260 | print(f"Warning: Failed to load schema configuration: {e}")
261 |
262 | try:
263 | colors = load_colors_from_toml(config_path)
264 | if colors:
265 | CONSOLE_COLORS.update(colors)
266 | except Exception as e:
267 | print(f"Warning: Failed to load colors configuration: {e}")
268 |
269 | try:
270 | prompts = load_prompts_from_toml(config_path)
271 | if prompts and "system_prompt" in prompts:
272 | SYSTEM_PROMPT = prompts["system_prompt"]
273 | except Exception as e:
274 | print(f"Warning: Failed to load prompts configuration: {e}")
275 |
--------------------------------------------------------------------------------
/datasets/put datasets here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/qinglong-captions/c9376f3bf325675c7dc10b5da702e1e0eb629947/datasets/put datasets here
--------------------------------------------------------------------------------
/lanceExport.ps1:
--------------------------------------------------------------------------------
1 | # Input parameters | 输入参数
2 | param(
3 | [string]$lance_file = "./datasets/dataset.lance", # Lance dataset path | Lance数据集路径
4 | [string]$output_dir = "./datasets", # Output directory | 输出目录
5 | [string]$version = "gemini", # Dataset version (gemini/WDtagger/pixtral)
6 | [bool]$not_clip_with_caption = $false # Not clip with caption | 不根据caption裁剪
7 | )
8 |
9 | # Set working directory | 设置工作目录
10 | Set-Location $PSScriptRoot
11 |
12 | # Activate virtual environment | 激活虚拟环境
13 | $venvPaths = @(
14 | "venv/Scripts/activate",
15 | ".venv/Scripts/activate",
16 | "venv/bin/Activate.ps1",
17 | ".venv/bin/activate.ps1"
18 | )
19 |
20 | $activated = $false
21 | foreach ($path in $venvPaths) {
22 | if (Test-Path $path) {
23 | Write-Output "Activating virtual environment: $path"
24 | & $path
25 | $activated = $true
26 | break
27 | }
28 | }
29 |
30 | if (-not $activated) {
31 | Write-Error "No virtual environment found. Please create one first."
32 | exit 1
33 | }
34 |
35 | # Set environment variables | 设置环境变量
36 | $env:HF_HOME = "huggingface"
37 | $env:XFORMERS_FORCE_DISABLE_TRITON = "1"
38 | $env:HF_ENDPOINT = "https://hf-mirror.com"
39 |
40 | # Prepare arguments | 准备参数
41 | $arguments = @(
42 | "-m",
43 | "module.lanceexport",
44 | $lance_file,
45 | "--output_dir=$output_dir"
46 | )
47 |
48 | if ($version) {
49 | $arguments += "--version=$version"
50 | }
51 |
52 | if ($not_clip_with_caption) {
53 | $arguments += "--not_clip_with_caption"
54 | }
55 |
56 | # Run export script | 运行导出脚本
57 | Write-Output "Starting export from $lance_file to $output_dir"
58 | & python $arguments
59 |
60 | # Wait for user input before closing | 等待用户输入后关闭
61 | Write-Output "`nExport finished. Press Enter to exit..."
62 | $null = Read-Host
--------------------------------------------------------------------------------
/lanceImport.ps1:
--------------------------------------------------------------------------------
1 | # Input parameters
2 | $train_data_dir = "./datasets" # input images path | 图片输入路径
3 | $caption_dir = $null # Optional caption files directory | 可选的描述文件目录
4 | $output_name = "dataset" # Output dataset name | 输出数据集名称
5 | $no_save_binary = $false # Don't save binary data | 不保存二进制数据
6 | $not_save_disk = $false # Load into memory only | 仅加载到内存
7 | $import_mode = 0 # Video import mode: 0=All, 1=Video only, 2=Audio only, 3=Split | 视频导入模式
8 | $tag = "gemini" # Dataset tag | 数据集标签
9 |
10 | # Activate virtual environment
11 | Set-Location $PSScriptRoot
12 | if ($env:OS -ilike "*windows*") {
13 | if (Test-Path "./venv/Scripts/activate") {
14 | Write-Output "Windows venv"
15 | ./venv/Scripts/activate
16 | }
17 | elseif (Test-Path "./.venv/Scripts/activate") {
18 | Write-Output "Windows .venv"
19 | ./.venv/Scripts/activate
20 | }
21 | }
22 | elseif (Test-Path "./venv/bin/activate") {
23 | Write-Output "Linux venv"
24 | ./venv/bin/Activate.ps1
25 | }
26 | elseif (Test-Path "./.venv/bin/activate") {
27 | Write-Output "Linux .venv"
28 | ./.venv/bin/activate.ps1
29 | }
30 |
31 | # Run the import script
32 | $args = @(
33 | $train_data_dir
34 | )
35 | if ($caption_dir) { $args += "--caption_dir=$caption_dir" }
36 | if ($output_name) { $args += "--output_name=$output_name" }
37 | if ($no_save_binary) { $args += "--no_save_binary" }
38 | if ($not_save_disk) { $args += "--not_save_disk" }
39 | $args += "--import_mode=$import_mode"
40 | $args += "--tag=$tag"
41 |
42 | python -m module.lanceImport @args
43 |
44 | Write-Output "Import finished"
45 | Read-Host | Out-Null
--------------------------------------------------------------------------------
/module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/qinglong-captions/c9376f3bf325675c7dc10b5da702e1e0eb629947/module/__init__.py
--------------------------------------------------------------------------------
/module/captioner.py:
--------------------------------------------------------------------------------
1 | import lance
2 | from rich.progress import (
3 | Progress,
4 | SpinnerColumn,
5 | TextColumn,
6 | BarColumn,
7 | TaskProgressColumn,
8 | TimeRemainingColumn,
9 | TimeElapsedColumn,
10 | TransferSpeedColumn,
11 | MofNCompleteColumn,
12 | )
13 | from rich.console import Console
14 | from PIL import Image
15 | import json
16 | import pyarrow as pa
17 | from module.lanceImport import transform2lance
18 | from module.lanceexport import extract_from_lance
19 | from module.api_handler import api_process_batch, process_llm_response
20 | from module.scenedetect import SceneDetector, run_async_in_thread
21 | from utils.stream_util import (
22 | split_media_stream_clips,
23 | split_video_with_imageio_ffmpeg,
24 | get_video_duration,
25 | )
26 | from config.config import (
27 | BASE_VIDEO_EXTENSIONS,
28 | BASE_AUDIO_EXTENSIONS,
29 | BASE_APPLICATION_EXTENSIONS,
30 | )
31 | import re
32 | import argparse
33 | import toml
34 | import pysrt
35 | from pathlib import Path
36 | import base64
37 | import asyncio
38 |
39 | Image.MAX_IMAGE_PIXELS = None # Disable image size limit check
40 |
41 | console = Console()
42 |
43 |
44 | def process_batch(args, config):
45 | # Load the dataset
46 | if not isinstance(args.dataset_dir, lance.LanceDataset):
47 | if args.gemini_api_key == "" and args.pixtral_api_key == "":
48 | dataset = transform2lance(dataset_dir=args.dataset_dir)
49 | else:
50 | dataset = transform2lance(dataset_dir=args.dataset_dir, save_binary=False)
51 |
52 | scanner = dataset.scanner(
53 | columns=["uris", "blob", "mime", "captions", "duration", "hash"],
54 | scan_in_order=True,
55 | late_materialization=["blob"],
56 | batch_size=1,
57 | )
58 | total_rows = dataset.count_rows()
59 |
60 | with Progress(
61 | "[progress.description]{task.description}",
62 | SpinnerColumn(spinner_name="dots"),
63 | MofNCompleteColumn(separator="/"),
64 | BarColumn(bar_width=40, complete_style="green", finished_style="bold green"),
65 | TextColumn("•"),
66 | TaskProgressColumn(),
67 | TextColumn("•"),
68 | TransferSpeedColumn(),
69 | TextColumn("•"),
70 | TimeElapsedColumn(),
71 | TextColumn("•"),
72 | TimeRemainingColumn(),
73 | expand=True,
74 | ) as progress:
75 | task = progress.add_task("[bold cyan]Processing media...", total=total_rows)
76 |
77 | results = []
78 | scene_detectors = {}
79 | processed_filepaths = []
80 | for batch in scanner.to_batches():
81 | filepaths = batch["uris"].to_pylist()
82 | mime = batch["mime"].to_pylist()
83 | duration = batch["duration"].to_pylist()
84 | sha256hash = batch["hash"].to_pylist()
85 |
86 | for filepath, mime, duration, sha256hash in zip(
87 | filepaths, mime, duration, sha256hash
88 | ):
89 | # 创建场景检测器,但异步初始化它(不阻塞主线程)
90 | scene_detector = None
91 | if (
92 | args.scene_threshold > 0
93 | and args.scene_min_len > 0
94 | and mime.startswith("video")
95 | ):
96 | scene_detector = SceneDetector(
97 | detector=args.scene_detector,
98 | threshold=args.scene_threshold,
99 | min_scene_len=args.scene_min_len,
100 | luma_only=args.scene_luma_only,
101 | console=progress,
102 | )
103 | # 启动场景检测,直接使用协程对象
104 | coroutine = scene_detector.detect_scenes_async(filepath)
105 | run_async_in_thread(coroutine)
106 | # 保存检测器实例以便后续使用
107 | scene_detectors[filepath] = scene_detector
108 |
109 | if (
110 | mime.startswith("image")
111 | or duration <= (args.segment_time + 1) * 1000
112 | ):
113 |
114 | output = api_process_batch(
115 | uri=filepath,
116 | mime=mime,
117 | config=config,
118 | args=args,
119 | sha256hash=sha256hash,
120 | progress=progress,
121 | task_id=task,
122 | )
123 |
124 | output = _postprocess_caption_content(
125 | output,
126 | filepath,
127 | args,
128 | )
129 |
130 | else:
131 | console.print(
132 | f"[blue]{filepath} video > {args.segment_time} seconds[/blue]"
133 | )
134 | console.print(f"[blue]split video[/blue]")
135 |
136 | # 创建用于分割的字幕文件
137 | subs = pysrt.SubRipFile()
138 |
139 | # 计算分块
140 | duration_seconds = duration / 1000 # 将毫秒转换为秒
141 | chunk_duration = args.segment_time
142 | num_chunks = int(
143 | (duration_seconds + chunk_duration - 1) // chunk_duration
144 | )
145 |
146 | # 创建字幕条目
147 | for i in range(num_chunks):
148 | start_time = i * chunk_duration
149 | end_time = min((i + 1) * chunk_duration, duration_seconds)
150 |
151 | # 创建字幕条目
152 | sub = pysrt.SubRipItem(
153 | index=i,
154 | start=pysrt.SubRipTime(seconds=start_time),
155 | end=pysrt.SubRipTime(seconds=end_time),
156 | text=f"Chunk {i + 1}",
157 | )
158 | subs.append(sub)
159 |
160 | for sub in subs:
161 | console.print(f"[blue]Subtitles created:[/blue] {sub}")
162 | try:
163 | split_video_with_imageio_ffmpeg(
164 | Path(filepath),
165 | subs,
166 | save_caption_func=None,
167 | segment_time=args.segment_time,
168 | )
169 | except Exception as e:
170 | # 使用字幕分割视频
171 | meta_type = "video" if mime.startswith("video") else "audio"
172 | console.print(
173 | f"[red]Error splitting video with imageio-ffmpeg: {e}[/red]"
174 | )
175 | split_media_stream_clips(Path(filepath), meta_type, subs)
176 |
177 | pathfile = Path(filepath)
178 | clip_dir = pathfile.parent / f"{pathfile.stem}_clip"
179 |
180 | search_pattern = f"*{pathfile.suffix}"
181 | files = sorted(clip_dir.glob(search_pattern))
182 |
183 | merged_subs = pysrt.SubRipFile()
184 | total_duration = 0
185 |
186 | # 使用全局变量来存储clip_task_id
187 | global clip_task_id
188 |
189 | # 检查是否已经存在clip任务
190 | if "clip_task_id" in globals() and clip_task_id in [
191 | task.id for task in progress.tasks
192 | ]:
193 | # 重置已存在的clip任务
194 | progress.reset(
195 | clip_task_id,
196 | total=num_chunks,
197 | visible=True,
198 | description=f"[cyan]Processing clips...",
199 | )
200 | clip_task = clip_task_id
201 | else:
202 | # 创建新的clip任务
203 | clip_task = progress.add_task(
204 | f"[cyan]Processing clips...", total=num_chunks
205 | )
206 | clip_task_id = clip_task
207 | for i in range(num_chunks):
208 | sub_path = Path(filepath).with_suffix(".srt")
209 | if sub_path.exists():
210 | sub = pysrt.open(sub_path, encoding="utf-8")
211 | merged_subs.extend(sub)
212 |
213 | console.print(
214 | f"[yellow]Processing chunk {i+1}/{num_chunks}[/yellow]"
215 | )
216 | uri = files[i]
217 |
218 | chunk_output = api_process_batch(
219 | uri=uri,
220 | mime=mime,
221 | config=config,
222 | args=args,
223 | sha256hash=sha256hash,
224 | progress=progress,
225 | task_id=task,
226 | )
227 |
228 | console.print(
229 | f"[green]API processing complete for chunk {i+1}[/green]"
230 | )
231 |
232 | console.print(
233 | f"[yellow]Post-processing chunk output...[/yellow]"
234 | )
235 | chunk_output = _postprocess_caption_content(
236 | chunk_output, uri, args
237 | )
238 |
239 | chunk_subs = pysrt.from_string(chunk_output)
240 | # 检查并删除超时的字幕
241 | for sub in list(chunk_subs): # 使用list创建副本以便安全删除
242 | if (
243 | sub.start.ordinal > args.segment_time * 1000
244 | ): # 转换为毫秒比较
245 | chunk_subs.remove(sub)
246 |
247 | if i > 0:
248 | last_duration = get_video_duration(files[i - 1])
249 |
250 | total_duration += int(float(last_duration))
251 | # 将纯毫秒单位转换为分、秒、毫秒
252 | last_duration_minutes = int(total_duration / 60000)
253 | last_duration_seconds = int((total_duration % 60000) / 1000)
254 | last_duration_milliseconds = total_duration % 1000
255 |
256 | console.print(
257 | f"[yellow]Total shift duration: {last_duration_minutes}m {last_duration_seconds}s {last_duration_milliseconds}ms[/yellow]"
258 | )
259 | # Shift all subtitles in the chunk
260 | chunk_subs.shift(
261 | minutes=last_duration_minutes,
262 | seconds=last_duration_seconds,
263 | milliseconds=last_duration_milliseconds,
264 | )
265 |
266 | # Extend merged subtitles with the shifted chunk
267 | merged_subs.extend(chunk_subs)
268 | console.print(
269 | f"[green]Successfully merged chunk {i+1}. Total subtitles: {len(merged_subs)}[/green]"
270 | )
271 |
272 | progress.update(
273 | clip_task,
274 | advance=1,
275 | refresh=True,
276 | description=f"[yellow]merging complete for chunk [/yellow]",
277 | )
278 |
279 | # Mark the clip task as completed and hide it
280 | progress.update(clip_task, completed=num_chunks, visible=False)
281 |
282 | # Update indices to continue from the last subtitle
283 | merged_subs.clean_indexes()
284 | # 手动构建 SRT 格式
285 | output = ""
286 | for i, sub in enumerate(merged_subs, start=1):
287 | # 格式: 序号 + 时间戳 + 文本
288 | output += f"{i}\n"
289 | output += f"{sub.start} --> {sub.end}\n"
290 | output += f"{sub.text}\n\n"
291 | if output:
292 | console.print(
293 | f"[green]All subtitles merged successfully. Total: {len(merged_subs)}[/green]"
294 | )
295 |
296 | for file in files:
297 | file.unlink(missing_ok=True)
298 |
299 | results.append(output)
300 | processed_filepaths.append(filepath)
301 |
302 | filepath_path = Path(filepath)
303 | # Determine caption file extension based on media type
304 | if (
305 | filepath_path.suffix in BASE_VIDEO_EXTENSIONS
306 | or filepath_path.suffix in BASE_AUDIO_EXTENSIONS
307 | ):
308 | caption_path = filepath_path.with_suffix(".srt")
309 | elif filepath_path.suffix in BASE_APPLICATION_EXTENSIONS:
310 | caption_path = filepath_path.with_suffix(".md")
311 | else:
312 | caption_path = filepath_path.with_suffix(".txt")
313 | console.print(f"[blue]Processing caption for:[/blue] {filepath_path}")
314 | if isinstance(output, dict):
315 | console.print(
316 | f"[blue]Caption content length:[/blue] {len(output['description'])}"
317 | )
318 | else:
319 | console.print(f"[blue]Caption content length:[/blue] {len(output)}")
320 |
321 | if caption_path.suffix == ".srt":
322 | try:
323 | subs = pysrt.from_string(output)
324 | if scene_detector:
325 | # 检查场景检测是否已经完成
326 | console.print(
327 | f"[bold cyan]{scene_detectors[filepath].get_scene_list()}...[/bold cyan]"
328 | )
329 | if scene_detectors[filepath].get_scene_list() is None:
330 | scene_list = asyncio.run(
331 | scene_detectors[filepath].ensure_detection_complete(
332 | filepath
333 | )
334 | )
335 | else:
336 | scene_list = scene_detectors[filepath].get_scene_list()
337 | # 使用实例方法align_subtitle,传入scene_list参数
338 | console.print(
339 | f"[bold cyan]Aligning subtitles with scene changes...[/bold cyan]"
340 | )
341 | subs = scene_detectors[filepath].align_subtitle(
342 | subs,
343 | scene_list=scene_list,
344 | console=console,
345 | segment_time=args.segment_time,
346 | )
347 | subs.save(str(caption_path), encoding="utf-8")
348 | console.print(
349 | f"[green]Saved captions to {caption_path}[/green]"
350 | )
351 | except Exception as e:
352 | console.print(
353 | f"[yellow]pysrt validation failed: {e}, falling back to direct file write[/yellow]"
354 | )
355 | try:
356 | caption_path.write_text(output, encoding="utf-8")
357 | console.print(
358 | f"[green]Saved captions to {caption_path}[/green]"
359 | )
360 | except Exception as e:
361 | console.print(f"[red]Error saving SRT file: {e}[/red]")
362 | elif caption_path.suffix == ".md":
363 | try:
364 | with open(caption_path, "w", encoding="utf-8") as f:
365 | f.write(output)
366 | console.print(
367 | f"[green]Saved captions to {caption_path}[/green]"
368 | )
369 | except Exception as e:
370 | console.print(f"[red]Error saving MD file: {e}[/red]")
371 | else:
372 | try:
373 | if isinstance(output, list):
374 | with open(caption_path, "w", encoding="utf-8") as f:
375 | for line in output:
376 | f.write(line + "\n")
377 | elif isinstance(output, dict):
378 | with open(
379 | filepath_path.with_suffix(".json"),
380 | "w",
381 | encoding="utf-8",
382 | ) as f:
383 | json.dump(output, f, indent=2, ensure_ascii=False)
384 | with open(caption_path, "w", encoding="utf-8") as f:
385 | if "description" in output and output["description"]:
386 | f.write(output["description"])
387 | else:
388 | f.write("No description available")
389 | else:
390 | caption_path.write_text(output, encoding="utf-8")
391 | console.print(
392 | f"[green]Saved captions to {caption_path}[/green]"
393 | )
394 | except Exception as e:
395 | console.print(f"[red]Error saving TXT file: {e}[/red]")
396 |
397 | progress.update(task, advance=len(batch))
398 |
399 | # 处理完所有批次后,将主任务设置为不可见
400 | progress.update(task, visible=False)
401 |
402 | # Update dataset with new captions
403 | if results:
404 | # 确保每个caption都是单个字符串
405 | processed_captions = []
406 | for caption in results:
407 | if isinstance(caption, list):
408 | # 如果是列表,合并为单个字符串
409 | processed_captions.append("\n".join(caption))
410 | elif isinstance(caption, dict):
411 | processed_captions.append(json.dumps(caption, ensure_ascii=False))
412 | else:
413 | processed_captions.append(caption)
414 |
415 | table = pa.table(
416 | {
417 | "uris": pa.array(processed_filepaths, type=pa.string()),
418 | "captions": pa.array(
419 | [[caption] for caption in processed_captions],
420 | type=pa.list_(pa.string()),
421 | ),
422 | }
423 | )
424 |
425 | dataset.merge_insert(on="uris").when_matched_update_all().execute(table)
426 |
427 | try:
428 | dataset.tags.create("gemini", 1)
429 | except:
430 | dataset.tags.update("gemini", 1)
431 |
432 | console.print("[green]Successfully updated dataset with new captions[/green]")
433 |
434 | extract_from_lance(
435 | dataset, args.dataset_dir, clip_with_caption=not args.not_clip_with_caption
436 | )
437 |
438 |
439 | def _postprocess_caption_content(output, filepath, args):
440 | """
441 | postprocess_caption_content
442 | """
443 | if not output:
444 | console.print(f"[red]No caption content generated for {filepath}[/red]")
445 | return ""
446 |
447 | if isinstance(output, list):
448 | # 检查是否为OCRPageObject对象列表
449 | if (
450 | len(output) > 0
451 | and hasattr(output[0], "markdown")
452 | and hasattr(output[0], "index")
453 | ):
454 | combined_output = ""
455 | for page in output:
456 | # 添加页面索引作为HTML页眉和页脚
457 | page_index = page.index if hasattr(page, "index") else "unknown"
458 | # 添加HTML页眉
459 | combined_output += f'\n Page {page_index+1} \n\n\n'
460 | # 添加页面内容
461 | page_markdown = (
462 | page.markdown if hasattr(page, "markdown") else str(page)
463 | )
464 | # 替换图片路径,将图片路径改为上一级目录
465 | if hasattr(page, "images") and args.document_image:
466 | # 查找并替换所有图片引用格式 
467 | img_pattern = r"!\[(.*?)\]\(([^/)]+)\)"
468 | parent_dir = Path(filepath).stem
469 | page_markdown = re.sub(
470 | img_pattern,
471 | lambda m: f"})",
472 | page_markdown,
473 | )
474 |
475 | if hasattr(page, "images") and args.document_image:
476 | for image in page.images:
477 | if hasattr(image, "image_base64") and image.image_base64:
478 | try:
479 | base64_str = image.image_base64
480 | # 处理data URL格式
481 | if base64_str.startswith("data:"):
482 | # 提取实际的base64内容
483 | base64_content = base64_str.split(",", 1)[1]
484 | image_data = base64.b64decode(base64_content)
485 | else:
486 | image_data = base64.b64decode(base64_str)
487 |
488 | image_filename = image.id
489 | image_dir = Path(filepath).with_suffix("")
490 | image_dir.mkdir(parents=True, exist_ok=True)
491 | image_path = image_dir / image_filename
492 | with open(image_path, "wb") as img_file:
493 | img_file.write(image_data)
494 | except Exception as e:
495 | console.print(
496 | f"[yellow]Error saving OCR image: {e}[/yellow]"
497 | )
498 | # 这里添加页面内容,只添加一次
499 | combined_output += f"{page_markdown}\n\n"
500 | # 添加HTML页脚和分隔符
501 | combined_output += f'\n\n'
502 | combined_output += '
\n\n'
503 | output = combined_output
504 | else:
505 | output = "\n".join(output)
506 |
507 | # 确保字幕内容格式正确
508 | output = output.strip()
509 | if not output.strip():
510 | console.print(f"[red]Empty caption content for {filepath}[/red]")
511 | return ""
512 |
513 | # 格式化时间戳 - 只处理视频和音频文件的字幕
514 | if (
515 | Path(filepath).suffix in BASE_VIDEO_EXTENSIONS
516 | or Path(filepath).suffix in BASE_AUDIO_EXTENSIONS
517 | ):
518 | # 确保字幕内容格式正确
519 | output = output.strip()
520 | if not output.strip():
521 | console.print(f"[red]Empty caption content for {filepath}[/red]")
522 | return ""
523 |
524 | # 使用单一的正则表达式和处理函数来修复时间戳格式
525 | # 创建一个匹配各种时间戳格式的模式
526 | timestamp_pattern = re.compile(
527 | # 匹配格式1: M:SS,mmm (单位数分钟)
528 | r"(? argparse.ArgumentParser:
576 | parser = argparse.ArgumentParser()
577 |
578 | parser.add_argument("dataset_dir", type=str, help="directory for dataset")
579 |
580 | parser.add_argument(
581 | "--gemini_api_key",
582 | type=str,
583 | default="",
584 | help="API key for gemini API",
585 | )
586 |
587 | parser.add_argument(
588 | "--gemini_model_path",
589 | type=str,
590 | default="gemini-exp-1206",
591 | help="Model path for gemini",
592 | )
593 |
594 | parser.add_argument(
595 | "--step_api_key",
596 | type=str,
597 | default="",
598 | help="API key for step API",
599 | )
600 |
601 | parser.add_argument(
602 | "--step_model_path",
603 | type=str,
604 | default="step-1.5v-mini",
605 | help="video model for step",
606 | )
607 |
608 | parser.add_argument(
609 | "--qwenVL_api_key",
610 | type=str,
611 | default="",
612 | help="API key for qwenVL API",
613 | )
614 |
615 | parser.add_argument(
616 | "--qwenVL_model_path",
617 | type=str,
618 | default="qwen-vl-max-latest",
619 | help="video model for qwenVL",
620 | )
621 |
622 | parser.add_argument(
623 | "--pixtral_api_key",
624 | type=str,
625 | default="",
626 | help="API key for pixtral API",
627 | )
628 |
629 | parser.add_argument(
630 | "--pixtral_model_path",
631 | type=str,
632 | default="pixtral-large-2411",
633 | help="Model path for pixtral",
634 | )
635 |
636 | parser.add_argument(
637 | "--glm_api_key",
638 | type=str,
639 | default="",
640 | help="API key for glm API",
641 | )
642 |
643 | parser.add_argument(
644 | "--glm_model_path",
645 | type=str,
646 | default="glm-4v-plus-0111",
647 | help="Model path for glm",
648 | )
649 |
650 | parser.add_argument(
651 | "--dir_name",
652 | action="store_true",
653 | help="Use the directory name as the dataset name",
654 | )
655 |
656 | parser.add_argument(
657 | "--mode",
658 | type=str,
659 | default="all",
660 | help="Mode for processing the dataset",
661 | )
662 |
663 | parser.add_argument(
664 | "--config",
665 | type=str,
666 | default="config/config.toml",
667 | help="Path to config file",
668 | )
669 |
670 | parser.add_argument(
671 | "--not_clip_with_caption",
672 | action="store_true",
673 | help="Not clip with caption",
674 | )
675 |
676 | parser.add_argument(
677 | "--wait_time",
678 | type=int,
679 | default=1,
680 | help="Wait time",
681 | )
682 |
683 | parser.add_argument(
684 | "--max_retries",
685 | type=int,
686 | default=20,
687 | help="Max retries",
688 | )
689 |
690 | parser.add_argument(
691 | "--segment_time",
692 | type=int,
693 | default=600,
694 | help="Segment time",
695 | )
696 |
697 | parser.add_argument(
698 | "--ocr",
699 | action="store_true",
700 | help="Use OCR to extract text from image",
701 | )
702 |
703 | parser.add_argument(
704 | "--document_image",
705 | action="store_true",
706 | help="Use OCR to extract image from document",
707 | )
708 |
709 | parser.add_argument(
710 | "--scene_detector",
711 | type=str,
712 | choices=[
713 | "ContentDetector",
714 | "AdaptiveDetector",
715 | "HashDetector",
716 | "HistogramDetector",
717 | "ThresholdDetector",
718 | ],
719 | default="AdaptiveDetector",
720 | help="Detector to use for scene detection",
721 | )
722 |
723 | parser.add_argument(
724 | "--scene_threshold",
725 | type=float,
726 | default=0.0,
727 | help="Threshold for scene detection",
728 | )
729 |
730 | parser.add_argument(
731 | "--scene_min_len",
732 | type=int,
733 | default=15,
734 | help="Minimum length(frames) for scene detection",
735 | )
736 |
737 | parser.add_argument(
738 | "--scene_luma_only",
739 | action="store_true",
740 | help="Only use luma (brightness) without color changes for scene detection.",
741 | )
742 |
743 | parser.add_argument(
744 | "--gemini_task",
745 | type=str,
746 | default="",
747 | help="Task for gemini-2.0-flash-exp",
748 | )
749 |
750 | parser.add_argument(
751 | "--tags_highlightrate",
752 | type=float,
753 | default=0.4,
754 | help="tags_highlightrate for check captions",
755 | )
756 |
757 | return parser
758 |
759 |
760 | if __name__ == "__main__":
761 | parser = setup_parser()
762 |
763 | args = parser.parse_args()
764 |
765 | config = toml.load(args.config)
766 |
767 | process_batch(args, config)
768 |
--------------------------------------------------------------------------------
/module/lanceImport.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset processing utilities for image-caption pairs using Lance format.
3 | This module provides tools for converting image-caption datasets to Lance format
4 | and accessing the data through PyTorch datasets.
5 | """
6 |
7 | import argparse
8 | import hashlib
9 | from dataclasses import dataclass
10 | from typing import Optional, List, Dict, Any, Callable, Tuple, Union
11 | import imageio.v3 as iio
12 | import lance
13 | import pyarrow as pa
14 | from PIL import Image, ImageMode
15 | from rich.progress import (
16 | Progress,
17 | SpinnerColumn,
18 | TextColumn,
19 | BarColumn,
20 | TaskProgressColumn,
21 | TimeRemainingColumn,
22 | TimeElapsedColumn,
23 | TransferSpeedColumn,
24 | MofNCompleteColumn,
25 | )
26 | from rich.console import Console
27 | import mimetypes
28 | from pathlib import Path
29 | from enum import Enum
30 | import numpy as np
31 | from mutagen import File as MutagenFile
32 |
33 | from config.config import (
34 | get_supported_extensions,
35 | DATASET_SCHEMA,
36 | CONSOLE_COLORS,
37 | )
38 |
39 |
40 | console = Console()
41 | image_extensions = get_supported_extensions("image")
42 | animation_extensions = get_supported_extensions("animation")
43 | video_extensions = get_supported_extensions("video")
44 | audio_extensions = get_supported_extensions("audio")
45 | application_extensions = get_supported_extensions("application")
46 |
47 |
48 | @dataclass
49 | class Metadata:
50 | """Metadata for media file."""
51 |
52 | uris: str # File path or URL
53 | mime: str # MIME type
54 | width: int = 0 # Image/video width in pixels
55 | height: int = 0 # Image/video height in pixels
56 | depth: int = 0 # Sample depth/width in bits
57 | channels: int = 0 # Number of channels (RGB=3, RGBA=4, mono=1, stereo=2)
58 | hash: str = "" # SHA256 hash
59 | size: int = 0 # File size in bytes
60 | has_audio: bool = False # True if audio is present
61 | duration: Optional[int] = None # Duration in milliseconds
62 | num_frames: Optional[int] = 1 # Number of frames
63 | frame_rate: float = 0.0 # Frames/samples per second
64 | blob: bytes = b"" # Binary data
65 |
66 | @property
67 | def filename(self) -> str:
68 | """File name without extension, derived from filepath."""
69 | return Path(self.uris).stem
70 |
71 | @property
72 | def ext(self) -> str:
73 | """File extension derived from filepath, including dot."""
74 | return Path(self.uris).suffix
75 |
76 | @property
77 | def bits_per_channel(self) -> int:
78 | """Get bits per channel."""
79 | return self.depth if self.channels > 0 else 0
80 |
81 | @property
82 | def bit_rate(self) -> int:
83 | """Calculate bit rate in bits per second.
84 |
85 | For audio: channels * depth * frame_rate
86 | For image: channels * depth * width * height * frame_rate
87 | """
88 | if self.duration == 0:
89 | return 0
90 |
91 | bits_per_sample = self.channels * self.bits_per_channel
92 | if self.width and self.height: # Image/Video
93 | bits_per_frame = bits_per_sample * self.width * self.height
94 | else: # Audio
95 | bits_per_frame = bits_per_sample
96 |
97 | return int(bits_per_frame * self.frame_rate)
98 |
99 |
100 | class VideoImportMode(Enum):
101 | """Import mode for video files."""
102 |
103 | ALL = 0 # Import complete video with audio
104 | VIDEO_ONLY = 1 # Import video without audio
105 | AUDIO_ONLY = 2 # Import audio only without video
106 | VIDEO_SPLIT_AUDIO = 3 # Split and import both video and audio separately
107 |
108 | @classmethod
109 | def from_int(cls, value: int) -> "VideoImportMode":
110 | """Convert integer to VideoImportMode."""
111 | for mode in cls:
112 | if mode.value == value:
113 | return mode
114 | raise ValueError(f"Invalid import mode value: {value}")
115 |
116 |
117 | class FileProcessor:
118 | """Utility class for processing files.
119 |
120 | This class provides methods for loading and processing file metadata,
121 | including size, format, and hash calculations.
122 |
123 | Example:
124 | processor = FileProcessor()
125 | metadata = processor.load_metadata("path/to/image.jpg")
126 | if metadata:
127 | print(f"Image size: {metadata.width}x{metadata.height}")
128 | """
129 |
130 | def _extract_audio_metadata(
131 | self, video: Any, file_path: str, meta: Dict[str, Any], save_binary: bool = True
132 | ) -> Optional[Tuple[Metadata, bytes]]:
133 | """Extract audio metadata from video file.
134 |
135 | Args:
136 | video: Video file object
137 | file_path: Path to video file
138 | meta: Video metadata
139 | save_binary: Whether to save binary data
140 |
141 | Returns:
142 | Tuple of (metadata, binary_data) if successful, None if failed
143 | """
144 | if not meta.get("has_audio"):
145 | console.print(
146 | f"[yellow]Warning: Video file {file_path} has no audio track[/yellow]"
147 | )
148 | return None
149 |
150 | # Get audio data
151 | audio_data = video.read_audio() # This returns numpy array of audio samples
152 | if audio_data is None:
153 | console.print(
154 | f"[yellow]Warning: Failed to extract audio data from {file_path}[/yellow]"
155 | )
156 | return None
157 |
158 | # Get audio binary data and calculate hash
159 | binary_data = audio_data.astype(np.int16).tobytes()
160 | audio_hash = hashlib.sha256(binary_data).hexdigest()
161 |
162 | # Create audio metadata
163 | duration = int(meta.get("duration", 0) * 1000) # Convert to milliseconds
164 | audio_metadata = Metadata(
165 | uris=file_path,
166 | mime=f"audio/wav",
167 | width=0,
168 | height=0,
169 | channels=audio_data.shape[1],
170 | depth=16, # 16-bit audio
171 | hash=audio_hash, # Use audio's own hash
172 | size=len(binary_data),
173 | has_audio=True,
174 | duration=duration,
175 | num_frames=int(duration * meta.get("audio_fps", 44100) / 1000),
176 | frame_rate=meta.get("audio_fps", 44100),
177 | blob=binary_data if save_binary else b"",
178 | )
179 |
180 | return audio_metadata, binary_data
181 |
182 | @staticmethod
183 | def load_metadata(
184 | file_path: str,
185 | save_binary: bool = True,
186 | import_mode: VideoImportMode = VideoImportMode.ALL,
187 | ) -> Optional[Metadata]:
188 | """Load and process image metadata.
189 |
190 | Args:
191 | file_path: Path to the media file
192 | save_binary: If True, store the binary data of the image
193 | import_mode: Mode for importing video components:
194 | - ALL: Complete video with audio
195 | - VIDEO_ONLY: Video without audio
196 | - AUDIO_ONLY: Audio only
197 | - VIDEO_SPLIT_AUDIO: Split video and audio
198 |
199 | Returns:
200 | Metadata object if successful, None if failed
201 |
202 | Raises:
203 | FileNotFoundError: If the image file doesn't exist
204 | IOError: If there's an error reading the file
205 | SyntaxError: If the image format is invalid
206 | """
207 | try:
208 | if file_path.endswith(image_extensions + animation_extensions):
209 | with Image.open(file_path) as img:
210 | # Get file pointer position
211 | pos = img.fp.tell()
212 | # Reset to beginning
213 | img.fp.seek(0)
214 | # Read data and calculate hash
215 | binary_data = img.fp.read()
216 | image_hash = hashlib.sha256(binary_data).hexdigest()
217 | # Restore position
218 | img.fp.seek(pos)
219 |
220 | # Get animation info if available
221 | duration = 0 # Initialize to 0 for accumulation
222 | n_frames = None
223 | frame_rate = 0.0
224 |
225 | if hasattr(img, "n_frames") and img.n_frames > 1:
226 | n_frames = img.n_frames
227 | # Get duration in milliseconds
228 | for frame in range(img.n_frames):
229 | img.seek(frame)
230 | duration += img.info.get("duration", 0)
231 | # Calculate frame rate from duration and frame count
232 | if duration > 0:
233 | frame_rate = (n_frames * 1000) / duration # Convert to fps
234 | else:
235 | duration = None
236 |
237 | # Get image MIME type, fallback to PIL format
238 | mime_type, _ = mimetypes.guess_type(file_path)
239 | mime = mime_type or f"image/{img.format.lower()}"
240 |
241 | # Get depth based on mode type
242 | channels = len(img.getbands())
243 | mode = img.mode
244 |
245 | # Try different ways to get bit depth
246 | depth = None
247 | # 1. Try img.bits first (most accurate, includes 12-bit)
248 | if hasattr(img, "bits"):
249 | depth = img.bits
250 | # 2. Try to get from tag for TIFF images (can have 12-bit)
251 | elif hasattr(img, "tag_v2"):
252 | bits = img.tag_v2.get(258) # BitsPerSample tag
253 | if bits:
254 | depth = bits[0] if isinstance(bits, tuple) else bits
255 | # 3. Fallback to mode info
256 | if depth is None:
257 | mode_info = ImageMode.getmode(mode)
258 | if mode_info.basetype == "1":
259 | depth = 1
260 | else:
261 | # Convert bytes to bits (note: this will show 16 for 12-bit images)
262 | type_size = int(mode_info.typestr[-1])
263 | depth = type_size * 8
264 |
265 | return Metadata(
266 | uris=file_path,
267 | mime=mime,
268 | width=img.size[0],
269 | height=img.size[1],
270 | depth=depth,
271 | channels=channels,
272 | hash=image_hash,
273 | size=Path(file_path).stat().st_size,
274 | has_audio=False,
275 | duration=duration,
276 | num_frames=n_frames,
277 | frame_rate=frame_rate,
278 | blob=binary_data if save_binary else b"",
279 | )
280 | elif file_path.endswith(video_extensions):
281 | try:
282 | # Get video metadata first
283 | meta = iio.immeta(file_path) or {}
284 |
285 | # Get video MIME type
286 | mime_type, _ = mimetypes.guess_type(file_path)
287 | extension = Path(file_path).suffix.lstrip(".")
288 | mime = mime_type or f"video/{extension}"
289 |
290 | # Get basic video info with safety checks
291 | size = meta.get("size", (0, 0))
292 | width = (
293 | int(size[0])
294 | if isinstance(size, (tuple, list)) and len(size) > 0
295 | else 0
296 | )
297 | height = (
298 | int(size[1])
299 | if isinstance(size, (tuple, list)) and len(size) > 1
300 | else 0
301 | )
302 |
303 | # Handle fps with safety checks
304 | fps = meta.get("fps", 0)
305 | frame_rate = (
306 | float(fps)
307 | if fps and not np.isinf(fps) and not np.isnan(fps)
308 | else 0.0
309 | )
310 |
311 | # Handle duration with safety checks
312 | dur = meta.get("duration", 0)
313 | duration = (
314 | int(dur * 1000)
315 | if dur and not np.isinf(dur) and not np.isnan(dur)
316 | else 0
317 | )
318 |
319 | # Calculate frames with safety checks
320 | n_frames = meta.get("nframes", 0)
321 | if not n_frames or np.isinf(n_frames) or np.isnan(n_frames):
322 | if frame_rate > 0 and duration > 0:
323 | n_frames = int(frame_rate * (duration / 1000))
324 | else:
325 | n_frames = 0
326 |
327 | # Get first frame for color info
328 | try:
329 | with iio.imopen(file_path, "r") as file:
330 | first_frame = file.read(index=0)
331 | channels = (
332 | first_frame.shape[2]
333 | if len(first_frame.shape) > 2
334 | else 1
335 | )
336 | depth = first_frame.dtype.itemsize * 8
337 | except Exception as e:
338 | console.print(
339 | f"[yellow]Warning: Could not read first frame from {file_path}: {e}[/yellow]"
340 | )
341 | channels = 3 # Assume RGB
342 | depth = 8 # Assume 8-bit
343 |
344 | # Read video binary data and calculate hash in chunks
345 | hasher = hashlib.sha256()
346 | binary_data = bytearray()
347 | chunk_size = 8192 # 8KB chunks
348 |
349 | with open(file_path, "rb") as f:
350 | while True:
351 | chunk = f.read(chunk_size)
352 | if not chunk:
353 | break
354 | hasher.update(chunk)
355 | if save_binary:
356 | binary_data.extend(chunk)
357 |
358 | video_hash = hasher.hexdigest()
359 |
360 | return Metadata(
361 | uris=file_path,
362 | mime=mime,
363 | width=width,
364 | height=height,
365 | depth=depth,
366 | channels=channels,
367 | hash=video_hash,
368 | size=Path(file_path).stat().st_size,
369 | has_audio=meta.get("has_audio", False),
370 | duration=duration,
371 | num_frames=n_frames,
372 | frame_rate=frame_rate,
373 | blob=bytes(binary_data) if save_binary else b"",
374 | )
375 | except Exception as e:
376 | console.print(
377 | f"[red]Error processing video {file_path}: {str(e)}[/red]"
378 | )
379 | return None
380 |
381 | elif file_path.endswith(audio_extensions):
382 | try:
383 | # Read audio file as binary first
384 | binary_data = Path(file_path).read_bytes()
385 | audio_hash = hashlib.sha256(binary_data).hexdigest()
386 |
387 | # Get audio MIME type
388 | mime_type, _ = mimetypes.guess_type(file_path)
389 | extension = Path(file_path).suffix.lstrip(".")
390 | mime = mime_type or f"audio/{extension}"
391 |
392 | # Try to get audio metadata using mutagen first
393 | audio = MutagenFile(file_path)
394 | if audio is not None:
395 | # Get duration in milliseconds
396 | duration = int(audio.info.length * 1000)
397 | # Get sample rate
398 | frame_rate = getattr(audio.info, "sample_rate", 44100)
399 | # Get number of channels
400 | channels = getattr(audio.info, "channels", 2)
401 | # Get bit depth if available
402 | depth = getattr(audio.info, "bits_per_sample", 16)
403 | # Calculate number of frames
404 | n_frames = int(audio.info.length * frame_rate)
405 | else:
406 | raise Exception("Could not read audio metadata")
407 |
408 | return Metadata(
409 | uris=file_path,
410 | mime=mime,
411 | width=0,
412 | height=0,
413 | depth=depth,
414 | channels=channels,
415 | hash=audio_hash,
416 | size=Path(file_path).stat().st_size,
417 | has_audio=True,
418 | duration=duration,
419 | num_frames=n_frames,
420 | frame_rate=frame_rate,
421 | blob=binary_data if save_binary else b"",
422 | )
423 | except Exception as e:
424 | console.print(
425 | f"[red]Error processing audio {file_path}: {str(e)}[/red]"
426 | )
427 | return None
428 |
429 | elif file_path.endswith(application_extensions):
430 | try:
431 | # Read application file as binary first
432 | binary_data = Path(file_path).read_bytes()
433 | application_hash = hashlib.sha256(binary_data).hexdigest()
434 |
435 | # Get application MIME type
436 | mime_type, _ = mimetypes.guess_type(file_path)
437 | extension = Path(file_path).suffix.lstrip(".")
438 | mime = mime_type or f"application/{extension}"
439 |
440 | return Metadata(
441 | uris=file_path,
442 | mime=mime,
443 | width=0,
444 | height=0,
445 | depth=0,
446 | channels=0,
447 | hash=application_hash,
448 | size=Path(file_path).stat().st_size,
449 | has_audio=False,
450 | duration=0,
451 | num_frames=0,
452 | frame_rate=0,
453 | blob=binary_data if save_binary else b"",
454 | )
455 | except Exception as e:
456 | console.print(
457 | f"[red]Error processing application {file_path}: {str(e)}[/red]"
458 | )
459 | return None
460 |
461 | except Exception as e:
462 | console.print(
463 | f"[red]Unexpected error processing {file_path}: {str(e)}[/red]"
464 | )
465 | return None
466 |
467 |
468 | def load_data(
469 | datasets_dir: str, texts_dir: Optional[str] = None
470 | ) -> List[Dict[str, Any]]:
471 | """
472 | Load image and caption data from directories.
473 |
474 | Args:
475 | datasets_dir: Directory containing images or videos
476 | texts_dir: Optional directory containing caption text files
477 |
478 | Returns:
479 | List of image-caption pairs
480 | """
481 | data = []
482 |
483 | if texts_dir:
484 | # Paired directory structure
485 | for file in Path(datasets_dir).iterdir():
486 | if not file.is_file() or not any(
487 | str(file).endswith(ext)
488 | for ext in (
489 | image_extensions
490 | + animation_extensions
491 | + video_extensions
492 | + audio_extensions
493 | + application_extensions
494 | )
495 | ):
496 | continue
497 |
498 | text_path = Path(texts_dir) / (file.stem + ".txt")
499 | srt_path = Path(texts_dir) / (file.stem + ".srt")
500 | md_path = Path(texts_dir) / (file.stem + ".md")
501 |
502 | caption = None
503 | if text_path.exists():
504 | with open(text_path, "r", encoding="utf-8") as f:
505 | caption = f.read().splitlines()
506 | elif md_path.exists():
507 | with open(md_path, "r", encoding="utf-8") as f:
508 | caption = [f.read()] # 将整个 Markdown 内容作为单个字符串
509 | elif srt_path.exists():
510 | with open(srt_path, "r", encoding="utf-8") as f:
511 | caption = [f.read()] # Store entire SRT content as a single string
512 | else:
513 | caption = []
514 |
515 | data.append({"file_path": str(file), "caption": caption})
516 | else:
517 | # Single directory structure
518 | datasets_path = Path(datasets_dir).absolute() # 转换为绝对路径
519 | for file_path in datasets_path.rglob("*"):
520 | if not file_path.is_file() or not any(
521 | str(file_path).endswith(ext)
522 | for ext in (
523 | image_extensions
524 | + animation_extensions
525 | + video_extensions
526 | + audio_extensions
527 | + application_extensions
528 | )
529 | ):
530 | continue
531 |
532 | text_path = file_path.with_suffix(".txt")
533 | srt_path = file_path.with_suffix(".srt")
534 | md_path = file_path.with_suffix(".md")
535 |
536 | caption = None
537 | if text_path.exists():
538 | with open(text_path, "r", encoding="utf-8") as f:
539 | caption = f.read().splitlines()
540 | elif md_path.exists():
541 | with open(md_path, "r", encoding="utf-8") as f:
542 | caption = [f.read()] # 将整个 Markdown 内容作为单个字符串
543 | elif srt_path.exists():
544 | with open(srt_path, "r", encoding="utf-8") as f:
545 | caption = [f.read()] # Store entire SRT content as a single string
546 | else:
547 | caption = []
548 |
549 | data.append({"file_path": str(file_path), "caption": caption})
550 |
551 | return data
552 |
553 |
554 | def process(
555 | data: List[Dict[str, Any]],
556 | save_binary: bool = True,
557 | import_mode: VideoImportMode = VideoImportMode.ALL,
558 | ) -> pa.RecordBatch:
559 | """
560 | Process image-caption pairs into Lance format.
561 |
562 | Args:
563 | data: List of dictionaries containing file paths and captions.
564 | save_binary: Whether to save binary data.
565 | import_mode: Mode for importing video components:
566 | - ALL: Complete video with audio
567 | - VIDEO_ONLY: Video without audio
568 | - AUDIO_ONLY: Audio only
569 | - VIDEO_SPLIT_AUDIO: Split video and audio
570 |
571 | Returns:
572 | A PyArrow RecordBatch containing the processed data.
573 | """
574 | processor = FileProcessor()
575 |
576 | with Progress(
577 | "[progress.description]{task.description}",
578 | SpinnerColumn(spinner_name="dots"),
579 | MofNCompleteColumn(separator="/"),
580 | BarColumn(bar_width=40, complete_style="green", finished_style="bold green"),
581 | TextColumn("•"),
582 | TaskProgressColumn(),
583 | TextColumn("•"),
584 | TransferSpeedColumn(),
585 | TextColumn("•"),
586 | TimeElapsedColumn(),
587 | TextColumn("•"),
588 | TimeRemainingColumn(),
589 | expand=True,
590 | transient=False, # 防止进度条随刷新滚动
591 | ) as progress:
592 |
593 | global console
594 |
595 | console = progress.console
596 |
597 | task = progress.add_task("[green]Processing file...", total=len(data))
598 |
599 | for item in data:
600 | file_path = item["file_path"]
601 | caption = item["caption"]
602 |
603 | console.print()
604 |
605 | # 根据文件类型选择颜色
606 | suffix = Path(file_path).suffix.lower()
607 | if suffix in image_extensions:
608 | if suffix in animation_extensions:
609 | color = CONSOLE_COLORS["animation"]
610 | media_type = "animation"
611 | else:
612 | color = CONSOLE_COLORS["image"]
613 | media_type = "image"
614 | elif suffix in video_extensions:
615 | color = CONSOLE_COLORS["video"]
616 | media_type = "video"
617 | elif suffix in audio_extensions:
618 | color = CONSOLE_COLORS["audio"]
619 | media_type = "audio"
620 | elif suffix in application_extensions:
621 | color = CONSOLE_COLORS["application"]
622 | media_type = "application"
623 | else:
624 | color = CONSOLE_COLORS["unknown"]
625 | media_type = "unknown"
626 |
627 | console.print(
628 | f"Processing {media_type} file [{color}]'{file_path}'[/{color}]"
629 | )
630 | console.print(f"Caption: {caption}", style=CONSOLE_COLORS["caption"])
631 |
632 | metadata = processor.load_metadata(file_path, save_binary, import_mode)
633 | if not metadata:
634 | progress.update(task, advance=1)
635 | continue
636 |
637 | # Get field names and create arrays
638 | field_names = [field[0] for field in DATASET_SCHEMA]
639 | arrays = []
640 | for field_name, field_type in DATASET_SCHEMA:
641 | if field_name == "filepath":
642 | value = str(Path(file_path).absolute())
643 | array = pa.array([value], type=field_type)
644 | elif field_name == "captions":
645 | array = pa.array([caption], type=field_type)
646 | elif field_name == "blob":
647 | array = pa.array([getattr(metadata, field_name)], type=field_type)
648 | else:
649 | value = getattr(metadata, field_name)
650 | # Convert None to appropriate default value based on type
651 | if value is None:
652 | if pa.types.is_integer(field_type):
653 | value = 0
654 | elif pa.types.is_floating(field_type):
655 | value = 0.0
656 | elif pa.types.is_boolean(field_type):
657 | value = False
658 | elif pa.types.is_string(field_type):
659 | value = ""
660 | array = pa.array([value], type=field_type)
661 | arrays.append(array)
662 |
663 | batch = pa.RecordBatch.from_arrays(
664 | arrays,
665 | names=field_names,
666 | )
667 |
668 | yield batch
669 | progress.update(task, advance=1)
670 |
671 |
672 | def transform2lance(
673 | dataset_dir: str,
674 | caption_dir: Optional[str] = None,
675 | output_name: str = "dataset",
676 | save_binary: bool = True,
677 | not_save_disk: bool = False,
678 | import_mode: VideoImportMode = VideoImportMode.ALL,
679 | tag: str = "gemini",
680 | load_condition: Callable[[str, Optional[str]], List[Dict[str, Any]]] = load_data,
681 | ) -> Optional[lance.LanceDataset]:
682 | """
683 | Transform image-caption pairs into Lance dataset.
684 |
685 | Args:
686 | dataset_dir: Directory containing training images
687 | caption_dir: Optional directory containing captions
688 | output_name: Name of output dataset
689 | save_binary: Whether to save binary data in the dataset.
690 | not_save_disk: If True, don't save to disk
691 | import_mode: Mode for importing video components:
692 | - ALL: Complete video with audio
693 | - VIDEO_ONLY: Video only
694 | - AUDIO_ONLY: Audio only
695 | - VIDEO_SPLIT_AUDIO: Split video and audio
696 | load_condition: Function to load data
697 |
698 | Returns:
699 | Lance dataset object or None if error occurs
700 | """
701 | data = load_condition(dataset_dir, caption_dir)
702 |
703 | schema = pa.schema(
704 | [
705 | pa.field(
706 | name,
707 | pa_type,
708 | metadata={b"lance-encoding:blob": b"true"} if name == "blob" else None,
709 | )
710 | for name, pa_type in DATASET_SCHEMA
711 | ]
712 | )
713 |
714 | try:
715 | reader = pa.RecordBatchReader.from_batches(
716 | schema, process(data, save_binary, import_mode)
717 | )
718 |
719 | dataset_path = Path(dataset_dir) / f"{output_name}.lance"
720 | mode = "append" if dataset_path.exists() else "create"
721 |
722 | lancedataset = lance.write_dataset(
723 | reader,
724 | str(dataset_path),
725 | schema,
726 | mode=mode if not_save_disk else "overwrite",
727 | )
728 |
729 | try:
730 | lancedataset.tags.create(tag, 1)
731 | except:
732 | lancedataset.tags.update(tag, 1)
733 |
734 | return lancedataset
735 |
736 | except AttributeError as e:
737 | console.print(f"[red]AttributeError: {e}[/red]")
738 | return None
739 |
740 |
741 | def setup_parser() -> argparse.ArgumentParser:
742 | """Setup argument parser."""
743 | parser = argparse.ArgumentParser(description="Transform dataset into Lance format")
744 | parser.add_argument(
745 | "dataset_dir", type=str, help="Directory containing training images"
746 | )
747 | parser.add_argument(
748 | "--caption_dir",
749 | type=str,
750 | default=None,
751 | help="Directory containing caption files",
752 | )
753 | parser.add_argument(
754 | "--output_name", type=str, default="dataset", help="Name of output dataset"
755 | )
756 | parser.add_argument(
757 | "--no_save_binary",
758 | action="store_true",
759 | help="Don't save binary data in the dataset",
760 | )
761 | parser.add_argument(
762 | "--not_save_disk",
763 | action="store_true",
764 | help="Load dataset into memory instead of saving to disk",
765 | )
766 | parser.add_argument(
767 | "--import_mode",
768 | type=int,
769 | default=0,
770 | choices=[0, 1, 2, 3],
771 | help="Video import mode: 0=Complete video with audio, 1=Video only, "
772 | "2=Audio only, 3=Split video and audio",
773 | )
774 | parser.add_argument(
775 | "--tag",
776 | type=str,
777 | default="gemini",
778 | help="Tag for the dataset",
779 | )
780 | return parser
781 |
782 |
783 | if __name__ == "__main__":
784 | parser = setup_parser()
785 | args = parser.parse_args()
786 |
787 | transform2lance(
788 | dataset_dir=args.dataset_dir,
789 | caption_dir=args.caption_dir,
790 | output_name=args.output_name,
791 | save_binary=not args.no_save_binary,
792 | not_save_disk=args.not_save_disk,
793 | import_mode=VideoImportMode.from_int(args.import_mode),
794 | tag=args.tag,
795 | )
796 |
--------------------------------------------------------------------------------
/module/lanceexport.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import lance
3 | import re
4 | from typing import Optional, Union, List, Dict, Any
5 | from rich.console import Console
6 | from rich.progress import (
7 | Progress,
8 | SpinnerColumn,
9 | TextColumn,
10 | BarColumn,
11 | TaskProgressColumn,
12 | TimeRemainingColumn,
13 | TimeElapsedColumn,
14 | TransferSpeedColumn,
15 | MofNCompleteColumn,
16 | )
17 | from config.config import get_supported_extensions, DATASET_SCHEMA, CONSOLE_COLORS
18 | from utils.stream_util import split_media_stream_clips, split_video_with_imageio_ffmpeg
19 | from pathlib import Path
20 | import pysrt
21 | import json
22 |
23 | console = Console()
24 | image_extensions = get_supported_extensions("image")
25 | animation_extensions = get_supported_extensions("animation")
26 | video_extensions = get_supported_extensions("video")
27 | audio_extensions = get_supported_extensions("audio")
28 | application_extensions = get_supported_extensions("application")
29 |
30 |
31 | def format_duration(duration_ms: int) -> str:
32 | """将毫秒转换为分:秒格式."""
33 | total_seconds = duration_ms // 1000
34 | minutes = total_seconds // 60
35 | seconds = total_seconds % 60
36 | return f"{minutes}:{seconds:02d}"
37 |
38 |
39 | def save_blob(
40 | uri: Path,
41 | blob: Union[bytes, lance.BlobFile],
42 | metadata: Dict[str, Any],
43 | media_type: str,
44 | ) -> bool:
45 | """Save binary blob to file.
46 |
47 | Args:
48 | uri: Target path
49 | blob: Binary data or BlobFile
50 | metadata: File metadata
51 | media_type: Type of media (image/video/audio)
52 |
53 | Returns:
54 | bool: True if successful
55 | """
56 | try:
57 | uri.parent.mkdir(parents=True, exist_ok=True)
58 |
59 | # Handle both bytes and BlobFile
60 | if isinstance(blob, lance.BlobFile):
61 | with open(uri, "wb") as f:
62 | while True:
63 | chunk = blob.read(8192) # Read in chunks
64 | if not chunk:
65 | break
66 | f.write(chunk)
67 | else:
68 | uri.write_bytes(blob)
69 |
70 | # Print media-specific metadata
71 | meta_info = []
72 | if media_type in ["image", "animation"]:
73 | meta_info.extend(
74 | [
75 | f"{metadata.get('width', 0)}x{metadata.get('height', 0)}",
76 | f"{metadata.get('channels', 0)}ch",
77 | (
78 | f"{metadata.get('num_frames', 1)} frames"
79 | if metadata.get("num_frames", 1) > 1
80 | else None
81 | ),
82 | ]
83 | )
84 | elif media_type == "video":
85 | duration = metadata.get("duration", 0)
86 | meta_info.extend(
87 | [
88 | f"{metadata.get('width', 0)}x{metadata.get('height', 0)}",
89 | f"{format_duration(duration)}",
90 | f"{metadata.get('frame_rate', 0):.1f}fps",
91 | ]
92 | )
93 | elif media_type == "audio":
94 | duration = metadata.get("duration", 0)
95 | meta_info.extend(
96 | [
97 | f"{metadata.get('channels', 0)}ch",
98 | f"{metadata.get('frame_rate', 0):.1f}Hz",
99 | f"{format_duration(duration)}",
100 | ]
101 | )
102 |
103 | elif media_type == "application":
104 | meta_info.extend(
105 | [
106 | f"{metadata.get('size', 0) / (1024 * 1024):.2f} MB",
107 | ]
108 | )
109 |
110 | meta_str = ", ".join(filter(None, meta_info))
111 | console.print()
112 |
113 | # 使用配置的颜色
114 | color = CONSOLE_COLORS.get(media_type, "white")
115 | console.print(
116 | f"[{color}]{media_type}: {uri} ({meta_str}) saved successfully.[/{color}]"
117 | )
118 | return True
119 | except Exception as e:
120 | console.print(f"[red]Error saving {media_type} {uri}: {e}[/red]")
121 | return False
122 |
123 |
124 | def save_caption(caption_path: str, caption_lines: List[str], media_type: str) -> bool:
125 | """Save caption data to disk.
126 |
127 | Args:
128 | caption_path: Path to save caption file
129 | caption_lines: List of caption lines
130 | media_type: Type of media (image/video/audio)
131 |
132 | Returns:
133 | bool: True if save successful, False otherwise
134 | """
135 | try:
136 | if not len(caption_lines):
137 | console.print(f"[red]No caption content found for {caption_path}[/red]")
138 | return False
139 | caption_path = Path(caption_path)
140 | caption_path.parent.mkdir(parents=True, exist_ok=True)
141 |
142 | if media_type == "audio" or media_type == "video":
143 | caption_path = caption_path.with_suffix(".srt")
144 | elif media_type == "application":
145 | caption_path = caption_path.with_suffix(".md")
146 | else:
147 | caption_path = caption_path.with_suffix(".txt")
148 |
149 | with open(caption_path, "w", encoding="utf-8") as f:
150 | if caption_path.suffix == ".srt":
151 | # For SRT files, preserve all lines including empty ones
152 | f.write("\n".join(caption_lines))
153 | elif caption_path.suffix == ".md":
154 | # For MD files, preserve original markdown formatting
155 | f.write("".join(caption_lines))
156 | else:
157 | # For TXT files, strip empty lines and whitespace
158 | for line in caption_lines:
159 | if "', "")
162 | .replace("", "")
163 | .replace("", "")
164 | )
165 |
166 | # Check if the content is in JSON format and parse it if possible
167 | if line.strip().startswith("{") and line.strip().endswith("}"):
168 | try:
169 | json_content = line.strip()
170 | parsed_json = json.loads(json_content)
171 | # Format JSON content with indentation for better readability
172 | with open(
173 | caption_path.with_suffix(".json"), "w", encoding="utf-8"
174 | ) as j:
175 | json.dump(parsed_json, j, indent=2, ensure_ascii=False)
176 | f.write(parsed_json["description"])
177 | except json.JSONDecodeError:
178 | # If not valid JSON, continue with normal text processing
179 | if line and line.strip():
180 | f.write(line.strip() + "\n")
181 | else:
182 | if line and line.strip():
183 | f.write(line.strip() + "\n")
184 |
185 | console.print()
186 | console.print(
187 | f"[{CONSOLE_COLORS['text']}]text: {caption_path} saved successfully.[/{CONSOLE_COLORS['text']}]"
188 | )
189 | return True
190 | except Exception as e:
191 | console.print(f"[red]Error saving caption: {e}[/red]")
192 | return False
193 |
194 |
195 | def save_caption_by_pages(caption_path: Path, caption_lines: List[str]) -> bool:
196 | """将多页文档分割为单独的页面并分别保存
197 |
198 | Args:
199 | caption_path: 保存路径
200 | caption_lines: 包含多页内容的文本列表
201 |
202 | Returns:
203 | bool: 成功返回True,失败返回False
204 | """
205 | try:
206 | # 合并文本行为单个字符串
207 | if len(caption_lines) == 1:
208 | # 如果只有一个元素(来自.md或.srt文件),直接使用该元素
209 | combined_text = caption_lines[0]
210 | else:
211 | # 如果是多行(来自.txt文件),用换行符连接
212 | combined_text = "\n".join(caption_lines)
213 |
214 | # 使用页眉作为分隔符来分割多个页面
215 | header_pattern = r'(?s).*?\s*Page\s+(\d+)\s*.*?'
217 | page_break_pattern = r''
218 |
219 | # 分割所有页面
220 | page_contents = []
221 | page_numbers = []
222 |
223 | # 查找所有页头位置
224 | header_matches = list(re.finditer(header_pattern, combined_text))
225 | footer_matches = list(re.finditer(footer_pattern, combined_text))
226 |
227 | # 如果没有找到页头,整体保存
228 | if not header_matches:
229 | # 没有找到页头,尝试其他方式分割内容
230 | # 尝试使用Markdown标题作为分割点
231 | md_header_pattern = r"^#{1,6}\s+(.+?)$"
232 | md_headers = list(
233 | re.finditer(md_header_pattern, combined_text, re.MULTILINE)
234 | )
235 |
236 | if md_headers:
237 | # 使用Markdown标题分割内容
238 | console.print(
239 | f"[yellow]No HTML headers found, splitting by Markdown headers.[/yellow]"
240 | )
241 |
242 | for i in range(len(md_headers)):
243 | header_match = md_headers[i]
244 | header_text = header_match.group(1).strip()
245 |
246 | # 计算当前部分内容的开始位置
247 | start_pos = header_match.start()
248 |
249 | # 计算当前部分内容的结束位置
250 | if i < len(md_headers) - 1:
251 | end_pos = md_headers[i + 1].start()
252 | else:
253 | end_pos = len(combined_text)
254 |
255 | # 提取部分内容
256 | section_content = combined_text[start_pos:end_pos]
257 |
258 | # 创建文件名 (使用标题的前20个字符,去除特殊字符)
259 | safe_header = re.sub(r"[^\w\s-]", "", header_text)[:20].strip()
260 | safe_header = re.sub(r"[-\s]+", "_", safe_header)
261 |
262 | section_filename = (
263 | f"{caption_path.stem}_{safe_header}{caption_path.suffix}"
264 | )
265 | section_file_path = caption_path.with_suffix("") / section_filename
266 |
267 | # 保存部分内容
268 | section_file_path.parent.mkdir(parents=True, exist_ok=True)
269 | with open(section_file_path, "w", encoding="utf-8") as f:
270 | f.write(section_content)
271 |
272 | console.print(
273 | f"[{CONSOLE_COLORS['text']}]text: {section_file_path} saved successfully.[/{CONSOLE_COLORS['text']}]"
274 | )
275 |
276 | return True
277 | else:
278 | # 如果没有任何分割点,保存为单个文件
279 | output_dir = caption_path.with_suffix(".md")
280 | output_dir.mkdir(parents=True, exist_ok=True)
281 | single_file_path = output_dir / f"{caption_path.stem}.md"
282 |
283 | with open(single_file_path, "w", encoding="utf-8") as f:
284 | f.write(combined_text)
285 | console.print(
286 | f"[{CONSOLE_COLORS['text']}]text: {single_file_path} saved successfully.[/{CONSOLE_COLORS['text']}]"
287 | )
288 | return True
289 | # 分割每个页面的内容
290 | for i in range(len(header_matches)):
291 | header_match = header_matches[i]
292 | page_number = int(header_match.group(1))
293 | page_numbers.append(page_number)
294 |
295 | # 计算当前页面内容的开始位置(从页头开始)
296 | start_pos = header_match.start()
297 |
298 | # 计算当前页面内容的结束位置
299 | # 先尝试查找对应的页脚
300 | end_pos = None
301 |
302 | # 寻找这个页码对应的页脚
303 | for footer_match in footer_matches:
304 | footer_page = int(footer_match.group(1))
305 | if footer_page == page_number:
306 | # 结束位置是这个页脚的结束位置
307 | end_pos = footer_match.end()
308 | break
309 |
310 | # 如果没找到对应页脚,则使用下一个页头作为结束位置
311 | if end_pos is None:
312 | if i < len(header_matches) - 1:
313 | end_pos = header_matches[i + 1].start()
314 | else:
315 | end_pos = len(combined_text)
316 |
317 | # 提取页面内容
318 | page_content = combined_text[start_pos:end_pos]
319 |
320 | # 移除页面分隔符 (确保使用多行模式)
321 | page_content = re.sub(page_break_pattern, "", page_content, flags=re.DOTALL)
322 |
323 | # 移除页眉
324 | page_content = re.sub(
325 | r'(?s).*?', "", page_content
333 | )
334 |
335 | # 清理可能的多余空行
336 | page_content = re.sub(r"\n{3,}", "\n\n", page_content)
337 |
338 | # 添加到页面内容列表
339 | page_contents.append((page_number, page_content))
340 |
341 | # 创建输出目录
342 | output_dir = caption_path.with_suffix("")
343 | output_dir.mkdir(parents=True, exist_ok=True)
344 |
345 | # 保存每一页为独立文件
346 | for page_number, page_content in page_contents:
347 | # 处理图片路径,将路径从子文件夹改为同级
348 | img_pattern = r"!\[(.*?)\]\(([^/]+)/([^/)]+)\)"
349 |
350 | # 检查是否有重复引用的图片
351 | processed_page_content = page_content
352 | matches = list(re.finditer(img_pattern, page_content))
353 |
354 | if matches:
355 | # 使用字典记录每个图片第一次出现的位置
356 | first_occurrence = {}
357 |
358 | # 找出每个图片第一次出现的位置
359 | for match in matches:
360 | alt_text = match.group(1)
361 | folder = match.group(2)
362 | img_name = match.group(3)
363 | if img_name not in first_occurrence:
364 | first_occurrence[img_name] = match
365 |
366 | # 先处理图片路径,统一格式
367 | processed_page_content = re.sub(
368 | img_pattern, r"", processed_page_content
369 | )
370 |
371 | # 移除所有重复的图片,但保留第一次出现的位置
372 | for img_name, match in first_occurrence.items():
373 | # 计算该图片在文本中所有出现的位置
374 | all_matches = [m for m in matches if m.group(3) == img_name]
375 |
376 | # 如果有多次出现,移除除了第一次之外的所有引用
377 | if len(all_matches) > 1:
378 | # 排序匹配,按位置从前向后处理
379 | sorted_matches = sorted(all_matches, key=lambda m: m.start())
380 |
381 | # 跳过第一次出现的匹配
382 | for m in sorted_matches[1:]:
383 | # 构建要移除的模式
384 | alt_text = m.group(1)
385 | pattern_to_remove = f"!\\[{re.escape(alt_text)}\\]\\({re.escape(img_name)}\\)"
386 | # 从处理后的内容中移除该模式
387 | processed_page_content = re.sub(
388 | pattern_to_remove, "", processed_page_content, count=1
389 | )
390 | else:
391 | # 如果没有匹配到图片,只进行路径格式转换
392 | processed_page_content = re.sub(img_pattern, r"", page_content)
393 |
394 | page_filename = f"{caption_path.stem}_{page_number}.md"
395 | page_file_path = output_dir / page_filename
396 |
397 | # 保存页面内容
398 | with open(page_file_path, "w", encoding="utf-8") as f:
399 | f.write(processed_page_content)
400 |
401 | console.print(
402 | f"[{CONSOLE_COLORS['text']}]text: {page_file_path} saved successfully.[/{CONSOLE_COLORS['text']}]"
403 | )
404 |
405 | return True
406 | except Exception as e:
407 | console.print(f"[red]Error saving pages: {e}[/red]")
408 | return False
409 |
410 |
411 | def split_md_document(uri: Path, caption_lines: List[str], save_caption_func) -> None:
412 | """分割多页Markdown文档并单独保存每一页
413 |
414 | Args:
415 | uri: 文件路径
416 | caption_lines: 包含多页内容的文本列表
417 | save_caption_func: 用于保存单页内容的函数
418 | """
419 | try:
420 | # 检查是否包含多页内容
421 | if any(
422 | ' None:
447 | """
448 | Extract images and captions from Lance dataset.
449 |
450 | Args:
451 | lance_or_path: Path to Lance dataset or Lance dataset object
452 | output_dir: Directory to save extracted images
453 | caption_dir: Optional directory to save caption files
454 | save_binary: Whether to save binary data
455 | """
456 | ds = (
457 | lance.dataset(lance_or_path, version=version)
458 | if isinstance(lance_or_path, str)
459 | else lance_or_path
460 | )
461 |
462 | # Create output directories
463 | output_path = Path(output_dir)
464 | output_path.mkdir(parents=True, exist_ok=True)
465 |
466 | if caption_dir:
467 | captions_dir_path = Path(caption_dir)
468 | captions_dir_path.mkdir(parents=True, exist_ok=True)
469 |
470 | with Progress(
471 | "[progress.description]{task.description}",
472 | SpinnerColumn(spinner_name="dots"),
473 | MofNCompleteColumn(separator="/"),
474 | BarColumn(bar_width=40, complete_style="green", finished_style="bold green"),
475 | TextColumn("•"),
476 | TaskProgressColumn(),
477 | TextColumn("•"),
478 | TransferSpeedColumn(),
479 | TextColumn("•"),
480 | TimeElapsedColumn(),
481 | TextColumn("•"),
482 | TimeRemainingColumn(),
483 | expand=True,
484 | transient=False, # 防止进度条随刷新滚动
485 | ) as progress:
486 |
487 | global console
488 |
489 | console = progress.console
490 |
491 | task = progress.add_task("[green]Extracting files...", total=ds.count_rows())
492 |
493 | for batch in ds.to_batches():
494 | # Get all metadata columns
495 | metadata_batch = {
496 | field[0]: batch.column(field[0]).to_pylist()
497 | for field in DATASET_SCHEMA
498 | if field[0] != "blob" # Skip blob to save memory
499 | }
500 | indices = list(range(len(batch)))
501 | blobs = ds.take_blobs(indices, "blob")
502 |
503 | for i in range(len(batch)):
504 | # Create metadata dict for current item
505 | metadata = {key: values[i] for key, values in metadata_batch.items()}
506 | uri = Path(metadata["uris"])
507 | blob = blobs[i]
508 |
509 | media_type = None
510 | suffix = uri.suffix.lower()
511 | if suffix in image_extensions:
512 | media_type = "image"
513 | elif suffix in animation_extensions:
514 | media_type = "animation"
515 | elif suffix in video_extensions:
516 | media_type = "video"
517 | elif suffix in audio_extensions:
518 | media_type = "audio"
519 | elif suffix in application_extensions:
520 | media_type = "application"
521 | if not uri.exists() and blob:
522 | if media_type:
523 | if not save_blob(uri, blob, metadata, media_type):
524 | progress.advance(task)
525 | continue
526 | else:
527 | console.print(
528 | f"[yellow]Unsupported file format: {suffix}[/yellow]"
529 | )
530 | progress.advance(task)
531 | continue
532 |
533 | # Save caption if available
534 | caption = metadata.get("captions", [])
535 | if caption:
536 | caption_file_path = captions_dir_path if caption_dir else uri
537 | caption_file_path.parent.mkdir(parents=True, exist_ok=True)
538 | save_caption(caption_file_path, caption, media_type)
539 |
540 | if clip_with_caption and (uri.with_suffix(".srt")).exists():
541 | subs = pysrt.open(uri.with_suffix(".srt"), encoding="utf-8")
542 | try:
543 | split_video_with_imageio_ffmpeg(uri, subs, save_caption)
544 | except Exception as e:
545 | console.print(f"[red]Error splitting video: {e}[/red]")
546 | split_media_stream_clips(
547 | uri, media_type, subs, save_caption
548 | )
549 | elif clip_with_caption and (uri.with_suffix(".md")).exists():
550 | split_md_document(uri, caption, save_caption)
551 |
552 | progress.advance(task)
553 |
554 |
555 | def main():
556 |
557 | parser = argparse.ArgumentParser(
558 | description="Extract images and captions from a Lance dataset"
559 | )
560 | parser.add_argument("lance_file", help="Path to the .lance file")
561 | parser.add_argument(
562 | "--output_dir",
563 | default="./dataset",
564 | help="Directory to save extracted data",
565 | )
566 | parser.add_argument(
567 | "--version",
568 | default="gemini",
569 | help="Dataset version",
570 | )
571 |
572 | parser.add_argument(
573 | "--not_clip_with_caption",
574 | action="store_true",
575 | help="Not clip with caption",
576 | )
577 |
578 | args = parser.parse_args()
579 | extract_from_lance(
580 | args.lance_file, args.output_dir, args.version, not args.not_clip_with_caption
581 | )
582 |
583 |
584 | if __name__ == "__main__":
585 | main()
586 |
--------------------------------------------------------------------------------
/module/waterdetect.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import time
4 | from PIL import Image
5 | from pathlib import Path
6 | import lance
7 | import pyarrow as pa
8 | from rich.console import Console
9 | from rich.pretty import Pretty
10 | from rich.progress import (
11 | Progress,
12 | SpinnerColumn,
13 | TextColumn,
14 | BarColumn,
15 | TaskProgressColumn,
16 | TimeRemainingColumn,
17 | TimeElapsedColumn,
18 | TransferSpeedColumn,
19 | MofNCompleteColumn,
20 | )
21 | import torch
22 | import onnxruntime as ort
23 | from huggingface_hub import hf_hub_download
24 | from module.lanceImport import transform2lance
25 | import concurrent.futures
26 | from transformers import AutoImageProcessor
27 | import shutil
28 | import json
29 |
30 | console = Console()
31 |
32 | FILES = ["model.onnx"]
33 |
34 |
35 | def preprocess_image(image):
36 | """使用processor预处理图像"""
37 | try:
38 | # 转换为RGB模式
39 | if isinstance(image, np.ndarray):
40 | image = Image.fromarray(image).convert("RGB")
41 | elif isinstance(image, str) or isinstance(image, Path):
42 | image = Image.open(image).convert("RGB")
43 | elif not isinstance(image, Image.Image):
44 | raise TypeError("Input must be a PIL image, numpy array, or file path")
45 |
46 | # 使用processor预处理图像
47 | inputs = processor(images=image, return_tensors="pt")
48 |
49 | # 转换为numpy数组返回
50 | return inputs["pixel_values"][0].numpy()
51 | except Exception as e:
52 | console.print(f"[red]preprocess_image error: {str(e)}[/red]")
53 | return None
54 |
55 |
56 | def load_and_preprocess_batch(uris):
57 | """并行加载和预处理一批图像"""
58 |
59 | def load_single_image(uri):
60 | try:
61 | # 直接传入路径,在preprocess_image中处理转换
62 | return preprocess_image(uri)
63 | except Exception as e:
64 | console.print(f"[red]Error processing {uri}: {str(e)}[/red]")
65 | return None
66 |
67 | with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
68 | batch_images = list(executor.map(load_single_image, uris))
69 |
70 | # 过滤掉加载失败的图像
71 | valid_images = [(i, img) for i, img in enumerate(batch_images) if img is not None]
72 | images = [img for _, img in valid_images]
73 |
74 | return images
75 |
76 |
77 | def process_batch(images, session, input_name):
78 | """处理图像批次"""
79 | try:
80 | # 图像通过processor处理后的numpy数组,直接堆叠
81 | batch_data = np.ascontiguousarray(np.stack(images))
82 | # 执行推理
83 | outputs = session.run(None, {input_name: batch_data})
84 | return outputs[0]
85 | except Exception as e:
86 | console.print(f"[red]Batch processing error: {str(e)}[/red]")
87 | return None
88 |
89 |
90 | def load_model(args):
91 | """加载模型和标签"""
92 | model_path = Path(args.model_dir) / args.repo_id.replace("/", "_") / "model.onnx"
93 |
94 | global processor
95 | processor = AutoImageProcessor.from_pretrained(args.repo_id, use_fast=True)
96 |
97 | # 下载模型
98 | if not model_path.exists() or args.force_download:
99 | for file in FILES:
100 | file_path = Path(args.model_dir) / args.repo_id.replace("/", "_") / file
101 | if not file_path.exists() or args.force_download:
102 | file_path = Path(
103 | hf_hub_download(
104 | repo_id=args.repo_id,
105 | filename=file,
106 | local_dir=file_path.parent,
107 | force_download=args.force_download,
108 | )
109 | )
110 | console.print(f"[blue]Downloaded {file} to {file_path}[/blue]")
111 | else:
112 | console.print(f"[green]Using existing {file}[/green]")
113 |
114 | # 设置推理提供者
115 | providers = []
116 | if "TensorrtExecutionProvider" in ort.get_available_providers():
117 | providers.append("TensorrtExecutionProvider")
118 | console.print("[green]Using TensorRT for inference[/green]")
119 | console.print("[yellow]compile may take a long time, please wait...[/yellow]")
120 | elif "CUDAExecutionProvider" in ort.get_available_providers():
121 | providers.append("CUDAExecutionProvider")
122 | console.print("[green]Using CUDA for inference[/green]")
123 | elif "ROCMExecutionProvider" in ort.get_available_providers():
124 | providers.append("ROCMExecutionProvider")
125 | console.print("[green]Using ROCm for inference[/green]")
126 | elif "OpenVINOExecutionProvider" in ort.get_available_providers():
127 | providers = [("OpenVINOExecutionProvider", {"device_type": "GPU_FP32"})]
128 | console.print("[green]Using OpenVINO for inference[/green]")
129 | else:
130 | providers.append("CPUExecutionProvider")
131 | console.print("[yellow]Using CPU for inference[/yellow]")
132 |
133 | # 创建推理会话
134 | sess_options = ort.SessionOptions()
135 | sess_options.graph_optimization_level = (
136 | ort.GraphOptimizationLevel.ORT_ENABLE_ALL
137 | ) # 启用所有优化
138 |
139 | if "CPUExecutionProvider" in providers:
140 | # CPU时启用多线程推理
141 | sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL # 启用并行执行
142 | sess_options.inter_op_num_threads = 8 # 设置线程数
143 | sess_options.intra_op_num_threads = 8 # 设置算子内部并行数
144 |
145 | # TensorRT 优化
146 | if "TensorrtExecutionProvider" in providers:
147 | sess_options.enable_mem_pattern = True
148 | sess_options.enable_mem_reuse = True
149 | providers_with_options = [
150 | (
151 | "TensorrtExecutionProvider",
152 | {
153 | "trt_fp16_enable": True, # Enable FP16 precision for faster inference
154 | "trt_builder_optimization_level": 3,
155 | "trt_max_partition_iterations": 1000,
156 | "trt_engine_cache_enable": True,
157 | "trt_engine_cache_path": f"{Path(args.model_dir) / args.repo_id.replace('/', '_')}/trt_engines",
158 | "trt_engine_hw_compatible": True,
159 | "trt_force_sequential_engine_build": False,
160 | "trt_context_memory_sharing_enable": True,
161 | "trt_timing_cache_enable": True,
162 | "trt_timing_cache_path": f"{Path(args.model_dir) / args.repo_id.replace('/', '_')}",
163 | "trt_sparsity_enable": True,
164 | "trt_min_subgraph_size": 7,
165 | # "trt_detailed_build_log": True,
166 | },
167 | ),
168 | (
169 | "CUDAExecutionProvider",
170 | {
171 | "arena_extend_strategy": "kSameAsRequested",
172 | "cudnn_conv_algo_search": "EXHAUSTIVE",
173 | "do_copy_in_default_stream": True,
174 | "cudnn_conv_use_max_workspace": "1", # 使用最大工作空间
175 | "tunable_op_enable": True, # 启用可调优操作
176 | "tunable_op_tuning_enable": True, # 启用调优
177 | },
178 | ),
179 | ]
180 |
181 | elif "CUDAExecutionProvider" in providers:
182 | # CUDA GPU 优化
183 | sess_options.enable_mem_pattern = True
184 | sess_options.enable_mem_reuse = True
185 | providers_with_options = [
186 | (
187 | "CUDAExecutionProvider",
188 | {
189 | "arena_extend_strategy": "kSameAsRequested",
190 | "cudnn_conv_algo_search": "EXHAUSTIVE",
191 | "do_copy_in_default_stream": True,
192 | "cudnn_conv_use_max_workspace": "1",
193 | "tunable_op_enable": True,
194 | "tunable_op_tuning_enable": True,
195 | },
196 | ),
197 | ]
198 | else:
199 | providers_with_options = providers
200 |
201 | console.print(f"[cyan]Providers with options:[/cyan]")
202 | console.print(Pretty(providers_with_options, indent_guides=True, expand_all=True))
203 | start_time = time.time()
204 | ort_sess = ort.InferenceSession(
205 | str(model_path), sess_options=sess_options, providers=providers_with_options
206 | )
207 | input_name = ort_sess.get_inputs()[0].name
208 | console.print(
209 | f"[green]Model loaded in {time.time() - start_time:.2f} seconds[/green]"
210 | )
211 | return ort_sess, input_name
212 |
213 |
214 | def main(args):
215 | global console
216 |
217 | watermark_dir = Path(args.train_data_dir) / "watermarked"
218 | no_watermark_dir = Path(args.train_data_dir) / "no_watermark"
219 | # 确保目标文件夹存在
220 | if watermark_dir.exists():
221 | for symlink in watermark_dir.rglob("*"):
222 | if symlink.is_symlink():
223 | symlink.unlink()
224 | if no_watermark_dir.exists():
225 | for symlink in no_watermark_dir.rglob("*"):
226 | if symlink.is_symlink():
227 | symlink.unlink()
228 |
229 | # 初始化 Lance 数据集
230 | if not isinstance(args.train_data_dir, lance.LanceDataset):
231 | if args.train_data_dir.endswith(".lance"):
232 | dataset = lance.dataset(args.train_data_dir)
233 | elif any(
234 | file.suffix == ".lance" for file in Path(args.train_data_dir).glob("*")
235 | ):
236 | lance_file = next(
237 | file
238 | for file in Path(args.train_data_dir).glob("*")
239 | if file.suffix == ".lance"
240 | )
241 | dataset = lance.dataset(str(lance_file))
242 | else:
243 | console.print("[yellow]Converting dataset to Lance format...[/yellow]")
244 | dataset = transform2lance(
245 | args.train_data_dir,
246 | output_name="dataset",
247 | save_binary=False,
248 | not_save_disk=False,
249 | tag="WatermarkDetection",
250 | )
251 | console.print("[green]Dataset converted to Lance format[/green]")
252 |
253 | else:
254 | dataset = args.train_data_dir
255 | console.print("[green]Using existing Lance dataset[/green]")
256 |
257 | ort_sess, input_name = load_model(args)
258 |
259 | # 先计算图片总数
260 | total_images = len(
261 | dataset.to_table(
262 | columns=["mime"],
263 | filter=("mime LIKE 'image/%'"),
264 | )
265 | )
266 |
267 | # 然后创建带columns的scanner处理数据
268 | scanner = dataset.scanner(
269 | columns=["uris", "mime", "captions"],
270 | filter=("mime LIKE 'image/%'"),
271 | scan_in_order=True,
272 | batch_size=args.batch_size,
273 | batch_readahead=16,
274 | fragment_readahead=4,
275 | io_buffer_size=32 * 1024 * 1024, # 32MB buffer
276 | late_materialization=True,
277 | )
278 |
279 | with Progress(
280 | "[progress.description]{task.description}",
281 | SpinnerColumn(spinner_name="dots"),
282 | MofNCompleteColumn(separator="/"),
283 | BarColumn(bar_width=40, complete_style="green", finished_style="bold green"),
284 | TextColumn("•"),
285 | TaskProgressColumn(),
286 | TextColumn("•"),
287 | TransferSpeedColumn(),
288 | TextColumn("•"),
289 | TimeElapsedColumn(),
290 | TextColumn("•"),
291 | TimeRemainingColumn(),
292 | expand=True,
293 | ) as progress:
294 | task = progress.add_task("[bold cyan]Processing images...", total=total_images)
295 |
296 | console = progress.console
297 |
298 | # 用于收集结果的列表
299 | detection_results = []
300 |
301 | for batch in scanner.to_batches():
302 | uris = batch["uris"].to_pylist() # 获取文件路径
303 |
304 | # 使用并行处理加载和预处理图像
305 | batch_images = load_and_preprocess_batch(uris)
306 |
307 | if not batch_images:
308 | progress.update(task, advance=len(uris))
309 | continue
310 |
311 | # 处理批次
312 | probs = process_batch(batch_images, ort_sess, input_name)
313 | # 创建对应的目标文件夹(如果不存在)
314 | if probs is not None:
315 | for path, prob in zip(uris, probs):
316 | # 获取水印检测结果
317 | watermark_prob = prob[1] # 索引1对应"Watermark"标签
318 |
319 | # 根据概率确定是否有水印(阈值可以调整)
320 | has_watermark = watermark_prob > args.thresh
321 |
322 | # 添加到结果列表
323 | detection_results.append((path, float(watermark_prob)))
324 |
325 | # 创建软链接
326 | source_path = Path(path).absolute()
327 | relative_path = source_path.relative_to(
328 | Path(args.train_data_dir).absolute()
329 | )
330 | target_dir = watermark_dir if has_watermark else no_watermark_dir
331 | target_path = target_dir / relative_path
332 | target_path.parent.mkdir(parents=True, exist_ok=True)
333 |
334 | # 创建软链接
335 | try:
336 | target_path.symlink_to(source_path)
337 | except (FileExistsError, PermissionError) as e:
338 | console.print(
339 | f"[red]Unable to create symlink for {path}: {e}[/red]"
340 | )
341 | # 如果无法创建软链接,尝试复制文件代替
342 | try:
343 | shutil.copy2(source_path, target_path)
344 | console.print(
345 | f"[yellow]Created copy instead of symlink for {path}[/yellow]"
346 | )
347 | except Exception as copy_err:
348 | console.print(f"[red]Failed to copy file: {copy_err}[/red]")
349 |
350 | progress.update(task, advance=len(batch["uris"].to_pylist()))
351 |
352 | # 统计水印图片数量
353 | watermark_count = sum(1 for _, prob in detection_results if prob > args.thresh)
354 | total_count = len(detection_results)
355 |
356 | # 按路径层次结构组织结果
357 | path_tree = {}
358 | for path, prob in detection_results:
359 | parts = Path(path).relative_to(Path(args.train_data_dir).absolute()).parts
360 | current = path_tree
361 | for i, part in enumerate(parts):
362 | if i == len(parts) - 1:
363 | # 最后一层是文件名,存储概率
364 | current[part] = f"{prob:.4f}" + (
365 | "🔴(Watermarked)🔖" if prob > args.thresh else "🟢(No Watermark)📄"
366 | )
367 | else:
368 | if part not in current:
369 | current[part] = {}
370 | current = current[part]
371 |
372 | # 使用Pretty打印结果树
373 | console.print("\n[bold green]Results:[/bold green]")
374 | console.print(Pretty(path_tree, indent_guides=True, expand_all=True))
375 | # 保存检测结果树到JSON文件
376 | result_json_path = Path(args.train_data_dir) / "watermark_detection_results.json"
377 | with open(result_json_path, "w", encoding="utf-8") as f:
378 | json.dump(path_tree, f, ensure_ascii=False, indent=2)
379 | console.print(f"[bold green]Results saved to:[/bold green] {result_json_path}")
380 |
381 | # 打印检测结果统计
382 | console.print(
383 | f"🔴Watermarked🔖: {watermark_count} ({watermark_count/total_count*100:.2f}%)"
384 | )
385 | console.print(
386 | f"🟢No Watermark📄: {total_count - watermark_count} ({(total_count - watermark_count)/total_count*100:.2f}%)"
387 | )
388 |
389 |
390 | def setup_parser() -> argparse.ArgumentParser:
391 | parser = argparse.ArgumentParser()
392 | parser.add_argument(
393 | "train_data_dir",
394 | type=str,
395 | help="Directory containing images to process",
396 | )
397 | parser.add_argument(
398 | "--repo_id",
399 | type=str,
400 | default="bdsqlsz/joycaption-watermark-detection-onnx",
401 | help="Repository ID for Watermark Detection model on Hugging Face",
402 | )
403 | parser.add_argument(
404 | "--model_dir",
405 | type=str,
406 | default="watermark_detection",
407 | help="Directory to store Watermark Detection model",
408 | )
409 | parser.add_argument(
410 | "--force_download",
411 | action="store_true",
412 | help="Force downloading Watermark Detection model",
413 | )
414 | parser.add_argument(
415 | "--batch_size",
416 | type=int,
417 | default=16,
418 | help="Batch size for inference",
419 | )
420 | parser.add_argument(
421 | "--thresh",
422 | type=float,
423 | default=1.0,
424 | help="Default threshold for tag confidence",
425 | )
426 |
427 | return parser
428 |
429 |
430 | if __name__ == "__main__":
431 | parser = setup_parser()
432 |
433 | args = parser.parse_args()
434 |
435 | main(args)
436 |
--------------------------------------------------------------------------------
/requirements-uv-linux.txt:
--------------------------------------------------------------------------------
1 | # This file was autogenerated by uv via the following command:
2 | # uv pip compile requirements.txt -o requirements-uv-linux.txt --index-strategy unsafe-best-match --no-build-isolation -p 3.11
3 | accelerate==1.6.0
4 | # via -r requirements.txt
5 | aiohappyeyeballs==2.6.1
6 | # via aiohttp
7 | aiohttp==3.11.18
8 | # via dashscope
9 | aiosignal==1.3.2
10 | # via aiohttp
11 | annotated-types==0.7.0
12 | # via pydantic
13 | anyio==4.9.0
14 | # via
15 | # google-genai
16 | # httpx
17 | # openai
18 | attrs==25.3.0
19 | # via aiohttp
20 | av==14.3.0
21 | # via -r requirements.txt
22 | bitmath==1.3.3.1
23 | # via hbutils
24 | cachetools==5.5.2
25 | # via
26 | # google-auth
27 | # zhipuai
28 | certifi==2025.4.26
29 | # via
30 | # httpcore
31 | # httpx
32 | # requests
33 | chardet==4.0.0
34 | # via
35 | # hbutils
36 | # pysrt
37 | charset-normalizer==3.4.2
38 | # via requests
39 | click==8.1.8
40 | # via scenedetect
41 | coloredlogs==15.0.1
42 | # via onnxruntime-gpu
43 | cssselect==1.3.0
44 | # via pyquery
45 | dashscope==1.23.2
46 | # via -r requirements.txt
47 | deprecation==2.1.0
48 | # via hbutils
49 | distro==1.9.0
50 | # via openai
51 | eval-type-backport==0.2.2
52 | # via mistralai
53 | filelock==3.18.0
54 | # via
55 | # huggingface-hub
56 | # torch
57 | # transformers
58 | flatbuffers==25.2.10
59 | # via onnxruntime-gpu
60 | frozenlist==1.6.0
61 | # via
62 | # aiohttp
63 | # aiosignal
64 | fsspec==2025.3.2
65 | # via
66 | # huggingface-hub
67 | # torch
68 | google-auth==2.40.1
69 | # via google-genai
70 | google-genai==1.14.0
71 | # via -r requirements.txt
72 | h11==0.16.0
73 | # via httpcore
74 | hbutils==0.11.0
75 | # via pyanimeinfo
76 | hf-xet==1.1.0
77 | # via huggingface-hub
78 | httpcore==1.0.9
79 | # via httpx
80 | httpx==0.28.1
81 | # via
82 | # google-genai
83 | # mistralai
84 | # openai
85 | # zhipuai
86 | huggingface-hub==0.31.1
87 | # via
88 | # -r requirements.txt
89 | # accelerate
90 | # tokenizers
91 | # transformers
92 | humanfriendly==10.0
93 | # via coloredlogs
94 | idna==3.10
95 | # via
96 | # anyio
97 | # httpx
98 | # requests
99 | # yarl
100 | imageio==2.37.0
101 | # via -r requirements.txt
102 | imageio-ffmpeg==0.6.0
103 | # via -r requirements.txt
104 | jinja2==3.1.6
105 | # via torch
106 | jiter==0.9.0
107 | # via openai
108 | lxml==4.9.4
109 | # via
110 | # pyanimeinfo
111 | # pyquery
112 | markdown-it-py==3.0.0
113 | # via rich
114 | markupsafe==3.0.2
115 | # via jinja2
116 | mdurl==0.1.2
117 | # via markdown-it-py
118 | mistralai==1.7.0
119 | # via -r requirements.txt
120 | mpmath==1.3.0
121 | # via sympy
122 | multidict==6.4.3
123 | # via
124 | # aiohttp
125 | # yarl
126 | mutagen==1.47.0
127 | # via -r requirements.txt
128 | networkx==3.4.2
129 | # via torch
130 | numpy==2.2.5
131 | # via
132 | # accelerate
133 | # imageio
134 | # onnxruntime-gpu
135 | # opencv-contrib-python-rolling
136 | # pandas
137 | # pylance
138 | # scenedetect
139 | # scipy
140 | # torchvision
141 | # transformers
142 | nvidia-cublas-cu12==12.6.4.1
143 | # via
144 | # nvidia-cudnn-cu12
145 | # nvidia-cusolver-cu12
146 | # torch
147 | nvidia-cuda-cupti-cu12==12.6.80
148 | # via torch
149 | nvidia-cuda-nvrtc-cu12==12.6.77
150 | # via torch
151 | nvidia-cuda-runtime-cu12==12.6.77
152 | # via
153 | # tensorrt-cu12-libs
154 | # torch
155 | nvidia-cudnn-cu12==9.5.1.17
156 | # via torch
157 | nvidia-cufft-cu12==11.3.0.4
158 | # via torch
159 | nvidia-cufile-cu12==1.11.1.6
160 | # via torch
161 | nvidia-curand-cu12==10.3.7.77
162 | # via torch
163 | nvidia-cusolver-cu12==11.7.1.2
164 | # via torch
165 | nvidia-cusparse-cu12==12.5.4.2
166 | # via
167 | # nvidia-cusolver-cu12
168 | # torch
169 | nvidia-cusparselt-cu12==0.6.3
170 | # via torch
171 | nvidia-nccl-cu12==2.26.2
172 | # via torch
173 | nvidia-nvjitlink-cu12==12.6.85
174 | # via
175 | # nvidia-cufft-cu12
176 | # nvidia-cusolver-cu12
177 | # nvidia-cusparse-cu12
178 | # torch
179 | nvidia-nvtx-cu12==12.6.77
180 | # via torch
181 | onnxruntime-gpu==1.20.2
182 | # via -r requirements.txt
183 | openai==1.78.0
184 | # via -r requirements.txt
185 | opencv-contrib-python-rolling @ https://github.com/cudawarped/opencv-python-cuda-wheels/releases/download/4.11.0.20250210/opencv_contrib_python_rolling-4.12.0.20250210-cp37-abi3-linux_x86_64.whl
186 | # via -r requirements.txt
187 | packaging==25.0
188 | # via
189 | # accelerate
190 | # deprecation
191 | # hbutils
192 | # huggingface-hub
193 | # onnxruntime-gpu
194 | # pillow-jxl-plugin
195 | # transformers
196 | pandas==2.2.3
197 | # via -r requirements.txt
198 | pillow==11.2.1
199 | # via
200 | # imageio
201 | # pillow-heif
202 | # pillow-jxl-plugin
203 | # rich-pixels
204 | # torchvision
205 | pillow-avif-plugin==1.5.2
206 | # via -r requirements.txt
207 | pillow-heif==0.22.0
208 | # via -r requirements.txt
209 | pillow-jxl-plugin==1.3.2
210 | # via -r requirements.txt
211 | platformdirs==4.3.8
212 | # via scenedetect
213 | propcache==0.3.1
214 | # via
215 | # aiohttp
216 | # yarl
217 | protobuf==6.30.2
218 | # via onnxruntime-gpu
219 | psutil==7.0.0
220 | # via accelerate
221 | pyanimeinfo==0.0.4
222 | # via -r requirements.txt
223 | pyarrow==20.0.0
224 | # via
225 | # -r requirements.txt
226 | # pylance
227 | pyasn1==0.6.1
228 | # via
229 | # pyasn1-modules
230 | # rsa
231 | pyasn1-modules==0.4.2
232 | # via google-auth
233 | pydantic==2.11.4
234 | # via
235 | # google-genai
236 | # mistralai
237 | # openai
238 | # zhipuai
239 | pydantic-core==2.33.2
240 | # via
241 | # pydantic
242 | # zhipuai
243 | pygments==2.19.1
244 | # via rich
245 | pyjwt==2.8.0
246 | # via zhipuai
247 | pylance==0.26.1
248 | # via -r requirements.txt
249 | pymediainfo==7.0.1
250 | # via -r requirements.txt
251 | pyparsing==3.0.9
252 | # via pyrfc6266
253 | pyquery==2.0.1
254 | # via pyanimeinfo
255 | pyrfc6266==1.0.2
256 | # via pyanimeinfo
257 | pysrt==1.1.2
258 | # via -r requirements.txt
259 | python-dateutil==2.9.0.post0
260 | # via
261 | # mistralai
262 | # pandas
263 | pytimeparse==1.1.8
264 | # via hbutils
265 | pytz==2025.2
266 | # via pandas
267 | pyyaml==6.0.2
268 | # via
269 | # accelerate
270 | # huggingface-hub
271 | # transformers
272 | regex==2024.11.6
273 | # via transformers
274 | requests==2.32.3
275 | # via
276 | # dashscope
277 | # google-genai
278 | # huggingface-hub
279 | # pyanimeinfo
280 | # transformers
281 | rich==14.0.0
282 | # via
283 | # -r requirements.txt
284 | # rich-pixels
285 | rich-pixels==3.0.1
286 | # via -r requirements.txt
287 | rsa==4.9.1
288 | # via google-auth
289 | safetensors==0.5.3
290 | # via
291 | # accelerate
292 | # transformers
293 | scenedetect==0.6.6
294 | # via -r requirements.txt
295 | scipy==1.15.3
296 | # via -r requirements.txt
297 | setuptools==80.3.1
298 | # via
299 | # hbutils
300 | # triton
301 | six==1.17.0
302 | # via python-dateutil
303 | sniffio==1.3.1
304 | # via
305 | # anyio
306 | # openai
307 | sympy==1.14.0
308 | # via
309 | # onnxruntime-gpu
310 | # torch
311 | tensorrt==10.10.0.31
312 | # via -r requirements.txt
313 | tensorrt-cu12==10.10.0.31
314 | # via tensorrt
315 | tensorrt-cu12-bindings==10.10.0.31
316 | # via tensorrt-cu12
317 | tensorrt-cu12-libs==10.10.0.31
318 | # via tensorrt-cu12
319 | tokenizers==0.21.1
320 | # via transformers
321 | toml==0.10.2
322 | # via -r requirements.txt
323 | torch==2.7.0
324 | # via
325 | # -r requirements.txt
326 | # accelerate
327 | # torchvision
328 | torchvision==0.22.0
329 | # via -r requirements.txt
330 | tqdm==4.67.1
331 | # via
332 | # huggingface-hub
333 | # openai
334 | # pyanimeinfo
335 | # scenedetect
336 | # transformers
337 | transformers==4.51.3
338 | # via -r requirements.txt
339 | triton==3.3.0
340 | # via torch
341 | typing-extensions==4.13.2
342 | # via
343 | # anyio
344 | # google-genai
345 | # huggingface-hub
346 | # openai
347 | # pydantic
348 | # pydantic-core
349 | # torch
350 | # typing-inspection
351 | typing-inspection==0.4.0
352 | # via
353 | # mistralai
354 | # pydantic
355 | tzdata==2025.2
356 | # via pandas
357 | urllib3==2.4.0
358 | # via requests
359 | websocket-client==1.8.0
360 | # via dashscope
361 | websockets==15.0.1
362 | # via google-genai
363 | yarl==1.20.0
364 | # via aiohttp
365 | zhipuai==2.1.5.20250421
366 | # via -r requirements.txt
367 |
--------------------------------------------------------------------------------
/requirements-uv.txt:
--------------------------------------------------------------------------------
1 | # This file was autogenerated by uv via the following command:
2 | # uv pip compile requirements.txt -o requirements-uv.txt --index-strategy unsafe-best-match --no-build-isolation -p 3.11
3 | accelerate==1.6.0
4 | # via -r requirements.txt
5 | aiohappyeyeballs==2.4.6
6 | # via aiohttp
7 | aiohttp==3.11.12
8 | # via dashscope
9 | aiosignal==1.3.2
10 | # via aiohttp
11 | annotated-types==0.7.0
12 | # via pydantic
13 | anyio==4.8.0
14 | # via
15 | # google-genai
16 | # httpx
17 | # openai
18 | attrs==25.1.0
19 | # via aiohttp
20 | av==14.0.1
21 | # via -r requirements.txt
22 | bitmath==1.3.3.1
23 | # via hbutils
24 | cachetools==5.5.0
25 | # via
26 | # google-auth
27 | # zhipuai
28 | certifi==2024.12.14
29 | # via
30 | # httpcore
31 | # httpx
32 | # requests
33 | chardet==4.0.0
34 | # via
35 | # hbutils
36 | # pysrt
37 | charset-normalizer==3.4.1
38 | # via requests
39 | click==8.1.8
40 | # via scenedetect
41 | colorama==0.4.6
42 | # via
43 | # click
44 | # tqdm
45 | coloredlogs==15.0.1
46 | # via onnxruntime-gpu
47 | cssselect==1.3.0
48 | # via pyquery
49 | dashscope==1.22.1
50 | # via -r requirements.txt
51 | deprecation==2.1.0
52 | # via hbutils
53 | distro==1.9.0
54 | # via openai
55 | eval-type-backport==0.2.2
56 | # via mistralai
57 | filelock==3.18.0
58 | # via
59 | # huggingface-hub
60 | # torch
61 | # transformers
62 | flatbuffers==25.2.10
63 | # via onnxruntime-gpu
64 | frozenlist==1.5.0
65 | # via
66 | # aiohttp
67 | # aiosignal
68 | fsspec==2025.3.2
69 | # via
70 | # huggingface-hub
71 | # torch
72 | google-auth==2.37.0
73 | # via google-genai
74 | google-genai==1.11.0
75 | # via -r requirements.txt
76 | h11==0.14.0
77 | # via httpcore
78 | hbutils==0.11.0
79 | # via pyanimeinfo
80 | hf-xet==1.0.3
81 | # via huggingface-hub
82 | httpcore==1.0.7
83 | # via httpx
84 | httpx==0.28.1
85 | # via
86 | # google-genai
87 | # mistralai
88 | # openai
89 | # zhipuai
90 | huggingface-hub==0.30.2
91 | # via
92 | # -r requirements.txt
93 | # accelerate
94 | # tokenizers
95 | # transformers
96 | humanfriendly==10.0
97 | # via coloredlogs
98 | idna==3.10
99 | # via
100 | # anyio
101 | # httpx
102 | # requests
103 | # yarl
104 | imageio==2.36.1
105 | # via -r requirements.txt
106 | imageio-ffmpeg==0.5.1
107 | # via -r requirements.txt
108 | jinja2==3.1.6
109 | # via torch
110 | jiter==0.8.2
111 | # via openai
112 | lxml==4.9.4
113 | # via
114 | # pyanimeinfo
115 | # pyquery
116 | markdown-it-py==3.0.0
117 | # via rich
118 | markupsafe==3.0.2
119 | # via jinja2
120 | mdurl==0.1.2
121 | # via markdown-it-py
122 | mistralai==1.7.0
123 | # via -r requirements.txt
124 | mpmath==1.3.0
125 | # via sympy
126 | multidict==6.1.0
127 | # via
128 | # aiohttp
129 | # yarl
130 | mutagen==1.47.0
131 | # via -r requirements.txt
132 | networkx==3.4.2
133 | # via torch
134 | numpy==1.26.4
135 | # via
136 | # accelerate
137 | # imageio
138 | # onnxruntime-gpu
139 | # opencv-contrib-python-rolling
140 | # pandas
141 | # pylance
142 | # scenedetect
143 | # scipy
144 | # torchvision
145 | # transformers
146 | nvidia-cuda-runtime-cu12==12.8.90
147 | # via tensorrt-cu12-libs
148 | onnxruntime-gpu==1.20.2
149 | # via -r requirements.txt
150 | openai==1.59.7
151 | # via -r requirements.txt
152 | opencv-contrib-python-rolling @ https://github.com/cudawarped/opencv-python-cuda-wheels/releases/download/4.11.0.20250124/opencv_contrib_python_rolling-4.12.0.86-cp37-abi3-win_amd64.whl
153 | # via -r requirements.txt
154 | packaging==24.2
155 | # via
156 | # accelerate
157 | # deprecation
158 | # hbutils
159 | # huggingface-hub
160 | # onnxruntime-gpu
161 | # pillow-jxl-plugin
162 | # transformers
163 | pandas==2.2.3
164 | # via -r requirements.txt
165 | pillow==10.4.0
166 | # via
167 | # imageio
168 | # pillow-heif
169 | # pillow-jxl-plugin
170 | # rich-pixels
171 | # torchvision
172 | pillow-avif-plugin==1.4.6
173 | # via -r requirements.txt
174 | pillow-heif==0.21.0
175 | # via -r requirements.txt
176 | pillow-jxl-plugin==1.3.0
177 | # via -r requirements.txt
178 | platformdirs==4.3.6
179 | # via scenedetect
180 | propcache==0.2.1
181 | # via
182 | # aiohttp
183 | # yarl
184 | protobuf==6.30.2
185 | # via onnxruntime-gpu
186 | psutil==7.0.0
187 | # via accelerate
188 | pyanimeinfo==0.0.4
189 | # via -r requirements.txt
190 | pyarrow==18.1.0
191 | # via
192 | # -r requirements.txt
193 | # pylance
194 | pyasn1==0.6.1
195 | # via
196 | # pyasn1-modules
197 | # rsa
198 | pyasn1-modules==0.4.1
199 | # via google-auth
200 | pydantic==2.10.4
201 | # via
202 | # google-genai
203 | # mistralai
204 | # openai
205 | # zhipuai
206 | pydantic-core==2.27.2
207 | # via
208 | # pydantic
209 | # zhipuai
210 | pygments==2.18.0
211 | # via rich
212 | pyjwt==2.8.0
213 | # via zhipuai
214 | pylance==0.20.0
215 | # via -r requirements.txt
216 | pymediainfo==6.1.0
217 | # via -r requirements.txt
218 | pyparsing==3.0.9
219 | # via pyrfc6266
220 | pyquery==2.0.1
221 | # via pyanimeinfo
222 | pyreadline3==3.5.4
223 | # via humanfriendly
224 | pyrfc6266==1.0.2
225 | # via pyanimeinfo
226 | pysrt==1.1.2
227 | # via -r requirements.txt
228 | python-dateutil==2.9.0.post0
229 | # via
230 | # mistralai
231 | # pandas
232 | pytimeparse==1.1.8
233 | # via hbutils
234 | pytz==2024.2
235 | # via pandas
236 | pyyaml==6.0.2
237 | # via
238 | # accelerate
239 | # huggingface-hub
240 | # transformers
241 | regex==2024.11.6
242 | # via transformers
243 | requests==2.32.3
244 | # via
245 | # dashscope
246 | # google-genai
247 | # huggingface-hub
248 | # pyanimeinfo
249 | # transformers
250 | rich==13.9.4
251 | # via
252 | # -r requirements.txt
253 | # rich-pixels
254 | rich-pixels==3.0.1
255 | # via -r requirements.txt
256 | rsa==4.9
257 | # via google-auth
258 | safetensors==0.5.3
259 | # via
260 | # accelerate
261 | # transformers
262 | scenedetect==0.6.6
263 | # via -r requirements.txt
264 | scipy==1.15.3
265 | # via -r requirements.txt
266 | setuptools==75.6.0
267 | # via
268 | # hbutils
269 | # imageio-ffmpeg
270 | six==1.17.0
271 | # via python-dateutil
272 | sniffio==1.3.1
273 | # via
274 | # anyio
275 | # openai
276 | sympy==1.13.3
277 | # via
278 | # onnxruntime-gpu
279 | # torch
280 | tensorrt==10.9.0.34
281 | # via -r requirements.txt
282 | tensorrt-cu12==10.9.0.34
283 | # via tensorrt
284 | tensorrt-cu12-bindings==10.9.0.34
285 | # via tensorrt-cu12
286 | tensorrt-cu12-libs==10.9.0.34
287 | # via tensorrt-cu12
288 | tokenizers==0.21.1
289 | # via transformers
290 | toml==0.10.2
291 | # via -r requirements.txt
292 | torch==2.7.0
293 | # via
294 | # -r requirements.txt
295 | # accelerate
296 | # torchvision
297 | torchvision==0.22.0
298 | # via -r requirements.txt
299 | tqdm==4.67.1
300 | # via
301 | # huggingface-hub
302 | # openai
303 | # pyanimeinfo
304 | # scenedetect
305 | # transformers
306 | transformers==4.51.3
307 | # via -r requirements.txt
308 | typing-extensions==4.12.2
309 | # via
310 | # anyio
311 | # google-genai
312 | # huggingface-hub
313 | # openai
314 | # pydantic
315 | # pydantic-core
316 | # torch
317 | # typing-inspection
318 | typing-inspection==0.4.0
319 | # via mistralai
320 | tzdata==2024.2
321 | # via pandas
322 | urllib3==2.3.0
323 | # via requests
324 | websocket-client==1.8.0
325 | # via dashscope
326 | websockets==14.2
327 | # via google-genai
328 | yarl==1.18.3
329 | # via aiohttp
330 | zhipuai==2.1.5.20250410
331 | # via -r requirements.txt
332 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | av
2 | pylance>=0.20.0
3 | rich
4 | rich_pixels
5 | pandas
6 | pillow-avif-plugin
7 | pillow-heif
8 | pillow-jxl-plugin
9 | pysrt
10 | toml
11 | imageio>=2.31.1
12 | imageio-ffmpeg>=0.4.8
13 | pymediainfo
14 | mutagen
15 | pyarrow>=14.0.1
16 | mistralai>=1.6.0
17 | google-genai>=1.11.0
18 | openAI
19 | dashscope
20 | pyanimeinfo
21 | zhipuai
22 |
23 | #WDtagger
24 | accelerate>=1.6.0
25 | huggingface_hub[hf_xet]>=0.30.2
26 | https://github.com/cudawarped/opencv-python-cuda-wheels/releases/download/4.11.0.20250124/opencv_contrib_python_rolling-4.12.0.86-cp37-abi3-win_amd64.whl; sys_platform == 'win32'
27 | https://github.com/cudawarped/opencv-python-cuda-wheels/releases/download/4.11.0.20250210/opencv_contrib_python_rolling-4.12.0.20250210-cp37-abi3-linux_x86_64.whl; sys_platform == 'linux'
28 | torch>=2.7.0
29 | onnxruntime-gpu==1.20.2; sys_platform == 'win32'
30 | onnxruntime-gpu>=1.20.2; sys_platform == 'linux'
31 | tensorrt>=10.9
32 |
33 | scenedetect>=0.6.6
34 |
35 | #WatermarkDetect
36 | transformers>4.50
37 | scipy
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdbds/qinglong-captions/c9376f3bf325675c7dc10b5da702e1e0eb629947/utils/__init__.py
--------------------------------------------------------------------------------
/utils/console_util.py:
--------------------------------------------------------------------------------
1 | from rich.console import Console
2 | from rich.text import Text
3 | from rich.panel import Panel
4 | from rich.layout import Layout
5 | from rich.markdown import Markdown
6 | from rich.segment import Segment
7 | from rich.style import Style
8 | from rich_pixels import Pixels
9 | from utils.wdtagger import TagClassifier
10 |
11 | # 全局控制台实例
12 | console = Console()
13 |
14 |
15 | class BaseLayout:
16 | """基础布局类,提供创建Rich布局的基本功能"""
17 |
18 | def __init__(self, panel_height=32, console=None):
19 | """
20 | 初始化基础布局
21 |
22 | Args:
23 | panel_height: 面板高度
24 | console: Rich控制台实例
25 | """
26 | self.panel_height = panel_height
27 | self.console = console or globals().get("console", Console())
28 | self.layout = Layout()
29 |
30 | def create_layout(self):
31 | """创建基本布局结构(由子类实现)"""
32 | pass
33 |
34 | def render(self, title=""):
35 | """
36 | 渲染布局为面板并返回
37 |
38 | Args:
39 | title: 面板标题
40 |
41 | Returns:
42 | Panel: Rich面板对象
43 | """
44 | panel = Panel(
45 | self.layout,
46 | title=title,
47 | height=self.panel_height + 2,
48 | padding=0,
49 | )
50 | return panel
51 |
52 | def print(self, title=""):
53 | """
54 | 打印布局到控制台
55 |
56 | Args:
57 | title: 面板标题
58 | """
59 | panel = self.render(title)
60 | self.console.print()
61 | self.console.print()
62 | self.console.print(panel)
63 |
64 |
65 | class CaptionLayout(BaseLayout):
66 | """用于显示图片字幕的布局类"""
67 |
68 | def __init__(
69 | self,
70 | tag_description,
71 | short_description,
72 | long_description,
73 | pixels,
74 | short_highlight_rate=0,
75 | long_highlight_rate=0,
76 | panel_height=32,
77 | console=None,
78 | ):
79 | """
80 | 初始化字幕布局
81 |
82 | Args:
83 | tag_description: 标签描述
84 | short_description: 短描述
85 | long_description: 长描述
86 | pixels: Rich Pixels对象
87 | short_highlight_rate: 短描述高亮率
88 | long_highlight_rate: 长描述高亮率
89 | panel_height: 面板高度
90 | console: Rich控制台实例
91 | """
92 | super().__init__(panel_height, console)
93 | tagClassifier = TagClassifier()
94 | # Process tag_description to handle various spacing and comma combinations
95 | cleaned_description = tag_description.replace("<", "").replace(">", "")
96 | processed_tags = [tag.strip() for tag in cleaned_description.split(',') if tag.strip()]
97 | tag_values = tagClassifier.classify(processed_tags).values()
98 | self.tag_description = ",".join([",".join(value) for value in tag_values])
99 | self.short_description = short_description
100 | self.long_description = long_description
101 | self.pixels = pixels
102 | self.short_highlight_rate = short_highlight_rate
103 | self.long_highlight_rate = long_highlight_rate
104 | self.create_layout()
105 |
106 | def create_layout(self):
107 | """创建字幕布局结构"""
108 | # 创建右侧的垂直布局
109 | right_layout = Layout()
110 |
111 | # 创建上半部分的水平布局(tag和short并排)
112 | top_layout = Layout()
113 | top_layout.split_row(
114 | Layout(
115 | Panel(
116 | self.tag_description,
117 | title="tags",
118 | height=self.panel_height // 2,
119 | padding=0,
120 | expand=True,
121 | ),
122 | ratio=1,
123 | ),
124 | Layout(
125 | Panel(
126 | self.short_description,
127 | title=f"short_description - [yellow]hr:[/yellow] {self.short_highlight_rate}",
128 | height=self.panel_height // 2,
129 | padding=0,
130 | expand=True,
131 | ),
132 | ratio=1,
133 | ),
134 | )
135 |
136 | # 将右侧布局分为上下两部分
137 | right_layout.split_column(
138 | Layout(top_layout, ratio=1),
139 | Layout(
140 | Panel(
141 | self.long_description,
142 | title=f"long_description - [yellow]highlight rate:[/yellow] {self.long_highlight_rate}",
143 | height=self.panel_height // 2,
144 | padding=0,
145 | expand=True,
146 | )
147 | ),
148 | )
149 |
150 | # 主布局分为左右两部分
151 | self.layout.split_row(
152 | Layout(
153 | Panel(self.pixels, height=self.panel_height, padding=0, expand=True),
154 | name="image",
155 | ratio=1,
156 | ),
157 | Layout(right_layout, name="caption", ratio=2),
158 | )
159 |
160 |
161 | class MarkdownLayout(BaseLayout):
162 | """用于显示Markdown内容的布局类"""
163 |
164 | def __init__(self, pixels, markdown_content, panel_height=32, console=None):
165 | """
166 | 初始化Markdown布局
167 |
168 | Args:
169 | pixels: Rich Pixels对象
170 | markdown_content: Markdown内容
171 | panel_height: 面板高度
172 | console: Rich控制台实例
173 | """
174 | super().__init__(panel_height, console)
175 | self.pixels = pixels
176 | self.markdown_content = markdown_content
177 | self.create_layout()
178 |
179 | def create_layout(self):
180 | """创建Markdown布局结构"""
181 | # 创建右侧布局(单个Markdown窗口)
182 | right_layout = Layout(
183 | Panel(
184 | Markdown(self.markdown_content),
185 | title="markdown",
186 | padding=0,
187 | expand=True,
188 | )
189 | )
190 |
191 | # 如果pixels为空,直接全局渲染markdown内容,否则分为左右两部分
192 | if self.pixels is None:
193 | self.layout.update(
194 | Layout(
195 | Panel(
196 | Markdown(self.markdown_content),
197 | title="markdown",
198 | padding=0,
199 | expand=True,
200 | ),
201 | name="markdown",
202 | )
203 | )
204 | else:
205 | self.layout.split_row(
206 | Layout(
207 | Panel(
208 | self.pixels, height=self.panel_height, padding=0, expand=True
209 | ),
210 | name="image",
211 | ratio=1,
212 | ),
213 | Layout(right_layout, name="markdown", ratio=2),
214 | )
215 |
216 |
217 | class CaptionAndRateLayout(BaseLayout):
218 | """用于显示图片字幕和评分的布局类"""
219 |
220 | def __init__(
221 | self,
222 | tag_description,
223 | rating,
224 | average_score,
225 | long_description,
226 | pixels,
227 | short_highlight_rate=0,
228 | long_highlight_rate=0,
229 | panel_height=32,
230 | console=None,
231 | ):
232 | """
233 | 初始化字幕布局
234 |
235 | Args:
236 | tag_description: 标签描述
237 | rating: 高亮率
238 | average_score: 平均评分
239 | long_description: 长描述
240 | pixels: Rich Pixels对象
241 | long_highlight_rate: 长描述高亮率
242 | panel_height: 面板高度
243 | console: Rich控制台实例
244 | """
245 | super().__init__(panel_height, console)
246 | self.tag_description = tag_description
247 | self.long_description = long_description
248 | self.long_highlight_rate = long_highlight_rate
249 | self.pixels = pixels
250 | self.rating_chart = self.create_rating_chart(rating)
251 | self.average_score = average_score
252 | self.create_layout()
253 |
254 | def create_layout(self):
255 | """创建字幕布局结构"""
256 | # 创建右侧的垂直布局
257 | right_layout = Layout()
258 |
259 | # 创建上半部分的水平布局(tag和short并排)
260 | top_layout = Layout()
261 | top_layout.split_row(
262 | Layout(
263 | Panel(
264 | Text(self.tag_description, style="magenta"),
265 | title="tags",
266 | height=self.panel_height // 2,
267 | padding=0,
268 | expand=True,
269 | ),
270 | ratio=1,
271 | ),
272 | Layout(
273 | Panel(
274 | self.rating_chart,
275 | title=f"rating - [yellow]average score:[/yellow] {self.average_score}",
276 | height=self.panel_height // 2,
277 | padding=0,
278 | expand=True,
279 | ),
280 | ratio=1,
281 | ),
282 | )
283 |
284 | # 将右侧布局分为上下两部分
285 | right_layout.split_column(
286 | Layout(top_layout, ratio=1),
287 | Layout(
288 | Panel(
289 | self.long_description,
290 | title=f"long_description - [yellow]highlight rate:[/yellow] {self.long_highlight_rate}",
291 | height=self.panel_height // 2,
292 | padding=0,
293 | expand=True,
294 | )
295 | ),
296 | )
297 |
298 | # 主布局分为左右两部分
299 | self.layout.split_row(
300 | Layout(
301 | Panel(self.pixels, height=self.panel_height, padding=0, expand=True),
302 | name="image",
303 | ratio=1,
304 | ),
305 | Layout(right_layout, name="caption", ratio=2),
306 | )
307 |
308 | def create_rating_chart(self, ratings, max_rating=10):
309 | """创建一个简单的评分图
310 |
311 | Args:
312 | ratings: 字典,包含维度名称和对应评分
313 | max_rating: 最大评分值
314 | """
315 | if not ratings:
316 | return Pixels.from_ascii("No ratings available")
317 |
318 | # 获取维度列表并清理维度名称
319 | clean_ratings = {}
320 | for key, value in ratings.items():
321 | # 移除所有非标准字符,不仅仅是方块字符
322 | clean_key = ""
323 | for char in key:
324 | # 只保留字母、数字、空格和常见标点符号
325 | if char.isalnum() or char.isspace() or char in "&/,.:-_()":
326 | clean_key += char
327 | # 如果清理后为空,使用原始键
328 | if not clean_key:
329 | clean_key = f"Dimension {len(clean_ratings) + 1}"
330 | clean_ratings[clean_key] = value
331 |
332 | # 创建彩虹颜色映射
333 | rainbow_colors = [
334 | "bright_red", # 红色
335 | "orange3", # 橙色
336 | "yellow", # 黄色
337 | "green", # 绿色
338 | "spring_green3", # 青色
339 | "bright_blue", # 蓝色
340 | "blue_violet", # 靛色
341 | "purple", # 紫色
342 | "magenta", # 紫红色
343 | ]
344 |
345 | # 准备行和颜色映射
346 | lines = []
347 | mapping = {}
348 |
349 | # 为每个维度创建一行
350 | for i, (dimension, rating) in enumerate(clean_ratings.items()):
351 | # 获取对应颜色
352 | color_index = min(i, len(rainbow_colors) - 1)
353 | color = rainbow_colors[color_index]
354 |
355 | # 处理维度名称,只保留第一个&前的内容
356 | short_dim = dimension.split("&")[0].strip()
357 | dim_text = f"{short_dim}:"
358 | # 评分条长度等于评分值
359 | bar_length = int(rating)
360 |
361 | # 使用隐藏的控制字符作为映射键,这些字符在普通文本中不太可能出现
362 | # 使用ASCII 01-31的不可见控制字符,每行使用不同的控制字符
363 | control_char = chr(1 + i) # 使用 SOH, STX, ETX 等不可见控制字符
364 | mapping[control_char] = Segment("■", Style(color=color))
365 |
366 | # 生成评分条
367 | bar = control_char * bar_length
368 |
369 | # 特殊处理某些维度的最大分数
370 | current_max_rating = (
371 | 5
372 | if dimension
373 | in ["Storytelling & Concept", "Setting & Environment Integration"]
374 | else max_rating
375 | )
376 | value_text = f" {rating}/{current_max_rating}"
377 |
378 | # 组合行内容
379 | line = dim_text + bar + value_text
380 | lines.append(line)
381 |
382 | # 组合成ASCII图
383 | ascii_grid = "\n".join(lines)
384 |
385 | # 返回Pixels对象
386 | return Pixels.from_ascii(ascii_grid, mapping)
387 |
--------------------------------------------------------------------------------
/utils/stream_util.py:
--------------------------------------------------------------------------------
1 | import av
2 | import re
3 | from typing import Tuple
4 | from rich.progress import Progress, BarColumn, TimeRemainingColumn
5 | from rich.console import Console
6 | from av.audio.format import AudioFormat
7 | from av.audio.layout import AudioLayout
8 | import subprocess
9 | import imageio_ffmpeg
10 | from pymediainfo import MediaInfo
11 |
12 | console = Console()
13 |
14 |
15 | def split_media_stream_clips(uri, media_type, subs, save_caption_func=None, **kwargs):
16 | """
17 | Process media stream and extract clips based on subtitles.
18 |
19 | Args:
20 | uri (str): Path to the media file
21 | media_type (str): Type of media ('video' or 'audio')
22 | subs (pysrt.SubRipFile): Subtitles to process
23 | save_caption_func (callable): Function to save captions
24 |
25 | Returns:
26 | None
27 | """
28 | with av.open(uri) as in_container:
29 | if media_type != "video":
30 | video_stream = None
31 | else:
32 | video_stream = next(
33 | (s for s in in_container.streams if s.type == "video"),
34 | None,
35 | )
36 | # Try to get audio stream if available
37 | audio_stream = next(
38 | (s for s in in_container.streams if s.type == "audio"),
39 | None,
40 | )
41 |
42 | # 添加字幕片段的进度条
43 | with Progress(
44 | "[progress.description]{task.description}",
45 | BarColumn(),
46 | "[progress.percentage]{task.percentage:>3.0f}%",
47 | TimeRemainingColumn(),
48 | console=console,
49 | ) as sub_progress:
50 | sub_task = sub_progress.add_task(
51 | f"[cyan]Processing subtitles for {uri.name}",
52 | total=len(subs),
53 | )
54 |
55 | for i, sub in enumerate(subs):
56 | # if len(subs) < 2:
57 | # break
58 | clip_path = (
59 | uri.parent / f"{uri.stem}_clip/{uri.stem}_{sub.index}{uri.suffix}"
60 | )
61 | clip_path.parent.mkdir(parents=True, exist_ok=True)
62 |
63 | with av.open(str(clip_path), mode="w") as out_container:
64 | # copy encoder settings
65 | if video_stream:
66 | out_video_stream = out_container.add_stream_from_template(
67 | template=video_stream
68 | )
69 | else:
70 | out_video_stream = None
71 | if audio_stream:
72 | # 为音频流使用特定的设置
73 | if media_type == "video":
74 | codec_name = "aac"
75 | out_audio_stream = out_container.add_stream(
76 | codec_name=codec_name,
77 | rate=48000, # AAC标准采样率
78 | )
79 | out_audio_stream.layout = AudioLayout(
80 | "mono"
81 | ) # AAC通常使用立体声
82 | out_audio_stream.format = AudioFormat(
83 | "fltp"
84 | ) # AAC使用浮点平面格式
85 | elif uri.suffix == ".mp3":
86 | codec_name = "mp3"
87 | out_audio_stream = out_container.add_stream(
88 | codec_name=codec_name,
89 | rate=16000,
90 | )
91 | out_audio_stream.layout = AudioLayout("mono")
92 | out_audio_stream.format = AudioFormat("s16p")
93 | else:
94 | codec_name = "pcm_s16le"
95 | out_audio_stream = out_container.add_stream(
96 | codec_name=codec_name,
97 | rate=16000,
98 | )
99 | out_audio_stream.layout = AudioLayout("mono")
100 | out_audio_stream.format = AudioFormat("s16p")
101 | else:
102 | out_audio_stream = None
103 |
104 | # 正确计算 start 和 end 时间戳, 单位是 video_stream.time_base
105 | # 使用毫秒并根据 video_stream.time_base 转换
106 | start_seconds = (
107 | sub.start.hours * 3600
108 | + sub.start.minutes * 60
109 | + sub.start.seconds
110 | )
111 | end_seconds = (
112 | sub.end.hours * 3600 + sub.end.minutes * 60 + sub.end.seconds
113 | )
114 | if video_stream:
115 | start_offset = int(
116 | start_seconds
117 | * video_stream.time_base.denominator
118 | / video_stream.time_base.numerator
119 | ) # 开始时间戳偏移量 (基于 video_stream.time_base)
120 | else:
121 | start_offset = int(
122 | start_seconds
123 | * audio_stream.time_base.denominator
124 | / audio_stream.time_base.numerator
125 | ) # 开始时间戳偏移量 (基于 audio_stream.time_base)
126 | # seek to start
127 | in_container.seek(
128 | start_offset,
129 | stream=(video_stream if video_stream else audio_stream),
130 | )
131 |
132 | # 手动跳过帧 (如果在 seek 之后需要的话)
133 | for frame in in_container.decode(video_stream, audio_stream):
134 | if frame.time > end_seconds:
135 | break
136 |
137 | if (
138 | video_stream
139 | and isinstance(frame, av.VideoFrame)
140 | and frame.time >= start_seconds
141 | ):
142 | for packet in out_video_stream.encode(frame):
143 | out_container.mux(packet)
144 | elif (
145 | audio_stream
146 | and isinstance(frame, av.AudioFrame)
147 | and frame.time >= start_seconds
148 | ):
149 | for packet in out_audio_stream.encode(frame):
150 | out_container.mux(packet)
151 |
152 | # Flush streams
153 | if out_video_stream:
154 | for packet in out_video_stream.encode():
155 | out_container.mux(packet)
156 | if out_audio_stream:
157 | for packet in out_audio_stream.encode():
158 | out_container.mux(packet)
159 |
160 | if save_caption_func:
161 | save_caption_func(clip_path, [sub.text], "image")
162 | sub_progress.advance(sub_task)
163 |
164 |
165 | def split_video_with_imageio_ffmpeg(
166 | uri, subs, save_caption_func=None, segment_time=120, **kwargs
167 | ):
168 | """
169 | Process media stream and extract clips based on subtitles using ffmpeg.
170 |
171 | Args:
172 | uri (Path): Path to the media file
173 | subs (pysrt.SubRipFile): Subtitles to process
174 | save_caption_func (callable, optional): Function to save captions
175 | """
176 | ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe()
177 | with Progress(
178 | "[progress.description]{task.description}",
179 | BarColumn(),
180 | "[progress.percentage]{task.percentage:>3.0f}%",
181 | TimeRemainingColumn(),
182 | console=console,
183 | ) as sub_progress:
184 | sub_task = sub_progress.add_task(
185 | f"[cyan]Processing subtitles for {uri.name}",
186 | total=len(subs),
187 | )
188 |
189 | for i, sub in enumerate(subs):
190 | # if len(subs) < 2:
191 | # break
192 | clip_path = (
193 | uri.parent / f"{uri.stem}_clip/{uri.stem}_{sub.index}{uri.suffix}"
194 | )
195 | clip_path.parent.mkdir(parents=True, exist_ok=True)
196 |
197 | # 计算开始和结束时间
198 | start_time = f"{int(sub.start.hours):02d}:{int(sub.start.minutes):02d}:{int(sub.start.seconds):02d}.{int(sub.start.milliseconds):03d}"
199 | duration = (
200 | (sub.end.hours - sub.start.hours) * 3600
201 | + (sub.end.minutes - sub.start.minutes) * 60
202 | + (sub.end.seconds - sub.start.seconds)
203 | + (sub.end.milliseconds - sub.start.milliseconds) / 1000
204 | )
205 |
206 | if duration == segment_time:
207 | # 使用segment模式时的输出模板
208 | output_template = str(
209 | uri.parent / f"{uri.stem}_clip/{uri.stem}_%03d{uri.suffix}"
210 | )
211 | command = [
212 | ffmpeg_exe,
213 | "-i",
214 | str(uri), # 输入文件
215 | "-f",
216 | "segment", # 使用segment模式
217 | "-c",
218 | "copy", # 拷贝原始编码,速度更快
219 | "-segment_time",
220 | str(segment_time), # 指定片段时长(5分钟)
221 | "-reset_timestamps",
222 | "1", # 重置时间戳
223 | "-y", # 覆盖输出文件
224 | "-break_non_keyframes",
225 | "0",
226 | output_template, # 输出文件模板
227 | ]
228 | else:
229 | if uri.suffix == ".mp3":
230 | audio_codec = "mp3"
231 | elif uri.suffix == ".wav":
232 | audio_codec = "pcm_s16le"
233 | else:
234 | audio_codec = "aac"
235 | # 根据是否是第一个片段来调整命令
236 | if i == 0:
237 | # 第一个片段,-ss 放在 -i 前面以获得更精确的开始时间
238 | command = [
239 | ffmpeg_exe,
240 | "-ss",
241 | start_time, # 开始时间
242 | "-t",
243 | str(duration), # 持续时间
244 | "-i",
245 | str(uri), # 输入文件
246 | "-c:v",
247 | "libx264", # 重新编码视频流
248 | "-c:a",
249 | audio_codec, # 重新编码音频流
250 | "-vf",
251 | "setpts=PTS-STARTPTS", # 重置视频时间戳
252 | "-af",
253 | "asetpts=PTS-STARTPTS", # 重置音频时间戳
254 | "-y", # 覆盖输出文件
255 | str(clip_path), # 输出文件
256 | ]
257 | else:
258 | # 其他片段,-i 放在前面以确保片段连接
259 | command = [
260 | ffmpeg_exe,
261 | "-i",
262 | str(uri), # 输入文件
263 | "-ss",
264 | start_time, # 开始时间
265 | "-t",
266 | str(duration), # 持续时间
267 | "-c:v",
268 | "libx264", # 重新编码视频流
269 | "-c:a",
270 | audio_codec, # 重新编码音频流
271 | "-vf",
272 | "setpts=PTS-STARTPTS", # 重置视频时间戳
273 | "-af",
274 | "asetpts=PTS-STARTPTS", # 重置音频时间戳
275 | "-y", # 覆盖输出文件
276 | str(clip_path), # 输出文件
277 | ]
278 |
279 | console.print(f"Running command: {' '.join(command)}")
280 | try:
281 | # 使用 subprocess.PIPE 并设置 encoding='utf-8'
282 | process = subprocess.Popen(
283 | command,
284 | stdout=subprocess.PIPE,
285 | stderr=subprocess.PIPE,
286 | text=True,
287 | encoding="utf-8",
288 | errors="replace",
289 | )
290 | stdout, stderr = process.communicate()
291 |
292 | if process.returncode != 0:
293 | console.print(f"[red]Error running ffmpeg:[/red] {stderr}")
294 | raise Exception(f"FFmpeg failed: {stderr}")
295 |
296 | except Exception as e:
297 | console.print(f"[red]Failed to run ffmpeg:[/red] {str(e)}")
298 | raise
299 |
300 | if save_caption_func:
301 | save_caption_func(clip_path, [sub.text], "image")
302 | sub_progress.advance(sub_task)
303 |
304 | if sub.end.minutes - sub.start.minutes == segment_time / 60:
305 | sub_progress.advance(sub_task, advance=4)
306 | break
307 |
308 |
309 | def sanitize_filename(name: str) -> str:
310 | """Sanitizes filenames.
311 |
312 | Requirements:
313 | - Only lowercase alphanumeric characters or dashes (-)
314 | - Cannot begin or end with a dash
315 | - Max length is 40 characters
316 | """
317 | # Convert to lowercase and replace non-alphanumeric chars with dash
318 | sanitized = re.sub(r"[^a-z0-9-]", "-", name.lower())
319 | # Replace multiple dashes with single dash
320 | sanitized = re.sub(r"-+", "-", sanitized)
321 | # Remove leading and trailing dashes
322 | sanitized = sanitized.strip("-")
323 | # If empty after sanitization, use a default name
324 | if not sanitized:
325 | sanitized = "file"
326 | # Ensure it starts and ends with alphanumeric character
327 | if sanitized[0] == "-":
328 | sanitized = "f" + sanitized
329 | if sanitized[-1] == "-":
330 | sanitized = sanitized + "f"
331 | # If length exceeds 40, keep the first 20 and last 19 chars with a dash in between
332 | if len(sanitized) > 40:
333 | # Take parts that don't end with dash
334 | first_part = sanitized[:20].rstrip("-")
335 | last_part = sanitized[-19:].rstrip("-")
336 | sanitized = first_part + "-" + last_part
337 | return sanitized
338 |
339 |
340 | def get_video_duration(file_path):
341 | """
342 | 获取视频片段的精确持续时间,用于字幕偏移计算
343 |
344 | Args:
345 | file_path: 视频文件路径
346 |
347 | Returns:
348 | float: 视频持续时间(毫秒)
349 | """
350 | for track in MediaInfo.parse(file_path).tracks:
351 | if track.track_type == "Video":
352 | return track.duration
353 | elif track.track_type == "Audio":
354 | return track.duration
355 | return 0
356 |
357 |
358 | def _round_to_16(value: int) -> int:
359 | """将值四舍五入为最接近的16的倍数"""
360 | return (value // 16) * 16
361 |
362 |
363 | def calculate_dimensions(
364 | width, height, max_long_edge: int = None, max_short_edge: int = None
365 | ) -> Tuple[int, int]:
366 | """
367 | 根据原始尺寸、最长边和最短边的最大值限制计算新尺寸
368 |
369 | Args:
370 | width: 原始宽度
371 | height: 原始高度
372 | max_long_edge: 最长边的最大值
373 | max_short_edge: 最短边的最大值
374 |
375 | Returns:
376 | 调整后的宽度和高度组成的元组
377 | """
378 | # 设置默认值
379 | if max_long_edge is None and max_short_edge is None:
380 | max_long_edge = 1024
381 |
382 | # 计算原始纵横比
383 | aspect_ratio = width / height
384 |
385 | # 确定长边和短边
386 | is_width_longer = width >= height
387 |
388 | # 将原始尺寸调整为16的倍数
389 | new_width = _round_to_16(width)
390 | new_height = _round_to_16(height)
391 |
392 | # 对尺寸进行多轮调整,直到满足所有条件
393 | for _ in range(2): # 最多进行两轮调整就足够了
394 | # 处理最长边的最大值限制
395 | if max_long_edge is not None:
396 | if is_width_longer and new_width > max_long_edge:
397 | new_width = max_long_edge
398 | new_height = _round_to_16(int(new_width / aspect_ratio))
399 | elif not is_width_longer and new_height > max_long_edge:
400 | new_height = max_long_edge
401 | new_width = _round_to_16(int(new_height * aspect_ratio))
402 |
403 | # 处理最短边的最大值限制
404 | if max_short_edge is not None:
405 | if is_width_longer and new_height > max_short_edge:
406 | new_height = max_short_edge
407 | new_width = _round_to_16(int(new_height * aspect_ratio))
408 | elif not is_width_longer and new_width > max_short_edge:
409 | new_width = max_short_edge
410 | new_height = _round_to_16(int(new_width / aspect_ratio))
411 |
412 | return new_width, new_height
413 |
--------------------------------------------------------------------------------