├── !del_txt.bat
├── 2tag
├── may.jpg
└── march.jpg
├── update_pip.cmd
├── convert_to_bnb_nf4.bat
├── version_checker.bat
├── fix_minicpmo.bat
├── reinstall.bat
├── requirements_ver.txt
├── .gitignore
├── update.cmd
├── install.bat
├── fix_minicpmo.py
├── utils.py
├── convert_to_bnb_nf4.py
├── version_checker.py
├── arg_parser.py
├── batch_processing.bat
├── requirements.txt
├── ide-cap-chan.py
├── README.md
├── image_processor.py
├── LICENSE
└── model_handlers.py
/!del_txt.bat:
--------------------------------------------------------------------------------
1 | del .\2tag\*.txt
--------------------------------------------------------------------------------
/2tag/may.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/2dameneko/ide-cap-chan/HEAD/2tag/may.jpg
--------------------------------------------------------------------------------
/2tag/march.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/2dameneko/ide-cap-chan/HEAD/2tag/march.jpg
--------------------------------------------------------------------------------
/update_pip.cmd:
--------------------------------------------------------------------------------
1 | call .\venv\Scripts\activate.bat
2 | python.exe -m pip install --upgrade pip
--------------------------------------------------------------------------------
/convert_to_bnb_nf4.bat:
--------------------------------------------------------------------------------
1 | call .\venv\Scripts\activate.bat
2 | rem pip install --upgrade bitsandbytes
3 | python "convert_to_bnb_nf4.py" "ToriiGate-v0.3"
--------------------------------------------------------------------------------
/version_checker.bat:
--------------------------------------------------------------------------------
1 | python "version_checker.py"
2 | rem dual call for in/out venv dependencies check
3 | call .\venv\Scripts\activate.bat
4 | rem pip show transformers
5 | rem pip list
6 | python "version_checker.py"
--------------------------------------------------------------------------------
/fix_minicpmo.bat:
--------------------------------------------------------------------------------
1 | REM Run if got an error:
2 | REM No such file or directory: 'C:\\Users\\username\\.cache\\huggingface\\modules\\transformers_modules\\MiniCPM-o-2_6\\image_processing_minicpmv.py'
3 | python "fix_minicpmo.py"
--------------------------------------------------------------------------------
/reinstall.bat:
--------------------------------------------------------------------------------
1 | python -m venv --system-site-packages venv
2 | call venv\Scripts\activate
3 | rem pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
4 | pip install -r requirements.txt
5 | pause
6 |
--------------------------------------------------------------------------------
/requirements_ver.txt:
--------------------------------------------------------------------------------
1 | huggingface_hub==0.28.1
2 | einops==0.8.1
3 | accelerate==1.4.0
4 | bitsandbytes==0.45.2
5 | qwen-vl-utils==0.0.10
6 | vector-quantize-pytorch==1.21.8
7 | vocos==0.1.0
8 | soundfile==0.13.1
9 | librosa==0.10.2.post1
10 | exllamav2==0.2.7+cu121.torch2.5.0
11 | flash_attn==2.7.0.post2
12 | typing_extensions==4.12.2
13 | transformers==4.48.3
14 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore all folders
2 | */
3 |
4 | # Except these
5 | !2tag/march.jpt
6 | !2tag/may.jpg
7 | !2tag/march.txt
8 | !2tag/may.txt
9 |
10 | # Ignore Python bytecode files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # Ignore virtual environment directories
16 | venv/
17 | .venv/
18 | env/
19 | .env/
20 | ENV/
21 | env.bak/
22 | venv.bak/
23 |
24 | # Ignore IDE and editor files
25 | .vscode/
26 | .idea/
27 | *.swp
28 | *.swo
29 |
30 | # Ignore system files
31 | .DS_Store
32 | Thumbs.db
33 |
--------------------------------------------------------------------------------
/update.cmd:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | REM Pull the latest changes from the repository
4 | git pull
5 |
6 | REM Set default values for environment variables if not already defined
7 | if not defined PYTHON (set PYTHON=python)
8 | if defined GIT (set "GIT_PYTHON_GIT_EXECUTABLE=%GIT%")
9 | if not defined VENV_DIR (set "VENV_DIR=%~dp0venv")
10 |
11 | REM Check if the virtual environment exists
12 | if exist "%VENV_DIR%\Scripts\Python.exe" (
13 | REM Activate the virtual environment
14 | call "%VENV_DIR%\Scripts\activate.bat"
15 | set PYTHON="%VENV_DIR%\Scripts\Python.exe"
16 | echo Virtual environment activated: %PYTHON%
17 | ) else (
18 | echo Virtual environment not found at %VENV_DIR%. Please create it first.
19 | exit /b 1
20 | )
21 |
22 | REM Install project requirements
23 | pip install -r requirements.txt
24 |
25 | REM End of script
26 | echo Update complete.
--------------------------------------------------------------------------------
/install.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | REM Set the directory for the virtual environment
4 | set "VENV_DIR=%~dp0venv"
5 |
6 | REM Check if the virtual environment already exists
7 | if exist "%VENV_DIR%\Scripts\Python.exe" (
8 | echo Virtual environment already found at %VENV_DIR%.
9 | exit /b 1
10 | ) else (
11 | REM Create the virtual environment
12 | python -m venv --system-site-packages venv
13 | goto :activate
14 | )
15 |
16 | :activate
17 | REM Activate the virtual environment
18 | call "%VENV_DIR%\Scripts\activate.bat"
19 |
20 | REM Upgrade pip
21 | python -m pip install --upgrade pip
22 |
23 | REM Install PyTorch and related packages
24 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
25 |
26 | REM Install project requirements
27 | pip install -r requirements.txt
28 |
29 | REM End of script
30 | echo Virtual environment setup complete.
--------------------------------------------------------------------------------
/fix_minicpmo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | def copy_image_processing_file():
5 | # Get current Windows username
6 | username = os.getenv('USERNAME')
7 |
8 | # Construct paths
9 | source_file = os.path.join('MiniCPM-o-2_6', 'image_processing_minicpmv.py')
10 | destination_dir = os.path.join(
11 | 'C:\\',
12 | 'Documents and Settings',
13 | username,
14 | '.cache',
15 | 'huggingface',
16 | 'modules',
17 | 'transformers_modules',
18 | 'MiniCPM-o-2_6'
19 | )
20 |
21 | # Check if source file exists
22 | if not os.path.exists(source_file):
23 | raise FileNotFoundError(f"Source file '{source_file}' not found")
24 |
25 | # Create destination directory if it doesn't exist
26 | os.makedirs(destination_dir, exist_ok=True)
27 |
28 | # Perform file copy
29 | shutil.copy(source_file, destination_dir)
30 | print(f"File successfully copied to: {destination_dir}")
31 |
32 | if __name__ == "__main__":
33 | copy_image_processing_file()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | from PIL import Image
4 |
5 | GPU_TEST_ITERATIONS = 4000
6 | GPU_TEST_SIZE = 1000
7 |
8 | def measure_gpu_speed(device):
9 | """
10 | Measure the speed of a GPU by performing matrix operations.
11 |
12 | Args:
13 | device: The CUDA device to measure
14 |
15 | Returns:
16 | float: A score representing the relative speed of the GPU
17 | """
18 | start_time = time.time()
19 | dummy_tensor = torch.randn(GPU_TEST_SIZE, GPU_TEST_SIZE).to(device)
20 | for _ in range(GPU_TEST_ITERATIONS):
21 | _ = dummy_tensor @ dummy_tensor
22 | end_time = time.time()
23 | return 1 / (end_time - start_time)
24 |
25 | def resize_image_proportionally(image, max_width=None, max_height=None):
26 | """
27 | Resize an image proportionally to fit within the specified dimensions.
28 |
29 | Args:
30 | image: PIL Image to resize
31 | max_width: Maximum width
32 | max_height: Maximum height
33 |
34 | Returns:
35 | PIL Image: Resized image
36 | """
37 | if (max_width is None or max_width <= 0) and (max_height is None or max_height <= 0):
38 | return image
39 |
40 | original_width, original_height = image.size
41 |
42 | if ((max_width is None or original_width <= max_width) and
43 | (max_height is None or original_height <= max_height)):
44 | return image
45 |
46 | if max_width and max_height:
47 | width_ratio = max_width / original_width
48 | height_ratio = max_height / original_height
49 | ratio = min(width_ratio, height_ratio)
50 | elif max_width:
51 | ratio = max_width / original_width
52 | else:
53 | ratio = max_height / original_height
54 |
55 | new_width = int(original_width * ratio)
56 | new_height = int(original_height * ratio)
57 |
58 | resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
59 | return resized_image
60 |
--------------------------------------------------------------------------------
/convert_to_bnb_nf4.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import json
4 | from transformers import AutoModel, AutoModelForVision2Seq, AutoTokenizer, BitsAndBytesConfig
5 | import torch
6 | import sys
7 |
8 | # Define the quantization configuration
9 | quantization_config = BitsAndBytesConfig(
10 | load_in_4bit=True,
11 | bnb_4bit_quant_type="nf4",
12 | bnb_4bit_use_double_quant=True,
13 | bnb_4bit_compute_dtype=torch.bfloat16,
14 | )
15 |
16 | # Get model name from command line arguments
17 | model_name = sys.argv[1]
18 |
19 | # Load the model and tokenizer with the quantization configuration
20 | model = AutoModel.from_pretrained(model_name, quantization_config=quantization_config, trust_remote_code = True, low_cpu_mem_usage = True)
21 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
22 |
23 | # Print the original model size
24 | original_size = sum(p.numel() for p in model.parameters())
25 | print(f"Original model size: {original_size / 1e6:.2f} million parameters")
26 |
27 | # Create new model name with suffix
28 | new_model_name = model_name + '-nf4'
29 |
30 | # Save the quantized model with new name
31 | model.save_pretrained(new_model_name)
32 | tokenizer.save_pretrained(new_model_name)
33 |
34 | # Modify config.json to remove the specified lines
35 | with open(os.path.join(new_model_name, 'config.json'), 'r') as file:
36 | config_dict = json.load(file)
37 |
38 | # Remove the specified fields
39 | config_dict["quantization_config"].pop("_load_in_4bit", None)
40 | config_dict["quantization_config"].pop("_load_in_8bit", None)
41 | config_dict["quantization_config"].pop("quant_method", None)
42 |
43 | # Write the modified config back to the file
44 | with open(os.path.join(new_model_name, 'config.json'), 'w') as file:
45 | json.dump(config_dict, file, indent=4)
46 |
47 | # Copy preprocessor_config.json and chat_template.json from input model to new model
48 | shutil.copyfile(os.path.join(model_name, 'preprocessor_config.json'), os.path.join(new_model_name, 'preprocessor_config.json'))
49 | shutil.copyfile(os.path.join(model_name, 'chat_template.json'), os.path.join(new_model_name, 'chat_template.json'))
50 |
51 | print(f"Model has been quantized and saved as {new_model_name}. Config.json has been modified.")
--------------------------------------------------------------------------------
/version_checker.py:
--------------------------------------------------------------------------------
1 | import re
2 | import pkg_resources
3 | from collections import OrderedDict
4 |
5 | def get_installed_versions():
6 | # Regular expression to extract package names from requirements.txt lines
7 | package_pattern = re.compile(r'^([a-zA-Z0-9_-]+)')
8 |
9 | installed_versions = OrderedDict()
10 | missing_packages = []
11 |
12 | try:
13 | # Read existing requirements_ver.txt if it exists
14 | try:
15 | with open('requirements_ver.txt', 'r') as f:
16 | for line in f:
17 | line = line.strip()
18 | if line and '==' in line:
19 | package_name, version = line.split('==')
20 | installed_versions[package_name] = version
21 | except FileNotFoundError:
22 | pass # If file doesn't exist, proceed without loading
23 |
24 | with open('requirements.txt', 'r') as f:
25 | for line in f:
26 | line = line.strip()
27 |
28 | # Skip empty lines and comments
29 | if not line or line.startswith('#'):
30 | continue
31 |
32 | # Extract package name from the line
33 | match = package_pattern.match(line)
34 | if not match:
35 | continue # Skip lines that don't match package pattern
36 |
37 | package_name = match.group(1)
38 |
39 | try:
40 | # Get installed version
41 | version = pkg_resources.get_distribution(package_name).version
42 | installed_versions[package_name] = version
43 | except pkg_resources.DistributionNotFound:
44 | if package_name not in installed_versions:
45 | missing_packages.append(package_name)
46 |
47 | except FileNotFoundError:
48 | print("Error: requirements.txt file not found")
49 | return
50 |
51 | # Write results to requirements_ver.txt
52 | with open('requirements_ver.txt', 'w') as f:
53 | for package_name, version in installed_versions.items():
54 | f.write(f"{package_name}=={version}\n")
55 |
56 | # Print summary
57 | print(f"Successfully wrote {len(installed_versions)} packages to requirements_ver.txt")
58 | if missing_packages:
59 | print(f"Missing packages: {', '.join(missing_packages)}")
60 |
61 | if __name__ == "__main__":
62 | get_installed_versions()
63 |
--------------------------------------------------------------------------------
/arg_parser.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import sys
3 |
4 | def parse_arguments():
5 | parser = ArgumentParser(description='Generate captions for images')
6 |
7 | parser.add_argument('--model_path', type=str, default="", help='Path to the used model')
8 | parser.add_argument('--model_type', type=str, default="exllama2",
9 | help='Model type (supported architectures: idefics3, llava, joy-caption, molmo, qwen2vl, molmo72b, pixtral, exllama2, minicpmo, generic (You can try this option if your model is not listed as supported. No warranties.))')
10 | parser.add_argument('--input_dir', type=str, default="./2tag", help='Path to the folder containing images')
11 | parser.add_argument('--CUDA_VISIBLE_DEVICES', type=str, default="0",
12 | help='Comma-separated list of CUDA devices. WARNING: multi-GPU captioning can overload your power supply unit. Model molmo72b ignores this arg and requires 2x24GB GPU')
13 | parser.add_argument('--caption_suffix', type=str, default=".txt", help='File extension for generated caption files')
14 | parser.add_argument('--tags_suffix', type=str, default=".ttxt", help='File extension for existing image info file (like *booru tags, traits, characters, etc)')
15 | parser.add_argument('--caption_format', type=str, choices=['json', 'markdown', 'short', 'long', 'bbox'], default='long',
16 | help='Format of the generated captions (supported formats: json, markdown, short, long, bbox), (req. ToriiGate-family models)')
17 |
18 | bool_args = [
19 | ('--add_tags', 'Use an additional file as existing *booru tags to enhance captioning, (req. ToriiGate-family models)'),
20 | ('--add_chars', 'Use an additional file as information about represented characters, (req. ToriiGate >= 0.4 model)'),
21 | ('--add_char_traits', 'Use an additional file as information about character traits, (req. ToriiGate >= 0.4 model)'),
22 | ('--add_info', 'Use an additional file as misc information about image, (req. ToriiGate >= 0.4 model)'),
23 | ('--no_chars', 'Do not add any characters to the output, (req. ToriiGate >= 0.4 model)'),
24 | ]
25 |
26 | for arg, help_text in bool_args:
27 | parser.add_argument(arg, default=False, action='store_true', help=help_text)
28 |
29 | args = parser.parse_args()
30 |
31 | check_mutually_exclusive(
32 | args, ['--add_tags', '--add_chars', '--add_char_traits', '--add_info']
33 | )
34 |
35 | check_mutually_exclusive(
36 | args, ['--add_chars', '--no_chars']
37 | )
38 |
39 | return args
40 |
41 | def check_mutually_exclusive(args, arg_names):
42 | args_list = [getattr(args, arg_name.replace('--', '')) for arg_name in arg_names]
43 | if sum(args_list) > 1:
44 | print(f"Error: Only one of the following arguments can be True at a time: {', '.join(arg_names)}")
45 | sys.exit(1)
46 |
--------------------------------------------------------------------------------
/batch_processing.bat:
--------------------------------------------------------------------------------
1 | call .\venv\Scripts\activate.bat
2 | python "ide-cap-chan.py"
3 |
4 | rem Full command line args example:
5 | rem python "ide-cap-chan3.py" --model_path "model_name" --model_type "model_type" --input_dir "folder_with_images_to_tag" --CUDA_VISIBLE_DEVICES "0, 1" --caption_suffix ".txt" --dont_use_tags --tags_suffix ".ttxt"
6 |
7 | rem Local models
8 | rem python "ide-cap-chan.py" --model_path "Minthy_ToriiGate-v0.4-2B-exl2-8bpw"
9 | rem python "ide-cap-chan.py" --model_type "idefics3" --model_path "ToriiGate-v0.3" --CUDA_VISIBLE_DEVICES "1"
10 | rem python "ide-cap-chan.py" --model_type "idefics3" --model_path "ToriiGate-v0.3-nf4" --CUDA_VISIBLE_DEVICES "0"
11 | rem python "ide-cap-chan.py" --model_type "idefics3" --model_path "ToriiGate-v0.3" --CUDA_VISIBLE_DEVICES "0" --add_tags --caption_suffix ".text" --tags_suffix ".txt"
12 | rem python "ide-cap-chan.py" --model_type "qwen2vl" --model_path "Minthy_ToriiGate-v0.4-2B" --CUDA_VISIBLE_DEVICES "0"
13 | rem python "ide-cap-chan.py" --model_type "qwen2vl" --model_path "Vikhr-2-VL-2b-Instruct-experimental" --CUDA_VISIBLE_DEVICES "0"
14 | rem python "ide-cap-chan.py" --model_type "minicpmo" --model_path "MiniCPM-o-2_6" --CUDA_VISIBLE_DEVICES "0"
15 | rem python "ide-cap-chan.py" --model_type "qwen2vl" --model_path "Minthy_ToriiGate-v0.4-7B" --CUDA_VISIBLE_DEVICES "0" --no_chars --add_tags --caption_format "markdown" --caption_suffix ".text" --tags_suffix ".txt"
16 | rem python "ide-cap-chan.py" --model_type "exllama2" --model_path "Minthy_ToriiGate-v0.4-2B-exl2-8bpw" --CUDA_VISIBLE_DEVICES "0" --no_chars --add_tags --caption_format "bbox" --caption_suffix ".text" --tags_suffix ".txt"
17 | rem python "ide-cap-chan.py" --model_type "exllama2" --model_path "Minthy_ToriiGate-v0.4-2B-exl2-8bpw" --CUDA_VISIBLE_DEVICES "0, 1"
18 | rem python "ide-cap-chan.py" --model_type "exllama2" --model_path "Minthy_ToriiGate-v0.4-7B-exl2-8bpw" --CUDA_VISIBLE_DEVICES "0"
19 | rem python "ide-cap-chan.py" --model_type "pixtral" --model_path "Ertugrul_Pixtral-12B-Captioner-Relaxed" --CUDA_VISIBLE_DEVICES "0"
20 | rem python "ide-cap-chan.py" --model_type "molmo" --model_path "ctranslate2-4you_molmo-7B-O-bnb-4bit" --CUDA_VISIBLE_DEVICES "0"
21 | rem python "ide-cap-chan.py" --model_type "molmo72b" --model_path "SeanScripts_Molmo-72B-0924-nf4"
22 | rem python "ide-cap-chan.py" --model_type "qwen2vl" --model_path "Ertugrul_Qwen2-VL-7B-Captioner-Relaxed" --CUDA_VISIBLE_DEVICES "0"
23 | rem python "ide-cap-chan.py" --model_type "idefics3" --model_path "Idefics3-8B-Llama3"
24 | rem python "ide-cap-chan.py" --model_type "llava" --model_path "llava-v1.6-mistral-7b-hf-nf4" --CUDA_VISIBLE_DEVICES "0"
25 | rem python "ide-cap-chan.py" --model_type "llava" --model_path "llava-v1.6-mistral-7b-hf" --CUDA_VISIBLE_DEVICES "0"
26 | rem python "ide-cap-chan.py" --model_type "joy-caption" --model_path "llama-joycaption-alpha-two-hf-llava" --CUDA_VISIBLE_DEVICES "0"
27 | rem python "ide-cap-chan.py" --model_type "llava" --model_path "llama-joycaption-alpha-two-hf-llava" --CUDA_VISIBLE_DEVICES "0"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | #git+https://github.com/huggingface/transformers #latest, minicpm_o broken (due Loggits deprecated)
2 | transformers==4.48.3 #last verion minicpmo compatible
3 | #transformers==4.44.2 #recommended for minicpm_o, but doesn't have support for idefics3
4 | accelerate==1.4.0
5 | huggingface_hub==0.28.1
6 | bitsandbytes==0.45.5
7 | typing_extensions==4.12.2
8 | einops==0.8.1
9 | qwen-vl-utils==0.0.10
10 | vector-quantize-pytorch==1.21.8
11 | vocos==0.1.0
12 | soundfile==0.13.1
13 | librosa==0.10.2.post1
14 |
15 | http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/modelscope_studio-0.4.0.9-py3-none-any.whl
16 |
17 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
18 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
19 |
20 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
21 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
22 |
23 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp312-cp312-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.12"
24 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp312-cp312-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.12"
25 |
26 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13"
27 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13"
28 |
29 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
30 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
31 |
32 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
33 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
34 |
35 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp312-cp312-win_amd64.whl; platform_system == "Windows" and python_version == "3.12"
36 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp312-cp312-win_amd64.whl; platform_system == "Windows" and python_version == "3.12"
37 |
38 | exllamav2@ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13"
39 | flash_attn@ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13"
40 |
--------------------------------------------------------------------------------
/ide-cap-chan.py:
--------------------------------------------------------------------------------
1 | # ide-cap-chan v0.96
2 | from arg_parser import parse_arguments
3 | from utils import measure_gpu_speed
4 | from image_processor import process_image_worker
5 | import torch.multiprocessing as mp
6 | from os import walk
7 | from os.path import splitext as split_extension
8 | from pathlib import Path
9 | import queue
10 | import time
11 |
12 | def main():
13 | args = parse_arguments()
14 |
15 | device_ids = list(map(int, args.CUDA_VISIBLE_DEVICES.split(',')))
16 |
17 | supported_model_types = ["idefics3", "llava", "joy-caption", "molmo", "qwen2vl", "molmo72b", "pixtral", "exllama2", "minicpmo", "generic"]
18 | input_model_type = args.model_type.lower()
19 | if input_model_type not in supported_model_types:
20 | print(f"Model type '{input_model_type}' not supported. Supported loaders: {', '.join(supported_model_types)}.")
21 | return
22 |
23 | model_name_or_path = args.model_path or {
24 | 'idefics3': "2dameneko/ToriiGate-v0.3-nf4",
25 | 'llava': "2dameneko/llava-v1.6-mistral-7b-hf-nf4",
26 | 'joy-caption': "fancyfeast/llama-joycaption-alpha-two-hf-llava",
27 | #'qwen2vl': "Ertugrul/Qwen2-VL-7B-Captioner-Relaxed",
28 | #'qwen2vl': "Minthy/ToriiGate-v0.4-2B",
29 | 'qwen2vl': "Vikhrmodels/Vikhr-2-VL-2b-Instruct-experimental",
30 | 'molmo': "cyan2k/molmo-7B-O-bnb-4bit",
31 | 'molmo72b': "SeanScripts/Molmo-72B-0924-nf4",
32 | 'pixtral': "Ertugrul/Pixtral-12B-Captioner-Relaxed",
33 | 'exllama2': "Minthy/ToriiGate-v0.4-2B-exl2-8bpw",
34 | #'exllama2': "Minthy/ToriiGate-v0.4-7B-exl2-8bpw",
35 | 'minicpmo': "openbmb/MiniCPM-o-2_6",
36 | 'generic': None,
37 | }[input_model_type]
38 |
39 | quant_suffixes = ["nf4", "bnb", "4bit"]
40 | use_nf4 = any(suffix in model_name_or_path for suffix in quant_suffixes)
41 |
42 | if input_model_type == "joy-caption" and use_nf4:
43 | print(f"Model type '{input_model_type}' not supported with -nf4 quantization. Set to false.")
44 | use_nf4 = False
45 |
46 | args_dict = {
47 | 'use_nf4' : use_nf4,
48 | 'caption_suffix' : args.caption_suffix,
49 | 'tags_suffix' : args.tags_suffix,
50 | 'add_tags' : args.add_tags,
51 | 'add_chars': args.add_chars,
52 | 'add_char_traits' : args.add_char_traits,
53 | 'add_info' : args.add_info,
54 | 'no_chars' : args.no_chars,
55 | 'caption_format' : args.caption_format,
56 | }
57 |
58 | input_dir = args.input_dir
59 | image_extensions = [".jpg", ".png", ".webp", ".jpeg"]
60 |
61 | # Measure GPU speeds for informational purposes
62 | gpu_speeds = [(i, measure_gpu_speed(f"cuda:{i}")) for i in device_ids]
63 |
64 | print(f'Using GPU ids: {device_ids}')
65 | print("GPUs speeds:")
66 | for gpu_id, gpu_speed in gpu_speeds:
67 | print(f" {gpu_id} | {gpu_speed:.2f}")
68 | print(f'Using model: {model_name_or_path} (type: {input_model_type})')
69 | print(f'Use quantization: {args_dict.get("use_nf4")}')
70 |
71 | # Find existing captions to avoid reprocessing
72 | existing_captions = []
73 | for root, _, files in walk(input_dir):
74 | for file in files:
75 | file_path = Path(root) / file
76 | if file_path.suffix.lower() == args_dict.get('caption_suffix'):
77 | path, _ = split_extension(str(file_path))
78 | existing_captions.append(path)
79 |
80 | # Create a list of files to process
81 | filelist = []
82 | for root, _, files in walk(input_dir):
83 | for file in files:
84 | file_path = Path(root) / file
85 | if file_path.suffix.lower() in image_extensions:
86 | path, _ = split_extension(str(file_path))
87 | if path not in existing_captions:
88 | filelist.append(file_path)
89 |
90 | if not filelist:
91 | print('There are no files to process.')
92 | return
93 |
94 | # Create a shared job queue
95 | job_queue = mp.Queue()
96 | result_queue = mp.Queue()
97 |
98 | # Put all files in the job queue
99 | for file_path in filelist:
100 | job_queue.put(file_path)
101 |
102 | # Add termination signals (one for each worker)
103 | for _ in range(len(device_ids)):
104 | job_queue.put(None)
105 |
106 | # Create and start worker processes
107 | processes = []
108 | for i, gpu_id in enumerate(device_ids):
109 | p = mp.Process(
110 | target=process_image_worker,
111 | args=(
112 | i, # worker_id
113 | gpu_id, # gpu_id
114 | job_queue,
115 | result_queue,
116 | model_name_or_path,
117 | input_model_type,
118 | args_dict,
119 | len(filelist) # total_files
120 | )
121 | )
122 | p.start()
123 | processes.append(p)
124 |
125 | # Monitor progress
126 | completed_files = 0
127 | total_files = len(filelist)
128 | start_time = time.time()
129 |
130 | while completed_files < total_files:
131 | try:
132 | # Get result from the result queue
133 | result = result_queue.get(timeout=1.0)
134 | if result is not None:
135 | worker_id, gpu_id, file_name, processing_time = result
136 | completed_files += 1
137 |
138 | # Calculate ETA
139 | elapsed = time.time() - start_time
140 | avg_time_per_file = elapsed / completed_files
141 | remaining_time = avg_time_per_file * (total_files - completed_files)
142 |
143 | print(f"Worker {worker_id} (GPU {gpu_id}): {completed_files}/{total_files} - {file_name} - {processing_time:.2f}s - ETA: {remaining_time:.2f}s")
144 | except queue.Empty:
145 | # No result available, just continue
146 | continue
147 |
148 | # Wait for all processes to finish
149 | for p in processes:
150 | p.join()
151 |
152 | print(f"All {total_files} files processed successfully.")
153 |
154 | if __name__ == "__main__":
155 | main()
156 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ide-cap-chan
2 |
3 |
4 |

5 |
6 |
7 | ide-cap-chan is a utility for batch captioning images with natural language using various Vision-Language (VL) models.
8 |
9 | ## Features
10 | * **High-speed processing**: Optimized for rapid batch caption generation with ExLlama2, Qwen2-VL-7B-Instruct, Qwen2-VL-2B-Instruct (Vikhr-family included),
11 | Idefics3-8B-Llama3, LLaVa-NeXT (LLaVa-1.6), Llama JoyCaption Alpha Two, Molmo-7B-O, Molmo-72B, MiniCPM-o-2_6 and Pixtral models
12 | * **Multi-GPU support**: Distribute workloads across multiple GPUs
13 | * **Efficient quantization**: Supports ExLlama2 (exl2), int8, and nf4 quantization for reduced VRAM usage
14 | * **Autoload strategies**: VRAM-optimized loading
15 | * **Model flexibility**: Use default or custom models via CLI arguments.
16 | * **Input flexibility**: Supports Hugging Face, local, and external models
17 | * **Tag integration**: Enhance captions with existing tags/captions
18 | * **Process control**: Interrupt and resume captioning tasks
19 | * **Batch processing**: Recursively process subfolders in input directories
20 |
21 | ## Requirements
22 | * NVIDIA GPU with CUDA support (8GB VRAM minimum for llava, 12GB recommended for Qwen2-VL-7B in exl2, 48GB VRAM total for Molmo-72B)
23 |
24 | ## Installation
25 | 1. Clone the repository:
26 | `git clone https://github.com/2dameneko/ide-cap-chan`
27 | 2. Install dependencies:
28 | - **Windows**: Run `install.bat`
29 | - **Linux**: Create a virtual environment and install requirements:
30 | ```bash
31 | python -m venv venv
32 | source venv/bin/activate
33 | pip install -r requirements.txt
34 | ```
35 |
36 | ## Usage
37 | 1. Place images and corresponding tag files in the input folder (default: `2tag`)
38 | 2. Start processing:
39 | - **Windows**: Run `batch_processing.bat`
40 | - **Linux**: Execute `python ide-cap-chan.py`
41 | 3. Specify alternative models using CLI arguments
42 | 4. Customize prompts in `model_handler.py` (modify `system_prompt` and `user_prompt`)
43 |
44 | ## Updating
45 | - **Windows**: Run `update.cmd`
46 |
47 | ## Options
48 | Run without arguments for default behavior. Available CLI options (`python ide-cap-chan.py -h`):
49 | | Argument | Description |
50 | |----------|-------------|
51 | | `--model_path` | Path to model (Hugging Face, local, or external) |
52 | | `--model_type` | Model architecture/loader: idefics3, llava, joy-caption, molmo, qwen2vl, molmo72b, pixtral, exllama2, minicpmo, generic (default: `exllama2`) |
53 | | `--input_dir` | Input directory path (default: `2tag`) |
54 | | `--CUDA_VISIBLE_DEVICES` | Comma-separated GPU IDs (default: `0`). **Note**:
- Multi-GPU may strain your PSU
- `molmo72b` ignores this argument and auto-splits across GPUs |
55 | | `--caption_suffix` | Caption file extension (default: `.txt`) |
56 | | `--caption_format` | Output format: `json`, `markdown`, `short`, `long`, `bbox` (requires ToriiGate ≥0.4) |
57 | | `--add_tags` | Enhance captions with existing tag files (ToriiGate-family models), (default: `.ttxt`) |
58 | | `--add_chars` | Enhance captions with character information (requires ToriiGate ≥0.4), (default: `.ttxt`) |
59 | | `--add_char_traits` | Enhance captions with character traits (requires ToriiGate ≥0.4), (default: `.ttxt`) |
60 | | `--add_info` | Enhance captions with miscellaneous image info (requires ToriiGate ≥0.4), (default: `.ttxt`) |
61 | | `--no_chars` | Do not add character names (requires ToriiGate ≥0.4), (default: `.ttxt`) |
62 |
63 | ## Supported File Formats
64 | `.jpg`, `.png`, `.webp`, `.jpeg`
65 |
66 | ## Version History
67 | * **0.96**: Moved to CUDA 12.8, PyTorch2.7, added support for Blackwell GPUs
68 | * **0.95**: Dynamic multi-GPU task queuing instead of splitting based on approximate GPU speed
69 | * **0.9**: Added MiniCPM-o-2_6 loader support, rewritten to modular design, pinned versions,
70 | * **0.8**: Added ExLlama2 loader support (default), ToriiGate-v0.4 features, Molmo-72B auto-split
71 | * **0.7**: Added Molmo/Qwen2VL/Pixtral support, improved multi-GPU quant processing, code refactor
72 | * **0.6**: Internal code improvements
73 | * **0.5**: Added JoyCaption support, code refactor
74 | * **0.4**: Added LLaVA support, updated to PyTorch 2.5.1
75 | * **0.3**: Improved argument handling, fixed extension case sensitivity
76 | * **0.2**:
77 | - Multi-GPU support with load balancing
78 | - nf4 quantization
79 | - Fixed duplicate file filtering
80 | - Updated environment scripts
81 | * **0.1**: Initial release
82 |
83 | ## Note
84 | This project is a proof of concept and not production-ready.
85 |
86 | ## License
87 | [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0)
88 |
89 | ## Credits
90 | - Idefics3 Architecture: [HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3)
91 | - LLaVA Architecture: [Transformers Documentation](https://huggingface.co/docs/transformers/main/model_doc/llava)
92 | - JoyCaption Code: [fpgaminer/joycaption](https://github.com/fpgaminer/joycaption)
93 | - Qwen2-VL Architecture: [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)
94 | - Qwen2-VL Implementation: [MNeMoNiCuZ/qwen2-vl-7b-captioner-relaxed-batch](https://github.com/MNeMoNiCuZ/qwen2-vl-7b-captioner-relaxed-batch)
95 | - Molmo Architecture: [AllenAI Collection](https://huggingface.co/collections/allenai/molmo-66f379e6fe3b8ef090a8ca19)
96 | - Pixtral Architecture: [Pixtral Documentation](https://huggingface.co/docs/transformers/model_doc/pixtral)
97 | - MiniCPM-o-2_6 Architecture: [MiniCPM-o-2_6 Documentation](https://openbmb.notion.site/MiniCPM-o-2-6-A-GPT-4o-Level-MLLM-for-Vision-Speech-and-Multimodal-Live-Streaming-on-Your-Phone-185ede1b7a558042b5d5e45e6b237da9)
98 | - Vikhr-2-VL: [Vikhr-2-VL Documentation](https://huggingface.co/Vikhrmodels)
99 | - ExLlamaV2: [ExLlamaV2 Documentation](https://github.com/turboderp-org/exllamav2Vikhrmodels)
100 |
101 |
102 | **Model Credits**
103 | [ToriiGate](https://huggingface.co/Minthy) · [LLaVA](https://huggingface.co/llava-hf) · [JoyCaption](https://huggingface.co/fancyfeast) · [Qwen2, Pixtral](https://huggingface.co/Ertugrul) · [Molmo](https://huggingface.co/cyan2k) · [Molmo72b](https://huggingface.co/SeanScripts/Molmo-72B-0924-nf4) · [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) · [Vikhr-2-VL-2b-Instruct](https://huggingface.co/Vikhrmodels/Vikhr-2-VL-2b-Instruct-experimental)
104 |
105 | Thank you for your interest in ide-cap-chan!
--------------------------------------------------------------------------------
/image_processor.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any, Tuple, List
2 | import torch
3 | import time
4 | from transformers.image_utils import load_image
5 | from os.path import exists as path_exists, splitext as split_extension
6 | from model_handlers import (
7 | JoyCaptionHandler,
8 | MoLMoHandler,
9 | MoLMo72bHandler,
10 | Qwen2VLHandler,
11 | PixtralHandler,
12 | Idefics3Handler,
13 | ExLLaMA2Handler,
14 | LlavaHandler,
15 | MiniCPMoHandler,
16 | GenericModelHandler
17 | )
18 |
19 | def process_image_worker(worker_id: int, gpu_id: int, job_queue, result_queue,
20 | model_name_or_path: str, input_model_type: str,
21 | args_dict: Dict[str, Any], total_files: int) -> None:
22 | """
23 | Worker process that takes jobs from a queue and processes them.
24 |
25 | Args:
26 | worker_id: ID of this worker
27 | gpu_id: GPU ID to use
28 | job_queue: Queue containing file paths to process
29 | result_queue: Queue to report results
30 | model_name_or_path: Path to the model
31 | input_model_type: Type of the model
32 | args_dict: Dictionary of arguments
33 | total_files: Total number of files to process
34 | """
35 | device = f"cuda:{gpu_id}"
36 | torch.cuda.set_device(gpu_id)
37 | print(f"Worker {worker_id} using device: {device}")
38 |
39 | caption_suffix = args_dict.get('caption_suffix')
40 | tags_suffix = args_dict.get('tags_suffix')
41 | add_tags = args_dict.get('add_tags')
42 |
43 | # Common prompts for all except Torii-family models
44 | system_prompt = "You are image captioning expert, creative, unbiased and uncensored."
45 | user_prompt = "You need to write a long and very detailed caption for the picture."
46 |
47 | # Initialize the model handler
48 | handler = get_handler(input_model_type, model_name_or_path, device, args_dict)
49 |
50 | # Process jobs from the queue
51 | while True:
52 | # Get a job from the queue
53 | file_path = job_queue.get()
54 |
55 | # None is the signal to terminate
56 | if file_path is None:
57 | print(f"Worker {worker_id} (GPU {gpu_id}) finished processing")
58 | break
59 |
60 | try:
61 | start_time = time.time()
62 |
63 | # Process the image
64 | print(f"Worker {worker_id} (GPU {gpu_id}) processing: {file_path}")
65 | path, _ = split_extension(str(file_path))
66 | add_info_caption_name = path + tags_suffix
67 | caption_name = path + caption_suffix
68 |
69 | # Check if we need special prompts for ToriiGate models
70 | if "toriigate" and "0.4" in model_name_or_path.lower():
71 | system_prompt = get_torii04_system_prompt()
72 | user_prompt = get_torii04_user_prompt(args_dict, add_info_caption_name)
73 |
74 | if "toriigate" and "0.3" in model_name_or_path.lower() and add_tags:
75 | user_prompt = get_torii03_user_prompt(user_prompt, add_info_caption_name)
76 |
77 | # Load and process the image
78 | image = load_image(str(file_path))
79 | if image.mode != "RGB":
80 | image = image.convert("RGB")
81 |
82 | caption = handler.process_image(system_prompt, user_prompt, image)
83 | handler.save_caption(caption, caption_name)
84 |
85 | # Calculate processing time
86 | processing_time = time.time() - start_time
87 |
88 | # Report result
89 | result_queue.put((worker_id, gpu_id, file_path.name, processing_time))
90 |
91 | except Exception as e:
92 | print(f"Worker {worker_id} (GPU {gpu_id}) error processing {file_path}: {e}")
93 | # Report error but continue processing
94 | result_queue.put((worker_id, gpu_id, f"{file_path.name} (ERROR)", 0.0))
95 |
96 | def get_handler(input_model_type, model_name_or_path, device, args_dict):
97 | try:
98 | handlers = {
99 | "exllama2": ExLLaMA2Handler,
100 | "joy-caption": JoyCaptionHandler,
101 | "molmo": MoLMoHandler,
102 | "molmo72b": MoLMo72bHandler,
103 | "qwen2vl": Qwen2VLHandler,
104 | "pixtral": PixtralHandler,
105 | "idefics3": Idefics3Handler,
106 | "llava": LlavaHandler,
107 | "minicpmo": MiniCPMoHandler,
108 | "generic": GenericModelHandler
109 | }
110 | except Exception:
111 | print(f"Unsupported model type: {input_model_type}")
112 | return handlers[input_model_type](model_name_or_path, device, args_dict)
113 |
114 | def get_torii04_user_prompt(args_dict, add_info_caption_name):
115 | add_tags = args_dict.get('add_tags')
116 | add_chars = args_dict.get('add_chars')
117 | add_char_traits = args_dict.get('add_char_traits')
118 | add_info = args_dict.get('add_info')
119 | no_chars = args_dict.get('no_chars')
120 | caption_format = args_dict.get('caption_format')
121 |
122 | image_info={}
123 |
124 | if path_exists(add_info_caption_name):
125 | if add_tags:
126 | tags = open(add_info_caption_name).read().strip()
127 | image_info["booru_tags"] = tags
128 |
129 | if add_chars:
130 | chars = open(add_info_caption_name).read().strip()
131 | image_info["chars"] = chars
132 |
133 | if add_char_traits:
134 | traits = open(add_info_caption_name).read().strip()
135 | image_info["characters_traits"] = traits
136 |
137 | if add_info:
138 | info = open(add_info_caption_name).read().strip()
139 | image_info["info"] = info
140 |
141 | base_prompt={
142 | 'json': 'Describe the picture in structured json-like format.',
143 | 'markdown': 'Describe the picture in structured markdown format.',
144 | #TODO Not implemented in code yet
145 | #'caption_vars': 'Write the following options for captions: ["Regular Summary","Individual Parts","Midjourney-Style Summary","DeviantArt Commission Request"].',
146 | 'short': 'You need to write a medium-short and convenient caption for the picture.',
147 | 'long': 'You need to write a long and very detailed caption for the picture.',
148 | 'bbox': 'Write bounding boxes for each character and their faces.',
149 | }
150 |
151 | grounding_prompt={
152 | 'grounding_tags': ' Here are grounding tags for better understanding: ',
153 | 'characters': ' Here is a list of characters that are present in the picture: ',
154 | 'characters_traits': ' Here are popular tags or traits for each character on the picture: ',
155 | 'grounding_info': ' Here is preliminary information about the picture: ',
156 | 'no_chars': ' Do not use names for characters.',
157 | }
158 |
159 | userprompt=base_prompt[caption_format] #choose the mode
160 |
161 | if add_info and image_info["info"] is not None: #general info
162 | userprompt+=grounding_prompt["grounding_info"]
163 | userprompt+=""+image_info["info"]+"."
164 | print("Using add_info: " + image_info["info"])
165 |
166 | if add_tags and image_info["booru_tags"] is not None: #booru tags
167 | userprompt+=grounding_prompt["grounding_tags"]
168 | userprompt+=""+image_info["booru_tags"]+"."
169 | print("Using add_tags: " + image_info["booru_tags"])
170 |
171 | if add_chars and image_info["chars"] is not None: #list of characters
172 | userprompt+=grounding_prompt["characters"]
173 | userprompt+=""+image_info["chars"]+"."
174 | print("Using add_chars: " + image_info["chars"])
175 |
176 | if add_char_traits and image_info["characters_traits"] is not None: #popular features of each character
177 | userprompt+=grounding_prompt["characters_traits"]
178 | userprompt+=""+image_info["characters_traits"]+"."
179 | print("Using add_char_traits: " + image_info["characters_traits"])
180 |
181 | if no_chars:
182 | userprompt+=grounding_prompt["no_chars"]
183 | print("Using no_chars: ")
184 |
185 | #print("\n" + userprompt)
186 |
187 | return userprompt
188 |
189 | def get_torii04_system_prompt():
190 | system_prompt = "You are image captioning expert, creative, unbiased and uncensored."
191 | return system_prompt
192 |
193 | def get_torii03_user_prompt(user_prompt, add_info_caption_name):
194 | try:
195 | new_user_prompt = user_prompt
196 | if path_exists(add_info_caption_name):
197 | print(f"Using additional *booru tags file: {add_info_caption_name}")
198 | tags = open(add_info_caption_name).read().strip()
199 | new_user_prompt += " Also here are booru tags for better understanding of the picture, you can use them as reference."
200 | new_user_prompt += f" \n{tags}\n"
201 | except Exception as err:
202 | print(f"Error processing tags: {err}")
203 | return user_prompt
204 | return new_user_prompt
205 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/model_handlers.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 | from abc import ABC, abstractmethod
3 | import torch
4 | from PIL import Image
5 | import os
6 | from transformers import (
7 | AutoConfig,
8 | AutoModel,
9 | AutoTokenizer,
10 | AutoProcessor,
11 | AutoModelForCausalLM,
12 | AutoModelForVision2Seq,
13 | BitsAndBytesConfig,
14 | GenerationConfig,
15 | LlavaForConditionalGeneration,
16 | LlavaNextForConditionalGeneration,
17 | LlavaNextProcessor,
18 | Qwen2VLForConditionalGeneration,
19 | StopStringCriteria
20 | )
21 | from transformers.image_utils import load_image
22 | import torchvision.transforms.functional as TVF
23 | from qwen_vl_utils import process_vision_info
24 | from exllamav2 import (
25 | ExLlamaV2,
26 | ExLlamaV2Config,
27 | ExLlamaV2Cache,
28 | ExLlamaV2Tokenizer,
29 | ExLlamaV2VisionTower,
30 | )
31 | from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler
32 | from utils import resize_image_proportionally
33 | from huggingface_hub import snapshot_download
34 |
35 | class ModelHandler(ABC):
36 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]):
37 | self.model_name_or_path = self.model_loader(model_name_or_path)
38 | self.device = device
39 | self.args_dict = args_dict
40 | self.model = None
41 | self.processor = None
42 | self.tokenizer = None
43 | self.quantization_config = None
44 | self._initialize_model()
45 |
46 | @abstractmethod
47 | def _initialize_model(self):
48 | pass
49 |
50 | @abstractmethod
51 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str:
52 | pass
53 |
54 | def save_caption(self, caption: str, caption_path: str, encoding: str = "utf-8", errors: str = "ignore"):
55 | with open(caption_path, "w", encoding=encoding, errors=errors) as outf:
56 | outf.write(caption)
57 |
58 | def model_loader(self, model_name_or_path: str) -> str:
59 | local_model_dir = model_name_or_path.split('/')[-1]
60 | if os.path.exists(local_model_dir):
61 | print(f"Model directory '{local_model_dir}' already exists. Using local version.")
62 | return local_model_dir
63 | else:
64 | print(f"Downloading model '{model_name_or_path}'...")
65 | snapshot_download(
66 | repo_id=model_name_or_path,
67 | local_dir=local_model_dir,
68 | )
69 | print(f"Model successfully saved to '{local_model_dir}' directory.")
70 | return local_model_dir
71 |
72 | def _get_quantization_config(self):
73 | use_nf4 = self.args_dict.get('use_nf4')
74 | if use_nf4:
75 | return BitsAndBytesConfig(
76 | load_in_4bit=True,
77 | bnb_4bit_quant_type="nf4",
78 | bnb_4bit_use_double_quant=True,
79 | bnb_4bit_compute_dtype=torch.bfloat16
80 | )
81 | return None
82 |
83 | class ExLLaMA2Handler(ModelHandler):
84 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]):
85 | super().__init__(model_name_or_path, device, args_dict)
86 |
87 | def _initialize_model(self):
88 | self.config = ExLlamaV2Config(self.model_name_or_path)
89 | self.config.max_seq_len = 32768
90 |
91 | self.vision_model = ExLlamaV2VisionTower(self.config)
92 | self.vision_model.load(progress=True)
93 |
94 | self.model = ExLlamaV2(self.config)
95 | self.tokenizer = ExLlamaV2Tokenizer(self.config)
96 |
97 | device_count = torch.cuda.device_count()
98 | free_mem, total_mem = torch.cuda.mem_get_info(self.device)
99 | free_gb = free_mem / (1024 ** 3)
100 | total_gb = total_mem / (1024 ** 3)
101 |
102 | if len(self.args_dict.get('filelist_chunks', [])) == 1 and self.device == "cuda:0" and device_count > 1:
103 | autosplit = True
104 | else:
105 | autosplit = False
106 |
107 | if autosplit:
108 | print(f"VRAM allocation strategy: Autosplit on {device_count} GPUs")
109 | cache = ExLlamaV2Cache(self.model, lazy=True, max_seq_len=self.config.max_seq_len)
110 | self.model.load_autosplit(cache, progress=True)
111 | else:
112 | split = [0.0] * device_count
113 | gpu_id = int(self.device.split(":")[1])
114 | split[gpu_id] = free_gb
115 | print(f"VRAM allocation strategy: allocated {free_gb:.2f} GB on GPU:{gpu_id}")
116 | self.model.load(split, progress=True)
117 | cache = ExLlamaV2Cache(self.model, lazy=False, max_seq_len=self.config.max_seq_len)
118 |
119 | self.generator = ExLlamaV2DynamicGenerator(
120 | model=self.model,
121 | cache=cache,
122 | tokenizer=self.tokenizer,
123 | )
124 |
125 | def model_loader(self, model_name_or_path: str) -> str:
126 | local_model_dir = model_name_or_path.split('/')[-1]
127 | if os.path.exists(local_model_dir):
128 | print(f"Model directory '{local_model_dir}' already exists. Using local version.")
129 | return local_model_dir
130 | else:
131 | print(f"Downloading model '{model_name_or_path}'...")
132 | from huggingface_hub import snapshot_download
133 | snapshot_download(
134 | repo_id=model_name_or_path,
135 | local_dir=local_model_dir,
136 | local_dir_use_symlinks=False
137 | )
138 | print(f"Model successfully saved to '{local_model_dir}' directory.")
139 | return local_model_dir
140 |
141 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str:
142 | image = resize_image_proportionally(image, 2048, 2048)
143 | image_embeddings = [self.vision_model.get_image_embeddings(
144 | model=self.model,
145 | tokenizer=self.tokenizer,
146 | image=image,
147 | )]
148 | placeholders = "\n".join([ie.text_alias for ie in image_embeddings]) + "\n"
149 |
150 | msg_text = (
151 | "<|im_start|>system\n" +
152 | system_prompt +
153 | "<|im_end|>\n" +
154 | "<|im_start|>user\n" +
155 | placeholders +
156 | user_prompt +
157 | "<|im_end|>\n" +
158 | "<|im_start|>assistant\n"
159 | )
160 |
161 | gen_settings = ExLlamaV2Sampler.Settings()
162 | gen_settings.temperature = 0.6
163 | gen_settings.top_p = 0.9
164 | gen_settings.top_k = 40
165 |
166 | output = self.generator.generate(
167 | prompt=msg_text,
168 | max_new_tokens=512,
169 | add_bos=True,
170 | encode_special_tokens=True,
171 | decode_special_tokens=True,
172 | stop_conditions=[self.tokenizer.eos_token_id],
173 | gen_settings=gen_settings,
174 | embeddings=image_embeddings,
175 | )
176 |
177 | return output.split('<|im_start|>assistant\n')[-1].strip()
178 |
179 | class JoyCaptionHandler(ModelHandler):
180 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]):
181 | super().__init__(model_name_or_path, device, args_dict)
182 |
183 | def _initialize_model(self):
184 | self.quantization_config = self._get_quantization_config()
185 | self.model = LlavaForConditionalGeneration.from_pretrained(
186 | self.model_name_or_path,
187 | torch_dtype=torch.bfloat16,
188 | quantization_config=self.quantization_config,
189 | device_map=self.device
190 | )
191 | self.processor = AutoProcessor.from_pretrained(
192 | self.model_name_or_path,
193 | trust_remote_code=True,
194 | torch_dtype=torch.bfloat16,
195 | quantization_config=self.quantization_config,
196 | use_fast=False
197 | )
198 | self.model.eval()
199 |
200 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str:
201 | if image.size != (384, 384):
202 | image = image.resize((384, 384), Image.LANCZOS)
203 | pixel_values = TVF.pil_to_tensor(image)
204 | pixel_values = (pixel_values / 255.0)
205 | pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
206 | pixel_values = pixel_values.unsqueeze(0).to(self.device)
207 |
208 | convo = [
209 | {"role": "system", "content": system_prompt},
210 | {"role": "user", "content": user_prompt},
211 | ]
212 |
213 | convo_string = self.processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
214 |
215 | inputs = self.processor(text=[convo_string], images=[image], return_tensors="pt").to(self.device)
216 | inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
217 |
218 | generate_ids = self.model.generate(
219 | **inputs,
220 | max_new_tokens=300,
221 | do_sample=True,
222 | suppress_tokens=None,
223 | use_cache=True,
224 | temperature=0.6,
225 | top_k=None,
226 | top_p=0.9,
227 | )[0]
228 |
229 | generate_ids = generate_ids[inputs['input_ids'].shape[1]:]
230 |
231 | caption = self.processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
232 | caption = caption.strip()
233 |
234 | return caption
235 |
236 | class MoLMoHandler(ModelHandler):
237 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]):
238 | super().__init__(model_name_or_path, device, args_dict)
239 |
240 | def _initialize_model(self):
241 | self.processor = AutoProcessor.from_pretrained(
242 | self.model_name_or_path,
243 | trust_remote_code=True,
244 | torch_dtype='auto',
245 | use_fast=False
246 | )
247 | self.model = AutoModelForCausalLM.from_pretrained(
248 | self.model_name_or_path,
249 | trust_remote_code=True,
250 | torch_dtype='auto',
251 | attn_implementation='eager', # sdpa or flash_attention_2 or "eager"
252 | device_map=self.device
253 | )
254 | self.model.eval()
255 |
256 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str:
257 | image = resize_image_proportionally(image, 1024, 1024)
258 |
259 | user_only_prompt = system_prompt + " " + user_prompt
260 |
261 | inputs = self.processor.process(images=image, text=user_only_prompt)
262 | inputs = {k: v.to(self.device).unsqueeze(0) for k, v in inputs.items()}
263 | prompt_tokens = inputs["input_ids"].size(1)
264 |
265 | output = self.model.generate_from_batch(
266 | inputs,
267 | generation_config=GenerationConfig(
268 | max_new_tokens=512,
269 | ),
270 | stopping_criteria=[StopStringCriteria(tokenizer=self.processor.tokenizer, stop_strings=["<|endoftext|>"])],
271 | tokenizer=self.processor.tokenizer,
272 | )
273 |
274 | generated_tokens = output[0, prompt_tokens:]
275 | caption = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
276 | return caption
277 |
278 | class MoLMo72bHandler(ModelHandler):
279 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]):
280 | super().__init__(model_name_or_path, device, args_dict)
281 |
282 | def _initialize_model(self):
283 | gpus = []
284 | if torch.cuda.is_available():
285 | for i in range(torch.cuda.device_count()):
286 | device_props = torch.cuda.get_device_properties(i)
287 | compute_cap = device_props.major * 10 + device_props.minor
288 | vram = device_props.total_memory
289 | gpus.append((i, compute_cap, vram))
290 |
291 | sorted_gpus = sorted(gpus, key=lambda x: (-x[1], -x[2]))
292 |
293 | config = AutoConfig.from_pretrained(self.model_name_or_path, trust_remote_code=True)
294 | n_layers = config.num_hidden_layers
295 |
296 | fixed_vram_main = 1 * 1024**3
297 | PER_LAYER_VRAM = 0.75 * 1024**3
298 | SAFETY_MARGIN = 1
299 |
300 | device_map = {"model.vision_backbone": "cpu"}
301 | if sorted_gpus:
302 | layer_allocations = []
303 | remaining_layers = n_layers
304 |
305 | for i, (dev_id, _, vram) in enumerate(sorted_gpus):
306 | if remaining_layers <= 0:
307 | break
308 |
309 | available_vram = vram * SAFETY_MARGIN
310 | if i == 0:
311 | available_vram -= fixed_vram_main
312 |
313 | max_possible_layers = int(available_vram // PER_LAYER_VRAM)
314 | allocate_layers = min(max_possible_layers, remaining_layers)
315 | layer_allocations.append((dev_id, allocate_layers))
316 | remaining_layers -= allocate_layers
317 |
318 | if remaining_layers > 0:
319 | layer_allocations[-1] = (layer_allocations[-1][0], layer_allocations[-1][1] + remaining_layers)
320 |
321 | current_layer = 0
322 | for dev_id, layers in layer_allocations:
323 | end_layer = current_layer + layers
324 | for layer_idx in range(current_layer, end_layer):
325 | device_map[f"model.transformer.blocks.{layer_idx}"] = dev_id
326 | current_layer = end_layer
327 |
328 | main_gpu = sorted_gpus[0][0]
329 | secondary_gpu = sorted_gpus[1][0] if len(sorted_gpus) > 1 else main_gpu
330 | device_map.update({
331 | "model.transformer.wte": main_gpu,
332 | "model.transformer.ln_f": main_gpu,
333 | "model.transformer.ff_out": secondary_gpu,
334 | })
335 | else:
336 | device_map = "auto"
337 |
338 | self.model = AutoModelForCausalLM.from_pretrained(
339 | self.model_name_or_path,
340 | device_map=device_map,
341 | attn_implementation='eager', # sdpa or flash_attention_2 or "eager"
342 | torch_dtype=torch.bfloat16,
343 | use_safetensors=True,
344 | trust_remote_code=True
345 | )
346 | self.processor = AutoProcessor.from_pretrained(self.model_name_or_path, trust_remote_code=True, use_fast=False)
347 | self.model.model.vision_backbone.float()
348 |
349 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str:
350 | image = resize_image_proportionally(image, 1024, 1024)
351 |
352 | user_only_prompt = system_prompt + " " + user_prompt
353 |
354 | inputs = self.processor.process(images=image, text=user_only_prompt)
355 | inputs = {k: v.to(self.device).unsqueeze(0) for k, v in inputs.items()}
356 | prompt_tokens = inputs["input_ids"].size(1)
357 |
358 | output = self.model.generate_from_batch(
359 | inputs,
360 | generation_config=GenerationConfig(
361 | max_new_tokens=512,
362 | ),
363 | stopping_criteria=[StopStringCriteria(tokenizer=self.processor.tokenizer, stop_strings=["<|qqqq|>"])],
364 | tokenizer=self.processor.tokenizer,
365 | )
366 |
367 | generated_tokens = output[0, prompt_tokens:]
368 | caption = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
369 | return caption
370 |
371 | class Qwen2VLHandler(ModelHandler):
372 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]):
373 | super().__init__(model_name_or_path, device, args_dict)
374 |
375 | def _initialize_model(self):
376 | self.model = Qwen2VLForConditionalGeneration.from_pretrained(
377 | self.model_name_or_path,
378 | torch_dtype='auto',
379 | quantization_config=self._get_quantization_config(),
380 | device_map=self.device
381 | )
382 | self.processor = AutoProcessor.from_pretrained(self.model_name_or_path)
383 |
384 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str:
385 | image = resize_image_proportionally(image, 1024, 1024)
386 | messages = [
387 | {
388 | "role": "system",
389 | "content": [{"type": "text", "text": system_prompt}]
390 | },
391 | {
392 | "role": "user",
393 | "content": [
394 | {
395 | "type": "image",
396 | 'image': image
397 | },
398 | {
399 | "type": "text",
400 | "text": user_prompt
401 | }
402 | ]
403 | }
404 | ]
405 |
406 | text = self.processor.apply_chat_template(
407 | messages, tokenize=False, add_generation_prompt=True
408 | )
409 | image_inputs, _ = process_vision_info(messages)
410 | inputs = self.processor(
411 | text=[text],
412 | images=image_inputs,
413 | padding=True,
414 | return_tensors="pt",
415 | )
416 | inputs = inputs.to(self.device)
417 |
418 | generated_ids = self.model.generate(**inputs, max_new_tokens=512)
419 | generated_ids_trimmed = [
420 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
421 | ]
422 | caption = self.processor.batch_decode(
423 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
424 | )[0]
425 | return caption
426 |
427 | class PixtralHandler(ModelHandler):
428 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]):
429 | super().__init__(model_name_or_path, device, args_dict)
430 |
431 | def _initialize_model(self):
432 | self.quantization_config = self._get_quantization_config()
433 |
434 | self.model = LlavaForConditionalGeneration.from_pretrained(
435 | self.model_name_or_path,
436 | torch_dtype=torch.bfloat16,
437 | quantization_config=self.quantization_config,
438 | device_map=self.device
439 | )
440 | self.processor = AutoProcessor.from_pretrained(
441 | self.model_name_or_path,
442 | trust_remote_code=True,
443 | torch_dtype=torch.bfloat16,
444 | use_fast = False,
445 | quantization_config=self.quantization_config
446 | )
447 | self.model.eval()
448 |
449 | def _get_quantization_config(self):
450 | return BitsAndBytesConfig(
451 | load_in_4bit=True,
452 | bnb_4bit_quant_type="nf4",
453 | bnb_4bit_use_double_quant=True,
454 | bnb_4bit_compute_dtype=torch.bfloat16
455 | )
456 |
457 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str:
458 | image = resize_image_proportionally(image, 768, 768)
459 |
460 | user_only_prompt = system_prompt + " " + user_prompt
461 |
462 | conversation = [
463 | {
464 | "role": "user",
465 | "content": [
466 | {"type": "text", "text": user_only_prompt},
467 | {"type": "image"},
468 | ],
469 | }
470 | ]
471 |
472 | prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
473 |
474 | inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device)
475 |
476 | with torch.no_grad():
477 | with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
478 | generate_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.3, use_cache=True, top_k=20)
479 | caption = self.processor.batch_decode(generate_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
480 | return caption
481 |
482 | class Idefics3Handler(ModelHandler):
483 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]):
484 | super().__init__(model_name_or_path, device, args_dict)
485 |
486 | def _initialize_model(self):
487 | self.processor = AutoProcessor.from_pretrained(self.model_name_or_path)
488 | self.model = AutoModelForVision2Seq.from_pretrained(
489 | self.model_name_or_path,
490 | torch_dtype=torch.bfloat16,
491 | device_map=self.device
492 | )
493 | self.model.eval()
494 |
495 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str:
496 | image = resize_image_proportionally(image, 2048, 2048)
497 | messages = [
498 | {
499 | "role": "system",
500 | "content": [
501 | {"type": "text", "text": system_prompt}
502 | ]
503 | },
504 | {
505 | "role": "user",
506 | "content": [
507 | {"type": "image"},
508 | {"type": "text", "text": user_prompt}
509 | ]
510 | }
511 | ]
512 |
513 | prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
514 | inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
515 | inputs = {k: v.to(self.device) for k, v in inputs.items()}
516 |
517 | with torch.no_grad():
518 | generated_ids = self.model.generate(**inputs, max_new_tokens=512)
519 | generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
520 | caption = generated_texts[0].split("Assistant: ")[1]
521 | return caption
522 |
523 | class LlavaHandler(ModelHandler):
524 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]):
525 | super().__init__(model_name_or_path, device, args_dict)
526 |
527 | def _initialize_model(self):
528 | self.model = LlavaNextForConditionalGeneration.from_pretrained(
529 | self.model_name_or_path,
530 | torch_dtype=torch.float16,
531 | #low_cpu_mem_usage=True,
532 | vision_feature_select_strategy="default",
533 | attn_implementation='flash_attention_2', # sdpa or flash_attention_2 or "eager"
534 | device_map=self.device
535 | )
536 | self.processor = LlavaNextProcessor.from_pretrained(
537 | self.model_name_or_path,
538 | #padding_side="left",
539 | #vision_feature_select_strategy="default",
540 | #patch_size=32,
541 | use_fast=False
542 | )
543 | self.processor.patch_size = self.model.config.vision_config.patch_size
544 | self.processor.vision_feature_select_strategy = self.model.config.vision_feature_select_strategy
545 |
546 | self.model.eval()
547 |
548 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str:
549 | image = resize_image_proportionally(image, 768, 768)
550 | messages = [
551 | {
552 | "role": "system",
553 | "content": [
554 | {"type": "text", "text": system_prompt}
555 | ]
556 | },
557 | {
558 | "role": "user",
559 | "content": [
560 | {"type": "image"},
561 | {"type": "text", "text": user_prompt}
562 | ]
563 | }
564 | ]
565 |
566 | prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
567 | inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
568 | inputs = {k: v.to(self.device) for k, v in inputs.items()}
569 |
570 | with torch.no_grad():
571 | generated_ids = self.model.generate(**inputs, max_new_tokens=512)
572 | generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
573 | caption = generated_texts[0].split("[/INST] ")[1]
574 | return caption
575 |
576 | class MiniCPMoHandler(ModelHandler):
577 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]):
578 | super().__init__(model_name_or_path, device, args_dict)
579 |
580 | def _initialize_model(self):
581 | self.quantization_config = self._get_quantization_config()
582 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True)
583 | self.model = AutoModel.from_pretrained(
584 | self.model_name_or_path,
585 | attn_implementation='eager', # sdpa or flash_attention_2 or "eager"
586 | torch_dtype=torch.bfloat16,
587 | trust_remote_code=True,
588 | local_files_only=True,
589 | init_vision=True,
590 | init_audio=True,
591 | init_tts=True,
592 | quantization_config=self.quantization_config,
593 | device_map=self.device
594 | )
595 |
596 | self.model.eval()
597 | self.model.init_tts()
598 | self.model.tts.float()
599 |
600 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str:
601 | image = resize_image_proportionally(image, 2048, 2048)
602 |
603 | user_only_prompt = system_prompt + " " + user_prompt
604 |
605 | msgs = [{'role': 'user', 'content': [image, user_only_prompt]}]
606 | caption = self.model.chat(
607 | image=None,
608 | msgs=msgs,
609 | tokenizer=self.tokenizer
610 | )
611 |
612 | return caption
613 |
614 | class GenericModelHandler(ModelHandler):
615 | def __init__(self, model_name_or_path: str, device: str, args_dict: Dict[str, Any]):
616 | super().__init__(model_name_or_path, device, args_dict)
617 |
618 | def _initialize_model(self):
619 | self.quantization_config = self._get_quantization_config()
620 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True)
621 | self.model = AutoModel.from_pretrained(
622 | self.model_name_or_path,
623 | torch_dtype=torch.bfloat16,
624 | trust_remote_code=True,
625 | quantization_config=self.quantization_config,
626 | device_map=self.device
627 | )
628 |
629 | self.model.eval()
630 |
631 | def process_image(self, system_prompt: str, user_prompt: str, image: Image.Image) -> str:
632 | image = resize_image_proportionally(image, 1024, 1024)
633 |
634 | user_only_prompt = system_prompt + " " + user_prompt
635 |
636 | msgs = [{'role': 'user', 'content': [image, user_only_prompt]}]
637 | caption = self.model.chat(
638 | image=None,
639 | msgs=msgs,
640 | tokenizer=self.tokenizer
641 | )
642 |
643 | return caption
644 |
--------------------------------------------------------------------------------