├── .gitignore ├── 0、install pwsh.sh ├── 1、install-uv-qinglong.ps1 ├── 2.1、image_watermark_detect.ps1 ├── 2、video_spliter.ps1 ├── 3、tagger.ps1 ├── 4、run.ps1 ├── LICENSE ├── README.md ├── config ├── __init__.py ├── config.py └── config.toml ├── datasets └── put datasets here ├── lanceExport.ps1 ├── lanceImport.ps1 ├── module ├── __init__.py ├── api_handler.py ├── captioner.py ├── lanceImport.py ├── lanceexport.py ├── scenedetect.py └── waterdetect.py ├── requirements-uv-linux.txt ├── requirements-uv.txt ├── requirements.txt └── utils ├── __init__.py ├── console_util.py ├── stream_util.py └── wdtagger.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | datasets/ 173 | wd14_tagger_model/ 174 | data/ 175 | huggingface/ 176 | watermark_detection/ 177 | 178 | # windsurf rules 179 | .windsurfrules 180 | -------------------------------------------------------------------------------- /0、install pwsh.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | echo "检查是否已安装 PowerShell..." 4 | if ! command -v pwsh &> /dev/null 5 | then 6 | echo "PowerShell 未安装,正在安装..." 7 | 8 | # 下载 PowerShell '.tar.gz' 压缩包 9 | curl -L -o /tmp/powershell.tar.gz https://github.com/PowerShell/PowerShell/releases/download/v7.5.1/powershell-7.5.1-linux-x64.tar.gz 10 | 11 | # 创建目标文件夹 12 | mkdir -p /opt/microsoft/powershell/7 13 | 14 | # 解压 PowerShell 到目标文件夹 15 | tar zxf /tmp/powershell.tar.gz -C /opt/microsoft/powershell/7 16 | 17 | # 设置执行权限 18 | chmod +x /opt/microsoft/powershell/7/pwsh 19 | 20 | # 创建指向 pwsh 的符号链接 21 | ln -s /opt/microsoft/powershell/7/pwsh /usr/bin/pwsh 22 | 23 | echo "PowerShell 安装完成" 24 | else 25 | echo "PowerShell 已安装" 26 | fi 27 | 28 | echo "Install completed" 29 | -------------------------------------------------------------------------------- /1、install-uv-qinglong.ps1: -------------------------------------------------------------------------------- 1 | Set-Location $PSScriptRoot 2 | 3 | $Env:HF_HOME = "huggingface" 4 | #$Env:HF_ENDPOINT="https://hf-mirror.com" 5 | $Env:PIP_DISABLE_PIP_VERSION_CHECK = 1 6 | $Env:PIP_NO_CACHE_DIR = 1 7 | #$Env:PIP_INDEX_URL="https://pypi.mirrors.ustc.edu.cn/simple" 8 | #$Env:UV_INDEX_URL="https://pypi.tuna.tsinghua.edu.cn/simple/" 9 | $Env:UV_EXTRA_INDEX_URL = "https://download.pytorch.org/whl/cu128" 10 | $Env:UV_CACHE_DIR = "${env:LOCALAPPDATA}/uv/cache" 11 | $Env:UV_NO_BUILD_ISOLATION = 1 12 | $Env:UV_NO_CACHE = 0 13 | $Env:UV_LINK_MODE = "symlink" 14 | $Env:GIT_LFS_SKIP_SMUDGE = 1 15 | $Env:CUDA_HOME = "${env:CUDA_PATH}" 16 | 17 | function InstallFail { 18 | Write-Output "Install failed|安装失败。" 19 | Read-Host | Out-Null ; 20 | Exit 21 | } 22 | 23 | function Check { 24 | param ( 25 | $ErrorInfo 26 | ) 27 | if (!($?)) { 28 | Write-Output $ErrorInfo 29 | InstallFail 30 | } 31 | } 32 | 33 | try { 34 | ~/.local/bin/uv --version 35 | Write-Output "uv installed|UV模块已安装." 36 | } 37 | catch { 38 | Write-Output "Installing uv|安装uv模块中..." 39 | if ($Env:OS -ilike "*windows*") { 40 | powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex" 41 | Check "uv install failed|安装uv模块失败。" 42 | } 43 | else { 44 | curl -LsSf https://astral.sh/uv/install.sh | sh 45 | Check "uv install failed|安装uv模块失败。" 46 | } 47 | } 48 | 49 | if ($env:OS -ilike "*windows*") { 50 | chcp 65001 51 | # First check if UV cache directory already exists 52 | if (Test-Path -Path "${env:LOCALAPPDATA}/uv/cache") { 53 | Write-Host "UV cache directory already exists, skipping disk space check" 54 | } 55 | else { 56 | # Check C drive free space with error handling 57 | try { 58 | $CDrive = Get-WmiObject Win32_LogicalDisk -Filter "DeviceID='C:'" -ErrorAction Stop 59 | if ($CDrive) { 60 | $FreeSpaceGB = [math]::Round($CDrive.FreeSpace / 1GB, 2) 61 | Write-Host "C: drive free space: ${FreeSpaceGB}GB" 62 | 63 | # $Env:UV cache directory based on available space 64 | if ($FreeSpaceGB -lt 10) { 65 | Write-Host "Low disk space detected. Using local .cache directory" 66 | $Env:UV_CACHE_DIR = ".cache" 67 | } 68 | } 69 | else { 70 | Write-Warning "C: drive not found. Using local .cache directory" 71 | $Env:UV_CACHE_DIR = ".cache" 72 | } 73 | } 74 | catch { 75 | Write-Warning "Failed to check disk space: $_. Using local .cache directory" 76 | $Env:UV_CACHE_DIR = ".cache" 77 | } 78 | } 79 | if (Test-Path "./venv/Scripts/activate") { 80 | Write-Output "Windows venv" 81 | . ./venv/Scripts/activate 82 | } 83 | elseif (Test-Path "./.venv/Scripts/activate") { 84 | Write-Output "Windows .venv" 85 | . ./.venv/Scripts/activate 86 | } 87 | else { 88 | Write-Output "Create .venv" 89 | ~/.local/bin/uv venv -p 3.11 --seed 90 | . ./.venv/Scripts/activate 91 | } 92 | } 93 | elseif (Test-Path "./venv/bin/activate") { 94 | Write-Output "Linux venv" 95 | . ./venv/bin/Activate.ps1 96 | } 97 | elseif (Test-Path "./.venv/bin/activate") { 98 | Write-Output "Linux .venv" 99 | . ./.venv/bin/activate.ps1 100 | } 101 | else { 102 | Write-Output "Create .venv" 103 | ~/.local/bin/uv venv -p 3.11 --seed 104 | . ./.venv/bin/activate.ps1 105 | } 106 | 107 | Write-Output "Installing main requirements" 108 | 109 | ~/.local/bin/uv pip install --upgrade setuptools wheel pip wheel_stub 110 | 111 | if ($env:OS -ilike "*windows*") { 112 | ~/.local/bin/uv pip sync requirements-uv.txt --index-strategy unsafe-best-match 113 | Check "Install main requirements failed" 114 | } 115 | else { 116 | ~/.local/bin/uv pip sync requirements-uv-linux.txt --index-strategy unsafe-best-match 117 | Check "Install main requirements failed" 118 | } 119 | 120 | Write-Output "Install finished" 121 | Read-Host | Out-Null ; 122 | -------------------------------------------------------------------------------- /2.1、image_watermark_detect.ps1: -------------------------------------------------------------------------------- 1 | #region Configuration 2 | # Model settings 3 | $Config = @{ 4 | train_data_dir = "./datasets" # Input images path | 图片输入路径 5 | # bdsqlsz/Watermark-Detection-SigLIP2-onnx 6 | # bdsqlsz/joycaption-watermark-detection-onnx 7 | repo_id = "bdsqlsz/joycaption-watermark-detection-onnx" # Model repo ID from Hugging Face 8 | model_dir = "watermark_detection" # Local model folder path | 本地模型文件夹路径 9 | batch_size = 12 # Batch size for inference 10 | thresh = 1.0 # Concept threshold 11 | } 12 | #endregion 13 | 14 | #region Environment Setup 15 | # Activate python venv 16 | Set-Location $PSScriptRoot 17 | $env:PYTHONPATH = "$PSScriptRoot;$env:PYTHONPATH" 18 | $VenvPaths = @( 19 | "./venv/Scripts/activate", 20 | "./.venv/Scripts/activate", 21 | "./venv/bin/Activate.ps1", 22 | "./.venv/bin/activate.ps1" 23 | ) 24 | 25 | foreach ($Path in $VenvPaths) { 26 | if (Test-Path $Path) { 27 | Write-Output "Activating venv: $Path" 28 | & $Path 29 | break 30 | } 31 | } 32 | 33 | # Set environment variables 34 | $Env:HF_HOME = "huggingface" 35 | #$Env:HF_ENDPOINT = "https://hf-mirror.com" 36 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1" 37 | $Env:CUDA_HOME = "${env:CUDA_PATH}" 38 | $Env:TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION = "1" 39 | $Env:TF_CUDNN_USE_AUTOTUNE = "1" 40 | $Env:TF_TRT_ALLOW_TF32 = "1" 41 | 42 | #endregion 43 | 44 | #region Build Arguments 45 | $ExtArgs = [System.Collections.ArrayList]::new() 46 | 47 | # Add configuration arguments 48 | if ($Config.repo_id) { [void]$ExtArgs.Add("--repo_id=$($Config.repo_id)") } 49 | if ($Config.model_dir) { [void]$ExtArgs.Add("--model_dir=$($Config.model_dir)") } 50 | if ($Config.batch_size) { [void]$ExtArgs.Add("--batch_size=$($Config.batch_size)") } 51 | if ($Config.thresh -ne 1.0) { [void]$ExtArgs.Add("--thresh=$($Config.thresh)") } 52 | 53 | #endregion 54 | 55 | #region Execute Watermark Detection 56 | Write-Output "Starting Watermark Detection..." 57 | 58 | # Run tagger 59 | accelerate launch --num_cpu_threads_per_process=8 "./module/waterdetect.py" ` 60 | $Config.train_data_dir ` 61 | $ExtArgs 62 | 63 | Write-Output "Watermark Detection finished" 64 | Read-Host | Out-Null 65 | 66 | #endregion 67 | -------------------------------------------------------------------------------- /2、video_spliter.ps1: -------------------------------------------------------------------------------- 1 | #region Configuration 2 | # 场景检测设置 3 | $Config = @{ 4 | input_video_dir = "./datasets" # 输入视频目录路径 5 | output_dir = "" # 输出目录路径,如果不指定则默认为输入目录 6 | detector = "AdaptiveDetector" # 场景检测器,可选"ContentDetector","AdaptiveDetector","HashDetector","HistogramDetector","ThresholdDetector" 7 | threshold = 0.0 # 场景检测阈值,数值越低越敏感。ContentDetector: 27.0, AdaptiveDetector: 3.0, HashDetector: 0.395, HistogramDetector: 0.05, ThresholdDetector: 12 8 | min_scene_len = 16 # 最小场景长度,数值越小越敏感 9 | luma_only = $false # 是否只使用亮度变化检测 10 | save_html = $true # 是否保存HTML报告 11 | video2images_min_number = 1 # 每个场景保存的图像数量,为0则不保存 12 | recursive = $false # 是否递归搜索子目录 13 | } 14 | #endregion 15 | 16 | #region Environment Setup 17 | # 激活Python虚拟环境 18 | Set-Location $PSScriptRoot 19 | $env:PYTHONPATH = "$PSScriptRoot;$env:PYTHONPATH" 20 | $VenvPaths = @( 21 | "./venv/Scripts/activate", 22 | "./.venv/Scripts/activate", 23 | "./venv/bin/Activate.ps1", 24 | "./.venv/bin/activate.ps1" 25 | ) 26 | 27 | foreach ($Path in $VenvPaths) { 28 | if (Test-Path $Path) { 29 | Write-Output "Activating venv: $Path" 30 | & $Path 31 | break 32 | } 33 | } 34 | #endregion 35 | 36 | #region Build Arguments 37 | $Env:HF_HOME = "huggingface" 38 | #$Env:HF_ENDPOINT = "https://hf-mirror.com" 39 | $ExtArgs = [System.Collections.ArrayList]::new() 40 | 41 | # 添加配置参数 42 | if ($Config.output_dir) { [void]$ExtArgs.Add("--output_dir=$($Config.output_dir)") } 43 | if ($Config.detector -ne "AdaptiveDetector") { [void]$ExtArgs.Add("--detector=$($Config.detector)") } 44 | if ($Config.threshold -ne 0.0) { [void]$ExtArgs.Add("--threshold=$($Config.threshold)") } 45 | if ($Config.min_scene_len) { [void]$ExtArgs.Add("--min_scene_len=$($Config.min_scene_len)") } 46 | if ($Config.luma_only) { [void]$ExtArgs.Add("--luma_only") } 47 | if ($Config.save_html) { [void]$ExtArgs.Add("--save_html") } 48 | if ($Config.video2images_min_number -gt 0) { [void]$ExtArgs.Add("--video2images_min_number=$($Config.video2images_min_number)") } 49 | if ($Config.recursive) { [void]$ExtArgs.Add("--recursive") } 50 | #endregion 51 | 52 | #region Execute Scene Detection 53 | Write-Output "Starting scene detection..." 54 | 55 | # 运行场景检测程序 56 | python -m module.scenedetect ` 57 | $Config.input_video_dir ` 58 | $ExtArgs 59 | 60 | Write-Output "Scene detection finished" 61 | Read-Host | Out-Null 62 | #endregion 63 | -------------------------------------------------------------------------------- /3、tagger.ps1: -------------------------------------------------------------------------------- 1 | #region Configuration 2 | # Model settings 3 | $Config = @{ 4 | train_data_dir = "./datasets" # Input images path | 图片输入路径 5 | repo_id = "SmilingWolf/wd-eva02-large-tagger-v3" # Model repo ID from Hugging Face 6 | model_dir = "wd14_tagger_model" # Local model folder path | 本地模型文件夹路径 7 | batch_size = 12 # Batch size for inference 8 | thresh = 0.3 # Concept threshold 9 | general_threshold = 0.3 # General threshold 10 | character_threshold = 1.0 # Character threshold 11 | } 12 | 13 | # Feature flags 14 | $Features = @{ 15 | frequency_tags = $false # Order by frequency tags 16 | remove_underscore = $true # Convert underscore to space 17 | use_rating_tags = $false # Use rating tags 18 | use_rating_tags_as_last_tag = $false # Put rating tags at the end 19 | character_tags_first = $false # Put character tags first 20 | character_tag_expand = $false # Split character_(series) into character, series 21 | remove_parents_tag = $true # Remove parent tags 22 | overwrite = $true # Overwrite existing tag files 23 | } 24 | 25 | # Tag settings 26 | $TagConfig = @{ 27 | undesired_tags = "" # Tags to exclude 28 | always_first_tags = "1girl,1boy,2girls,3girls,4girls,5girls,6girls,2boys,3boys,4boys,5boys,6boys" 29 | tag_replacement = "1girl,1woman;2girls,2women;3girls,3women;4girls,4women;5girls,5women;1boy,1man" 30 | } 31 | 32 | #endregion 33 | 34 | #region Environment Setup 35 | # Activate python venv 36 | Set-Location $PSScriptRoot 37 | $env:PYTHONPATH = "$PSScriptRoot;$env:PYTHONPATH" 38 | $VenvPaths = @( 39 | "./venv/Scripts/activate", 40 | "./.venv/Scripts/activate", 41 | "./venv/bin/Activate.ps1", 42 | "./.venv/bin/activate.ps1" 43 | ) 44 | 45 | foreach ($Path in $VenvPaths) { 46 | if (Test-Path $Path) { 47 | Write-Output "Activating venv: $Path" 48 | & $Path 49 | break 50 | } 51 | } 52 | 53 | # Set environment variables 54 | $Env:HF_HOME = "huggingface" 55 | #$Env:HF_ENDPOINT = "https://hf-mirror.com" 56 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1" 57 | $Env:CUDA_HOME = "${env:CUDA_PATH}" 58 | $Env:TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION = "1" 59 | $Env:TF_CUDNN_USE_AUTOTUNE = "1" 60 | $Env:TF_TRT_ALLOW_TF32 = "1" 61 | 62 | #endregion 63 | 64 | #region Build Arguments 65 | $ExtArgs = [System.Collections.ArrayList]::new() 66 | 67 | # Add configuration arguments 68 | if ($Config.repo_id) { [void]$ExtArgs.Add("--repo_id=$($Config.repo_id)") } 69 | if ($Config.model_dir) { [void]$ExtArgs.Add("--model_dir=$($Config.model_dir)") } 70 | if ($Config.batch_size) { [void]$ExtArgs.Add("--batch_size=$($Config.batch_size)") } 71 | if ($Config.general_threshold) { [void]$ExtArgs.Add("--general_threshold=$($Config.general_threshold)") } 72 | if ($Config.character_threshold) { [void]$ExtArgs.Add("--character_threshold=$($Config.character_threshold)") } 73 | 74 | # Add feature flags 75 | if ($Features.remove_underscore) { [void]$ExtArgs.Add("--remove_underscore") } 76 | if ($Features.recursive) { [void]$ExtArgs.Add("--recursive") } 77 | if ($Features.frequency_tags) { [void]$ExtArgs.Add("--frequency_tags") } 78 | if ($Features.character_tags_first) { [void]$ExtArgs.Add("--character_tags_first") } 79 | if ($Features.character_tag_expand) { [void]$ExtArgs.Add("--character_tag_expand") } 80 | if ($Features.use_rating_tags_as_last_tag) { [void]$ExtArgs.Add("--use_rating_tags_as_last_tag") } 81 | elseif ($Features.use_rating_tags) { [void]$ExtArgs.Add("--use_rating_tags") } 82 | if ($Features.remove_parents_tag) { [void]$ExtArgs.Add("--remove_parents_tag") } 83 | if ($Features.overwrite) { [void]$ExtArgs.Add("--overwrite") } 84 | 85 | # Add tag configuration 86 | if ($TagConfig.undesired_tags) { [void]$ExtArgs.Add("--undesired_tags=$($TagConfig.undesired_tags)") } 87 | if ($TagConfig.always_first_tags) { [void]$ExtArgs.Add("--always_first_tags=$($TagConfig.always_first_tags)") } 88 | if ($TagConfig.tag_replacement) { [void]$ExtArgs.Add("--tag_replacement=$($TagConfig.tag_replacement)") } 89 | 90 | #endregion 91 | 92 | #region Execute Tagger 93 | Write-Output "Starting tagger..." 94 | 95 | # Run tagger 96 | accelerate launch --num_cpu_threads_per_process=8 "./utils/wdtagger.py" ` 97 | $Config.train_data_dir ` 98 | --thresh=$($Config.thresh) ` 99 | --caption_extension .txt ` 100 | $ExtArgs 101 | 102 | Write-Output "Tagger finished" 103 | Read-Host | Out-Null 104 | 105 | #endregion 106 | -------------------------------------------------------------------------------- /4、run.ps1: -------------------------------------------------------------------------------- 1 | $dataset_path = "./datasets" 2 | $gemini_api_key = "" 3 | $gemini_model_path = "gemini-2.5-pro-exp-03-25" 4 | $gemini_task = "" 5 | $pixtral_api_key = "" 6 | $pixtral_model_path = "pixtral-large-2411" 7 | $step_api_key = "" 8 | $step_model_path = "step-1.5v-mini" 9 | $qwenVL_api_key = "" 10 | $qwenVL_model_path = "qwen-vl-max-latest" # qwen2.5-vl-72b-instruct<10mins qwen-vl-max-latest <1min 11 | $glm_api_key = "" 12 | $glm_model_path = "GLM-4V-Plus-0111" 13 | $dir_name = $false 14 | $mode = "long" 15 | $not_clip_with_caption = $false # Not clip with caption | 不根据caption裁剪 16 | $wait_time = 1 17 | $max_retries = 100 18 | $segment_time = 600 19 | $ocr = $false 20 | $document_image = $true 21 | $scene_detector = "AdaptiveDetector" # from ["ContentDetector","AdaptiveDetector","HashDetector","HistogramDetector","ThresholdDetector"] 22 | $scene_threshold = 0.0 # default value ["ContentDetector": 27.0, "AdaptiveDetector": 3.0, "HashDetector": 0.395, "HistogramDetector": 0.05, "ThresholdDetector": 12] 23 | $scene_min_len = 15 24 | $scene_luma_only = $false 25 | $tags_highlightrate = 0.38 26 | 27 | # ============= DO NOT MODIFY CONTENTS BELOW | 请勿修改下方内容 ===================== 28 | # Activate python venv 29 | Set-Location $PSScriptRoot 30 | if ($env:OS -ilike "*windows*") { 31 | if (Test-Path "./venv/Scripts/activate") { 32 | Write-Output "Windows venv" 33 | ./venv/Scripts/activate 34 | } 35 | elseif (Test-Path "./.venv/Scripts/activate") { 36 | Write-Output "Windows .venv" 37 | ./.venv/Scripts/activate 38 | } 39 | } 40 | elseif (Test-Path "./venv/bin/activate") { 41 | Write-Output "Linux venv" 42 | ./venv/bin/Activate.ps1 43 | } 44 | elseif (Test-Path "./.venv/bin/activate") { 45 | Write-Output "Linux .venv" 46 | ./.venv/bin/activate.ps1 47 | } 48 | 49 | $Env:HF_HOME = "huggingface" 50 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1" 51 | #$Env:HF_ENDPOINT = "https://hf-mirror.com" 52 | $Env:PILLOW_IGNORE_XMP_DATA_IS_TOO_LONG = "1" 53 | $ext_args = [System.Collections.ArrayList]::new() 54 | #$Env:HTTP_PROXY = "http://127.0.0.1:7890" 55 | #$Env:HTTPS_PROXY = "http://127.0.0.1:7890" 56 | 57 | 58 | 59 | if ($gemini_api_key) { 60 | [void]$ext_args.Add("--gemini_api_key=$gemini_api_key") 61 | if ($gemini_task) { 62 | [void]$ext_args.Add("--gemini_task=$gemini_task") 63 | } 64 | } 65 | 66 | if ($gemini_model_path) { 67 | [void]$ext_args.Add("--gemini_model_path=$gemini_model_path") 68 | } 69 | 70 | if ($pixtral_api_key) { 71 | [void]$ext_args.Add("--pixtral_api_key=$pixtral_api_key") 72 | } 73 | 74 | if ($pixtral_model_path) { 75 | [void]$ext_args.Add("--pixtral_model_path=$pixtral_model_path") 76 | } 77 | 78 | if ($step_api_key) { 79 | [void]$ext_args.Add("--step_api_key=$step_api_key") 80 | } 81 | 82 | if ($step_model_path) { 83 | [void]$ext_args.Add("--step_model_path=$step_model_path") 84 | } 85 | 86 | if ($qwenVL_api_key) { 87 | [void]$ext_args.Add("--qwenVL_api_key=$qwenVL_api_key") 88 | } 89 | 90 | if ($qwenVL_model_path) { 91 | [void]$ext_args.Add("--qwenVL_model_path=$qwenVL_model_path") 92 | } 93 | 94 | if ($glm_api_key) { 95 | [void]$ext_args.Add("--glm_api_key=$glm_api_key") 96 | } 97 | 98 | if ($glm_model_path) { 99 | [void]$ext_args.Add("--glm_model_path=$glm_model_path") 100 | } 101 | 102 | if ($dir_name) { 103 | [void]$ext_args.Add("--dir_name") 104 | } 105 | 106 | if ($mode -ine "all") { 107 | [void]$ext_args.Add("--mode=$mode") 108 | } 109 | 110 | if ($not_clip_with_caption) { 111 | [void]$ext_args.Add("--not_clip_with_caption") 112 | } 113 | 114 | if ($wait_time -ine 1) { 115 | [void]$ext_args.Add("--wait_time=$wait_time") 116 | } 117 | 118 | if ($max_retries -ine 20) { 119 | [void]$ext_args.Add("--max_retries=$max_retries") 120 | } 121 | 122 | if ($segment_time -ine 600) { 123 | [void]$ext_args.Add("--segment_time=$segment_time") 124 | } 125 | 126 | if ($ocr) { 127 | [void]$ext_args.Add("--ocr") 128 | } 129 | 130 | if ($document_image) { 131 | [void]$ext_args.Add("--document_image") 132 | } 133 | 134 | if ($scene_detector -ne "AdaptiveDetector") { 135 | [void]$ext_args.Add("--scene_detector=$($scene_detector)") 136 | } 137 | 138 | if ($scene_threshold -ne 0.0) { 139 | [void]$ext_args.Add("--scene_threshold=$scene_threshold") 140 | } 141 | 142 | if ($scene_min_len -ne 15) { 143 | [void]$ext_args.Add("--scene_min_len=$scene_min_len") 144 | } 145 | 146 | if ($scene_luma_only) { 147 | [void]$ext_args.Add("--scene_luma_only") 148 | } 149 | 150 | if ($tags_highlightrate -ne 0.4) { 151 | [void]$ext_args.Add("--tags_highlightrate=$tags_highlightrate") 152 | } 153 | 154 | # run train 155 | python -m module.captioner $dataset_path $ext_args 156 | 157 | Write-Output "Captioner finished" 158 | Read-Host | Out-Null ; 159 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/N4N1NOO2K) 2 | 3 | # qinglong-captioner (2.6) 4 | 5 | A Python toolkit for generating video captions using the Lance database format and Gemini API for automatic captioning. 6 | 7 | ## Changlog 8 | 9 | ### 2.6 10 | ![image](https://github.com/user-attachments/assets/34f8150b-3414-4e0c-9ade-b9406cd1602b) 11 | 12 | A new watermark detection script has been added, initially supporting two watermark detection models, which can quickly classify images in the dataset into watermarked/unwatermarked categories. 13 | It will generate two folders, and data separation is done through symbolic links. If needed, you can copy the corresponding folder to transfer data without deleting it, and it does not occupy additional space. 14 | (As symbolic links require permissions, you must run PowerShell as admin.) 15 | 16 | Finally, it will generate a JSON file report listing the watermark detection results for all images in the original path, including detection values and results. 17 | The watermark threshold can be modified in the script to correspondingly change the detection results. 18 | 19 | 20 | ### 2.5 21 | ![image](https://github.com/user-attachments/assets/bffd2120-6868-4a6e-894b-05c4ff5fd98f) 22 | 23 | We officially support the tags highlight captions feature! Currently unlocked in the pixtral model, and we are considering adding it to other models such as gemini in the future. 24 | 25 | What are tags highlight? 26 | 27 | As is well known, non-state-of-the-art VLMs have some inaccuracies, so first use wdtagger for tags annotation, and then input the tags annotation to the VLM for assistance, which can improve accuracy. 28 | 29 | Currently, the tags have been categorized, and it is also possible to quickly check the annotation quality (e.g., purple is for character names and copyright, red is for clothing, brown is for body features, light yellow is for actions, etc.) 30 | 31 | The annotation quality obtained in the end is comparable to some closed-source models! 32 | 33 | Additionally, we have added check parameters, which can specify the parent folder as the character name to designate the character's name, as well as specify the check for the tags highlight rate. Generally, good captions should have a highlight rate of over 35%. 34 | 35 | You can also specify different highlight rates to change the default standard. 36 | 37 | How to use? just use 3、tagger.ps1 first for generate tags for your image datasets, 38 | 39 | then use 4、run.ps1 with pixtral apikey 40 | 41 | ### 2.4 42 | 43 | We support Gemini image caption and rating. 44 | It also supports gemini2.5-flash-preview-04-17. 45 | 46 | However, after testing, the flash version has poor effects and image review, it is recommended to use the pro version 47 | 48 | ![image](https://github.com/user-attachments/assets/6ae9ed38-e67a-41d2-aa1d-4caf0e0db394) 49 | flash↑ 50 | 51 | ![image](https://github.com/user-attachments/assets/c83682aa-3a37-4198-b117-ffe7f74ff812) 52 | pro ↑ 53 | 54 | ### 2.3 55 | 56 | Well, we forgot to release version 2.2, so we directly released version 2.3! 57 | 58 | Version 2.3 updated the GLM4V model for video captions 59 | 60 | ### 2.2 61 | 62 | Version 2.2 has updated TensorRT for accelerating local ONNX model WDtagger. 63 | 64 | After testing, it takes 30 minutes to mark 10,000 samples with the standard CUDA tag, 65 | 66 | while using TensorRT, it can be completed in just 15 to 20 minutes. 67 | 68 | However, the first time using it will take a longer time to compile. 69 | 70 | If TensorRT fails, it will automatically revert to CUDA without worry. 71 | 72 | If it prompts that TensorRT librarys are missing, it may be missing some parts 73 | 74 | Please install version 10.7.x manually from [here](https://developer.nvidia.com/tensorrt/download/10x) 75 | 76 | ### 2.1 77 | 78 | Added support for Gemini 2.5 Pro Exp. Now uses 600 seconds cut video by default. 79 | 80 | ### 2.0 Big Update! 81 | 82 | Now we support video segmentation! A new video segmentation module has been added, which detects key timestamps based on scene changes and then outputs the corresponding images and video clips! 83 | Export an HTML for reference, the effect is very significant! 84 | ![image](https://github.com/user-attachments/assets/94407fec-92af-4a34-a15e-bc02bf45d2cd) 85 | 86 | We have also added subtitle alignment algorithms, which automatically align Gemini's timestamp subtitles to the millisecond level after detecting scene change frames (there are still some errors, but the effect is much better). 87 | 88 | Finally, we added the image output feature of the latest gemini-2.0-flash-exp model! 89 | 90 | You can customize the task, add the task name in the [`config.toml`](https://github.com/sdbds/qinglong-captions/blob/main/config/config.toml), which will automatically handle the corresponding images (and then label them) 91 | 92 | Currently, some simple task descriptions are as follows: Welcome the community to continuously optimize these task prompts and provide contributions! 93 | https://github.com/sdbds/qinglong-captions/blob/12b7750ee0bca7e41168e98775cd95c7b9c57173/config/config.toml#L239-L249 94 | 95 | ![image](https://github.com/user-attachments/assets/7e5ae1a9-b635-4705-b664-1c20934d12bc) 96 | 97 | ![image](https://github.com/user-attachments/assets/58527298-34f8-496d-8c4e-1a1c1c965b73) 98 | 99 | 100 | ### 1.9 101 | 102 | Now with Mistral OCR functionality! 103 | Utilizing Mistral's advanced OCR capabilities to extract text information from videos and images. 104 | 105 | This feature is particularly useful when processing media files containing subtitles, signs, or other text elements, enhancing the accuracy and completeness of captions. 106 | 107 | The OCR functionality is integrated into the existing workflow and can be used without additional configuration. 108 | 109 | ### 1.8 110 | 111 | Now added WDtagger! 112 | Even if you cannot use the GPU, you can also use the CPU for labeling. 113 | 114 | It has multi-threading and various optimizations, processing large-scale data quickly. 115 | 116 | Using ONNX processing, model acceleration. 117 | 118 | Code reference@kohya-ss 119 | https://github.com/sdbds/sd-scripts/blob/main/finetune/tag_images_by_wd14_tagger.py 120 | 121 | Version 2.0 will add dual caption functionality, input wdtagger's taggers, then output natural language 122 | ![image](https://github.com/user-attachments/assets/f14d4a69-9c79-4ffb-aff7-84d103dfeff4) 123 | 124 | 125 | ### 1.7 126 | 127 | Now we support the qwen-VL series video caption model! 128 | 129 | - qwen-vl-max-latest 130 | - qwen2.5-vl-72b-instruct 131 | - qwen2.5-vl-7b-instruct 132 | - qwen2.5-vl-3b-instruct 133 | 134 | qwen2.5-vl has 2 seconds ~ 10 mins, qwen-vl-max-latest has 1 min limit. 135 | These models are not good at capturing timestamps; it is recommended to use segmented video clips for captions and to modify the prompts. 136 | 137 | Video upload feature requires an application to be submitted to the official, please submit the application [here](https://smartservice.console.aliyun.com/service/create-ticket?spm=a2c4g.11186623.0.0.3489b0a8Ql486b). 138 | 139 | We consider adding local model inference in the future, such as qwen2.5-vl-7b-instruct, etc. 140 | 141 | Additionally, now using streaming inference to output logs, you can see the model's real-time output before the complete output is displayed. 142 | 143 | ### 1.6 144 | 145 | Now the Google gemini SDK has been updated, and the new version of the SDK is suitable for the new model of gemini 2.0! 146 | 147 | The new SDK is more powerful and mainly supports the function of verifying uploaded videos. 148 | 149 | If you want to repeatedly tag the same video and no longer need to upload it repeatedly, the video name and file size/hash will be automatically verified. 150 | 151 | At the same time, the millisecond-level alignment function has been updated. After the subtitles of long video segmentation are merged, the timeline is automatically aligned to milliseconds, which is very neat! 152 | 153 | ## Features 154 | 155 | - Automatic video/audio/image description using Google's Gemini API or only image with pixtral-large 124B 156 | - Export captions in SRT format 157 | - Support for multiple video formats 158 | - Batch processing with progress tracking 159 | - Maintains original directory structure 160 | - Configurable through TOML files 161 | - Lance database integration for efficient data management 162 | 163 | ## Modules 164 | 165 | ### Dataset Import (`lanceImport.py`) 166 | - Import videos into Lance database format 167 | - Preserve original directory structure 168 | - Support for both single directory and paired directory structures 169 | 170 | ### Dataset Export (`lanceexport.py`) 171 | - Extract videos and captions from Lance datasets 172 | - Maintains original file structure 173 | - Exports captions as SRT files in the same directory as source videos 174 | - Auto Clip with SRT timestamps 175 | 176 | ### Auto Captioning (`captioner.py` & `api_handler.py`) 177 | - Automatic video scene description using Gemini API or Pixtral API 178 | - Batch processing support 179 | - SRT format output with timestamps 180 | - Robust error handling and retry mechanisms 181 | - Progress tracking for batch operations 182 | 183 | ### Configuration (`config.py` & `config.toml`) 184 | - API prompt configuration management 185 | - Customizable batch processing parameters 186 | - Default schema includes file paths and metadata 187 | 188 | ## Installation 189 | 190 | Give unrestricted script access to powershell so venv can work: 191 | 192 | - Open an administrator powershell window 193 | - Type Set-ExecutionPolicy Unrestricted and answer A 194 | - Close admin powershell window 195 | 196 | ![Video Preview](https://files.catbox.moe/jr5n3e.gif) 197 | 198 | ### Windows 199 | Run the following PowerShell script: 200 | ```powershell 201 | ./1、install-uv-qinglong.ps1 202 | ``` 203 | 204 | ### Linux 205 | 1. First install PowerShell: 206 | ```bash 207 | sudo sh ./0、install pwsh.sh 208 | ``` 209 | 2. Then run the installation script using PowerShell: 210 | ```powershell 211 | sudo pwsh ./1、install-uv-qinglong.ps1 212 | ``` 213 | use sudo pwsh if you in Linux. 214 | 215 | ### TensorRT (Optional) 216 | windows need to install TensorRT-libs manually from [here](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.9.0/zip/TensorRT-10.9.0.34.Windows.win10.cuda-12.8.zip). 217 | TensorRT can faster use WD14Tagger (not effect API part) 218 | Now we use 10.9 version 219 | 220 | ## Usage 221 | 222 | video example: 223 | https://files.catbox.moe/8fudnf.mp4 224 | 225 | ### Just put Video or audio files into datasets folders 226 | 227 | ### Importing Media 228 | Use the PowerShell script to import your videos: 229 | ```powershell 230 | ./lanceImport.ps1 231 | ``` 232 | 233 | ### Exporting Media 234 | Use the PowerShell script to export data from Lance format: 235 | ```powershell 236 | ./lanceExport.ps1 237 | ``` 238 | 239 | ### Auto Captioning 240 | Use the PowerShell script to generate captions for your videos: 241 | 242 | ```powershell 243 | ./run.ps1 244 | ``` 245 | 246 | Note: You'll need to configure your [Gemini API key](https://aistudio.google.com/apikey) in `run.ps1` before using the auto-captioning feature. 247 | [Pixtral API key](https://console.mistral.ai/api-keys/) optional for image caption. 248 | 249 | Now we support [step-1.5v-mini](https://platform.stepfun.com/) optional for video captioner. 250 | 251 | Now we support [qwen-VL](https://bailian.console.aliyun.com/#/model-market) series optional for video captioner. 252 | 253 | Now we support [Mistral OCR](https://console.mistral.ai/api-keys/) optional for PDF and image OCR. 254 | 255 | Now we support [GLM](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) series optional for video captioner. 256 | 257 | ``` 258 | $dataset_path = "./datasets" 259 | $gemini_api_key = "" 260 | $gemini_model_path = "gemini-2.0-pro-exp-02-05" 261 | $pixtral_api_key = "" 262 | $pixtral_model_path = "pixtral-large-2411" 263 | $step_api_key = "" 264 | $step_model_path = "step-1.5v-mini" 265 | $qwenVL_api_key = "" 266 | $qwenVL_model_path = "qwen-vl-max-latest" # qwen2.5-vl-72b-instruct<10mins qwen-vl-max-latest <1min 267 | $glm_api_key = "" 268 | $glm_model_path = "GLM-4V-Plus-0111" 269 | $dir_name = $true 270 | $mode = "long" 271 | $not_clip_with_caption = $false # Not clip with caption | 不根据caption裁剪 272 | $wait_time = 1 273 | $max_retries = 100 274 | $segment_time = 600 275 | $ocr = $false 276 | $document_image = $true 277 | $scene_detector = "AdaptiveDetector" # from ["ContentDetector","AdaptiveDetector","HashDetector","HistogramDetector","ThresholdDetector"] 278 | $scene_threshold = 0.0 # default value ["ContentDetector": 27.0, "AdaptiveDetector": 3.0, "HashDetector": 0.395, "HistogramDetector": 0.05, "ThresholdDetector": 12] 279 | $scene_min_len = 15 280 | $scene_luma_only = $false 281 | ``` 282 | --- 283 | 284 | # 青龙数据集工具 (2.6) 285 | 286 | 基于 Lance 数据库格式的视频自动字幕生成工具,使用 Gemini API 进行场景描述生成。 287 | 288 | ## 功能特点 289 | 290 | - 使用 Google Gemini API 进行视频场景自动描述 291 | - 导出 SRT 格式字幕文件 292 | - 支持多种视频格式 293 | - 批量处理并显示进度 294 | - 保持原始目录结构 295 | - 通过 TOML 文件配置 296 | - 集成 Lance 数据库实现高效数据管理 297 | 298 | ## 模块说明 299 | 300 | ### 数据集导入 (`lanceImport.py`) 301 | - 将视频导入 Lance 数据库格式 302 | - 保持原始目录结构 303 | - 支持单目录和配对目录结构 304 | 305 | ### 数据集导出 (`lanceexport.py`) 306 | - 从 Lance 数据集中提取视频和字幕 307 | - 保持原有文件结构 308 | - 在源视频所在目录导出 SRT 格式字幕 309 | 310 | ### 自动字幕生成 (`captioner.py` & `api_handler.py`) 311 | - 使用 Gemini API 进行视频场景描述 312 | - 支持批量处理 313 | - 生成带时间戳的 SRT 格式字幕 314 | - 健壮的错误处理和重试机制 315 | - 批处理进度跟踪 316 | 317 | ### 配置模块 (`config.py` & `config.toml`) 318 | - API 配置管理 319 | - 可自定义批处理参数 320 | - 默认结构包含文件路径和元数据 321 | 322 | ## 安装方法 323 | 324 | ### Windows 系统 325 | 运行以下 PowerShell 脚本: 326 | ```powershell 327 | ./1、install-uv-qinglong.ps1 328 | ``` 329 | 330 | ### Linux 系统 331 | 1. 首先安装 PowerShell: 332 | ```bash 333 | sudo sh ./0、install pwsh.sh 334 | ``` 335 | 2. 然后使用 PowerShell 运行安装脚本: 336 | ```powershell 337 | pwsh ./1、install-uv-qinglong.ps1 338 | ``` 339 | 340 | ## 使用方法 341 | 342 | ### 把媒体文件放到datasets文件夹下 343 | 344 | ### 导入视频 345 | 使用 PowerShell 脚本导入视频: 346 | ```powershell 347 | ./lanceImport.ps1 348 | ``` 349 | 350 | ### 导出数据 351 | 使用 PowerShell 脚本从 Lance 格式导出数据: 352 | ```powershell 353 | ./lanceExport.ps1 354 | ``` 355 | 356 | ### 自动字幕生成 357 | 使用 PowerShell 脚本为视频生成字幕: 358 | ```powershell 359 | ./run.ps1 360 | ``` 361 | 362 | 注意:使用自动字幕生成功能前,需要在 `run.ps1` 中配置 [Gemini API 密钥](https://aistudio.google.com/apikey)。 363 | [Pixtral API 秘钥](https://console.mistral.ai/api-keys/) 可选为图片打标。 364 | 365 | 现在我们支持使用[阶跃星辰](https://platform.stepfun.com/)的视频模型进行视频标注。 366 | 367 | 现在我们支持使用[通义千问VL](https://bailian.console.aliyun.com/#/model-market)的视频模型进行视频标注。 368 | 369 | 现在我们支持使用[Mistral OCR](https://console.mistral.ai/api-keys/)的OCR功能进行图片字幕生成。 370 | 371 | 现在我们支持使用[智谱GLM](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys)的视频模型进行视频标注。 372 | 373 | ``` 374 | $dataset_path = "./datasets" 375 | $gemini_api_key = "" 376 | $gemini_model_path = "gemini-2.0-pro-exp-02-05" 377 | $pixtral_api_key = "" 378 | $pixtral_model_path = "pixtral-large-2411" 379 | $step_api_key = "" 380 | $step_model_path = "step-1.5v-mini" 381 | $qwenVL_api_key = "" 382 | $qwenVL_model_path = "qwen-vl-max-latest" # qwen2.5-vl-72b-instruct<10mins qwen-vl-max-latest <1min 383 | $glm_api_key = "" 384 | $glm_model_path = "GLM-4V-Plus-0111" 385 | $dir_name = $true 386 | $mode = "long" 387 | $not_clip_with_caption = $false # Not clip with caption | 不根据caption裁剪 388 | $wait_time = 1 389 | $max_retries = 100 390 | $segment_time = 600 391 | $ocr = $false 392 | $document_image = $true 393 | $scene_detector = "AdaptiveDetector" # from ["ContentDetector","AdaptiveDetector","HashDetector","HistogramDetector","ThresholdDetector"] 394 | $scene_threshold = 0.0 # default value ["ContentDetector": 27.0, "AdaptiveDetector": 3.0, "HashDetector": 0.395, "HistogramDetector": 0.05, "ThresholdDetector": 12] 395 | $scene_min_len = 15 396 | $scene_luma_only = $false 397 | ``` 398 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/qinglong-captions/c9376f3bf325675c7dc10b5da702e1e0eb629947/config/__init__.py -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | """Configuration constants for the dataset processing.""" 2 | 3 | from typing import Tuple, List, Dict, Any 4 | import os 5 | import toml 6 | import pyarrow as pa 7 | 8 | # Base image extensions 9 | BASE_IMAGE_EXTENSIONS: List[str] = [ 10 | ".png", 11 | ".jpg", 12 | ".jpeg", 13 | ".gif", 14 | ".webp", 15 | ".bmp", 16 | ".ico", 17 | ".tif", 18 | ".tiff", 19 | ".PNG", 20 | ".JPG", 21 | ".JPEG", 22 | ".GIF", 23 | ".WEBP", 24 | ".BMP", 25 | ".ICO", 26 | ".TIF", 27 | ".TIFF", 28 | ] 29 | 30 | BASE_ANIMATION_EXTENSIONS: List[str] = [ 31 | ".gif", 32 | ".webp", 33 | ".GIF", 34 | ".WEBP", 35 | ] 36 | 37 | BASE_VIDEO_EXTENSIONS: List[str] = [ 38 | ".mp4", 39 | ".webm", 40 | ".avi", 41 | ".mkv", 42 | ".mov", 43 | ".flv", 44 | ".wmv", 45 | ".m4v", 46 | ".mpg", 47 | ".mpeg", 48 | ".MP4", 49 | ".WEBM", 50 | ".AVI", 51 | ".MKV", 52 | ".MOV", 53 | ".FLV", 54 | ".WMV", 55 | ".M4V", 56 | ".MPG", 57 | ".MPEG", 58 | ] 59 | 60 | BASE_AUDIO_EXTENSIONS: List[str] = [ 61 | ".mp3", 62 | ".wav", 63 | ".ogg", 64 | ".flac", 65 | ".m4a", 66 | ".wma", 67 | ".aac", 68 | ".aiff", 69 | ".aifc", 70 | ".aif", 71 | ".au", 72 | ".snd", 73 | ".mid", 74 | ".midi", 75 | ".mka", 76 | ".MP3", 77 | ".WAV", 78 | ".OGG", 79 | ".FLAC", 80 | ".M4A", 81 | ".WMA", 82 | ".AAC", 83 | ".AIFF", 84 | ".AIFC", 85 | ".AIF", 86 | ".AU", 87 | ".SND", 88 | ".MID", 89 | ".MIDI", 90 | ".MKA", 91 | ] 92 | 93 | BASE_APPLICATION_EXTENSIONS: List[str] = [ 94 | ".pdf", 95 | ".PDF", 96 | ] 97 | 98 | 99 | def get_supported_extensions(media_type: str = "image") -> Tuple[str, ...]: 100 | """Get all supported media extensions including optional formats.""" 101 | if media_type == "image" or media_type == "animation": 102 | extensions = ( 103 | BASE_IMAGE_EXTENSIONS.copy() 104 | if media_type == "image" 105 | else BASE_ANIMATION_EXTENSIONS.copy() 106 | ) 107 | 108 | # Try to add AVIF support 109 | try: 110 | import pillow_avif 111 | 112 | extensions.extend([".avif", ".AVIF"]) 113 | except ImportError: 114 | pass 115 | 116 | # Try to add JPEG-XL support 117 | try: 118 | import pillow_jxl 119 | 120 | extensions.extend([".jxl", ".JXL"]) 121 | except ImportError: 122 | pass 123 | 124 | try: 125 | from pillow_heif import register_heif_opener 126 | 127 | register_heif_opener() 128 | extensions.extend([".heic", ".heif", ".HEIC", ".HEIF"]) 129 | except ImportError: 130 | pass 131 | 132 | if media_type == "animation": 133 | try: 134 | from apng import APNG 135 | 136 | extensions.extend([".apng", ".APNG"]) 137 | except ImportError: 138 | pass 139 | 140 | elif media_type == "video": 141 | extensions = BASE_VIDEO_EXTENSIONS.copy() 142 | elif media_type == "audio": 143 | extensions = BASE_AUDIO_EXTENSIONS.copy() 144 | elif media_type == "application": 145 | extensions = BASE_APPLICATION_EXTENSIONS.copy() 146 | 147 | return tuple(extensions) 148 | 149 | 150 | def load_toml_config(config_path: str, section: str) -> Dict[str, Any]: 151 | """Load a configuration section from a TOML file. 152 | 153 | Args: 154 | config_path: Path to the TOML file 155 | section: Name of the section to load 156 | 157 | Returns: 158 | Dictionary containing the configuration data 159 | """ 160 | if not os.path.exists(config_path): 161 | raise FileNotFoundError(f"Config file not found: {config_path}") 162 | 163 | try: 164 | config = toml.load(config_path) 165 | section_data = config.get(section, {}) 166 | 167 | if not section_data: 168 | raise ValueError(f"No {section} configuration found in TOML file") 169 | 170 | return section_data 171 | except Exception as e: 172 | raise ValueError(f"Failed to parse config file: {str(e)}") 173 | 174 | 175 | def load_schema_from_toml(config_path: str) -> List[Tuple[str, str]]: 176 | """Load dataset schema from a TOML file. 177 | 178 | Args: 179 | config_path: Path to the TOML file containing schema definition 180 | 181 | Returns: 182 | List of tuples containing (field_name, field_type) 183 | """ 184 | schema_data = load_toml_config(config_path, "schema") 185 | fields = schema_data.get("fields", []) 186 | return [(field["name"], field["type"]) for field in fields] 187 | 188 | 189 | def load_colors_from_toml(config_path: str) -> Dict[str, str]: 190 | """Load console colors from a TOML file. 191 | 192 | Args: 193 | config_path: Path to the TOML file containing colors definition 194 | 195 | Returns: 196 | Dictionary mapping media types to color names 197 | """ 198 | return load_toml_config(config_path, "colors") 199 | 200 | 201 | def load_prompts_from_toml(config_path: str) -> Dict[str, str]: 202 | """Load prompts from a TOML file. 203 | 204 | Args: 205 | config_path: Path to the TOML file containing prompts definition 206 | 207 | Returns: 208 | Dictionary containing prompt configurations 209 | """ 210 | return load_toml_config(config_path, "prompts") 211 | 212 | 213 | # Default schema definition 214 | DEFAULT_DATASET_SCHEMA = [ 215 | ("uris", pa.string()), 216 | ("mime", pa.string()), 217 | ("width", pa.int32()), 218 | ("height", pa.int32()), 219 | ("channels", pa.int32()), 220 | ("depth", pa.int32()), 221 | ("hash", pa.string()), 222 | ("size", pa.int64()), 223 | ("has_audio", pa.bool_()), 224 | ("duration", pa.int32()), 225 | ("num_frames", pa.int32()), 226 | ("frame_rate", pa.float32()), 227 | ("blob", pa.large_binary()), 228 | ("captions", pa.list_(pa.string())), 229 | ] 230 | 231 | # Default console colors 232 | DEFAULT_CONSOLE_COLORS = { 233 | "image": "green", 234 | "animation": "bold green", 235 | "video": "magenta", 236 | "audio": "orange1", 237 | "application": "bright_red", 238 | "text": "yellow", 239 | "caption": "yellow", 240 | "unknown": "cyan", 241 | } 242 | 243 | # Current active configurations - defaults to built-in values 244 | DATASET_SCHEMA = DEFAULT_DATASET_SCHEMA.copy() 245 | CONSOLE_COLORS = DEFAULT_CONSOLE_COLORS.copy() 246 | SYSTEM_PROMPT = "" # Will be loaded from config 247 | 248 | 249 | def load_config(config_path: str) -> None: 250 | """Load all configurations from a TOML file. 251 | 252 | Args: 253 | config_path: Path to the TOML file containing configurations 254 | """ 255 | global DATASET_SCHEMA, CONSOLE_COLORS, SYSTEM_PROMPT 256 | 257 | try: 258 | DATASET_SCHEMA = load_schema_from_toml(config_path) 259 | except Exception as e: 260 | print(f"Warning: Failed to load schema configuration: {e}") 261 | 262 | try: 263 | colors = load_colors_from_toml(config_path) 264 | if colors: 265 | CONSOLE_COLORS.update(colors) 266 | except Exception as e: 267 | print(f"Warning: Failed to load colors configuration: {e}") 268 | 269 | try: 270 | prompts = load_prompts_from_toml(config_path) 271 | if prompts and "system_prompt" in prompts: 272 | SYSTEM_PROMPT = prompts["system_prompt"] 273 | except Exception as e: 274 | print(f"Warning: Failed to load prompts configuration: {e}") 275 | -------------------------------------------------------------------------------- /datasets/put datasets here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/qinglong-captions/c9376f3bf325675c7dc10b5da702e1e0eb629947/datasets/put datasets here -------------------------------------------------------------------------------- /lanceExport.ps1: -------------------------------------------------------------------------------- 1 | # Input parameters | 输入参数 2 | param( 3 | [string]$lance_file = "./datasets/dataset.lance", # Lance dataset path | Lance数据集路径 4 | [string]$output_dir = "./datasets", # Output directory | 输出目录 5 | [string]$version = "gemini", # Dataset version (gemini/WDtagger/pixtral) 6 | [bool]$not_clip_with_caption = $false # Not clip with caption | 不根据caption裁剪 7 | ) 8 | 9 | # Set working directory | 设置工作目录 10 | Set-Location $PSScriptRoot 11 | 12 | # Activate virtual environment | 激活虚拟环境 13 | $venvPaths = @( 14 | "venv/Scripts/activate", 15 | ".venv/Scripts/activate", 16 | "venv/bin/Activate.ps1", 17 | ".venv/bin/activate.ps1" 18 | ) 19 | 20 | $activated = $false 21 | foreach ($path in $venvPaths) { 22 | if (Test-Path $path) { 23 | Write-Output "Activating virtual environment: $path" 24 | & $path 25 | $activated = $true 26 | break 27 | } 28 | } 29 | 30 | if (-not $activated) { 31 | Write-Error "No virtual environment found. Please create one first." 32 | exit 1 33 | } 34 | 35 | # Set environment variables | 设置环境变量 36 | $env:HF_HOME = "huggingface" 37 | $env:XFORMERS_FORCE_DISABLE_TRITON = "1" 38 | $env:HF_ENDPOINT = "https://hf-mirror.com" 39 | 40 | # Prepare arguments | 准备参数 41 | $arguments = @( 42 | "-m", 43 | "module.lanceexport", 44 | $lance_file, 45 | "--output_dir=$output_dir" 46 | ) 47 | 48 | if ($version) { 49 | $arguments += "--version=$version" 50 | } 51 | 52 | if ($not_clip_with_caption) { 53 | $arguments += "--not_clip_with_caption" 54 | } 55 | 56 | # Run export script | 运行导出脚本 57 | Write-Output "Starting export from $lance_file to $output_dir" 58 | & python $arguments 59 | 60 | # Wait for user input before closing | 等待用户输入后关闭 61 | Write-Output "`nExport finished. Press Enter to exit..." 62 | $null = Read-Host -------------------------------------------------------------------------------- /lanceImport.ps1: -------------------------------------------------------------------------------- 1 | # Input parameters 2 | $train_data_dir = "./datasets" # input images path | 图片输入路径 3 | $caption_dir = $null # Optional caption files directory | 可选的描述文件目录 4 | $output_name = "dataset" # Output dataset name | 输出数据集名称 5 | $no_save_binary = $false # Don't save binary data | 不保存二进制数据 6 | $not_save_disk = $false # Load into memory only | 仅加载到内存 7 | $import_mode = 0 # Video import mode: 0=All, 1=Video only, 2=Audio only, 3=Split | 视频导入模式 8 | $tag = "gemini" # Dataset tag | 数据集标签 9 | 10 | # Activate virtual environment 11 | Set-Location $PSScriptRoot 12 | if ($env:OS -ilike "*windows*") { 13 | if (Test-Path "./venv/Scripts/activate") { 14 | Write-Output "Windows venv" 15 | ./venv/Scripts/activate 16 | } 17 | elseif (Test-Path "./.venv/Scripts/activate") { 18 | Write-Output "Windows .venv" 19 | ./.venv/Scripts/activate 20 | } 21 | } 22 | elseif (Test-Path "./venv/bin/activate") { 23 | Write-Output "Linux venv" 24 | ./venv/bin/Activate.ps1 25 | } 26 | elseif (Test-Path "./.venv/bin/activate") { 27 | Write-Output "Linux .venv" 28 | ./.venv/bin/activate.ps1 29 | } 30 | 31 | # Run the import script 32 | $args = @( 33 | $train_data_dir 34 | ) 35 | if ($caption_dir) { $args += "--caption_dir=$caption_dir" } 36 | if ($output_name) { $args += "--output_name=$output_name" } 37 | if ($no_save_binary) { $args += "--no_save_binary" } 38 | if ($not_save_disk) { $args += "--not_save_disk" } 39 | $args += "--import_mode=$import_mode" 40 | $args += "--tag=$tag" 41 | 42 | python -m module.lanceImport @args 43 | 44 | Write-Output "Import finished" 45 | Read-Host | Out-Null -------------------------------------------------------------------------------- /module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/qinglong-captions/c9376f3bf325675c7dc10b5da702e1e0eb629947/module/__init__.py -------------------------------------------------------------------------------- /module/captioner.py: -------------------------------------------------------------------------------- 1 | import lance 2 | from rich.progress import ( 3 | Progress, 4 | SpinnerColumn, 5 | TextColumn, 6 | BarColumn, 7 | TaskProgressColumn, 8 | TimeRemainingColumn, 9 | TimeElapsedColumn, 10 | TransferSpeedColumn, 11 | MofNCompleteColumn, 12 | ) 13 | from rich.console import Console 14 | from PIL import Image 15 | import json 16 | import pyarrow as pa 17 | from module.lanceImport import transform2lance 18 | from module.lanceexport import extract_from_lance 19 | from module.api_handler import api_process_batch, process_llm_response 20 | from module.scenedetect import SceneDetector, run_async_in_thread 21 | from utils.stream_util import ( 22 | split_media_stream_clips, 23 | split_video_with_imageio_ffmpeg, 24 | get_video_duration, 25 | ) 26 | from config.config import ( 27 | BASE_VIDEO_EXTENSIONS, 28 | BASE_AUDIO_EXTENSIONS, 29 | BASE_APPLICATION_EXTENSIONS, 30 | ) 31 | import re 32 | import argparse 33 | import toml 34 | import pysrt 35 | from pathlib import Path 36 | import base64 37 | import asyncio 38 | 39 | Image.MAX_IMAGE_PIXELS = None # Disable image size limit check 40 | 41 | console = Console() 42 | 43 | 44 | def process_batch(args, config): 45 | # Load the dataset 46 | if not isinstance(args.dataset_dir, lance.LanceDataset): 47 | if args.gemini_api_key == "" and args.pixtral_api_key == "": 48 | dataset = transform2lance(dataset_dir=args.dataset_dir) 49 | else: 50 | dataset = transform2lance(dataset_dir=args.dataset_dir, save_binary=False) 51 | 52 | scanner = dataset.scanner( 53 | columns=["uris", "blob", "mime", "captions", "duration", "hash"], 54 | scan_in_order=True, 55 | late_materialization=["blob"], 56 | batch_size=1, 57 | ) 58 | total_rows = dataset.count_rows() 59 | 60 | with Progress( 61 | "[progress.description]{task.description}", 62 | SpinnerColumn(spinner_name="dots"), 63 | MofNCompleteColumn(separator="/"), 64 | BarColumn(bar_width=40, complete_style="green", finished_style="bold green"), 65 | TextColumn("•"), 66 | TaskProgressColumn(), 67 | TextColumn("•"), 68 | TransferSpeedColumn(), 69 | TextColumn("•"), 70 | TimeElapsedColumn(), 71 | TextColumn("•"), 72 | TimeRemainingColumn(), 73 | expand=True, 74 | ) as progress: 75 | task = progress.add_task("[bold cyan]Processing media...", total=total_rows) 76 | 77 | results = [] 78 | scene_detectors = {} 79 | processed_filepaths = [] 80 | for batch in scanner.to_batches(): 81 | filepaths = batch["uris"].to_pylist() 82 | mime = batch["mime"].to_pylist() 83 | duration = batch["duration"].to_pylist() 84 | sha256hash = batch["hash"].to_pylist() 85 | 86 | for filepath, mime, duration, sha256hash in zip( 87 | filepaths, mime, duration, sha256hash 88 | ): 89 | # 创建场景检测器,但异步初始化它(不阻塞主线程) 90 | scene_detector = None 91 | if ( 92 | args.scene_threshold > 0 93 | and args.scene_min_len > 0 94 | and mime.startswith("video") 95 | ): 96 | scene_detector = SceneDetector( 97 | detector=args.scene_detector, 98 | threshold=args.scene_threshold, 99 | min_scene_len=args.scene_min_len, 100 | luma_only=args.scene_luma_only, 101 | console=progress, 102 | ) 103 | # 启动场景检测,直接使用协程对象 104 | coroutine = scene_detector.detect_scenes_async(filepath) 105 | run_async_in_thread(coroutine) 106 | # 保存检测器实例以便后续使用 107 | scene_detectors[filepath] = scene_detector 108 | 109 | if ( 110 | mime.startswith("image") 111 | or duration <= (args.segment_time + 1) * 1000 112 | ): 113 | 114 | output = api_process_batch( 115 | uri=filepath, 116 | mime=mime, 117 | config=config, 118 | args=args, 119 | sha256hash=sha256hash, 120 | progress=progress, 121 | task_id=task, 122 | ) 123 | 124 | output = _postprocess_caption_content( 125 | output, 126 | filepath, 127 | args, 128 | ) 129 | 130 | else: 131 | console.print( 132 | f"[blue]{filepath} video > {args.segment_time} seconds[/blue]" 133 | ) 134 | console.print(f"[blue]split video[/blue]") 135 | 136 | # 创建用于分割的字幕文件 137 | subs = pysrt.SubRipFile() 138 | 139 | # 计算分块 140 | duration_seconds = duration / 1000 # 将毫秒转换为秒 141 | chunk_duration = args.segment_time 142 | num_chunks = int( 143 | (duration_seconds + chunk_duration - 1) // chunk_duration 144 | ) 145 | 146 | # 创建字幕条目 147 | for i in range(num_chunks): 148 | start_time = i * chunk_duration 149 | end_time = min((i + 1) * chunk_duration, duration_seconds) 150 | 151 | # 创建字幕条目 152 | sub = pysrt.SubRipItem( 153 | index=i, 154 | start=pysrt.SubRipTime(seconds=start_time), 155 | end=pysrt.SubRipTime(seconds=end_time), 156 | text=f"Chunk {i + 1}", 157 | ) 158 | subs.append(sub) 159 | 160 | for sub in subs: 161 | console.print(f"[blue]Subtitles created:[/blue] {sub}") 162 | try: 163 | split_video_with_imageio_ffmpeg( 164 | Path(filepath), 165 | subs, 166 | save_caption_func=None, 167 | segment_time=args.segment_time, 168 | ) 169 | except Exception as e: 170 | # 使用字幕分割视频 171 | meta_type = "video" if mime.startswith("video") else "audio" 172 | console.print( 173 | f"[red]Error splitting video with imageio-ffmpeg: {e}[/red]" 174 | ) 175 | split_media_stream_clips(Path(filepath), meta_type, subs) 176 | 177 | pathfile = Path(filepath) 178 | clip_dir = pathfile.parent / f"{pathfile.stem}_clip" 179 | 180 | search_pattern = f"*{pathfile.suffix}" 181 | files = sorted(clip_dir.glob(search_pattern)) 182 | 183 | merged_subs = pysrt.SubRipFile() 184 | total_duration = 0 185 | 186 | # 使用全局变量来存储clip_task_id 187 | global clip_task_id 188 | 189 | # 检查是否已经存在clip任务 190 | if "clip_task_id" in globals() and clip_task_id in [ 191 | task.id for task in progress.tasks 192 | ]: 193 | # 重置已存在的clip任务 194 | progress.reset( 195 | clip_task_id, 196 | total=num_chunks, 197 | visible=True, 198 | description=f"[cyan]Processing clips...", 199 | ) 200 | clip_task = clip_task_id 201 | else: 202 | # 创建新的clip任务 203 | clip_task = progress.add_task( 204 | f"[cyan]Processing clips...", total=num_chunks 205 | ) 206 | clip_task_id = clip_task 207 | for i in range(num_chunks): 208 | sub_path = Path(filepath).with_suffix(".srt") 209 | if sub_path.exists(): 210 | sub = pysrt.open(sub_path, encoding="utf-8") 211 | merged_subs.extend(sub) 212 | 213 | console.print( 214 | f"[yellow]Processing chunk {i+1}/{num_chunks}[/yellow]" 215 | ) 216 | uri = files[i] 217 | 218 | chunk_output = api_process_batch( 219 | uri=uri, 220 | mime=mime, 221 | config=config, 222 | args=args, 223 | sha256hash=sha256hash, 224 | progress=progress, 225 | task_id=task, 226 | ) 227 | 228 | console.print( 229 | f"[green]API processing complete for chunk {i+1}[/green]" 230 | ) 231 | 232 | console.print( 233 | f"[yellow]Post-processing chunk output...[/yellow]" 234 | ) 235 | chunk_output = _postprocess_caption_content( 236 | chunk_output, uri, args 237 | ) 238 | 239 | chunk_subs = pysrt.from_string(chunk_output) 240 | # 检查并删除超时的字幕 241 | for sub in list(chunk_subs): # 使用list创建副本以便安全删除 242 | if ( 243 | sub.start.ordinal > args.segment_time * 1000 244 | ): # 转换为毫秒比较 245 | chunk_subs.remove(sub) 246 | 247 | if i > 0: 248 | last_duration = get_video_duration(files[i - 1]) 249 | 250 | total_duration += int(float(last_duration)) 251 | # 将纯毫秒单位转换为分、秒、毫秒 252 | last_duration_minutes = int(total_duration / 60000) 253 | last_duration_seconds = int((total_duration % 60000) / 1000) 254 | last_duration_milliseconds = total_duration % 1000 255 | 256 | console.print( 257 | f"[yellow]Total shift duration: {last_duration_minutes}m {last_duration_seconds}s {last_duration_milliseconds}ms[/yellow]" 258 | ) 259 | # Shift all subtitles in the chunk 260 | chunk_subs.shift( 261 | minutes=last_duration_minutes, 262 | seconds=last_duration_seconds, 263 | milliseconds=last_duration_milliseconds, 264 | ) 265 | 266 | # Extend merged subtitles with the shifted chunk 267 | merged_subs.extend(chunk_subs) 268 | console.print( 269 | f"[green]Successfully merged chunk {i+1}. Total subtitles: {len(merged_subs)}[/green]" 270 | ) 271 | 272 | progress.update( 273 | clip_task, 274 | advance=1, 275 | refresh=True, 276 | description=f"[yellow]merging complete for chunk [/yellow]", 277 | ) 278 | 279 | # Mark the clip task as completed and hide it 280 | progress.update(clip_task, completed=num_chunks, visible=False) 281 | 282 | # Update indices to continue from the last subtitle 283 | merged_subs.clean_indexes() 284 | # 手动构建 SRT 格式 285 | output = "" 286 | for i, sub in enumerate(merged_subs, start=1): 287 | # 格式: 序号 + 时间戳 + 文本 288 | output += f"{i}\n" 289 | output += f"{sub.start} --> {sub.end}\n" 290 | output += f"{sub.text}\n\n" 291 | if output: 292 | console.print( 293 | f"[green]All subtitles merged successfully. Total: {len(merged_subs)}[/green]" 294 | ) 295 | 296 | for file in files: 297 | file.unlink(missing_ok=True) 298 | 299 | results.append(output) 300 | processed_filepaths.append(filepath) 301 | 302 | filepath_path = Path(filepath) 303 | # Determine caption file extension based on media type 304 | if ( 305 | filepath_path.suffix in BASE_VIDEO_EXTENSIONS 306 | or filepath_path.suffix in BASE_AUDIO_EXTENSIONS 307 | ): 308 | caption_path = filepath_path.with_suffix(".srt") 309 | elif filepath_path.suffix in BASE_APPLICATION_EXTENSIONS: 310 | caption_path = filepath_path.with_suffix(".md") 311 | else: 312 | caption_path = filepath_path.with_suffix(".txt") 313 | console.print(f"[blue]Processing caption for:[/blue] {filepath_path}") 314 | if isinstance(output, dict): 315 | console.print( 316 | f"[blue]Caption content length:[/blue] {len(output['description'])}" 317 | ) 318 | else: 319 | console.print(f"[blue]Caption content length:[/blue] {len(output)}") 320 | 321 | if caption_path.suffix == ".srt": 322 | try: 323 | subs = pysrt.from_string(output) 324 | if scene_detector: 325 | # 检查场景检测是否已经完成 326 | console.print( 327 | f"[bold cyan]{scene_detectors[filepath].get_scene_list()}...[/bold cyan]" 328 | ) 329 | if scene_detectors[filepath].get_scene_list() is None: 330 | scene_list = asyncio.run( 331 | scene_detectors[filepath].ensure_detection_complete( 332 | filepath 333 | ) 334 | ) 335 | else: 336 | scene_list = scene_detectors[filepath].get_scene_list() 337 | # 使用实例方法align_subtitle,传入scene_list参数 338 | console.print( 339 | f"[bold cyan]Aligning subtitles with scene changes...[/bold cyan]" 340 | ) 341 | subs = scene_detectors[filepath].align_subtitle( 342 | subs, 343 | scene_list=scene_list, 344 | console=console, 345 | segment_time=args.segment_time, 346 | ) 347 | subs.save(str(caption_path), encoding="utf-8") 348 | console.print( 349 | f"[green]Saved captions to {caption_path}[/green]" 350 | ) 351 | except Exception as e: 352 | console.print( 353 | f"[yellow]pysrt validation failed: {e}, falling back to direct file write[/yellow]" 354 | ) 355 | try: 356 | caption_path.write_text(output, encoding="utf-8") 357 | console.print( 358 | f"[green]Saved captions to {caption_path}[/green]" 359 | ) 360 | except Exception as e: 361 | console.print(f"[red]Error saving SRT file: {e}[/red]") 362 | elif caption_path.suffix == ".md": 363 | try: 364 | with open(caption_path, "w", encoding="utf-8") as f: 365 | f.write(output) 366 | console.print( 367 | f"[green]Saved captions to {caption_path}[/green]" 368 | ) 369 | except Exception as e: 370 | console.print(f"[red]Error saving MD file: {e}[/red]") 371 | else: 372 | try: 373 | if isinstance(output, list): 374 | with open(caption_path, "w", encoding="utf-8") as f: 375 | for line in output: 376 | f.write(line + "\n") 377 | elif isinstance(output, dict): 378 | with open( 379 | filepath_path.with_suffix(".json"), 380 | "w", 381 | encoding="utf-8", 382 | ) as f: 383 | json.dump(output, f, indent=2, ensure_ascii=False) 384 | with open(caption_path, "w", encoding="utf-8") as f: 385 | if "description" in output and output["description"]: 386 | f.write(output["description"]) 387 | else: 388 | f.write("No description available") 389 | else: 390 | caption_path.write_text(output, encoding="utf-8") 391 | console.print( 392 | f"[green]Saved captions to {caption_path}[/green]" 393 | ) 394 | except Exception as e: 395 | console.print(f"[red]Error saving TXT file: {e}[/red]") 396 | 397 | progress.update(task, advance=len(batch)) 398 | 399 | # 处理完所有批次后,将主任务设置为不可见 400 | progress.update(task, visible=False) 401 | 402 | # Update dataset with new captions 403 | if results: 404 | # 确保每个caption都是单个字符串 405 | processed_captions = [] 406 | for caption in results: 407 | if isinstance(caption, list): 408 | # 如果是列表,合并为单个字符串 409 | processed_captions.append("\n".join(caption)) 410 | elif isinstance(caption, dict): 411 | processed_captions.append(json.dumps(caption, ensure_ascii=False)) 412 | else: 413 | processed_captions.append(caption) 414 | 415 | table = pa.table( 416 | { 417 | "uris": pa.array(processed_filepaths, type=pa.string()), 418 | "captions": pa.array( 419 | [[caption] for caption in processed_captions], 420 | type=pa.list_(pa.string()), 421 | ), 422 | } 423 | ) 424 | 425 | dataset.merge_insert(on="uris").when_matched_update_all().execute(table) 426 | 427 | try: 428 | dataset.tags.create("gemini", 1) 429 | except: 430 | dataset.tags.update("gemini", 1) 431 | 432 | console.print("[green]Successfully updated dataset with new captions[/green]") 433 | 434 | extract_from_lance( 435 | dataset, args.dataset_dir, clip_with_caption=not args.not_clip_with_caption 436 | ) 437 | 438 | 439 | def _postprocess_caption_content(output, filepath, args): 440 | """ 441 | postprocess_caption_content 442 | """ 443 | if not output: 444 | console.print(f"[red]No caption content generated for {filepath}[/red]") 445 | return "" 446 | 447 | if isinstance(output, list): 448 | # 检查是否为OCRPageObject对象列表 449 | if ( 450 | len(output) > 0 451 | and hasattr(output[0], "markdown") 452 | and hasattr(output[0], "index") 453 | ): 454 | combined_output = "" 455 | for page in output: 456 | # 添加页面索引作为HTML页眉和页脚 457 | page_index = page.index if hasattr(page, "index") else "unknown" 458 | # 添加HTML页眉 459 | combined_output += f'
\n Page {page_index+1} \n
\n\n' 460 | # 添加页面内容 461 | page_markdown = ( 462 | page.markdown if hasattr(page, "markdown") else str(page) 463 | ) 464 | # 替换图片路径,将图片路径改为上一级目录 465 | if hasattr(page, "images") and args.document_image: 466 | # 查找并替换所有图片引用格式 ![...](filename) 467 | img_pattern = r"!\[(.*?)\]\(([^/)]+)\)" 468 | parent_dir = Path(filepath).stem 469 | page_markdown = re.sub( 470 | img_pattern, 471 | lambda m: f"![{m.group(1)}]({parent_dir}/{m.group(2)})", 472 | page_markdown, 473 | ) 474 | 475 | if hasattr(page, "images") and args.document_image: 476 | for image in page.images: 477 | if hasattr(image, "image_base64") and image.image_base64: 478 | try: 479 | base64_str = image.image_base64 480 | # 处理data URL格式 481 | if base64_str.startswith("data:"): 482 | # 提取实际的base64内容 483 | base64_content = base64_str.split(",", 1)[1] 484 | image_data = base64.b64decode(base64_content) 485 | else: 486 | image_data = base64.b64decode(base64_str) 487 | 488 | image_filename = image.id 489 | image_dir = Path(filepath).with_suffix("") 490 | image_dir.mkdir(parents=True, exist_ok=True) 491 | image_path = image_dir / image_filename 492 | with open(image_path, "wb") as img_file: 493 | img_file.write(image_data) 494 | except Exception as e: 495 | console.print( 496 | f"[yellow]Error saving OCR image: {e}[/yellow]" 497 | ) 498 | # 这里添加页面内容,只添加一次 499 | combined_output += f"{page_markdown}\n\n" 500 | # 添加HTML页脚和分隔符 501 | combined_output += f'\n\n' 502 | combined_output += '
\n\n' 503 | output = combined_output 504 | else: 505 | output = "\n".join(output) 506 | 507 | # 确保字幕内容格式正确 508 | output = output.strip() 509 | if not output.strip(): 510 | console.print(f"[red]Empty caption content for {filepath}[/red]") 511 | return "" 512 | 513 | # 格式化时间戳 - 只处理视频和音频文件的字幕 514 | if ( 515 | Path(filepath).suffix in BASE_VIDEO_EXTENSIONS 516 | or Path(filepath).suffix in BASE_AUDIO_EXTENSIONS 517 | ): 518 | # 确保字幕内容格式正确 519 | output = output.strip() 520 | if not output.strip(): 521 | console.print(f"[red]Empty caption content for {filepath}[/red]") 522 | return "" 523 | 524 | # 使用单一的正则表达式和处理函数来修复时间戳格式 525 | # 创建一个匹配各种时间戳格式的模式 526 | timestamp_pattern = re.compile( 527 | # 匹配格式1: M:SS,mmm (单位数分钟) 528 | r"(? argparse.ArgumentParser: 576 | parser = argparse.ArgumentParser() 577 | 578 | parser.add_argument("dataset_dir", type=str, help="directory for dataset") 579 | 580 | parser.add_argument( 581 | "--gemini_api_key", 582 | type=str, 583 | default="", 584 | help="API key for gemini API", 585 | ) 586 | 587 | parser.add_argument( 588 | "--gemini_model_path", 589 | type=str, 590 | default="gemini-exp-1206", 591 | help="Model path for gemini", 592 | ) 593 | 594 | parser.add_argument( 595 | "--step_api_key", 596 | type=str, 597 | default="", 598 | help="API key for step API", 599 | ) 600 | 601 | parser.add_argument( 602 | "--step_model_path", 603 | type=str, 604 | default="step-1.5v-mini", 605 | help="video model for step", 606 | ) 607 | 608 | parser.add_argument( 609 | "--qwenVL_api_key", 610 | type=str, 611 | default="", 612 | help="API key for qwenVL API", 613 | ) 614 | 615 | parser.add_argument( 616 | "--qwenVL_model_path", 617 | type=str, 618 | default="qwen-vl-max-latest", 619 | help="video model for qwenVL", 620 | ) 621 | 622 | parser.add_argument( 623 | "--pixtral_api_key", 624 | type=str, 625 | default="", 626 | help="API key for pixtral API", 627 | ) 628 | 629 | parser.add_argument( 630 | "--pixtral_model_path", 631 | type=str, 632 | default="pixtral-large-2411", 633 | help="Model path for pixtral", 634 | ) 635 | 636 | parser.add_argument( 637 | "--glm_api_key", 638 | type=str, 639 | default="", 640 | help="API key for glm API", 641 | ) 642 | 643 | parser.add_argument( 644 | "--glm_model_path", 645 | type=str, 646 | default="glm-4v-plus-0111", 647 | help="Model path for glm", 648 | ) 649 | 650 | parser.add_argument( 651 | "--dir_name", 652 | action="store_true", 653 | help="Use the directory name as the dataset name", 654 | ) 655 | 656 | parser.add_argument( 657 | "--mode", 658 | type=str, 659 | default="all", 660 | help="Mode for processing the dataset", 661 | ) 662 | 663 | parser.add_argument( 664 | "--config", 665 | type=str, 666 | default="config/config.toml", 667 | help="Path to config file", 668 | ) 669 | 670 | parser.add_argument( 671 | "--not_clip_with_caption", 672 | action="store_true", 673 | help="Not clip with caption", 674 | ) 675 | 676 | parser.add_argument( 677 | "--wait_time", 678 | type=int, 679 | default=1, 680 | help="Wait time", 681 | ) 682 | 683 | parser.add_argument( 684 | "--max_retries", 685 | type=int, 686 | default=20, 687 | help="Max retries", 688 | ) 689 | 690 | parser.add_argument( 691 | "--segment_time", 692 | type=int, 693 | default=600, 694 | help="Segment time", 695 | ) 696 | 697 | parser.add_argument( 698 | "--ocr", 699 | action="store_true", 700 | help="Use OCR to extract text from image", 701 | ) 702 | 703 | parser.add_argument( 704 | "--document_image", 705 | action="store_true", 706 | help="Use OCR to extract image from document", 707 | ) 708 | 709 | parser.add_argument( 710 | "--scene_detector", 711 | type=str, 712 | choices=[ 713 | "ContentDetector", 714 | "AdaptiveDetector", 715 | "HashDetector", 716 | "HistogramDetector", 717 | "ThresholdDetector", 718 | ], 719 | default="AdaptiveDetector", 720 | help="Detector to use for scene detection", 721 | ) 722 | 723 | parser.add_argument( 724 | "--scene_threshold", 725 | type=float, 726 | default=0.0, 727 | help="Threshold for scene detection", 728 | ) 729 | 730 | parser.add_argument( 731 | "--scene_min_len", 732 | type=int, 733 | default=15, 734 | help="Minimum length(frames) for scene detection", 735 | ) 736 | 737 | parser.add_argument( 738 | "--scene_luma_only", 739 | action="store_true", 740 | help="Only use luma (brightness) without color changes for scene detection.", 741 | ) 742 | 743 | parser.add_argument( 744 | "--gemini_task", 745 | type=str, 746 | default="", 747 | help="Task for gemini-2.0-flash-exp", 748 | ) 749 | 750 | parser.add_argument( 751 | "--tags_highlightrate", 752 | type=float, 753 | default=0.4, 754 | help="tags_highlightrate for check captions", 755 | ) 756 | 757 | return parser 758 | 759 | 760 | if __name__ == "__main__": 761 | parser = setup_parser() 762 | 763 | args = parser.parse_args() 764 | 765 | config = toml.load(args.config) 766 | 767 | process_batch(args, config) 768 | -------------------------------------------------------------------------------- /module/lanceImport.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset processing utilities for image-caption pairs using Lance format. 3 | This module provides tools for converting image-caption datasets to Lance format 4 | and accessing the data through PyTorch datasets. 5 | """ 6 | 7 | import argparse 8 | import hashlib 9 | from dataclasses import dataclass 10 | from typing import Optional, List, Dict, Any, Callable, Tuple, Union 11 | import imageio.v3 as iio 12 | import lance 13 | import pyarrow as pa 14 | from PIL import Image, ImageMode 15 | from rich.progress import ( 16 | Progress, 17 | SpinnerColumn, 18 | TextColumn, 19 | BarColumn, 20 | TaskProgressColumn, 21 | TimeRemainingColumn, 22 | TimeElapsedColumn, 23 | TransferSpeedColumn, 24 | MofNCompleteColumn, 25 | ) 26 | from rich.console import Console 27 | import mimetypes 28 | from pathlib import Path 29 | from enum import Enum 30 | import numpy as np 31 | from mutagen import File as MutagenFile 32 | 33 | from config.config import ( 34 | get_supported_extensions, 35 | DATASET_SCHEMA, 36 | CONSOLE_COLORS, 37 | ) 38 | 39 | 40 | console = Console() 41 | image_extensions = get_supported_extensions("image") 42 | animation_extensions = get_supported_extensions("animation") 43 | video_extensions = get_supported_extensions("video") 44 | audio_extensions = get_supported_extensions("audio") 45 | application_extensions = get_supported_extensions("application") 46 | 47 | 48 | @dataclass 49 | class Metadata: 50 | """Metadata for media file.""" 51 | 52 | uris: str # File path or URL 53 | mime: str # MIME type 54 | width: int = 0 # Image/video width in pixels 55 | height: int = 0 # Image/video height in pixels 56 | depth: int = 0 # Sample depth/width in bits 57 | channels: int = 0 # Number of channels (RGB=3, RGBA=4, mono=1, stereo=2) 58 | hash: str = "" # SHA256 hash 59 | size: int = 0 # File size in bytes 60 | has_audio: bool = False # True if audio is present 61 | duration: Optional[int] = None # Duration in milliseconds 62 | num_frames: Optional[int] = 1 # Number of frames 63 | frame_rate: float = 0.0 # Frames/samples per second 64 | blob: bytes = b"" # Binary data 65 | 66 | @property 67 | def filename(self) -> str: 68 | """File name without extension, derived from filepath.""" 69 | return Path(self.uris).stem 70 | 71 | @property 72 | def ext(self) -> str: 73 | """File extension derived from filepath, including dot.""" 74 | return Path(self.uris).suffix 75 | 76 | @property 77 | def bits_per_channel(self) -> int: 78 | """Get bits per channel.""" 79 | return self.depth if self.channels > 0 else 0 80 | 81 | @property 82 | def bit_rate(self) -> int: 83 | """Calculate bit rate in bits per second. 84 | 85 | For audio: channels * depth * frame_rate 86 | For image: channels * depth * width * height * frame_rate 87 | """ 88 | if self.duration == 0: 89 | return 0 90 | 91 | bits_per_sample = self.channels * self.bits_per_channel 92 | if self.width and self.height: # Image/Video 93 | bits_per_frame = bits_per_sample * self.width * self.height 94 | else: # Audio 95 | bits_per_frame = bits_per_sample 96 | 97 | return int(bits_per_frame * self.frame_rate) 98 | 99 | 100 | class VideoImportMode(Enum): 101 | """Import mode for video files.""" 102 | 103 | ALL = 0 # Import complete video with audio 104 | VIDEO_ONLY = 1 # Import video without audio 105 | AUDIO_ONLY = 2 # Import audio only without video 106 | VIDEO_SPLIT_AUDIO = 3 # Split and import both video and audio separately 107 | 108 | @classmethod 109 | def from_int(cls, value: int) -> "VideoImportMode": 110 | """Convert integer to VideoImportMode.""" 111 | for mode in cls: 112 | if mode.value == value: 113 | return mode 114 | raise ValueError(f"Invalid import mode value: {value}") 115 | 116 | 117 | class FileProcessor: 118 | """Utility class for processing files. 119 | 120 | This class provides methods for loading and processing file metadata, 121 | including size, format, and hash calculations. 122 | 123 | Example: 124 | processor = FileProcessor() 125 | metadata = processor.load_metadata("path/to/image.jpg") 126 | if metadata: 127 | print(f"Image size: {metadata.width}x{metadata.height}") 128 | """ 129 | 130 | def _extract_audio_metadata( 131 | self, video: Any, file_path: str, meta: Dict[str, Any], save_binary: bool = True 132 | ) -> Optional[Tuple[Metadata, bytes]]: 133 | """Extract audio metadata from video file. 134 | 135 | Args: 136 | video: Video file object 137 | file_path: Path to video file 138 | meta: Video metadata 139 | save_binary: Whether to save binary data 140 | 141 | Returns: 142 | Tuple of (metadata, binary_data) if successful, None if failed 143 | """ 144 | if not meta.get("has_audio"): 145 | console.print( 146 | f"[yellow]Warning: Video file {file_path} has no audio track[/yellow]" 147 | ) 148 | return None 149 | 150 | # Get audio data 151 | audio_data = video.read_audio() # This returns numpy array of audio samples 152 | if audio_data is None: 153 | console.print( 154 | f"[yellow]Warning: Failed to extract audio data from {file_path}[/yellow]" 155 | ) 156 | return None 157 | 158 | # Get audio binary data and calculate hash 159 | binary_data = audio_data.astype(np.int16).tobytes() 160 | audio_hash = hashlib.sha256(binary_data).hexdigest() 161 | 162 | # Create audio metadata 163 | duration = int(meta.get("duration", 0) * 1000) # Convert to milliseconds 164 | audio_metadata = Metadata( 165 | uris=file_path, 166 | mime=f"audio/wav", 167 | width=0, 168 | height=0, 169 | channels=audio_data.shape[1], 170 | depth=16, # 16-bit audio 171 | hash=audio_hash, # Use audio's own hash 172 | size=len(binary_data), 173 | has_audio=True, 174 | duration=duration, 175 | num_frames=int(duration * meta.get("audio_fps", 44100) / 1000), 176 | frame_rate=meta.get("audio_fps", 44100), 177 | blob=binary_data if save_binary else b"", 178 | ) 179 | 180 | return audio_metadata, binary_data 181 | 182 | @staticmethod 183 | def load_metadata( 184 | file_path: str, 185 | save_binary: bool = True, 186 | import_mode: VideoImportMode = VideoImportMode.ALL, 187 | ) -> Optional[Metadata]: 188 | """Load and process image metadata. 189 | 190 | Args: 191 | file_path: Path to the media file 192 | save_binary: If True, store the binary data of the image 193 | import_mode: Mode for importing video components: 194 | - ALL: Complete video with audio 195 | - VIDEO_ONLY: Video without audio 196 | - AUDIO_ONLY: Audio only 197 | - VIDEO_SPLIT_AUDIO: Split video and audio 198 | 199 | Returns: 200 | Metadata object if successful, None if failed 201 | 202 | Raises: 203 | FileNotFoundError: If the image file doesn't exist 204 | IOError: If there's an error reading the file 205 | SyntaxError: If the image format is invalid 206 | """ 207 | try: 208 | if file_path.endswith(image_extensions + animation_extensions): 209 | with Image.open(file_path) as img: 210 | # Get file pointer position 211 | pos = img.fp.tell() 212 | # Reset to beginning 213 | img.fp.seek(0) 214 | # Read data and calculate hash 215 | binary_data = img.fp.read() 216 | image_hash = hashlib.sha256(binary_data).hexdigest() 217 | # Restore position 218 | img.fp.seek(pos) 219 | 220 | # Get animation info if available 221 | duration = 0 # Initialize to 0 for accumulation 222 | n_frames = None 223 | frame_rate = 0.0 224 | 225 | if hasattr(img, "n_frames") and img.n_frames > 1: 226 | n_frames = img.n_frames 227 | # Get duration in milliseconds 228 | for frame in range(img.n_frames): 229 | img.seek(frame) 230 | duration += img.info.get("duration", 0) 231 | # Calculate frame rate from duration and frame count 232 | if duration > 0: 233 | frame_rate = (n_frames * 1000) / duration # Convert to fps 234 | else: 235 | duration = None 236 | 237 | # Get image MIME type, fallback to PIL format 238 | mime_type, _ = mimetypes.guess_type(file_path) 239 | mime = mime_type or f"image/{img.format.lower()}" 240 | 241 | # Get depth based on mode type 242 | channels = len(img.getbands()) 243 | mode = img.mode 244 | 245 | # Try different ways to get bit depth 246 | depth = None 247 | # 1. Try img.bits first (most accurate, includes 12-bit) 248 | if hasattr(img, "bits"): 249 | depth = img.bits 250 | # 2. Try to get from tag for TIFF images (can have 12-bit) 251 | elif hasattr(img, "tag_v2"): 252 | bits = img.tag_v2.get(258) # BitsPerSample tag 253 | if bits: 254 | depth = bits[0] if isinstance(bits, tuple) else bits 255 | # 3. Fallback to mode info 256 | if depth is None: 257 | mode_info = ImageMode.getmode(mode) 258 | if mode_info.basetype == "1": 259 | depth = 1 260 | else: 261 | # Convert bytes to bits (note: this will show 16 for 12-bit images) 262 | type_size = int(mode_info.typestr[-1]) 263 | depth = type_size * 8 264 | 265 | return Metadata( 266 | uris=file_path, 267 | mime=mime, 268 | width=img.size[0], 269 | height=img.size[1], 270 | depth=depth, 271 | channels=channels, 272 | hash=image_hash, 273 | size=Path(file_path).stat().st_size, 274 | has_audio=False, 275 | duration=duration, 276 | num_frames=n_frames, 277 | frame_rate=frame_rate, 278 | blob=binary_data if save_binary else b"", 279 | ) 280 | elif file_path.endswith(video_extensions): 281 | try: 282 | # Get video metadata first 283 | meta = iio.immeta(file_path) or {} 284 | 285 | # Get video MIME type 286 | mime_type, _ = mimetypes.guess_type(file_path) 287 | extension = Path(file_path).suffix.lstrip(".") 288 | mime = mime_type or f"video/{extension}" 289 | 290 | # Get basic video info with safety checks 291 | size = meta.get("size", (0, 0)) 292 | width = ( 293 | int(size[0]) 294 | if isinstance(size, (tuple, list)) and len(size) > 0 295 | else 0 296 | ) 297 | height = ( 298 | int(size[1]) 299 | if isinstance(size, (tuple, list)) and len(size) > 1 300 | else 0 301 | ) 302 | 303 | # Handle fps with safety checks 304 | fps = meta.get("fps", 0) 305 | frame_rate = ( 306 | float(fps) 307 | if fps and not np.isinf(fps) and not np.isnan(fps) 308 | else 0.0 309 | ) 310 | 311 | # Handle duration with safety checks 312 | dur = meta.get("duration", 0) 313 | duration = ( 314 | int(dur * 1000) 315 | if dur and not np.isinf(dur) and not np.isnan(dur) 316 | else 0 317 | ) 318 | 319 | # Calculate frames with safety checks 320 | n_frames = meta.get("nframes", 0) 321 | if not n_frames or np.isinf(n_frames) or np.isnan(n_frames): 322 | if frame_rate > 0 and duration > 0: 323 | n_frames = int(frame_rate * (duration / 1000)) 324 | else: 325 | n_frames = 0 326 | 327 | # Get first frame for color info 328 | try: 329 | with iio.imopen(file_path, "r") as file: 330 | first_frame = file.read(index=0) 331 | channels = ( 332 | first_frame.shape[2] 333 | if len(first_frame.shape) > 2 334 | else 1 335 | ) 336 | depth = first_frame.dtype.itemsize * 8 337 | except Exception as e: 338 | console.print( 339 | f"[yellow]Warning: Could not read first frame from {file_path}: {e}[/yellow]" 340 | ) 341 | channels = 3 # Assume RGB 342 | depth = 8 # Assume 8-bit 343 | 344 | # Read video binary data and calculate hash in chunks 345 | hasher = hashlib.sha256() 346 | binary_data = bytearray() 347 | chunk_size = 8192 # 8KB chunks 348 | 349 | with open(file_path, "rb") as f: 350 | while True: 351 | chunk = f.read(chunk_size) 352 | if not chunk: 353 | break 354 | hasher.update(chunk) 355 | if save_binary: 356 | binary_data.extend(chunk) 357 | 358 | video_hash = hasher.hexdigest() 359 | 360 | return Metadata( 361 | uris=file_path, 362 | mime=mime, 363 | width=width, 364 | height=height, 365 | depth=depth, 366 | channels=channels, 367 | hash=video_hash, 368 | size=Path(file_path).stat().st_size, 369 | has_audio=meta.get("has_audio", False), 370 | duration=duration, 371 | num_frames=n_frames, 372 | frame_rate=frame_rate, 373 | blob=bytes(binary_data) if save_binary else b"", 374 | ) 375 | except Exception as e: 376 | console.print( 377 | f"[red]Error processing video {file_path}: {str(e)}[/red]" 378 | ) 379 | return None 380 | 381 | elif file_path.endswith(audio_extensions): 382 | try: 383 | # Read audio file as binary first 384 | binary_data = Path(file_path).read_bytes() 385 | audio_hash = hashlib.sha256(binary_data).hexdigest() 386 | 387 | # Get audio MIME type 388 | mime_type, _ = mimetypes.guess_type(file_path) 389 | extension = Path(file_path).suffix.lstrip(".") 390 | mime = mime_type or f"audio/{extension}" 391 | 392 | # Try to get audio metadata using mutagen first 393 | audio = MutagenFile(file_path) 394 | if audio is not None: 395 | # Get duration in milliseconds 396 | duration = int(audio.info.length * 1000) 397 | # Get sample rate 398 | frame_rate = getattr(audio.info, "sample_rate", 44100) 399 | # Get number of channels 400 | channels = getattr(audio.info, "channels", 2) 401 | # Get bit depth if available 402 | depth = getattr(audio.info, "bits_per_sample", 16) 403 | # Calculate number of frames 404 | n_frames = int(audio.info.length * frame_rate) 405 | else: 406 | raise Exception("Could not read audio metadata") 407 | 408 | return Metadata( 409 | uris=file_path, 410 | mime=mime, 411 | width=0, 412 | height=0, 413 | depth=depth, 414 | channels=channels, 415 | hash=audio_hash, 416 | size=Path(file_path).stat().st_size, 417 | has_audio=True, 418 | duration=duration, 419 | num_frames=n_frames, 420 | frame_rate=frame_rate, 421 | blob=binary_data if save_binary else b"", 422 | ) 423 | except Exception as e: 424 | console.print( 425 | f"[red]Error processing audio {file_path}: {str(e)}[/red]" 426 | ) 427 | return None 428 | 429 | elif file_path.endswith(application_extensions): 430 | try: 431 | # Read application file as binary first 432 | binary_data = Path(file_path).read_bytes() 433 | application_hash = hashlib.sha256(binary_data).hexdigest() 434 | 435 | # Get application MIME type 436 | mime_type, _ = mimetypes.guess_type(file_path) 437 | extension = Path(file_path).suffix.lstrip(".") 438 | mime = mime_type or f"application/{extension}" 439 | 440 | return Metadata( 441 | uris=file_path, 442 | mime=mime, 443 | width=0, 444 | height=0, 445 | depth=0, 446 | channels=0, 447 | hash=application_hash, 448 | size=Path(file_path).stat().st_size, 449 | has_audio=False, 450 | duration=0, 451 | num_frames=0, 452 | frame_rate=0, 453 | blob=binary_data if save_binary else b"", 454 | ) 455 | except Exception as e: 456 | console.print( 457 | f"[red]Error processing application {file_path}: {str(e)}[/red]" 458 | ) 459 | return None 460 | 461 | except Exception as e: 462 | console.print( 463 | f"[red]Unexpected error processing {file_path}: {str(e)}[/red]" 464 | ) 465 | return None 466 | 467 | 468 | def load_data( 469 | datasets_dir: str, texts_dir: Optional[str] = None 470 | ) -> List[Dict[str, Any]]: 471 | """ 472 | Load image and caption data from directories. 473 | 474 | Args: 475 | datasets_dir: Directory containing images or videos 476 | texts_dir: Optional directory containing caption text files 477 | 478 | Returns: 479 | List of image-caption pairs 480 | """ 481 | data = [] 482 | 483 | if texts_dir: 484 | # Paired directory structure 485 | for file in Path(datasets_dir).iterdir(): 486 | if not file.is_file() or not any( 487 | str(file).endswith(ext) 488 | for ext in ( 489 | image_extensions 490 | + animation_extensions 491 | + video_extensions 492 | + audio_extensions 493 | + application_extensions 494 | ) 495 | ): 496 | continue 497 | 498 | text_path = Path(texts_dir) / (file.stem + ".txt") 499 | srt_path = Path(texts_dir) / (file.stem + ".srt") 500 | md_path = Path(texts_dir) / (file.stem + ".md") 501 | 502 | caption = None 503 | if text_path.exists(): 504 | with open(text_path, "r", encoding="utf-8") as f: 505 | caption = f.read().splitlines() 506 | elif md_path.exists(): 507 | with open(md_path, "r", encoding="utf-8") as f: 508 | caption = [f.read()] # 将整个 Markdown 内容作为单个字符串 509 | elif srt_path.exists(): 510 | with open(srt_path, "r", encoding="utf-8") as f: 511 | caption = [f.read()] # Store entire SRT content as a single string 512 | else: 513 | caption = [] 514 | 515 | data.append({"file_path": str(file), "caption": caption}) 516 | else: 517 | # Single directory structure 518 | datasets_path = Path(datasets_dir).absolute() # 转换为绝对路径 519 | for file_path in datasets_path.rglob("*"): 520 | if not file_path.is_file() or not any( 521 | str(file_path).endswith(ext) 522 | for ext in ( 523 | image_extensions 524 | + animation_extensions 525 | + video_extensions 526 | + audio_extensions 527 | + application_extensions 528 | ) 529 | ): 530 | continue 531 | 532 | text_path = file_path.with_suffix(".txt") 533 | srt_path = file_path.with_suffix(".srt") 534 | md_path = file_path.with_suffix(".md") 535 | 536 | caption = None 537 | if text_path.exists(): 538 | with open(text_path, "r", encoding="utf-8") as f: 539 | caption = f.read().splitlines() 540 | elif md_path.exists(): 541 | with open(md_path, "r", encoding="utf-8") as f: 542 | caption = [f.read()] # 将整个 Markdown 内容作为单个字符串 543 | elif srt_path.exists(): 544 | with open(srt_path, "r", encoding="utf-8") as f: 545 | caption = [f.read()] # Store entire SRT content as a single string 546 | else: 547 | caption = [] 548 | 549 | data.append({"file_path": str(file_path), "caption": caption}) 550 | 551 | return data 552 | 553 | 554 | def process( 555 | data: List[Dict[str, Any]], 556 | save_binary: bool = True, 557 | import_mode: VideoImportMode = VideoImportMode.ALL, 558 | ) -> pa.RecordBatch: 559 | """ 560 | Process image-caption pairs into Lance format. 561 | 562 | Args: 563 | data: List of dictionaries containing file paths and captions. 564 | save_binary: Whether to save binary data. 565 | import_mode: Mode for importing video components: 566 | - ALL: Complete video with audio 567 | - VIDEO_ONLY: Video without audio 568 | - AUDIO_ONLY: Audio only 569 | - VIDEO_SPLIT_AUDIO: Split video and audio 570 | 571 | Returns: 572 | A PyArrow RecordBatch containing the processed data. 573 | """ 574 | processor = FileProcessor() 575 | 576 | with Progress( 577 | "[progress.description]{task.description}", 578 | SpinnerColumn(spinner_name="dots"), 579 | MofNCompleteColumn(separator="/"), 580 | BarColumn(bar_width=40, complete_style="green", finished_style="bold green"), 581 | TextColumn("•"), 582 | TaskProgressColumn(), 583 | TextColumn("•"), 584 | TransferSpeedColumn(), 585 | TextColumn("•"), 586 | TimeElapsedColumn(), 587 | TextColumn("•"), 588 | TimeRemainingColumn(), 589 | expand=True, 590 | transient=False, # 防止进度条随刷新滚动 591 | ) as progress: 592 | 593 | global console 594 | 595 | console = progress.console 596 | 597 | task = progress.add_task("[green]Processing file...", total=len(data)) 598 | 599 | for item in data: 600 | file_path = item["file_path"] 601 | caption = item["caption"] 602 | 603 | console.print() 604 | 605 | # 根据文件类型选择颜色 606 | suffix = Path(file_path).suffix.lower() 607 | if suffix in image_extensions: 608 | if suffix in animation_extensions: 609 | color = CONSOLE_COLORS["animation"] 610 | media_type = "animation" 611 | else: 612 | color = CONSOLE_COLORS["image"] 613 | media_type = "image" 614 | elif suffix in video_extensions: 615 | color = CONSOLE_COLORS["video"] 616 | media_type = "video" 617 | elif suffix in audio_extensions: 618 | color = CONSOLE_COLORS["audio"] 619 | media_type = "audio" 620 | elif suffix in application_extensions: 621 | color = CONSOLE_COLORS["application"] 622 | media_type = "application" 623 | else: 624 | color = CONSOLE_COLORS["unknown"] 625 | media_type = "unknown" 626 | 627 | console.print( 628 | f"Processing {media_type} file [{color}]'{file_path}'[/{color}]" 629 | ) 630 | console.print(f"Caption: {caption}", style=CONSOLE_COLORS["caption"]) 631 | 632 | metadata = processor.load_metadata(file_path, save_binary, import_mode) 633 | if not metadata: 634 | progress.update(task, advance=1) 635 | continue 636 | 637 | # Get field names and create arrays 638 | field_names = [field[0] for field in DATASET_SCHEMA] 639 | arrays = [] 640 | for field_name, field_type in DATASET_SCHEMA: 641 | if field_name == "filepath": 642 | value = str(Path(file_path).absolute()) 643 | array = pa.array([value], type=field_type) 644 | elif field_name == "captions": 645 | array = pa.array([caption], type=field_type) 646 | elif field_name == "blob": 647 | array = pa.array([getattr(metadata, field_name)], type=field_type) 648 | else: 649 | value = getattr(metadata, field_name) 650 | # Convert None to appropriate default value based on type 651 | if value is None: 652 | if pa.types.is_integer(field_type): 653 | value = 0 654 | elif pa.types.is_floating(field_type): 655 | value = 0.0 656 | elif pa.types.is_boolean(field_type): 657 | value = False 658 | elif pa.types.is_string(field_type): 659 | value = "" 660 | array = pa.array([value], type=field_type) 661 | arrays.append(array) 662 | 663 | batch = pa.RecordBatch.from_arrays( 664 | arrays, 665 | names=field_names, 666 | ) 667 | 668 | yield batch 669 | progress.update(task, advance=1) 670 | 671 | 672 | def transform2lance( 673 | dataset_dir: str, 674 | caption_dir: Optional[str] = None, 675 | output_name: str = "dataset", 676 | save_binary: bool = True, 677 | not_save_disk: bool = False, 678 | import_mode: VideoImportMode = VideoImportMode.ALL, 679 | tag: str = "gemini", 680 | load_condition: Callable[[str, Optional[str]], List[Dict[str, Any]]] = load_data, 681 | ) -> Optional[lance.LanceDataset]: 682 | """ 683 | Transform image-caption pairs into Lance dataset. 684 | 685 | Args: 686 | dataset_dir: Directory containing training images 687 | caption_dir: Optional directory containing captions 688 | output_name: Name of output dataset 689 | save_binary: Whether to save binary data in the dataset. 690 | not_save_disk: If True, don't save to disk 691 | import_mode: Mode for importing video components: 692 | - ALL: Complete video with audio 693 | - VIDEO_ONLY: Video only 694 | - AUDIO_ONLY: Audio only 695 | - VIDEO_SPLIT_AUDIO: Split video and audio 696 | load_condition: Function to load data 697 | 698 | Returns: 699 | Lance dataset object or None if error occurs 700 | """ 701 | data = load_condition(dataset_dir, caption_dir) 702 | 703 | schema = pa.schema( 704 | [ 705 | pa.field( 706 | name, 707 | pa_type, 708 | metadata={b"lance-encoding:blob": b"true"} if name == "blob" else None, 709 | ) 710 | for name, pa_type in DATASET_SCHEMA 711 | ] 712 | ) 713 | 714 | try: 715 | reader = pa.RecordBatchReader.from_batches( 716 | schema, process(data, save_binary, import_mode) 717 | ) 718 | 719 | dataset_path = Path(dataset_dir) / f"{output_name}.lance" 720 | mode = "append" if dataset_path.exists() else "create" 721 | 722 | lancedataset = lance.write_dataset( 723 | reader, 724 | str(dataset_path), 725 | schema, 726 | mode=mode if not_save_disk else "overwrite", 727 | ) 728 | 729 | try: 730 | lancedataset.tags.create(tag, 1) 731 | except: 732 | lancedataset.tags.update(tag, 1) 733 | 734 | return lancedataset 735 | 736 | except AttributeError as e: 737 | console.print(f"[red]AttributeError: {e}[/red]") 738 | return None 739 | 740 | 741 | def setup_parser() -> argparse.ArgumentParser: 742 | """Setup argument parser.""" 743 | parser = argparse.ArgumentParser(description="Transform dataset into Lance format") 744 | parser.add_argument( 745 | "dataset_dir", type=str, help="Directory containing training images" 746 | ) 747 | parser.add_argument( 748 | "--caption_dir", 749 | type=str, 750 | default=None, 751 | help="Directory containing caption files", 752 | ) 753 | parser.add_argument( 754 | "--output_name", type=str, default="dataset", help="Name of output dataset" 755 | ) 756 | parser.add_argument( 757 | "--no_save_binary", 758 | action="store_true", 759 | help="Don't save binary data in the dataset", 760 | ) 761 | parser.add_argument( 762 | "--not_save_disk", 763 | action="store_true", 764 | help="Load dataset into memory instead of saving to disk", 765 | ) 766 | parser.add_argument( 767 | "--import_mode", 768 | type=int, 769 | default=0, 770 | choices=[0, 1, 2, 3], 771 | help="Video import mode: 0=Complete video with audio, 1=Video only, " 772 | "2=Audio only, 3=Split video and audio", 773 | ) 774 | parser.add_argument( 775 | "--tag", 776 | type=str, 777 | default="gemini", 778 | help="Tag for the dataset", 779 | ) 780 | return parser 781 | 782 | 783 | if __name__ == "__main__": 784 | parser = setup_parser() 785 | args = parser.parse_args() 786 | 787 | transform2lance( 788 | dataset_dir=args.dataset_dir, 789 | caption_dir=args.caption_dir, 790 | output_name=args.output_name, 791 | save_binary=not args.no_save_binary, 792 | not_save_disk=args.not_save_disk, 793 | import_mode=VideoImportMode.from_int(args.import_mode), 794 | tag=args.tag, 795 | ) 796 | -------------------------------------------------------------------------------- /module/lanceexport.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import lance 3 | import re 4 | from typing import Optional, Union, List, Dict, Any 5 | from rich.console import Console 6 | from rich.progress import ( 7 | Progress, 8 | SpinnerColumn, 9 | TextColumn, 10 | BarColumn, 11 | TaskProgressColumn, 12 | TimeRemainingColumn, 13 | TimeElapsedColumn, 14 | TransferSpeedColumn, 15 | MofNCompleteColumn, 16 | ) 17 | from config.config import get_supported_extensions, DATASET_SCHEMA, CONSOLE_COLORS 18 | from utils.stream_util import split_media_stream_clips, split_video_with_imageio_ffmpeg 19 | from pathlib import Path 20 | import pysrt 21 | import json 22 | 23 | console = Console() 24 | image_extensions = get_supported_extensions("image") 25 | animation_extensions = get_supported_extensions("animation") 26 | video_extensions = get_supported_extensions("video") 27 | audio_extensions = get_supported_extensions("audio") 28 | application_extensions = get_supported_extensions("application") 29 | 30 | 31 | def format_duration(duration_ms: int) -> str: 32 | """将毫秒转换为分:秒格式.""" 33 | total_seconds = duration_ms // 1000 34 | minutes = total_seconds // 60 35 | seconds = total_seconds % 60 36 | return f"{minutes}:{seconds:02d}" 37 | 38 | 39 | def save_blob( 40 | uri: Path, 41 | blob: Union[bytes, lance.BlobFile], 42 | metadata: Dict[str, Any], 43 | media_type: str, 44 | ) -> bool: 45 | """Save binary blob to file. 46 | 47 | Args: 48 | uri: Target path 49 | blob: Binary data or BlobFile 50 | metadata: File metadata 51 | media_type: Type of media (image/video/audio) 52 | 53 | Returns: 54 | bool: True if successful 55 | """ 56 | try: 57 | uri.parent.mkdir(parents=True, exist_ok=True) 58 | 59 | # Handle both bytes and BlobFile 60 | if isinstance(blob, lance.BlobFile): 61 | with open(uri, "wb") as f: 62 | while True: 63 | chunk = blob.read(8192) # Read in chunks 64 | if not chunk: 65 | break 66 | f.write(chunk) 67 | else: 68 | uri.write_bytes(blob) 69 | 70 | # Print media-specific metadata 71 | meta_info = [] 72 | if media_type in ["image", "animation"]: 73 | meta_info.extend( 74 | [ 75 | f"{metadata.get('width', 0)}x{metadata.get('height', 0)}", 76 | f"{metadata.get('channels', 0)}ch", 77 | ( 78 | f"{metadata.get('num_frames', 1)} frames" 79 | if metadata.get("num_frames", 1) > 1 80 | else None 81 | ), 82 | ] 83 | ) 84 | elif media_type == "video": 85 | duration = metadata.get("duration", 0) 86 | meta_info.extend( 87 | [ 88 | f"{metadata.get('width', 0)}x{metadata.get('height', 0)}", 89 | f"{format_duration(duration)}", 90 | f"{metadata.get('frame_rate', 0):.1f}fps", 91 | ] 92 | ) 93 | elif media_type == "audio": 94 | duration = metadata.get("duration", 0) 95 | meta_info.extend( 96 | [ 97 | f"{metadata.get('channels', 0)}ch", 98 | f"{metadata.get('frame_rate', 0):.1f}Hz", 99 | f"{format_duration(duration)}", 100 | ] 101 | ) 102 | 103 | elif media_type == "application": 104 | meta_info.extend( 105 | [ 106 | f"{metadata.get('size', 0) / (1024 * 1024):.2f} MB", 107 | ] 108 | ) 109 | 110 | meta_str = ", ".join(filter(None, meta_info)) 111 | console.print() 112 | 113 | # 使用配置的颜色 114 | color = CONSOLE_COLORS.get(media_type, "white") 115 | console.print( 116 | f"[{color}]{media_type}: {uri} ({meta_str}) saved successfully.[/{color}]" 117 | ) 118 | return True 119 | except Exception as e: 120 | console.print(f"[red]Error saving {media_type} {uri}: {e}[/red]") 121 | return False 122 | 123 | 124 | def save_caption(caption_path: str, caption_lines: List[str], media_type: str) -> bool: 125 | """Save caption data to disk. 126 | 127 | Args: 128 | caption_path: Path to save caption file 129 | caption_lines: List of caption lines 130 | media_type: Type of media (image/video/audio) 131 | 132 | Returns: 133 | bool: True if save successful, False otherwise 134 | """ 135 | try: 136 | if not len(caption_lines): 137 | console.print(f"[red]No caption content found for {caption_path}[/red]") 138 | return False 139 | caption_path = Path(caption_path) 140 | caption_path.parent.mkdir(parents=True, exist_ok=True) 141 | 142 | if media_type == "audio" or media_type == "video": 143 | caption_path = caption_path.with_suffix(".srt") 144 | elif media_type == "application": 145 | caption_path = caption_path.with_suffix(".md") 146 | else: 147 | caption_path = caption_path.with_suffix(".txt") 148 | 149 | with open(caption_path, "w", encoding="utf-8") as f: 150 | if caption_path.suffix == ".srt": 151 | # For SRT files, preserve all lines including empty ones 152 | f.write("\n".join(caption_lines)) 153 | elif caption_path.suffix == ".md": 154 | # For MD files, preserve original markdown formatting 155 | f.write("".join(caption_lines)) 156 | else: 157 | # For TXT files, strip empty lines and whitespace 158 | for line in caption_lines: 159 | if "', "") 162 | .replace("", "") 163 | .replace("", "") 164 | ) 165 | 166 | # Check if the content is in JSON format and parse it if possible 167 | if line.strip().startswith("{") and line.strip().endswith("}"): 168 | try: 169 | json_content = line.strip() 170 | parsed_json = json.loads(json_content) 171 | # Format JSON content with indentation for better readability 172 | with open( 173 | caption_path.with_suffix(".json"), "w", encoding="utf-8" 174 | ) as j: 175 | json.dump(parsed_json, j, indent=2, ensure_ascii=False) 176 | f.write(parsed_json["description"]) 177 | except json.JSONDecodeError: 178 | # If not valid JSON, continue with normal text processing 179 | if line and line.strip(): 180 | f.write(line.strip() + "\n") 181 | else: 182 | if line and line.strip(): 183 | f.write(line.strip() + "\n") 184 | 185 | console.print() 186 | console.print( 187 | f"[{CONSOLE_COLORS['text']}]text: {caption_path} saved successfully.[/{CONSOLE_COLORS['text']}]" 188 | ) 189 | return True 190 | except Exception as e: 191 | console.print(f"[red]Error saving caption: {e}[/red]") 192 | return False 193 | 194 | 195 | def save_caption_by_pages(caption_path: Path, caption_lines: List[str]) -> bool: 196 | """将多页文档分割为单独的页面并分别保存 197 | 198 | Args: 199 | caption_path: 保存路径 200 | caption_lines: 包含多页内容的文本列表 201 | 202 | Returns: 203 | bool: 成功返回True,失败返回False 204 | """ 205 | try: 206 | # 合并文本行为单个字符串 207 | if len(caption_lines) == 1: 208 | # 如果只有一个元素(来自.md或.srt文件),直接使用该元素 209 | combined_text = caption_lines[0] 210 | else: 211 | # 如果是多行(来自.txt文件),用换行符连接 212 | combined_text = "\n".join(caption_lines) 213 | 214 | # 使用页眉作为分隔符来分割多个页面 215 | header_pattern = r'(?s)
.*?\s*Page\s+(\d+)\s*.*?' 217 | page_break_pattern = r'
' 218 | 219 | # 分割所有页面 220 | page_contents = [] 221 | page_numbers = [] 222 | 223 | # 查找所有页头位置 224 | header_matches = list(re.finditer(header_pattern, combined_text)) 225 | footer_matches = list(re.finditer(footer_pattern, combined_text)) 226 | 227 | # 如果没有找到页头,整体保存 228 | if not header_matches: 229 | # 没有找到页头,尝试其他方式分割内容 230 | # 尝试使用Markdown标题作为分割点 231 | md_header_pattern = r"^#{1,6}\s+(.+?)$" 232 | md_headers = list( 233 | re.finditer(md_header_pattern, combined_text, re.MULTILINE) 234 | ) 235 | 236 | if md_headers: 237 | # 使用Markdown标题分割内容 238 | console.print( 239 | f"[yellow]No HTML headers found, splitting by Markdown headers.[/yellow]" 240 | ) 241 | 242 | for i in range(len(md_headers)): 243 | header_match = md_headers[i] 244 | header_text = header_match.group(1).strip() 245 | 246 | # 计算当前部分内容的开始位置 247 | start_pos = header_match.start() 248 | 249 | # 计算当前部分内容的结束位置 250 | if i < len(md_headers) - 1: 251 | end_pos = md_headers[i + 1].start() 252 | else: 253 | end_pos = len(combined_text) 254 | 255 | # 提取部分内容 256 | section_content = combined_text[start_pos:end_pos] 257 | 258 | # 创建文件名 (使用标题的前20个字符,去除特殊字符) 259 | safe_header = re.sub(r"[^\w\s-]", "", header_text)[:20].strip() 260 | safe_header = re.sub(r"[-\s]+", "_", safe_header) 261 | 262 | section_filename = ( 263 | f"{caption_path.stem}_{safe_header}{caption_path.suffix}" 264 | ) 265 | section_file_path = caption_path.with_suffix("") / section_filename 266 | 267 | # 保存部分内容 268 | section_file_path.parent.mkdir(parents=True, exist_ok=True) 269 | with open(section_file_path, "w", encoding="utf-8") as f: 270 | f.write(section_content) 271 | 272 | console.print( 273 | f"[{CONSOLE_COLORS['text']}]text: {section_file_path} saved successfully.[/{CONSOLE_COLORS['text']}]" 274 | ) 275 | 276 | return True 277 | else: 278 | # 如果没有任何分割点,保存为单个文件 279 | output_dir = caption_path.with_suffix(".md") 280 | output_dir.mkdir(parents=True, exist_ok=True) 281 | single_file_path = output_dir / f"{caption_path.stem}.md" 282 | 283 | with open(single_file_path, "w", encoding="utf-8") as f: 284 | f.write(combined_text) 285 | console.print( 286 | f"[{CONSOLE_COLORS['text']}]text: {single_file_path} saved successfully.[/{CONSOLE_COLORS['text']}]" 287 | ) 288 | return True 289 | # 分割每个页面的内容 290 | for i in range(len(header_matches)): 291 | header_match = header_matches[i] 292 | page_number = int(header_match.group(1)) 293 | page_numbers.append(page_number) 294 | 295 | # 计算当前页面内容的开始位置(从页头开始) 296 | start_pos = header_match.start() 297 | 298 | # 计算当前页面内容的结束位置 299 | # 先尝试查找对应的页脚 300 | end_pos = None 301 | 302 | # 寻找这个页码对应的页脚 303 | for footer_match in footer_matches: 304 | footer_page = int(footer_match.group(1)) 305 | if footer_page == page_number: 306 | # 结束位置是这个页脚的结束位置 307 | end_pos = footer_match.end() 308 | break 309 | 310 | # 如果没找到对应页脚,则使用下一个页头作为结束位置 311 | if end_pos is None: 312 | if i < len(header_matches) - 1: 313 | end_pos = header_matches[i + 1].start() 314 | else: 315 | end_pos = len(combined_text) 316 | 317 | # 提取页面内容 318 | page_content = combined_text[start_pos:end_pos] 319 | 320 | # 移除页面分隔符 (确保使用多行模式) 321 | page_content = re.sub(page_break_pattern, "", page_content, flags=re.DOTALL) 322 | 323 | # 移除页眉 324 | page_content = re.sub( 325 | r'(?s)
.*?', "", page_content 333 | ) 334 | 335 | # 清理可能的多余空行 336 | page_content = re.sub(r"\n{3,}", "\n\n", page_content) 337 | 338 | # 添加到页面内容列表 339 | page_contents.append((page_number, page_content)) 340 | 341 | # 创建输出目录 342 | output_dir = caption_path.with_suffix("") 343 | output_dir.mkdir(parents=True, exist_ok=True) 344 | 345 | # 保存每一页为独立文件 346 | for page_number, page_content in page_contents: 347 | # 处理图片路径,将路径从子文件夹改为同级 348 | img_pattern = r"!\[(.*?)\]\(([^/]+)/([^/)]+)\)" 349 | 350 | # 检查是否有重复引用的图片 351 | processed_page_content = page_content 352 | matches = list(re.finditer(img_pattern, page_content)) 353 | 354 | if matches: 355 | # 使用字典记录每个图片第一次出现的位置 356 | first_occurrence = {} 357 | 358 | # 找出每个图片第一次出现的位置 359 | for match in matches: 360 | alt_text = match.group(1) 361 | folder = match.group(2) 362 | img_name = match.group(3) 363 | if img_name not in first_occurrence: 364 | first_occurrence[img_name] = match 365 | 366 | # 先处理图片路径,统一格式 367 | processed_page_content = re.sub( 368 | img_pattern, r"![\1](\3)", processed_page_content 369 | ) 370 | 371 | # 移除所有重复的图片,但保留第一次出现的位置 372 | for img_name, match in first_occurrence.items(): 373 | # 计算该图片在文本中所有出现的位置 374 | all_matches = [m for m in matches if m.group(3) == img_name] 375 | 376 | # 如果有多次出现,移除除了第一次之外的所有引用 377 | if len(all_matches) > 1: 378 | # 排序匹配,按位置从前向后处理 379 | sorted_matches = sorted(all_matches, key=lambda m: m.start()) 380 | 381 | # 跳过第一次出现的匹配 382 | for m in sorted_matches[1:]: 383 | # 构建要移除的模式 384 | alt_text = m.group(1) 385 | pattern_to_remove = f"!\\[{re.escape(alt_text)}\\]\\({re.escape(img_name)}\\)" 386 | # 从处理后的内容中移除该模式 387 | processed_page_content = re.sub( 388 | pattern_to_remove, "", processed_page_content, count=1 389 | ) 390 | else: 391 | # 如果没有匹配到图片,只进行路径格式转换 392 | processed_page_content = re.sub(img_pattern, r"![\1](\3)", page_content) 393 | 394 | page_filename = f"{caption_path.stem}_{page_number}.md" 395 | page_file_path = output_dir / page_filename 396 | 397 | # 保存页面内容 398 | with open(page_file_path, "w", encoding="utf-8") as f: 399 | f.write(processed_page_content) 400 | 401 | console.print( 402 | f"[{CONSOLE_COLORS['text']}]text: {page_file_path} saved successfully.[/{CONSOLE_COLORS['text']}]" 403 | ) 404 | 405 | return True 406 | except Exception as e: 407 | console.print(f"[red]Error saving pages: {e}[/red]") 408 | return False 409 | 410 | 411 | def split_md_document(uri: Path, caption_lines: List[str], save_caption_func) -> None: 412 | """分割多页Markdown文档并单独保存每一页 413 | 414 | Args: 415 | uri: 文件路径 416 | caption_lines: 包含多页内容的文本列表 417 | save_caption_func: 用于保存单页内容的函数 418 | """ 419 | try: 420 | # 检查是否包含多页内容 421 | if any( 422 | '
None: 447 | """ 448 | Extract images and captions from Lance dataset. 449 | 450 | Args: 451 | lance_or_path: Path to Lance dataset or Lance dataset object 452 | output_dir: Directory to save extracted images 453 | caption_dir: Optional directory to save caption files 454 | save_binary: Whether to save binary data 455 | """ 456 | ds = ( 457 | lance.dataset(lance_or_path, version=version) 458 | if isinstance(lance_or_path, str) 459 | else lance_or_path 460 | ) 461 | 462 | # Create output directories 463 | output_path = Path(output_dir) 464 | output_path.mkdir(parents=True, exist_ok=True) 465 | 466 | if caption_dir: 467 | captions_dir_path = Path(caption_dir) 468 | captions_dir_path.mkdir(parents=True, exist_ok=True) 469 | 470 | with Progress( 471 | "[progress.description]{task.description}", 472 | SpinnerColumn(spinner_name="dots"), 473 | MofNCompleteColumn(separator="/"), 474 | BarColumn(bar_width=40, complete_style="green", finished_style="bold green"), 475 | TextColumn("•"), 476 | TaskProgressColumn(), 477 | TextColumn("•"), 478 | TransferSpeedColumn(), 479 | TextColumn("•"), 480 | TimeElapsedColumn(), 481 | TextColumn("•"), 482 | TimeRemainingColumn(), 483 | expand=True, 484 | transient=False, # 防止进度条随刷新滚动 485 | ) as progress: 486 | 487 | global console 488 | 489 | console = progress.console 490 | 491 | task = progress.add_task("[green]Extracting files...", total=ds.count_rows()) 492 | 493 | for batch in ds.to_batches(): 494 | # Get all metadata columns 495 | metadata_batch = { 496 | field[0]: batch.column(field[0]).to_pylist() 497 | for field in DATASET_SCHEMA 498 | if field[0] != "blob" # Skip blob to save memory 499 | } 500 | indices = list(range(len(batch))) 501 | blobs = ds.take_blobs(indices, "blob") 502 | 503 | for i in range(len(batch)): 504 | # Create metadata dict for current item 505 | metadata = {key: values[i] for key, values in metadata_batch.items()} 506 | uri = Path(metadata["uris"]) 507 | blob = blobs[i] 508 | 509 | media_type = None 510 | suffix = uri.suffix.lower() 511 | if suffix in image_extensions: 512 | media_type = "image" 513 | elif suffix in animation_extensions: 514 | media_type = "animation" 515 | elif suffix in video_extensions: 516 | media_type = "video" 517 | elif suffix in audio_extensions: 518 | media_type = "audio" 519 | elif suffix in application_extensions: 520 | media_type = "application" 521 | if not uri.exists() and blob: 522 | if media_type: 523 | if not save_blob(uri, blob, metadata, media_type): 524 | progress.advance(task) 525 | continue 526 | else: 527 | console.print( 528 | f"[yellow]Unsupported file format: {suffix}[/yellow]" 529 | ) 530 | progress.advance(task) 531 | continue 532 | 533 | # Save caption if available 534 | caption = metadata.get("captions", []) 535 | if caption: 536 | caption_file_path = captions_dir_path if caption_dir else uri 537 | caption_file_path.parent.mkdir(parents=True, exist_ok=True) 538 | save_caption(caption_file_path, caption, media_type) 539 | 540 | if clip_with_caption and (uri.with_suffix(".srt")).exists(): 541 | subs = pysrt.open(uri.with_suffix(".srt"), encoding="utf-8") 542 | try: 543 | split_video_with_imageio_ffmpeg(uri, subs, save_caption) 544 | except Exception as e: 545 | console.print(f"[red]Error splitting video: {e}[/red]") 546 | split_media_stream_clips( 547 | uri, media_type, subs, save_caption 548 | ) 549 | elif clip_with_caption and (uri.with_suffix(".md")).exists(): 550 | split_md_document(uri, caption, save_caption) 551 | 552 | progress.advance(task) 553 | 554 | 555 | def main(): 556 | 557 | parser = argparse.ArgumentParser( 558 | description="Extract images and captions from a Lance dataset" 559 | ) 560 | parser.add_argument("lance_file", help="Path to the .lance file") 561 | parser.add_argument( 562 | "--output_dir", 563 | default="./dataset", 564 | help="Directory to save extracted data", 565 | ) 566 | parser.add_argument( 567 | "--version", 568 | default="gemini", 569 | help="Dataset version", 570 | ) 571 | 572 | parser.add_argument( 573 | "--not_clip_with_caption", 574 | action="store_true", 575 | help="Not clip with caption", 576 | ) 577 | 578 | args = parser.parse_args() 579 | extract_from_lance( 580 | args.lance_file, args.output_dir, args.version, not args.not_clip_with_caption 581 | ) 582 | 583 | 584 | if __name__ == "__main__": 585 | main() 586 | -------------------------------------------------------------------------------- /module/waterdetect.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import time 4 | from PIL import Image 5 | from pathlib import Path 6 | import lance 7 | import pyarrow as pa 8 | from rich.console import Console 9 | from rich.pretty import Pretty 10 | from rich.progress import ( 11 | Progress, 12 | SpinnerColumn, 13 | TextColumn, 14 | BarColumn, 15 | TaskProgressColumn, 16 | TimeRemainingColumn, 17 | TimeElapsedColumn, 18 | TransferSpeedColumn, 19 | MofNCompleteColumn, 20 | ) 21 | import torch 22 | import onnxruntime as ort 23 | from huggingface_hub import hf_hub_download 24 | from module.lanceImport import transform2lance 25 | import concurrent.futures 26 | from transformers import AutoImageProcessor 27 | import shutil 28 | import json 29 | 30 | console = Console() 31 | 32 | FILES = ["model.onnx"] 33 | 34 | 35 | def preprocess_image(image): 36 | """使用processor预处理图像""" 37 | try: 38 | # 转换为RGB模式 39 | if isinstance(image, np.ndarray): 40 | image = Image.fromarray(image).convert("RGB") 41 | elif isinstance(image, str) or isinstance(image, Path): 42 | image = Image.open(image).convert("RGB") 43 | elif not isinstance(image, Image.Image): 44 | raise TypeError("Input must be a PIL image, numpy array, or file path") 45 | 46 | # 使用processor预处理图像 47 | inputs = processor(images=image, return_tensors="pt") 48 | 49 | # 转换为numpy数组返回 50 | return inputs["pixel_values"][0].numpy() 51 | except Exception as e: 52 | console.print(f"[red]preprocess_image error: {str(e)}[/red]") 53 | return None 54 | 55 | 56 | def load_and_preprocess_batch(uris): 57 | """并行加载和预处理一批图像""" 58 | 59 | def load_single_image(uri): 60 | try: 61 | # 直接传入路径,在preprocess_image中处理转换 62 | return preprocess_image(uri) 63 | except Exception as e: 64 | console.print(f"[red]Error processing {uri}: {str(e)}[/red]") 65 | return None 66 | 67 | with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: 68 | batch_images = list(executor.map(load_single_image, uris)) 69 | 70 | # 过滤掉加载失败的图像 71 | valid_images = [(i, img) for i, img in enumerate(batch_images) if img is not None] 72 | images = [img for _, img in valid_images] 73 | 74 | return images 75 | 76 | 77 | def process_batch(images, session, input_name): 78 | """处理图像批次""" 79 | try: 80 | # 图像通过processor处理后的numpy数组,直接堆叠 81 | batch_data = np.ascontiguousarray(np.stack(images)) 82 | # 执行推理 83 | outputs = session.run(None, {input_name: batch_data}) 84 | return outputs[0] 85 | except Exception as e: 86 | console.print(f"[red]Batch processing error: {str(e)}[/red]") 87 | return None 88 | 89 | 90 | def load_model(args): 91 | """加载模型和标签""" 92 | model_path = Path(args.model_dir) / args.repo_id.replace("/", "_") / "model.onnx" 93 | 94 | global processor 95 | processor = AutoImageProcessor.from_pretrained(args.repo_id, use_fast=True) 96 | 97 | # 下载模型 98 | if not model_path.exists() or args.force_download: 99 | for file in FILES: 100 | file_path = Path(args.model_dir) / args.repo_id.replace("/", "_") / file 101 | if not file_path.exists() or args.force_download: 102 | file_path = Path( 103 | hf_hub_download( 104 | repo_id=args.repo_id, 105 | filename=file, 106 | local_dir=file_path.parent, 107 | force_download=args.force_download, 108 | ) 109 | ) 110 | console.print(f"[blue]Downloaded {file} to {file_path}[/blue]") 111 | else: 112 | console.print(f"[green]Using existing {file}[/green]") 113 | 114 | # 设置推理提供者 115 | providers = [] 116 | if "TensorrtExecutionProvider" in ort.get_available_providers(): 117 | providers.append("TensorrtExecutionProvider") 118 | console.print("[green]Using TensorRT for inference[/green]") 119 | console.print("[yellow]compile may take a long time, please wait...[/yellow]") 120 | elif "CUDAExecutionProvider" in ort.get_available_providers(): 121 | providers.append("CUDAExecutionProvider") 122 | console.print("[green]Using CUDA for inference[/green]") 123 | elif "ROCMExecutionProvider" in ort.get_available_providers(): 124 | providers.append("ROCMExecutionProvider") 125 | console.print("[green]Using ROCm for inference[/green]") 126 | elif "OpenVINOExecutionProvider" in ort.get_available_providers(): 127 | providers = [("OpenVINOExecutionProvider", {"device_type": "GPU_FP32"})] 128 | console.print("[green]Using OpenVINO for inference[/green]") 129 | else: 130 | providers.append("CPUExecutionProvider") 131 | console.print("[yellow]Using CPU for inference[/yellow]") 132 | 133 | # 创建推理会话 134 | sess_options = ort.SessionOptions() 135 | sess_options.graph_optimization_level = ( 136 | ort.GraphOptimizationLevel.ORT_ENABLE_ALL 137 | ) # 启用所有优化 138 | 139 | if "CPUExecutionProvider" in providers: 140 | # CPU时启用多线程推理 141 | sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL # 启用并行执行 142 | sess_options.inter_op_num_threads = 8 # 设置线程数 143 | sess_options.intra_op_num_threads = 8 # 设置算子内部并行数 144 | 145 | # TensorRT 优化 146 | if "TensorrtExecutionProvider" in providers: 147 | sess_options.enable_mem_pattern = True 148 | sess_options.enable_mem_reuse = True 149 | providers_with_options = [ 150 | ( 151 | "TensorrtExecutionProvider", 152 | { 153 | "trt_fp16_enable": True, # Enable FP16 precision for faster inference 154 | "trt_builder_optimization_level": 3, 155 | "trt_max_partition_iterations": 1000, 156 | "trt_engine_cache_enable": True, 157 | "trt_engine_cache_path": f"{Path(args.model_dir) / args.repo_id.replace('/', '_')}/trt_engines", 158 | "trt_engine_hw_compatible": True, 159 | "trt_force_sequential_engine_build": False, 160 | "trt_context_memory_sharing_enable": True, 161 | "trt_timing_cache_enable": True, 162 | "trt_timing_cache_path": f"{Path(args.model_dir) / args.repo_id.replace('/', '_')}", 163 | "trt_sparsity_enable": True, 164 | "trt_min_subgraph_size": 7, 165 | # "trt_detailed_build_log": True, 166 | }, 167 | ), 168 | ( 169 | "CUDAExecutionProvider", 170 | { 171 | "arena_extend_strategy": "kSameAsRequested", 172 | "cudnn_conv_algo_search": "EXHAUSTIVE", 173 | "do_copy_in_default_stream": True, 174 | "cudnn_conv_use_max_workspace": "1", # 使用最大工作空间 175 | "tunable_op_enable": True, # 启用可调优操作 176 | "tunable_op_tuning_enable": True, # 启用调优 177 | }, 178 | ), 179 | ] 180 | 181 | elif "CUDAExecutionProvider" in providers: 182 | # CUDA GPU 优化 183 | sess_options.enable_mem_pattern = True 184 | sess_options.enable_mem_reuse = True 185 | providers_with_options = [ 186 | ( 187 | "CUDAExecutionProvider", 188 | { 189 | "arena_extend_strategy": "kSameAsRequested", 190 | "cudnn_conv_algo_search": "EXHAUSTIVE", 191 | "do_copy_in_default_stream": True, 192 | "cudnn_conv_use_max_workspace": "1", 193 | "tunable_op_enable": True, 194 | "tunable_op_tuning_enable": True, 195 | }, 196 | ), 197 | ] 198 | else: 199 | providers_with_options = providers 200 | 201 | console.print(f"[cyan]Providers with options:[/cyan]") 202 | console.print(Pretty(providers_with_options, indent_guides=True, expand_all=True)) 203 | start_time = time.time() 204 | ort_sess = ort.InferenceSession( 205 | str(model_path), sess_options=sess_options, providers=providers_with_options 206 | ) 207 | input_name = ort_sess.get_inputs()[0].name 208 | console.print( 209 | f"[green]Model loaded in {time.time() - start_time:.2f} seconds[/green]" 210 | ) 211 | return ort_sess, input_name 212 | 213 | 214 | def main(args): 215 | global console 216 | 217 | watermark_dir = Path(args.train_data_dir) / "watermarked" 218 | no_watermark_dir = Path(args.train_data_dir) / "no_watermark" 219 | # 确保目标文件夹存在 220 | if watermark_dir.exists(): 221 | for symlink in watermark_dir.rglob("*"): 222 | if symlink.is_symlink(): 223 | symlink.unlink() 224 | if no_watermark_dir.exists(): 225 | for symlink in no_watermark_dir.rglob("*"): 226 | if symlink.is_symlink(): 227 | symlink.unlink() 228 | 229 | # 初始化 Lance 数据集 230 | if not isinstance(args.train_data_dir, lance.LanceDataset): 231 | if args.train_data_dir.endswith(".lance"): 232 | dataset = lance.dataset(args.train_data_dir) 233 | elif any( 234 | file.suffix == ".lance" for file in Path(args.train_data_dir).glob("*") 235 | ): 236 | lance_file = next( 237 | file 238 | for file in Path(args.train_data_dir).glob("*") 239 | if file.suffix == ".lance" 240 | ) 241 | dataset = lance.dataset(str(lance_file)) 242 | else: 243 | console.print("[yellow]Converting dataset to Lance format...[/yellow]") 244 | dataset = transform2lance( 245 | args.train_data_dir, 246 | output_name="dataset", 247 | save_binary=False, 248 | not_save_disk=False, 249 | tag="WatermarkDetection", 250 | ) 251 | console.print("[green]Dataset converted to Lance format[/green]") 252 | 253 | else: 254 | dataset = args.train_data_dir 255 | console.print("[green]Using existing Lance dataset[/green]") 256 | 257 | ort_sess, input_name = load_model(args) 258 | 259 | # 先计算图片总数 260 | total_images = len( 261 | dataset.to_table( 262 | columns=["mime"], 263 | filter=("mime LIKE 'image/%'"), 264 | ) 265 | ) 266 | 267 | # 然后创建带columns的scanner处理数据 268 | scanner = dataset.scanner( 269 | columns=["uris", "mime", "captions"], 270 | filter=("mime LIKE 'image/%'"), 271 | scan_in_order=True, 272 | batch_size=args.batch_size, 273 | batch_readahead=16, 274 | fragment_readahead=4, 275 | io_buffer_size=32 * 1024 * 1024, # 32MB buffer 276 | late_materialization=True, 277 | ) 278 | 279 | with Progress( 280 | "[progress.description]{task.description}", 281 | SpinnerColumn(spinner_name="dots"), 282 | MofNCompleteColumn(separator="/"), 283 | BarColumn(bar_width=40, complete_style="green", finished_style="bold green"), 284 | TextColumn("•"), 285 | TaskProgressColumn(), 286 | TextColumn("•"), 287 | TransferSpeedColumn(), 288 | TextColumn("•"), 289 | TimeElapsedColumn(), 290 | TextColumn("•"), 291 | TimeRemainingColumn(), 292 | expand=True, 293 | ) as progress: 294 | task = progress.add_task("[bold cyan]Processing images...", total=total_images) 295 | 296 | console = progress.console 297 | 298 | # 用于收集结果的列表 299 | detection_results = [] 300 | 301 | for batch in scanner.to_batches(): 302 | uris = batch["uris"].to_pylist() # 获取文件路径 303 | 304 | # 使用并行处理加载和预处理图像 305 | batch_images = load_and_preprocess_batch(uris) 306 | 307 | if not batch_images: 308 | progress.update(task, advance=len(uris)) 309 | continue 310 | 311 | # 处理批次 312 | probs = process_batch(batch_images, ort_sess, input_name) 313 | # 创建对应的目标文件夹(如果不存在) 314 | if probs is not None: 315 | for path, prob in zip(uris, probs): 316 | # 获取水印检测结果 317 | watermark_prob = prob[1] # 索引1对应"Watermark"标签 318 | 319 | # 根据概率确定是否有水印(阈值可以调整) 320 | has_watermark = watermark_prob > args.thresh 321 | 322 | # 添加到结果列表 323 | detection_results.append((path, float(watermark_prob))) 324 | 325 | # 创建软链接 326 | source_path = Path(path).absolute() 327 | relative_path = source_path.relative_to( 328 | Path(args.train_data_dir).absolute() 329 | ) 330 | target_dir = watermark_dir if has_watermark else no_watermark_dir 331 | target_path = target_dir / relative_path 332 | target_path.parent.mkdir(parents=True, exist_ok=True) 333 | 334 | # 创建软链接 335 | try: 336 | target_path.symlink_to(source_path) 337 | except (FileExistsError, PermissionError) as e: 338 | console.print( 339 | f"[red]Unable to create symlink for {path}: {e}[/red]" 340 | ) 341 | # 如果无法创建软链接,尝试复制文件代替 342 | try: 343 | shutil.copy2(source_path, target_path) 344 | console.print( 345 | f"[yellow]Created copy instead of symlink for {path}[/yellow]" 346 | ) 347 | except Exception as copy_err: 348 | console.print(f"[red]Failed to copy file: {copy_err}[/red]") 349 | 350 | progress.update(task, advance=len(batch["uris"].to_pylist())) 351 | 352 | # 统计水印图片数量 353 | watermark_count = sum(1 for _, prob in detection_results if prob > args.thresh) 354 | total_count = len(detection_results) 355 | 356 | # 按路径层次结构组织结果 357 | path_tree = {} 358 | for path, prob in detection_results: 359 | parts = Path(path).relative_to(Path(args.train_data_dir).absolute()).parts 360 | current = path_tree 361 | for i, part in enumerate(parts): 362 | if i == len(parts) - 1: 363 | # 最后一层是文件名,存储概率 364 | current[part] = f"{prob:.4f}" + ( 365 | "🔴(Watermarked)🔖" if prob > args.thresh else "🟢(No Watermark)📄" 366 | ) 367 | else: 368 | if part not in current: 369 | current[part] = {} 370 | current = current[part] 371 | 372 | # 使用Pretty打印结果树 373 | console.print("\n[bold green]Results:[/bold green]") 374 | console.print(Pretty(path_tree, indent_guides=True, expand_all=True)) 375 | # 保存检测结果树到JSON文件 376 | result_json_path = Path(args.train_data_dir) / "watermark_detection_results.json" 377 | with open(result_json_path, "w", encoding="utf-8") as f: 378 | json.dump(path_tree, f, ensure_ascii=False, indent=2) 379 | console.print(f"[bold green]Results saved to:[/bold green] {result_json_path}") 380 | 381 | # 打印检测结果统计 382 | console.print( 383 | f"🔴Watermarked🔖: {watermark_count} ({watermark_count/total_count*100:.2f}%)" 384 | ) 385 | console.print( 386 | f"🟢No Watermark📄: {total_count - watermark_count} ({(total_count - watermark_count)/total_count*100:.2f}%)" 387 | ) 388 | 389 | 390 | def setup_parser() -> argparse.ArgumentParser: 391 | parser = argparse.ArgumentParser() 392 | parser.add_argument( 393 | "train_data_dir", 394 | type=str, 395 | help="Directory containing images to process", 396 | ) 397 | parser.add_argument( 398 | "--repo_id", 399 | type=str, 400 | default="bdsqlsz/joycaption-watermark-detection-onnx", 401 | help="Repository ID for Watermark Detection model on Hugging Face", 402 | ) 403 | parser.add_argument( 404 | "--model_dir", 405 | type=str, 406 | default="watermark_detection", 407 | help="Directory to store Watermark Detection model", 408 | ) 409 | parser.add_argument( 410 | "--force_download", 411 | action="store_true", 412 | help="Force downloading Watermark Detection model", 413 | ) 414 | parser.add_argument( 415 | "--batch_size", 416 | type=int, 417 | default=16, 418 | help="Batch size for inference", 419 | ) 420 | parser.add_argument( 421 | "--thresh", 422 | type=float, 423 | default=1.0, 424 | help="Default threshold for tag confidence", 425 | ) 426 | 427 | return parser 428 | 429 | 430 | if __name__ == "__main__": 431 | parser = setup_parser() 432 | 433 | args = parser.parse_args() 434 | 435 | main(args) 436 | -------------------------------------------------------------------------------- /requirements-uv-linux.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile requirements.txt -o requirements-uv-linux.txt --index-strategy unsafe-best-match --no-build-isolation -p 3.11 3 | accelerate==1.6.0 4 | # via -r requirements.txt 5 | aiohappyeyeballs==2.6.1 6 | # via aiohttp 7 | aiohttp==3.11.18 8 | # via dashscope 9 | aiosignal==1.3.2 10 | # via aiohttp 11 | annotated-types==0.7.0 12 | # via pydantic 13 | anyio==4.9.0 14 | # via 15 | # google-genai 16 | # httpx 17 | # openai 18 | attrs==25.3.0 19 | # via aiohttp 20 | av==14.3.0 21 | # via -r requirements.txt 22 | bitmath==1.3.3.1 23 | # via hbutils 24 | cachetools==5.5.2 25 | # via 26 | # google-auth 27 | # zhipuai 28 | certifi==2025.4.26 29 | # via 30 | # httpcore 31 | # httpx 32 | # requests 33 | chardet==4.0.0 34 | # via 35 | # hbutils 36 | # pysrt 37 | charset-normalizer==3.4.2 38 | # via requests 39 | click==8.1.8 40 | # via scenedetect 41 | coloredlogs==15.0.1 42 | # via onnxruntime-gpu 43 | cssselect==1.3.0 44 | # via pyquery 45 | dashscope==1.23.2 46 | # via -r requirements.txt 47 | deprecation==2.1.0 48 | # via hbutils 49 | distro==1.9.0 50 | # via openai 51 | eval-type-backport==0.2.2 52 | # via mistralai 53 | filelock==3.18.0 54 | # via 55 | # huggingface-hub 56 | # torch 57 | # transformers 58 | flatbuffers==25.2.10 59 | # via onnxruntime-gpu 60 | frozenlist==1.6.0 61 | # via 62 | # aiohttp 63 | # aiosignal 64 | fsspec==2025.3.2 65 | # via 66 | # huggingface-hub 67 | # torch 68 | google-auth==2.40.1 69 | # via google-genai 70 | google-genai==1.14.0 71 | # via -r requirements.txt 72 | h11==0.16.0 73 | # via httpcore 74 | hbutils==0.11.0 75 | # via pyanimeinfo 76 | hf-xet==1.1.0 77 | # via huggingface-hub 78 | httpcore==1.0.9 79 | # via httpx 80 | httpx==0.28.1 81 | # via 82 | # google-genai 83 | # mistralai 84 | # openai 85 | # zhipuai 86 | huggingface-hub==0.31.1 87 | # via 88 | # -r requirements.txt 89 | # accelerate 90 | # tokenizers 91 | # transformers 92 | humanfriendly==10.0 93 | # via coloredlogs 94 | idna==3.10 95 | # via 96 | # anyio 97 | # httpx 98 | # requests 99 | # yarl 100 | imageio==2.37.0 101 | # via -r requirements.txt 102 | imageio-ffmpeg==0.6.0 103 | # via -r requirements.txt 104 | jinja2==3.1.6 105 | # via torch 106 | jiter==0.9.0 107 | # via openai 108 | lxml==4.9.4 109 | # via 110 | # pyanimeinfo 111 | # pyquery 112 | markdown-it-py==3.0.0 113 | # via rich 114 | markupsafe==3.0.2 115 | # via jinja2 116 | mdurl==0.1.2 117 | # via markdown-it-py 118 | mistralai==1.7.0 119 | # via -r requirements.txt 120 | mpmath==1.3.0 121 | # via sympy 122 | multidict==6.4.3 123 | # via 124 | # aiohttp 125 | # yarl 126 | mutagen==1.47.0 127 | # via -r requirements.txt 128 | networkx==3.4.2 129 | # via torch 130 | numpy==2.2.5 131 | # via 132 | # accelerate 133 | # imageio 134 | # onnxruntime-gpu 135 | # opencv-contrib-python-rolling 136 | # pandas 137 | # pylance 138 | # scenedetect 139 | # scipy 140 | # torchvision 141 | # transformers 142 | nvidia-cublas-cu12==12.6.4.1 143 | # via 144 | # nvidia-cudnn-cu12 145 | # nvidia-cusolver-cu12 146 | # torch 147 | nvidia-cuda-cupti-cu12==12.6.80 148 | # via torch 149 | nvidia-cuda-nvrtc-cu12==12.6.77 150 | # via torch 151 | nvidia-cuda-runtime-cu12==12.6.77 152 | # via 153 | # tensorrt-cu12-libs 154 | # torch 155 | nvidia-cudnn-cu12==9.5.1.17 156 | # via torch 157 | nvidia-cufft-cu12==11.3.0.4 158 | # via torch 159 | nvidia-cufile-cu12==1.11.1.6 160 | # via torch 161 | nvidia-curand-cu12==10.3.7.77 162 | # via torch 163 | nvidia-cusolver-cu12==11.7.1.2 164 | # via torch 165 | nvidia-cusparse-cu12==12.5.4.2 166 | # via 167 | # nvidia-cusolver-cu12 168 | # torch 169 | nvidia-cusparselt-cu12==0.6.3 170 | # via torch 171 | nvidia-nccl-cu12==2.26.2 172 | # via torch 173 | nvidia-nvjitlink-cu12==12.6.85 174 | # via 175 | # nvidia-cufft-cu12 176 | # nvidia-cusolver-cu12 177 | # nvidia-cusparse-cu12 178 | # torch 179 | nvidia-nvtx-cu12==12.6.77 180 | # via torch 181 | onnxruntime-gpu==1.20.2 182 | # via -r requirements.txt 183 | openai==1.78.0 184 | # via -r requirements.txt 185 | opencv-contrib-python-rolling @ https://github.com/cudawarped/opencv-python-cuda-wheels/releases/download/4.11.0.20250210/opencv_contrib_python_rolling-4.12.0.20250210-cp37-abi3-linux_x86_64.whl 186 | # via -r requirements.txt 187 | packaging==25.0 188 | # via 189 | # accelerate 190 | # deprecation 191 | # hbutils 192 | # huggingface-hub 193 | # onnxruntime-gpu 194 | # pillow-jxl-plugin 195 | # transformers 196 | pandas==2.2.3 197 | # via -r requirements.txt 198 | pillow==11.2.1 199 | # via 200 | # imageio 201 | # pillow-heif 202 | # pillow-jxl-plugin 203 | # rich-pixels 204 | # torchvision 205 | pillow-avif-plugin==1.5.2 206 | # via -r requirements.txt 207 | pillow-heif==0.22.0 208 | # via -r requirements.txt 209 | pillow-jxl-plugin==1.3.2 210 | # via -r requirements.txt 211 | platformdirs==4.3.8 212 | # via scenedetect 213 | propcache==0.3.1 214 | # via 215 | # aiohttp 216 | # yarl 217 | protobuf==6.30.2 218 | # via onnxruntime-gpu 219 | psutil==7.0.0 220 | # via accelerate 221 | pyanimeinfo==0.0.4 222 | # via -r requirements.txt 223 | pyarrow==20.0.0 224 | # via 225 | # -r requirements.txt 226 | # pylance 227 | pyasn1==0.6.1 228 | # via 229 | # pyasn1-modules 230 | # rsa 231 | pyasn1-modules==0.4.2 232 | # via google-auth 233 | pydantic==2.11.4 234 | # via 235 | # google-genai 236 | # mistralai 237 | # openai 238 | # zhipuai 239 | pydantic-core==2.33.2 240 | # via 241 | # pydantic 242 | # zhipuai 243 | pygments==2.19.1 244 | # via rich 245 | pyjwt==2.8.0 246 | # via zhipuai 247 | pylance==0.26.1 248 | # via -r requirements.txt 249 | pymediainfo==7.0.1 250 | # via -r requirements.txt 251 | pyparsing==3.0.9 252 | # via pyrfc6266 253 | pyquery==2.0.1 254 | # via pyanimeinfo 255 | pyrfc6266==1.0.2 256 | # via pyanimeinfo 257 | pysrt==1.1.2 258 | # via -r requirements.txt 259 | python-dateutil==2.9.0.post0 260 | # via 261 | # mistralai 262 | # pandas 263 | pytimeparse==1.1.8 264 | # via hbutils 265 | pytz==2025.2 266 | # via pandas 267 | pyyaml==6.0.2 268 | # via 269 | # accelerate 270 | # huggingface-hub 271 | # transformers 272 | regex==2024.11.6 273 | # via transformers 274 | requests==2.32.3 275 | # via 276 | # dashscope 277 | # google-genai 278 | # huggingface-hub 279 | # pyanimeinfo 280 | # transformers 281 | rich==14.0.0 282 | # via 283 | # -r requirements.txt 284 | # rich-pixels 285 | rich-pixels==3.0.1 286 | # via -r requirements.txt 287 | rsa==4.9.1 288 | # via google-auth 289 | safetensors==0.5.3 290 | # via 291 | # accelerate 292 | # transformers 293 | scenedetect==0.6.6 294 | # via -r requirements.txt 295 | scipy==1.15.3 296 | # via -r requirements.txt 297 | setuptools==80.3.1 298 | # via 299 | # hbutils 300 | # triton 301 | six==1.17.0 302 | # via python-dateutil 303 | sniffio==1.3.1 304 | # via 305 | # anyio 306 | # openai 307 | sympy==1.14.0 308 | # via 309 | # onnxruntime-gpu 310 | # torch 311 | tensorrt==10.10.0.31 312 | # via -r requirements.txt 313 | tensorrt-cu12==10.10.0.31 314 | # via tensorrt 315 | tensorrt-cu12-bindings==10.10.0.31 316 | # via tensorrt-cu12 317 | tensorrt-cu12-libs==10.10.0.31 318 | # via tensorrt-cu12 319 | tokenizers==0.21.1 320 | # via transformers 321 | toml==0.10.2 322 | # via -r requirements.txt 323 | torch==2.7.0 324 | # via 325 | # -r requirements.txt 326 | # accelerate 327 | # torchvision 328 | torchvision==0.22.0 329 | # via -r requirements.txt 330 | tqdm==4.67.1 331 | # via 332 | # huggingface-hub 333 | # openai 334 | # pyanimeinfo 335 | # scenedetect 336 | # transformers 337 | transformers==4.51.3 338 | # via -r requirements.txt 339 | triton==3.3.0 340 | # via torch 341 | typing-extensions==4.13.2 342 | # via 343 | # anyio 344 | # google-genai 345 | # huggingface-hub 346 | # openai 347 | # pydantic 348 | # pydantic-core 349 | # torch 350 | # typing-inspection 351 | typing-inspection==0.4.0 352 | # via 353 | # mistralai 354 | # pydantic 355 | tzdata==2025.2 356 | # via pandas 357 | urllib3==2.4.0 358 | # via requests 359 | websocket-client==1.8.0 360 | # via dashscope 361 | websockets==15.0.1 362 | # via google-genai 363 | yarl==1.20.0 364 | # via aiohttp 365 | zhipuai==2.1.5.20250421 366 | # via -r requirements.txt 367 | -------------------------------------------------------------------------------- /requirements-uv.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile requirements.txt -o requirements-uv.txt --index-strategy unsafe-best-match --no-build-isolation -p 3.11 3 | accelerate==1.6.0 4 | # via -r requirements.txt 5 | aiohappyeyeballs==2.4.6 6 | # via aiohttp 7 | aiohttp==3.11.12 8 | # via dashscope 9 | aiosignal==1.3.2 10 | # via aiohttp 11 | annotated-types==0.7.0 12 | # via pydantic 13 | anyio==4.8.0 14 | # via 15 | # google-genai 16 | # httpx 17 | # openai 18 | attrs==25.1.0 19 | # via aiohttp 20 | av==14.0.1 21 | # via -r requirements.txt 22 | bitmath==1.3.3.1 23 | # via hbutils 24 | cachetools==5.5.0 25 | # via 26 | # google-auth 27 | # zhipuai 28 | certifi==2024.12.14 29 | # via 30 | # httpcore 31 | # httpx 32 | # requests 33 | chardet==4.0.0 34 | # via 35 | # hbutils 36 | # pysrt 37 | charset-normalizer==3.4.1 38 | # via requests 39 | click==8.1.8 40 | # via scenedetect 41 | colorama==0.4.6 42 | # via 43 | # click 44 | # tqdm 45 | coloredlogs==15.0.1 46 | # via onnxruntime-gpu 47 | cssselect==1.3.0 48 | # via pyquery 49 | dashscope==1.22.1 50 | # via -r requirements.txt 51 | deprecation==2.1.0 52 | # via hbutils 53 | distro==1.9.0 54 | # via openai 55 | eval-type-backport==0.2.2 56 | # via mistralai 57 | filelock==3.18.0 58 | # via 59 | # huggingface-hub 60 | # torch 61 | # transformers 62 | flatbuffers==25.2.10 63 | # via onnxruntime-gpu 64 | frozenlist==1.5.0 65 | # via 66 | # aiohttp 67 | # aiosignal 68 | fsspec==2025.3.2 69 | # via 70 | # huggingface-hub 71 | # torch 72 | google-auth==2.37.0 73 | # via google-genai 74 | google-genai==1.11.0 75 | # via -r requirements.txt 76 | h11==0.14.0 77 | # via httpcore 78 | hbutils==0.11.0 79 | # via pyanimeinfo 80 | hf-xet==1.0.3 81 | # via huggingface-hub 82 | httpcore==1.0.7 83 | # via httpx 84 | httpx==0.28.1 85 | # via 86 | # google-genai 87 | # mistralai 88 | # openai 89 | # zhipuai 90 | huggingface-hub==0.30.2 91 | # via 92 | # -r requirements.txt 93 | # accelerate 94 | # tokenizers 95 | # transformers 96 | humanfriendly==10.0 97 | # via coloredlogs 98 | idna==3.10 99 | # via 100 | # anyio 101 | # httpx 102 | # requests 103 | # yarl 104 | imageio==2.36.1 105 | # via -r requirements.txt 106 | imageio-ffmpeg==0.5.1 107 | # via -r requirements.txt 108 | jinja2==3.1.6 109 | # via torch 110 | jiter==0.8.2 111 | # via openai 112 | lxml==4.9.4 113 | # via 114 | # pyanimeinfo 115 | # pyquery 116 | markdown-it-py==3.0.0 117 | # via rich 118 | markupsafe==3.0.2 119 | # via jinja2 120 | mdurl==0.1.2 121 | # via markdown-it-py 122 | mistralai==1.7.0 123 | # via -r requirements.txt 124 | mpmath==1.3.0 125 | # via sympy 126 | multidict==6.1.0 127 | # via 128 | # aiohttp 129 | # yarl 130 | mutagen==1.47.0 131 | # via -r requirements.txt 132 | networkx==3.4.2 133 | # via torch 134 | numpy==1.26.4 135 | # via 136 | # accelerate 137 | # imageio 138 | # onnxruntime-gpu 139 | # opencv-contrib-python-rolling 140 | # pandas 141 | # pylance 142 | # scenedetect 143 | # scipy 144 | # torchvision 145 | # transformers 146 | nvidia-cuda-runtime-cu12==12.8.90 147 | # via tensorrt-cu12-libs 148 | onnxruntime-gpu==1.20.2 149 | # via -r requirements.txt 150 | openai==1.59.7 151 | # via -r requirements.txt 152 | opencv-contrib-python-rolling @ https://github.com/cudawarped/opencv-python-cuda-wheels/releases/download/4.11.0.20250124/opencv_contrib_python_rolling-4.12.0.86-cp37-abi3-win_amd64.whl 153 | # via -r requirements.txt 154 | packaging==24.2 155 | # via 156 | # accelerate 157 | # deprecation 158 | # hbutils 159 | # huggingface-hub 160 | # onnxruntime-gpu 161 | # pillow-jxl-plugin 162 | # transformers 163 | pandas==2.2.3 164 | # via -r requirements.txt 165 | pillow==10.4.0 166 | # via 167 | # imageio 168 | # pillow-heif 169 | # pillow-jxl-plugin 170 | # rich-pixels 171 | # torchvision 172 | pillow-avif-plugin==1.4.6 173 | # via -r requirements.txt 174 | pillow-heif==0.21.0 175 | # via -r requirements.txt 176 | pillow-jxl-plugin==1.3.0 177 | # via -r requirements.txt 178 | platformdirs==4.3.6 179 | # via scenedetect 180 | propcache==0.2.1 181 | # via 182 | # aiohttp 183 | # yarl 184 | protobuf==6.30.2 185 | # via onnxruntime-gpu 186 | psutil==7.0.0 187 | # via accelerate 188 | pyanimeinfo==0.0.4 189 | # via -r requirements.txt 190 | pyarrow==18.1.0 191 | # via 192 | # -r requirements.txt 193 | # pylance 194 | pyasn1==0.6.1 195 | # via 196 | # pyasn1-modules 197 | # rsa 198 | pyasn1-modules==0.4.1 199 | # via google-auth 200 | pydantic==2.10.4 201 | # via 202 | # google-genai 203 | # mistralai 204 | # openai 205 | # zhipuai 206 | pydantic-core==2.27.2 207 | # via 208 | # pydantic 209 | # zhipuai 210 | pygments==2.18.0 211 | # via rich 212 | pyjwt==2.8.0 213 | # via zhipuai 214 | pylance==0.20.0 215 | # via -r requirements.txt 216 | pymediainfo==6.1.0 217 | # via -r requirements.txt 218 | pyparsing==3.0.9 219 | # via pyrfc6266 220 | pyquery==2.0.1 221 | # via pyanimeinfo 222 | pyreadline3==3.5.4 223 | # via humanfriendly 224 | pyrfc6266==1.0.2 225 | # via pyanimeinfo 226 | pysrt==1.1.2 227 | # via -r requirements.txt 228 | python-dateutil==2.9.0.post0 229 | # via 230 | # mistralai 231 | # pandas 232 | pytimeparse==1.1.8 233 | # via hbutils 234 | pytz==2024.2 235 | # via pandas 236 | pyyaml==6.0.2 237 | # via 238 | # accelerate 239 | # huggingface-hub 240 | # transformers 241 | regex==2024.11.6 242 | # via transformers 243 | requests==2.32.3 244 | # via 245 | # dashscope 246 | # google-genai 247 | # huggingface-hub 248 | # pyanimeinfo 249 | # transformers 250 | rich==13.9.4 251 | # via 252 | # -r requirements.txt 253 | # rich-pixels 254 | rich-pixels==3.0.1 255 | # via -r requirements.txt 256 | rsa==4.9 257 | # via google-auth 258 | safetensors==0.5.3 259 | # via 260 | # accelerate 261 | # transformers 262 | scenedetect==0.6.6 263 | # via -r requirements.txt 264 | scipy==1.15.3 265 | # via -r requirements.txt 266 | setuptools==75.6.0 267 | # via 268 | # hbutils 269 | # imageio-ffmpeg 270 | six==1.17.0 271 | # via python-dateutil 272 | sniffio==1.3.1 273 | # via 274 | # anyio 275 | # openai 276 | sympy==1.13.3 277 | # via 278 | # onnxruntime-gpu 279 | # torch 280 | tensorrt==10.9.0.34 281 | # via -r requirements.txt 282 | tensorrt-cu12==10.9.0.34 283 | # via tensorrt 284 | tensorrt-cu12-bindings==10.9.0.34 285 | # via tensorrt-cu12 286 | tensorrt-cu12-libs==10.9.0.34 287 | # via tensorrt-cu12 288 | tokenizers==0.21.1 289 | # via transformers 290 | toml==0.10.2 291 | # via -r requirements.txt 292 | torch==2.7.0 293 | # via 294 | # -r requirements.txt 295 | # accelerate 296 | # torchvision 297 | torchvision==0.22.0 298 | # via -r requirements.txt 299 | tqdm==4.67.1 300 | # via 301 | # huggingface-hub 302 | # openai 303 | # pyanimeinfo 304 | # scenedetect 305 | # transformers 306 | transformers==4.51.3 307 | # via -r requirements.txt 308 | typing-extensions==4.12.2 309 | # via 310 | # anyio 311 | # google-genai 312 | # huggingface-hub 313 | # openai 314 | # pydantic 315 | # pydantic-core 316 | # torch 317 | # typing-inspection 318 | typing-inspection==0.4.0 319 | # via mistralai 320 | tzdata==2024.2 321 | # via pandas 322 | urllib3==2.3.0 323 | # via requests 324 | websocket-client==1.8.0 325 | # via dashscope 326 | websockets==14.2 327 | # via google-genai 328 | yarl==1.18.3 329 | # via aiohttp 330 | zhipuai==2.1.5.20250410 331 | # via -r requirements.txt 332 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | av 2 | pylance>=0.20.0 3 | rich 4 | rich_pixels 5 | pandas 6 | pillow-avif-plugin 7 | pillow-heif 8 | pillow-jxl-plugin 9 | pysrt 10 | toml 11 | imageio>=2.31.1 12 | imageio-ffmpeg>=0.4.8 13 | pymediainfo 14 | mutagen 15 | pyarrow>=14.0.1 16 | mistralai>=1.6.0 17 | google-genai>=1.11.0 18 | openAI 19 | dashscope 20 | pyanimeinfo 21 | zhipuai 22 | 23 | #WDtagger 24 | accelerate>=1.6.0 25 | huggingface_hub[hf_xet]>=0.30.2 26 | https://github.com/cudawarped/opencv-python-cuda-wheels/releases/download/4.11.0.20250124/opencv_contrib_python_rolling-4.12.0.86-cp37-abi3-win_amd64.whl; sys_platform == 'win32' 27 | https://github.com/cudawarped/opencv-python-cuda-wheels/releases/download/4.11.0.20250210/opencv_contrib_python_rolling-4.12.0.20250210-cp37-abi3-linux_x86_64.whl; sys_platform == 'linux' 28 | torch>=2.7.0 29 | onnxruntime-gpu==1.20.2; sys_platform == 'win32' 30 | onnxruntime-gpu>=1.20.2; sys_platform == 'linux' 31 | tensorrt>=10.9 32 | 33 | scenedetect>=0.6.6 34 | 35 | #WatermarkDetect 36 | transformers>4.50 37 | scipy -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/qinglong-captions/c9376f3bf325675c7dc10b5da702e1e0eb629947/utils/__init__.py -------------------------------------------------------------------------------- /utils/console_util.py: -------------------------------------------------------------------------------- 1 | from rich.console import Console 2 | from rich.text import Text 3 | from rich.panel import Panel 4 | from rich.layout import Layout 5 | from rich.markdown import Markdown 6 | from rich.segment import Segment 7 | from rich.style import Style 8 | from rich_pixels import Pixels 9 | from utils.wdtagger import TagClassifier 10 | 11 | # 全局控制台实例 12 | console = Console() 13 | 14 | 15 | class BaseLayout: 16 | """基础布局类,提供创建Rich布局的基本功能""" 17 | 18 | def __init__(self, panel_height=32, console=None): 19 | """ 20 | 初始化基础布局 21 | 22 | Args: 23 | panel_height: 面板高度 24 | console: Rich控制台实例 25 | """ 26 | self.panel_height = panel_height 27 | self.console = console or globals().get("console", Console()) 28 | self.layout = Layout() 29 | 30 | def create_layout(self): 31 | """创建基本布局结构(由子类实现)""" 32 | pass 33 | 34 | def render(self, title=""): 35 | """ 36 | 渲染布局为面板并返回 37 | 38 | Args: 39 | title: 面板标题 40 | 41 | Returns: 42 | Panel: Rich面板对象 43 | """ 44 | panel = Panel( 45 | self.layout, 46 | title=title, 47 | height=self.panel_height + 2, 48 | padding=0, 49 | ) 50 | return panel 51 | 52 | def print(self, title=""): 53 | """ 54 | 打印布局到控制台 55 | 56 | Args: 57 | title: 面板标题 58 | """ 59 | panel = self.render(title) 60 | self.console.print() 61 | self.console.print() 62 | self.console.print(panel) 63 | 64 | 65 | class CaptionLayout(BaseLayout): 66 | """用于显示图片字幕的布局类""" 67 | 68 | def __init__( 69 | self, 70 | tag_description, 71 | short_description, 72 | long_description, 73 | pixels, 74 | short_highlight_rate=0, 75 | long_highlight_rate=0, 76 | panel_height=32, 77 | console=None, 78 | ): 79 | """ 80 | 初始化字幕布局 81 | 82 | Args: 83 | tag_description: 标签描述 84 | short_description: 短描述 85 | long_description: 长描述 86 | pixels: Rich Pixels对象 87 | short_highlight_rate: 短描述高亮率 88 | long_highlight_rate: 长描述高亮率 89 | panel_height: 面板高度 90 | console: Rich控制台实例 91 | """ 92 | super().__init__(panel_height, console) 93 | tagClassifier = TagClassifier() 94 | # Process tag_description to handle various spacing and comma combinations 95 | cleaned_description = tag_description.replace("<", "").replace(">", "") 96 | processed_tags = [tag.strip() for tag in cleaned_description.split(',') if tag.strip()] 97 | tag_values = tagClassifier.classify(processed_tags).values() 98 | self.tag_description = ",".join([",".join(value) for value in tag_values]) 99 | self.short_description = short_description 100 | self.long_description = long_description 101 | self.pixels = pixels 102 | self.short_highlight_rate = short_highlight_rate 103 | self.long_highlight_rate = long_highlight_rate 104 | self.create_layout() 105 | 106 | def create_layout(self): 107 | """创建字幕布局结构""" 108 | # 创建右侧的垂直布局 109 | right_layout = Layout() 110 | 111 | # 创建上半部分的水平布局(tag和short并排) 112 | top_layout = Layout() 113 | top_layout.split_row( 114 | Layout( 115 | Panel( 116 | self.tag_description, 117 | title="tags", 118 | height=self.panel_height // 2, 119 | padding=0, 120 | expand=True, 121 | ), 122 | ratio=1, 123 | ), 124 | Layout( 125 | Panel( 126 | self.short_description, 127 | title=f"short_description - [yellow]hr:[/yellow] {self.short_highlight_rate}", 128 | height=self.panel_height // 2, 129 | padding=0, 130 | expand=True, 131 | ), 132 | ratio=1, 133 | ), 134 | ) 135 | 136 | # 将右侧布局分为上下两部分 137 | right_layout.split_column( 138 | Layout(top_layout, ratio=1), 139 | Layout( 140 | Panel( 141 | self.long_description, 142 | title=f"long_description - [yellow]highlight rate:[/yellow] {self.long_highlight_rate}", 143 | height=self.panel_height // 2, 144 | padding=0, 145 | expand=True, 146 | ) 147 | ), 148 | ) 149 | 150 | # 主布局分为左右两部分 151 | self.layout.split_row( 152 | Layout( 153 | Panel(self.pixels, height=self.panel_height, padding=0, expand=True), 154 | name="image", 155 | ratio=1, 156 | ), 157 | Layout(right_layout, name="caption", ratio=2), 158 | ) 159 | 160 | 161 | class MarkdownLayout(BaseLayout): 162 | """用于显示Markdown内容的布局类""" 163 | 164 | def __init__(self, pixels, markdown_content, panel_height=32, console=None): 165 | """ 166 | 初始化Markdown布局 167 | 168 | Args: 169 | pixels: Rich Pixels对象 170 | markdown_content: Markdown内容 171 | panel_height: 面板高度 172 | console: Rich控制台实例 173 | """ 174 | super().__init__(panel_height, console) 175 | self.pixels = pixels 176 | self.markdown_content = markdown_content 177 | self.create_layout() 178 | 179 | def create_layout(self): 180 | """创建Markdown布局结构""" 181 | # 创建右侧布局(单个Markdown窗口) 182 | right_layout = Layout( 183 | Panel( 184 | Markdown(self.markdown_content), 185 | title="markdown", 186 | padding=0, 187 | expand=True, 188 | ) 189 | ) 190 | 191 | # 如果pixels为空,直接全局渲染markdown内容,否则分为左右两部分 192 | if self.pixels is None: 193 | self.layout.update( 194 | Layout( 195 | Panel( 196 | Markdown(self.markdown_content), 197 | title="markdown", 198 | padding=0, 199 | expand=True, 200 | ), 201 | name="markdown", 202 | ) 203 | ) 204 | else: 205 | self.layout.split_row( 206 | Layout( 207 | Panel( 208 | self.pixels, height=self.panel_height, padding=0, expand=True 209 | ), 210 | name="image", 211 | ratio=1, 212 | ), 213 | Layout(right_layout, name="markdown", ratio=2), 214 | ) 215 | 216 | 217 | class CaptionAndRateLayout(BaseLayout): 218 | """用于显示图片字幕和评分的布局类""" 219 | 220 | def __init__( 221 | self, 222 | tag_description, 223 | rating, 224 | average_score, 225 | long_description, 226 | pixels, 227 | short_highlight_rate=0, 228 | long_highlight_rate=0, 229 | panel_height=32, 230 | console=None, 231 | ): 232 | """ 233 | 初始化字幕布局 234 | 235 | Args: 236 | tag_description: 标签描述 237 | rating: 高亮率 238 | average_score: 平均评分 239 | long_description: 长描述 240 | pixels: Rich Pixels对象 241 | long_highlight_rate: 长描述高亮率 242 | panel_height: 面板高度 243 | console: Rich控制台实例 244 | """ 245 | super().__init__(panel_height, console) 246 | self.tag_description = tag_description 247 | self.long_description = long_description 248 | self.long_highlight_rate = long_highlight_rate 249 | self.pixels = pixels 250 | self.rating_chart = self.create_rating_chart(rating) 251 | self.average_score = average_score 252 | self.create_layout() 253 | 254 | def create_layout(self): 255 | """创建字幕布局结构""" 256 | # 创建右侧的垂直布局 257 | right_layout = Layout() 258 | 259 | # 创建上半部分的水平布局(tag和short并排) 260 | top_layout = Layout() 261 | top_layout.split_row( 262 | Layout( 263 | Panel( 264 | Text(self.tag_description, style="magenta"), 265 | title="tags", 266 | height=self.panel_height // 2, 267 | padding=0, 268 | expand=True, 269 | ), 270 | ratio=1, 271 | ), 272 | Layout( 273 | Panel( 274 | self.rating_chart, 275 | title=f"rating - [yellow]average score:[/yellow] {self.average_score}", 276 | height=self.panel_height // 2, 277 | padding=0, 278 | expand=True, 279 | ), 280 | ratio=1, 281 | ), 282 | ) 283 | 284 | # 将右侧布局分为上下两部分 285 | right_layout.split_column( 286 | Layout(top_layout, ratio=1), 287 | Layout( 288 | Panel( 289 | self.long_description, 290 | title=f"long_description - [yellow]highlight rate:[/yellow] {self.long_highlight_rate}", 291 | height=self.panel_height // 2, 292 | padding=0, 293 | expand=True, 294 | ) 295 | ), 296 | ) 297 | 298 | # 主布局分为左右两部分 299 | self.layout.split_row( 300 | Layout( 301 | Panel(self.pixels, height=self.panel_height, padding=0, expand=True), 302 | name="image", 303 | ratio=1, 304 | ), 305 | Layout(right_layout, name="caption", ratio=2), 306 | ) 307 | 308 | def create_rating_chart(self, ratings, max_rating=10): 309 | """创建一个简单的评分图 310 | 311 | Args: 312 | ratings: 字典,包含维度名称和对应评分 313 | max_rating: 最大评分值 314 | """ 315 | if not ratings: 316 | return Pixels.from_ascii("No ratings available") 317 | 318 | # 获取维度列表并清理维度名称 319 | clean_ratings = {} 320 | for key, value in ratings.items(): 321 | # 移除所有非标准字符,不仅仅是方块字符 322 | clean_key = "" 323 | for char in key: 324 | # 只保留字母、数字、空格和常见标点符号 325 | if char.isalnum() or char.isspace() or char in "&/,.:-_()": 326 | clean_key += char 327 | # 如果清理后为空,使用原始键 328 | if not clean_key: 329 | clean_key = f"Dimension {len(clean_ratings) + 1}" 330 | clean_ratings[clean_key] = value 331 | 332 | # 创建彩虹颜色映射 333 | rainbow_colors = [ 334 | "bright_red", # 红色 335 | "orange3", # 橙色 336 | "yellow", # 黄色 337 | "green", # 绿色 338 | "spring_green3", # 青色 339 | "bright_blue", # 蓝色 340 | "blue_violet", # 靛色 341 | "purple", # 紫色 342 | "magenta", # 紫红色 343 | ] 344 | 345 | # 准备行和颜色映射 346 | lines = [] 347 | mapping = {} 348 | 349 | # 为每个维度创建一行 350 | for i, (dimension, rating) in enumerate(clean_ratings.items()): 351 | # 获取对应颜色 352 | color_index = min(i, len(rainbow_colors) - 1) 353 | color = rainbow_colors[color_index] 354 | 355 | # 处理维度名称,只保留第一个&前的内容 356 | short_dim = dimension.split("&")[0].strip() 357 | dim_text = f"{short_dim}:" 358 | # 评分条长度等于评分值 359 | bar_length = int(rating) 360 | 361 | # 使用隐藏的控制字符作为映射键,这些字符在普通文本中不太可能出现 362 | # 使用ASCII 01-31的不可见控制字符,每行使用不同的控制字符 363 | control_char = chr(1 + i) # 使用 SOH, STX, ETX 等不可见控制字符 364 | mapping[control_char] = Segment("■", Style(color=color)) 365 | 366 | # 生成评分条 367 | bar = control_char * bar_length 368 | 369 | # 特殊处理某些维度的最大分数 370 | current_max_rating = ( 371 | 5 372 | if dimension 373 | in ["Storytelling & Concept", "Setting & Environment Integration"] 374 | else max_rating 375 | ) 376 | value_text = f" {rating}/{current_max_rating}" 377 | 378 | # 组合行内容 379 | line = dim_text + bar + value_text 380 | lines.append(line) 381 | 382 | # 组合成ASCII图 383 | ascii_grid = "\n".join(lines) 384 | 385 | # 返回Pixels对象 386 | return Pixels.from_ascii(ascii_grid, mapping) 387 | -------------------------------------------------------------------------------- /utils/stream_util.py: -------------------------------------------------------------------------------- 1 | import av 2 | import re 3 | from typing import Tuple 4 | from rich.progress import Progress, BarColumn, TimeRemainingColumn 5 | from rich.console import Console 6 | from av.audio.format import AudioFormat 7 | from av.audio.layout import AudioLayout 8 | import subprocess 9 | import imageio_ffmpeg 10 | from pymediainfo import MediaInfo 11 | 12 | console = Console() 13 | 14 | 15 | def split_media_stream_clips(uri, media_type, subs, save_caption_func=None, **kwargs): 16 | """ 17 | Process media stream and extract clips based on subtitles. 18 | 19 | Args: 20 | uri (str): Path to the media file 21 | media_type (str): Type of media ('video' or 'audio') 22 | subs (pysrt.SubRipFile): Subtitles to process 23 | save_caption_func (callable): Function to save captions 24 | 25 | Returns: 26 | None 27 | """ 28 | with av.open(uri) as in_container: 29 | if media_type != "video": 30 | video_stream = None 31 | else: 32 | video_stream = next( 33 | (s for s in in_container.streams if s.type == "video"), 34 | None, 35 | ) 36 | # Try to get audio stream if available 37 | audio_stream = next( 38 | (s for s in in_container.streams if s.type == "audio"), 39 | None, 40 | ) 41 | 42 | # 添加字幕片段的进度条 43 | with Progress( 44 | "[progress.description]{task.description}", 45 | BarColumn(), 46 | "[progress.percentage]{task.percentage:>3.0f}%", 47 | TimeRemainingColumn(), 48 | console=console, 49 | ) as sub_progress: 50 | sub_task = sub_progress.add_task( 51 | f"[cyan]Processing subtitles for {uri.name}", 52 | total=len(subs), 53 | ) 54 | 55 | for i, sub in enumerate(subs): 56 | # if len(subs) < 2: 57 | # break 58 | clip_path = ( 59 | uri.parent / f"{uri.stem}_clip/{uri.stem}_{sub.index}{uri.suffix}" 60 | ) 61 | clip_path.parent.mkdir(parents=True, exist_ok=True) 62 | 63 | with av.open(str(clip_path), mode="w") as out_container: 64 | # copy encoder settings 65 | if video_stream: 66 | out_video_stream = out_container.add_stream_from_template( 67 | template=video_stream 68 | ) 69 | else: 70 | out_video_stream = None 71 | if audio_stream: 72 | # 为音频流使用特定的设置 73 | if media_type == "video": 74 | codec_name = "aac" 75 | out_audio_stream = out_container.add_stream( 76 | codec_name=codec_name, 77 | rate=48000, # AAC标准采样率 78 | ) 79 | out_audio_stream.layout = AudioLayout( 80 | "mono" 81 | ) # AAC通常使用立体声 82 | out_audio_stream.format = AudioFormat( 83 | "fltp" 84 | ) # AAC使用浮点平面格式 85 | elif uri.suffix == ".mp3": 86 | codec_name = "mp3" 87 | out_audio_stream = out_container.add_stream( 88 | codec_name=codec_name, 89 | rate=16000, 90 | ) 91 | out_audio_stream.layout = AudioLayout("mono") 92 | out_audio_stream.format = AudioFormat("s16p") 93 | else: 94 | codec_name = "pcm_s16le" 95 | out_audio_stream = out_container.add_stream( 96 | codec_name=codec_name, 97 | rate=16000, 98 | ) 99 | out_audio_stream.layout = AudioLayout("mono") 100 | out_audio_stream.format = AudioFormat("s16p") 101 | else: 102 | out_audio_stream = None 103 | 104 | # 正确计算 start 和 end 时间戳, 单位是 video_stream.time_base 105 | # 使用毫秒并根据 video_stream.time_base 转换 106 | start_seconds = ( 107 | sub.start.hours * 3600 108 | + sub.start.minutes * 60 109 | + sub.start.seconds 110 | ) 111 | end_seconds = ( 112 | sub.end.hours * 3600 + sub.end.minutes * 60 + sub.end.seconds 113 | ) 114 | if video_stream: 115 | start_offset = int( 116 | start_seconds 117 | * video_stream.time_base.denominator 118 | / video_stream.time_base.numerator 119 | ) # 开始时间戳偏移量 (基于 video_stream.time_base) 120 | else: 121 | start_offset = int( 122 | start_seconds 123 | * audio_stream.time_base.denominator 124 | / audio_stream.time_base.numerator 125 | ) # 开始时间戳偏移量 (基于 audio_stream.time_base) 126 | # seek to start 127 | in_container.seek( 128 | start_offset, 129 | stream=(video_stream if video_stream else audio_stream), 130 | ) 131 | 132 | # 手动跳过帧 (如果在 seek 之后需要的话) 133 | for frame in in_container.decode(video_stream, audio_stream): 134 | if frame.time > end_seconds: 135 | break 136 | 137 | if ( 138 | video_stream 139 | and isinstance(frame, av.VideoFrame) 140 | and frame.time >= start_seconds 141 | ): 142 | for packet in out_video_stream.encode(frame): 143 | out_container.mux(packet) 144 | elif ( 145 | audio_stream 146 | and isinstance(frame, av.AudioFrame) 147 | and frame.time >= start_seconds 148 | ): 149 | for packet in out_audio_stream.encode(frame): 150 | out_container.mux(packet) 151 | 152 | # Flush streams 153 | if out_video_stream: 154 | for packet in out_video_stream.encode(): 155 | out_container.mux(packet) 156 | if out_audio_stream: 157 | for packet in out_audio_stream.encode(): 158 | out_container.mux(packet) 159 | 160 | if save_caption_func: 161 | save_caption_func(clip_path, [sub.text], "image") 162 | sub_progress.advance(sub_task) 163 | 164 | 165 | def split_video_with_imageio_ffmpeg( 166 | uri, subs, save_caption_func=None, segment_time=120, **kwargs 167 | ): 168 | """ 169 | Process media stream and extract clips based on subtitles using ffmpeg. 170 | 171 | Args: 172 | uri (Path): Path to the media file 173 | subs (pysrt.SubRipFile): Subtitles to process 174 | save_caption_func (callable, optional): Function to save captions 175 | """ 176 | ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe() 177 | with Progress( 178 | "[progress.description]{task.description}", 179 | BarColumn(), 180 | "[progress.percentage]{task.percentage:>3.0f}%", 181 | TimeRemainingColumn(), 182 | console=console, 183 | ) as sub_progress: 184 | sub_task = sub_progress.add_task( 185 | f"[cyan]Processing subtitles for {uri.name}", 186 | total=len(subs), 187 | ) 188 | 189 | for i, sub in enumerate(subs): 190 | # if len(subs) < 2: 191 | # break 192 | clip_path = ( 193 | uri.parent / f"{uri.stem}_clip/{uri.stem}_{sub.index}{uri.suffix}" 194 | ) 195 | clip_path.parent.mkdir(parents=True, exist_ok=True) 196 | 197 | # 计算开始和结束时间 198 | start_time = f"{int(sub.start.hours):02d}:{int(sub.start.minutes):02d}:{int(sub.start.seconds):02d}.{int(sub.start.milliseconds):03d}" 199 | duration = ( 200 | (sub.end.hours - sub.start.hours) * 3600 201 | + (sub.end.minutes - sub.start.minutes) * 60 202 | + (sub.end.seconds - sub.start.seconds) 203 | + (sub.end.milliseconds - sub.start.milliseconds) / 1000 204 | ) 205 | 206 | if duration == segment_time: 207 | # 使用segment模式时的输出模板 208 | output_template = str( 209 | uri.parent / f"{uri.stem}_clip/{uri.stem}_%03d{uri.suffix}" 210 | ) 211 | command = [ 212 | ffmpeg_exe, 213 | "-i", 214 | str(uri), # 输入文件 215 | "-f", 216 | "segment", # 使用segment模式 217 | "-c", 218 | "copy", # 拷贝原始编码,速度更快 219 | "-segment_time", 220 | str(segment_time), # 指定片段时长(5分钟) 221 | "-reset_timestamps", 222 | "1", # 重置时间戳 223 | "-y", # 覆盖输出文件 224 | "-break_non_keyframes", 225 | "0", 226 | output_template, # 输出文件模板 227 | ] 228 | else: 229 | if uri.suffix == ".mp3": 230 | audio_codec = "mp3" 231 | elif uri.suffix == ".wav": 232 | audio_codec = "pcm_s16le" 233 | else: 234 | audio_codec = "aac" 235 | # 根据是否是第一个片段来调整命令 236 | if i == 0: 237 | # 第一个片段,-ss 放在 -i 前面以获得更精确的开始时间 238 | command = [ 239 | ffmpeg_exe, 240 | "-ss", 241 | start_time, # 开始时间 242 | "-t", 243 | str(duration), # 持续时间 244 | "-i", 245 | str(uri), # 输入文件 246 | "-c:v", 247 | "libx264", # 重新编码视频流 248 | "-c:a", 249 | audio_codec, # 重新编码音频流 250 | "-vf", 251 | "setpts=PTS-STARTPTS", # 重置视频时间戳 252 | "-af", 253 | "asetpts=PTS-STARTPTS", # 重置音频时间戳 254 | "-y", # 覆盖输出文件 255 | str(clip_path), # 输出文件 256 | ] 257 | else: 258 | # 其他片段,-i 放在前面以确保片段连接 259 | command = [ 260 | ffmpeg_exe, 261 | "-i", 262 | str(uri), # 输入文件 263 | "-ss", 264 | start_time, # 开始时间 265 | "-t", 266 | str(duration), # 持续时间 267 | "-c:v", 268 | "libx264", # 重新编码视频流 269 | "-c:a", 270 | audio_codec, # 重新编码音频流 271 | "-vf", 272 | "setpts=PTS-STARTPTS", # 重置视频时间戳 273 | "-af", 274 | "asetpts=PTS-STARTPTS", # 重置音频时间戳 275 | "-y", # 覆盖输出文件 276 | str(clip_path), # 输出文件 277 | ] 278 | 279 | console.print(f"Running command: {' '.join(command)}") 280 | try: 281 | # 使用 subprocess.PIPE 并设置 encoding='utf-8' 282 | process = subprocess.Popen( 283 | command, 284 | stdout=subprocess.PIPE, 285 | stderr=subprocess.PIPE, 286 | text=True, 287 | encoding="utf-8", 288 | errors="replace", 289 | ) 290 | stdout, stderr = process.communicate() 291 | 292 | if process.returncode != 0: 293 | console.print(f"[red]Error running ffmpeg:[/red] {stderr}") 294 | raise Exception(f"FFmpeg failed: {stderr}") 295 | 296 | except Exception as e: 297 | console.print(f"[red]Failed to run ffmpeg:[/red] {str(e)}") 298 | raise 299 | 300 | if save_caption_func: 301 | save_caption_func(clip_path, [sub.text], "image") 302 | sub_progress.advance(sub_task) 303 | 304 | if sub.end.minutes - sub.start.minutes == segment_time / 60: 305 | sub_progress.advance(sub_task, advance=4) 306 | break 307 | 308 | 309 | def sanitize_filename(name: str) -> str: 310 | """Sanitizes filenames. 311 | 312 | Requirements: 313 | - Only lowercase alphanumeric characters or dashes (-) 314 | - Cannot begin or end with a dash 315 | - Max length is 40 characters 316 | """ 317 | # Convert to lowercase and replace non-alphanumeric chars with dash 318 | sanitized = re.sub(r"[^a-z0-9-]", "-", name.lower()) 319 | # Replace multiple dashes with single dash 320 | sanitized = re.sub(r"-+", "-", sanitized) 321 | # Remove leading and trailing dashes 322 | sanitized = sanitized.strip("-") 323 | # If empty after sanitization, use a default name 324 | if not sanitized: 325 | sanitized = "file" 326 | # Ensure it starts and ends with alphanumeric character 327 | if sanitized[0] == "-": 328 | sanitized = "f" + sanitized 329 | if sanitized[-1] == "-": 330 | sanitized = sanitized + "f" 331 | # If length exceeds 40, keep the first 20 and last 19 chars with a dash in between 332 | if len(sanitized) > 40: 333 | # Take parts that don't end with dash 334 | first_part = sanitized[:20].rstrip("-") 335 | last_part = sanitized[-19:].rstrip("-") 336 | sanitized = first_part + "-" + last_part 337 | return sanitized 338 | 339 | 340 | def get_video_duration(file_path): 341 | """ 342 | 获取视频片段的精确持续时间,用于字幕偏移计算 343 | 344 | Args: 345 | file_path: 视频文件路径 346 | 347 | Returns: 348 | float: 视频持续时间(毫秒) 349 | """ 350 | for track in MediaInfo.parse(file_path).tracks: 351 | if track.track_type == "Video": 352 | return track.duration 353 | elif track.track_type == "Audio": 354 | return track.duration 355 | return 0 356 | 357 | 358 | def _round_to_16(value: int) -> int: 359 | """将值四舍五入为最接近的16的倍数""" 360 | return (value // 16) * 16 361 | 362 | 363 | def calculate_dimensions( 364 | width, height, max_long_edge: int = None, max_short_edge: int = None 365 | ) -> Tuple[int, int]: 366 | """ 367 | 根据原始尺寸、最长边和最短边的最大值限制计算新尺寸 368 | 369 | Args: 370 | width: 原始宽度 371 | height: 原始高度 372 | max_long_edge: 最长边的最大值 373 | max_short_edge: 最短边的最大值 374 | 375 | Returns: 376 | 调整后的宽度和高度组成的元组 377 | """ 378 | # 设置默认值 379 | if max_long_edge is None and max_short_edge is None: 380 | max_long_edge = 1024 381 | 382 | # 计算原始纵横比 383 | aspect_ratio = width / height 384 | 385 | # 确定长边和短边 386 | is_width_longer = width >= height 387 | 388 | # 将原始尺寸调整为16的倍数 389 | new_width = _round_to_16(width) 390 | new_height = _round_to_16(height) 391 | 392 | # 对尺寸进行多轮调整,直到满足所有条件 393 | for _ in range(2): # 最多进行两轮调整就足够了 394 | # 处理最长边的最大值限制 395 | if max_long_edge is not None: 396 | if is_width_longer and new_width > max_long_edge: 397 | new_width = max_long_edge 398 | new_height = _round_to_16(int(new_width / aspect_ratio)) 399 | elif not is_width_longer and new_height > max_long_edge: 400 | new_height = max_long_edge 401 | new_width = _round_to_16(int(new_height * aspect_ratio)) 402 | 403 | # 处理最短边的最大值限制 404 | if max_short_edge is not None: 405 | if is_width_longer and new_height > max_short_edge: 406 | new_height = max_short_edge 407 | new_width = _round_to_16(int(new_height * aspect_ratio)) 408 | elif not is_width_longer and new_width > max_short_edge: 409 | new_width = max_short_edge 410 | new_height = _round_to_16(int(new_width / aspect_ratio)) 411 | 412 | return new_width, new_height 413 | --------------------------------------------------------------------------------